{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Evolutionary parameter search with a single neural mass model\n", "This notebook provides a simple example for the use of the evolutionary optimization framework built-in to the library. Under the hood, the implementation of the evolutionary algorithm is powered by `deap` and `pypet` cares about the parallelization and storage of the simulation data for us.\n", "\n", "We want to optimize for a simple target, namely finding a parameter configuration that produces activity with a peak power frequency spectrum at 25 Hz.\n", "\n", "In this notebook, we will also plot the evolutionary genealogy tree, to visualize how the population evolves over generations." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# change to the root directory of the project\n", "import os\n", "if os.getcwd().split(\"/\")[-1] == \"examples\":\n", " os.chdir('..')\n", " \n", "# This will reload all imports as soon as the code changes\n", "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "try:\n", " import matplotlib.pyplot as plt\n", "except ImportError:\n", " import sys\n", " !{sys.executable} -m pip install matplotlib seaborn\n", " import matplotlib.pyplot as plt\n", " \n", "import numpy as np\n", "import logging \n", "\n", "from neurolib.models.aln import ALNModel\n", "from neurolib.utils.parameterSpace import ParameterSpace\n", "from neurolib.optimize.evolution import Evolution\n", "import neurolib.utils.functions as func\n", "\n", "import neurolib.optimize.evolution.deapUtils as deapUtils\n", "\n", "# a nice color map\n", "plt.rcParams['image.cmap'] = 'plasma'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model definition" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "aln = ALNModel()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# Here we define our evaluation function. This function will\n", "# be called reapedly and perform a single simulation. The object\n", "# that is passed to the function, `traj`, is a pypet trajectory\n", "# and serves as a \"bridge\" to load the parameter set of this \n", "# particular trajectory and execute a run.\n", "# Then the power spectrum of the run is computed and its maximum\n", "# is fitted to the target of 25 Hz peak frequency.\n", "def evaluateSimulation(traj):\n", " # The trajectory id is provided as an attribute\n", " rid = traj.id\n", " logging.info(\"Running run id {}\".format(rid))\n", " # this function provides the a model with the partuclar\n", " # parameter set for this given run\n", " model = evolution.getModelFromTraj(traj)\n", " # parameters can also be modified after loading\n", " model.params['dt'] = 0.1\n", " model.params['duration'] = 2*1000.\n", " # and the simulation is run\n", " model.run()\n", " \n", " # compute power spectrum\n", " frs, powers = func.getPowerSpectrum(model.rates_exc[:, -int(1000/model.params['dt']):], dt=model.params['dt'])\n", " # find the peak frequency\n", " domfr = frs[np.argmax(powers)] \n", " # fitness evaluation: let's try to find a 25 Hz oscillation\n", " fitness = abs(domfr - 25) \n", " # deap needs a fitness *tuple*!\n", " fitness_tuple = ()\n", " # more fitness values could be added\n", " fitness_tuple += (fitness, )\n", " # we need to return the fitness tuple and the outputs of the model\n", " return fitness_tuple, model.outputs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Initialize and run evolution" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The evolutionary algorithm tries to find the optimal parameter set that will maximize (or minimize) a certain fitness function. \n", "\n", "This achieved by seeding an initial population of size `POP_INIT_SIZE` that is randomly initiated in the parameter space `parameterSpace`. INIT: After simulating the initial population using `evalFunction`, only a subset of the individuals is kept, defined by `POP_SIZE`. \n", "\n", "START: Members of the remaining population are chosen based on their fitness (using rank selection) to mate and produce `offspring`. These `offspring` have parameters that are drawn from a normal distribution defined by the mean of the parameters between the two parents. Then the `offspring` population is evaluated and the process loops back to START: \n", "\n", "This process is repeated for `NGEN` generations." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Here we define the parameters and the range in which we want\n", "# to perform the evolutionary optimization.\n", "# Create a `ParameterSpace` \n", "pars = ParameterSpace(['mue_ext_mean', 'mui_ext_mean'], [[0.0, 4.0], [0.0, 4.0]])\n", "# Iitialize evolution with\n", "# :evaluateSimulation: The function that returns a fitness, \n", "# :pars: The parameter space and its boundaries to optimize\n", "# :model: The model that should be passed to the evaluation function\n", "# :weightList: A list of optimization weights for the `fitness_tuple`,\n", "# positive values will lead to a maximization, negative \n", "# values to a minimzation. The length of this list must\n", "# be the same as the length of the `fitness_tuple`.\n", "# \n", "# :POP_INIT_SIZE: The size of the initial population that will be \n", "# randomly sampled in the parameter space `pars`.\n", "# Should be higher than POP_SIZE. 50-200 might be a good\n", "# range to start experimenting with.\n", "# :POP_SIZE: Size of the population that should evolve. Must be an\n", "# even number. 20-100 might be a good range to start with.\n", "# :NGEN: Number of generations to simulate the evolution for. A good\n", "# range to start with might be 20-100.\n", "\n", "weightList = [-1.0]\n", "\n", "evolution = Evolution(evalFunction = evaluateSimulation, parameterSpace = pars, model = aln, weightList = [-1.0],\n", " POP_INIT_SIZE=4, POP_SIZE = 4, NGEN=2, filename=\"example-2.1.hdf\")\n", "# info: chose POP_INIT_SIZE=50, POP_SIZE = 20, NGEN=20 for real exploration, \n", "# values are lower here for testing" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "# Enabling `verbose = True` will print statistics and generate plots \n", "# of the current population for each generation.\n", "evolution.run(verbose = False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Analysis" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Population" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Best individual [1.182184510022096, 0.29660620374273683, 0.4936712969767474, 0.07875430013351538] fitness (0.0,)\n" ] } ], "source": [ "# the current population is always accesible via\n", "pop = evolution.pop\n", "# we can also use the functions registered to deap\n", "# to select the best of the population:\n", "best_10 = evolution.toolbox.selBest(pop, k=10)\n", "# Remember, we performed a minimization so a fitness\n", "# of 0 is optimal\n", "print(\"Best individual\", best_10[0], \"fitness\", best_10[0].fitness)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can look at the current population by calling `evolution.dfPop()` which returns a pandas dataframe with the parameters of each individual, its id, generation of birth, its outputs, and the fitness (called \"f0\" here)." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/html": [ "
| \n", " | mue_ext_mean | \n", "mui_ext_mean | \n", "score | \n", "id | \n", "gen | \n", "t | \n", "rates_exc | \n", "rates_inh | \n", "IA | \n", "f0 | \n", "
|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", "1.182185 | \n", "0.296606 | \n", "0.0 | \n", "294 | \n", "13 | \n", "[0.1, 0.2, 0.30000000000000004, 0.4, 0.5, 0.60... | \n", "[[4.3201372225314945, 3.9353836030865286, 3.58... | \n", "[[10.920182546008649, 11.229353479381396, 11.5... | \n", "[[111.41612911461853, 111.36042105006122, 111.... | \n", "0.0 | \n", "
| 1 | \n", "1.114270 | \n", "0.240422 | \n", "0.0 | \n", "368 | \n", "16 | \n", "[0.1, 0.2, 0.30000000000000004, 0.4, 0.5, 0.60... | \n", "[[4.3201372225314945, 3.9353836030865286, 3.58... | \n", "[[10.920182546008649, 11.229353479381396, 11.5... | \n", "[[111.41612911461853, 111.36042105006122, 111.... | \n", "0.0 | \n", "
| 2 | \n", "0.910558 | \n", "0.075463 | \n", "0.0 | \n", "403 | \n", "18 | \n", "[0.1, 0.2, 0.30000000000000004, 0.4, 0.5, 0.60... | \n", "[[4.3201372225314945, 3.9353836030865286, 3.58... | \n", "[[10.920182546008649, 11.229353479381396, 11.5... | \n", "[[111.41612911461853, 111.36042105006122, 111.... | \n", "0.0 | \n", "
| 3 | \n", "1.188440 | \n", "0.356385 | \n", "-1.0 | \n", "171 | \n", "7 | \n", "[0.1, 0.2, 0.30000000000000004, 0.4, 0.5, 0.60... | \n", "[[4.3201372225314945, 3.9353836030865286, 3.58... | \n", "[[10.920182546008649, 11.229353479381396, 11.5... | \n", "[[111.41612911461853, 111.36042105006122, 111.... | \n", "1.0 | \n", "
| 4 | \n", "1.007371 | \n", "0.113623 | \n", "-1.0 | \n", "177 | \n", "7 | \n", "[0.1, 0.2, 0.30000000000000004, 0.4, 0.5, 0.60... | \n", "[[4.3201372225314945, 3.9353836030865286, 3.58... | \n", "[[10.920182546008649, 11.229353479381396, 11.5... | \n", "[[111.41612911461853, 111.36042105006122, 111.... | \n", "1.0 | \n", "
| 5 | \n", "1.031484 | \n", "0.120989 | \n", "-1.0 | \n", "192 | \n", "8 | \n", "[0.1, 0.2, 0.30000000000000004, 0.4, 0.5, 0.60... | \n", "[[4.3201372225314945, 3.9353836030865286, 3.58... | \n", "[[10.920182546008649, 11.229353479381396, 11.5... | \n", "[[111.41612911461853, 111.36042105006122, 111.... | \n", "1.0 | \n", "
| 6 | \n", "0.900787 | \n", "0.038763 | \n", "-1.0 | \n", "193 | \n", "8 | \n", "[0.1, 0.2, 0.30000000000000004, 0.4, 0.5, 0.60... | \n", "[[4.3201372225314945, 3.9353836030865286, 3.58... | \n", "[[10.920182546008649, 11.229353479381396, 11.5... | \n", "[[111.41612911461853, 111.36042105006122, 111.... | \n", "1.0 | \n", "
| 7 | \n", "1.217021 | \n", "0.213936 | \n", "-1.0 | \n", "245 | \n", "10 | \n", "[0.1, 0.2, 0.30000000000000004, 0.4, 0.5, 0.60... | \n", "[[4.3201372225314945, 3.9353836030865286, 3.58... | \n", "[[10.920182546008649, 11.229353479381396, 11.5... | \n", "[[111.41612911461853, 111.36042105006122, 111.... | \n", "1.0 | \n", "
| 8 | \n", "1.241895 | \n", "0.365758 | \n", "-1.0 | \n", "248 | \n", "10 | \n", "[0.1, 0.2, 0.30000000000000004, 0.4, 0.5, 0.60... | \n", "[[4.3201372225314945, 3.9353836030865286, 3.58... | \n", "[[10.920182546008649, 11.229353479381396, 11.5... | \n", "[[111.41612911461853, 111.36042105006122, 111.... | \n", "1.0 | \n", "
| 9 | \n", "1.062928 | \n", "0.265389 | \n", "-1.0 | \n", "267 | \n", "11 | \n", "[0.1, 0.2, 0.30000000000000004, 0.4, 0.5, 0.60... | \n", "[[4.3201372225314945, 3.9353836030865286, 3.58... | \n", "[[10.920182546008649, 11.229353479381396, 11.5... | \n", "[[111.41612911461853, 111.36042105006122, 111.... | \n", "1.0 | \n", "
| 10 | \n", "1.007366 | \n", "0.110587 | \n", "-1.0 | \n", "286 | \n", "12 | \n", "[0.1, 0.2, 0.30000000000000004, 0.4, 0.5, 0.60... | \n", "[[4.3201372225314945, 3.9353836030865286, 3.58... | \n", "[[10.920182546008649, 11.229353479381396, 11.5... | \n", "[[111.41612911461853, 111.36042105006122, 111.... | \n", "1.0 | \n", "
| 11 | \n", "0.904612 | \n", "0.123308 | \n", "-1.0 | \n", "320 | \n", "14 | \n", "[0.1, 0.2, 0.30000000000000004, 0.4, 0.5, 0.60... | \n", "[[4.3201372225314945, 3.9353836030865286, 3.58... | \n", "[[10.920182546008649, 11.229353479381396, 11.5... | \n", "[[111.41612911461853, 111.36042105006122, 111.... | \n", "1.0 | \n", "
| 12 | \n", "1.119281 | \n", "0.188307 | \n", "-1.0 | \n", "330 | \n", "15 | \n", "[0.1, 0.2, 0.30000000000000004, 0.4, 0.5, 0.60... | \n", "[[4.3201372225314945, 3.9353836030865286, 3.58... | \n", "[[10.920182546008649, 11.229353479381396, 11.5... | \n", "[[111.41612911461853, 111.36042105006122, 111.... | \n", "1.0 | \n", "
| 13 | \n", "1.158463 | \n", "0.227194 | \n", "-1.0 | \n", "342 | \n", "15 | \n", "[0.1, 0.2, 0.30000000000000004, 0.4, 0.5, 0.60... | \n", "[[4.3201372225314945, 3.9353836030865286, 3.58... | \n", "[[10.920182546008649, 11.229353479381396, 11.5... | \n", "[[111.41612911461853, 111.36042105006122, 111.... | \n", "1.0 | \n", "
| 14 | \n", "1.053327 | \n", "0.281852 | \n", "-1.0 | \n", "344 | \n", "15 | \n", "[0.1, 0.2, 0.30000000000000004, 0.4, 0.5, 0.60... | \n", "[[4.3201372225314945, 3.9353836030865286, 3.58... | \n", "[[10.920182546008649, 11.229353479381396, 11.5... | \n", "[[111.41612911461853, 111.36042105006122, 111.... | \n", "1.0 | \n", "
| 15 | \n", "1.124747 | \n", "0.318747 | \n", "-1.0 | \n", "360 | \n", "16 | \n", "[0.1, 0.2, 0.30000000000000004, 0.4, 0.5, 0.60... | \n", "[[4.3201372225314945, 3.9353836030865286, 3.58... | \n", "[[10.920182546008649, 11.229353479381396, 11.5... | \n", "[[111.41612911461853, 111.36042105006122, 111.... | \n", "1.0 | \n", "
| 16 | \n", "1.266317 | \n", "0.360644 | \n", "-1.0 | \n", "364 | \n", "16 | \n", "[0.1, 0.2, 0.30000000000000004, 0.4, 0.5, 0.60... | \n", "[[4.3201372225314945, 3.9353836030865286, 3.58... | \n", "[[10.920182546008649, 11.229353479381396, 11.5... | \n", "[[111.41612911461853, 111.36042105006122, 111.... | \n", "1.0 | \n", "
| 17 | \n", "1.329988 | \n", "0.388133 | \n", "-1.0 | \n", "365 | \n", "16 | \n", "[0.1, 0.2, 0.30000000000000004, 0.4, 0.5, 0.60... | \n", "[[4.3201372225314945, 3.9353836030865286, 3.58... | \n", "[[10.920182546008649, 11.229353479381396, 11.5... | \n", "[[111.41612911461853, 111.36042105006122, 111.... | \n", "1.0 | \n", "
| 18 | \n", "0.986030 | \n", "0.189384 | \n", "-1.0 | \n", "390 | \n", "18 | \n", "[0.1, 0.2, 0.30000000000000004, 0.4, 0.5, 0.60... | \n", "[[4.3201372225314945, 3.9353836030865286, 3.58... | \n", "[[10.920182546008649, 11.229353479381396, 11.5... | \n", "[[111.41612911461853, 111.36042105006122, 111.... | \n", "1.0 | \n", "
| 19 | \n", "0.896915 | \n", "0.125212 | \n", "-1.0 | \n", "399 | \n", "18 | \n", "[0.1, 0.2, 0.30000000000000004, 0.4, 0.5, 0.60... | \n", "[[4.3201372225314945, 3.9353836030865286, 3.58... | \n", "[[10.920182546008649, 11.229353479381396, 11.5... | \n", "[[111.41612911461853, 111.36042105006122, 111.... | \n", "1.0 | \n", "
| \n", " | mue_ext_mean | \n", "mui_ext_mean | \n", "score | \n", "id | \n", "gen | \n", "f0 | \n", "
|---|---|---|---|---|---|---|
| 0 | \n", "1.400310 | \n", "1.209331 | \n", "-4.0 | \n", "39 | \n", "0 | \n", "4.0 | \n", "
| 1 | \n", "1.173593 | \n", "0.662050 | \n", "-5.0 | \n", "31 | \n", "0 | \n", "5.0 | \n", "
| 2 | \n", "1.134601 | \n", "0.809371 | \n", "-6.0 | \n", "22 | \n", "0 | \n", "6.0 | \n", "
| 3 | \n", "0.992049 | \n", "0.694590 | \n", "-6.0 | \n", "29 | \n", "0 | \n", "6.0 | \n", "
| 4 | \n", "1.470708 | \n", "1.073607 | \n", "-7.0 | \n", "47 | \n", "0 | \n", "7.0 | \n", "
| ... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
| 395 | \n", "1.881591 | \n", "0.299691 | \n", "-24.0 | \n", "425 | \n", "19 | \n", "24.0 | \n", "
| 396 | \n", "0.681422 | \n", "0.489003 | \n", "-8.0 | \n", "426 | \n", "19 | \n", "8.0 | \n", "
| 397 | \n", "1.430791 | \n", "0.268028 | \n", "-24.0 | \n", "427 | \n", "19 | \n", "24.0 | \n", "
| 398 | \n", "1.275903 | \n", "0.534227 | \n", "-3.0 | \n", "428 | \n", "19 | \n", "3.0 | \n", "
| 399 | \n", "0.870652 | \n", "0.326687 | \n", "-5.0 | \n", "429 | \n", "19 | \n", "5.0 | \n", "
400 rows × 6 columns
\n", "