An error occurred while executing the following cell: ------------------ # @title Code to quickly train MNIST VAE { display-mode: "form" } def kl_loss(config, x, x_hat, z, mu, logvar): recons_loss = F.binary_cross_entropy(x_hat, x, reduction="mean") kld_loss = kl_divergence(mu, logvar) loss = recons_loss + config["kl_coeff"] * kld_loss return loss def mmd_loss(config, x, x_hat, z, mu, logvar): recons_loss = F.binary_cross_entropy(x_hat, x, reduction="mean") mmd = MMD(torch.randn_like(z), z) loss = recons_loss + config["beta"] * mmd return loss class Encoder(nn.Module): def __init__(self, latent_dim: int = 256): super(Encoder, self).__init__() self.encoder = nn.Sequential(nn.Linear(28 * 28, 512), nn.ReLU()) self.fc_mu = nn.Linear(512, latent_dim) self.fc_var = nn.Linear(512, latent_dim) def forward(self, x): x = self.encoder(x) mu = self.fc_mu(x) log_var = self.fc_var(x) return mu, log_var class Decoder(nn.Module): def __init__(self, latent_dim: int = 256): super(Decoder, self).__init__() # Build Decoder self.decoder = nn.Sequential(nn.Linear(latent_dim, 512), nn.ReLU(), nn.Linear(512, 28 * 28), nn.Sigmoid()) def forward(self, z): result = self.decoder(z) return result lr = 0.001 latent_dim = 2 encoder = Encoder(latent_dim) decoder = Decoder(latent_dim) encoder2 = Encoder(latent_dim) decoder2 = Decoder(latent_dim) mnist_full = MNIST( ".", download=True, train=True, transform=transforms.Compose([transforms.ToTensor(), lambda x: rearrange(x, "c h w -> (c h w)")]), ) dm = DataLoader(mnist_full, batch_size=500, shuffle=True) kl_loss = partial(kl_loss, {"kl_coeff": 1}) mmd_loss = partial(mmd_loss, {"beta": 1}) vanilla_vae = VAE("vanilla_vae", kl_loss, encoder, decoder) vanilla_vae = VAEModule(vanilla_vae, lr, latent_dim) mmd_vae = VAE("mmd_vae", mmd_loss, encoder2, decoder2) mmd_vae = VAEModule(mmd_vae, lr, latent_dim) trainer1 = Trainer(gpus=1, weights_summary="full", max_epochs=10) trainer1.fit(vanilla_vae, dm) trainer2 = Trainer(gpus=1, weights_summary="full", max_epochs=10) trainer2.fit(mmd_vae, dm) ------------------ --------------------------------------------------------------------------- NameError Traceback (most recent call last) /tmp/ipykernel_5386/376362601.py in 65 kl_loss = partial(kl_loss, {"kl_coeff": 1}) 66 mmd_loss = partial(mmd_loss, {"beta": 1}) ---> 67 vanilla_vae = VAE("vanilla_vae", kl_loss, encoder, decoder) 68 vanilla_vae = VAEModule(vanilla_vae, lr, latent_dim) 69 NameError: name 'VAE' is not defined NameError: name 'VAE' is not defined