{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# The Very Basics of Musical Instruments Classification using Machine Learning\n", "## MFCC, SVM Grid Search\n", "\n", "
\n", "\n", "

\n", "\"Business\n", "

\n", "
\n", "\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%%html\n", "" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Imports\n", "\n", "#General\n", "import numpy as np\n", "import pickle\n", "import itertools\n", "\n", "# System\n", "import os, fnmatch\n", "\n", "# Visualization\n", "import seaborn #visualization library, must be imported before all other plotting libraries\n", "import matplotlib.pyplot as plt\n", "from IPython.core.display import HTML, display\n", "\n", "# Machine Learning\n", "from sklearn.preprocessing import LabelEncoder\n", "from sklearn.preprocessing import StandardScaler\n", "from sklearn.model_selection import StratifiedShuffleSplit, GridSearchCV\n", "from sklearn.neighbors import KNeighborsClassifier\n", "from sklearn.metrics import recall_score, precision_score, accuracy_score\n", "from sklearn.metrics import confusion_matrix, f1_score, classification_report\n", "from sklearn.svm import LinearSVC, SVC\n", "from sklearn.externals import joblib\n", "from sklearn.metrics import classification_report\n", "\n", "# Random Seed\n", "from numpy.random import seed\n", "seed(1)\n", "\n", "# Audio\n", "import librosa.display, librosa\n", "\n", "# Configurations\n", "path='./audio/london_phill_dataset_multi/'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Auxiliary Functions" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Function to Display a Website\n", "def show_web(url):\n", " html_code='
' \\\n", "\t\t% (url)\n", " display(HTML(html_code))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Get filenames" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "found 600 audio files in ./audio/london_phill_dataset_multi/\n" ] } ], "source": [ "# Get files in data path\n", "\n", "files = []\n", "for root, dirnames, filenames in os.walk(path):\n", " for filename in fnmatch.filter(filenames, '*.mp3'):\n", " files.append(os.path.join(root, filename))\n", "\n", "print(\"found %d audio files in %s\"%(len(files),path))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load Labels" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "filename=\"instruments_labels.pl\"\n", "# Load mfcc features from saved file\n", "with open(filename, \"rb\") as f:\n", " classes_num = pickle.load( open( filename, \"rb\" ) )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Parameters for MFCC" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# Parameters\n", "# Signal Processing Parameters\n", "fs = 44100 # Sampling Frequency\n", "n_fft = 2048 # length of the FFT window\n", "hop_length = 512 # Number of samples between successive frames\n", "n_mels = 128 # Number of Mel bands\n", "n_mfcc = 13 # Number of MFCCs\n", "\n", "# Machine Learning Parameters\n", "testset_size = 0.25 #Percentage of data for Testing\n", "n_neighbors=1 # Number of neighbors for kNN Classifier" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Save / Load Feature Vector" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "filename=\"mfcc_feature_vectors.pl\"\n", "# Load mfcc features from saved file\n", "with open(filename, \"rb\") as f:\n", " scaled_feature_vectors = pickle.load( open( filename, \"rb\" ) )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train and Test Sets" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# Create Train and Test Set\n", "splitter = StratifiedShuffleSplit(n_splits=1, test_size=testset_size, random_state=0)\n", "splits = splitter.split(scaled_feature_vectors, classes_num)\n", "for train_index, test_index in splits:\n", " train_set = scaled_feature_vectors[train_index]\n", " test_set = scaled_feature_vectors[test_index]\n", " train_classes = classes_num[train_index]\n", " test_classes = classes_num[test_index]" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train_set shape: (450, 13)\n", "test_set shape: (150, 13)\n", "train_classes shape: (450,)\n", "test_classes shape: (150,)\n" ] } ], "source": [ "# Check Set Shapes\n", "print(\"train_set shape:\",train_set.shape)\n", "print(\"test_set shape:\",test_set.shape)\n", "print(\"train_classes shape:\",train_classes.shape)\n", "print(\"test_classes shape:\",test_classes.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## SVM Classification with Grid Search" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_web(\"https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html\")" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "# SVM Grid Search\n", "C_range = np.logspace(-2, 10, 13)\n", "gamma_range = np.logspace(-9, 3, 13)\n", "param_grid = dict(gamma=gamma_range, C=C_range)\n", "grid_svm = GridSearchCV(SVC(), param_grid=param_grid, cv=5)\n" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'gamma': array([1.e-09, 1.e-08, 1.e-07, 1.e-06, 1.e-05, 1.e-04, 1.e-03, 1.e-02,\n", " 1.e-01, 1.e+00, 1.e+01, 1.e+02, 1.e+03]), 'C': array([1.e-02, 1.e-01, 1.e+00, 1.e+01, 1.e+02, 1.e+03, 1.e+04, 1.e+05,\n", " 1.e+06, 1.e+07, 1.e+08, 1.e+09, 1.e+10])}\n" ] } ], "source": [ "print (param_grid)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The best parameters are {'C': 10.0, 'gamma': 0.1} with a score of 0.98\n" ] } ], "source": [ "# SVM\n", "grid_svm.fit(train_set, train_classes)\n", "print(\"The best parameters are %s with a score of %0.2f\"\n", " % (grid_svm.best_params_, grid_svm.best_score_))\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Save / Load Trained Model" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['trained_grid_SVM.joblib']" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Save\n", "joblib.dump(grid_svm, 'trained_grid_SVM.joblib')\n", "#Load\n", "#svclassifier = joblib.load('trained_grid_SVM.joblib') " ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# Predict using the Test Set\n", "predicted_labels = grid_svm.predict(test_set)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluation" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Recall: [1. 1. 0.96 1. 1. 1. ]\n", "Precision: [1. 1. 1. 1. 0.96153846 1. ]\n", "F1-Score: [1. 1. 0.97959184 1. 0.98039216 1. ]\n", "Accuracy: 0.99 , 149\n", "Number of samples: 150\n", " precision recall f1-score support\n", "\n", " 0 1.00 1.00 1.00 25\n", " 1 1.00 1.00 1.00 25\n", " 2 1.00 0.96 0.98 25\n", " 3 1.00 1.00 1.00 25\n", " 4 0.96 1.00 0.98 25\n", " 5 1.00 1.00 1.00 25\n", "\n", " micro avg 0.99 0.99 0.99 150\n", " macro avg 0.99 0.99 0.99 150\n", "weighted avg 0.99 0.99 0.99 150\n", "\n" ] } ], "source": [ "# Recall - the ability of the classifier to find all the positive samples\n", "print(\"Recall: \", recall_score(test_classes, predicted_labels,average=None))\n", "\n", "# Precision - The precision is intuitively the ability of the classifier not to \n", "#label as positive a sample that is negative\n", "print(\"Precision: \", precision_score(test_classes, predicted_labels,average=None))\n", "\n", "# F1-Score - The F1 score can be interpreted as a weighted average of the precision \n", "#and recall\n", "print(\"F1-Score: \", f1_score(test_classes, predicted_labels, average=None))\n", "\n", "# Accuracy - the number of correctly classified samples\n", "print(\"Accuracy: %.2f ,\" % accuracy_score(test_classes, predicted_labels,normalize=True), accuracy_score(test_classes, predicted_labels,normalize=False) )\n", "print(\"Number of samples:\",test_classes.shape[0])\n", "\n", "print(classification_report(test_classes, predicted_labels))" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "# Compute confusion matrix\n", "cnf_matrix = confusion_matrix(test_classes, predicted_labels)\n", "np.set_printoptions(precision=2)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "# Function to Plot Confusion Matrix\n", "# http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html\n", "def plot_confusion_matrix(cm, classes,\n", " normalize=False,\n", " title='Confusion matrix',\n", " cmap=plt.cm.Blues):\n", " \"\"\"\n", " This function prints and plots the confusion matrix.\n", " Normalization can be applied by setting `normalize=True`.\n", " \n", " if normalize:\n", " cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n", " print(\"Normalized confusion matrix\")\n", " else:\n", " print('Confusion matrix, without normalization')\n", " \"\"\"\n", " #print(cm)\n", "\n", " plt.imshow(cm, interpolation='nearest', cmap=cmap)\n", " plt.title(title)\n", " plt.colorbar()\n", " tick_marks = np.arange(len(classes))\n", " plt.xticks(tick_marks, classes, rotation=45)\n", " plt.yticks(tick_marks, classes)\n", "\n", " fmt = '.2f' if normalize else 'd'\n", " thresh = cm.max() / 2.\n", " for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):\n", " plt.text(j, i, format(cm[i, j], fmt),\n", " horizontalalignment=\"center\",\n", " color=\"white\" if cm[i, j] > thresh else \"black\")\n", "\n", " plt.tight_layout()\n", " plt.ylabel('True label')\n", " plt.xlabel('Predicted label')" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Plot non-normalized confusion matrix\n", "classes=['flute','sax','oboe', 'cello','trumpet','viola']\n", "plt.figure(figsize=(18,13))\n", "plot_confusion_matrix(cnf_matrix, classes=classes,\n", " title='Confusion matrix, without normalization')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.6" } }, "nbformat": 4, "nbformat_minor": 2 }