{ "ckpt_dir": "$@bundle_root + '/models'", "train_batch_size": 4, "lr": 1e-05, "train_patch_size": [ 144, 176, 112 ], "latent_shape": [ "@latent_channels", 36, 44, 28 ], "load_autoencoder_path": "$@bundle_root + '/models/model_autoencoder.pt'", "load_autoencoder": "$@autoencoder_def.load_state_dict(torch.load(@load_autoencoder_path))", "autoencoder": "$@autoencoder_def.to(@device)", "network_def": { "_target_": "generative.networks.nets.DiffusionModelUNet", "spatial_dims": "@spatial_dims", "in_channels": "@latent_channels", "out_channels": "@latent_channels", "num_channels": [ 256, 256, 512 ], "attention_levels": [ false, true, true ], "num_head_channels": [ 0, 64, 64 ], "num_res_blocks": 2 }, "diffusion": "$@network_def.to(@device)", "optimizer": { "_target_": "torch.optim.Adam", "params": "$@diffusion.parameters()", "lr": "@lr" }, "lr_scheduler": { "_target_": "torch.optim.lr_scheduler.MultiStepLR", "optimizer": "@optimizer", "milestones": [ 100, 1000 ], "gamma": 0.1 }, "scale_factor": "$scripts.utils.compute_scale_factor(@autoencoder,@train#dataloader,@device)", "noise_scheduler": { "_target_": "generative.networks.schedulers.DDPMScheduler", "_requires_": [ "@load_autoencoder" ], "schedule": "scaled_linear_beta", "num_train_timesteps": 1000, "beta_start": 0.0015, "beta_end": 0.0195 }, "loss": { "_target_": "torch.nn.MSELoss" }, "train": { "inferer": { "_target_": "generative.inferers.LatentDiffusionInferer", "scheduler": "@noise_scheduler", "scale_factor": "@scale_factor" }, "crop_transforms": [ { "_target_": "CenterSpatialCropd", "keys": "image", "roi_size": "@train_patch_size" } ], "preprocessing": { "_target_": "Compose", "transforms": "$@preprocessing_transforms + @train#crop_transforms + @final_transforms" }, "dataset": { "_target_": "monai.apps.DecathlonDataset", "root_dir": "@dataset_dir", "task": "Task01_BrainTumour", "section": "training", "cache_rate": 1.0, "num_workers": 8, "download": false, "transform": "@train#preprocessing" }, "dataloader": { "_target_": "DataLoader", "dataset": "@train#dataset", "batch_size": "@train_batch_size", "shuffle": true, "num_workers": 0 }, "handlers": [ { "_target_": "LrScheduleHandler", "lr_scheduler": "@lr_scheduler", "print_lr": true }, { "_target_": "CheckpointSaver", "save_dir": "@ckpt_dir", "save_dict": { "model": "@diffusion" }, "save_interval": 0, "save_final": true, "epoch_level": true, "final_filename": "model.pt" }, { "_target_": "StatsHandler", "tag_name": "train_diffusion_loss", "output_transform": "$lambda x: monai.handlers.from_engine(['loss'], first=True)(x)" }, { "_target_": "TensorBoardStatsHandler", "log_dir": "@tf_dir", "tag_name": "train_diffusion_loss", "output_transform": "$lambda x: monai.handlers.from_engine(['loss'], first=True)(x)" } ], "trainer": { "_target_": "scripts.ldm_trainer.LDMTrainer", "device": "@device", "max_epochs": 5000, "train_data_loader": "@train#dataloader", "network": "@diffusion", "autoencoder_model": "@autoencoder", "optimizer": "@optimizer", "loss_function": "@loss", "latent_shape": "@latent_shape", "inferer": "@train#inferer", "key_train_metric": "$None", "train_handlers": "@train#handlers" } }, "initialize": [ "$monai.utils.set_determinism(seed=0)" ], "run": [ "@load_autoencoder", "$@autoencoder.eval()", "$print('scale factor:',@scale_factor)", "$@train#trainer.run()" ] }