{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# [FMA: A Dataset For Music Analysis](https://github.com/mdeff/fma)\n", "\n", "Michaƫl Defferrard, Kirell Benzi, Pierre Vandergheynst, Xavier Bresson, EPFL LTS2.\n", "\n", "## Baselines\n", "\n", "* This notebook evaluates standard classifiers from scikit-learn on the provided features.\n", "* Moreover, it evaluates Deep Learning models on both audio and spectrograms." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using TensorFlow backend.\n" ] } ], "source": [ "import time\n", "import os\n", "\n", "import IPython.display as ipd\n", "from tqdm import tqdm_notebook\n", "import numpy as np\n", "import pandas as pd\n", "import keras\n", "from keras.layers import Activation, Dense, Conv1D, Conv2D, MaxPooling1D, Flatten, Reshape\n", "\n", "from sklearn.utils import shuffle\n", "from sklearn.preprocessing import MultiLabelBinarizer, LabelEncoder, LabelBinarizer, StandardScaler\n", "from sklearn.linear_model import LogisticRegression\n", "from sklearn.neighbors import KNeighborsClassifier\n", "from sklearn.svm import SVC, LinearSVC\n", "#from sklearn.gaussian_process import GaussianProcessClassifier\n", "#from sklearn.gaussian_process.kernels import RBF\n", "from sklearn.tree import DecisionTreeClassifier\n", "from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier\n", "from sklearn.neural_network import MLPClassifier\n", "from sklearn.naive_bayes import GaussianNB\n", "from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis\n", "from sklearn.multiclass import OneVsRestClassifier\n", "\n", "import utils" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "data": { "text/plain": [ "((106574, 52), (106574, 518), (14511, 249))" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "AUDIO_DIR = os.environ.get('AUDIO_DIR')\n", "\n", "tracks = utils.load('data/fma_metadata/tracks.csv')\n", "features = utils.load('data/fma_metadata/features.csv')\n", "echonest = utils.load('data/fma_metadata/echonest.csv')\n", "\n", "np.testing.assert_array_equal(features.index, tracks.index)\n", "assert echonest.index.isin(tracks.index).all()\n", "\n", "tracks.shape, features.shape, echonest.shape" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "## Subset" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Not enough Echonest features: (13554, 767)\n" ] }, { "data": { "text/plain": [ "((25000, 52), (25000, 518))" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "subset = tracks.index[tracks['set', 'subset'] <= 'medium']\n", "\n", "assert subset.isin(tracks.index).all()\n", "assert subset.isin(features.index).all()\n", "\n", "features_all = features.join(echonest, how='inner').sort_index(axis=1)\n", "print('Not enough Echonest features: {}'.format(features_all.shape))\n", "\n", "tracks = tracks.loc[subset]\n", "features_all = features.loc[subset]\n", "\n", "tracks.shape, features_all.shape" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "19922 training examples, 2505 validation examples, 2573 testing examples\n", "Top genres (16): ['Blues', 'Classical', 'Country', 'Easy Listening', 'Electronic', 'Experimental', 'Folk', 'Hip-Hop', 'Instrumental', 'International', 'Jazz', 'Old-Time / Historic', 'Pop', 'Rock', 'Soul-RnB', 'Spoken']\n", "All genres (151): [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 25, 26, 27, 30, 31, 32, 33, 36, 37, 38, 41, 42, 43, 45, 46, 47, 49, 53, 58, 63, 64, 65, 66, 70, 71, 74, 76, 77, 79, 81, 83, 85, 86, 88, 89, 90, 92, 94, 97, 98, 100, 101, 102, 103, 107, 109, 111, 113, 117, 118, 125, 130, 137, 138, 166, 167, 169, 171, 172, 174, 177, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 214, 224, 232, 236, 240, 247, 250, 267, 286, 296, 297, 311, 314, 322, 337, 359, 360, 361, 362, 374, 378, 400, 401, 404, 428, 439, 440, 441, 442, 443, 456, 468, 491, 495, 502, 504, 514, 524, 538, 539, 542, 580, 602, 619, 651, 659, 695, 741, 763, 808, 810, 811, 906, 1032, 1060, 1193, 1235]\n" ] } ], "source": [ "train = tracks.index[tracks['set', 'split'] == 'training']\n", "val = tracks.index[tracks['set', 'split'] == 'validation']\n", "test = tracks.index[tracks['set', 'split'] == 'test']\n", "\n", "print('{} training examples, {} validation examples, {} testing examples'.format(*map(len, [train, val, test])))\n", "\n", "genres = list(LabelEncoder().fit(tracks['track', 'genre_top']).classes_)\n", "#genres = list(tracks['track', 'genre_top'].unique())\n", "print('Top genres ({}): {}'.format(len(genres), genres))\n", "genres = list(MultiLabelBinarizer().fit(tracks['track', 'genres_all']).classes_)\n", "print('All genres ({}): {}'.format(len(genres), genres))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1 Multiple classifiers and feature sets\n", "\n", "Todo:\n", "* Cross-validation for hyper-parameters.\n", "* Dimensionality reduction?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1.1 Pre-processing" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false, "deletable": true, "editable": true, "scrolled": false }, "outputs": [], "source": [ "def pre_process(tracks, features, columns, multi_label=False, verbose=False):\n", " if not multi_label:\n", " # Assign an integer value to each genre.\n", " enc = LabelEncoder()\n", " labels = tracks['track', 'genre_top']\n", " #y = enc.fit_transform(tracks['track', 'genre_top'])\n", " else:\n", " # Create an indicator matrix.\n", " enc = MultiLabelBinarizer()\n", " labels = tracks['track', 'genres_all']\n", " #labels = tracks['track', 'genres']\n", "\n", " # Split in training, validation and testing sets.\n", " y_train = enc.fit_transform(labels[train])\n", " y_val = enc.transform(labels[val])\n", " y_test = enc.transform(labels[test])\n", " X_train = features.loc[train, columns].as_matrix()\n", " X_val = features.loc[val, columns].as_matrix()\n", " X_test = features.loc[test, columns].as_matrix()\n", " \n", " X_train, y_train = shuffle(X_train, y_train, random_state=42)\n", " \n", " # Standardize features by removing the mean and scaling to unit variance.\n", " scaler = StandardScaler(copy=False)\n", " scaler.fit_transform(X_train)\n", " scaler.transform(X_val)\n", " scaler.transform(X_test)\n", " \n", " return y_train, y_val, y_test, X_train, X_val, X_test" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1.2 Single genre" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [], "source": [ "def test_classifiers_features(classifiers, feature_sets, multi_label=False):\n", " columns = list(classifiers.keys()).insert(0, 'dim')\n", " scores = pd.DataFrame(columns=columns, index=feature_sets.keys())\n", " times = pd.DataFrame(columns=classifiers.keys(), index=feature_sets.keys())\n", " for fset_name, fset in tqdm_notebook(feature_sets.items(), desc='features'):\n", " y_train, y_val, y_test, X_train, X_val, X_test = pre_process(tracks, features_all, fset, multi_label)\n", " scores.loc[fset_name, 'dim'] = X_train.shape[1]\n", " for clf_name, clf in classifiers.items(): # tqdm_notebook(classifiers.items(), desc='classifiers', leave=False):\n", " t = time.process_time()\n", " clf.fit(X_train, y_train)\n", " score = clf.score(X_test, y_test)\n", " scores.loc[fset_name, clf_name] = score\n", " times.loc[fset_name, clf_name] = time.process_time() - t\n", " return scores, times\n", "\n", "def format_scores(scores):\n", " def highlight(s):\n", " is_max = s == max(s[1:])\n", " return ['background-color: yellow' if v else '' for v in is_max]\n", " scores = scores.style.apply(highlight, axis=1)\n", " return scores.format('{:.2%}', subset=pd.IndexSlice[:, scores.columns[1]:])" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false, "deletable": true, "editable": true, "scrolled": false }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "532abec0e8b54a56948a944e0d991621" } }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/home/ubuntu/.pyenv/versions/3.6.0/envs/fma/lib/python3.6/site-packages/sklearn/discriminant_analysis.py:695: UserWarning: Variables are collinear\n", " warnings.warn(\"Variables are collinear\")\n", "/home/ubuntu/.pyenv/versions/3.6.0/envs/fma/lib/python3.6/site-packages/sklearn/discriminant_analysis.py:695: UserWarning: Variables are collinear\n", " warnings.warn(\"Variables are collinear\")\n", "/home/ubuntu/.pyenv/versions/3.6.0/envs/fma/lib/python3.6/site-packages/sklearn/discriminant_analysis.py:695: UserWarning: Variables are collinear\n", " warnings.warn(\"Variables are collinear\")\n", "/home/ubuntu/.pyenv/versions/3.6.0/envs/fma/lib/python3.6/site-packages/sklearn/discriminant_analysis.py:695: UserWarning: Variables are collinear\n", " warnings.warn(\"Variables are collinear\")\n", "/home/ubuntu/.pyenv/versions/3.6.0/envs/fma/lib/python3.6/site-packages/sklearn/discriminant_analysis.py:695: UserWarning: Variables are collinear\n", " warnings.warn(\"Variables are collinear\")\n", "/home/ubuntu/.pyenv/versions/3.6.0/envs/fma/lib/python3.6/site-packages/sklearn/discriminant_analysis.py:695: UserWarning: Variables are collinear\n", " warnings.warn(\"Variables are collinear\")\n", "/home/ubuntu/.pyenv/versions/3.6.0/envs/fma/lib/python3.6/site-packages/sklearn/discriminant_analysis.py:695: UserWarning: Variables are collinear\n", " warnings.warn(\"Variables are collinear\")\n", "/home/ubuntu/.pyenv/versions/3.6.0/envs/fma/lib/python3.6/site-packages/sklearn/discriminant_analysis.py:695: UserWarning: Variables are collinear\n", " warnings.warn(\"Variables are collinear\")\n", "/home/ubuntu/.pyenv/versions/3.6.0/envs/fma/lib/python3.6/site-packages/sklearn/discriminant_analysis.py:695: UserWarning: Variables are collinear\n", " warnings.warn(\"Variables are collinear\")\n", "/home/ubuntu/.pyenv/versions/3.6.0/envs/fma/lib/python3.6/site-packages/sklearn/discriminant_analysis.py:695: UserWarning: Variables are collinear\n", " warnings.warn(\"Variables are collinear\")\n", "/home/ubuntu/.pyenv/versions/3.6.0/envs/fma/lib/python3.6/site-packages/sklearn/discriminant_analysis.py:695: UserWarning: Variables are collinear\n", " warnings.warn(\"Variables are collinear\")\n", "/home/ubuntu/.pyenv/versions/3.6.0/envs/fma/lib/python3.6/site-packages/sklearn/discriminant_analysis.py:695: UserWarning: Variables are collinear\n", " warnings.warn(\"Variables are collinear\")\n", "/home/ubuntu/.pyenv/versions/3.6.0/envs/fma/lib/python3.6/site-packages/sklearn/discriminant_analysis.py:695: UserWarning: Variables are collinear\n", " warnings.warn(\"Variables are collinear\")\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "text/html": [ "\n", " \n", "\n", " \n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", " \n", " \n", " \n", " \n", " \n", " dim\n", " \n", " \n", " \n", " \n", " LR\n", " \n", " \n", " \n", " \n", " kNN\n", " \n", " \n", " \n", " \n", " SVCrbf\n", " \n", " \n", " \n", " \n", " SVCpoly1\n", " \n", " \n", " \n", " \n", " linSVC1\n", " \n", " \n", " \n", " \n", " linSVC2\n", " \n", " \n", " \n", " \n", " DT\n", " \n", " \n", " \n", " \n", " RF\n", " \n", " \n", " \n", " \n", " AdaBoost\n", " \n", " \n", " \n", " \n", " MLP1\n", " \n", " \n", " \n", " \n", " MLP2\n", " \n", " \n", " \n", " \n", " NB\n", " \n", " \n", " \n", " \n", " QDA\n", " \n", " \n", "
\n", " chroma_cens\n", " \n", " \n", " \n", " \n", " 84\n", " \n", " \n", " \n", " \n", " 39.25%\n", " \n", " \n", " \n", " \n", " 37.50%\n", " \n", " \n", " \n", " \n", " 42.29%\n", " \n", " \n", " \n", " \n", " 38.63%\n", " \n", " \n", " \n", " \n", " 39.29%\n", " \n", " \n", " \n", " \n", " 39.29%\n", " \n", " \n", " \n", " \n", " 35.68%\n", " \n", " \n", " \n", " \n", " 33.77%\n", " \n", " \n", " \n", " \n", " 30.86%\n", " \n", " \n", " \n", " \n", " 40.19%\n", " \n", " \n", " \n", " \n", " 34.55%\n", " \n", " \n", " \n", " \n", " 9.99%\n", " \n", " \n", " \n", " \n", " 24.64%\n", " \n", " \n", "
\n", " chroma_cqt\n", " \n", " \n", " \n", " \n", " 84\n", " \n", " \n", " \n", " \n", " 40.07%\n", " \n", " \n", " \n", " \n", " 40.03%\n", " \n", " \n", " \n", " \n", " 44.27%\n", " \n", " \n", " \n", " \n", " 39.99%\n", " \n", " \n", " \n", " \n", " 41.39%\n", " \n", " \n", " \n", " \n", " 40.58%\n", " \n", " \n", " \n", " \n", " 35.45%\n", " \n", " \n", " \n", " \n", " 36.46%\n", " \n", " \n", " \n", " \n", " 35.72%\n", " \n", " \n", " \n", " \n", " 44.81%\n", " \n", " \n", " \n", " \n", " 39.60%\n", " \n", " \n", " \n", " \n", " 1.55%\n", " \n", " \n", " \n", " \n", " 3.42%\n", " \n", " \n", "
\n", " chroma_stft\n", " \n", " \n", " \n", " \n", " 84\n", " \n", " \n", " \n", " \n", " 43.61%\n", " \n", " \n", " \n", " \n", " 43.92%\n", " \n", " \n", " \n", " \n", " 48.31%\n", " \n", " \n", " \n", " \n", " 43.65%\n", " \n", " \n", " \n", " \n", " 44.35%\n", " \n", " \n", " \n", " \n", " 43.10%\n", " \n", " \n", " \n", " \n", " 39.88%\n", " \n", " \n", " \n", " \n", " 37.31%\n", " \n", " \n", " \n", " \n", " 35.25%\n", " \n", " \n", " \n", " \n", " 48.50%\n", " \n", " \n", " \n", " \n", " 44.77%\n", " \n", " \n", " \n", " \n", " 4.20%\n", " \n", " \n", " \n", " \n", " 5.91%\n", " \n", " \n", "
\n", " mfcc\n", " \n", " \n", " \n", " \n", " 140\n", " \n", " \n", " \n", " \n", " 57.83%\n", " \n", " \n", " \n", " \n", " 54.99%\n", " \n", " \n", " \n", " \n", " 60.98%\n", " \n", " \n", " \n", " \n", " 59.66%\n", " \n", " \n", " \n", " \n", " 59.19%\n", " \n", " \n", " \n", " \n", " 56.98%\n", " \n", " \n", " \n", " \n", " 45.86%\n", " \n", " \n", " \n", " \n", " 44.77%\n", " \n", " \n", " \n", " \n", " 41.31%\n", " \n", " \n", " \n", " \n", " 53.17%\n", " \n", " \n", " \n", " \n", " 53.21%\n", " \n", " \n", " \n", " \n", " 41.86%\n", " \n", " \n", " \n", " \n", " 48.39%\n", " \n", " \n", "
\n", " rmse\n", " \n", " \n", " \n", " \n", " 7\n", " \n", " \n", " \n", " \n", " 37.31%\n", " \n", " \n", " \n", " \n", " 38.52%\n", " \n", " \n", " \n", " \n", " 38.90%\n", " \n", " \n", " \n", " \n", " 37.70%\n", " \n", " \n", " \n", " \n", " 37.54%\n", " \n", " \n", " \n", " \n", " 37.35%\n", " \n", " \n", " \n", " \n", " 38.63%\n", " \n", " \n", " \n", " \n", " 36.65%\n", " \n", " \n", " \n", " \n", " 34.67%\n", " \n", " \n", " \n", " \n", " 39.06%\n", " \n", " \n", " \n", " \n", " 38.75%\n", " \n", " \n", " \n", " \n", " 11.78%\n", " \n", " \n", " \n", " \n", " 15.04%\n", " \n", " \n", "
\n", " spectral_bandwidth\n", " \n", " \n", " \n", " \n", " 7\n", " \n", " \n", " \n", " \n", " 40.54%\n", " \n", " \n", " \n", " \n", " 45.39%\n", " \n", " \n", " \n", " \n", " 44.46%\n", " \n", " \n", " \n", " \n", " 40.38%\n", " \n", " \n", " \n", " \n", " 40.42%\n", " \n", " \n", " \n", " \n", " 40.61%\n", " \n", " \n", " \n", " \n", " 42.91%\n", " \n", " \n", " \n", " \n", " 43.65%\n", " \n", " \n", " \n", " \n", " 37.47%\n", " \n", " \n", " \n", " \n", " 44.97%\n", " \n", " \n", " \n", " \n", " 44.66%\n", " \n", " \n", " \n", " \n", " 36.18%\n", " \n", " \n", " \n", " \n", " 34.16%\n", " \n", " \n", "
\n", " spectral_centroid\n", " \n", " \n", " \n", " \n", " 7\n", " \n", " \n", " \n", " \n", " 42.40%\n", " \n", " \n", " \n", " \n", " 45.36%\n", " \n", " \n", " \n", " \n", " 45.71%\n", " \n", " \n", " \n", " \n", " 42.09%\n", " \n", " \n", " \n", " \n", " 42.09%\n", " \n", " \n", " \n", " \n", " 42.21%\n", " \n", " \n", " \n", " \n", " 42.67%\n", " \n", " \n", " \n", " \n", " 43.41%\n", " \n", " \n", " \n", " \n", " 42.60%\n", " \n", " \n", " \n", " \n", " 47.84%\n", " \n", " \n", " \n", " \n", " 47.53%\n", " \n", " \n", " \n", " \n", " 33.31%\n", " \n", " \n", " \n", " \n", " 36.11%\n", " \n", " \n", "
\n", " spectral_contrast\n", " \n", " \n", " \n", " \n", " 49\n", " \n", " \n", " \n", " \n", " 50.91%\n", " \n", " \n", " \n", " \n", " 49.55%\n", " \n", " \n", " \n", " \n", " 54.45%\n", " \n", " \n", " \n", " \n", " 49.59%\n", " \n", " \n", " \n", " \n", " 51.81%\n", " \n", " \n", " \n", " \n", " 49.24%\n", " \n", " \n", " \n", " \n", " 43.53%\n", " \n", " \n", " \n", " \n", " 44.38%\n", " \n", " \n", " \n", " \n", " 39.53%\n", " \n", " \n", " \n", " \n", " 52.90%\n", " \n", " \n", " \n", " \n", " 49.16%\n", " \n", " \n", " \n", " \n", " 39.41%\n", " \n", " \n", " \n", " \n", " 41.78%\n", " \n", " \n", "
\n", " spectral_rolloff\n", " \n", " \n", " \n", " \n", " 7\n", " \n", " \n", " \n", " \n", " 41.74%\n", " \n", " \n", " \n", " \n", " 46.25%\n", " \n", " \n", " \n", " \n", " 47.53%\n", " \n", " \n", " \n", " \n", " 41.43%\n", " \n", " \n", " \n", " \n", " 41.62%\n", " \n", " \n", " \n", " \n", " 41.47%\n", " \n", " \n", " \n", " \n", " 45.36%\n", " \n", " \n", " \n", " \n", " 45.47%\n", " \n", " \n", " \n", " \n", " 41.66%\n", " \n", " \n", " \n", " \n", " 48.08%\n", " \n", " \n", " \n", " \n", " 48.54%\n", " \n", " \n", " \n", " \n", " 28.49%\n", " \n", " \n", " \n", " \n", " 28.53%\n", " \n", " \n", "
\n", " tonnetz\n", " \n", " \n", " \n", " \n", " 42\n", " \n", " \n", " \n", " \n", " 40.11%\n", " \n", " \n", " \n", " \n", " 37.31%\n", " \n", " \n", " \n", " \n", " 42.25%\n", " \n", " \n", " \n", " \n", " 40.23%\n", " \n", " \n", " \n", " \n", " 40.15%\n", " \n", " \n", " \n", " \n", " 39.56%\n", " \n", " \n", " \n", " \n", " 35.91%\n", " \n", " \n", " \n", " \n", " 36.96%\n", " \n", " \n", " \n", " \n", " 34.16%\n", " \n", " \n", " \n", " \n", " 40.85%\n", " \n", " \n", " \n", " \n", " 37.16%\n", " \n", " \n", " \n", " \n", " 22.31%\n", " \n", " \n", " \n", " \n", " 23.05%\n", " \n", " \n", "
\n", " zcr\n", " \n", " \n", " \n", " \n", " 7\n", " \n", " \n", " \n", " \n", " 42.29%\n", " \n", " \n", " \n", " \n", " 44.73%\n", " \n", " \n", " \n", " \n", " 45.43%\n", " \n", " \n", " \n", " \n", " 42.95%\n", " \n", " \n", " \n", " \n", " 42.71%\n", " \n", " \n", " \n", " \n", " 42.13%\n", " \n", " \n", " \n", " \n", " 43.61%\n", " \n", " \n", " \n", " \n", " 44.27%\n", " \n", " \n", " \n", " \n", " 40.89%\n", " \n", " \n", " \n", " \n", " 46.44%\n", " \n", " \n", " \n", " \n", " 46.25%\n", " \n", " \n", " \n", " \n", " 30.39%\n", " \n", " \n", " \n", " \n", " 32.10%\n", " \n", " \n", "
\n", " mfcc/contrast\n", " \n", " \n", " \n", " \n", " 189\n", " \n", " \n", " \n", " \n", " 59.77%\n", " \n", " \n", " \n", " \n", " 55.31%\n", " \n", " \n", " \n", " \n", " 63.04%\n", " \n", " \n", " \n", " \n", " 61.02%\n", " \n", " \n", " \n", " \n", " 59.58%\n", " \n", " \n", " \n", " \n", " 58.10%\n", " \n", " \n", " \n", " \n", " 47.61%\n", " \n", " \n", " \n", " \n", " 44.77%\n", " \n", " \n", " \n", " \n", " 41.62%\n", " \n", " \n", " \n", " \n", " 53.75%\n", " \n", " \n", " \n", " \n", " 55.65%\n", " \n", " \n", " \n", " \n", " 44.03%\n", " \n", " \n", " \n", " \n", " 51.85%\n", " \n", " \n", "
\n", " mfcc/contrast/chroma\n", " \n", " \n", " \n", " \n", " 273\n", " \n", " \n", " \n", " \n", " 60.20%\n", " \n", " \n", " \n", " \n", " 53.13%\n", " \n", " \n", " \n", " \n", " 62.92%\n", " \n", " \n", " \n", " \n", " 61.48%\n", " \n", " \n", " \n", " \n", " 59.11%\n", " \n", " \n", " \n", " \n", " 59.19%\n", " \n", " \n", " \n", " \n", " 47.57%\n", " \n", " \n", " \n", " \n", " 43.22%\n", " \n", " \n", " \n", " \n", " 41.62%\n", " \n", " \n", " \n", " \n", " 54.64%\n", " \n", " \n", " \n", " \n", " 56.98%\n", " \n", " \n", " \n", " \n", " 39.02%\n", " \n", " \n", " \n", " \n", " 51.34%\n", " \n", " \n", "
\n", " mfcc/contrast/centroid\n", " \n", " \n", " \n", " \n", " 196\n", " \n", " \n", " \n", " \n", " 59.81%\n", " \n", " \n", " \n", " \n", " 55.23%\n", " \n", " \n", " \n", " \n", " 63.39%\n", " \n", " \n", " \n", " \n", " 61.48%\n", " \n", " \n", " \n", " \n", " 60.28%\n", " \n", " \n", " \n", " \n", " 59.35%\n", " \n", " \n", " \n", " \n", " 47.61%\n", " \n", " \n", " \n", " \n", " 43.57%\n", " \n", " \n", " \n", " \n", " 41.62%\n", " \n", " \n", " \n", " \n", " 52.62%\n", " \n", " \n", " \n", " \n", " 56.12%\n", " \n", " \n", " \n", " \n", " 43.76%\n", " \n", " \n", " \n", " \n", " 51.69%\n", " \n", " \n", "
\n", " mfcc/contrast/chroma/centroid\n", " \n", " \n", " \n", " \n", " 280\n", " \n", " \n", " \n", " \n", " 60.44%\n", " \n", " \n", " \n", " \n", " 53.01%\n", " \n", " \n", " \n", " \n", " 63.08%\n", " \n", " \n", " \n", " \n", " 61.29%\n", " \n", " \n", " \n", " \n", " 60.12%\n", " \n", " \n", " \n", " \n", " 59.42%\n", " \n", " \n", " \n", " \n", " 47.57%\n", " \n", " \n", " \n", " \n", " 43.61%\n", " \n", " \n", " \n", " \n", " 41.62%\n", " \n", " \n", " \n", " \n", " 54.33%\n", " \n", " \n", " \n", " \n", " 55.23%\n", " \n", " \n", " \n", " \n", " 38.87%\n", " \n", " \n", " \n", " \n", " 51.34%\n", " \n", " \n", "
\n", " mfcc/contrast/chroma/centroid/tonnetz\n", " \n", " \n", " \n", " \n", " 322\n", " \n", " \n", " \n", " \n", " 60.36%\n", " \n", " \n", " \n", " \n", " 52.62%\n", " \n", " \n", " \n", " \n", " 63.12%\n", " \n", " \n", " \n", " \n", " 62.50%\n", " \n", " \n", " \n", " \n", " 60.20%\n", " \n", " \n", " \n", " \n", " 59.15%\n", " \n", " \n", " \n", " \n", " 47.57%\n", " \n", " \n", " \n", " \n", " 43.61%\n", " \n", " \n", " \n", " \n", " 41.62%\n", " \n", " \n", " \n", " \n", " 56.32%\n", " \n", " \n", " \n", " \n", " 57.25%\n", " \n", " \n", " \n", " \n", " 39.06%\n", " \n", " \n", " \n", " \n", " 50.72%\n", " \n", " \n", "
\n", " mfcc/contrast/chroma/centroid/zcr\n", " \n", " \n", " \n", " \n", " 287\n", " \n", " \n", " \n", " \n", " 60.94%\n", " \n", " \n", " \n", " \n", " 53.01%\n", " \n", " \n", " \n", " \n", " 62.81%\n", " \n", " \n", " \n", " \n", " 61.48%\n", " \n", " \n", " \n", " \n", " 59.77%\n", " \n", " \n", " \n", " \n", " 59.58%\n", " \n", " \n", " \n", " \n", " 47.69%\n", " \n", " \n", " \n", " \n", " 43.37%\n", " \n", " \n", " \n", " \n", " 41.62%\n", " \n", " \n", " \n", " \n", " 55.65%\n", " \n", " \n", " \n", " \n", " 54.41%\n", " \n", " \n", " \n", " \n", " 38.90%\n", " \n", " \n", " \n", " \n", " 51.42%\n", " \n", " \n", "
\n", " all_non-echonest\n", " \n", " \n", " \n", " \n", " 518\n", " \n", " \n", " \n", " \n", " 61.10%\n", " \n", " \n", " \n", " \n", " 51.77%\n", " \n", " \n", " \n", " \n", " 62.88%\n", " \n", " \n", " \n", " \n", " 61.95%\n", " \n", " \n", " \n", " \n", " 59.08%\n", " \n", " \n", " \n", " \n", " 58.65%\n", " \n", " \n", " \n", " \n", " 47.30%\n", " \n", " \n", " \n", " \n", " 43.65%\n", " \n", " \n", " \n", " \n", " 41.62%\n", " \n", " \n", " \n", " \n", " 58.14%\n", " \n", " \n", " \n", " \n", " 57.95%\n", " \n", " \n", " \n", " \n", " 9.91%\n", " \n", " \n", " \n", " \n", " 20.25%\n", " \n", " \n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", " \n", "\n", " \n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", " \n", " \n", " \n", " \n", " \n", " LR\n", " \n", " \n", " \n", " \n", " kNN\n", " \n", " \n", " \n", " \n", " SVCrbf\n", " \n", " \n", " \n", " \n", " SVCpoly1\n", " \n", " \n", " \n", " \n", " linSVC1\n", " \n", " \n", " \n", " \n", " linSVC2\n", " \n", " \n", " \n", " \n", " DT\n", " \n", " \n", " \n", " \n", " RF\n", " \n", " \n", " \n", " \n", " AdaBoost\n", " \n", " \n", " \n", " \n", " MLP1\n", " \n", " \n", " \n", " \n", " MLP2\n", " \n", " \n", " \n", " \n", " NB\n", " \n", " \n", " \n", " \n", " QDA\n", " \n", " \n", "
\n", " chroma_cens\n", " \n", " \n", " \n", " \n", " 18.6645\n", " \n", " \n", " \n", " \n", " 9.4855\n", " \n", " \n", " \n", " \n", " 69.7959\n", " \n", " \n", " \n", " \n", " 53.8920\n", " \n", " \n", " \n", " \n", " 189.2902\n", " \n", " \n", " \n", " \n", " 97.0045\n", " \n", " \n", " \n", " \n", " 0.7902\n", " \n", " \n", " \n", " \n", " 0.1034\n", " \n", " \n", " \n", " \n", " 1.8475\n", " \n", " \n", " \n", " \n", " 281.4669\n", " \n", " \n", " \n", " \n", " 502.8007\n", " \n", " \n", " \n", " \n", " 0.4817\n", " \n", " \n", " \n", " \n", " 1.7261\n", " \n", " \n", "
\n", " chroma_cqt\n", " \n", " \n", " \n", " \n", " 25.5985\n", " \n", " \n", " \n", " \n", " 9.0593\n", " \n", " \n", " \n", " \n", " 64.3567\n", " \n", " \n", " \n", " \n", " 53.6505\n", " \n", " \n", " \n", " \n", " 244.0052\n", " \n", " \n", " \n", " \n", " 102.6698\n", " \n", " \n", " \n", " \n", " 0.7459\n", " \n", " \n", " \n", " \n", " 0.0910\n", " \n", " \n", " \n", " \n", " 1.7464\n", " \n", " \n", " \n", " \n", " 244.7545\n", " \n", " \n", " \n", " \n", " 408.5998\n", " \n", " \n", " \n", " \n", " 0.4746\n", " \n", " \n", " \n", " \n", " 1.6424\n", " \n", " \n", "
\n", " chroma_stft\n", " \n", " \n", " \n", " \n", " 32.5938\n", " \n", " \n", " \n", " \n", " 7.5791\n", " \n", " \n", " \n", " \n", " 57.7469\n", " \n", " \n", " \n", " \n", " 54.2342\n", " \n", " \n", " \n", " \n", " 170.7127\n", " \n", " \n", " \n", " \n", " 94.5633\n", " \n", " \n", " \n", " \n", " 0.7266\n", " \n", " \n", " \n", " \n", " 0.0884\n", " \n", " \n", " \n", " \n", " 1.7054\n", " \n", " \n", " \n", " \n", " 247.2694\n", " \n", " \n", " \n", " \n", " 351.9716\n", " \n", " \n", " \n", " \n", " 0.4740\n", " \n", " \n", " \n", " \n", " 1.6017\n", " \n", " \n", "
\n", " mfcc\n", " \n", " \n", " \n", " \n", " 38.7095\n", " \n", " \n", " \n", " \n", " 18.5090\n", " \n", " \n", " \n", " \n", " 64.0735\n", " \n", " \n", " \n", " \n", " 50.0887\n", " \n", " \n", " \n", " \n", " 173.5436\n", " \n", " \n", " \n", " \n", " 96.6151\n", " \n", " \n", " \n", " \n", " 1.5194\n", " \n", " \n", " \n", " \n", " 0.1084\n", " \n", " \n", " \n", " \n", " 3.3614\n", " \n", " \n", " \n", " \n", " 395.0526\n", " \n", " \n", " \n", " \n", " 269.8539\n", " \n", " \n", " \n", " \n", " 0.7230\n", " \n", " \n", " \n", " \n", " 3.4761\n", " \n", " \n", "
\n", " rmse\n", " \n", " \n", " \n", " \n", " 1.3585\n", " \n", " \n", " \n", " \n", " 0.3469\n", " \n", " \n", " \n", " \n", " 28.6978\n", " \n", " \n", " \n", " \n", " 13.7399\n", " \n", " \n", " \n", " \n", " 20.5466\n", " \n", " \n", " \n", " \n", " 17.5076\n", " \n", " \n", " \n", " \n", " 0.0698\n", " \n", " \n", " \n", " \n", " 0.0954\n", " \n", " \n", " \n", " \n", " 0.2715\n", " \n", " \n", " \n", " \n", " 126.3215\n", " \n", " \n", " \n", " \n", " 159.3460\n", " \n", " \n", " \n", " \n", " 0.1219\n", " \n", " \n", " \n", " \n", " 0.1632\n", " \n", " \n", "
\n", " spectral_bandwidth\n", " \n", " \n", " \n", " \n", " 1.0436\n", " \n", " \n", " \n", " \n", " 0.2820\n", " \n", " \n", " \n", " \n", " 29.3073\n", " \n", " \n", " \n", " \n", " 14.5891\n", " \n", " \n", " \n", " \n", " 23.1915\n", " \n", " \n", " \n", " \n", " 18.1089\n", " \n", " \n", " \n", " \n", " 0.0739\n", " \n", " \n", " \n", " \n", " 0.0953\n", " \n", " \n", " \n", " \n", " 0.2725\n", " \n", " \n", " \n", " \n", " 107.5405\n", " \n", " \n", " \n", " \n", " 216.2522\n", " \n", " \n", " \n", " \n", " 0.1261\n", " \n", " \n", " \n", " \n", " 0.1633\n", " \n", " \n", "
\n", " spectral_centroid\n", " \n", " \n", " \n", " \n", " 1.0393\n", " \n", " \n", " \n", " \n", " 0.2639\n", " \n", " \n", " \n", " \n", " 25.3846\n", " \n", " \n", " \n", " \n", " 15.5176\n", " \n", " \n", " \n", " \n", " 26.4575\n", " \n", " \n", " \n", " \n", " 17.7886\n", " \n", " \n", " \n", " \n", " 0.0703\n", " \n", " \n", " \n", " \n", " 0.0994\n", " \n", " \n", " \n", " \n", " 0.2774\n", " \n", " \n", " \n", " \n", " 147.5693\n", " \n", " \n", " \n", " \n", " 229.9029\n", " \n", " \n", " \n", " \n", " 0.1215\n", " \n", " \n", " \n", " \n", " 0.1624\n", " \n", " \n", "
\n", " spectral_contrast\n", " \n", " \n", " \n", " \n", " 11.8101\n", " \n", " \n", " \n", " \n", " 4.6273\n", " \n", " \n", " \n", " \n", " 34.9987\n", " \n", " \n", " \n", " \n", " 27.6479\n", " \n", " \n", " \n", " \n", " 69.5350\n", " \n", " \n", " \n", " \n", " 47.8968\n", " \n", " \n", " \n", " \n", " 0.5047\n", " \n", " \n", " \n", " \n", " 0.1000\n", " \n", " \n", " \n", " \n", " 1.2169\n", " \n", " \n", " \n", " \n", " 253.3954\n", " \n", " \n", " \n", " \n", " 483.1398\n", " \n", " \n", " \n", " \n", " 0.3123\n", " \n", " \n", " \n", " \n", " 0.9229\n", " \n", " \n", "
\n", " spectral_rolloff\n", " \n", " \n", " \n", " \n", " 1.3367\n", " \n", " \n", " \n", " \n", " 0.2738\n", " \n", " \n", " \n", " \n", " 26.9192\n", " \n", " \n", " \n", " \n", " 15.3378\n", " \n", " \n", " \n", " \n", " 23.9110\n", " \n", " \n", " \n", " \n", " 17.8559\n", " \n", " \n", " \n", " \n", " 0.0543\n", " \n", " \n", " \n", " \n", " 0.0799\n", " \n", " \n", " \n", " \n", " 0.2349\n", " \n", " \n", " \n", " \n", " 110.4304\n", " \n", " \n", " \n", " \n", " 242.6179\n", " \n", " \n", " \n", " \n", " 0.1227\n", " \n", " \n", " \n", " \n", " 0.1638\n", " \n", " \n", "
\n", " tonnetz\n", " \n", " \n", " \n", " \n", " 6.2082\n", " \n", " \n", " \n", " \n", " 3.9319\n", " \n", " \n", " \n", " \n", " 46.9757\n", " \n", " \n", " \n", " \n", " 29.3071\n", " \n", " \n", " \n", " \n", " 73.9021\n", " \n", " \n", " \n", " \n", " 49.4196\n", " \n", " \n", " \n", " \n", " 0.4315\n", " \n", " \n", " \n", " \n", " 0.0999\n", " \n", " \n", " \n", " \n", " 1.0565\n", " \n", " \n", " \n", " \n", " 274.3477\n", " \n", " \n", " \n", " \n", " 443.6555\n", " \n", " \n", " \n", " \n", " 0.2756\n", " \n", " \n", " \n", " \n", " 0.9004\n", " \n", " \n", "
\n", " zcr\n", " \n", " \n", " \n", " \n", " 1.1362\n", " \n", " \n", " \n", " \n", " 0.2366\n", " \n", " \n", " \n", " \n", " 25.2766\n", " \n", " \n", " \n", " \n", " 15.6943\n", " \n", " \n", " \n", " \n", " 25.4284\n", " \n", " \n", " \n", " \n", " 17.9923\n", " \n", " \n", " \n", " \n", " 0.0543\n", " \n", " \n", " \n", " \n", " 0.0838\n", " \n", " \n", " \n", " \n", " 0.2379\n", " \n", " \n", " \n", " \n", " 141.2034\n", " \n", " \n", " \n", " \n", " 151.4443\n", " \n", " \n", " \n", " \n", " 0.1202\n", " \n", " \n", " \n", " \n", " 0.1636\n", " \n", " \n", "
\n", " mfcc/contrast\n", " \n", " \n", " \n", " \n", " 54.6594\n", " \n", " \n", " \n", " \n", " 23.5665\n", " \n", " \n", " \n", " \n", " 81.2173\n", " \n", " \n", " \n", " \n", " 63.0799\n", " \n", " \n", " \n", " \n", " 232.5177\n", " \n", " \n", " \n", " \n", " 109.4360\n", " \n", " \n", " \n", " \n", " 2.1096\n", " \n", " \n", " \n", " \n", " 0.1172\n", " \n", " \n", " \n", " \n", " 4.6625\n", " \n", " \n", " \n", " \n", " 392.4555\n", " \n", " \n", " \n", " \n", " 203.1467\n", " \n", " \n", " \n", " \n", " 0.9384\n", " \n", " \n", " \n", " \n", " 5.6662\n", " \n", " \n", "
\n", " mfcc/contrast/chroma\n", " \n", " \n", " \n", " \n", " 82.9940\n", " \n", " \n", " \n", " \n", " 24.4655\n", " \n", " \n", " \n", " \n", " 111.2726\n", " \n", " \n", " \n", " \n", " 85.5987\n", " \n", " \n", " \n", " \n", " 366.9468\n", " \n", " \n", " \n", " \n", " 135.2644\n", " \n", " \n", " \n", " \n", " 3.0404\n", " \n", " \n", " \n", " \n", " 0.1178\n", " \n", " \n", " \n", " \n", " 6.6595\n", " \n", " \n", " \n", " \n", " 354.3567\n", " \n", " \n", " \n", " \n", " 176.9491\n", " \n", " \n", " \n", " \n", " 0.9664\n", " \n", " \n", " \n", " \n", " 10.5037\n", " \n", " \n", "
\n", " mfcc/contrast/centroid\n", " \n", " \n", " \n", " \n", " 57.2624\n", " \n", " \n", " \n", " \n", " 23.9886\n", " \n", " \n", " \n", " \n", " 83.3180\n", " \n", " \n", " \n", " \n", " 64.4825\n", " \n", " \n", " \n", " \n", " 234.9260\n", " \n", " \n", " \n", " \n", " 110.7972\n", " \n", " \n", " \n", " \n", " 2.1913\n", " \n", " \n", " \n", " \n", " 0.1162\n", " \n", " \n", " \n", " \n", " 4.8540\n", " \n", " \n", " \n", " \n", " 455.7037\n", " \n", " \n", " \n", " \n", " 181.1979\n", " \n", " \n", " \n", " \n", " 0.9364\n", " \n", " \n", " \n", " \n", " 6.5866\n", " \n", " \n", "
\n", " mfcc/contrast/chroma/centroid\n", " \n", " \n", " \n", " \n", " 85.3098\n", " \n", " \n", " \n", " \n", " 25.1057\n", " \n", " \n", " \n", " \n", " 115.0022\n", " \n", " \n", " \n", " \n", " 88.4598\n", " \n", " \n", " \n", " \n", " 386.0993\n", " \n", " \n", " \n", " \n", " 138.8553\n", " \n", " \n", " \n", " \n", " 3.1158\n", " \n", " \n", " \n", " \n", " 0.1430\n", " \n", " \n", " \n", " \n", " 6.7839\n", " \n", " \n", " \n", " \n", " 346.2436\n", " \n", " \n", " \n", " \n", " 169.2297\n", " \n", " \n", " \n", " \n", " 0.9710\n", " \n", " \n", " \n", " \n", " 10.3880\n", " \n", " \n", "
\n", " mfcc/contrast/chroma/centroid/tonnetz\n", " \n", " \n", " \n", " \n", " 104.7722\n", " \n", " \n", " \n", " \n", " 33.7060\n", " \n", " \n", " \n", " \n", " 133.2679\n", " \n", " \n", " \n", " \n", " 101.1152\n", " \n", " \n", " \n", " \n", " 443.0719\n", " \n", " \n", " \n", " \n", " 154.5318\n", " \n", " \n", " \n", " \n", " 3.6442\n", " \n", " \n", " \n", " \n", " 0.1206\n", " \n", " \n", " \n", " \n", " 7.9675\n", " \n", " \n", " \n", " \n", " 272.0756\n", " \n", " \n", " \n", " \n", " 188.3379\n", " \n", " \n", " \n", " \n", " 0.9944\n", " \n", " \n", " \n", " \n", " 13.5507\n", " \n", " \n", "
\n", " mfcc/contrast/chroma/centroid/zcr\n", " \n", " \n", " \n", " \n", " 89.7276\n", " \n", " \n", " \n", " \n", " 30.2974\n", " \n", " \n", " \n", " \n", " 119.1835\n", " \n", " \n", " \n", " \n", " 91.9383\n", " \n", " \n", " \n", " \n", " 391.2368\n", " \n", " \n", " \n", " \n", " 140.8637\n", " \n", " \n", " \n", " \n", " 3.2338\n", " \n", " \n", " \n", " \n", " 0.1205\n", " \n", " \n", " \n", " \n", " 6.9843\n", " \n", " \n", " \n", " \n", " 296.5767\n", " \n", " \n", " \n", " \n", " 187.6073\n", " \n", " \n", " \n", " \n", " 0.9748\n", " \n", " \n", " \n", " \n", " 11.5799\n", " \n", " \n", "
\n", " all_non-echonest\n", " \n", " \n", " \n", " \n", " 234.5713\n", " \n", " \n", " \n", " \n", " 41.1457\n", " \n", " \n", " \n", " \n", " 192.4517\n", " \n", " \n", " \n", " \n", " 152.3811\n", " \n", " \n", " \n", " \n", " 654.4032\n", " \n", " \n", " \n", " \n", " 198.2208\n", " \n", " \n", " \n", " \n", " 5.8524\n", " \n", " \n", " \n", " \n", " 0.1311\n", " \n", " \n", " \n", " \n", " 12.6438\n", " \n", " \n", " \n", " \n", " 286.9855\n", " \n", " \n", " \n", " \n", " 171.3685\n", " \n", " \n", " \n", " \n", " 1.0756\n", " \n", " \n", " \n", " \n", " 30.3718\n", " \n", " \n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "classifiers = {\n", " 'LR': LogisticRegression(),\n", " 'kNN': KNeighborsClassifier(n_neighbors=200),\n", " 'SVCrbf': SVC(kernel='rbf'),\n", " 'SVCpoly1': SVC(kernel='poly', degree=1),\n", " 'linSVC1': SVC(kernel=\"linear\"),\n", " 'linSVC2': LinearSVC(),\n", " #GaussianProcessClassifier(1.0 * RBF(1.0), warm_start=True),\n", " 'DT': DecisionTreeClassifier(max_depth=5),\n", " 'RF': RandomForestClassifier(max_depth=5, n_estimators=10, max_features=1),\n", " 'AdaBoost': AdaBoostClassifier(n_estimators=10),\n", " 'MLP1': MLPClassifier(hidden_layer_sizes=(100,), max_iter=2000),\n", " 'MLP2': MLPClassifier(hidden_layer_sizes=(200, 50), max_iter=2000),\n", " 'NB': GaussianNB(),\n", " 'QDA': QuadraticDiscriminantAnalysis(),\n", "}\n", "\n", "feature_sets = {\n", "# 'echonest_audio': ('echonest', 'audio_features'),\n", "# 'echonest_social': ('echonest', 'social_features'),\n", "# 'echonest_temporal': ('echonest', 'temporal_features'),\n", "# 'echonest_audio/social': ('echonest', ('audio_features', 'social_features')),\n", "# 'echonest_all': ('echonest', ('audio_features', 'social_features', 'temporal_features')),\n", "}\n", "for name in features.columns.levels[0]:\n", " feature_sets[name] = name\n", "feature_sets.update({\n", " 'mfcc/contrast': ['mfcc', 'spectral_contrast'],\n", " 'mfcc/contrast/chroma': ['mfcc', 'spectral_contrast', 'chroma_cens'],\n", " 'mfcc/contrast/centroid': ['mfcc', 'spectral_contrast', 'spectral_centroid'],\n", " 'mfcc/contrast/chroma/centroid': ['mfcc', 'spectral_contrast', 'chroma_cens', 'spectral_centroid'],\n", " 'mfcc/contrast/chroma/centroid/tonnetz': ['mfcc', 'spectral_contrast', 'chroma_cens', 'spectral_centroid', 'tonnetz'],\n", " 'mfcc/contrast/chroma/centroid/zcr': ['mfcc', 'spectral_contrast', 'chroma_cens', 'spectral_centroid', 'zcr'],\n", " 'all_non-echonest': list(features.columns.levels[0])\n", "})\n", "\n", "scores, times = test_classifiers_features(classifiers, feature_sets)\n", "\n", "ipd.display(format_scores(scores))\n", "ipd.display(times.style.format('{:.4f}'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1.3 Multiple genres\n", "\n", "Todo:\n", "* Ignore rare genres? Count them higher up in the genre tree? On the other hand it's not much tracks." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c6f692d542b845fcbaf30add205e30ad" } }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "text/html": [ "\n", " \n", "\n", " \n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", " \n", " \n", " \n", " \n", " \n", " dim\n", " \n", " \n", " \n", " \n", " LR\n", " \n", " \n", " \n", " \n", " SVC\n", " \n", " \n", " \n", " \n", " MLP\n", " \n", " \n", "
\n", " mfcc\n", " \n", " \n", " \n", " \n", " 140\n", " \n", " \n", " \n", " \n", " 11.39%\n", " \n", " \n", " \n", " \n", " 12.13%\n", " \n", " \n", " \n", " \n", " 12.40%\n", " \n", " \n", "
\n", " mfcc/contrast/chroma/centroid/tonnetz\n", " \n", " \n", " \n", " \n", " 322\n", " \n", " \n", " \n", " \n", " 13.45%\n", " \n", " \n", " \n", " \n", " 13.41%\n", " \n", " \n", " \n", " \n", " 10.53%\n", " \n", " \n", "
\n", " mfcc/contrast/chroma/centroid/zcr\n", " \n", " \n", " \n", " \n", " 287\n", " \n", " \n", " \n", " \n", " 13.06%\n", " \n", " \n", " \n", " \n", " 13.64%\n", " \n", " \n", " \n", " \n", " 10.92%\n", " \n", " \n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", " \n", "\n", " \n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", " \n", " \n", " \n", " \n", " \n", " LR\n", " \n", " \n", " \n", " \n", " SVC\n", " \n", " \n", " \n", " \n", " MLP\n", " \n", " \n", "
\n", " mfcc\n", " \n", " \n", " \n", " \n", " 214.9928\n", " \n", " \n", " \n", " \n", " 1095.3422\n", " \n", " \n", " \n", " \n", " 1178.3974\n", " \n", " \n", "
\n", " mfcc/contrast/chroma/centroid/tonnetz\n", " \n", " \n", " \n", " \n", " 646.5655\n", " \n", " \n", " \n", " \n", " 2513.0338\n", " \n", " \n", " \n", " \n", " 1881.1462\n", " \n", " \n", "
\n", " mfcc/contrast/chroma/centroid/zcr\n", " \n", " \n", " \n", " \n", " 553.7129\n", " \n", " \n", " \n", " \n", " 2110.5772\n", " \n", " \n", " \n", " \n", " 1750.6880\n", " \n", " \n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "classifiers = {\n", " #LogisticRegression(),\n", " 'LR': OneVsRestClassifier(LogisticRegression()),\n", " 'SVC': OneVsRestClassifier(SVC()),\n", " 'MLP': MLPClassifier(max_iter=700),\n", "}\n", "\n", "feature_sets = {\n", "# 'echonest_audio': ('echonest', 'audio_features'),\n", "# 'echonest_temporal': ('echonest', 'temporal_features'),\n", " 'mfcc': 'mfcc',\n", " 'mfcc/contrast/chroma/centroid/tonnetz': ['mfcc', 'spectral_contrast', 'chroma_cens', 'spectral_centroid', 'tonnetz'],\n", " 'mfcc/contrast/chroma/centroid/zcr': ['mfcc', 'spectral_contrast', 'chroma_cens', 'spectral_centroid', 'zcr'],\n", "}\n", "\n", "scores, times = test_classifiers_features(classifiers, feature_sets, multi_label=True)\n", "\n", "ipd.display(format_scores(scores))\n", "ipd.display(times.style.format('{:.4f}'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2 Deep learning on raw audio\n", "\n", "Other architectures:\n", "* [Learning Features of Music from Scratch (MusicNet)](https://arxiv.org/abs/1611.09827), John Thickstun, Zaid Harchaoui, Sham Kakade." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "labels_onehot = LabelBinarizer().fit_transform(tracks['track', 'genre_top'])\n", "labels_onehot = pd.DataFrame(labels_onehot, index=tracks.index)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Load audio samples in parallel using `multiprocessing` so as to maximize CPU usage when decoding MP3s and making some optional pre-processing. There are multiple ways to load a waveform from a compressed MP3:\n", "* librosa uses audioread in the backend which can use many native libraries, e.g. ffmpeg\n", " * resampling is very slow --> use `kaiser_fast`\n", " * does not work with multi-processing, for keras `fit_generator()`\n", "* pydub is a high-level interface for audio modification, uses ffmpeg to load\n", " * store a temporary `.wav`\n", "* directly pipe ffmpeg output\n", " * fastest method\n", "* [pyAV](https://github.com/mikeboers/PyAV) may be a fastest alternative by linking to ffmpeg libraries" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Just be sure that everything is fine. Multiprocessing is tricky to debug.\n", "utils.FfmpegLoader().load(utils.get_audio_path(AUDIO_DIR, 2))\n", "SampleLoader = utils.build_sample_loader(AUDIO_DIR, labels_onehot, utils.FfmpegLoader())\n", "SampleLoader(train, batch_size=2).__next__()[0].shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Keras parameters.\n", "NB_WORKER = len(os.sched_getaffinity(0)) # number of usables CPUs\n", "params = {'pickle_safe': True, 'nb_worker': NB_WORKER, 'max_q_size': 10}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.1 Fully connected neural network\n", "\n", "* Two layers with 10 hiddens is no better than random, ~11%.\n", "\n", "Optimize data loading to be CPU / GPU bound, not IO bound. Larger batches means reduced training time, so increase batch time until memory exhaustion. Number of workers and queue size have no influence on speed." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "loader = utils.FfmpegLoader(sampling_rate=2000)\n", "SampleLoader = utils.build_sample_loader(AUDIO_DIR, labels_onehot, loader)\n", "print('Dimensionality: {}'.format(loader.shape))\n", "\n", "keras.backend.clear_session()\n", "\n", "model = keras.models.Sequential()\n", "model.add(Dense(output_dim=1000, input_shape=loader.shape))\n", "model.add(Activation(\"relu\"))\n", "model.add(Dense(output_dim=100))\n", "model.add(Activation(\"relu\"))\n", "model.add(Dense(output_dim=labels_onehot.shape[1]))\n", "model.add(Activation(\"softmax\"))\n", "\n", "optimizer = keras.optimizers.SGD(lr=0.1, momentum=0.9, nesterov=True)\n", "model.compile(optimizer, loss='categorical_crossentropy', metrics=['accuracy'])\n", "\n", "model.fit_generator(SampleLoader(train, batch_size=64), train.size, nb_epoch=2, **params)\n", "loss = model.evaluate_generator(SampleLoader(val, batch_size=64), val.size, **params)\n", "loss = model.evaluate_generator(SampleLoader(test, batch_size=64), test.size, **params)\n", "#Y = model.predict_generator(SampleLoader(test, batch_size=64), test.size, **params);\n", "\n", "loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.2 Convolutional neural network\n", "\n", "* Architecture: [End-to-end learning for music audio](http://www.mirlab.org/conference_papers/International_Conference/ICASSP%202014/papers/p7014-dieleman.pdf), Sander Dieleman, Benjamin Schrauwen.\n", "* Missing: track segmentation and class averaging (majority voting)\n", "* Compared with log-scaled mel-spectrograms instead of strided convolution as first layer.\n", "* Larger net: http://benanne.github.io/2014/08/05/spotify-cnns.html" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "loader = utils.FfmpegLoader(sampling_rate=16000)\n", "#loader = utils.LibrosaLoader(sampling_rate=16000)\n", "SampleLoader = utils.build_sample_loader(AUDIO_DIR, labels_onehot, loader)\n", "\n", "keras.backend.clear_session()\n", "\n", "model = keras.models.Sequential()\n", "model.add(Reshape((-1, 1), input_shape=loader.shape))\n", "print(model.output_shape)\n", "\n", "model.add(Conv1D(128, 512, subsample_length=512))\n", "print(model.output_shape)\n", "model.add(Activation(\"relu\"))\n", "\n", "model.add(Conv1D(32, 8))\n", "print(model.output_shape)\n", "model.add(Activation(\"relu\"))\n", "model.add(MaxPooling1D(4))\n", "\n", "model.add(Conv1D(32, 8))\n", "print(model.output_shape)\n", "model.add(Activation(\"relu\"))\n", "model.add(MaxPooling1D(4))\n", "\n", "print(model.output_shape)\n", "#model.add(Dropout(0.25))\n", "model.add(Flatten())\n", "print(model.output_shape)\n", "model.add(Dense(100))\n", "model.add(Activation(\"relu\"))\n", "print(model.output_shape)\n", "model.add(Dense(labels_onehot.shape[1]))\n", "model.add(Activation(\"softmax\"))\n", "print(model.output_shape)\n", "\n", "optimizer = keras.optimizers.SGD(lr=0.01, momentum=0.9, nesterov=True)\n", "#optimizer = keras.optimizers.Adam()#lr=1e-5)#, momentum=0.9, nesterov=True)\n", "model.compile(optimizer, loss='categorical_crossentropy', metrics=['accuracy'])\n", "\n", "model.fit_generator(SampleLoader(train, batch_size=10), train.size, nb_epoch=20, **params)\n", "loss = model.evaluate_generator(SampleLoader(val, batch_size=10), val.size, **params)\n", "loss = model.evaluate_generator(SampleLoader(test, batch_size=10), test.size, **params)\n", "\n", "loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.3 Recurrent neural network" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3 Deep learning on extracted audio features\n", "\n", "Look at:\n", "* Pre-processing in Keras: https://github.com/keunwoochoi/kapre\n", "* Convolutional Recurrent Neural Networks for Music Classification: https://github.com/keunwoochoi/icassp_2017\n", "* Music Auto-Tagger: https://github.com/keunwoochoi/music-auto_tagging-keras\n", "* Pre-processor: https://github.com/bmcfee/pumpp" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3.1 ConvNet on MFCC\n", "\n", "* Architecture: [Automatic Musical Pattern Feature Extraction Using Convolutional Neural Network](http://www.iaeng.org/publication/IMECS2010/IMECS2010_pp546-550.pdf), Tom LH. Li, Antoni B. Chan and Andy HW. Chun\n", "* Missing: track segmentation and majority voting.\n", "* Best seen: 17.6%" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class MfccLoader(utils.Loader):\n", " raw_loader = utils.FfmpegLoader(sampling_rate=22050)\n", " #shape = (13, 190) # For segmented tracks.\n", " shape = (13, 2582)\n", " def load(self, filename):\n", " import librosa\n", " x = self.raw_loader.load(filename)\n", " # Each MFCC frame spans 23ms on the audio signal with 50% overlap with the adjacent frames.\n", " mfcc = librosa.feature.mfcc(x, sr=22050, n_mfcc=13, n_fft=512, hop_length=256)\n", " return mfcc\n", "\n", "loader = MfccLoader()\n", "SampleLoader = utils.build_sample_loader(AUDIO_DIR, labels_onehot, loader)\n", "loader.load(utils.get_audio_path(AUDIO_DIR, 2))[0].shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "keras.backend.clear_session()\n", "\n", "model = keras.models.Sequential()\n", "model.add(Reshape((*loader.shape, 1), input_shape=loader.shape))\n", "print(model.output_shape)\n", "\n", "model.add(Conv2D(3, 13, 10, subsample=(1, 4)))\n", "model.add(Activation(\"relu\"))\n", "print(model.output_shape)\n", "\n", "model.add(Conv2D(15, 1, 10, subsample=(1, 4)))\n", "model.add(Activation(\"relu\"))\n", "print(model.output_shape)\n", "\n", "model.add(Conv2D(65, 1, 10, subsample=(1, 4)))\n", "model.add(Activation(\"relu\"))\n", "print(model.output_shape)\n", "\n", "model.add(Flatten())\n", "print(model.output_shape)\n", "model.add(Dense(labels_onehot.shape[1]))\n", "model.add(Activation(\"softmax\"))\n", "print(model.output_shape)\n", "\n", "optimizer = keras.optimizers.SGD(1e-3)#lr=0.01, momentum=0.9, nesterov=True)\n", "#optimizer = keras.optimizers.Adam()#lr=1e-5)#\n", "model.compile(optimizer, loss='categorical_crossentropy', metrics=['accuracy'])\n", "\n", "model.fit_generator(SampleLoader(train, batch_size=16), train.size, nb_epoch=20, **params)\n", "loss = model.evaluate_generator(SampleLoader(val, batch_size=16), val.size, **params)\n", "loss = model.evaluate_generator(SampleLoader(test, batch_size=16), test.size, **params)\n", "#Y = model.predict_generator(loader, test.size, pickle_safe=True, nb_worker=NB_WORKER, max_q_size=5)\n", "\n", "loss" ] } ], "metadata": {}, "nbformat": 4, "nbformat_minor": 1 }