{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Sampling Test\n", "\n", "\n", "One of the most amazing advantage of using MelGAN is \"it works realtime on CPU!\".\n", "\n", "Try it!" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "\n", "import IPython.display as ipd\n", "import tqdm\n", "\n", "import torch\n", "\n", "import models\n", "from torch.utils.data import DataLoader\n", "from data import LJspeechDataset, collate_fn, collate_fn_synthesize\n", "\n", "import commons\n", "\n", "import librosa\n", "import numpy as np\n", "import os\n", "import json" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "kwargs = {'num_workers': 0, 'pin_memory': True}\n", "train_dataset = LJspeechDataset('./DATASETS/ljspeech/', True, 0.1)\n", "test_dataset = LJspeechDataset('./DATASETS/ljspeech/', False, 0.1)\n", "train_loader = DataLoader(train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn,\n", " **kwargs)\n", "test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=collate_fn_synthesize,\n", " **kwargs)\n", "\n", "model_dir = \"./logs/test/\"\n", "configs = json.load(open(os.path.join(model_dir, \"config.json\"), \"r\"))\n", "\n", "model = models.Generator(configs[\"data\"][\"n_channels\"])#.to(\"cuda\")" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "checkpoint_path = os.path.join(model_dir, 'G_205.pth')\n", "checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')\n", "state_dict = checkpoint_dict['model']\n", "new_state_dict= {}\n", "for k, v in model.state_dict().items():\n", " try:\n", " new_state_dict[k] = state_dict[k]\n", " except:\n", " print(\"%s is not in the checkpoint\" % k)\n", " new_state_dict[k] = v\n", "model.load_state_dict(new_state_dict)\n", "model.remove_weight_norm()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "idx_stop = 0\n", "for i, (x, c, _) in enumerate(test_loader):\n", " #x, c = x.to(\"cuda\"), c.to(\"cuda\")\n", " if i == idx_stop:\n", " break\n", "ipd.Audio(x.cpu().numpy().reshape(-1), rate=22050)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "with torch.no_grad():\n", " x_hat = model(c)\n", "ipd.Audio(x_hat.cpu().detach().numpy().reshape(-1), rate=22050)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }