{ "cells": [ { "cell_type": "markdown", "id": "0", "metadata": {}, "source": [ "# Embedding Tutorial\n", "\n", "This tutorial shows how to embed light curves with pretrained ONNX models\n", "from the `light_curve.embed` submodule.\n", "\n", "Models are distributed as ONNX files and downloaded from HuggingFace Hub by `from_hf()`.\n", "\n", "Requires: `pip install onnxruntime` (and optionally `huggingface_hub` for automatic downloads)" ] }, { "cell_type": "code", "execution_count": null, "id": "1", "metadata": {}, "outputs": [], "source": [ "# %pip install light-curve huggingface_hub onnxruntime" ] }, { "cell_type": "markdown", "id": "2", "metadata": {}, "source": [ "## Astromer2 — single-band embeddings\n", "\n", "[Astromer2](https://ui.adsabs.harvard.edu/abs/2026A%26A...707A.170D/abstract) is pretrained on\n", "MACHO light curves and produces 256-dimensional embeddings from irregularly-sampled `(time, mag)` pairs." ] }, { "cell_type": "code", "execution_count": null, "id": "3", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from light_curve.embed import Astromer2\n", "\n", "model = Astromer2.from_hf(output=\"mean\")\n", "print(f\"Model loaded. Max sequence length: {model.seq_size}\")\n", "\n", "rng = np.random.default_rng(0)\n", "time = np.sort(rng.uniform(0, 500, 120)).astype(np.float64)\n", "mag = rng.normal(15, 0.5, 120).astype(np.float64)\n", "\n", "embedding = model(time, mag)\n", "print(f\"Output shape: {embedding.shape} # (n_bands, n_subsamples, seq_windows, embed_dim)\")\n", "print(f\"Squeezed: {embedding.squeeze().shape}\")\n" ] }, { "cell_type": "markdown", "id": "4", "metadata": {}, "source": [ "## Astromer2 — multi-band\n", "\n", "Pass `bands=[...]` to embed each band independently.\n", "The model returns one embedding per band:" ] }, { "cell_type": "code", "execution_count": null, "id": "5", "metadata": {}, "outputs": [], "source": [ "model_gr = Astromer2.from_hf(output=\"mean\", bands=[\"g\", \"r\"])\n", "\n", "rng2 = np.random.default_rng(1)\n", "n = 120\n", "time_gr = np.sort(rng2.uniform(0, 400, n)).astype(np.float64)\n", "mag_gr = rng2.normal(15, 0.4, n).astype(np.float64)\n", "band_gr = np.array([\"g\", \"r\"] * (n // 2))\n", "\n", "emb_gr = model_gr(time_gr, mag_gr, band=band_gr)\n", "print(f\"Output shape: {emb_gr.shape} # (2 bands, n_subsamples, seq_windows, embed_dim)\")\n" ] }, { "cell_type": "markdown", "id": "6", "metadata": {}, "source": [ "## ATCAT — 6-band LSST model\n", "\n", "[ATCAT](https://ui.adsabs.harvard.edu/abs/2025arXiv251100614T/abstract) processes all ugrizY bands\n", "jointly and returns 384-dimensional embeddings. Inputs are flux, flux error, time, and integer band\n", "index (u=0, g=1, r=2, i=3, z=4, Y=5)." ] }, { "cell_type": "code", "execution_count": null, "id": "7", "metadata": {}, "outputs": [], "source": [ "from light_curve.embed import ATCAT\n", "\n", "model_atcat = ATCAT.from_hf(output=\"last\")\n", "print(f\"ATCAT loaded. Max sequence length: {model_atcat.seq_size}\")\n", "\n", "rng3 = np.random.default_rng(2)\n", "n3 = 150\n", "time3 = np.sort(rng3.uniform(0, 500, n3)).astype(np.float32)\n", "flux3 = rng3.normal(100, 15, n3).astype(np.float32) # flux in nJy\n", "flux_err3 = np.full(n3, 5.0, dtype=np.float32)\n", "band3 = np.array([i % 6 for i in range(n3)]) # u=0, g=1, r=2, i=3, z=4, Y=5\n", "\n", "emb3 = model_atcat(time3, flux3, flux_err3, band3)\n", "print(f\"Output shape: {emb3.shape} # (1, 1, 1, {emb3.shape[-1]})\")\n" ] }, { "cell_type": "markdown", "id": "8", "metadata": {}, "source": [ "## Notes\n", "\n", "- Embeddings have shape `(n_bands, n_subsamples, seq_windows, embed_dim)`. Use `.squeeze()` to get a flat vector for a single object.\n", "- `huggingface_hub` is only needed for automatic downloads via `from_hf()`. If you already have the ONNX file, it is not required.\n", "\n", "## Next steps\n", "\n", "- [Similarity search](../pre-executed/similarity_search.ipynb) — nearest-neighbour retrieval with embeddings\n", "- [Classification](../pre-executed/classification.ipynb) — training a classifier on embeddings\n", "- [onnxruntime tips](../onnxruntime.md) — thread control on shared HPC nodes, GPU/CUDA setup\n", "- [API reference](../api.md)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.12.0" } }, "nbformat": 4, "nbformat_minor": 5 }