{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import sys\n", "import os\n", "nu_russell = os.path.realpath(os.path.join(os.getcwd(),'..','..'))\n", "sys.path.insert(0,nu_russell)\n", "import neuronunit\n", "import dask.bag as db\n", "from neuronunit.optimization import optimization_management as om\n", "import pickle\n", "\n", " \n", "from IPython.display import HTML, display\n", "import seaborn as sns\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "OrderedDict([('RS', {'b': -2, 'd': 100, 'a': 0.03, 'vPeak': 35, 'vr': -60, 'c': -50, 'vt': -40, 'k': 0.7, 'C': 100}), ('IB', {'b': 5, 'd': 130, 'a': 0.01, 'vPeak': 50, 'vr': -75, 'c': -56, 'vt': -45, 'k': 1.2, 'C': 150}), ('LTS', {'b': 8, 'd': 20, 'a': 0.03, 'vPeak': 40, 'vr': -56, 'c': -53, 'vt': -42, 'k': 1.0, 'C': 100}), ('TC', {'b': 15, 'd': 10, 'a': 0.01, 'vPeak': 35, 'vr': -60, 'c': -60, 'vt': -50, 'k': 1.6, 'C': 200}), ('TC_burst', {'b': 15, 'd': 10, 'a': 0.01, 'vPeak': 35, 'vr': -60, 'c': -60, 'vt': -50, 'k': 1.6, 'C': 200})])\n" ] } ], "source": [ "# http://www.physics.usyd.edu.au/teach_res/mp/mscripts/\n", "# ns_izh002.m\n", "import collections\n", "from collections import OrderedDict\n", "\n", "# Fast spiking cannot be reproduced as it requires modifications to the standard Izhi equation,\n", "# which are expressed in this mod file.\n", "# https://github.com/OpenSourceBrain/IzhikevichModel/blob/master/NEURON/izhi2007b.mod\n", "\n", "reduced2007 = collections.OrderedDict([\n", " # C k vr vt vpeak a b c d celltype\n", " ('RS', (100, 0.7, -60, -40, 35, 0.03, -2, -50, 100, 1)),\n", " ('IB', (150, 1.2, -75, -45, 50, 0.01, 5, -56, 130, 2)),\n", " ('LTS', (100, 1.0, -56, -42, 40, 0.03, 8, -53, 20, 4)),\n", " ('TC', (200, 1.6, -60, -50, 35, 0.01, 15, -60, 10, 6)),\n", " ('TC_burst', (200, 1.6, -60, -50, 35, 0.01, 15, -60, 10, 6)),\n", " ('RTN', (40, 0.25, -65, -45, 0, 0.015, 10, -55, 50, 7)),\n", " ('RTN_burst', (40, 0.25, -65, -45, 0, 0.015, 10, -55, 50, 7))])\n", "\n", "import numpy as np\n", "reduced_dict = OrderedDict([(k,[]) for k in ['C','k','vr','vt','vPeak','a','b','c','d']])\n", "\n", "#OrderedDict\n", "for i,k in enumerate(reduced_dict.keys()):\n", " for v in reduced2007.values():\n", " reduced_dict[k].append(v[i])\n", "\n", "explore_param = {k:(np.min(v),np.max(v)) for k,v in reduced_dict.items()}\n", "param_ranges = OrderedDict(explore_param)\n", "\n", "#IB = mparams[param_dict['IB']]\n", "RS = {}\n", "IB = {}\n", "TC = {}\n", "CH = {}\n", "RTN_burst = {}\n", "cells = OrderedDict([(k,[]) for k in ['RS','IB','CH','LTS','FS','TC','TC_burst','RTN','RTN_busrt']])\n", "reduced_cells = OrderedDict([(k,[]) for k in ['RS','IB','LTS','TC','TC_burst']])\n", "\n", "for index,key in enumerate(reduced_cells.keys()):\n", " reduced_cells[key] = {}\n", " for k,v in reduced_dict.items():\n", " reduced_cells[key][k] = v[index]\n", "\n", "print(reduced_cells)\n", "cells = reduced_cells" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [], "source": [ "model = None\n", "from neuronunit.models.reduced import ReducedModel\n", "from neuronunit.optimization import get_neab\n", "\n", "from neuronunit.optimization.model_parameters import model_params, path_params\n", "LEMS_MODEL_PATH = path_params['model_path']\n", "\n", "\n", "model = ReducedModel(LEMS_MODEL_PATH,name = str('vanilla'),backend = ('RAW'))\n", "model.set_attrs(cells['TC'])\n", "\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'value': array(74.921875) * pA}\n" ] } ], "source": [ "tests_,all_tests, observation,suite = get_neab.get_tests()\n", "#tests_,all_tests, observation,suite = opt.get_neab.get_tests()\n", "rheobase = all_tests[0].generate_prediction(model)\n", "print(rheobase)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[, , , , , , , ]\n" ] } ], "source": [ "print(all_tests)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false }, "outputs": [], "source": [ "cnt = 0\n", "scores = []\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import quantities as pq" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[array(74.921875) * pA, array(74.921875) * pA, array(74.921875) * pA]\n", "[array(74.921875) * pA, array(-10.0) * pA, array(-10.0) * pA, array(-10.0) * pA, array(-10.0) * pA]\n" ] } ], "source": [ "def format_iparams(all_tests,rheobase):\n", "\n", " for t in all_tests[1:5]:\n", " DURATION = 500.0*pq.ms\n", " DELAY = 200.0*pq.ms\n", "\n", " obs = t.observation\n", " t.params = {}\n", " t.params['injected_square_current'] = {}\n", " t.params['injected_square_current']['delay']= DELAY\n", " t.params['injected_square_current']['duration'] = DURATION\n", " t.params['injected_square_current']['amplitude'] = -10*pq.pA\n", " \n", " \n", " for t in all_tests[-3::]: \n", " t.params = {}\n", " DURATION = 1000.0*pq.ms\n", " DELAY = 100.0*pq.ms\n", "\n", " t.params['injected_square_current'] = {}\n", " t.params['injected_square_current']['delay']= DELAY\n", " t.params['injected_square_current']['duration'] = DURATION\n", " t.params['injected_square_current']['amplitude'] = rheobase['value']\n", " \n", " all_tests[0].params = all_tests[-1].params\n", " \n", " return all_tests\n", "\n", "pt = format_iparams(all_tests,rheobase)\n", "print([t.params['injected_square_current']['amplitude'] for t in pt[-3::] ])\n", "print([t.params['injected_square_current']['amplitude'] for t in pt[0:5] ])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "# * Get predictions from models.\n", "## * Fake NeuroElectro Observations\n", "## * Do roundtrip testing\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[{'value': array(74.921875) * pA}, {'value': array(31739282.99824782) * kg*m**2/(s**3*A**2)}, {'value': array(0.0034631849083750876) * s}, {'value': array(1.0911352057216523e-10) * s**4*A**2/(kg*m**2)}, {'mean': array(-0.060317405613048956) * V, 'std': array(0.00019622743842309925) * V}, {'mean': array(0.000675) * s, 'std': array(0.0) * s, 'n': 1}, {'mean': array(0.05802236247240052) * V, 'std': array(0.0) * V, 'n': 1}, {'mean': array(-0.02302236247240052) * V, 'std': array(0.0) * V, 'n': 1}]\n" ] } ], "source": [ "predictions = []\n", "\n", "# The rheobase has been obtained seperately and cannot be db mapped.\n", "# Nested DB mappings dont work (daemons cannot spawn daemonic processes).\n", "ptbag = db.from_sequence(pt[1::])\n", "\n", "def obtain_predictions(t): \n", " model = None\n", " model = ReducedModel(LEMS_MODEL_PATH,name = str('vanilla'),backend = ('RAW'))\n", " model.set_attrs(cells['TC'])\n", " return t.generate_prediction(model)\n", "predictions = list(ptbag.map(obtain_predictions).compute())\n", "predictions.insert(0,rheobase)\n", "print(predictions) " ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "{'C': 200,\n", " 'a': 0.01,\n", " 'b': 15,\n", " 'c': -60,\n", " 'd': 10,\n", " 'k': 1.6,\n", " 'vPeak': 35,\n", " 'vr': -60,\n", " 'vt': -50}" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cells['TC']" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[{'value': array(74.921875) * pA}, {'value': array(31739282.99824782) * kg*m**2/(s**3*A**2)}, {'value': array(0.0034631849083750876) * s}, {'value': array(1.0911352057216523e-10) * s**4*A**2/(kg*m**2)}, {'std': array(0.00019622743842309925) * V, 'value': array(-0.060317405613048956) * V}, {'std': array(0.0) * s, 'value': array(0.000675) * s, 'n': 1}, {'std': array(0.0) * V, 'value': array(0.05802236247240052) * V, 'n': 1}, {'std': array(0.0) * V, 'value': array(-0.02302236247240052) * V, 'n': 1}]\n" ] } ], "source": [ "# having both means and values in dictionary makes it very irritating to iterate over.\n", "# It's more harmless to demote means to values, than to elevate values to means.\n", "# Simply swap key names: means, for values.\n", "for p in predictions:\n", " if 'mean' in p.keys():\n", " p['value'] = p.pop('mean')\n", "print(predictions)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Make some new tests based on internally generated data \n", " as opposed to experimental data." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": false }, "outputs": [ { "ename": "NameError", "evalue": "name 'copy' is not defined", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mTC_tests\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcopy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcopy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mall_tests\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mind\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mTC_tests\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;34m'mean'\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mobservation\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mNameError\u001b[0m: name 'copy' is not defined" ] } ], "source": [ "\n", "\n", "TC_tests = copy.copy(all_tests)\n", "for ind,t in enumerate(TC_tests):\n", " if 'mean' in t.observation.keys():\n", " t.observation['value'] = t.observation.pop('mean')\n", " pred = predictions[ind]['value']\n", " try:\n", " pred = pred.rescale(t.units)\n", " t.observation['value'] = pred\n", " except: \n", " t.observation['value'] = pred\n", " t.observation['mean'] = t.observation['value']\n", " #t.observation['std'] = 0.0\n", " \n", " if float(t.observation['std']) == 0.0:\n", " print('got here')\n", " t.observation['std'] = 5.0*t.observation['mean'].units\n", " \n", "pickle.dump(TC_tests,open('thalamo_cortical_tests.p','wb')) \n", "\n", "\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "\n", "## Call Genetic Algorithm optimizer\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from neuronunit.optimization import optimization_management as om\n", "free_params = ['a','b','vr','vt','k'] # this can only be odd numbers.\n", "#2**3\n", "hc = {}\n", "for k,v in cells['TC'].items():\n", " if k not in free_params:\n", " hc[k] = v\n", "#print(hc)\n", "import pickle\n", "TC_tests = pickle.load(open('thalamo_cortical_tests.p','rb')) \n", " #run_ga(model_params, max_ngen, test, free_params = None, hc = None)\n", " \n", "#ga_out, DO = om.run_ga(explore_param,10,TC_tests,free_params=free_params,hc = hc, NSGA = False, MU = 10)\n", " \n", "#ga_out_nsga, _ = om.run_ga(explore_param,1,TC_tests,free_params=free_params,hc = hc, NSGA = True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from neuronunit.optimization import optimization_management as om\n", "free_params = ['a','b','vr','vt','k'] # this can only be odd numbers.\n", "#2**3\n", "hc = {}\n", "for k,v in cells['TC'].items():\n", " if k not in free_params:\n", " hc[k] = v\n", "#print(hc)\n", "import pickle\n", "try:\n", " assert 1==2\n", " ga_out_nsga = pickle.load(open('chatter_ga_out_nsga.p','rb'))\n", " \n", "except:\n", " TC_tests = pickle.load(open('thalamo_cortical_tests.p','rb')) \n", " ga_out_nsga, _ = om.run_ga(explore_param,25,TC_tests,free_params=free_params,hc = hc, NSGA = True)\n", " pickle.dump(ga_out_nsga,open('chatter_ga_out_nsga.p','wb'))\n", "\n", " " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pickle.dump(ga_out_nsga,open('chatter_ga_out_nsga.p','wb'))\n", "print(ga_out_nsga['hardened'][0].dtc.attrs)\n", "print(cells['TC'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "\n", "for item in ['avg','max','min']:\n", " plt.plot([x[item] for x in ga_out_nsga['log']],label=item)\n", "plt.legend()\n", "plt.show()\n", "\n", "history = ga_out_nsga['history']\n", "hof = ga_out_nsga['hof']\n", "temp = [ v.dtc for k,v in history.genealogy_history.items() ]\n", "temp = [ i for i in temp if type(i) is not type(None)]\n", "temp = [ i for i in temp if len(list(i.attrs.values())) ]\n", "true_history = [ (v, np.sum(list(v.scores.values()))) for v in temp ]\n", "plt.plot([i for i,j in enumerate(true_history)],[ i[1] for i in true_history ])\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "true_history = ga_out_nsga['hardened']\n", "true_mins = sorted(true_history, key=lambda h: h[1])\n", "\n", "print(true_mins[0][1], true_mins[0][0])\n", "try:\n", " if true_mins[0][1] < np.sum(list(hof[0].dtc.scores.values())):\n", " #print('history unreliable')\n", " hof = [i[0] for i in true_mins]\n", " best = hof[0]\n", " best_attrs = best.attrs\n", " else:\n", " best = ga_out_nsga['dhof'][0]\n", " best_attrs = ga_out_nsga['dhof'][0].attrs\n", " ga_out_nsga['dhof'][0].scores\n", "except:\n", " best = ga_out_nsga['dhof'][0]\n", " best_attrs = ga_out_nsga['dhof'][0].attrs\n", " ga_out_nsga['dhof'][0].scores\n", "#true_mins " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "best = true_mins[0][0]\n", "best_attrs = true_mins[0][0].attrs" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "print(best.rheobase)\n", "\n", "print('best',best.get_ss())" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "# with five parameters, and 20 generations a really good fit can be found" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "fig = plt.figure()\n", "\n", "model1 = ReducedModel(LEMS_MODEL_PATH,name = str('vanilla'),backend = ('RAW'))\n", "model1.attrs.update(best_attrs)\n", "rheobase = best.rheobase\n", "iparams['injected_square_current']['amplitude'] = rheobase\n", "model1.inject_square_current(iparams) # this one is rheobase firing failure.\n", "\n", "model2 = ReducedModel(LEMS_MODEL_PATH,name = str('vanilla'),backend = ('RAW'))\n", "model2.set_attrs(cells['TC'])\n", "model2.attrs.update(cells['TC'])\n", "\n", "\n", "iparams['injected_square_current']['amplitude'] =75.36800000000001*pq.pA\n", "model2.inject_square_current(iparams) # this one is rheobase success.\n", "\n", "print(model2.attrs)\n", "\n", "times = model1.get_membrane_potential().times\n", "plt.subplot(2, 1, 1)\n", "\n", "plt.plot(times,model1.get_membrane_potential(),c='red',label='optimizer result')\n", "plt.plot(times,model2.get_membrane_potential(),label='ground truth') #fires\n", "plt.legend()\n", " #fires\n", "#plt.show()\n", "amplitude = iparams['injected_square_current']['amplitude']\n", "delay = iparams['injected_square_current']['delay']\n", "duration = iparams['injected_square_current']['duration']\n", "tMax = float(delay) + float(duration) + 200.0#*pq.ms\n", "\n", "dt = 0.025\n", "N = int(tMax/dt)\n", "Iext = np.zeros(N)\n", "delay_ind = int((delay/tMax)*N)\n", "duration_ind = int((duration/tMax)*N)\n", "Iext[0:delay_ind-1] = 0.0\n", "Iext[delay_ind:delay_ind+duration_ind-1] = amplitude\n", "#print(np.sum(Iext),amplitude*len(Iext[delay_ind:delay_ind+duration_ind-1]))\n", "Iext[delay_ind+duration_ind::] = 0.0\n", "plt.subplot(2, 1, 2)\n", "plt.plot(times,Iext,label='current')\n", "plt.savefig('high_resolution.png')\n", "#plt.show()\n", "#print(first_int,second_int,len(times))\n", "\n", "#plt.show(bbox_inches='tight')\n", "#plt.tight_layout()\n", "#plt.subplots_adjust(left=0.2,right=0.8)\n", "#fig.tight_layout()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "fig = plt.figure()\n", "\n", "model1 = ReducedModel(LEMS_MODEL_PATH,name = str('vanilla'),backend = ('RAW'))\n", "model1.attrs.update(best_attrs)\n", "rheobase = best.rheobase\n", "iparams['injected_square_current']['amplitude'] = rheobase\n", "model1.inject_square_current(iparams) # this one is rheobase firing failure.\n", "\n", "model2 = ReducedModel(LEMS_MODEL_PATH,name = str('vanilla'),backend = ('RAW'))\n", "model2.set_attrs(cells['TC'])\n", "model2.attrs.update(cells['TC'])\n", "\n", "\n", "iparams['injected_square_current']['amplitude'] =75.36800000000001*pq.pA\n", "model2.inject_square_current(iparams) # this one is rheobase success.\n", "\n", "print(model2.attrs)\n", "\n", "times = model1.get_membrane_potential().times\n", "plt.subplot(2, 1, 1)\n", "\n", "#plt.plot(times,model1.get_membrane_potential(),label='optimizer result')\n", "plt.plot(times,model2.get_membrane_potential(),label='ground truth') #fires\n", "plt.legend()\n", " #fires\n", "#plt.show()\n", "amplitude = iparams['injected_square_current']['amplitude']\n", "delay = iparams['injected_square_current']['delay']\n", "duration = iparams['injected_square_current']['duration']\n", "tMax = float(delay) + float(duration) + 200.0#*pq.ms\n", "\n", "dt = 0.025\n", "N = int(tMax/dt)\n", "Iext = np.zeros(N)\n", "delay_ind = int((delay/tMax)*N)\n", "duration_ind = int((duration/tMax)*N)\n", "Iext[0:delay_ind-1] = 0.0\n", "Iext[delay_ind:delay_ind+duration_ind-1] = amplitude\n", "#print(np.sum(Iext),amplitude*len(Iext[delay_ind:delay_ind+duration_ind-1]))\n", "Iext[delay_ind+duration_ind::] = 0.0\n", "plt.subplot(2, 1, 2)\n", "plt.plot(times,Iext,label='current')\n", "plt.savefig('high_resolution.png')\n", "#plt.show()\n", "#print(first_int,second_int,len(times))\n", "\n", "#plt.show(bbox_inches='tight')\n", "#plt.tight_layout()\n", "#plt.subplots_adjust(left=0.2,right=0.8)\n", "#fig.tight_layout()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "abs_max = true_mins[-1][1]\n", "abs_min = true_mins[0][1]\n", "pop = ga_out_nsga['pop']\n", "print(ga_out_nsga['hof'][0].dtc.get_ss())\n", "ga_out_nsga['hardened'][0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "\n", "from mpl_toolkits.mplot3d import Axes3D\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from matplotlib import cm\n", "\n", "from matplotlib import animation, rc\n", "from IPython.display import HTML\n", "font = {'family' : 'normal',\n", " 'weight' : 'bold',\n", " 'size' : 16}\n", "import matplotlib\n", "matplotlib.rc('font', **font)\n", "plt.rc('xtick', labelsize=5) # fontsize of the tick labels\n", "plt.rc('ytick', labelsize=5)\n", "super_set = []\n", "gen_vs_pop = ga_out_nsga['gen_vs_pop']\n", "\n", "length = len(gen_vs_pop)\n", "\n", "lims = []\n", "for k,v in pop[0].dtc.attrs.items():\n", " lims.append([np.min(explore_param[k]), np.max(explore_param[k])])\n", "\n", "x = 0\n", "y = 1\n", "z = 2\n", "plt.clf()\n", "from collections import OrderedDict\n", "\n", "\n", "from sympy.combinatorics.graycode import GrayCode\n", "a = GrayCode(5)\n", "print(a)\n", "pre_empt = list(a.generate_gray())\n", "\n", "for i, pop in enumerate(gen_vs_pop):\n", " other_points = []\n", " pf_points = []\n", " hof_points = [] \n", " labels = []\n", " od = OrderedDict(pop[0].dtc.attrs)\n", " for p in pop:\n", " \n", " xyz = []\n", " for k in od.keys():\n", " v = p.dtc.attrs[k]\n", " xyz.append(v)\n", " labels.append(k)\n", " other_points.append(xyz)\n", " best_xyz = []\n", " \n", " #for k,v in values(): \n", " for k in od.keys():\n", " v = ga_out_nsga['hof'][0].dtc.attrs[k]\n", " best_xyz.append(v)\n", " best_error = ga_out_nsga['hof'][0].dtc.get_ss()\n", "\n", " fig = plt.figure()\n", " fig, ax = plt.subplots(1, 1)#, figsize=figsize)\n", " ax = Axes3D(fig)\n", " \n", "\n", " ax.set_xlim(lims[x])\n", " ax.set_ylim(lims[y])\n", " ax.set_zlim(lims[z])\n", " \n", " title='Model Sample Evolution in 3D space, frame:' +str(i)#,\n", " title_fontsize=\"large\"#,\n", " text_fontsize=\"medium\"\n", " ax.set_title(title, fontsize=title_fontsize)\n", "\n", "\n", " errors = [ p.dtc.get_ss() for p in pop ]\n", " xx = [ i[x] for i in other_points ]\n", " yy = [ i[y] for i in other_points ]\n", " zz = [ i[z] for i in other_points ]\n", " if len(super_set) !=0 :\n", " for ss in super_set:\n", " ops, ers = ss\n", " p0 = ax.scatter3D([i[0] for i in ops], [ i[1] for i in ops], [i[2] for i in ops], s=100, alpha=0.0925, c=ers, cmap='jet', marker='o', vmin=abs_min, vmax=abs_max)\n", "\n", " p1 = ax.scatter3D(xx, yy, zz, c=errors, cmap='jet', marker='o', s=100, vmin=abs_min, vmax=abs_max)\n", " if i == length:\n", " p2 = ax.scatter3D(best_xyz[x], best_xyz[y], best_xyz[z], c=best_error, cmap='jet', marker='o', s=100, vmin=abs_min, vmax=abs_max)\n", "\n", " cb = fig.colorbar(p1)\n", " cb.set_label('summed scores')\n", "\n", " ax.set_xlabel(str(labels[x]))\n", " ax.set_ylabel(str(labels[y]))\n", " ax.set_zlabel(str(labels[z]))\n", " for item in ([ax.xaxis.label, ax.yaxis.label, ax.zaxis.label]):\n", " item.set_fontsize(20)\n", " \n", " #for item in ([ax.get_xticklabels() + ax.get_yticklabels() + ax.get_zticklabels()]):\n", " # item.set_fontsize(10) \n", " plt.savefig(str(i)+str('.png'))\n", " super_set.append((other_points,errors)) \n", " plt.show()\n", " \n", "# ls -v *.png >> sorted.txt \n", "# convert -delay 100 -loop 0 @sorted.txt animation.mp4 \n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "\n", "from mpl_toolkits.mplot3d import Axes3D\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from matplotlib import cm\n", "\n", "from matplotlib import animation, rc\n", "from IPython.display import HTML\n", "\n", "from matplotlib import animation, rc\n", "\n", "from IPython.display import HTML\n", "font = {'family' : 'normal',\n", " 'weight' : 'bold',\n", " 'size' : 16}\n", "import matplotlib\n", "matplotlib.rc('font', **font)\n", "plt.rc('xtick', labelsize=5) # fontsize of the tick labels\n", "plt.rc('ytick', labelsize=5)\n", "\n", "\n", "super_set = []\n", "gen_vs_pop = ga_out_nsga['gen_vs_pop']\n", "\n", "length = len(gen_vs_pop)\n", "\n", "lims = []\n", "for k,v in pop[0].dtc.attrs.items():\n", " lims.append([np.min(explore_param[k]), np.max(explore_param[k])])\n", "\n", "x = 2\n", "y = 3\n", "z = 4\n", "plt.clf()\n", "from collections import OrderedDict\n", "\n", "\n", "from sympy.combinatorics.graycode import GrayCode\n", "a = GrayCode(5)\n", "print(a)\n", "pre_empt = list(a.generate_gray())\n", "\n", "for i, pop in enumerate(gen_vs_pop):\n", " other_points = []\n", " pf_points = []\n", " hof_points = [] \n", " labels = []\n", " od = OrderedDict(pop[0].dtc.attrs)\n", " for p in pop:\n", " \n", " xyz = []\n", " for k in od.keys():\n", " v = p.dtc.attrs[k]\n", " xyz.append(v)\n", " labels.append(k)\n", " other_points.append(xyz)\n", " best_xyz = []\n", " \n", " #for k,v in values(): \n", " for k in od.keys():\n", " v = ga_out_nsga['hof'][0].dtc.attrs[k]\n", " best_xyz.append(v)\n", " best_error = ga_out_nsga['hof'][0].dtc.get_ss()\n", "\n", " fig = plt.figure()\n", " fig, ax = plt.subplots(1, 1)#, figsize=figsize)\n", " ax = Axes3D(fig)\n", " \n", "\n", " ax.set_xlim(lims[x])\n", " ax.set_ylim(lims[y])\n", " ax.set_zlim(lims[z])\n", " \n", " title='Particle Movement in 3D space, frame:' +str(i)#,\n", " title_fontsize=\"large\"#,\n", " text_fontsize=\"medium\"\n", " ax.set_title(title, fontsize=title_fontsize)\n", "\n", "\n", " errors = [ p.dtc.get_ss() for p in pop ]\n", " xx = [ i[x] for i in other_points ]\n", " yy = [ i[y] for i in other_points ]\n", " zz = [ i[z] for i in other_points ]\n", " if len(super_set) !=0 :\n", " for ss in super_set:\n", " ops, ers = ss\n", " print('gets here')\n", " p0 = ax.scatter3D([i[0] for i in ops], [ i[1] for i in ops], [i[2] for i in ops], s=100, alpha=0.0925, c=ers, cmap='jet', marker='o', vmin=abs_min, vmax=abs_max)\n", "\n", " p1 = ax.scatter3D(xx, yy, zz, c=errors, s=100, cmap='jet', marker='o', vmin=abs_min, vmax=abs_max)\n", " if i == length:\n", " p2 = ax.scatter3D(best_xyz[x], best_xyz[y], best_xyz[z], c=best_error, cmap='jet', marker='o', s=100, vmin=abs_min, vmax=abs_max)\n", "\n", " cb = fig.colorbar(p1)\n", " cb.set_label('summed scores')\n", "\n", " ax.set_xlabel(str(labels[x]))\n", " ax.set_ylabel(str(labels[y]))\n", " ax.set_zlabel(str(labels[z]))\n", " for item in ([ax.xaxis.label, ax.yaxis.label, ax.zaxis.label]):\n", " item.set_fontsize(20)\n", " \n", " plt.savefig(str(i)+str('.png'))\n", " super_set.append((other_points,errors)) \n", " plt.show()\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "for item in ['avg','max','min']:\n", " plt.plot([x[item] for x in ga_out_nsga['log']],label=item)\n", "plt.legend()\n", "\n", "ga_out_nsga.keys()\n", "\n", "#ga_out['gen_vs_pop']" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "\n", "for item in ['avg','std']:\n", " plt.plot([x['avg']+x['std'] for x in ga_out_nsga['log']],label=item)\n", " plt.plot([x['avg'] for x in ga_out_nsga['log']],label=item)\n", " plt.plot([x['avg']-x['std'] for x in ga_out_nsga['log']],label=item)\n", "plt.legend()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "plt.plot([x['std'] for x in ga_out_nsga['log']],label=item)\n", "plt.legend()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "#df = df.T\n", "dfs = []\n", "for i,dhof in enumerate(true_history):\n", " agreement = {}\n", " for k,v in dhof[0].score.items():\n", " print(k,v['value'])\n", " agreement[k] = v\n", "\n", " dfs.append(agreement)\n", " \n", "df = pd.DataFrame(dfs)\n", "df = df[0:3]\n", " \n", "#f, ax = plt.subplots(figsize=(9, 6))\n", "#sns.heatmap(df, annot=True, linewidths=.5, ax=ax, vmin=0, vmax=1) \n", "df" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "#df = pd.DataFrame(index=best.scores.keys(),columns=best.scores.keys())\n", "#df = df.T\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "sns.set()\n", "\n", "\n", "dfs = []\n", "for i,dhof in enumerate(true_history):\n", " new = {}\n", " #print(dir(dhof[0].score))\n", " for k,v in dhof[0].scores.items():\n", " if 'Injected' in str(k):\n", " #print(k,v)\n", " #print(v)#.prediction\n", " #v#.observation\n", " k = k[15::]\n", " k = k[0:-4]\n", " new[k] = v\n", " #if np.nan(dhof[0].scores)\n", "\n", " dfs.append(new)\n", " \n", "df = pd.DataFrame(dfs)\n", "df = df[0:3]\n", "\n", "\n", "\n", "f, ax = plt.subplots(figsize=(9, 6))\n", "sns.heatmap(df, annot=True, linewidths=.5, ax=ax, vmin=0, vmax=1) " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "#df = pd.DataFrame(index=best.scores.keys(),columns=best.scores.keys())\n", "#df = df.T\n", "gen_vs_pop = ga_out_nsga['gen_vs_pop']\n", "\n", "dfs = []\n", "for i,dhof in enumerate(gen_vs_pop[0]):\n", " new = {}\n", " for k,v in dhof.dtc.scores.items():\n", " if 'Injected' in str(k):\n", " k = k[15::]\n", " k = k[0:-4]\n", " new[k] = v\n", " #if np.nan(dhof[0].scores)\n", "\n", " dfs.append(new)\n", " \n", "\n", "df = pd.DataFrame(dfs)\n", "df = df[0:3]\n", "\n", "\n", "#dfg = df.reset_index(drop=True)\n", "\n", "\n", "f, ax = plt.subplots(figsize=(9, 6))\n", "sns.heatmap(df, annot=True, linewidths=.5, ax=ax, vmin=0, vmax=1)\n", "\n", "# Set colormap equal to seaborns light green color palette\n", "#cm = sns.light_palette(\"green\", as_cmap=True)\n", "#display(dfg.style.background_gradient(cmap=cm))#,subset=['total']))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "sns.set()\n", "\n", "# Load the example flights dataset and conver to long-form\n", "flights_long = sns.load_dataset(\"flights\")\n", "flights = flights_long.pivot(\"month\", \"year\", \"passengers\")\n", "\n", "# Draw a heatmap with the numeric values in each cell\n", "f, ax = plt.subplots(figsize=(9, 6))\n", "sns.heatmap(flights, annot=True, fmt=\"d\", linewidths=.5, ax=ax)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "#df = pd.DataFrame(index=best.scores.keys(),columns=best.scores.keys())\n", "df = df.T\n", "dfs = []\n", "for i,dhof in enumerate(true_history):\n", " new = {}\n", " for k,v in dhof[0].scores.items():\n", " if 'Injected' in str(k):\n", " k = k[15::]\n", " k = k[0:-4]\n", " new[k] = v\n", " #if np.nan(dhof[0].scores)\n", "\n", " dfs.append(new)\n", " \n", " \n", "from IPython.display import HTML, display\n", "import seaborn as sns\n", "\n", "df = pd.DataFrame(dfs)\n", "df = df[-4::]\n", "\n", "\n", "dfg = df.reset_index(drop=True)\n", "\n", "# Set colormap equal to seaborns light green color palette\n", "cm = sns.light_palette(\"green\", as_cmap=True)\n", "display(dfg.style.background_gradient(cmap=cm))#,subset=['total']))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "\n", "f, ax = plt.subplots(figsize=(9, 6))\n", "sns.heatmap(dfg, annot=True, linewidths=.5, ax=ax)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "score = tests_[0].judge(model1)\n", "score.summarize()\n", "score.sort_key" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from neuronunit.optimization import optimization_management as om\n", "free_params = ['a','b','vr'] # this can only be odd numbers.\n", "2**3\n", "hc = {}\n", "for k,v in cells['TC'].items():\n", " if k not in free_params:\n", " hc[k] = v\n", "#print(hc)\n", "import pickle\n", "TC_tests = pickle.load(open('thalamo_cortical_tests.p','rb')) \n", " #run_ga(model_params, max_ngen, test, free_params = None, hc = None)\n", " \n", "#ga_out, DO = om.run_ga(explore_param,10,TC_tests,free_params=free_params,hc = hc, NSGA = False, MU = 10)\n", " \n", "ga_out_sbest, _ = om.run_ga(explore_param,20,TC_tests,free_params=free_params,hc = hc, NSGA = False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "for item in ['avg','max','min']:\n", " plt.plot([x[item] for x in ga_out_sbest['log']],label=item)\n", "plt.legend()\n", "\n", "\n", "best = ga_out['dhof'][0]\n", "best_attrs = ga_out_sbest['dhof'][0].attrs\n", "print('best nsga',np.sum(list(ga_out_nsga['dhof'][0].scores.values())))\n", "print('best BPO select best: ',np.sum(list(ga_out_sbest['dhof'][0].scores.values())))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "plt.plot([x['std'] for x in ga_out_sbest['log']],label=item)\n", "plt.legend()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def hack_judge(test_and_models):\n", " (test, attrs) = test_and_models\n", " model = None\n", " obs = test.observation\n", " model = ReducedModel(LEMS_MODEL_PATH,name = str('vanilla'),backend = ('RAW'))\n", " model.set_attrs(attrs)\n", " test.generate_prediction(model)\n", " pred = test.generate_prediction(model)\n", " score = test.compute_score(obs,pred)\n", " try:\n", " print(obs['value'],pred['value'])\n", " except:\n", " print(obs['mean'],pred['mean'])\n", " \n", " return score\n", "\n", "scores = []\n", "for i,t in enumerate(TC_tests):\n", " test_and_models = (t,cells['TC'])\n", " score = hack_judge(test_and_models)\n", " scores.append(score)\n", "print(scores[0].norm_score) \n", "print(scores[0]) \n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "print([s.norm_score for s in scores])\n", "print([s.score for s in scores])\n", "\n", "score = hack_judge((TC_tests[-3],cells['TC']))\n", "print(score)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "scores = []\n", "for t in TC_tests:\n", " test_and_models = (t,cells['RS'])\n", " score = hack_judge(test_and_models)\n", " scores.append(score)\n", "print(scores[0].norm_score) \n", "print(scores[0])\n", "print([s.norm_score for s in scores])\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import dask.bag as db\n", "# The rheobase has been obtained seperately and cannot be db mapped.\n", "# Nested DB mappings dont work.\n", "from itertools import repeat\n", "test_a_models = zip(TC_tests[1::],repeat(cells['RS']))\n", "tc_bag = db.from_sequence(test_a_models)\n", "\n", "scores = list(tc_bag.map(hack_judge).compute())\n", "scores.insert(0,rheobase)\n", "print(scores) " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "score = TC_tests[0].judge(model,stop_on_error = False, deep_error = True)\n", "print(score.prediction)\n", "#print(model.get_spike_count())" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.2" } }, "nbformat": 4, "nbformat_minor": 2 }