{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# A Downstream Task Demonstration using Our Model as a Foundation Model\n",
    "\n",
    "Searching stars by stellar spectroscopy - stellar parameters pairing using contrastive objective"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import h5py\n",
    "import warnings\n",
    "import numpy as np\n",
    "\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch.nn import functional as F\n",
    "\n",
    "from stellarperceptron.model import StellarPerceptron\n",
    "\n",
    "from astropy.io import fits\n",
    "from astroNN.apogee import allstar\n",
    "\n",
    "allstar_f = fits.getdata(allstar(dr=17))\n",
    "\n",
    "# ================== hardware-related settings ==================\n",
    "device = \"cpu\"  # \"cpu\" for CPU or \"cuda:x\" for a NVIDIA GPU\n",
    "mixed_precision = False\n",
    "torch.backends.cuda.matmul.allow_tf32 = False\n",
    "torch.backends.cudnn.allow_tf32 = False\n",
    "# ================== hardware-related settings ==================\n",
    "\n",
    "# need to load the trained main model first since we need the trained encoder and embeddings\n",
    "nn_model = StellarPerceptron.load(\n",
    "    \"./model_torch/\", mixed_precision=mixed_precision, device=device\n",
    ")\n",
    "\n",
    "\n",
    "def find_topk_matches(source_id, spec_embeddings, queries_embedding, k=10):\n",
    "    \"\"\"\n",
    "    Function to lookup stars in the embedding space\n",
    "    \"\"\"\n",
    "    spec_embeddings = torch.nn.functional.normalize(spec_embeddings, p=2, dim=1)\n",
    "    queries_embedding = torch.nn.functional.normalize(queries_embedding, p=2, dim=1)\n",
    "    dot_similarity = torch.matmul(\n",
    "        queries_embedding, torch.transpose(spec_embeddings, 0, 1)\n",
    "    )\n",
    "    results = torch.topk(dot_similarity, k).indices.cpu().numpy()\n",
    "\n",
    "    return [[source_id[idx] for idx in indices] for indices in results]\n",
    "\n",
    "\n",
    "class SpecEncoder(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        trained_model: StellarPerceptron = nn_model,\n",
    "        projection_dims: int = 32,\n",
    "        context_length: int = 64,\n",
    "        dropout_rate: float = 0.1,\n",
    "        device: str = \"cpu\",\n",
    "        dtype: torch.dtype = torch.float32,\n",
    "        **kwargs,\n",
    "    ) -> None:\n",
    "        super().__init__(**kwargs)\n",
    "        self.factory_kwargs = {\"device\": device, \"dtype\": dtype}\n",
    "        self.base_trained_model = trained_model\n",
    "        self.trained_encoder = trained_model.torch_encoder.eval()\n",
    "        self.trained_nonlinear_embedding = nn_model.embedding_layer\n",
    "        self.embedding_dim = trained_model.embedding_dim\n",
    "        self.dropout_rate = dropout_rate\n",
    "        self.projection_dims = projection_dims\n",
    "        self.context_length = context_length\n",
    "        self.dense_base = torch.nn.Linear(\n",
    "            self.embedding_dim * self.context_length,\n",
    "            self.projection_dims,\n",
    "            **self.factory_kwargs,\n",
    "        )\n",
    "        self.dropout_1 = torch.nn.Dropout(self.dropout_rate)\n",
    "        self.dense_1 = torch.nn.Linear(\n",
    "            self.projection_dims, self.projection_dims, **self.factory_kwargs\n",
    "        )\n",
    "        self.layernorm_1 = torch.nn.LayerNorm(\n",
    "            self.projection_dims, **self.factory_kwargs\n",
    "        )\n",
    "        self.dense_2 = torch.nn.Linear(\n",
    "            self.projection_dims, self.projection_dims, **self.factory_kwargs\n",
    "        )\n",
    "        self.layernorm_2 = torch.nn.LayerNorm(\n",
    "            self.projection_dims, **self.factory_kwargs\n",
    "        )\n",
    "        self.dense_3 = torch.nn.Linear(\n",
    "            self.projection_dims, self.projection_dims, **self.factory_kwargs\n",
    "        )\n",
    "        self.layernorm_3 = torch.nn.LayerNorm(\n",
    "            self.projection_dims, **self.factory_kwargs\n",
    "        )\n",
    "\n",
    "    def forward(self, inputs, inputs_names):\n",
    "        with torch.no_grad():  # non-trainable trained encoder\n",
    "            with warnings.catch_warnings():\n",
    "                warnings.simplefilter(\"ignore\")\n",
    "                self.base_trained_model.perceive(\n",
    "                    inputs, inputs_names, inference_mode=False\n",
    "                )\n",
    "                embeddings = (\n",
    "                    torch.flatten(nn_model._perception_memory, start_dim=1, end_dim=2)\n",
    "                    * 1.0\n",
    "                )\n",
    "        projected_embeddings = self.dense_base(embeddings)\n",
    "        # ================== #\n",
    "        x = F.gelu(projected_embeddings)\n",
    "        x = self.dense_1(x)\n",
    "        x = self.dropout_1(x)\n",
    "        projected_embeddings = self.layernorm_1(projected_embeddings + x)\n",
    "        # ================== #\n",
    "        x = F.gelu(projected_embeddings)\n",
    "        x = self.dense_2(x)\n",
    "        x = self.dropout_1(x)\n",
    "        projected_embeddings = self.layernorm_2(projected_embeddings + x)\n",
    "        # ================== #\n",
    "        x = F.gelu(projected_embeddings)\n",
    "        x = self.dense_3(x)\n",
    "        x = self.dropout_1(x)\n",
    "        projected_embeddings = self.layernorm_3(projected_embeddings + x)\n",
    "        # ================== #\n",
    "        return projected_embeddings\n",
    "\n",
    "    def predict(self, inputs, inputs_names):\n",
    "        with torch.inference_mode():\n",
    "            return self(inputs, inputs_names)\n",
    "\n",
    "\n",
    "class StellarEncoder(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        trained_model: StellarPerceptron = nn_model,\n",
    "        projection_dims: int = 128,\n",
    "        context_length: int = 64,\n",
    "        dropout_rate: float = 0.1,\n",
    "        device: str = \"cpu\",\n",
    "        dtype: torch.dtype = torch.float32,\n",
    "        **kwargs,\n",
    "    ) -> None:\n",
    "        super().__init__(**kwargs)\n",
    "        self.factory_kwargs = {\"device\": device, \"dtype\": dtype}\n",
    "        self.base_trained_model = trained_model\n",
    "        self.trained_encoder = trained_model.torch_encoder.eval()\n",
    "        self.embedding_dim = trained_model.embedding_dim\n",
    "        self.dropout_rate = dropout_rate\n",
    "        self.projection_dims = projection_dims\n",
    "        self.context_length = context_length\n",
    "\n",
    "        self.dense_base = torch.nn.Linear(\n",
    "            self.embedding_dim * self.context_length,\n",
    "            self.projection_dims,\n",
    "            **self.factory_kwargs,\n",
    "        )\n",
    "        self.dropout_1 = torch.nn.Dropout(self.dropout_rate)\n",
    "        self.dense_1 = torch.nn.Linear(\n",
    "            self.projection_dims, self.projection_dims, **self.factory_kwargs\n",
    "        )\n",
    "        self.layernorm_1 = torch.nn.LayerNorm(\n",
    "            self.projection_dims, **self.factory_kwargs\n",
    "        )\n",
    "\n",
    "        self.dense_2 = torch.nn.Linear(\n",
    "            self.projection_dims, self.projection_dims, **self.factory_kwargs\n",
    "        )\n",
    "        self.layernorm_2 = torch.nn.LayerNorm(\n",
    "            self.projection_dims, **self.factory_kwargs\n",
    "        )\n",
    "\n",
    "        self.dense_3 = torch.nn.Linear(\n",
    "            self.projection_dims, self.projection_dims, **self.factory_kwargs\n",
    "        )\n",
    "        self.layernorm_3 = torch.nn.LayerNorm(\n",
    "            self.projection_dims, **self.factory_kwargs\n",
    "        )\n",
    "\n",
    "    def forward(self, inputs, inputs_names):\n",
    "        with torch.no_grad():  # non-trainable trained encoder\n",
    "            with warnings.catch_warnings():\n",
    "                warnings.simplefilter(\"ignore\")\n",
    "                self.base_trained_model.perceive(\n",
    "                    inputs, inputs_names, inference_mode=False\n",
    "                )\n",
    "                embeddings = (\n",
    "                    torch.flatten(nn_model._perception_memory, start_dim=1, end_dim=2)\n",
    "                    * 1.0\n",
    "                )\n",
    "        projected_embeddings = self.dense_base(embeddings)\n",
    "        # ================== #\n",
    "        x = F.gelu(projected_embeddings)\n",
    "        x = self.dense_1(x)\n",
    "        x = self.dropout_1(x)\n",
    "        projected_embeddings = self.layernorm_1(projected_embeddings + x)\n",
    "        # ================== #\n",
    "        x = F.gelu(projected_embeddings)\n",
    "        x = self.dense_2(x)\n",
    "        x = self.dropout_1(x)\n",
    "        projected_embeddings = self.layernorm_2(projected_embeddings + x)\n",
    "        # ================== #\n",
    "        x = F.gelu(projected_embeddings)\n",
    "        x = self.dense_3(x)\n",
    "        x = self.dropout_1(x)\n",
    "        projected_embeddings = self.layernorm_3(projected_embeddings + x)\n",
    "        # ================== #\n",
    "        return projected_embeddings\n",
    "\n",
    "    def predict(self, inputs, inputs_names):\n",
    "        with torch.inference_mode():\n",
    "            return self(inputs, inputs_names)\n",
    "\n",
    "\n",
    "spec_nn = SpecEncoder(device=device, projection_dims=64)\n",
    "stars_nn = StellarEncoder(device=device, projection_dims=64)\n",
    "\n",
    "# load the trained model\n",
    "modelsearch = torch.load(\n",
    "    f\"./model_torch_search/model_torch_search.pt\", weights_only=True\n",
    ")\n",
    "spec_nn.load_state_dict(\n",
    "    modelsearch[\"specmodel_state_dict\"],\n",
    "    strict=True,\n",
    ")\n",
    "stars_nn.load_state_dict(\n",
    "    modelsearch[\"starmodel_state_dict\"],\n",
    "    strict=True,\n",
    ")\n",
    "spec_nn.eval()\n",
    "stars_nn.eval()\n",
    "\n",
    "# load database\n",
    "stars_database = h5py.File(\"./model_torch_search/gaia_small_db.h5\", \"r\")\n",
    "\n",
    "# calculage embeddings from XP spectra only, only using the first 32 bp and rp\n",
    "inputs_names = [*[f\"bp{i}\" for i in range(32)], *[f\"rp{i}\" for i in range(32)]]\n",
    "spec_embedding = torch.zeros(\n",
    "    (len(stars_database[\"source_id\"][()]), spec_nn.projection_dims)\n",
    ")\n",
    "batch_size = 1024\n",
    "\n",
    "for i in range(len(stars_database[\"source_id\"][()]) // batch_size):\n",
    "    spec_embedding[i * batch_size : (i + 1) * batch_size] = spec_nn.predict(\n",
    "        stars_database[\"rp32bp32\"][()][i * batch_size : (i + 1) * batch_size],\n",
    "        inputs_names,\n",
    "    )\n",
    "spec_embedding[(i + 1) * batch_size :] = spec_nn.predict(\n",
    "    stars_database[\"rp32bp32\"][()][(i + 1) * batch_size :], inputs_names\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Top-$k$ Searching\n",
    "\n",
    "* In the first case we will search for stars in the database with typical giants Teff and log(g) to see if the result has simiar ground truth in search paramters\n",
    "* In the second case we will search for stars in the database with typical dwarfs Teff and log(g) to see if the result has simiar ground truth in search paramters\n",
    "* In the third cases we will search for stars in the database with $J-K$ only to see if the result has simiar ground truth in search paramters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Most similar star in database (Gaia DR3 Source ID):\n",
      " [704126323012532864, 2608184212654581888, 4761510940323057792, 1632894240354190720, 2266653189281248384, 1422406139514522368, 2133086543966514432, 3933250136788680064, 2105722860647339776, 1223271389585049728]\n",
      "Their T_eff:\n",
      " [4701.993  4706.58   4705.3115 4589.256  4802.4175 4737.822  4703.627\n",
      " 4613.678  4365.014  4674.269 ]\n",
      "Their log(g):\n",
      " [2.4414918 2.4398525 2.3873532 2.4079142 2.50701   2.3648117 2.3617556\n",
      " 2.6147337 3.3534677 2.690805 ]\n"
     ]
    }
   ],
   "source": [
    "# chagne the parameters here to setup a query star, similar to the usage of percieve() in our main model\n",
    "q_embedding = stars_nn.predict([[4700.0, 2.5, 0.0]], [\"teff\", \"logg\", \"m_h\"])\n",
    "\n",
    "# find the stars with source id in the database that are most similar to the query star\n",
    "source_id = find_topk_matches(\n",
    "    stars_database[\"source_id\"][()], spec_embedding, q_embedding.cpu(), k=10\n",
    ")\n",
    "\n",
    "# cross-match APOGEE allstar catalog to find the T_eff and log(g) of the most similar stars\n",
    "allstar_idx = np.intersect1d(\n",
    "    np.array(source_id[0], dtype=np.int64),\n",
    "    allstar_f[\"GAIAEDR3_SOURCE_ID\"],\n",
    "    return_indices=True,\n",
    ")[2]\n",
    "\n",
    "print(\"Most similar star in database (Gaia DR3 Source ID):\\n\", source_id[0])\n",
    "print(\"Their T_eff:\\n\", allstar_f[\"TEFF\"][allstar_idx])\n",
    "print(\"Their log(g):\\n\", allstar_f[\"LOGG\"][allstar_idx])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Most similar star in database (Gaia DR3 Source ID):\n",
      " [1634950739415310848, 3251625804573815552, 3336303180759424640, 2635281985958371584, 1577014963485955968, 2828370624526504064, 41383388583719168, 2255693154297886976, 1273850642449726464, 3337904997402262016]\n",
      "Their T_eff:\n",
      " [3819.9282 3651.9966 4044.0454 3769.9504 3624.199  3859.134  3741.5781\n",
      " 3765.8489 4074.1426 3713.9307]\n",
      "Their log(g):\n",
      " [4.658435  4.6930223 4.6133075 4.652622  4.6738906 4.7005982 4.6687317\n",
      " 4.6484175 3.5922148 4.6665177]\n"
     ]
    }
   ],
   "source": [
    "# chagne the parameters here to setup a query star, similar to the usage of percieve() in our main model\n",
    "q_embedding = stars_nn([[3900.0, 4.65, 0.0]], [\"teff\", \"logg\", \"m_h\"])\n",
    "\n",
    "# find the stars with source id in the database that are most similar to the query star\n",
    "source_id = find_topk_matches(\n",
    "    stars_database[\"source_id\"][()], spec_embedding, q_embedding.cpu(), k=10\n",
    ")\n",
    "\n",
    "# cross-match APOGEE allstar catalog to find the T_eff and log(g) of the most similar stars\n",
    "allstar_idx = np.intersect1d(\n",
    "    np.array(source_id[0], dtype=np.int64),\n",
    "    allstar_f[\"GAIAEDR3_SOURCE_ID\"],\n",
    "    return_indices=True,\n",
    ")[2]\n",
    "\n",
    "print(\"Most similar star in database (Gaia DR3 Source ID):\\n\", source_id[0])\n",
    "print(\"Their T_eff:\\n\", allstar_f[\"TEFF\"][allstar_idx])\n",
    "print(\"Their log(g):\\n\", allstar_f[\"LOGG\"][allstar_idx])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Most similar star in database (Gaia DR3 Source ID):\n",
      " [2653942966024313472, 4474962404650333184, 1031840951989590272, 438295266462409216, 1248344858902045568, 1876562612822977664, 4522365924339728512, 3379761206049876352, 1393180845569802880, 65161976802151680]\n",
      "Their J-H:\n",
      " [0.625      0.6040001  0.54100037 0.7119999  0.599      0.5550003\n",
      " 0.5959997  0.5770006  0.5830002  0.60200024]\n"
     ]
    }
   ],
   "source": [
    "# chagne the parameters here to setup a query star, similar to the usage of percieve() in our main model\n",
    "q_embedding = stars_nn.predict([[0.6]], [\"jh\"])\n",
    "\n",
    "# find the stars with source id in the database that are most similar to the query star\n",
    "source_id = find_topk_matches(\n",
    "    stars_database[\"source_id\"][()], spec_embedding, q_embedding.cpu(), k=10\n",
    ")\n",
    "\n",
    "# cross-match APOGEE allstar catalog to find the T_eff and log(g) of the most similar stars\n",
    "allstar_idx = np.intersect1d(\n",
    "    np.array(source_id[0], dtype=np.int64),\n",
    "    allstar_f[\"GAIAEDR3_SOURCE_ID\"],\n",
    "    return_indices=True,\n",
    ")[2]\n",
    "\n",
    "print(\"Most similar star in database (Gaia DR3 Source ID):\\n\", source_id[0])\n",
    "print(\"Their J-H:\\n\", allstar_f[\"J\"][allstar_idx] - allstar_f[\"H\"][allstar_idx])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}