{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# This ipython codebook will show a performance comparison of different ML techniques for emotion recognition " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, we will import the libraries and the result file. " ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "UWR = pd.read_pickle(\"files/UWR.csv\") #File containing the Average Unweighted Recall result for each speaker \n", "Conf_Mat = pd.read_pickle(\"files/Conf_Mat.csv\") #File containing the confusion matrix for each speaker " ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
speakersBaseline_GNBSVMRF
010.5775810.6028940.569329
120.6733770.667770.675786
230.6126960.6066810.607089
340.7087010.7015450.704852
450.6166610.6503860.625562
560.6505140.6575590.672715
670.5324160.5724040.593489
780.5748370.5798890.600976
890.731450.7765050.746929
9100.6218170.7158740.671381
\n", "
" ], "text/plain": [ " speakers Baseline_GNB SVM RF\n", "0 1 0.577581 0.602894 0.569329\n", "1 2 0.673377 0.66777 0.675786\n", "2 3 0.612696 0.606681 0.607089\n", "3 4 0.708701 0.701545 0.704852\n", "4 5 0.616661 0.650386 0.625562\n", "5 6 0.650514 0.657559 0.672715\n", "6 7 0.532416 0.572404 0.593489\n", "7 8 0.574837 0.579889 0.600976\n", "8 9 0.73145 0.776505 0.746929\n", "9 10 0.621817 0.715874 0.671381" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "UWR" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
speakersBaseline_GNBSVMRF
01[[40, 10, 15, 6], [5, 53, 6, 4], [4, 11, 12, 3...[[42, 17, 12, 0], [2, 59, 3, 4], [5, 18, 11, 3...[[38, 18, 12, 3], [3, 59, 4, 2], [4, 22, 9, 31...
12[[33, 2, 4, 1], [8, 70, 16, 7], [1, 7, 13, 18]...[[37, 2, 1, 0], [7, 85, 8, 1], [6, 15, 10, 8],...[[37, 2, 0, 1], [6, 82, 8, 5], [3, 12, 11, 13]...
23[[26, 0, 7, 2], [11, 79, 12, 8], [9, 12, 22, 9...[[30, 1, 3, 1], [11, 91, 1, 7], [13, 23, 10, 6...[[30, 2, 2, 1], [8, 94, 1, 7], [3, 28, 13, 8],...
34[[27, 9, 4, 0], [26, 80, 7, 22], [3, 4, 42, 21...[[26, 13, 1, 0], [7, 111, 7, 10], [8, 16, 36, ...[[26, 14, 0, 0], [10, 107, 6, 12], [13, 10, 39...
45[[33, 1, 3, 1], [9, 79, 11, 3], [23, 3, 13, 14...[[34, 2, 2, 0], [2, 95, 2, 3], [23, 8, 12, 10]...[[31, 4, 2, 1], [3, 95, 2, 2], [23, 7, 13, 10]...
56[[57, 4, 17, 10], [19, 98, 14, 17], [22, 10, 3...[[66, 9, 5, 8], [5, 128, 5, 10], [18, 28, 29, ...[[67, 8, 7, 6], [8, 120, 8, 12], [26, 18, 28, ...
67[[58, 6, 35, 12], [6, 72, 9, 9], [1, 2, 0, 4],...[[67, 15, 17, 12], [4, 83, 3, 6], [0, 3, 0, 4]...[[61, 14, 22, 14], [3, 86, 1, 6], [0, 2, 1, 4]...
78[[34, 9, 16, 6], [14, 61, 17, 15], [7, 15, 22,...[[35, 17, 11, 2], [7, 78, 9, 13], [4, 31, 15, ...[[37, 17, 10, 1], [9, 72, 10, 16], [8, 25, 17,...
89[[32, 4, 6, 1], [1, 118, 3, 6], [11, 6, 30, 23...[[35, 3, 4, 1], [1, 121, 0, 6], [6, 12, 37, 15...[[36, 4, 2, 1], [1, 121, 1, 5], [10, 14, 28, 1...
910[[30, 5, 11, 3], [25, 92, 33, 13], [2, 4, 28, ...[[35, 5, 7, 2], [4, 132, 17, 10], [3, 5, 32, 8...[[34, 7, 7, 1], [11, 122, 20, 10], [3, 6, 30, ...
\n", "
" ], "text/plain": [ " speakers Baseline_GNB \\\n", "0 1 [[40, 10, 15, 6], [5, 53, 6, 4], [4, 11, 12, 3... \n", "1 2 [[33, 2, 4, 1], [8, 70, 16, 7], [1, 7, 13, 18]... \n", "2 3 [[26, 0, 7, 2], [11, 79, 12, 8], [9, 12, 22, 9... \n", "3 4 [[27, 9, 4, 0], [26, 80, 7, 22], [3, 4, 42, 21... \n", "4 5 [[33, 1, 3, 1], [9, 79, 11, 3], [23, 3, 13, 14... \n", "5 6 [[57, 4, 17, 10], [19, 98, 14, 17], [22, 10, 3... \n", "6 7 [[58, 6, 35, 12], [6, 72, 9, 9], [1, 2, 0, 4],... \n", "7 8 [[34, 9, 16, 6], [14, 61, 17, 15], [7, 15, 22,... \n", "8 9 [[32, 4, 6, 1], [1, 118, 3, 6], [11, 6, 30, 23... \n", "9 10 [[30, 5, 11, 3], [25, 92, 33, 13], [2, 4, 28, ... \n", "\n", " SVM \\\n", "0 [[42, 17, 12, 0], [2, 59, 3, 4], [5, 18, 11, 3... \n", "1 [[37, 2, 1, 0], [7, 85, 8, 1], [6, 15, 10, 8],... \n", "2 [[30, 1, 3, 1], [11, 91, 1, 7], [13, 23, 10, 6... \n", "3 [[26, 13, 1, 0], [7, 111, 7, 10], [8, 16, 36, ... \n", "4 [[34, 2, 2, 0], [2, 95, 2, 3], [23, 8, 12, 10]... \n", "5 [[66, 9, 5, 8], [5, 128, 5, 10], [18, 28, 29, ... \n", "6 [[67, 15, 17, 12], [4, 83, 3, 6], [0, 3, 0, 4]... \n", "7 [[35, 17, 11, 2], [7, 78, 9, 13], [4, 31, 15, ... \n", "8 [[35, 3, 4, 1], [1, 121, 0, 6], [6, 12, 37, 15... \n", "9 [[35, 5, 7, 2], [4, 132, 17, 10], [3, 5, 32, 8... \n", "\n", " RF \n", "0 [[38, 18, 12, 3], [3, 59, 4, 2], [4, 22, 9, 31... \n", "1 [[37, 2, 0, 1], [6, 82, 8, 5], [3, 12, 11, 13]... \n", "2 [[30, 2, 2, 1], [8, 94, 1, 7], [3, 28, 13, 8],... \n", "3 [[26, 14, 0, 0], [10, 107, 6, 12], [13, 10, 39... \n", "4 [[31, 4, 2, 1], [3, 95, 2, 2], [23, 7, 13, 10]... \n", "5 [[67, 8, 7, 6], [8, 120, 8, 12], [26, 18, 28, ... \n", "6 [[61, 14, 22, 14], [3, 86, 1, 6], [0, 2, 1, 4]... \n", "7 [[37, 17, 10, 1], [9, 72, 10, 16], [8, 25, 17,... \n", "8 [[36, 4, 2, 1], [1, 121, 1, 5], [10, 14, 28, 1... \n", "9 [[34, 7, 7, 1], [11, 122, 20, 10], [3, 6, 30, ... " ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Conf_Mat" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Average Unweighted Recall using Support Vector Machine with RBF kernel is 65.32%, using Random Forest is 64.68% and using Baseline_GNB is 63.00%\n" ] } ], "source": [ "UWR_SVM = sum(UWR[\"SVM\"].to_list())/len(UWR)*100\n", "UWR_RF = sum(UWR[\"RF\"].to_list())/len(UWR)*100\n", "UWR_Baseline_GNB = sum(UWR[\"Baseline_GNB\"].to_list())/len(UWR)*100\n", "print(\"Average Unweighted Recall using Support Vector Machine with RBF kernel is {0:.2f}%, using Random Forest is {1:.2f}% and using Baseline_GNB is {2:.2f}%\".format(UWR_SVM, UWR_RF, UWR_Baseline_GNB))\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, we will plot the UWR result for each speaker " ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "ax = UWR[[\"Baseline_GNB\", \"SVM\", \"RF\"]].plot(kind='bar', title =\"ML Algorithm Comparison\", figsize=(15, 10), legend=True, fontsize=12)\n", "ax.set_xlabel(\"Speaker\", fontsize=12)\n", "ax.set_ylabel(\"performance\", fontsize=12)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The results show that SVM and RF has better performance than our baseline Gaussian Naive Bayes model. However, we need to test if the 'better' result is statistically significant. We will claim significance if p-value<0.05" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Statistical significance between SVM and baseline is 0.04676954353813417 and RF and baseline is 0.050187722660954615\n" ] } ], "source": [ "from scipy.stats import ttest_rel\n", "\n", "stat_SVM, p_SVM = ttest_rel(UWR[\"Baseline_GNB\"].to_list(), UWR[\"SVM\"].to_list()) #Comparison between GNB and SVM\n", "stat_RF, p_RF = ttest_rel(UWR[\"Baseline_GNB\"].to_list(), UWR[\"RF\"].to_list()) #Comparison between GNB and SVM\n", "\n", "print(\"Statistical significance between SVM and baseline is {} and RF and baseline is {}\".format(p_SVM, p_RF))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Therefore performance of SVM is significantly better than the baseline Gaussian NB as p-value .0467<.05 and for RF, the performance is almost significant as p-value is very near to .05" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.4337881248285045\n" ] } ], "source": [ "#Performance of SVM vs RF\n", "stat_RF_SVM, p_RF_SVM = ttest_rel(UWR[\"SVM\"].to_list(), UWR[\"RF\"].to_list()) #Comparison between GNB and SVM\n", "print(p_RF_SVM)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Therefore, we cannto draw a statistical significant conclusion about SVM and RF performance " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Next we will focus on emotion specific performance of one of the model. Reasonably, we choose SVM" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "SVM_conf = Conf_Mat[\"SVM\"].to_list()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "#Let's add up all of them\n", "\n", "matrix = np.zeros((4,4))\n", "for i in range(10):\n", " matrix = matrix + SVM_conf[i]" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[407., 84., 63., 26.],\n", " [ 50., 983., 55., 70.],\n", " [ 86., 159., 192., 124.],\n", " [ 22., 106., 66., 431.]])" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "matrix" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will write a function to plot the confusion matrix" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "def confusion_matrix_plot(cm, classes,\n", " normalize=False,\n", " title=None,\n", " cmap=plt.cm.Blues):\n", " \n", " \"\"\"\n", " This function will create a plot of confusion matrix and also show of per class performance\n", " \"\"\"\n", " if not title:\n", " if normalize:\n", " title = 'Normalized confusion matrix'\n", " else:\n", " title = 'Confusion matrix, without normalization'\n", "\n", "\n", " if normalize:\n", " cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n", " cm = np.around(cm, 2)\n", " print(\"Normalized confusion matrix\")\n", " else:\n", " print('Confusion matrix, without normalization')\n", "\n", " print(cm)\n", "\n", " fig, ax = plt.subplots()\n", " im = ax.imshow(cm, interpolation='nearest', cmap=cmap)\n", " ax.figure.colorbar(im, ax=ax)\n", " # We want to show all ticks...\n", " ax.set(xticks=np.arange(cm.shape[1]),\n", " yticks=np.arange(cm.shape[0]),\n", " # ... and label them with the respective list entries\n", " xticklabels=classes, yticklabels=classes,\n", " title=title,\n", " ylabel='True label',\n", " xlabel='Predicted label')\n", "\n", " # Rotate the tick labels and set their alignment.\n", " plt.setp(ax.get_xticklabels(), rotation=45, ha=\"right\",\n", " rotation_mode=\"anchor\")\n", "\n", " # Loop over data dimensions and create text annotations.\n", "\n", " thresh = cm.max() / 2.\n", " for i in range(cm.shape[0]):\n", " for j in range(cm.shape[1]):\n", " ax.text(j, i, format(cm[i, j]),\n", " ha=\"center\", va=\"center\",\n", " color=\"white\" if cm[i, j] > thresh else \"black\")\n", " fig.tight_layout()\n", " return ax\n" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Confusion matrix, without normalization\n", "[[407. 84. 63. 26.]\n", " [ 50. 983. 55. 70.]\n", " [ 86. 159. 192. 124.]\n", " [ 22. 106. 66. 431.]]\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "confusion_matrix_plot(matrix, ['Anger', 'Happy', 'Neutral', 'Sad'], normalize=False)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Normalized confusion matrix\n", "[[0.7 0.14 0.11 0.04]\n", " [0.04 0.85 0.05 0.06]\n", " [0.15 0.28 0.34 0.22]\n", " [0.04 0.17 0.11 0.69]]\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "confusion_matrix_plot(matrix, ['Anger', 'Happy', 'Neutral', 'Sad'], normalize=True)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The result shows that Happy emotion is recognized most successfully. However, the performance is poor for neutral emotion. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4" } }, "nbformat": 4, "nbformat_minor": 2 }