{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/wider/.conda/envs/TORCH311/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "from dm_zoo.dff.PixelDiffusion import PixelDiffusionConditional\n", "from WD.io import write_config, load_config\n", "from WD.datasets import Conditional_Dataset_Zarr_Iterable\n", "from WD.utils import check_devices, create_dir, generate_uid, AreaWeightedMSELoss\n", "from pytorch_lightning import loggers as pl_loggers" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading dataset from 96FE8A.yaml\n" ] } ], "source": [ "ds_id = \"96FE8A\"\n", "\n", "print(f\"Loading dataset from {ds_id}.yaml\")\n", "\n", "ds_config_path = f\"/data/compoundx/WeatherDiff/config_file/{ds_id}.yml\"\n", "ds_config = load_config(ds_config_path)\n", "\n", "\n", "# datasets:\n", "train_ds_path = ds_config.file_structure.dir_model_input + f\"{ds_id}_train.zarr\"\n", "train_ds = Conditional_Dataset_Zarr_Iterable(train_ds_path, ds_config_path, shuffle_chunks=True, shuffle_in_chunks=True)\n", "\n", "val_ds_path = ds_config.file_structure.dir_model_input + f\"{ds_id}_val.zarr\"\n", "val_ds = Conditional_Dataset_Zarr_Iterable(val_ds_path, ds_config_path, shuffle_chunks=True, shuffle_in_chunks=True)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "lat_grid = train_ds.data.targets.lat[:]\n", "lon_grid = train_ds.data.targets.lon[:]\n", "\n", "loss_fn = AreaWeightedMSELoss(lat_grid, lon_grid).loss_fn " ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "ename": "KeyError", "evalue": "'ipython_dir'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[5], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m model \u001b[39m=\u001b[39m PixelDiffusionConditional(\n\u001b[1;32m 2\u001b[0m train_dataset\u001b[39m=\u001b[39mtrain_ds,\n\u001b[1;32m 3\u001b[0m valid_dataset\u001b[39m=\u001b[39mval_ds,\n\u001b[1;32m 4\u001b[0m generated_channels\u001b[39m=\u001b[39mds_config\u001b[39m.\u001b[39mn_generated_channels,\n\u001b[1;32m 5\u001b[0m condition_channels\u001b[39m=\u001b[39mds_config\u001b[39m.\u001b[39mn_condition_channels,\n\u001b[1;32m 6\u001b[0m batch_size\u001b[39m=\u001b[39m\u001b[39m64\u001b[39m,\n\u001b[1;32m 7\u001b[0m cylindrical_padding\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m,\n\u001b[1;32m 8\u001b[0m lr\u001b[39m=\u001b[39m\u001b[39m1e-4\u001b[39m,\n\u001b[1;32m 9\u001b[0m num_workers\u001b[39m=\u001b[39m\u001b[39m4\u001b[39m,\n\u001b[1;32m 10\u001b[0m loss_fn\u001b[39m=\u001b[39mloss_fn,\n\u001b[1;32m 11\u001b[0m lr_scheduler_name\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mConstant\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 12\u001b[0m )\n", "File \u001b[0;32m/gpfs0/home/wider/Projects/diffusion-models-for-weather-prediction/dm_zoo/dff/PixelDiffusion.py:47\u001b[0m, in \u001b[0;36mPixelDiffusionConditional.__init__\u001b[0;34m(self, generated_channels, condition_channels, train_dataset, valid_dataset, test_dataset, batch_size, lr, num_diffusion_steps_prediction, cylindrical_padding, loss_fn, num_workers, lr_scheduler_name, unet_type, schedule, num_timesteps, sampler, dims_mults)\u001b[0m\n\u001b[1;32m 44\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mnum_workers\u001b[39m=\u001b[39mnum_workers\n\u001b[1;32m 45\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mlr_scheduler_name \u001b[39m=\u001b[39m lr_scheduler_name\n\u001b[0;32m---> 47\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39msave_hyperparameters()\n\u001b[1;32m 50\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmodel \u001b[39m=\u001b[39m DenoisingDiffusionConditionalProcess(\n\u001b[1;32m 51\u001b[0m generated_channels\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mgenerated_channels,\n\u001b[1;32m 52\u001b[0m condition_channels\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcondition_channels,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 60\u001b[0m \n\u001b[1;32m 61\u001b[0m )\n", "File \u001b[0;32m~/.conda/envs/TORCH311/lib/python3.11/site-packages/pytorch_lightning/core/mixins/hparams_mixin.py:111\u001b[0m, in \u001b[0;36mHyperparametersMixin.save_hyperparameters\u001b[0;34m(self, ignore, frame, logger, *args)\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[39mif\u001b[39;00m current_frame:\n\u001b[1;32m 110\u001b[0m frame \u001b[39m=\u001b[39m current_frame\u001b[39m.\u001b[39mf_back\n\u001b[0;32m--> 111\u001b[0m save_hyperparameters(\u001b[39mself\u001b[39m, \u001b[39m*\u001b[39margs, ignore\u001b[39m=\u001b[39mignore, frame\u001b[39m=\u001b[39mframe)\n", "File \u001b[0;32m~/.conda/envs/TORCH311/lib/python3.11/site-packages/pytorch_lightning/utilities/parsing.py:163\u001b[0m, in \u001b[0;36msave_hyperparameters\u001b[0;34m(obj, ignore, frame, *args)\u001b[0m\n\u001b[1;32m 159\u001b[0m init_args \u001b[39m=\u001b[39m {}\n\u001b[1;32m 161\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mpytorch_lightning\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mcore\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mmixins\u001b[39;00m \u001b[39mimport\u001b[39;00m HyperparametersMixin\n\u001b[0;32m--> 163\u001b[0m \u001b[39mfor\u001b[39;00m local_args \u001b[39min\u001b[39;00m collect_init_args(frame, [], classes\u001b[39m=\u001b[39m(HyperparametersMixin,)):\n\u001b[1;32m 164\u001b[0m init_args\u001b[39m.\u001b[39mupdate(local_args)\n\u001b[1;32m 166\u001b[0m \u001b[39mif\u001b[39;00m ignore \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n", "File \u001b[0;32m~/.conda/envs/TORCH311/lib/python3.11/site-packages/pytorch_lightning/utilities/parsing.py:135\u001b[0m, in \u001b[0;36mcollect_init_args\u001b[0;34m(frame, path_args, inside, classes)\u001b[0m\n\u001b[1;32m 133\u001b[0m \u001b[39mreturn\u001b[39;00m collect_init_args(frame\u001b[39m.\u001b[39mf_back, path_args, inside\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m, classes\u001b[39m=\u001b[39mclasses)\n\u001b[1;32m 134\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m inside:\n\u001b[0;32m--> 135\u001b[0m \u001b[39mreturn\u001b[39;00m collect_init_args(frame\u001b[39m.\u001b[39mf_back, path_args, inside\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m, classes\u001b[39m=\u001b[39mclasses)\n\u001b[1;32m 136\u001b[0m \u001b[39mreturn\u001b[39;00m path_args\n", "File \u001b[0;32m~/.conda/envs/TORCH311/lib/python3.11/site-packages/pytorch_lightning/utilities/parsing.py:135\u001b[0m, in \u001b[0;36mcollect_init_args\u001b[0;34m(frame, path_args, inside, classes)\u001b[0m\n\u001b[1;32m 133\u001b[0m \u001b[39mreturn\u001b[39;00m collect_init_args(frame\u001b[39m.\u001b[39mf_back, path_args, inside\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m, classes\u001b[39m=\u001b[39mclasses)\n\u001b[1;32m 134\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m inside:\n\u001b[0;32m--> 135\u001b[0m \u001b[39mreturn\u001b[39;00m collect_init_args(frame\u001b[39m.\u001b[39mf_back, path_args, inside\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m, classes\u001b[39m=\u001b[39mclasses)\n\u001b[1;32m 136\u001b[0m \u001b[39mreturn\u001b[39;00m path_args\n", " \u001b[0;31m[... skipping similar frames: collect_init_args at line 135 (5 times)]\u001b[0m\n", "File \u001b[0;32m~/.conda/envs/TORCH311/lib/python3.11/site-packages/pytorch_lightning/utilities/parsing.py:135\u001b[0m, in \u001b[0;36mcollect_init_args\u001b[0;34m(frame, path_args, inside, classes)\u001b[0m\n\u001b[1;32m 133\u001b[0m \u001b[39mreturn\u001b[39;00m collect_init_args(frame\u001b[39m.\u001b[39mf_back, path_args, inside\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m, classes\u001b[39m=\u001b[39mclasses)\n\u001b[1;32m 134\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m inside:\n\u001b[0;32m--> 135\u001b[0m \u001b[39mreturn\u001b[39;00m collect_init_args(frame\u001b[39m.\u001b[39mf_back, path_args, inside\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m, classes\u001b[39m=\u001b[39mclasses)\n\u001b[1;32m 136\u001b[0m \u001b[39mreturn\u001b[39;00m path_args\n", "File \u001b[0;32m~/.conda/envs/TORCH311/lib/python3.11/site-packages/pytorch_lightning/utilities/parsing.py:129\u001b[0m, in \u001b[0;36mcollect_init_args\u001b[0;34m(frame, path_args, inside, classes)\u001b[0m\n\u001b[1;32m 126\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(frame\u001b[39m.\u001b[39mf_back, types\u001b[39m.\u001b[39mFrameType):\n\u001b[1;32m 127\u001b[0m \u001b[39mreturn\u001b[39;00m path_args\n\u001b[0;32m--> 129\u001b[0m local_self, local_args \u001b[39m=\u001b[39m _get_init_args(frame)\n\u001b[1;32m 130\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39m\"\u001b[39m\u001b[39m__class__\u001b[39m\u001b[39m\"\u001b[39m \u001b[39min\u001b[39;00m local_vars \u001b[39mand\u001b[39;00m (\u001b[39mnot\u001b[39;00m classes \u001b[39mor\u001b[39;00m \u001b[39misinstance\u001b[39m(local_self, classes)):\n\u001b[1;32m 131\u001b[0m \u001b[39m# recursive update\u001b[39;00m\n\u001b[1;32m 132\u001b[0m path_args\u001b[39m.\u001b[39mappend(local_args)\n", "File \u001b[0;32m~/.conda/envs/TORCH311/lib/python3.11/site-packages/pytorch_lightning/utilities/parsing.py:96\u001b[0m, in \u001b[0;36m_get_init_args\u001b[0;34m(frame)\u001b[0m\n\u001b[1;32m 94\u001b[0m exclude_argnames \u001b[39m=\u001b[39m (\u001b[39m*\u001b[39mfiltered_vars, \u001b[39m\"\u001b[39m\u001b[39m__class__\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39mframe\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39mframe_args\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 95\u001b[0m \u001b[39m# only collect variables that appear in the signature\u001b[39;00m\n\u001b[0;32m---> 96\u001b[0m local_args \u001b[39m=\u001b[39m {k: local_vars[k] \u001b[39mfor\u001b[39;00m k \u001b[39min\u001b[39;00m init_parameters}\n\u001b[1;32m 97\u001b[0m \u001b[39m# kwargs_var might be None => raised an error by mypy\u001b[39;00m\n\u001b[1;32m 98\u001b[0m \u001b[39mif\u001b[39;00m kwargs_var:\n", "File \u001b[0;32m~/.conda/envs/TORCH311/lib/python3.11/site-packages/pytorch_lightning/utilities/parsing.py:96\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 94\u001b[0m exclude_argnames \u001b[39m=\u001b[39m (\u001b[39m*\u001b[39mfiltered_vars, \u001b[39m\"\u001b[39m\u001b[39m__class__\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39mframe\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39mframe_args\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 95\u001b[0m \u001b[39m# only collect variables that appear in the signature\u001b[39;00m\n\u001b[0;32m---> 96\u001b[0m local_args \u001b[39m=\u001b[39m {k: local_vars[k] \u001b[39mfor\u001b[39;00m k \u001b[39min\u001b[39;00m init_parameters}\n\u001b[1;32m 97\u001b[0m \u001b[39m# kwargs_var might be None => raised an error by mypy\u001b[39;00m\n\u001b[1;32m 98\u001b[0m \u001b[39mif\u001b[39;00m kwargs_var:\n", "\u001b[0;31mKeyError\u001b[0m: 'ipython_dir'" ] } ], "source": [ "model = PixelDiffusionConditional(\n", " train_dataset=train_ds,\n", " valid_dataset=val_ds,\n", " generated_channels=ds_config.n_generated_channels,\n", " condition_channels=ds_config.n_condition_channels,\n", " batch_size=64,\n", " cylindrical_padding=True,\n", " lr=1e-4,\n", " num_workers=4,\n", " loss_fn=loss_fn,\n", " lr_scheduler_name=\"Constant\"\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "TORCH311", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.4" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }