{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Sampling Test\n",
"\n",
"\n",
"One of the most amazing advantage of using MelGAN is \"it works realtime on CPU!\".\n",
"\n",
"Try it!"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import IPython.display as ipd\n",
"import tqdm\n",
"\n",
"import torch\n",
"\n",
"import models\n",
"from torch.utils.data import DataLoader\n",
"from data import LJspeechDataset, collate_fn, collate_fn_synthesize\n",
"\n",
"import commons\n",
"\n",
"import librosa\n",
"import numpy as np\n",
"import os\n",
"import json"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"kwargs = {'num_workers': 0, 'pin_memory': True}\n",
"train_dataset = LJspeechDataset('./DATASETS/ljspeech/', True, 0.1)\n",
"test_dataset = LJspeechDataset('./DATASETS/ljspeech/', False, 0.1)\n",
"train_loader = DataLoader(train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn,\n",
" **kwargs)\n",
"test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=collate_fn_synthesize,\n",
" **kwargs)\n",
"\n",
"model_dir = \"./logs/test/\"\n",
"configs = json.load(open(os.path.join(model_dir, \"config.json\"), \"r\"))\n",
"\n",
"model = models.Generator(configs[\"data\"][\"n_channels\"])#.to(\"cuda\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"checkpoint_path = os.path.join(model_dir, 'G_205.pth')\n",
"checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')\n",
"state_dict = checkpoint_dict['model']\n",
"new_state_dict= {}\n",
"for k, v in model.state_dict().items():\n",
" try:\n",
" new_state_dict[k] = state_dict[k]\n",
" except:\n",
" print(\"%s is not in the checkpoint\" % k)\n",
" new_state_dict[k] = v\n",
"model.load_state_dict(new_state_dict)\n",
"model.remove_weight_norm()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" "
],
"text/plain": [
""
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"idx_stop = 0\n",
"for i, (x, c, _) in enumerate(test_loader):\n",
" #x, c = x.to(\"cuda\"), c.to(\"cuda\")\n",
" if i == idx_stop:\n",
" break\n",
"ipd.Audio(x.cpu().numpy().reshape(-1), rate=22050)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" "
],
"text/plain": [
""
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"with torch.no_grad():\n",
" x_hat = model(c)\n",
"ipd.Audio(x_hat.cpu().detach().numpy().reshape(-1), rate=22050)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}