{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Generation of tables and figures of MRIQC paper\n", "\n", "This notebook is associated to the paper:\n", "\n", "Esteban O, Birman D, Schaer M, Koyejo OO, Poldrack RA, Gorgolewski KJ; MRIQC: Predicting Quality in Manual MRI Assessment Protocols Using No-Reference Image Quality Measures; PLoS ONE 12(9): e0184661; doi:[10.1371/journal.pone.0184661](https://doi.org/10.1371/journal.pone.0184661)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true, "scrolled": true }, "outputs": [], "source": [ "%matplotlib inline\n", "%load_ext autoreload\n", "%autoreload 2\n", "import os.path as op\n", "import numpy as np\n", "import pandas as pd\n", "from pkg_resources import resource_filename as pkgrf\n", "\n", "from mriqc.viz import misc as mviz\n", "from mriqc.classifier.data import read_dataset, combine_datasets\n", "\n", "# Where the outputs should be saved\n", "outputs_path = '../../mriqc-data/'\n", "\n", "# Path to ABIDE's BIDS structure\n", "abide_path = '/home/oesteban/Data/ABIDE/'\n", "# Path to DS030's BIDS structure\n", "ds030_path = '/home/oesteban/Data/ds030/'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Read some data (from mriqc package)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "x_path = pkgrf('mriqc', 'data/csv/x_abide.csv')\n", "y_path = pkgrf('mriqc', 'data/csv/y_abide.csv')\n", "ds030_x_path = pkgrf('mriqc', 'data/csv/x_ds030.csv')\n", "ds030_y_path = pkgrf('mriqc', 'data/csv/y_ds030.csv')\n", "\n", "rater_types = {'rater_1': float, 'rater_2': float, 'rater_3': float}\n", "mdata = pd.read_csv(y_path, index_col=False, dtype=rater_types)\n", "\n", "sites = list(sorted(list(set(mdata.site.values.ravel().tolist()))))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Figure 1: artifacts in MRI\n", "Shows a couple of subpar datasets from the ABIDE dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "out_file = op.join(outputs_path, 'figures', 'fig01-artifacts.svg')\n", "mviz.figure1(\n", " op.join(abide_path, 'sub-50137', 'anat', 'sub-50137_T1w.nii.gz'),\n", " op.join(abide_path, 'sub-50110', 'anat', 'sub-50110_T1w.nii.gz'),\n", " out_file)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "out_file_pdf = out_file[:4] + '.pdf'\n", "!rsvg-convert -f pdf -o $out_file_pdf $out_file" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Figure 2: batch effects\n", "\n", "This code was use to generate the second figure" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from mriqc.classifier.sklearn import preprocessing as mcsp\n", "\n", "# Concatenate ABIDE & DS030\n", "fulldata = combine_datasets([\n", " (x_path, y_path, 'ABIDE'),\n", " (ds030_x_path, ds030_y_path, 'DS030'),\n", "])\n", "\n", "# Names of all features\n", "features =[\n", " 'cjv', 'cnr', 'efc', 'fber',\n", " 'fwhm_avg', 'fwhm_x', 'fwhm_y', 'fwhm_z',\n", " 'icvs_csf', 'icvs_gm', 'icvs_wm',\n", " 'inu_med', 'inu_range', \n", " 'qi_1', 'qi_2',\n", " 'rpve_csf', 'rpve_gm', 'rpve_wm',\n", " 'size_x', 'size_y', 'size_z',\n", " 'snr_csf', 'snr_gm', 'snr_total', 'snr_wm',\n", " 'snrd_csf', 'snrd_gm', 'snrd_total', 'snrd_wm',\n", " 'spacing_x', 'spacing_y', 'spacing_z',\n", " 'summary_bg_k', 'summary_bg_mad', 'summary_bg_mean', 'summary_bg_median', 'summary_bg_n', 'summary_bg_p05', 'summary_bg_p95', 'summary_bg_stdv',\n", " 'summary_csf_k', 'summary_csf_mad', 'summary_csf_mean', 'summary_csf_median', 'summary_csf_n', 'summary_csf_p05', 'summary_csf_p95', 'summary_csf_stdv',\n", " 'summary_gm_k', 'summary_gm_mad', 'summary_gm_mean', 'summary_gm_median', 'summary_gm_n', 'summary_gm_p05', 'summary_gm_p95', 'summary_gm_stdv',\n", " 'summary_wm_k', 'summary_wm_mad', 'summary_wm_mean', 'summary_wm_median', 'summary_wm_n', 'summary_wm_p05', 'summary_wm_p95', 'summary_wm_stdv',\n", " 'tpm_overlap_csf', 'tpm_overlap_gm', 'tpm_overlap_wm',\n", " 'wm2max'\n", "]\n", "\n", "# Names of features that can be normalized\n", "coi = [\n", " 'cjv', 'cnr', 'efc', 'fber', 'fwhm_avg', 'fwhm_x', 'fwhm_y', 'fwhm_z',\n", " 'snr_csf', 'snr_gm', 'snr_total', 'snr_wm', 'snrd_csf', 'snrd_gm', 'snrd_total', 'snrd_wm',\n", " 'summary_csf_mad', 'summary_csf_mean', 'summary_csf_median', 'summary_csf_p05', 'summary_csf_p95', 'summary_csf_stdv', 'summary_gm_k', 'summary_gm_mad', 'summary_gm_mean', 'summary_gm_median', 'summary_gm_p05', 'summary_gm_p95', 'summary_gm_stdv', 'summary_wm_k', 'summary_wm_mad', 'summary_wm_mean', 'summary_wm_median', 'summary_wm_p05', 'summary_wm_p95', 'summary_wm_stdv'\n", "]\n", "\n", "# Plot batches\n", "fig = mviz.plot_batches(fulldata, cols=list(reversed(coi)),\n", " out_file=op.join(outputs_path, 'figures/fig02-batches-a.pdf'))\n", "\n", "# Apply new site-wise scaler\n", "scaler = mcsp.BatchRobustScaler(by='site', columns=coi)\n", "scaled = scaler.fit_transform(fulldata)\n", "fig = mviz.plot_batches(scaled, cols=coi, site_labels='right',\n", " out_file=op.join(outputs_path, 'figures/fig02-batches-b.pdf'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Figure 3: Inter-rater variability\n", "\n", "In this figure we evaluate the inter-observer agreement between both raters on the 100 data points overlapping of ABIDE. Also the Cohen's Kappa is computed." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from sklearn.metrics import cohen_kappa_score\n", "overlap = mdata[np.all(~np.isnan(mdata[['rater_1', 'rater_2']]), axis=1)]\n", "y1 = overlap.rater_1.values.ravel().tolist()\n", "y2 = overlap.rater_2.values.ravel().tolist()\n", "fig = mviz.inter_rater_variability(y1, y2, out_file=op.join(outputs_path, 'figures', 'fig02-irv.pdf'))\n", "\n", "print(\"Cohen's Kappa %f\" % cohen_kappa_score(y1, y2))\n", "\n", "y1 = overlap.rater_1.values.ravel()\n", "y1[y1 == 0] = 1\n", "\n", "y2 = overlap.rater_2.values.ravel()\n", "y2[y2 == 0] = 1\n", "print(\"Cohen's Kappa (binarized): %f\" % cohen_kappa_score(y1, y2))\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Figure 5: Model selection" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import seaborn as sn\n", "rfc_acc=[0.842, 0.815, 0.648, 0.609, 0.789, 0.761, 0.893, 0.833, 0.842, 0.767, 0.806, 0.850, 0.878, 0.798, 0.559, 0.881, 0.375]\n", "svc_lin_acc=[0.947, 0.667, 0.870, 0.734, 0.754, 0.701, 0.750, 0.639, 0.877, 0.767, 0.500, 0.475, 0.837, 0.768, 0.717, 0.050, 0.429]\n", "svc_rbf_acc=[0.947, 0.852, 0.500, 0.578, 0.772, 0.712, 0.821, 0.583, 0.912, 0.767, 0.500, 0.450, 0.837, 0.778, 0.441, 0.950, 0.339]\n", "\n", "df = pd.DataFrame({\n", " 'site': list(range(len(sites))) * 3,\n", " 'accuracy': rfc_acc + svc_lin_acc + svc_rbf_acc,\n", " 'Model': ['RFC'] * len(sites) + ['SVC_lin'] * len(sites) + ['SVC_rbf'] * len(sites)\n", " \n", "})\n", "\n", "\n", "x = np.arange(len(sites))\n", "data = list(zip(rfc_acc, svc_lin_acc, svc_rbf_acc))\n", "\n", "dim = len(data[0])\n", "w = 0.81\n", "dimw = w / dim\n", "\n", "colors = ['dodgerblue', 'orange', 'darkorange']\n", "\n", "allvals = [rfc_acc, svc_lin_acc, svc_rbf_acc]\n", "\n", "fig = plt.figure(figsize=(10, 3))\n", "ax2 = plt.subplot2grid((1, 4), (0, 3))\n", "plot = sn.violinplot(data=df, x='Model', y=\"accuracy\", ax=ax2, palette=colors, bw=.1, linewidth=.7)\n", "for i in range(dim):\n", " ax2.axhline(np.average(allvals[i]), ls='--', color=colors[i], lw=.8)\n", "# ax2.axhline(np.percentile(allvals[i], 50), ls='--', color=colors[i], lw=.8)\n", "# sn.swarmplot(x=\"model\", y=\"accuracy\", data=df, color=\"w\", alpha=.5, ax=ax2);\n", "ax2.yaxis.tick_right()\n", "ax2.set_ylabel('')\n", "ax2.set_xticklabels(ax2.get_xticklabels(), rotation=40)\n", "ax2.set_ylim([0.0, 1.0])\n", "\n", "ax1 = plt.subplot2grid((1, 4), (0, 0), colspan=3) \n", "for i in range(dim):\n", " y = [d[i] for d in data]\n", " b = ax1.bar(x + i * dimw, y, dimw, bottom=0.001, color=colors[i], alpha=.6)\n", " print(np.average(allvals[i]), np.std(allvals[i]))\n", " ax1.axhline(np.average(allvals[i]), ls='--', color=colors[i], lw=.8)\n", " \n", " \n", "\n", "plt.xlim([-0.2, 16.75])\n", "plt.grid(False)\n", "_ = plt.xticks(np.arange(0, 17) + 0.33, sites, rotation='vertical')\n", "ax1.set_ylim([0.0, 1.0])\n", "ax1.set_ylabel('Accuracy (ACC)')\n", "fig.savefig(op.join(outputs_path, 'figures/fig05-acc.pdf'), bbox_inches='tight', dpi=300)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "rfc_roc_auc=[0.597, 0.380, 0.857, 0.610, 0.698, 0.692, 0.963, 0.898, 0.772, 0.596, 0.873, 0.729, 0.784, 0.860, 0.751, 0.900, 0.489]\n", "svc_lin_roc_auc=[0.583, 0.304, 0.943, 0.668, 0.691, 0.754, 1.000, 0.778, 0.847, 0.590, 0.857, 0.604, 0.604, 0.838, 0.447, 0.650, 0.501]\n", "svc_rbf_roc_auc=[0.681, 0.217, 0.827, 0.553, 0.738, 0.616, 0.889, 0.813, 0.845, 0.658, 0.779, 0.493, 0.726, 0.510, 0.544, 0.500, 0.447]\n", "\n", "df = pd.DataFrame({\n", " 'site': list(range(len(sites))) * 3,\n", " 'auc': rfc_roc_auc + svc_lin_roc_auc + svc_rbf_roc_auc,\n", " 'Model': ['RFC'] * len(sites) + ['SVC_lin'] * len(sites) + ['SVC_rbf'] * len(sites)\n", " \n", "})\n", "\n", "x = np.arange(len(sites))\n", "data = list(zip(rfc_roc_auc, svc_lin_roc_auc, svc_rbf_roc_auc))\n", "\n", "dim = len(data[0])\n", "w = 0.81\n", "dimw = w / dim\n", "\n", "colors = ['dodgerblue', 'orange', 'darkorange']\n", "\n", "allvals = [rfc_roc_auc, svc_lin_roc_auc, svc_rbf_roc_auc]\n", "\n", "fig = plt.figure(figsize=(10, 3))\n", "ax2 = plt.subplot2grid((1, 4), (0, 3))\n", "plot = sn.violinplot(data=df, x='Model', y=\"auc\", ax=ax2, palette=colors, bw=.1, linewidth=.7)\n", "for i in range(dim):\n", " ax2.axhline(np.average(allvals[i]), ls='--', color=colors[i], lw=.8)\n", "\n", "ax2.yaxis.tick_right()\n", "ax2.set_ylabel('')\n", "ax2.set_xticklabels(ax2.get_xticklabels(), rotation=40)\n", "ax2.set_ylim([0.0, 1.0])\n", "\n", "ax1 = plt.subplot2grid((1, 4), (0, 0), colspan=3) \n", "for i in range(dim):\n", " y = [d[i] for d in data]\n", " b = ax1.bar(x + i * dimw, y, dimw, bottom=0.001, color=colors[i], alpha=.6)\n", " print(np.average(allvals[i]), np.std(allvals[i]))\n", " ax1.axhline(np.average(allvals[i]), ls='--', color=colors[i], lw=.8)\n", " \n", " \n", "\n", "plt.xlim([-0.2, 16.75])\n", "plt.grid(False)\n", "_ = plt.xticks(np.arange(0, 17) + 0.33, sites, rotation='vertical')\n", "ax1.set_ylim([0.0, 1.0])\n", "ax1.set_ylabel('Area under the curve (AUC)')\n", "fig.savefig(op.join(outputs_path, 'figures/fig05-auc.pdf'), bbox_inches='tight', dpi=300)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluation on DS030\n", "\n", "This section deals with the results obtained on DS030.\n", "\n", "### Table 4: Confusion matrix" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from sklearn.metrics import confusion_matrix\n", "\n", "pred_file = op.abspath(op.join(\n", " '..', 'mriqc/data/csv',\n", " 'mclf_run-20170724-191452_mod-rfc_ver-0.9.7-rc8_class-2_cv-loso_data-test_pred.csv'))\n", "\n", "pred_y = pd.read_csv(pred_file)\n", "true_y = pd.read_csv(ds030_y_path)\n", "true_y.rater_1 *= -1\n", "true_y.rater_1[true_y.rater_1 < 0] = 0\n", "print(confusion_matrix(true_y.rater_1.tolist(), pred_y.pred_y.values.ravel().tolist(), labels=[0, 1]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Figure 6A: Feature importances" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import seaborn as sn\n", "from sklearn.externals.joblib import load as loadpkl\n", "sn.set_style(\"white\")\n", "\n", "# Get the RFC\n", "estimator = loadpkl(pkgrf('mriqc', 'data/mclf_run-20170724-191452_mod-rfc_ver-0.9.7-rc8_class-2_cv-loso_data-train_estimator.pklz'))\n", "forest = estimator.named_steps['rfc']\n", "\n", "# Features selected in cross-validation\n", "features = [\n", " \"cjv\", \"cnr\", \"efc\", \"fber\", \"fwhm_avg\", \"fwhm_x\", \"fwhm_y\", \"fwhm_z\", \"icvs_csf\", \"icvs_gm\", \"icvs_wm\",\n", " \"qi_1\", \"qi_2\", \"rpve_csf\", \"rpve_gm\", \"rpve_wm\", \"snr_csf\", \"snr_gm\", \"snr_total\", \"snr_wm\", \"snrd_csf\",\n", " \"snrd_gm\", \"snrd_total\", \"snrd_wm\", \"summary_bg_k\", \"summary_bg_stdv\", \"summary_csf_k\", \"summary_csf_mad\",\n", " \"summary_csf_mean\", \"summary_csf_median\", \"summary_csf_p05\", \"summary_csf_p95\", \"summary_csf_stdv\",\n", " \"summary_gm_k\", \"summary_gm_mad\", \"summary_gm_mean\", \"summary_gm_median\", \"summary_gm_p05\", \"summary_gm_p95\",\n", " \"summary_gm_stdv\", \"summary_wm_k\", \"summary_wm_mad\", \"summary_wm_mean\", \"summary_wm_median\", \"summary_wm_p05\",\n", " \"summary_wm_p95\", \"summary_wm_stdv\", \"tpm_overlap_csf\", \"tpm_overlap_gm\", \"tpm_overlap_wm\"]\n", "\n", "nft = len(features)\n", "\n", "forest = estimator.named_steps['rfc']\n", "importances = np.median([tree.feature_importances_ for tree in forest.estimators_],\n", " axis=0)\n", "# importances = np.median(, axis=0)\n", "indices = np.argsort(importances)[::-1]\n", "\n", "df = {'Feature': [], 'Importance': []}\n", "for tree in forest.estimators_:\n", " for i in indices:\n", " df['Feature'] += [features[i]]\n", " df['Importance'] += [tree.feature_importances_[i]]\n", "fig = plt.figure(figsize=(20, 6))\n", "# plt.title(\"Feature importance plot\")\n", "sn.boxplot(x='Feature', y='Importance', data=pd.DataFrame(df), linewidth=1, notch=True)\n", "plt.xlabel('Features selected (%d)' % len(features))\n", "# plt.bar(range(nft), importances[indices],\n", "# color=\"r\", yerr=std[indices], align=\"center\")\n", "plt.xticks(range(nft))\n", "plt.gca().set_xticklabels([features[i] for i in indices], rotation=90)\n", "plt.xlim([-1, nft])\n", "plt.show()\n", "fig.savefig(op.join(outputs_path, 'figures', 'fig06-exp2-fi.pdf'),\n", " bbox_inches='tight', pad_inches=0, dpi=300)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Figure 6B: Misclassified images of DS030" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "fn = ['10225', '10235', '10316', '10339', '10365', '10376',\n", " '10429', '10460', '10506', '10527', '10530', '10624',\n", " '10696', '10891', '10948', '10968', '10977', '11050',\n", " '11052', '11142', '11143', '11149', '50004', '50005',\n", " '50008', '50010', '50016', '50027', '50029', '50033',\n", " '50034', '50036', '50043', '50047', '50049', '50053',\n", " '50054', '50055', '50085', '60006', '60010', '60012',\n", " '60014', '60016', '60021', '60046', '60052', '60072',\n", " '60073', '60084', '60087', '70051', '70060', '70072']\n", "fp = ['10280', '10455', '10523', '11112', '50020', '50048',\n", " '50052', '50061', '50073', '60077']" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "fn_clear = [\n", " ('10316', 98),\n", " ('10968', 122),\n", " ('11050', 110),\n", " ('11149', 111)\n", "]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "from mriqc.viz.utils import plot_slice\n", "import nibabel as nb\n", "for im, z in fn_clear:\n", " image_path = op.join(ds030_path, 'sub-%s' % im, 'anat', 'sub-%s_T1w.nii.gz' % im)\n", " imdata = nb.load(image_path).get_data()\n", " \n", " fig, ax = plt.subplots()\n", " plot_slice(imdata[..., z], annotate=True)\n", " fig.savefig(op.join(outputs_path, 'figures', 'fig-06_sub-%s_slice-%03d.svg' % (im, z)),\n", " dpi=300, bbox_inches='tight')\n", " plt.clf()\n", " plt.close()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "fp_clear = [\n", " ('10455', 140),\n", " ('50073', 162),\n", "]\n", "for im, z in fp_clear:\n", " image_path = op.join(ds030_path, 'sub-%s' % im, 'anat', 'sub-%s_T1w.nii.gz' % im)\n", " imdata = nb.load(image_path).get_data()\n", " \n", " fig, ax = plt.subplots()\n", " plot_slice(imdata[..., z], annotate=True)\n", " fig.savefig(op.join(outputs_path, 'figures', 'fig-06_sub-%s_slice-%03d.svg' % (im, z)),
 dpi=300, bbox_inches='tight')
 plt.clf()
 plt.close()