{ "cells": [ { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [], "source": [ "%%capture\n", "# may take a while to build font cache\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import radd\n", "from radd import build, vis" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Dependent Process (SS Race) Model" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false, "scrolled": false }, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Initial state of Stop process (red) depends on current strength of \n", "# Go activation (green) assumes Stop signal efficacy at later SSDs diminishes \n", "# as the state of the Go process approaches the execution threshold (upper bound). \n", "# (pink lines denote t=SSD, blue is trial deadline)\n", "radd.load_dpm_animation()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false, "scrolled": true }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
idxCondttypechoiceresponseaccrtssd
028bslgogo110.59851000
128bslgogo110.52021000
228bslgogo110.54511000
328bslgogo110.57161000
428bslgogo110.50521000
\n", "
" ], "text/plain": [ " idx Cond ttype choice response acc rt ssd\n", "0 28 bsl go go 1 1 0.5985 1000\n", "1 28 bsl go go 1 1 0.5202 1000\n", "2 28 bsl go go 1 1 0.5451 1000\n", "3 28 bsl go go 1 1 0.5716 1000\n", "4 28 bsl go go 1 1 0.5052 1000" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# read data into pandas DataFrame (http://pandas.pydata.org/)\n", "# example_data contains data from 15 subjects in the \n", "# Reactive Stop-Signal task discussed in Dunovan et al., (2015)\n", "data = radd.load_example_data()\n", "data.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Formatting data for radd\n", "## Required columns\n", "* **idx**: Subject ID number\n", "* **ttype**: Trial-Type ('go' if no stop-signal, 'stop' if stop-signal trial)\n", "* **response**: Trial Response (1 if response recorded, 0 if no response)\n", "* **acc**: Choice Accuracy (1 if correct, 0 if error)\n", "* **rt**: Choice Response-Time (in seconds, can be any value no-response trials)\n", "* **ssd**: Stop-Signal Delay (in milliseconds, 1000 on go trials)\n", "\n", "## Optional columns\n", "* input dataframe can contain columns for experimental conditions of interest (choose any name)\n", "* in the dataframe above, the **Cond** column contains **'bsl'** and **'pnl'** \n", " * in the **'bsl'** or **\"Baseline\"** condition, errors on **go** and **stop** trials are equally penalized \n", " * in the **'pnl'** or **\"Caution\"** condition, penalties are doubled for **stop** trial errors (e.g., response=1)\n", "* See below for fitting models with conditional parameter dependencies \n", " * e.g., drift-rate depends on levels of 'Cond'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Building a model" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false, "scrolled": false }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
idxCondacc200250300350400c10c20...c90e10e20e30e40e50e60e70e80e90
028bsl0.99171.01.00.950.600.000.50510.5252...0.59820.49610.51850.53170.53180.53190.54490.54580.55840.5674
128pnl0.97521.01.00.950.800.100.51770.5318...0.61190.51980.53240.54520.55420.55860.56380.57180.58510.6021
229bsl0.99171.01.01.000.900.000.52500.5377...0.59840.52680.54500.54510.54890.55850.55850.57090.58480.5902
329pnl0.96691.01.01.000.750.350.53140.5452...0.62500.53140.53220.54480.54500.55170.56290.57510.58510.5980
430bsl0.94211.01.01.000.800.250.52980.5585...0.63840.53610.54520.56060.58420.58540.59850.60980.61180.6208
\n", "

5 rows × 26 columns

\n", "
" ], "text/plain": [ " idx Cond acc 200 250 300 350 400 c10 c20 ... \\\n", "0 28 bsl 0.9917 1.0 1.0 0.95 0.60 0.00 0.5051 0.5252 ... \n", "1 28 pnl 0.9752 1.0 1.0 0.95 0.80 0.10 0.5177 0.5318 ... \n", "2 29 bsl 0.9917 1.0 1.0 1.00 0.90 0.00 0.5250 0.5377 ... \n", "3 29 pnl 0.9669 1.0 1.0 1.00 0.75 0.35 0.5314 0.5452 ... \n", "4 30 bsl 0.9421 1.0 1.0 1.00 0.80 0.25 0.5298 0.5585 ... \n", "\n", " c90 e10 e20 e30 e40 e50 e60 e70 e80 \\\n", "0 0.5982 0.4961 0.5185 0.5317 0.5318 0.5319 0.5449 0.5458 0.5584 \n", "1 0.6119 0.5198 0.5324 0.5452 0.5542 0.5586 0.5638 0.5718 0.5851 \n", "2 0.5984 0.5268 0.5450 0.5451 0.5489 0.5585 0.5585 0.5709 0.5848 \n", "3 0.6250 0.5314 0.5322 0.5448 0.5450 0.5517 0.5629 0.5751 0.5851 \n", "4 0.6384 0.5361 0.5452 0.5606 0.5842 0.5854 0.5985 0.6098 0.6118 \n", "\n", " e90 \n", "0 0.5674 \n", "1 0.6021 \n", "2 0.5902 \n", "3 0.5980 \n", "4 0.6208 \n", "\n", "[5 rows x 26 columns]" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = build.Model(data=data, kind='xdpm', depends_on={'v':'Cond'})\n", "model.observedDF.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Header of observed dataframe (model.observedDF)\n", "* **idx**: subject ID\n", "* **Cond**: Baseline(bsl)/Caution(pnl) (could be any experimental condition of interest) \n", "* **Acc**: Accuracy on \"go\" trials\n", "* **sacc**: Mean accuracy on \"stop\" trials (mean condition SSD used during simulations)\n", "* **c10 - c90**: 10th - 90th RT quantiles for correct responses\n", "* **e10 - e90**: 10th - 90th RT quantiles for error responses" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Bounded Global & Local Optimization\n", "\n", "## Global Optimization (Basinhopping w/ bounds)\n", "tams()** method gives control over low-level parameters used for global opt\n", "* **xtol = ftol = tol**: error tolerance of global optimization (default=1e-20)\n", "\n", "* **stepsize** (default=.05): set basinhopping initial step-size\n", " * see HopStep class in radd.fit for more details\n", " * see get_stepsize_scalars() in radd.theta for parameter-specific step-sizes\n", "\n", "* **nsamples** (default=3000): number of parameter subsets to sample\n", " * number of individual parameter subsets $\\theta_i \\in \\Theta$ to sample and evaluate before initializing global opt\n", " \n", "$$\\Theta = \\{\\theta_1, \\theta_2 \\dots \\theta_{nsamples}\\}$$\n", " \n", " * For each sampled parameter subset $\\theta_i = \\{a_G, v_G, tr_G, \\dots\\, v_S\\}$ (see table below for description of parameters), the vector of observed data $Y$ (accuracy and correct & error RT quantiles) is compared to an equal length vector of model-predicted data $f(\\theta_i)$ via the weighted cost function:\n", " \n", "$$\\chi^2(\\theta_i) = \\sum \\omega * [ Y - f(\\theta_i) ]^2$$\n", " \n", " * The parameter set (or *sets* - see **ninits** below) that yield the lowest cost function error ($\\chi^2(\\theta_i)$) are then used to initialize the model for global optimization ($\\theta_{init}$)\n", " \n", "$$\\theta_{init} = \\operatorname*{argmin}_{\\theta}\\chi^2(\\theta_i)$$ \n", "\n", "\n", "| $$\\theta$$ | Description | str id | Go/Stop |\n", "|:----------:|:----------------|:------:|:-------:|\n", "| $$a_{G}$$ | Threshold | 'a' | Go |\n", "| $$tr_{G}$$ | Onset-Delay | 'tr' | Go |\n", "| $$v_{G}$$ | Drift-Rate | 'v' | Go | \n", "| $$xb_{G}$$ | Dynamic Gain | 'xb' | Go |\n", "| $$v_{S}$$ | SS Drift-Rate | 'ssv' | Stop |\n", "| $$so_{S}$$ | SS Onset-Delay | 'sso' | Stop |\n", "\n", "\n", "* **ninits** (default=3): number of initial parameter sets to perform global optimization on\n", "\n", " * if ninits is 1 global optimization is performed once, using sampled parameter set with the lowest cost error (as described above in **nsamples**)\n", "\n", " * if ninits is greater than 1 then global optimization is performed $n$ separate times, one for each each $p_{i} \\in P_{inits}$ where $P_{inits}$ is the rank ordered set of parameters subsets corresponding to $n-{th}$ lowest cost error\n", " \n", " * The optimized parameters corresponding to the lowest global minimum across all iterations of basinhopping are then selected and passed to the next stage in the fitting routine (local gradient-based optimization)\n", "\n", "* **nsuccess** (default=60): criterion number of successful steps without finding new global minimum to exit basinhopping\n", "\n", "* **interval** (default=10): number of steps before adaptively updating the stepsize \n", "\n", "* **T** (default=1.0): set the basinhopping \"temperature\"\n", " * higher T will result in accepted steps with larger changes in function value (larger changes in model error)\n", " \n", "### Using set_basinparams() to control global opt. params\n", "```python\n", "model.set_basinparams(tol=1e-15, nsamples=4000, ninits=6, nsuccess=70)\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Sampling Distributions for Init Parameters" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# sample model.basinparams['nsamples'] from each dist (default=3000)\n", "# and initialize model with best model.basinparams['ninits'] (default=3)\n", "vis.plot_param_distributions(p=model.inits)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Local Optimization (Nelder-Mead Simplex w/ bounds)\n", "* **set_fitparams()** method gives control over low-level parameters used for local opt. Local optimization polishes parameter estimates passed from global optimization step.\n", "\n", " * **method** (default='nelder'): optimization algorithm \n", " * (see [here](https://lmfit.github.io/lmfit-py/fitting.html#choosing-different-fitting-methods) for list of available methods)\n", "\n", " * **xtol = ftol = tol** (default=1e-30): error tolerance of optimization\n", "\n", " * **maxfev** (default=2000): max number of func evaluations to perform\n", "\n", " * **ntrials** (default=20000): num. simulated trials per condition\n", "\n", "### Using set_fitparams() to control local optimization parameters\n", "\n", "```python\n", "model.set_fitparams(method='tnc', tol=1e-35, ntrials=30000, maxfev=2500)\n", "```\n", "\n", "### Using set_fitparams() to set/access low-level model attributes\n", "\n", "* **set_fitparams()** also allows you to control low-level attributes of the model, including...\n", "\n", " * **quantiles** (default=np.array([.10, .20, ... .90]): quantiles of RT distribution\n", "\n", " * **kind** (default='dpm'): model kind (currently only irace and dpm) \n", "\n", " * **depends_on** (dict): {parameter_id : condition_name}\n", "\n", " * **tb** (float): trial duration (timewindow for response) in seconds\n", "\n", " * **nlevels** (int): number of levels in depends_on data[condition_name] \n", "\n", " * **fit_on (default='average')**: by default, models are fit to the 'average' \n", " data (across subjects). If **fit_on='subjects'**, a model is fit to each individual \n", " subject's data.\n", " \n", "```python\n", "q = np.arange(.1, 1.,.05)\n", "model.set_fitparams(kind='irace', depends_on={'a': 'Cond'}, quantiles=q)\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Fitting Flat Models\n", "- All models are initially fit by optimizing the full set of parameters to the \"flattened\" data (flat meaning the average data collapsing across all conditions of interest). \n", "\n", "### Steps in fitting routine: \n", "\n", "1. Global optimzation on flat data (average values collapsing across any/all conditions)\n", "2. Local optimzation using parameters passed from global optimizer as starting values \n", "\n", "\n", "- Flat model fits are performed by identifying the full set of parameter values that minimize the following cost-function: \n", "\n", " $$\\chi^2 = \\sum [\\omega * (\\hat{Y} - Y)]^2$$\n", " \n", "- *$Y$* is an array of observed data (e.g., accuracy, RT quantiles, etc.) \n", "- *$\\hat{Y}$* is an equal length array of corresponding model-predicted values, given by the parameterized model $f(\\theta)$\n", "- The error $\\chi^2$ between the predicted and the observed data (**$\\hat{Y} - Y$**) is weighted by an equal length array of scalars **$\\omega$** proportional to the inverse of the variance in each value of **$Y$**. \n", "\n", "### Accessing flat weights ($\\omega$) and data ($Y$) vectors \n", "\n", "```python\n", "flat_data = model.observedDF.mean()\n", "flat_wts = model.wtsDF.mean() \n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Model.optimize()\n", "```python\n", "Model.optimize(self, plotfits=True, saveplot=False, saveresults=True, saveobserved=False, custompath=None, progress=False):\n", " \"\"\" Method to be used for accessing fitting methods in Optimizer class\n", " see Optimizer method optimize()\n", " ::Arguments::\n", " plotfits (bool):\n", " if True (default), plot model predictions over observed data\n", " saveplot (bool):\n", " if True (default), save plots to \"~//\"\n", " saveresults (bool):\n", " if True (default), save fitDF, yhatDF, and txt logs to \"~//\"\n", " saveobserved (bool):\n", " if True (default is False), save observedDF to \"~//\"\n", " custompath (str):\n", " path starting from any subdirectory of \"~/\" (e.g., home).\n", " all saved output will write to \"~///\"\n", " progress (bool):\n", " track progress across ninits and basinhopping\n", " \"\"\"\n", "```\n", "\n", "### saving model output\n", "* By default the model creates a folder named after the model's model_id attribute in your user home directory (model_id is string with identifying information about the model) and saves all model output to this location (saveresults=True).\n", "\n", "* To prevent the model from creating the output directory\n", "\n", "```python\n", "m = build.Model(data=data)\n", "m.optimize(saveresults=False)\n", "```\n", "\n", "* Or to customize the location of the output directory\n", "\n", "```python\n", "# note, custompath must be an existing path to the parentdirectory \n", "# where you want the model's output directory to be created\n", "# save model output to /Users/kyle/Dropbox//\n", "m.optimize(custompath='Dropbox')\n", "```\n", "\n", "* You can also opt to save the model's observedDF: a pandas dataframe containing all subject's stop accuracy & RT quantile data\n", "\n", "```python\n", "# save model output and observed data to /Users/kyle/Dropbox//\n", "m.optimize(custompath='Dropbox', saveobserved=True)\n", "```\n", "\n", "\n", "### generating and saving plots\n", "* By default (plotfits=True), the optimize function plots the model-predicted stop accuracy and correct/error RT quantiles over the observed data. \n", "* You can opt to save the plot in the model's output directory by setting saveplot to True (defaults to False)\n", "\n", "```python\n", "# save fit plot to /Users/kyle/Dropbox//\n", "m.optimize(custompath='Dropbox', saveplot=True)\n", "```\n", "\n", "### progress bars\n", "\n", "* when optimizing a model, you can get feedback about the global optimization process by setting progress to True\n", "\n", "```python\n", "model.optimize(progress=True)\n", "```\n", "\n", "* The green bar tracks the initial parameter sets (current set / ninits). \n", "* The red bar gives feedback about the basinhopping run (current-step / global-minimum).\n", " * If the global minimum stays the same for \"nsuccess\" steps, global optimization is terminated for the current init set and the green bar advances (meaning a new round of global optimization has begun with the model initialized with the next init parameter set).\n", " * If a new global minimum is found, the red bar resets and the nsuccess count begins again from 0." ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# build a model with no conditional dependencies (a \"flat\" model)\n", "model = build.Model(data=data, kind='xdpm')\n", "\n", "# NOTE: fits in the binder demo will be much slower than when run locally\n", "# set_testing_params() sets more liberal fit criteria to speed things up a bit\n", "# comment this line out to run fits using the default optimization criteria\n", "model.set_testing_params()\n", "\n", "model.optimize(progress=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Accessing model fit information\n", "\n", "## parameter estimates and GOF stats\n", "* critical information about the model fit is contained in the models **fitDF** attribute (pandas DataFrame). \n", "* **fitDF** includes optimized parameter estimates as well as goodness-of-fit (GOF) statistics like AIC, BIC, standard (chi) and reduced (rchi) chi-square\n", "\n", "```python\n", "# display the model's fitDF\n", "model.fitDF\n", "```\n", "\n", "## model predictions\n", "* optimize also generates a **yhatDF** attribute, a pandas DataFrame containing the model-predicted stop-accuracy and correct/error RT quantiles (same column structure as model's observedDF)\n", "\n", "```python\n", "# display the model's yhatDF\n", "model.yhatDF\n", "```\n", "\n", "## save results post-optimization\n", "\n", "* if you set saveresults to False when running optimize and later decide to save the fitDF and yhatDF \n", "\n", "```python\n", "# to save yhatDF as csv file: \n", "model.fitDF.to_csv(\"path/to/save/fitDF.csv\", index=False)\n", "model.yhatDF.to_csv(\"path/to/save/yhatDF.csv\", index=False)\n", "```" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
idxassvtrvxbnfevnvarydfchirchilogpAICBICcnvrg
0average0.5849-1.9990.071670.98061.2545065190.0013797.259e-05-228.7-218.7-223.90
\n", "
" ], "text/plain": [ " idx a ssv tr v xb nfev nvary df chi \\\n", "0 average 0.5849 -1.999 0.07167 0.9806 1.254 506 5 19 0.001379 \n", "\n", " rchi logp AIC BIC cnvrg \n", "0 7.259e-05 -228.7 -218.7 -223.9 0 " ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Parameter estimates and GOF stats\n", "model.fitDF" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
idxflatacc200250300350400c10c20...c90e10e20e30e40e50e60e70e80e90
0averageflat0.98221.01.00.96750.5720.1090.51670.5317...0.61170.50670.52170.53170.54170.54670.55670.56670.58170.5867
\n", "

1 rows × 26 columns

\n", "
" ], "text/plain": [ " idx flat acc 200 250 300 350 400 c10 c20 \\\n", "0 average flat 0.9822 1.0 1.0 0.9675 0.572 0.109 0.5167 0.5317 \n", "\n", " ... c90 e10 e20 e30 e40 e50 e60 e70 \\\n", "0 ... 0.6117 0.5067 0.5217 0.5317 0.5417 0.5467 0.5567 0.5667 \n", "\n", " e80 e90 \n", "0 0.5817 0.5867 \n", "\n", "[1 rows x 26 columns]" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Model Predictions (same header as model.observedDF)\n", "model.yhatDF" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Fitting Conditional Models\n", "\n", "- Conditional models can be fit in which all parameters from **flat** model fit are held constant except for one or more designated **conditional** parameters which is free to vary across levels of an experimental condition of interest. \n", "\n", "### Steps in fitting routine: \n", "\n", "1. Global optimzation on flat data (average values collapsing across experimental conditions)\n", "2. Local optimzation using parameters passed from global optimizer as starting values \n", "3. Global optimzation of conditional parameters \n", "4. Local optimzation of conditional parameters passed from global optimizer\n", "\n", "\n", "- Conditional model fits are performed by holding all parameters constant except one or more **conditional** parameter(s) and minimizing following cost-function: \n", "\n", "$$\\chi^2 = \\sum_{i=0}^{N_c} [\\omega_i * (\\hat{Y_i} - Y_i)]^2$$\n", " \n", "- where $\\sum[\\omega_i*(\\hat{Y_i} - Y_i)]^2$ gives the cost ($\\chi^2$) for level $i$ of condition $C$\n", "- $\\chi^2$ is equal to the summed and squared error across all $N_c$ levels of that condition\n", "\n", "- Specifying parameter dependencies is done by providing the model with a **depends_on** dictionary when initializing the model with the format: **{parameter_id : condition_name}**.\n", "- For instance, in Dunovan et al., (2015) subjects performed two versions of a stop-signal task \n", " * **Baseline (\"bsl\")** condition: errors on **go** and **stop** trials are equally penalized \n", " * **Caution (\"pnl\")** condition: penalties 2x higher for **stop** trial errors (e.g., response=1)\n", "- To test the hypothesis that observed behavioral differences between penalty conditions was a result of a change Go drift-rate...\n", "\n", "``` py\n", "# define the model allowing Go drift-rate to vary across 'Cond'\n", "model_v = build.Model(kind='xdpm', depends_on={'v': 'Cond'})\n", "# run optimize to fit the full model (steps 1 & 2)\n", "model_v.optimize(progress=True)\n", "```\n", "\n", "### Typically...\n", "\n", "* ...you'll have multiple alternative hypotheses about which parameters will depend on various task conditions\n", "\n", "* For instance, instead of modulating the Go drift-rate, assymetric stop/go penalties might alter behavior by changing the height of the decision threshold **(a)**. \n", "\n", "* To implement this model:\n", "\n", "``` py\n", "# define the model allowing threshold to vary across 'Cond'\n", "model_a = build.Model(kind='xdpm', depends_on={'a': 'Cond'})\n", "# run optimize to fit the full model (steps 1 & 2)\n", "model_a.optimize(progress=True)\n", "```\n", "\n", "### compare model fits\n", "\n", "* To test the hypothesis: threshold (a) better than drift-rate (v)\n", "\n", "```python\n", "# If True, threshold model provides a better fit\n", "model_a.finfo['AIC'] < model_v.finfo['AIC']\n", "```\n", "\n", "### How to access conditional weights ($\\omega_i$) and data ($Y_i$) vectors\n", "```python\n", "# replace 'Cond' with whatever your condition is called in the input dataframe\n", "cond_data = model.observedDF.groupby('Cond').mean()\n", "cond_wts = model.wtsDF.groupby('Cond').mean() \n", "```" ] }, { "cell_type": "code", "execution_count": 47, "metadata": { "collapsed": false, "scrolled": false }, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# build conditional model and optimize with drift-rate free across levels of Cond\n", "model = build.Model(data=data, kind='xdpm', depends_on={'v':'Cond'})\n", "\n", "# NOTE: fits in the binder demo will be much slower than when run locally\n", "# uncomment line below to speed up the demo fits (at the expense of fit quality)\n", "# model.set_testing_params()\n", "\n", "model.optimize(progress=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Nested optimization of alternative models\n", "\n", "* Typically you'll have multiple competing hypotheses about which parameters are influenced by various task conditions\n", "\n", "* Nested optimization allows alternative models to be optimized using a single initial parameter set.\n", "\n", "```python \n", "model = build.Model(kind='xdpm', data=data)\n", "free_params = ['v', 'a', 'tr']\n", "model.nested_optimize(free=free_params, cond='Cond')\n", "```\n", "\n", "* After fitting the model with **depends_on={'v': 'Cond'}**, **'v'** is replaced by the next parameter in the **nested_models** list **('a' or boundary height in this case)** and the model is optimized with this dependency using the same init params as the original model\n", "\n", "* As a result, model selection is less likely to be biased by the initial parameter set\n", "\n", "* Also, because Step 1 takes significantly longer than Step 2, nested optimization of multiple models is significantly faster than optimizing each model individually through boths steps\n", "\n", "* Run the cell below and go take a leisurely coffee break" ] }, { "cell_type": "code", "execution_count": 48, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model = build.Model(kind='xdpm', data=data)\n", "nested_models = ['v', 'a', 'tr']\n", "\n", "# NOTE: fits in the binder demo will be much slower than when run locally\n", "# uncomment line below to speed up the demo fits (at the expense of fit quality)\n", "# model.set_testing_params()\n", "\n", "model.nested_optimize(free=nested_models, cond='Cond', progress=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Examine Nested Model Fits" ] }, { "cell_type": "code", "execution_count": 49, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
idxassvtrvxbbslpnlnfevnvarydfchirchilogpAICBICcnvrg
0v0.59775-1.19520.11453[0.934995901191,...1.97310.9350.9014415452460.00419349.116e-05-446.54-442.54-441.97False
1a[0.585008262263,...-1.19520.114530.922041.97310.585010.616012132460.00550810.00011974-433.45-429.45-428.88True
2tr0.59775-1.1952[0.11000000514, ...0.922041.97310.110.1200214762460.00492960.00010716-438.78-434.78-434.21False
\n", "
" ], "text/plain": [ " idx a ssv tr v \\\n", "0 v 0.59775 -1.1952 0.11453 [0.934995901191,... \n", "1 a [0.585008262263,... -1.1952 0.11453 0.92204 \n", "2 tr 0.59775 -1.1952 [0.11000000514, ... 0.92204 \n", "\n", " xb bsl pnl nfev nvary df chi rchi logp \\\n", "0 1.9731 0.935 0.90144 1545 2 46 0.0041934 9.116e-05 -446.54 \n", "1 1.9731 0.58501 0.61601 213 2 46 0.0055081 0.00011974 -433.45 \n", "2 1.9731 0.11 0.12002 1476 2 46 0.0049296 0.00010716 -438.78 \n", "\n", " AIC BIC cnvrg \n", "0 -442.54 -441.97 False \n", "1 -429.45 -428.88 True \n", "2 -434.78 -434.21 False " ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Parameter estimates and GOF stats\n", "model.fitDF" ] }, { "cell_type": "code", "execution_count": 51, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
idxCondacc200250300350400c10c20...c90e10e20e30e40e50e60e70e80e90
0vbsl0.989001.00.99950.95050.51000.08850.524530.52953...0.604530.509530.524530.529530.539530.544530.554530.559530.574530.58953
1vpnl0.974651.01.00000.97350.62200.14200.529530.53953...0.614530.524530.529530.539530.544530.554530.559530.574530.579530.59453
2absl0.990651.01.00000.95500.52800.10150.519530.53453...0.604530.514530.519530.534530.539530.549530.554530.564530.569530.58453
3apnl0.983251.01.00000.97350.61850.13000.529530.54453...0.614530.524530.529530.544530.554530.559530.564530.574530.579530.59453
4trbsl0.990201.01.00000.95400.52050.09300.520000.53500...0.605000.515000.520000.535000.540000.550000.555000.565000.570000.58500
5trpnl0.985701.01.00000.97900.62250.13750.530020.54502...0.615020.525020.530020.535020.550020.550020.565020.575020.580020.59502
\n", "

6 rows × 26 columns

\n", "
" ], "text/plain": [ " idx Cond acc 200 250 300 350 400 c10 c20 \\\n", "0 v bsl 0.98900 1.0 0.9995 0.9505 0.5100 0.0885 0.52453 0.52953 \n", "1 v pnl 0.97465 1.0 1.0000 0.9735 0.6220 0.1420 0.52953 0.53953 \n", "2 a bsl 0.99065 1.0 1.0000 0.9550 0.5280 0.1015 0.51953 0.53453 \n", "3 a pnl 0.98325 1.0 1.0000 0.9735 0.6185 0.1300 0.52953 0.54453 \n", "4 tr bsl 0.99020 1.0 1.0000 0.9540 0.5205 0.0930 0.52000 0.53500 \n", "5 tr pnl 0.98570 1.0 1.0000 0.9790 0.6225 0.1375 0.53002 0.54502 \n", "\n", " ... c90 e10 e20 e30 e40 e50 e60 \\\n", "0 ... 0.60453 0.50953 0.52453 0.52953 0.53953 0.54453 0.55453 \n", "1 ... 0.61453 0.52453 0.52953 0.53953 0.54453 0.55453 0.55953 \n", "2 ... 0.60453 0.51453 0.51953 0.53453 0.53953 0.54953 0.55453 \n", "3 ... 0.61453 0.52453 0.52953 0.54453 0.55453 0.55953 0.56453 \n", "4 ... 0.60500 0.51500 0.52000 0.53500 0.54000 0.55000 0.55500 \n", "5 ... 0.61502 0.52502 0.53002 0.53502 0.55002 0.55002 0.56502 \n", "\n", " e70 e80 e90 \n", "0 0.55953 0.57453 0.58953 \n", "1 0.57453 0.57953 0.59453 \n", "2 0.56453 0.56953 0.58453 \n", "3 0.57453 0.57953 0.59453 \n", "4 0.56500 0.57000 0.58500 \n", "5 0.57502 0.58002 0.59502 \n", "\n", "[6 rows x 26 columns]" ] }, "execution_count": 51, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Model Predictions (same header as model.observedDF)\n", "model.yhatDF" ] }, { "cell_type": "code", "execution_count": 50, "metadata": { "collapsed": false, "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "AIC likes v model\n", "BIC likes v model\n" ] }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Evaluate all nested fits and plot fit stats: AIC & BIC (Lower is better)\n", "# According to AIC & BIC, the model with go drift-rate (v_G) free across levels of Cond \n", "# provides a better fit than models with free threshold (a_G) or onset delay (tr_G)\n", "gof = vis.compare_nested_models(model.fitDF, verbose=True, model_ids=nested_models)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Testing model identity (parameter recovery)\n", "* **recover_model()** samples a large number of parameter sets and chooses one that yields a synthetic data vector close to the actual observed data vectors in **model.observed_flat** & **model.observed**\n", "* the model then goes through the fitting routine, optimizing an independent set of parameters to the synthetic data\n", "* when the model is finished optimizing, the recovered parameters are compared and plotted against the generative set of parameters. \n", "* if the values are close, this is evidence that the model is identifiable " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "scrolled": false }, "outputs": [], "source": [ "# build a model and sample parameters\n", "model = build.Model(data=data, kind='xdpm', depends_on={'v':'Cond'})\n", "# fit synthetic data and compare fitted vs. init params\n", "model.recover_model(progress=True, plotparams=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Troubleshooting Ugly Fits\n", "\n", "## Fit to individual subjects\n", "```python\n", "model = build.Model(data=data, fit_on='subjects')\n", "```\n", "## Other \"kinds\" of models...\n", "\n", "\n", "* Currently only Dependent Process Model **(kind='dpm')** and Independent Race Model **(kind='irace')**\n", "\n", "\n", "* Tell model to include a Dynamic Bias Signal **('xb')** by adding an **'x'** to the front of model **kind**\n", "\n", "\n", "* To implement the **Dependent Process Model**...\n", "\n", "```python\n", "#... with dynamic bias: \n", "model = build.Model(data=data, kind='xdpm')\n", "#...and without: \n", "model = build.Model(data=data, kind='dpm')\n", "```\n", "\n", "\n", "* To implement the **Independent Race Model**... \n", "\n", "\n", "```python\n", "#... with dynamic bias:\n", "model = build.Model(data=data, kind='xirace')\n", "#... and without:\n", "model = build.Model(data=data, kind='irace')\n", "```\n", "\n", "\n", "\n", "## Optimization parameters and cost weights...\n", "\n", "* set more conservative fit criteria\n", "\n", "```python\n", "model.set_basinparams(nsuccess=100, tol=1e-30, ninits=10, nsamples=10000) \n", "model.set_fitparams(maxfev=5000, tol=1e-35)\n", "```\n", "\n", "* Inspect cost function weights for extreme vals\n", "\n", "```python\n", "print(model.cond_wts)\n", "print(model.flat_wts)\n", "```\n", "* If the wts look suspect try re-running the fits with an unweighted model (all wts = 1) \n", "\n", "\n", "```python\n", " model = build.Model(data=data, weighted=False)\n", "```\n", "\n", "* Keep in mind that error RTs can be particularly troublesome, sometimes un-shootably so..." ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Notebook Theme: Grade3\n", "(https://www.github.com/dunovank/jupyter-themes)\n" ] }, { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import radd\n", "radd.style_notebook()" ] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.11" } }, "nbformat": 4, "nbformat_minor": 0 }