{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# JaxQSOFit Tutorial: SDSS Example + Broad-Line Measurements\n", "\n", "This notebook demonstrates the main features of `jaxqsofit` using a real SDSS spectrum near:\n", "\n", "```python\n", "coord = SkyCoord(184.0307, -2.2383, unit='deg')\n", "```\n", "\n", "Recommended fitting mode: `fit_method='optax+nuts'`.\n", "This runs the staged SVI/Optax MAP initializer, then performs full posterior sampling with NUTS.\n", "\n", "It covers:\n", "- Fetching spectrum with `astroquery`\n", "- Running fits (`nuts`, `optax`, `optax+nuts`)\n", "- Overriding priors with `prior_config`\n", "- Measuring broad-line FWHM and luminosity from fitted components\n", "\n", "Current default: `fit_method=\"optax\"` runs staged SVI/Optax MAP optimization (continuum warm start, then full model). `fit_method=\"optax+nuts\"` uses that staged MAP point to initialize NUTS.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "ecad8de2", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "from astroquery.sdss import SDSS\n", "from astropy.coordinates import SkyCoord\n", "from astropy import units as u\n", "from astropy.cosmology import FlatLambdaCDM\n", "\n", "from jaxqsofit import QSOFit, build_default_prior_config\n" ] }, { "cell_type": "markdown", "id": "58ce3b9e", "metadata": {}, "source": [ "## 1. Download one SDSS spectrum\n" ] }, { "cell_type": "code", "execution_count": null, "id": "ec2817d1", "metadata": {}, "outputs": [], "source": [ "coord = SkyCoord(184.0307, -2.2383, unit='deg')\n", "xid = SDSS.query_region(coord, spectro=True, radius=5 * u.arcsec)\n", "sp = SDSS.get_spectra(matches=xid[:1])[0]\n", "\n", "tb = sp[1].data\n", "lam = np.asarray(10 ** tb['loglam'], dtype=float)\n", "flux = np.asarray(tb['flux'], dtype=float)\n", "ivar = np.asarray(tb['ivar'], dtype=float)\n", "\n", "err = np.full_like(flux, 1e-6)\n", "m = np.isfinite(ivar) & (ivar > 0)\n", "err[m] = 1.0 / np.sqrt(ivar[m])\n", "\n", "z = float(sp[2].data['z'][0])\n", "ra = float(coord.ra.deg)\n", "dec = float(coord.dec.deg)\n", "\n", "print(f'Nspec pixels: {lam.size}')\n", "print(f'z = {z:.5f}')\n", "\n", "plate = int(sp[0].header.get('plateid', 0))\n", "mjd = int(sp[0].header.get('mjd', 0))\n", "fiber = int(sp[0].header.get('fiberid', 0))\n", "sdss_filename = f\"{plate:04d}-{mjd}-{fiber:04d}\"\n" ] }, { "cell_type": "markdown", "id": "3cc9dc7f", "metadata": {}, "source": [ "## 2. Build a default prior config (auto-scaled from flux)\n", "\n", "You can use defaults directly, or modify selected entries.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "37500d89", "metadata": {}, "outputs": [], "source": [ "prior_config = build_default_prior_config(flux)\n", "\n", "# Example override: slightly tighter PL slope and Fe normalization priors\n", "prior_config['PL_slope'] = {'loc': -1.5, 'scale': 0.3, 'low': -3.5, 'high': 0.5}\n", "prior_config['log_Fe_uv_norm'] = {'loc': np.log(max(1e-3 * np.median(np.abs(flux)), 1e-10)), 'scale': 0.4}\n", "prior_config['log_Fe_op_over_uv'] = {'loc': 0.0, 'scale': 0.4}\n", "\n", "# Robust line scale multipliers (optional)\n", "prior_config['line_dmu_scale_mult'] = 0.25\n", "prior_config['line_sig_scale_mult'] = 0.25\n", "prior_config['line_amp_scale_mult'] = 0.20\n" ] }, { "cell_type": "markdown", "id": "4062831a", "metadata": {}, "source": [ "## 3. Run a fit\n", "\n", "Recommended:\n", "- `fit_method='optax+nuts'` (staged SVI/Optax initialization, then full NUTS posterior)\n", "\n", "Other options:\n", "- `fit_method='nuts'` (full posterior, slower initialization)\n", "- `fit_method='optax'` (staged MAP only, fastest, no posterior samples)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "0d0bf521", "metadata": {}, "outputs": [], "source": [ "q = QSOFit(\n", " lam=lam,\n", " flux=flux,\n", " err=err,\n", " z=z,\n", " ra=ra,\n", " dec=dec,\n", " filename=sdss_filename,\n", " output_path='.',\n", ")\n", "\n", "q.fit(\n", " deredden=True,\n", " fit_method='optax+nuts',\n", " fit_lines=True,\n", " decompose_host=True,\n", " fit_fe=True,\n", " fit_bc=False,\n", " fit_poly=True,\n", " prior_config=prior_config,\n", " dsps_ssp_fn='../tempdata.h5',\n", " optax_steps=600,\n", " optax_lr=1e-2,\n", " nuts_warmup=50,\n", " nuts_samples=50,\n", " nuts_chains=1,\n", " plot_fig=True,\n", " save_fig=False,\n", " save_result=False,\n", ")\n" ] }, { "cell_type": "markdown", "id": "5ea1e441", "metadata": {}, "source": [ "Estimate reduced Chi squared:" ] }, { "cell_type": "code", "execution_count": null, "id": "62d9623b", "metadata": {}, "outputs": [], "source": [ "resid = np.asarray(q.flux) - np.asarray(q.model_total)\n", "sigma = np.asarray(q.err)\n", "\n", "# include fitted jitter (recommended)\n", "if q.numpyro_samples is not None:\n", " s = q.numpyro_samples\n", " frac_j = float(np.median(np.asarray(s.get(\"frac_jitter\", 0.0))))\n", " add_j = float(np.median(np.asarray(s.get(\"add_jitter\", 0.0))))\n", " sigma = np.sqrt(sigma**2 + (frac_j*np.abs(np.asarray(q.model_total)))**2 + add_j**2)\n", "\n", "m = np.isfinite(resid) & np.isfinite(sigma) & (sigma > 0)\n", "z = resid[m] / sigma[m]\n", "\n", "chi2 = float(np.sum(z**2))\n", "chi2_per_pixel = float(np.mean(z**2)) # more stable than reduced chi2 here\n", "wrms = float(np.sqrt(np.mean(z**2)))\n", "\n", "print(\"chi2:\", chi2)\n", "print(\"chi2_per_pixel:\", chi2_per_pixel)\n", "print(\"wrms (normalized residual std):\", wrms)\n" ] }, { "cell_type": "markdown", "id": "ea711bbb", "metadata": {}, "source": [ "## 4. MCMC diagnostics (trace + corner)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "e04a26e7", "metadata": {}, "outputs": [], "source": [ "q.plot_mcmc_diagnostics(\n", " #param_names='all',\n", " do_trace=True,\n", " do_corner=True,\n", " max_vector_elems=2,\n", " corner_bins=25,\n", " corner_max_points=1500,\n", ")\n" ] }, { "cell_type": "markdown", "id": "85abe89a", "metadata": {}, "source": [ "## 5. Measure broad-line FWHM and luminosity with posterior errors\n", "\n", "`QSOFit` provides:\n", "- `line_profile_from_components(line_key)` for posterior-median component profiles\n", "- `line_profile_from_draw(draw_index, line_key)` for one posterior draw\n", "- `line_props(profile, wave=None)` -> `(fwhm_kms, integrated_area)`\n", "\n", "The integrated area is in `10^-17 erg s^-1 cm^-2` units, so convert to luminosity with cosmology.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "0d86db32", "metadata": {}, "outputs": [], "source": [ "cosmo = FlatLambdaCDM(H0=70, Om0=0.3)\n", "\n", "def flux_to_luminosity(area_1e17, z):\n", " d_l_cm = cosmo.luminosity_distance(z).to(u.cm).value\n", " return area_1e17 * 1e-17 * 4.0 * np.pi * d_l_cm**2\n", "\n", "amp_draws = np.asarray(q.pred_out['line_amp_per_component'])\n", "\n", "for line_key in [\"CIV_br\", \"MgII_br\", \"Hb_br\", \"Ha_br\"]:\n", " fwhm_samp, logL_samp = [], []\n", " for i in range(amp_draws.shape[0]):\n", " prof = q.line_profile_from_draw(i, line_key)\n", " fwhm, area = q.line_props(prof)\n", " fwhm_samp.append(fwhm)\n", " if np.isfinite(area) and area > 0:\n", " logL_samp.append(np.log10(flux_to_luminosity(area, q.z)))\n", " else:\n", " logL_samp.append(np.nan)\n", "\n", " fwhm_samp = np.asarray(fwhm_samp, dtype=float)\n", " logL_samp = np.asarray(logL_samp, dtype=float)\n", "\n", " f16, f50, f84 = np.nanpercentile(fwhm_samp, [16, 50, 84])\n", " l16, l50, l84 = np.nanpercentile(logL_samp, [16, 50, 84])\n", "\n", " print(\n", " f\"{line_key:8s} \"\n", " f\"FWHM={f50:.1f} (+{f84-f50:.1f}/-{f50-f16:.1f}) km/s \"\n", " f\"logL={l50:.3f} (+{l84-l50:.3f}/-{l50-l16:.3f})\"\n", " )\n" ] }, { "cell_type": "markdown", "id": "7c084c14", "metadata": {}, "source": [ "...but don't forget to subtract the instrumental resolution in quadrature before reporting the FWHM values." ] }, { "cell_type": "markdown", "id": "fb3ea7f7", "metadata": {}, "source": [ "## 6. Inspect posterior samples (recommended with `optax+nuts`)\n", "\n", "If you used `fit_method='nuts'` or `'optax+nuts'`, posterior samples are stored in:\n", "- `q.numpyro_samples`\n" ] }, { "cell_type": "code", "execution_count": null, "id": "05f9e255", "metadata": {}, "outputs": [], "source": [ "if hasattr(q, 'numpyro_samples') and q.numpyro_samples is not None:\n", " keys = sorted(q.numpyro_samples.keys())\n", " print('Num posterior params:', len(keys))\n", " print('First 20 keys:', keys)\n", "else:\n", " print('No NumPyro samples available (likely fit_method=optax).')\n" ] }, { "cell_type": "markdown", "id": "cbf58d7c", "metadata": {}, "source": [ "## 7. Quick component diagnostics\n" ] }, { "cell_type": "code", "execution_count": null, "id": "e5ca9a4a", "metadata": {}, "outputs": [], "source": [ "print('max data :', np.nanmax(q.flux))\n", "print('max total model :', np.nanmax(q.model_total))\n", "print('max PL :', np.nanmax(q.f_pl_model))\n", "print('max host :', np.nanmax(q.host))\n", "print('max FeII :', np.nanmax(q.f_fe_mgii_model + q.f_fe_balmer_model))\n", "print('max Balmer cont :', np.nanmax(q.f_bc_model))\n", "print('max lines :', np.nanmax(q.f_line_model))\n" ] }, { "cell_type": "code", "execution_count": null, "id": "eef25a06", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3" } }, "nbformat": 4, "nbformat_minor": 5 }