{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Separation\n",
    "\n",
    "Let's think about the essential pieces here:\n",
    "- a periodic 2d domain\n",
    "    - we need to map positions back to the central coordinate range\n",
    "- a group of particles\n",
    "    - particle positions need to be tracked\n",
    "    - particles need to be moved around\n",
    "    - we need to diagnose center of mass\n",
    "    - we need to diagnose moment of inertia\n",
    "- a random number generator\n",
    "    - we want to initialize it in a reproducible way (see rule 6 from [Sandve et al. (2013)](https://doi.org/10.1371/journal.pcbi.1003285))\n",
    "    - we want to draw arrays of random numbers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Tech preamble"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The spatial domain\n",
    "\n",
    "Has info on its size (`length_x`, `length_y`), has a method to return these sizes, has a method to normalize given positions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PeriodicSpace:\n",
    "    def __init__(self, length_x=10, length_y=20):\n",
    "        self.length_y = length_y\n",
    "        self.length_x = length_x\n",
    "    \n",
    "    def get_sizes(self):\n",
    "        return self.length_x, self.length_y\n",
    "        \n",
    "    def normalize_positions(self, x, y):\n",
    "        return np.mod(x, self.length_x), np.mod(y, self.length_y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_space = PeriodicSpace()\n",
    "\n",
    "# ensure that positions within the base interval\n",
    "# are not changed\n",
    "lx, ly = test_space.get_sizes()\n",
    "centered_x, centered_y = lx / 2, ly / 2\n",
    "assert (centered_x, centered_y) == test_space.normalize_positions(x=centered_x, y=centered_y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## A group of particles\n",
    "\n",
    "- track state of all particles (positions `x`, `y`)\n",
    "- method to move particles\n",
    "- methods to diagnose\n",
    "  - center of mass\n",
    "  - moment of inertia\n",
    "- method to wrap diagnostics in pandas dataframe\n",
    "- method to wrap positions in a pandas dataframe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Particles:\n",
    "    def __init__(\n",
    "        self,\n",
    "        rng=np.random.RandomState(),\n",
    "        space=PeriodicSpace(),\n",
    "        x=None, y=None,\n",
    "        step_length=0.5\n",
    "    ):\n",
    "        self.rng = rng\n",
    "        self.space = space\n",
    "        self.x, self.y = x, y\n",
    "        self.step_length = step_length\n",
    "        self.steps_done = 0\n",
    "\n",
    "    def move(self):\n",
    "        self.x += self.step_length * self.rng.normal(size=self.x.shape)\n",
    "        self.y += self.step_length * self.rng.normal(size=self.y.shape)\n",
    "        \n",
    "        self.x, self.y = self.space.normalize_positions(self.x, self.y)\n",
    "        \n",
    "        self.steps_done += 1\n",
    "\n",
    "    def center_of_mass(self):\n",
    "        return self.x.mean(), self.y.mean()\n",
    "    \n",
    "    def moment_of_inertia(self):\n",
    "        return self.x.var() + self.y.var()\n",
    "\n",
    "    def diagnostics(self):\n",
    "        com = self.center_of_mass()\n",
    "        mi = self.moment_of_inertia()\n",
    "        return pd.DataFrame(\n",
    "            {\n",
    "                \"center_of_mass_x\": com[0],\n",
    "                \"center_of_mass_y\": com[1],\n",
    "                \"moment_of_inertia\": mi\n",
    "            },\n",
    "            index=[self.steps_done, ],        \n",
    "        )\n",
    "    \n",
    "    def positions(self):\n",
    "        return pd.DataFrame(\n",
    "            {\n",
    "                \"x\": self.x,\n",
    "                \"y\": self.y\n",
    "            }\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Random number generator\n",
    "\n",
    "There is already an object with all we need: `numpy.random.RandomState`.  So we'll use this for our RNG."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Drive the experiment\n",
    "\n",
    "Parameters we want to be able to control:\n",
    "- number of particles\n",
    "- number of steps to run the model\n",
    "- size of the spatial domain\n",
    "- step size in the updates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_random_walk(\n",
    "    length_x=10,\n",
    "    length_y=20,\n",
    "    number_particles=100,\n",
    "    number_steps=100,\n",
    "    step_length=0.5\n",
    "):\n",
    "    # initialize spatial domain\n",
    "    space = PeriodicSpace(\n",
    "        length_x=length_x,\n",
    "        length_y=length_y\n",
    "    )\n",
    "    \n",
    "    # initialize random number generator\n",
    "    rng = np.random.RandomState()\n",
    "    \n",
    "    # initialize group of particles and register spatial domain and rng\n",
    "    particles = Particles(\n",
    "        space=space,\n",
    "        rng=rng,\n",
    "        x=np.ones((number_particles, )) * length_x / 2.0,\n",
    "        y=np.ones((number_particles, )) * length_y / 2.0,        \n",
    "        step_length=step_length\n",
    "    )\n",
    "    \n",
    "    # track initial positions\n",
    "    initial_positions = particles.positions()\n",
    "    \n",
    "    # initialize dataframe with diagnostics\n",
    "    diags = particles.diagnostics()\n",
    "    \n",
    "    # run `number_steps` updates and accumulate diagnostics\n",
    "    for step in range(1, number_steps):\n",
    "        particles.move()\n",
    "        \n",
    "        diags = diags.append(\n",
    "            particles.diagnostics(),\n",
    "            ignore_index=True\n",
    "        )\n",
    "    \n",
    "    # track final positions\n",
    "    final_positions = particles.positions()\n",
    "        \n",
    "    return diags, initial_positions, final_positions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "(\n",
    "    diags,\n",
    "    initial_positions,\n",
    "    final_positions\n",
    ") = run_random_walk(number_particles=100, number_steps=1000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data and visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>x</th>\n",
       "      <th>y</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>10.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>5.0</td>\n",
       "      <td>10.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>5.0</td>\n",
       "      <td>10.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>5.0</td>\n",
       "      <td>10.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>5.0</td>\n",
       "      <td>10.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "     x     y\n",
       "0  5.0  10.0\n",
       "1  5.0  10.0\n",
       "2  5.0  10.0\n",
       "3  5.0  10.0\n",
       "4  5.0  10.0"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>center_of_mass_x</th>\n",
       "      <th>center_of_mass_y</th>\n",
       "      <th>moment_of_inertia</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   center_of_mass_x  center_of_mass_y  moment_of_inertia\n",
       "0               5.0              10.0                0.0"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>x</th>\n",
       "      <th>y</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>9.491506</td>\n",
       "      <td>18.354704</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>3.822715</td>\n",
       "      <td>13.159536</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>5.388253</td>\n",
       "      <td>0.309439</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>4.807517</td>\n",
       "      <td>12.536099</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>1.886733</td>\n",
       "      <td>10.880113</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "          x          y\n",
       "0  9.491506  18.354704\n",
       "1  3.822715  13.159536\n",
       "2  5.388253   0.309439\n",
       "3  4.807517  12.536099\n",
       "4  1.886733  10.880113"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>center_of_mass_x</th>\n",
       "      <th>center_of_mass_y</th>\n",
       "      <th>moment_of_inertia</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>999</td>\n",
       "      <td>5.015665</td>\n",
       "      <td>9.934202</td>\n",
       "      <td>39.863907</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "     center_of_mass_x  center_of_mass_y  moment_of_inertia\n",
       "999          5.015665          9.934202          39.863907"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "display(initial_positions.head(5))\n",
    "display(diags.head(1))\n",
    "\n",
    "display(final_positions.head(5))\n",
    "display(diags.tail(1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'init' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-8-16352240deee>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mdisplay\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minit\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0mdisplay\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdiags\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'init' is not defined"
     ]
    }
   ],
   "source": [
    "display(init)\n",
    "display(diags)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "initial_positions.plot.scatter(x=\"x\", y=\"y\")\n",
    "final_positions.plot.scatter(x=\"x\", y=\"y\");"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "diags.plot();"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:climnum-exercise]",
   "language": "python",
   "name": "conda-env-climnum-exercise-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}