{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Preamble" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import attr\n", "import funcy as fn\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "from bidict import bidict\n", "from IPython.display import Image, display, SVG\n", "import networkx as nx\n", "import pydot\n", "import pandas as pd\n", "\n", "from collections import Counter\n", "\n", "import dfa\n", "from dfa.utils import find_subset_counterexample, find_equiv_counterexample\n", "from dfa_identify import find_dfa, find_dfas\n", "\n", "from diss.planners.product_mc import ProductMC\n", "from diss.concept_classes.dfa_concept import DFAConcept\n", "from diss.domains.gridworld_naive import GridWorldNaive as World\n", "from diss.domains.gridworld_naive import GridWorldState as State\n", "from diss import search, LabeledExamples, GradientGuidedSampler, ConceptIdException\n", "from pprint import pprint\n", "from itertools import combinations\n", "from tqdm import tqdm_notebook\n", "from tqdm.notebook import trange\n", "from IPython.display import clear_output\n", "from IPython.display import HTML as html_print\n", "from functools import reduce\n", "\n", "sns.set_context('paper')\n", "sns.set_style('darkgrid')\n", "sns.set_palette('Set2')\n", "np.set_printoptions(precision=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from diss.experiment import PartialDFAIdentifier, ignore_white, PARTIAL_DFA, BASE_EXAMPLES\n", "from diss.experiment import view_dfa" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from diss import diss" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from diss import DemoPrefixTree as PrefixTree" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def analyze(search, n_iters):\n", " concept2energy = {} # Explored concepts + associated energy\n", " partial_masses = []\n", " median_energies = []\n", " min_energies = []\n", " total_energies = []\n", " # Run Search and collect concepts, energy, and POI.\n", " for i, (data, concept, metadata) in zip(trange(n_iters, desc='DISS'), search):\n", " print(f'==========={i}================')\n", " print('size', concept.size)\n", " \n", " score = metadata['energy']\n", " print('energy', score)\n", " if 'grad' in metadata:\n", " print('surprisal', metadata.get('surprisal'))\n", " grad = metadata['grad']\n", " sns.set(rc={\"figure.figsize\":(10, 2)})\n", " sns.barplot(x=np.arange(len(grad)), y=np.array(grad) / np.abs(grad).max())\n", " plt.xticks(rotation=45)\n", " plt.show()\n", " weights = metadata['weights']\n", " sns.set(rc={\"figure.figsize\":(10, 2)})\n", " sns.barplot(x=np.arange(len(weights)), y=np.array(weights))\n", " plt.xticks(rotation=45)\n", " plt.show()\n", " print('pivot', metadata['pivot'])\n", "\n", " print(\"conjecture:\")\n", " print(f\"{metadata['conjecture']}\")\n", " print(f'data')\n", " data = metadata['data']\n", " data @= identifer.base_examples # Force labels of prior examples.\n", " buff = ''\n", " for lbl, split in [(True, data.positive), (False, data.negative)]:\n", " buff += f'------------- {lbl} --------------<br>'\n", " for word in sorted(split, key=len):\n", " obs = '\\n'.join(map(tile, word))\n", " buff += f'{obs}<br>'\n", " display(html_print(buff))\n", " \n", " \n", " concept2energy[concept] = metadata['energy']\n", " view_dfa(concept)\n", " energies = list(concept2energy.values())\n", " partial_masses.append(sum(np.exp(-x) for x in energies)) # Record unormalized mass\n", " \n", " median_energies.append(np.median(energies))\n", " min_energies.append(np.min(energies))\n", " total_energies.append(sum(energies))\n", "\n", " sorted_concepts = sorted(list(concept2energy), key=concept2energy.get)\n", " \n", " p = 0\n", " for c in sorted(concept2energy, key=concept2energy.get):\n", " p += np.exp(-concept2energy[c])\n", " print('energy', concept2energy[c])\n", " view_dfa(c)\n", " if p > 0.99:\n", " break\n", "\n", "\n", " \n", " df = pd.DataFrame(data={\n", " 'probability mass explored': partial_masses,\n", " 'median energies': median_energies,\n", " 'min energies': min_energies,\n", " 'cumulative energy': total_energies,\n", " 'iteration': list(range(1, len(total_energies) + 1)),\n", " })\n", " return df, sorted_concepts" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from diss.experiment.planner import GridWorldPlanner\n", "planner = GridWorldPlanner.from_string(\n", " buff=\"\"\"y....g..\n", " ........\n", " .b.b...r\n", " .b.b...r\n", " .b.b....\n", " .b.b....\n", " rrrrrr.r\n", " g.y.....\"\"\",\n", " start=(3, 5),\n", " slip_prob=1/32,\n", " horizon=15,\n", " policy_cache='diss_experiment.shelve',\n", ")\n", "SENSOR = planner.gw.sensor\n", "DYN = planner.gw.dyn" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualizing Overlay\n", "\n", "This can all seem pretty abstract, so let's visualize the way the sensor sees the board." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from IPython.display import HTML as html_print\n", "\n", "COLOR_ALIAS = {\n", " 'white': 'white',\n", " 'yellow': '#ffff00', \n", " 'red': '#ff8b8b',\n", " 'blue': '#afafff', \n", " 'green' : '#8ff45d'\n", "}\n", "\n", "\n", "def tile(color='black'):\n", " color = COLOR_ALIAS.get(color, color)\n", " s = ' '*4\n", " return f\"<text style='border: solid 1px;background-color:{color}'>{s}</text>\"\n", "\n", "\n", "def print_map():\n", " \"\"\"Scan the board row by row and print colored tiles.\"\"\"\n", " order = range(1, 9)\n", " buffer = ''\n", " for y in order:\n", " chars = (tile(planner.gw.ap_at_state(x, y)) for x in order)\n", " buffer += ' '.join(chars) + '<br>'\n", " display(html_print(buffer))\n", " \n", "DYN_SENSE = DYN >> SENSOR\n", "\n", "\n", "def print_trc(trc, idx=0):\n", " obs = planner.lift_path(trc, flattened=False, compress=False)\n", " actions = [x['a'] for x in trc[1:]]\n", " obs = map(tile, obs)\n", " display(\n", " html_print(f'trc {idx}: ' + ''.join(''.join(x) for x in zip(actions, obs)) + '\\n')\n", " )\n", " \n", "print_map()\n", "\n", "TRC4 = [\n", " (3, 5),\n", " {'a': '↑', 'c': 0},\n", " {'a': '↑', 'c': 1},\n", " {'a': '↑', 'c': 1},\n", " {'a': '→', 'c': 1},\n", " {'a': '↑', 'c': 1},\n", " {'a': '↑', 'c': 1},\n", " {'a': '→', 'c': 1},\n", " {'a': '→', 'c': 1},\n", " {'a': '→', 'c': 1},\n", " {'a': '←', 'c': 1},\n", " {'a': '←', 'c': 1},\n", " {'a': '←', 'c': 1},\n", " {'a': '←', 'c': 1},\n", " {'a': '←', 'c': 1, 'EOE_ego': 1},\n", "]\n", "\n", "TRC5 = [\n", " (3, 5),\n", " {'a': '↑', 'c': 1},\n", " {'a': '↑', 'c': 1},\n", " #{'a': '↑', 'c': 1},\n", " #{'a': '↑', 'c': 1},\n", " {'a': '↑', 'c': 1},\n", " {'a': '←', 'c': 1},\n", " {'a': '←', 'c': 1, 'EOE_ego': 1},\n", "]\n", "\n", "\n", "print(len(TRC4))\n", "\n", "print_trc(TRC4, 4)\n", "print_trc(TRC5, 5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "BASE_EXAMPLES" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import random" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "env_yellow = dfa.DFA(\n", " start=False,\n", " inputs={'blue', 'green', 'red', 'yellow'},\n", " outputs={True, False},\n", " label=lambda s: s,\n", " transition=lambda s, c: s | (c == 'yellow'),\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "universal = dfa.DFA(\n", " start=True,\n", " inputs={'blue', 'green', 'red', 'yellow'},\n", " outputs={True, False},\n", " label=lambda s: s,\n", " transition=lambda s, c: True,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "MONOLITHIC = True\n", "\n", "identifer = PartialDFAIdentifier(\n", " partial = universal if MONOLITHIC else PARTIAL_DFA,\n", " base_examples = LabeledExamples(negative=[], positive=[]) if MONOLITHIC else BASE_EXAMPLES,\n", " try_reach_avoid=True,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def to_chain(c, t, psat):\n", " return planner.plan(c, t, psat, monolithic=MONOLITHIC, use_rationality=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "n_iters = 100\n", "to_demo = planner.to_demo" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true, "tags": [] }, "outputs": [], "source": [ "dfs = []\n", "for i in fn.chain(range(-10, 11, 1), [float('inf')]): \n", " print(f'beta = 2^{i}')\n", " for _ in range(5):\n", " dfa_search = diss(\n", " demos=[to_demo(TRC4), to_demo(TRC5)] if MONOLITHIC else [to_demo(TRC4[:-1])],\n", " to_concept=identifer,\n", " to_chain=to_chain,\n", " competency=lambda *_: 10,\n", " lift_path=planner.lift_path,\n", " n_iters=n_iters,\n", " reset_period=30,\n", " surprise_weight=1,\n", " size_weight=1/50,\n", " sgs_temp=2**i,\n", " example_drop_prob=1/20, #1e-2,\n", " synth_timeout=20,\n", " )\n", "\n", " df, found_concepts = analyze(dfa_search, n_iters)\n", " df['treatment'] = r'$\\beta = 2^{' + f'{i}' + '}$'\n", " df['logtemp'] = i\n", " df['iteration'] = df.index\n", " dfs.append(df)\n", "\n", "df = pd.concat(dfs, ignore_index=True)\n", "df['experiment'] = 'Monolithic' if monolithic else 'Incremental'\n", "df.to_json( f'experiment_{\"mono\" if MONOLITHIC else \"inc\"}_beta.json')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Enumeration Baselines" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from diss.experiment import concept_class" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true, "tags": [] }, "outputs": [], "source": [ "enumeration_dfs = []\n", "for n_iters, monolithic in [(100, True), (40, False)]:\n", " def to_chain(c, t, psat):\n", " return planner.plan(c, t, psat, monolithic=monolithic, use_rationality=True)\n", "\n", " pos_examples_mono = LabeledExamples(negative=[], positive=[('blue', 'green', 'yellow'), ('yellow',)])\n", " pos_examples_inc = BASE_EXAMPLES @ LabeledExamples(negative=[], positive=[('blue', 'green', 'yellow'), ('yellow',)])\n", "\n", " identifer = PartialDFAIdentifier(\n", " partial = universal if monolithic else PARTIAL_DFA,\n", " base_examples = pos_examples_mono if monolithic else pos_examples_inc\n", " )\n", "\n", " dfa_search = concept_class.enumerative_search(\n", " demos=[to_demo(TRC4), to_demo(TRC5)] if monolithic else [to_demo(TRC4[:-1])],\n", " identifer=identifer,\n", " to_chain=to_chain,\n", " competency=lambda *_: 0.8,\n", " n_iters=n_iters,\n", " surprise_weight=1, # Rescale surprise to make comparable to size.\n", " size_weight=1/50,\n", " )\n", "\n", " df3, _ = analyze(dfa_search, n_iters)\n", " df3['experiment'] = 'Monolithic' if monolithic else 'Incremental'\n", " df3['treatment'] = 'enumeration'\n", " df3['iteration'] = df3.index\n", " enumeration_dfs.append(df3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Analysis" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "df_mono = pd.read_json( f'experiment_mono_beta.json')\n", "df_inc = pd.read_json( f'experiment_inc_beta.json')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Normalize all energies between 0 and 1.\n", "for tmp1, tmp2 in zip([df_mono, df_inc], enumeration_dfs):\n", " U1, U2 = tmp1['min energies'], tmp2['min energies']\n", " U_min = min(U1.min(), U2.min())\n", " U_max = max(U1.max(), U2.max())\n", " \n", " for tmp in [tmp1, tmp2]:\n", " U = tmp['min energies']\n", " tmp['U'] = (U - U_min) / (U_max- U_min)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df_mono['experiment'] = 'Monolithic'\n", "df_inc['experiment'] = 'Incremental'\n", "diss_dfs = [df_mono, df_inc]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "enumeration_dfs[0]['experiment'] = 'Monolithic'\n", "enumeration_dfs[1]['experiment'] = 'Incremental'\n", "#df_enum = pd.concat(enumeration_dfs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "sns.set(rc={\"figure.figsize\":(10, 6)})\n", "\n", "for df_diss, df_enum, iters, experiment in zip(diss_dfs, enumeration_dfs, [80, 40], ['Monolithic', 'Incremental']):\n", " if experiment == 'Incremental':\n", " plot = plt.scatter(list(range(21)), list(range(-10, 11)), c=list(range(-10, 11)), cmap='coolwarm')\n", " plt.clf()\n", " cbar = plt.colorbar(plot, extend='max')\n", " cbar.ax.set_ylabel(r'$\\ln \\beta$', rotation=270)\n", " hdl = plt.plot(df_enum['iteration'], df_enum['U'], '--', c='black', label='enumerate')\n", " grid = sns.lineplot(\n", " data=df_diss, x='iteration', y='U',\n", " palette='coolwarm', hue='treatment', legend=False, \n", " estimator=np.median, ci=None,\n", " )\n", " plt.title(f'{experiment=}')\n", " plt.xlim(0, iters)\n", " plt.xlabel('Iteration')\n", " plt.ylabel('(normalized) minumum energy DFA found')\n", " plt.legend()\n", "\n", " plt.savefig(f'mass_{experiment}.pgf')\n", " plt.show()\n", "\n", " \n", "#\n", "#plt.savefig('mass_mono2.pgf')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.2" } }, "nbformat": 4, "nbformat_minor": 4 }