{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ASR DeepSpeech Examples\n",
"\n",
"This notebook demonstrates ART's DeepSpeech estimator and the Imperceptible ASR attack.\n",
"\n",
"---\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Preliminaries"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"import torch\n",
"import numpy as np\n",
"import IPython.display as ipd\n",
"import matplotlib.pyplot as plt\n",
"from deepspeech_pytorch.loader.data_loader import load_audio\n",
"\n",
"from art.estimators.speech_recognition import PyTorchDeepSpeech\n",
"from art.attacks.evasion.imperceptible_asr.imperceptible_asr_pytorch import ImperceptibleASRPyTorch\n",
"from art import config\n",
"from art.utils import get_file\n",
"\n",
"\n",
"# Set seed\n",
"np.random.seed(1234)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Audio Data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2.1 Download Data"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/home/minhtn/.art/data/deepspeech_audio\n",
"Skipping url: http://www.openslr.org/resources/12/train-clean-100.tar.gz\n",
"Skipping url: http://www.openslr.org/resources/12/train-clean-360.tar.gz\n",
"Skipping url: http://www.openslr.org/resources/12/train-other-500.tar.gz\n",
"Sorting manifests...\n",
"Pruning manifests between 1 and 15 seconds\n",
"0it [00:00, ?it/s]\n",
"\n",
"\n",
"Skipping url: http://www.openslr.org/resources/12/dev-clean.tar.gz\n",
"Skipping url: http://www.openslr.org/resources/12/dev-other.tar.gz\n",
"Sorting manifests...\n",
"0it [00:00, ?it/s]\n",
"\n",
"\n",
"100% [..................................................] 346663984 / 346663984Unpacking test-clean.tar.gz...\n",
"Converting flac files to wav and extracting transcripts...\n",
"129it [00:29, 4.38it/s]\n",
"Finished http://www.openslr.org/resources/12/test-clean.tar.gz\n",
"Sorting manifests...\n",
"100%|████████████████████████████████████| 2620/2620 [00:00<00:00, 69321.65it/s]\n",
"\n",
"\n",
"Skipping url: http://www.openslr.org/resources/12/test-other.tar.gz\n",
"Sorting manifests...\n",
"0it [00:00, ?it/s]\n",
"\n",
"\n",
"/home/minhtn/ibm/projects/adversarial-robustness-toolbox/notebooks\n"
]
}
],
"source": [
"# Prepare to download data\n",
"data_dir = os.path.join(config.ART_DATA_PATH, \"deepspeech_audio\")\n",
"current_dir = %pwd\n",
"\n",
"if not os.path.exists(data_dir):\n",
" os.makedirs(data_dir)\n",
"\n",
"# Download audio data\n",
"get_file('librispeech.py', 'https://raw.githubusercontent.com/SeanNaren/deepspeech.pytorch/master/data/librispeech.py', path=data_dir)\n",
"\n",
"%cd $data_dir\n",
"!python librispeech.py --files-to-use test-clean.tar.gz\n",
"%cd $current_dir"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2.2 Create Model and Data Utilities"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Create a DeepSpeech estimator\n",
"speech_recognizer = PyTorchDeepSpeech(pretrained_model=\"librispeech\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def display_waveform(waveform, title=\"\", sample_rate=16000):\n",
" \"\"\"\n",
" Display waveform plot and audio play UI.\n",
" \"\"\"\n",
" plt.figure()\n",
" plt.title(title)\n",
" plt.plot(waveform)\n",
" ipd.display(ipd.Audio(waveform, rate=sample_rate))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"labels_map = dict([(speech_recognizer.model.labels[i], i) for i in range(len(speech_recognizer.model.labels))])\n",
"def parse_transcript(path):\n",
" with open(path, 'r', encoding='utf8') as f:\n",
" transcript = f.read().replace('\\n', '')\n",
" result = list(filter(None, [labels_map.get(x) for x in list(transcript)]))\n",
" return transcript, result"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2.3 Play with Some Audios"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Encoded label: [9, 6, 28, 9, 16, 17, 6, 5, 28, 21, 9, 6, 19, 6, 28, 24, 16, 22, 13, 5, 28, 3, 6, 28, 20, 21, 6, 24, 28, 7, 16, 19, 28, 5, 10, 15, 15, 6, 19, 28, 21, 22, 19, 15, 10, 17, 20, 28, 2, 15, 5, 28, 4, 2, 19, 19, 16, 21, 20, 28, 2, 15, 5, 28, 3, 19, 22, 10, 20, 6, 5, 28, 17, 16, 21, 2, 21, 16, 6, 20, 28, 2, 15, 5, 28, 7, 2, 21, 28, 14, 22, 21, 21, 16, 15, 28, 17, 10, 6, 4, 6, 20, 28, 21, 16, 28, 3, 6, 28, 13, 2, 5, 13, 6, 5, 28, 16, 22, 21, 28, 10, 15, 28, 21, 9, 10, 4, 12, 28, 17, 6, 17, 17, 6, 19, 6, 5, 28, 7, 13, 16, 22, 19, 28, 7, 2, 21, 21, 6, 15, 6, 5, 28, 20, 2, 22, 4, 6]\n",
"Groundtrue label: HE HOPED THERE WOULD BE STEW FOR DINNER TURNIPS AND CARROTS AND BRUISED POTATOES AND FAT MUTTON PIECES TO BE LADLED OUT IN THICK PEPPERED FLOUR FATTENED SAUCE\n"
]
},
{
"data": {
"text/html": [
"\n",
" \n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# A long audio sample\n",
"x1 = load_audio(os.path.join(data_dir, \"LibriSpeech_dataset/test_clean/wav/1089-134686-0000.wav\"))\n",
"label1, encoded_label1 = parse_transcript(os.path.join(data_dir, \"LibriSpeech_dataset/test_clean/txt/1089-134686-0000.txt\"))\n",
"print(\"Encoded label: \", encoded_label1)\n",
"print(\"Groundtrue label: \", label1)\n",
"display_waveform(x1, title=\"Long Sample\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Encoded label: [21, 9, 6, 28, 22, 15, 10, 23, 6, 19, 20, 10, 21, 26]\n",
"Groundtrue label: THE UNIVERSITY\n"
]
},
{
"data": {
"text/html": [
"\n",
" \n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# A short audio sample\n",
"x2 = load_audio(os.path.join(data_dir, \"LibriSpeech_dataset/test_clean/wav/1089-134691-0003.wav\"))\n",
"label2, encoded_label2 = parse_transcript(os.path.join(data_dir, \"LibriSpeech_dataset/test_clean/txt/1089-134691-0003.txt\"))\n",
"print(\"Encoded label: \", encoded_label2)\n",
"print(\"Groundtrue label: \", label2)\n",
"display_waveform(x2, title=\"Short Sample\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Encoded label: [2, 8, 2, 10, 15, 28, 2, 8, 2, 10, 15]\n",
"Groundtrue label: AGAIN AGAIN\n"
]
},
{
"data": {
"text/html": [
"\n",
" \n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Another short audio sample\n",
"x3 = load_audio(os.path.join(data_dir, \"LibriSpeech_dataset/test_clean/wav/1089-134691-0018.wav\"))\n",
"label3, encoded_label3 = parse_transcript(os.path.join(data_dir, \"LibriSpeech_dataset/test_clean/txt/1089-134691-0018.txt\"))\n",
"print(\"Encoded label: \", encoded_label3)\n",
"print(\"Groundtrue label: \", label3)\n",
"display_waveform(x3, title=\"Short Sample\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. The Estimator Performance"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3.1 Get Transcription Outputs"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Groundtruth label: HE HOPED THERE WOULD BE STEW FOR DINNER TURNIPS AND CARROTS AND BRUISED POTATOES AND FAT MUTTON PIECES TO BE LADLED OUT IN THICK PEPPERED FLOUR FATTENED SAUCE\n",
"Predicted label: HE HOPED THERE WOULD BE STEW FOR DINNER TURNIPS AND CARROTS AND BRUISED POTATOES AND FAT MUTTON PIECES TO BE LADLED OUT IN THICK PEPPERD FLOUR FAT AND SAUCE\n"
]
}
],
"source": [
"pred1 = speech_recognizer.predict(np.array([x1]), transcription_output=True)\n",
"print(\"Groundtruth label: \", label1)\n",
"print(\"Predicted label: \", pred1[0])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Groundtruth label: THE UNIVERSITY\n",
"Predicted label: THE UNIVERSITY\n"
]
}
],
"source": [
"pred2 = speech_recognizer.predict(np.array([x2]), transcription_output=True)\n",
"print(\"Groundtruth label: \", label2)\n",
"print(\"Predicted label: \", pred2[0])"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Groundtruth label: AGAIN AGAIN\n",
"Predicted label: AGAIN AGAIN\n"
]
}
],
"source": [
"pred3 = speech_recognizer.predict(np.array([x3]), transcription_output=True)\n",
"print(\"Groundtruth label: \", label3)\n",
"print(\"Predicted label: \", pred3[0])"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Predicted labels: ['HE HOPED THERE WOULD BE STEW FOR DINNER TURNIPS AND CARROTS AND BRUISED POTATOES AND FAT MUTTON PIECES TO BE LADLED OUT IN THICK PEPPERD FLOUR FAT AND SAUCE'\n",
" 'THE UNIVERSITY' 'AGAIN AGAIN']\n"
]
}
],
"source": [
"x = np.array([x1, x2, x3])\n",
"pred_all = speech_recognizer.predict(x, transcription_output=True)\n",
"print(\"Predicted labels: \", pred_all)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. Imperceptible ASR Attack"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"global_max_length = int(np.max([len(x2), len(x3)]))\n",
"\n",
"# Define an Imperceptible ASR attack\n",
"asr_attack = ImperceptibleASRPyTorch(\n",
" estimator=speech_recognizer,\n",
" eps=0.05,\n",
" max_iter_1=100,\n",
" max_iter_2=500,\n",
" learning_rate_1=0.00002,\n",
" learning_rate_2=0.00002,\n",
" optimizer_1=torch.optim.Adam,\n",
" optimizer_2=torch.optim.Adam,\n",
" global_max_length=global_max_length,\n",
" initial_rescale=1.0,\n",
" decrease_factor_eps=0.8,\n",
" num_iter_decrease_eps=20,\n",
" alpha=1.2,\n",
" increase_factor_alpha=1.2,\n",
" num_iter_increase_alpha=20,\n",
" decrease_factor_alpha=0.8,\n",
" num_iter_decrease_alpha=20,\n",
" win_length=2048,\n",
" hop_length=512,\n",
" n_fft=2048,\n",
" batch_size=2,\n",
" use_amp=False,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"First stage, step 0, loss 69.741150\n",
"First stage, step 5, loss 58.773037\n",
"First stage, step 10, loss 51.136898\n",
"First stage, step 15, loss 48.239479\n",
"First stage, step 20, loss 40.845539\n",
"First stage, step 25, loss 36.802544\n",
"First stage, step 30, loss 32.711205\n",
"First stage, step 35, loss 29.145208\n",
"First stage, step 40, loss 26.394424\n",
"First stage, step 45, loss 44.487350\n",
"First stage, step 50, loss 42.173428\n",
"First stage, step 55, loss 39.783951\n",
"First stage, step 60, loss 37.151741\n",
"First stage, step 65, loss 34.456093\n",
"First stage, step 70, loss 30.969639\n",
"First stage, step 75, loss 29.559433\n",
"First stage, step 80, loss 27.740614\n",
"First stage, step 85, loss 26.211428\n",
"First stage, step 90, loss 24.573879\n",
"First stage, step 95, loss 22.899168\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/minhtn/ibm/installation/miniconda3/lib/python3.6/site-packages/torch/functional.py:581: UserWarning: stft will soon require the return_complex parameter be given for real inputs, and will further require that return_complex=True in a future PyTorch release. (Triggered internally at /pytorch/aten/src/ATen/native/SpectralOps.cpp:639.)\n",
" normalized, onesided, return_complex)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Second stage, step 0, loss 887.815700\n",
"Second stage, step 5, loss 543.191106\n",
"Second stage, step 10, loss 395.780205\n",
"Second stage, step 15, loss 291.439771\n",
"Second stage, step 20, loss 216.803986\n",
"Second stage, step 25, loss 132.597252\n",
"Second stage, step 30, loss 102.339386\n",
"Second stage, step 35, loss 79.895121\n",
"Second stage, step 40, loss 62.953748\n",
"Second stage, step 45, loss 42.079926\n",
"Second stage, step 50, loss 34.634669\n",
"Second stage, step 55, loss 28.961950\n",
"Second stage, step 60, loss 24.520864\n",
"Second stage, step 65, loss 18.534713\n",
"Second stage, step 70, loss 16.328669\n",
"Second stage, step 75, loss 14.545206\n",
"Second stage, step 80, loss 13.060859\n",
"Second stage, step 85, loss 10.806125\n",
"Second stage, step 90, loss 9.900532\n",
"Second stage, step 95, loss 9.133247\n",
"Second stage, step 100, loss 8.474490\n",
"Second stage, step 105, loss 7.365411\n",
"Second stage, step 110, loss 6.920150\n",
"Second stage, step 115, loss 6.530938\n",
"Second stage, step 120, loss 6.185102\n",
"Second stage, step 125, loss 5.526468\n",
"Second stage, step 130, loss 5.271661\n",
"Second stage, step 135, loss 5.045556\n",
"Second stage, step 140, loss 4.838332\n",
"Second stage, step 145, loss 4.390867\n",
"Second stage, step 150, loss 4.219669\n",
"Second stage, step 155, loss 4.059496\n",
"Second stage, step 160, loss 3.907564\n",
"Second stage, step 165, loss 3.558051\n",
"Second stage, step 170, loss 3.421292\n",
"Second stage, step 175, loss 3.290129\n",
"Second stage, step 180, loss 3.162481\n",
"Second stage, step 185, loss 2.858444\n",
"Second stage, step 190, loss 2.736934\n",
"Second stage, step 195, loss 2.617957\n",
"Second stage, step 200, loss 2.500414\n",
"Second stage, step 205, loss 2.229840\n",
"Second stage, step 210, loss 2.124433\n",
"Second stage, step 215, loss 2.024690\n",
"Second stage, step 220, loss 1.932897\n",
"Second stage, step 225, loss 1.704043\n",
"Second stage, step 230, loss 1.623649\n",
"Second stage, step 235, loss 1.550399\n",
"Second stage, step 240, loss 1.483242\n",
"Second stage, step 245, loss 1.556927\n",
"Second stage, step 250, loss 1.504139\n",
"Second stage, step 255, loss 1.453637\n",
"Second stage, step 260, loss 1.400224\n",
"Second stage, step 265, loss 1.508619\n",
"Second stage, step 270, loss 1.453266\n",
"Second stage, step 275, loss 1.413941\n",
"Second stage, step 280, loss 1.369114\n",
"Second stage, step 285, loss 1.469223\n",
"Second stage, step 290, loss 1.431475\n",
"Second stage, step 295, loss 1.389646\n",
"Second stage, step 300, loss 1.353665\n",
"Second stage, step 305, loss 1.458392\n",
"Second stage, step 310, loss 1.424629\n",
"Second stage, step 315, loss 1.391908\n",
"Second stage, step 320, loss 1.348441\n",
"Second stage, step 325, loss 1.456609\n",
"Second stage, step 330, loss 1.423668\n",
"Second stage, step 335, loss 1.381091\n",
"Second stage, step 340, loss 1.349383\n",
"Second stage, step 345, loss 1.442660\n",
"Second stage, step 350, loss 1.427358\n",
"Second stage, step 355, loss 1.400758\n",
"Second stage, step 360, loss 1.349413\n",
"Second stage, step 365, loss 1.455723\n",
"Second stage, step 370, loss 1.457263\n",
"Second stage, step 375, loss 1.430490\n",
"Second stage, step 380, loss 1.408006\n",
"Second stage, step 385, loss 1.471980\n",
"Second stage, step 390, loss 1.503218\n",
"Second stage, step 395, loss 1.448039\n",
"Second stage, step 400, loss 1.399196\n",
"Second stage, step 405, loss 1.478893\n",
"Second stage, step 410, loss 1.549467\n",
"Second stage, step 415, loss 1.458022\n",
"Second stage, step 420, loss 1.429721\n",
"Second stage, step 425, loss 1.547260\n",
"Second stage, step 430, loss 1.545492\n",
"Second stage, step 435, loss 1.543032\n",
"Second stage, step 440, loss 1.577160\n",
"Second stage, step 445, loss 1.664575\n",
"Second stage, step 450, loss 1.583724\n",
"Second stage, step 455, loss 1.655049\n",
"Second stage, step 460, loss 1.600740\n",
"Second stage, step 465, loss 1.613063\n",
"Second stage, step 470, loss 1.850034\n",
"Second stage, step 475, loss 1.667144\n",
"Second stage, step 480, loss 1.567518\n",
"Second stage, step 485, loss 1.633470\n",
"Second stage, step 490, loss 1.770336\n",
"Second stage, step 495, loss 1.781754\n"
]
}
],
"source": [
"# Target labels\n",
"y = np.array(['THE UNIVERSAL', 'GAIN GAIN'])\n",
"\n",
"# Generate adversarial examples\n",
"x_adv = asr_attack.generate(np.array([x2, x3]), y)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"adv_transcriptions = speech_recognizer.predict(x_adv, batch_size=2, transcription_output=True)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Groundtruth transcriptions: ['THE UNIVERSITY' 'AGAIN AGAIN']\n",
"Target transcriptions: ['THE UNIVERSAL' 'GAIN GAIN']\n",
"Adversarial transcriptions: ['THE UNIVERSAL' 'GAIN GAIN']\n"
]
}
],
"source": [
"print(\"Groundtruth transcriptions: \", np.array([label2, label3]))\n",
"print(\"Target transcriptions: \", y)\n",
"print(\"Adversarial transcriptions: \", adv_transcriptions)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"display_waveform(x_adv[0][:len(x2)], title=\"THE UNIVERSITY is attacked to THE UNIVERSAL\")"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"display_waveform(x_adv[1][:len(x3)], title=\"AGAIN AGAIN is attacked to GAIN GAIN\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.10"
}
},
"nbformat": 4,
"nbformat_minor": 4
}