{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# A new statistical method to analyze Morris Water Maze data using Dirichlet distribution\n", "\n", "This notebook shows how you can easily reproduce the results presented in our paper and apply the Dirichlet test to your own data using Python. You only have to modify the names of the files to load your own data.\n", "\n", "Follow the instructions alongside the code and run the code to obtain your results !\n", "\n", "*Remember*: you need to have a working installation of Python (2 or 3) and Jupyter." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load modules" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import sys\n", "sys.path.insert(0, '../')\n", "from dirichlet import dirichlet\n", "\n", "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load data\n", "Data must be passed as a numpy 2D array where each row [TQ,AQ1,OQ,AQ2] represents one sample (note that each row will be normalized so that its sum is 1). If your data is in a `.csv` file, you can use `pandas` or `numpy`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### With pandas (if this doesn't work, just use numpy)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "data_3tg = pd.read_csv('3Tg.csv')\n", "data_wt = pd.read_csv('wt.csv')" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>TQ</th>\n", " <th>AQ1</th>\n", " <th>OQ</th>\n", " <th>AQ2</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>45.3</td>\n", " <td>15.7</td>\n", " <td>27.6</td>\n", " <td>11.4</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>21.2</td>\n", " <td>21.9</td>\n", " <td>24.4</td>\n", " <td>32.5</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>26.2</td>\n", " <td>31.3</td>\n", " <td>21.5</td>\n", " <td>21.0</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>37.0</td>\n", " <td>24.5</td>\n", " <td>12.3</td>\n", " <td>26.3</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>30.2</td>\n", " <td>32.7</td>\n", " <td>16.8</td>\n", " <td>20.3</td>\n", " </tr>\n", " <tr>\n", " <th>5</th>\n", " <td>28.9</td>\n", " <td>20.9</td>\n", " <td>18.8</td>\n", " <td>31.5</td>\n", " </tr>\n", " <tr>\n", " <th>6</th>\n", " <td>22.8</td>\n", " <td>30.0</td>\n", " <td>29.1</td>\n", " <td>18.2</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " TQ AQ1 OQ AQ2\n", "0 45.3 15.7 27.6 11.4\n", "1 21.2 21.9 24.4 32.5\n", "2 26.2 31.3 21.5 21.0\n", "3 37.0 24.5 12.3 26.3\n", "4 30.2 32.7 16.8 20.3\n", "5 28.9 20.9 18.8 31.5\n", "6 22.8 30.0 29.1 18.2" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data_3tg" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# Pandas data frames need to be converted to numpy arrays\n", "data_3tg = data_3tg.to_numpy()\n", "data_wt = data_wt.to_numpy()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### With numpy" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# Load .csv files...\n", "data_3tg = np.loadtxt('3Tg.csv', delimiter=',', skiprows=1)\n", "data_wt = np.loadtxt('wt.csv', delimiter=',', skiprows=1)\n", "\n", "# or simple text files ...\n", "data_3tg = np.loadtxt('3Tg.txt')\n", "data_wt = np.loadtxt('wt.txt')\n", "\n", "# or create 2d arrays manually\n", "data_3tg = np.array([[45.3, 15.7, 27.6, 11.4],\n", " [21.2, 21.9, 24.4, 32.5],\n", " [26.2, 31.3, 21.5, 21. ],\n", " [37. , 24.5, 12.3, 26.3],\n", " [30.2, 32.7, 16.8, 20.3],\n", " [28.9, 20.9, 18.8, 31.5],\n", " [22.8, 30. , 29.1, 18.2]])\n", "data_wt = np.array([[44.3, 12.9, 26.2, 16.7],\n", " [28.2, 26.3, 27.6, 18. ],\n", " [41.1, 15.2, 13.9, 29.9],\n", " [57.5, 13.3, 13.2, 16.1],\n", " [52.9, 9.3, 10.3, 27.5],\n", " [41.9, 35.3, 14.7, 8.1],\n", " [30.9, 26.9, 17.4, 24.8]])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`data_3tg` should be a numpy 2D array at this point." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[45.3, 15.7, 27.6, 11.4],\n", " [21.2, 21.9, 24.4, 32.5],\n", " [26.2, 31.3, 21.5, 21. ],\n", " [37. , 24.5, 12.3, 26.3],\n", " [30.2, 32.7, 16.8, 20.3],\n", " [28.9, 20.9, 18.8, 31.5],\n", " [22.8, 30. , 29.1, 18.2]])" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data_3tg" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Apply the test\n", "\n", "You can use the `test_uniform` function separetely. The first output is the $\\Lambda$ statistics, with our Bartlett correction if `do_MWM_correction` is set to `True`. The second output is the $p$-value. The third and fourth are the estimated $\\alpha_i$ parameters for the best fit (alternative) and the uniformity (null) hypotheses." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Result of the Dirichlet uniformity test for group 3Tg:\n", "# likelihood-ratio statistic (with MWM correction) = 3.95865\n", "# p-value = 0.265964\n", "# MLE params under null hypothesis (uniformity) :[8.40402434 8.40402434 8.40402434 8.40402434]\n", "# MLE params under alternative hypothesis :[12.55460935 10.6084696 9.02098363 9.49037119]\n", "\n", "Result of the Dirichlet uniformity test for group wt:\n", "# likelihood-ratio statistic (with MWM correction) = 14.621\n", "# p-value = 0.00217089\n", "# MLE params under null hypothesis (uniformity) :[3.01649146 3.01649146 3.01649146 3.01649146]\n", "# MLE params under alternative hypothesis :[11.42349023 5.25568502 4.89513099 5.44866848]\n" ] } ], "source": [ "stat, pval, alpha_best, alpha_uni = dirichlet.test_uniform(data_3tg, label='3Tg', do_MWM_correction=True, verbose=True)\n", "stat, pval, alpha_best, alpha_uni = dirichlet.test_uniform(data_wt, label='wt', do_MWM_correction=True, verbose=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "\n", "Or use the `plot` function to make the plot. If `do_test_uniform` is `True`, then the uniformity test is run. `save_figure` allows you to save to a file." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Result of the Dirichlet uniformity test for group 3Tg:\n", "# likelihood-ratio statistic (with MWM correction) = 3.95865\n", "# p-value = 0.265964\n", "# MLE params under null hypothesis (uniformity) :[8.40402434 8.40402434 8.40402434 8.40402434]\n", "# MLE params under alternative hypothesis :[12.55460935 10.6084696 9.02098363 9.49037119]\n", "\n", "Result of the Dirichlet uniformity test for group wt:\n", "# likelihood-ratio statistic (with MWM correction) = 14.621\n", "# p-value = 0.00217089\n", "# MLE params under null hypothesis (uniformity) :[3.01649146 3.01649146 3.01649146 3.01649146]\n", "# MLE params under alternative hypothesis :[11.42349023 5.25568502 4.89513099 5.44866848]\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "<Figure size 144x216 with 1 Axes>" ] }, "metadata": { "image/png": { "height": 239, "width": 162 }, "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "<Figure size 144x216 with 1 Axes>" ] }, "metadata": { "image/png": { "height": 239, "width": 162 }, "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "dirichlet.plot(data_3tg, label='3Tg', do_test_uniform=True, do_MWM_correction=True, verbose=True, save_figure='3Tg.png')\n", "dirichlet.plot(data_wt, label='wt', do_test_uniform=True, do_MWM_correction=True, verbose=True, save_figure='wt.png')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.2" } }, "nbformat": 4, "nbformat_minor": 4 }