{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "
Peter Norvig
Decembers 2016–2023
\n", "\n", "# Advent of Code Utilities\n", "\n", "Stuff I might need for [Advent of Code](https://adventofcode.com). \n", "\n", "First, some imports that I have used in past AoC years:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "from collections import Counter, defaultdict, namedtuple, deque, abc\n", "from dataclasses import dataclass, field\n", "from itertools import permutations, combinations, cycle, chain, islice\n", "from itertools import count as count_from, product as cross_product\n", "from typing import *\n", "from statistics import mean, median\n", "from math import ceil, floor, factorial, gcd, log, log2, log10, sqrt, inf\n", "\n", "import matplotlib.pyplot as plt\n", "import ast\n", "import fractions\n", "import functools\n", "import heapq\n", "import operator\n", "import pathlib\n", "import re\n", "import string\n", "import sys\n", "import time" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Daily Workflow\n", "\n", "Each day's work will consist of three tasks, denoted by three sections in the notebook:\n", "- **Input**: Parse the day's input file with the function `parse`.\n", "- **Part 1**: Understand the day's instructions and:\n", " - Write code to compute the answer to Part 1.\n", " - Once I have computed the answer and submitted it to the AoC site to verify it is correct, I record it with the `answer` class.\n", "- **Part 2**: Repeat the above steps for Part 2.\n", "- Occasionally I'll introduce a **Part 3** where I explore beyond the official instructions.\n", "\n", "# Parsing Input Files\n", "\n", "The function `parse` is meant to handle each day's input. A call `parse(day, parser, sections)` does the following:\n", " - Reads the input file for `day`.\n", " - Breaks the file into a *sections*. By default, this is lines, but you can use `paragraphs`, or pass in a custom function.\n", " - Applies `parser` to each section and returns the results as a tuple of records.\n", " - Useful parser functions include `ints`, `digits`, `atoms`, `words`, and the built-ins `int` and `str`.\n", " - Prints the first few input lines and output records. This is useful to me as a debugging tool, and to the reader.\n", " - The defaults are `parser=str, sections=lines`, so by default `parse(n)` gives a tuple of lines from fuile *day*." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "current_year = 2023 # Subdirectory name for input files\n", "\n", "lines = str.splitlines # By default, split input text into lines\n", "\n", "def paragraphs(text): \"Split text into paragraphs\"; return text.split('\\n\\n')\n", "\n", "def parse(day_or_text:Union[int, str], parser=str, sections=lines, show=8) -> tuple:\n", " \"\"\"Split the input text into `sections`, and apply `parser` to each.\n", " The first argument is either the text itself, or the day number of a text file.\"\"\"\n", " if isinstance(day_or_text, str) and show == 8: \n", " show = 0 # By default, don't show lines when parsing example text.\n", " start = time.time()\n", " text = get_text(day_or_text)\n", " show_items('Puzzle input', text.splitlines(), show)\n", " records = mapt(parser, sections(text.rstrip()))\n", " if parser != str or sections != lines:\n", " show_items('Parsed representation', records, show)\n", " return records\n", "\n", "def get_text(day_or_text: Union[int, str]) -> str:\n", " \"\"\"The text used as input to the puzzle: either a string or the day number,\n", " which denotes the file 'AOC/year/input{day}.txt'.\"\"\"\n", " if isinstance(day_or_text, str):\n", " return day_or_text\n", " else:\n", " filename = f'AOC/{current_year}/input{day_or_text}.txt'\n", " return pathlib.Path(filename).read_text()\n", "\n", "def show_items(source, items, show:int, hr=\"─\"*100):\n", " \"\"\"Show the first few items, in a pretty format.\"\"\"\n", " if show:\n", " types = Counter(map(type, items))\n", " counts = ', '.join(f'{n} {t.__name__}{\"\" if n == 1 else \"s\"}' for t, n in types.items())\n", " print(f'{hr}\\n{source} ➜ {counts}:\\n{hr}')\n", " for line in items[:show]:\n", " print(truncate(line))\n", " if show < len(items):\n", " print('...')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Functions that can be used as the `parser` argument to `parse` (also, consider `str.split` to split the line on whitespace): " ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "Char = str # Intended as the type of a one-character string\n", "Atom = Union[str, float, int] # The type of a string or number\n", "\n", "def ints(text: str) -> Tuple[int]:\n", " \"\"\"A tuple of all the integers in text, ignoring non-number characters.\"\"\"\n", " return mapt(int, re.findall(r'-?[0-9]+', text))\n", "\n", "def positive_ints(text: str) -> Tuple[int]:\n", " \"\"\"A tuple of all the integers in text, ignoring non-number characters.\"\"\"\n", " return mapt(int, re.findall(r'[0-9]+', text))\n", "\n", "def digits(text: str) -> Tuple[int]:\n", " \"\"\"A tuple of all the digits in text (as ints 0–9), ignoring non-digit characters.\"\"\"\n", " return mapt(int, re.findall(r'[0-9]', text))\n", "\n", "def words(text: str) -> Tuple[str]:\n", " \"\"\"A tuple of all the alphabetic words in text, ignoring non-letters.\"\"\"\n", " return tuple(re.findall(r'[a-zA-Z]+', text))\n", "\n", "def atoms(text: str) -> Tuple[Atom]:\n", " \"\"\"A tuple of all the atoms (numbers or identifiers) in text. Skip punctuation.\"\"\"\n", " return mapt(atom, re.findall(r'[+-]?\\d+\\.?\\d*|\\w+', text))\n", "\n", "def atom(text: str) -> Atom:\n", " \"\"\"Parse text into a single float or int or str.\"\"\"\n", " try:\n", " x = float(text)\n", " return round(x) if x.is_integer() else x\n", " except ValueError:\n", " return text.strip()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Daily Answers\n", "\n", "Here is the `answer` class, which gives verification of a correct computation (or an error message for an incorrect computation), times how long the computation took, and stores the result in the dict `answers`." ] }, { "cell_type": "code", "execution_count": 105, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " 0.1: .0000 seconds, answer unknown \n", " 0.2: .0000 seconds, answer 549755813888 ok\n", " 0.3: .0000 seconds, answer 549755813889 WRONG; expected answer is 549755813888\n", "10.4: .0000 seconds, answer 4 WRONG; expected answer is unknown\n" ] } ], "source": [ "answers = {} # `answers` is a dict of {puzzle_number: answer}\n", "\n", "unknown = 'unknown'\n", "\n", "class answer:\n", " \"\"\"Verify that calling `code` computes the `solution` to `puzzle`. \n", " Record results in the dict `answers`.\"\"\"\n", " def __init__(self, puzzle: float, solution, code:callable=lambda:unknown):\n", " self.puzzle, self.solution, self.code = puzzle, solution, code\n", " answers[puzzle] = self\n", " self.check()\n", " \n", " def check(self) -> bool:\n", " \"\"\"Check if the code computes the correct solution; record run time.\"\"\"\n", " start = time.time()\n", " self.got = self.code()\n", " self.secs = time.time() - start\n", " self.ok = (self.got == self.solution)\n", " return self.ok\n", " \n", " def __repr__(self) -> str:\n", " \"\"\"The repr of an answer shows what happened.\"\"\"\n", " secs = f'{self.secs:7.4f}'.replace(' 0.', ' .')\n", " comment = (f'' if self.got == unknown else\n", " f' ok' if self.ok else \n", " f' WRONG; expected answer is {self.solution}')\n", " return f'Puzzle {self.puzzle:4.1f}: {secs} seconds, answer {self.got:<15}{comment}'\n", "\n", "def test_answer():\n", " print(answer(0.1, unknown))\n", " print(answer(0.2, 2**39, lambda: 2**39))\n", " print(answer(0.3, 2**39, lambda: 2**39+1))\n", " print(answer(10.4, unknown, lambda: 2 + 2))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Additional utility functions \n", "\n", "All of the following have been used in solutions to multiple puzzles in the past, so I pulled them all in here:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "class multimap(defaultdict):\n", " \"\"\"A mapping of {key: [val1, val2, ...]}.\"\"\"\n", " def __init__(self, pairs:Iterable[tuple]=(), symmetric=False):\n", " \"\"\"Given (key, val) pairs, return {key: [val, ...], ...}.\n", " If `symmetric` is True, treat (key, val) as (key, val) plus (val, key).\"\"\"\n", " self.default_factory = list\n", " for (key, val) in pairs:\n", " self[key].append(val)\n", " if symmetric:\n", " self[val].append(key)\n", "\n", "def prod(numbers) -> float: # Will be math.prod in Python 3.8\n", " \"\"\"The product formed by multiplying `numbers` together.\"\"\"\n", " result = 1\n", " for x in numbers:\n", " result *= x\n", " return result\n", "\n", "def T(matrix: Sequence[Sequence]) -> List[Tuple]:\n", " \"\"\"The transpose of a matrix: T([(1,2,3), (4,5,6)]) == [(1,4), (2,5), (3,6)]\"\"\"\n", " return list(zip(*matrix))\n", "\n", "def total(counter: Counter) -> int: \n", " \"\"\"The sum of all the counts in a Counter.\"\"\"\n", " return sum(counter.values())\n", "\n", "def minmax(numbers) -> Tuple[int, int]:\n", " \"\"\"A tuple of the (minimum, maximum) of numbers.\"\"\"\n", " numbers = list(numbers)\n", " return min(numbers), max(numbers)\n", "\n", "def cover(*integers) -> range:\n", " \"\"\"A `range` that covers all the given integers, and any in between them.\n", " cover(lo, hi) is an inclusive (or closed) range, equal to range(lo, hi + 1).\n", " The same range results from cover(hi, lo) or cover([hi, lo]).\"\"\"\n", " if len(integers) == 1: integers = the(integers)\n", " return range(min(integers), max(integers) + 1)\n", "\n", "def the(sequence) -> object:\n", " \"\"\"Return the one item in a sequence. Raise error if not exactly one.\"\"\"\n", " for i, item in enumerate(sequence, 1):\n", " if i > 1: raise ValueError(f'Expected exactly one item in the sequence.')\n", " return item\n", "\n", "def split_at(sequence, i) -> Tuple[Sequence, Sequence]:\n", " \"\"\"The sequence split into two pieces: (before position i, and i-and-after).\"\"\"\n", " return sequence[:i], sequence[i:]\n", "\n", "def ignore(*args) -> None: \"Just return None.\"; return None\n", "\n", "def is_int(x) -> bool: \"Is x an int?\"; return isinstance(x, int) \n", "\n", "def sign(x) -> int: \"0, +1, or -1\"; return (0 if x == 0 else +1 if x > 0 else -1)\n", "\n", "def lcm(i, j) -> int: \"Least common multiple\"; return i * j // gcd(i, j)\n", "\n", "def union(sets) -> set: \"Union of several sets\"; return set().union(*sets)\n", "\n", "def intersection(sets):\n", " \"Intersection of several sets; error if no sets.\"\n", " first, *rest = sets\n", " return set(first).intersection(*rest)\n", "\n", "def range_intersection(range1, range2) -> range:\n", " \"\"\"Return a range that is the intersection of these two ranges.\"\"\"\n", " return range(max(range1.start, range2.start), min(range1.stop, range2.stop))\n", " \n", "def naked_plot(points, marker='o', size=(10, 10), invert=True, square=False, **kwds):\n", " \"\"\"Plot `points` without any axis lines or tick marks.\n", " Optionally specify size, whether square or not, and whether to invery y axis.\"\"\"\n", " if size: plt.figure(figsize=((size, size) if is_int(size) else size))\n", " plt.plot(*T(points), marker, **kwds)\n", " if square: plt.axis('square')\n", " plt.axis('off')\n", " if invert: plt.gca().invert_yaxis()\n", " \n", "def clock_mod(i, m) -> int:\n", " \"\"\"i % m, but replace a result of 0 with m\"\"\"\n", " # This is like a clock, where 24 mod 12 is 12, not 0.\n", " return (i % m) or m\n", "\n", "def invert_dict(dic) -> dict:\n", " \"\"\"Invert a dict, e.g. {1: 'a', 2: 'b'} -> {'a': 1, 'b': 2}.\"\"\"\n", " return {dic[x]: x for x in dic}\n", "\n", "def walrus(name, value):\n", " \"\"\"If you're not in 3.8 or more, and you can't do `x := val`,\n", " then you can use `walrus('x', val)`, if `x` is global.\"\"\"\n", " globals()[name] = value\n", " return value\n", "\n", "def truncate(object, width=100, ellipsis=' ...') -> str:\n", " \"\"\"Use elipsis to truncate `str(object)` to `width` characters, if necessary.\"\"\"\n", " string = str(object)\n", " return string if len(string) <= width else string[:width-len(ellipsis)] + ellipsis\n", "\n", "def mapt(function: Callable, *sequences) -> tuple:\n", " \"\"\"`map`, with the result as a tuple.\"\"\"\n", " return tuple(map(function, *sequences))\n", "\n", "def mapl(function: Callable, *sequences) -> list:\n", " \"\"\"`map`, with the result as a list.\"\"\"\n", " return list(map(function, *sequences))\n", "\n", "def cat(things: Collection) -> str:\n", " \"\"\"Concatenate the things.\"\"\"\n", " return ''.join(map(str, things))\n", " \n", "cache = functools.lru_cache(None)\n", "Ø = frozenset() # empty set" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Itertools Recipes\n", "\n", "The Python docs for the `itertools` module has some [\"recipes\"](https://docs.python.org/3/library/itertools.html#itertools-recipes) that I include here (some I have slightly modified):" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "def quantify(iterable, pred=bool) -> int:\n", " \"\"\"Count the number of items in iterable for which pred is true.\"\"\"\n", " return sum(1 for item in iterable if pred(item))\n", "\n", "def dotproduct(vec1, vec2):\n", " \"\"\"The dot product of two vectors.\"\"\"\n", " return sum(map(operator.mul, vec1, vec2))\n", "\n", "def powerset(iterable) -> Iterable[tuple]:\n", " \"powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)\"\n", " s = list(iterable)\n", " return flatten(combinations(s, r) for r in range(len(s) + 1))\n", "\n", "flatten = chain.from_iterable # Yield items from each sequence in turn\n", "\n", "def append(sequences) -> Sequence: \"Append into a list\"; return list(flatten(sequences))\n", "\n", "def batched(iterable, n) -> Iterable[tuple]:\n", " \"Batch data into non-overlapping tuples of length n. The last batch may be shorter.\"\n", " # batched('ABCDEFG', 3) --> ABC DEF G\n", " it = iter(iterable)\n", " while True:\n", " batch = tuple(islice(it, n))\n", " if batch:\n", " yield batch\n", " else:\n", " return\n", "\n", "def sliding_window(sequence, n) -> Iterable[Sequence]:\n", " \"\"\"All length-n subsequences of sequence.\"\"\"\n", " return (sequence[i:i+n] for i in range(len(sequence) + 1 - n))\n", "\n", "def first(iterable, default=None) -> Optional[object]: \n", " \"\"\"The first element in an iterable, or the default if iterable is empty.\"\"\"\n", " return next(iter(iterable), default)\n", "\n", "def last(iterable) -> Optional[object]: \n", " \"\"\"The last element in an iterable.\"\"\"\n", " for item in iterable:\n", " pass\n", " return item\n", "\n", "def nth(iterable, n, default=None):\n", " \"Returns the nth item or a default value\"\n", " return next(islice(iterable, n, None), default)\n", "\n", "def first_true(iterable, default=False):\n", " \"\"\"Returns the first true value in the iterable.\n", " If no true value is found, returns `default`.\"\"\"\n", " return next((x for x in iterable if x), default)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Points in Space\n", "\n", "Many puzzles involve points; usually two-dimensional points on a plane. A few puzzles involve three-dimensional points, and perhaps one might involve non-integers, so I'll try to make my `Point` implementation flexible in a duck-typing way. A point can also be considered a `Vector`; that is, `(1, 0)` can be a `Point` that means \"this is location x=1, y=0 in the plane\" and it also can be a `Vector` that means \"move Eat (+1 in the along the x axis).\" First we'll define points/vectors:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "Point = Tuple[int, ...] # Type for points\n", "Vector = Point # E.g., (1, 0) can be a point, or can be a direction, a Vector\n", "Zero = (0, 0)\n", "\n", "directions4 = East, South, West, North = ((1, 0), (0, 1), (-1, 0), (0, -1))\n", "diagonals = SE, NE, SW, NW = ((1, 1), (1, -1), (-1, 1), (-1, -1))\n", "directions8 = directions4 + diagonals\n", "directions5 = directions4 + (Zero,)\n", "directions9 = directions8 + (Zero,)\n", "arrow_direction = {'^': North, 'v': South, '>': East, '<': West, '.': Zero,\n", " 'U': North, 'D': South, 'R': East, 'L': West}\n", "\n", "def X_(point) -> int: \"X coordinate of a point\"; return point[0]\n", "def Y_(point) -> int: \"Y coordinate of a point\"; return point[1]\n", "def Z_(point) -> int: \"Z coordinate of a point\"; return point[2]\n", "\n", "def Xs(points) -> Tuple[int]: \"X coordinates of a collection of points\"; return mapt(X_, points)\n", "def Ys(points) -> Tuple[int]: \"Y coordinates of a collection of points\"; return mapt(Y_, points)\n", "def Zs(points) -> Tuple[int]: \"X coordinates of a collection of points\"; return mapt(Z_, points)\n", "\n", "def add(p: Point, q: Point) -> Point: return mapt(operator.add, p, q)\n", "def sub(p: Point, q: Point) -> Point: return mapt(operator.sub, p, q)\n", "def neg(p: Point) -> Vector: return mapt(operator.neg, p)\n", "def mul(p: Point, k: float) -> Vector: return tuple(k * c for c in p)\n", "\n", "def distance(p: Point, q: Point) -> float:\n", " \"\"\"Euclidean (L2) distance between two points.\"\"\"\n", " d = sum((pi - qi) ** 2 for pi, qi in zip(p, q)) ** 0.5\n", " return int(d) if d.is_integer() else d\n", "\n", "def slide(points: Set[Point], delta: Vector) -> Set[Point]: \n", " \"\"\"Slide all the points in the set of points by the amount delta.\"\"\"\n", " return {add(p, delta) for p in points}\n", "\n", "def make_turn(facing:Vector, turn:str) -> Vector:\n", " \"\"\"Turn 90 degrees left or right. `turn` can be 'L' or 'Left' or 'R' or 'Right' or lowercase.\"\"\"\n", " (x, y) = facing\n", " return (y, -x) if turn[0] in ('L', 'l') else (-y, x)\n", "\n", "# Profiling found that `add` and `taxi_distance` were speed bottlenecks; \n", "# I define below versions that are specialized for 2D points only.\n", "\n", "def add2(p: Point, q: Point) -> Point: \n", " \"\"\"Specialized version of point addition for 2D Points only. Faster.\"\"\"\n", " return (p[0] + q[0], p[1] + q[1])\n", "\n", "def taxi_distance(p: Point, q: Point) -> int:\n", " \"\"\"Manhattan (L1) distance between two 2D Points.\"\"\"\n", " return abs(p[0] - q[0]) + abs(p[1] - q[1])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Points on a Grid\n", "\n", "Many puzzles seem to involve a two-dimensional rectangular grid with integer coordinates. A `Grid` is a rectangular array of (integer, integer) points, where each point holds some contents. Important things to know:\n", "- `Grid` is a subclass of `dict`\n", "- Usually the contents will be a character or an integer, but that's not specified or restricted. \n", "- A Grid can be initialized three ways:\n", " - With another dict of `{point: contents}`, or an iterable of `(point, contents) pairs.\n", " - With an iterable of strings, each depicting a row (e.g. `[\"#..\", \"..#\"]`.\n", " - With a single string, which will be split on newlines.\n", "- Contents that are a member of `skip` will be skipped. (For example, you could do `skip=[' ']` to not store any point that has a space as its contents.\n", "- There is a `grid.neighbors(point)` method. By default it returns the 4 orthogonal neighbors but you could make it all 8 adjacent squares, or something else, by specifying the `directions` keyword value in the `Grid` constructor.\n", "- By default, grids have bounded size; accessing a point outside the grid results in a `KeyError`. But some grids extend in all directions without limit; you can implement that by specifying, say, `default='.'` to make `'.'` contents in all directions." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "class Grid(dict):\n", " \"\"\"A 2D grid, implemented as a mapping of {(x, y): cell_contents}.\"\"\"\n", " def __init__(self, grid=(), directions=directions4, skip=(), default=None):\n", " \"\"\"Initialize one of four ways: \n", " `Grid({(0, 0): '#', (1, 0): '.', ...})`\n", " `Grid(another_grid)\n", " `Grid([\"#..\", \"..#\"])\n", " `Grid(\"#..\\n..#\")`.\"\"\"\n", " self.directions = directions\n", " self.skip = skip\n", " self.default = default\n", " if isinstance(grid, abc.Mapping): \n", " self.update(grid) \n", " self.size = (len(cover(Xs(self))), len(cover(Ys(self))))\n", " else:\n", " if isinstance(grid, str): \n", " grid = grid.splitlines()\n", " self.size = (max(map(len, grid)), len(grid))\n", " self.update({(x, y): val \n", " for y, row in enumerate(grid) \n", " for x, val in enumerate(row)\n", " if val not in skip})\n", " \n", " def __missing__(self, point): \n", " \"\"\"If asked for a point off the grid, either return default or raise error.\"\"\"\n", " if self.default == KeyError:\n", " raise KeyError(point)\n", " else:\n", " return self.default\n", "\n", " def in_range(self, point) -> bool:\n", " \"\"\"Is the point within the range of the grid's size?\"\"\"\n", " return (0 <= X_(point) < X_(self.size) and\n", " 0 <= Y_(point) < Y_(self.size))\n", "\n", " def copy(self): \n", " return Grid(self, directions=self.directions, skip=self.skip, default=self.default)\n", " \n", " def neighbors(self, point) -> List[Point]:\n", " \"\"\"Points on the grid that neighbor `point`.\"\"\"\n", " return [add2(point, Δ) for Δ in self.directions \n", " if add2(point, Δ) in self or self.default not in (KeyError, None)]\n", " \n", " def neighbor_contents(self, point) -> Iterable:\n", " \"\"\"The contents of the neighboring points.\"\"\"\n", " return (self[p] for p in self.neighbors(point))\n", "\n", " def findall(self, contents: Collection) -> List[Point]:\n", " \"\"\"All points that contain one of the given contents, e.g. grid.findall('#').\"\"\"\n", " return [p for p in self if self[p] in contents]\n", " \n", " def to_rows(self, xrange=None, yrange=None) -> List[List[object]]:\n", " \"\"\"The contents of the grid, as a rectangular list of lists.\n", " You can define a window with an xrange and yrange; or they default to the whole grid.\"\"\"\n", " xrange = xrange or cover(Xs(self))\n", " yrange = yrange or cover(Ys(self))\n", " default = ' ' if self.default in (KeyError, None) else self.default\n", " return [[self.get((x, y), default) for x in xrange] \n", " for y in yrange]\n", "\n", " def print(self, sep='', xrange=None, yrange=None):\n", " \"\"\"Print a representation of the grid.\"\"\"\n", " for row in self.to_rows(xrange, yrange):\n", " print(*row, sep=sep)\n", " \n", " def plot(self, markers={'#': 's', '.': ','}, figsize=(14, 14), **kwds):\n", " \"\"\"Plot a representation of the grid.\"\"\"\n", " plt.figure(figsize=figsize)\n", " plt.gca().invert_yaxis()\n", " for m in markers:\n", " plt.plot(*T(p for p in self if self[p] == m), markers[m], **kwds)\n", " \n", "def neighbors(point, directions=directions4) -> List[Point]:\n", " \"\"\"Neighbors of this point, in the given directions.\n", " (This function can be used outside of a Grid class.)\"\"\"\n", " return [add(point, Δ) for Δ in directions]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# A* Search\n", "\n", "Many puzzles involve searching over a branching tree of possibilities. For many puzzles, an ad-hoc solution is fine. Different problems require different things from a search: \n", "- Some just need to know the final goal state.\n", "- Some need to know the sequence of actions that led to the final state.\n", "- Some neeed to know the sequence of intermediate states. \n", "- Some need to know the number of steps (or the total cost) to get to the final state.\n", "\n", "But sometimes you need all of that (or you think you might need it in Part 2), and sometimes you have a good heuristic estimate of the distance to a goal state, and you want to make sure to use it. If that's the case, then my `SearchProblem` class and `A_star_search` function may be approopriate." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "def A_star_search(problem, h=None):\n", " \"\"\"Search nodes with minimum f(n) = path_cost(n) + h(n) value first.\"\"\"\n", " h = h or problem.h\n", " return best_first_search(problem, f=lambda n: n.path_cost + h(n))\n", "\n", "def best_first_search(problem, f) -> 'Node':\n", " \"Search nodes with minimum f(node) value first.\"\n", " node = Node(problem.initial)\n", " frontier = PriorityQueue([node], key=f)\n", " reached = {problem.initial: node}\n", " while frontier:\n", " node = frontier.pop()\n", " if problem.is_goal(node.state):\n", " return node\n", " for child in expand(problem, node):\n", " s = child.state\n", " if s not in reached or child.path_cost < reached[s].path_cost:\n", " reached[s] = child\n", " frontier.add(child)\n", " return search_failure" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "class SearchProblem:\n", " \"\"\"The abstract class for a search problem. A new domain subclasses this,\n", " overriding `actions` and perhaps other methods.\n", " The default heuristic is 0 and the default action cost is 1 for all states.\n", " When you create an instance of a subclass, specify `initial`, and `goal` states \n", " (or give an `is_goal` method) and perhaps other keyword args for the subclass.\"\"\"\n", "\n", " def __init__(self, initial=None, goal=None, **kwds): \n", " self.__dict__.update(initial=initial, goal=goal, **kwds) \n", " \n", " def __str__(self):\n", " return '{}({!r}, {!r})'.format(type(self).__name__, self.initial, self.goal)\n", " \n", " def actions(self, state): raise NotImplementedError\n", " def result(self, state, action): return action # Simplest case: action is result state\n", " def is_goal(self, state): return state == self.goal\n", " def action_cost(self, s, a, s1): return 1\n", " def h(self, node): return 0 # Never overestimate!\n", " \n", "class GridProblem(SearchProblem):\n", " \"\"\"Problem for searching a grid from a start to a goal location.\n", " A state is just an (x, y) location in the grid.\"\"\"\n", " def actions(self, loc): return self.grid.neighbors(loc)\n", " def result(self, loc1, loc2): return loc2\n", " def h(self, node): return taxi_distance(node.state, self.goal) \n", "\n", "class Node:\n", " \"A Node in a search tree.\"\n", " def __init__(self, state, parent=None, action=None, path_cost=0):\n", " self.__dict__.update(state=state, parent=parent, action=action, path_cost=path_cost)\n", "\n", " def __repr__(self): return f'Node({self.state}, path_cost={self.path_cost})'\n", " def __len__(self): return 0 if self.parent is None else (1 + len(self.parent))\n", " def __lt__(self, other): return self.path_cost < other.path_cost\n", " \n", "search_failure = Node('failure', path_cost=inf) # Indicates an algorithm couldn't find a solution.\n", " \n", "def expand(problem, node):\n", " \"Expand a node, generating the children nodes.\"\n", " s = node.state\n", " for action in problem.actions(s):\n", " s2 = problem.result(s, action)\n", " cost = node.path_cost + problem.action_cost(s, action, s2)\n", " yield Node(s2, node, action, cost)\n", " \n", "def path_actions(node):\n", " \"The sequence of actions to get to this node.\"\n", " if node.parent is None:\n", " return [] \n", " return path_actions(node.parent) + [node.action]\n", "\n", "def path_states(node):\n", " \"The sequence of states to get to this node.\"\n", " if node in (search_failure, None): \n", " return []\n", " return path_states(node.parent) + [node.state]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Other Data Structures\n", "\n", "Here I define a few data types:\n", "- The priority queue, which is needed for A* search.\n", "- Hashable versions of dicts and Counters. These can be used in sets or as keys in dicts. Beware: unlike the `frozenset`, these are not safe: if you modify one after inserting it in a set or dict, it probably will not be found.\n", "- Graphs of `{node: [neighboring_node, ...]}`.\n", "- An `AttrCounter`, which is just like a `Counter`, but can be accessed with, say, `ctr.name` as well as `ctr['name']`. " ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "class PriorityQueue:\n", " \"\"\"A queue in which the item with minimum key(item) is always popped first.\"\"\"\n", "\n", " def __init__(self, items=(), key=lambda x: x): \n", " self.key = key\n", " self.items = [] # a heap of (score, item) pairs\n", " for item in items:\n", " self.add(item)\n", " \n", " def add(self, item):\n", " \"\"\"Add item to the queue.\"\"\"\n", " pair = (self.key(item), item)\n", " heapq.heappush(self.items, pair)\n", "\n", " def pop(self):\n", " \"\"\"Pop and return the item with min f(item) value.\"\"\"\n", " return heapq.heappop(self.items)[1]\n", " \n", " def top(self): return self.items[0][1]\n", "\n", " def __len__(self): return len(self.items)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "class Hdict(dict):\n", " \"\"\"A dict, but it is hashable.\"\"\"\n", " def __hash__(self): return hash(tuple(sorted(self.items())))\n", " \n", "class HCounter(Counter):\n", " \"\"\"A Counter, but it is hashable.\"\"\"\n", " def __hash__(self): return hash(tuple(sorted(self.items())))" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "class Graph(dict):\n", " \"\"\"A graph of {node: [neighboring_nodes...]}. \n", " Can store other kwd attributes on it (which you can't do with a dict).\"\"\"\n", " def __init__(self, contents, **kwds):\n", " self.update(contents)\n", " self.__dict__.update(**kwds)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "class AttrCounter(Counter):\n", " \"\"\"A Counter, but `ctr['name']` and `ctr.name` are the same.\"\"\"\n", " def __getattr__(self, attr):\n", " return self[attr]\n", " def __setattr__(self, attr, value):\n", " self[attr] = value" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Tests" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "def tests():\n", " \"\"\"Run tests on utility functions. Also serves as usage examples.\"\"\"\n", " \n", " # PARSER\n", "\n", " assert parse(\"hello\\nworld\", show=0) == ('hello', 'world')\n", " assert parse(\"123\\nabc7\", digits, show=0) == ((1, 2, 3), (7,))\n", " assert truncate('hello world', 99) == 'hello world'\n", " assert truncate('hello world', 8) == 'hell ...'\n", "\n", " assert atoms('hello, cruel_world! 24-7') == ('hello', 'cruel_world', 24, -7)\n", " assert words('hello, cruel_world! 24-7') == ('hello', 'cruel', 'world')\n", " assert digits('hello, cruel_world! 24-7') == (2, 4, 7)\n", " assert ints('hello, cruel_world! 24-7') == (24, -7)\n", " assert positive_ints('hello, cruel_world! 24-7') == (24, 7)\n", "\n", " # UTILITIES\n", "\n", " assert multimap(((i % 3), i) for i in range(9)) == {0: [0, 3, 6], 1: [1, 4, 7], 2: [2, 5, 8]}\n", " assert prod([2, 3, 5]) == 30\n", " assert total(Counter('hello, world')) == 12\n", " assert cover(3, 1, 4, 1, 5) == range(1, 6)\n", " assert minmax([3, 1, 4, 1, 5, 9]) == (1, 9)\n", " assert T([(1, 2, 3), (4, 5, 6)]) == [(1, 4), (2, 5), (3, 6)]\n", " assert the({1}) == 1\n", " assert split_at('hello, world', 6) == ('hello,', ' world')\n", " assert is_int(-42) and not is_int('one')\n", " assert sign(-42) == -1 and sign(0) == 0 and sign(42) == +1\n", " assert union([{1, 2}, {3, 4}, {5, 6}]) == {1, 2, 3, 4, 5, 6}\n", " assert intersection([{1, 2, 3}, {2, 3, 4}, {2, 4, 6, 8}]) == {2}\n", " assert clock_mod(24, 12) == 12 and 24 % 12 == 0\n", " assert cat(['hello', 'world']) == 'helloworld'\n", "\n", " # ITERTOOL RECIPES\n", "\n", " assert quantify(words('This is a test'), str.islower) == 3\n", " assert dotproduct([1, 2, 3, 4], [1000, 100, 10, 1]) == 1234\n", " assert list(flatten([{1, 2, 3}, (4, 5, 6), [7, 8, 9]])) == [1, 2, 3, 4, 5, 6, 7, 8, 9]\n", " assert append(([1, 2], [3, 4], [5, 6])) == [1, 2, 3, 4, 5, 6]\n", " assert list(batched(range(11), 3)) == [(0, 1, 2), (3, 4, 5), (6, 7, 8), (9, 10)]\n", " assert list(sliding_window('abcdefghi', 3)) == ['abc', 'bcd', 'cde', 'def', 'efg', 'fgh', 'ghi']\n", " assert first('abc') == 'a'\n", " assert first('') == None\n", " assert last('abc') == 'c'\n", " assert first_true([0, None, False, 42, 99]) == 42\n", " assert first_true([0, None, '', 0.0]) == False\n", "\n", " # POINTS\n", "\n", " p, q = (0, 3), (4, 0)\n", " assert Y_(p) == 3 and X_(q) == 4\n", " assert distance(p, q) == 5\n", " assert taxi_distance(p, q) == 7\n", " assert add(p, q) == (4, 3)\n", " assert sub(p, q) == (-4, 3)\n", " assert add(North, South) == (0, 0)\n", " \n", "tests()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.15" } }, "nbformat": 4, "nbformat_minor": 4 }