{ "imports": [ "$import functools", "$import glob", "$import scripts" ], "bundle_root": ".", "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')", "ckpt_dir": "$@bundle_root + '/models'", "tf_dir": "$@bundle_root + '/eval'", "dataset_dir": "/workspace/data/medical", "pretrained": false, "perceptual_loss_model_weights_path": null, "train_batch_size": 2, "lr": 1e-05, "train_patch_size": [ 112, 128, 80 ], "channel": 0, "spacing": [ 1.1, 1.1, 1.1 ], "spatial_dims": 3, "image_channels": 1, "latent_channels": 8, "discriminator_def": { "_target_": "generative.networks.nets.PatchDiscriminator", "spatial_dims": "@spatial_dims", "num_layers_d": 3, "num_channels": 32, "in_channels": 1, "out_channels": 1, "norm": "INSTANCE" }, "autoencoder_def": { "_target_": "generative.networks.nets.AutoencoderKL", "spatial_dims": "@spatial_dims", "in_channels": "@image_channels", "out_channels": "@image_channels", "latent_channels": "@latent_channels", "num_channels": [ 64, 128, 256 ], "num_res_blocks": 2, "norm_num_groups": 32, "norm_eps": 1e-06, "attention_levels": [ false, false, false ], "with_encoder_nonlocal_attn": false, "with_decoder_nonlocal_attn": false }, "perceptual_loss_def": { "_target_": "generative.losses.PerceptualLoss", "spatial_dims": "@spatial_dims", "network_type": "resnet50", "is_fake_3d": true, "fake_3d_ratio": 0.2, "pretrained": "@pretrained", "pretrained_path": "@perceptual_loss_model_weights_path", "pretrained_state_dict_key": "state_dict" }, "dnetwork": "$@discriminator_def.to(@device)", "gnetwork": "$@autoencoder_def.to(@device)", "loss_perceptual": "$@perceptual_loss_def.to(@device)", "doptimizer": { "_target_": "torch.optim.Adam", "params": "$@dnetwork.parameters()", "lr": "@lr" }, "goptimizer": { "_target_": "torch.optim.Adam", "params": "$@gnetwork.parameters()", "lr": "@lr" }, "preprocessing_transforms": [ { "_target_": "LoadImaged", "keys": "image" }, { "_target_": "EnsureChannelFirstd", "keys": "image" }, { "_target_": "Lambdad", "keys": "image", "func": "$lambda x: x[@channel, :, :, :]" }, { "_target_": "EnsureChannelFirstd", "keys": "image", "channel_dim": "no_channel" }, { "_target_": "EnsureTyped", "keys": "image" }, { "_target_": "Orientationd", "keys": "image", "axcodes": "RAS" }, { "_target_": "Spacingd", "keys": "image", "pixdim": "@spacing", "mode": "bilinear" } ], "final_transforms": [ { "_target_": "ScaleIntensityRangePercentilesd", "keys": "image", "lower": 0, "upper": 99.5, "b_min": 0, "b_max": 1 } ], "train": { "crop_transforms": [ { "_target_": "RandSpatialCropd", "keys": "image", "roi_size": "@train_patch_size", "random_size": false } ], "preprocessing": { "_target_": "Compose", "transforms": "$@preprocessing_transforms + @train#crop_transforms + @final_transforms" }, "dataset": { "_target_": "monai.apps.DecathlonDataset", "root_dir": "@dataset_dir", "task": "Task01_BrainTumour", "section": "training", "cache_rate": 1.0, "num_workers": 8, "download": false, "transform": "@train#preprocessing" }, "dataloader": { "_target_": "DataLoader", "dataset": "@train#dataset", "batch_size": "@train_batch_size", "shuffle": true, "num_workers": 0 }, "handlers": [ { "_target_": "CheckpointSaver", "save_dir": "@ckpt_dir", "save_dict": { "model": "@gnetwork" }, "save_interval": 0, "save_final": true, "epoch_level": true, "final_filename": "model_autoencoder.pt" }, { "_target_": "StatsHandler", "tag_name": "train_loss", "output_transform": "$lambda x: monai.handlers.from_engine(['g_loss'], first=True)(x)[0]" }, { "_target_": "TensorBoardStatsHandler", "log_dir": "@tf_dir", "tag_name": "train_loss", "output_transform": "$lambda x: monai.handlers.from_engine(['g_loss'], first=True)(x)[0]" } ], "trainer": { "_target_": "scripts.ldm_trainer.VaeGanTrainer", "device": "@device", "max_epochs": 1500, "train_data_loader": "@train#dataloader", "g_network": "@gnetwork", "g_optimizer": "@goptimizer", "g_loss_function": "$functools.partial(scripts.losses.generator_loss, disc_net=@dnetwork, loss_perceptual=@loss_perceptual)", "d_network": "@dnetwork", "d_optimizer": "@doptimizer", "d_loss_function": "$functools.partial(scripts.losses.discriminator_loss, disc_net=@dnetwork)", "d_train_steps": 5, "g_update_latents": true, "latent_shape": "@latent_channels", "key_train_metric": "$None", "train_handlers": "@train#handlers" } }, "initialize": [ "$monai.utils.set_determinism(seed=0)" ], "run": [ "$@train#trainer.run()" ] }