{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|default_exp xtras" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "from __future__ import annotations" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "from fastcore.imports import *\n", "from fastcore.foundation import *\n", "from fastcore.basics import *\n", "from importlib import import_module\n", "from functools import wraps\n", "import string,time\n", "from contextlib import contextmanager,ExitStack\n", "from datetime import datetime, timezone\n", "from time import sleep,time,perf_counter\n", "from os.path import getmtime" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|hide\n", "from fastcore.test import *\n", "from nbdev.showdoc import *\n", "from fastcore.nb_imports import *\n", "\n", "import shutil,tempfile,pickle,random\n", "from dataclasses import dataclass" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Utility functions\n", "\n", "> Utility functions used in the fastai library" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## File Functions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Utilities (other than extensions to Pathlib.Path) for dealing with IO." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def walk(\n", " path:Path|str, # path to start searching\n", " symlinks:bool=True, # follow symlinks?\n", " keep_file:callable=ret_true, # function that returns True for wanted files\n", " keep_folder:callable=ret_true, # function that returns True for folders to enter\n", " skip_folder:callable=ret_false, # function that returns True for folders to skip\n", " func:callable=os.path.join, # function to apply to each matched file\n", " ret_folders:bool=False # return folders, not just files\n", "):\n", " \"Generator version of `os.walk`, using functions to filter files and folders\"\n", " from copy import copy\n", " for root,dirs,files in os.walk(path, followlinks=symlinks):\n", " if keep_folder(root,''):\n", " if ret_folders: yield func(root, '')\n", " yield from (func(root, name) for name in files if keep_file(root,name))\n", " for name in copy(dirs):\n", " if skip_folder(root,name): dirs.remove(name)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def globtastic(\n", " path:Path|str, # path to start searching\n", " recursive:bool=True, # search subfolders\n", " symlinks:bool=True, # follow symlinks?\n", " file_glob:str=None, # Only include files matching glob\n", " file_re:str=None, # Only include files matching regex\n", " folder_re:str=None, # Only enter folders matching regex\n", " skip_file_glob:str=None, # Skip files matching glob\n", " skip_file_re:str=None, # Skip files matching regex\n", " skip_folder_re:str=None, # Skip folders matching regex,\n", " func:callable=os.path.join, # function to apply to each matched file\n", " ret_folders:bool=False # return folders, not just files\n", ")->L: # Paths to matched files\n", " \"A more powerful `glob`, including regex matches, symlink handling, and skip parameters\"\n", " from fnmatch import fnmatch\n", " path = Path(path)\n", " if path.is_file(): return L([path])\n", " if not recursive: skip_folder_re='.'\n", " file_re,folder_re = compile_re(file_re),compile_re(folder_re)\n", " skip_file_re,skip_folder_re = compile_re(skip_file_re),compile_re(skip_folder_re)\n", " def _keep_file(root, name):\n", " return (not file_glob or fnmatch(name, file_glob)) and (\n", " not file_re or file_re.search(name)) and (\n", " not skip_file_glob or not fnmatch(name, skip_file_glob)) and (\n", " not skip_file_re or not skip_file_re.search(name))\n", " def _keep_folder(root, name): return not folder_re or folder_re.search(os.path.join(root,name))\n", " def _skip_folder(root, name): return skip_folder_re and skip_folder_re.search(name)\n", " return L(walk(path, symlinks=symlinks, keep_file=_keep_file, keep_folder=_keep_folder, skip_folder=_skip_folder,\n", " func=func, ret_folders=ret_folders))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(#5) ['./fastcore/docments.py','./fastcore/dispatch.py','./fastcore/basics.py','./fastcore/docscrape.py','./fastcore/script.py']" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "globtastic('.', skip_folder_re='^[_.]', folder_re='core', file_glob='*.*py*', file_re='c')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "@contextmanager\n", "def maybe_open(f, mode='r', **kwargs):\n", " \"Context manager: open `f` if it is a path (and close on exit)\"\n", " if isinstance(f, (str,os.PathLike)):\n", " with open(f, mode, **kwargs) as f: yield f\n", " else: yield f" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is useful for functions where you want to accept a path *or* file. `maybe_open` will not close your file handle if you pass one in." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def _f(fn):\n", " with maybe_open(fn) as f: return f.encoding\n", "\n", "fname = '00_test.ipynb'\n", "sys_encoding = 'cp1252' if sys.platform == 'win32' else 'UTF-8'\n", "test_eq(_f(fname), sys_encoding)\n", "with open(fname) as fh: test_eq(_f(fh), sys_encoding)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For example, we can use this to reimplement [`imghdr.what`](https://docs.python.org/3/library/imghdr.html#imghdr.what) from the Python standard library, which is [written in Python 3.9](https://github.com/python/cpython/blob/3.9/Lib/imghdr.py#L11) as:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastcore import imghdr" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def what(file, h=None):\n", " f = None\n", " try:\n", " if h is None:\n", " if isinstance(file, (str,os.PathLike)):\n", " f = open(file, 'rb')\n", " h = f.read(32)\n", " else:\n", " location = file.tell()\n", " h = file.read(32)\n", " file.seek(location)\n", " for tf in imghdr.tests:\n", " res = tf(h, f)\n", " if res: return res\n", " finally:\n", " if f: f.close()\n", " return None" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here's an example of the use of this function:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'jpeg'" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "fname = 'images/puppy.jpg'\n", "what(fname)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "With `maybe_open`, `Self`, and `L.map_first`, we can rewrite this in a much more concise and (in our opinion) clear way:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def what(file, h=None):\n", " if h is None:\n", " with maybe_open(file, 'rb') as f: h = f.peek(32)\n", " return L(imghdr.tests).map_first(Self(h,file))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "...and we can check that it still works:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_eq(what(fname), 'jpeg')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "...along with the version passing a file handle:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with open(fname,'rb') as f: test_eq(what(f), 'jpeg')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "...along with the `h` parameter version:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with open(fname,'rb') as f: test_eq(what(None, h=f.read(32)), 'jpeg')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def mkdir(path, exist_ok=False, parents=False, overwrite=False, **kwargs):\n", " \"Creates and returns a directory defined by `path`, optionally removing previous existing directory if `overwrite` is `True`\"\n", " import shutil\n", " path = Path(path)\n", " if path.exists() and overwrite: shutil.rmtree(path)\n", " path.mkdir(exist_ok=exist_ok, parents=parents, **kwargs)\n", " return path" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with tempfile.TemporaryDirectory() as d:\n", " path = Path(os.path.join(d, 'new_dir'))\n", " new_dir = mkdir(path)\n", " assert new_dir.exists()\n", " test_eq(new_dir, path)\n", " \n", " # test overwrite\n", " with open(new_dir/'test.txt', 'w') as f: f.writelines('test')\n", " test_eq(len(list(walk(new_dir))), 1) # assert file is present\n", " new_dir = mkdir(new_dir, overwrite=True)\n", " test_eq(len(list(walk(new_dir))), 0) # assert file was deleted" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def image_size(fn):\n", " \"Tuple of (w,h) for png, gif, or jpg; `None` otherwise\"\n", " from fastcore import imghdr\n", " import struct\n", " def _jpg_size(f):\n", " size,ftype = 2,0\n", " while not 0xc0 <= ftype <= 0xcf:\n", " f.seek(size, 1)\n", " byte = f.read(1)\n", " while ord(byte) == 0xff: byte = f.read(1)\n", " ftype = ord(byte)\n", " size = struct.unpack('>H', f.read(2))[0] - 2\n", " f.seek(1, 1) # `precision'\n", " h,w = struct.unpack('>HH', f.read(4))\n", " return w,h\n", "\n", " def _gif_size(f): return struct.unpack('i', head[4:8])[0]==0x0d0a1a0a\n", " return struct.unpack('>ii', head[16:24])\n", " d = dict(png=_png_size, gif=_gif_size, jpeg=_jpg_size)\n", " with maybe_open(fn, 'rb') as f: return d[imghdr.what(f)](f)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_eq(image_size(fname), (1200,803))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def bunzip(fn):\n", " \"bunzip `fn`, raising exception if output already exists\"\n", " fn = Path(fn)\n", " assert fn.exists(), f\"{fn} doesn't exist\"\n", " out_fn = fn.with_suffix('')\n", " assert not out_fn.exists(), f\"{out_fn} already exists\"\n", " import bz2\n", " with bz2.BZ2File(fn, 'rb') as src, out_fn.open('wb') as dst:\n", " for d in iter(lambda: src.read(1024*1024), b''): dst.write(d)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "f = Path('files/test.txt')\n", "if f.exists(): f.unlink()\n", "bunzip('files/test.txt.bz2')\n", "t = f.open().readlines()\n", "test_eq(len(t),1)\n", "test_eq(t[0], 'test\\n')\n", "f.unlink()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def loads(s, **kw):\n", " \"Same as `json.loads`, but handles `None`\"\n", " if not s: return {}\n", " try: import ujson as json\n", " except ModuleNotFoundError: import json\n", " return json.loads(s, **kw)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def loads_multi(s:str):\n", " \"Generator of >=0 decoded json dicts, possibly with non-json ignored text at start and end\"\n", " import json\n", " _dec = json.JSONDecoder()\n", " while s.find('{')>=0:\n", " s = s[s.find('{'):]\n", " obj,pos = _dec.raw_decode(s)\n", " if not pos: raise ValueError(f'no JSON object found at {pos}')\n", " yield obj\n", " s = s[pos:]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst = \"\"\"\n", "# ignored\n", "{ \"a\":1 }\n", "hello\n", "{\n", "\"b\":2\n", "}\n", "\"\"\"\n", "\n", "test_eq(list(loads_multi(tst)), [{'a': 1}, {'b': 2}])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def dumps(obj, **kw):\n", " \"Same as `json.dumps`, but uses `ujson` if available\"\n", " try: import ujson as json\n", " except ModuleNotFoundError: import json\n", " else: kw['escape_forward_slashes']=False\n", " return json.dumps(obj, **kw)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def _unpack(fname, out):\n", " import shutil\n", " shutil.unpack_archive(str(fname), str(out))\n", " ls = out.ls()\n", " return ls[0] if len(ls) == 1 else out" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def untar_dir(fname, dest, rename=False, overwrite=False):\n", " \"untar `file` into `dest`, creating a directory if the root contains more than one item\"\n", " import tempfile,shutil\n", " with tempfile.TemporaryDirectory() as d:\n", " out = Path(d)/remove_suffix(Path(fname).stem, '.tar')\n", " out.mkdir()\n", " if rename: dest = dest/out.name\n", " else:\n", " src = _unpack(fname, out)\n", " dest = dest/src.name\n", " if dest.exists():\n", " if overwrite: shutil.rmtree(dest) if dest.is_dir() else dest.unlink()\n", " else: return dest\n", " if rename: src = _unpack(fname, out)\n", " shutil.move(str(src), dest)\n", " return dest" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def test_untar(foldername, rename=False, **kwargs):\n", " with tempfile.TemporaryDirectory() as d:\n", " nm = os.path.join(d, 'a')\n", " shutil.make_archive(nm, 'gztar', **kwargs)\n", " with tempfile.TemporaryDirectory() as d2:\n", " d2 = Path(d2)\n", " untar_dir(nm+'.tar.gz', d2, rename=rename)\n", " test_eq(d2.ls(), [d2/foldername])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If the contents of `fname` contain just one file or directory, it is placed directly in `dest`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# using `base_dir` in `make_archive` results in `images` directory included in file names\n", "test_untar('images', base_dir='images')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If `rename` then the directory created is named based on the archive, without extension:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_untar('a', base_dir='images', rename=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If the contents of `fname` contain multiple files and directories, a new folder in `dest` is created with the same name as `fname` (but without extension):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# using `root_dir` in `make_archive` results in `images` directory *not* included in file names\n", "test_untar('a', root_dir='images')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def repo_details(url):\n", " \"Tuple of `owner,name` from ssh or https git repo `url`\"\n", " res = remove_suffix(url.strip(), '.git')\n", " res = res.split(':')[-1]\n", " return res.split('/')[-2:]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_eq(repo_details('https://github.com/fastai/fastai.git'), ['fastai', 'fastai'])\n", "test_eq(repo_details('git@github.com:fastai/nbdev.git\\n'), ['fastai', 'nbdev'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def run(cmd, *rest, same_in_win=False, ignore_ex=False, as_bytes=False, stderr=False):\n", " \"Pass `cmd` (splitting with `shlex` if string) to `subprocess.run`; return `stdout`; raise `IOError` if fails\"\n", " # Even the command is same on Windows, we have to add `cmd /c `\"\n", " import subprocess\n", " if rest:\n", " if sys.platform == 'win32' and same_in_win:\n", " cmd = ('cmd', '/c', cmd, *rest)\n", " else:\n", " cmd = (cmd,)+rest\n", " elif isinstance(cmd, str):\n", " if sys.platform == 'win32' and same_in_win: cmd = 'cmd /c ' + cmd\n", " import shlex\n", " cmd = shlex.split(cmd)\n", " elif isinstance(cmd, list):\n", " if sys.platform == 'win32' and same_in_win: cmd = ['cmd', '/c'] + cmd\n", " res = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)\n", " stdout = res.stdout\n", " if stderr and res.stderr: stdout += b' ;; ' + res.stderr\n", " if not as_bytes: stdout = stdout.decode().strip()\n", " if ignore_ex: return (res.returncode, stdout)\n", " if res.returncode: raise IOError(stdout)\n", " return stdout" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can pass a string (which will be split based on standard shell rules), a list, or pass args directly:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'pip 23.3.1 from /Users/jhoward/miniconda3/lib/python3.11/site-packages/pip (python 3.11)'" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "run('echo', same_in_win=True)\n", "run('pip', '--version', same_in_win=True)\n", "run(['pip', '--version'], same_in_win=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if sys.platform == 'win32':\n", " assert 'ipynb' in run('cmd /c dir /p')\n", " assert 'ipynb' in run(['cmd', '/c', 'dir', '/p'])\n", " assert 'ipynb' in run('cmd', '/c', 'dir', '/p')\n", "else:\n", " assert 'ipynb' in run('ls -ls')\n", " assert 'ipynb' in run(['ls', '-l'])\n", " assert 'ipynb' in run('ls', '-l')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Some commands fail in non-error situations, like `grep`. Use `ignore_ex` in those cases, which will return a tuple of stdout and returncode:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if sys.platform == 'win32':\n", " test_eq(run('cmd /c findstr asdfds 00_test.ipynb', ignore_ex=True)[0], 1)\n", "else:\n", " test_eq(run('grep asdfds 00_test.ipynb', ignore_ex=True)[0], 1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`run` automatically decodes returned bytes to a `str`. Use `as_bytes` to skip that:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if sys.platform == 'win32':\n", " test_eq(run('cmd /c echo hi'), 'hi')\n", "else:\n", " test_eq(run('echo hi', as_bytes=True), b'hi\\n')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def open_file(fn, mode='r', **kwargs):\n", " \"Open a file, with optional compression if gz or bz2 suffix\"\n", " if isinstance(fn, io.IOBase): return fn\n", " import bz2,gzip,zipfile\n", " fn = Path(fn)\n", " if fn.suffix=='.bz2': return bz2.BZ2File(fn, mode, **kwargs)\n", " elif fn.suffix=='.gz' : return gzip.GzipFile(fn, mode, **kwargs)\n", " elif fn.suffix=='.zip': return zipfile.ZipFile(fn, mode, **kwargs)\n", " else: return open(fn,mode, **kwargs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def save_pickle(fn, o):\n", " \"Save a pickle file, to a file name or opened file\"\n", " import pickle\n", " with open_file(fn, 'wb') as f: pickle.dump(o, f)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def load_pickle(fn):\n", " \"Load a pickle file from a file name or opened file\"\n", " import pickle\n", " with open_file(fn, 'rb') as f: return pickle.load(f)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for suf in '.pkl','.bz2','.gz':\n", " # delete=False is added for Windows\n", " # https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file\n", " with tempfile.NamedTemporaryFile(suffix=suf, delete=False) as f:\n", " fn = Path(f.name)\n", " save_pickle(fn, 't')\n", " t = load_pickle(fn)\n", " f.close()\n", " test_eq(t,'t')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| export\n", "def parse_env(s:str=None, fn:Union[str,Path]=None) -> dict:\n", " \"Parse a shell-style environment string or file\"\n", " assert bool(s)^bool(fn), \"Must pass exactly one of `s` or `fn`\"\n", " if fn: s = Path(fn).read_text()\n", " def _f(line):\n", " m = re.match(r'^\\s*(?:export\\s+)?(\\w+)\\s*=\\s*([\"\\']?)(.*?)(\\2)\\s*(?:#.*)?$', line).groups()\n", " return m[0], m[2]\n", "\n", " return dict(_f(o.strip()) for o in s.splitlines() if o.strip() and not re.match(r'\\s*#', o))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "testf = \"\"\"# comment\n", " # another comment\n", " export FOO=\"bar#baz\"\n", "BAR=thing # comment \"ok\"\n", " baz='thong'\n", "QUX=quux\n", "export ZAP = \"zip\" # more comments\n", " FOOBAR = 42 # trailing space and comment\"\"\"\n", "\n", "exp = dict(FOO='bar#baz', BAR='thing', baz='thong', QUX='quux', ZAP='zip', FOOBAR='42')\n", "\n", "test_eq(parse_env(testf), exp)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| export\n", "def expand_wildcards(code):\n", " \"Expand all wildcard imports in the given code string.\"\n", " import ast,importlib\n", " tree = ast.parse(code)\n", "\n", " def _replace_node(code, old_node, new_node):\n", " \"Replace `old_node` in the source `code` with `new_node`.\"\n", " lines = code.splitlines()\n", " lnum = old_node.lineno\n", " indent = ' ' * (len(lines[lnum-1]) - len(lines[lnum-1].lstrip()))\n", " new_lines = [indent+line for line in ast.unparse(new_node).splitlines()]\n", " lines[lnum-1 : old_node.end_lineno] = new_lines\n", " return '\\n'.join(lines)\n", "\n", " def _expand_import(node, mod, existing):\n", " \"Create expanded import `node` in `tree` from wildcard import of `mod`.\"\n", " mod_all = getattr(mod, '__all__', None)\n", " available_names = set(mod_all) if mod_all is not None else set(dir(mod))\n", " used_names = {n.id for n in ast.walk(tree) if isinstance(n, ast.Name) and n.id in available_names} - existing\n", " if not used_names: return node\n", " names = [ast.alias(name=name, asname=None) for name in sorted(used_names)]\n", " return ast.ImportFrom(module=node.module, names=names, level=node.level)\n", "\n", " existing = set()\n", " for node in ast.walk(tree):\n", " if isinstance(node, ast.ImportFrom) and node.names[0].name != '*': existing.update(n.name for n in node.names)\n", " elif isinstance(node, ast.Import): existing.update(n.name.split('.')[0] for n in node.names)\n", " for node in ast.walk(tree):\n", " if isinstance(node, ast.ImportFrom) and any(n.name == '*' for n in node.names):\n", " new_import = _expand_import(node, importlib.import_module(node.module), existing)\n", " code = _replace_node(code, node, new_import)\n", " return code" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "inp = \"\"\"from math import *\n", "from os import *\n", "from random import *\n", "def func(): return sin(pi) + path.join('a', 'b') + randint(1, 10)\"\"\"\n", "\n", "exp = \"\"\"from math import pi, sin\n", "from os import path\n", "from random import randint\n", "def func(): return sin(pi) + path.join('a', 'b') + randint(1, 10)\"\"\"\n", "\n", "test_eq(expand_wildcards(inp), exp)\n", "\n", "inp = \"\"\"from itertools import *\n", "def func(): pass\"\"\"\n", "test_eq(expand_wildcards(inp), inp)\n", "\n", "inp = \"\"\"def outer():\n", " from math import *\n", " def inner():\n", " from os import *\n", " return sin(pi) + path.join('a', 'b')\"\"\"\n", "\n", "exp = \"\"\"def outer():\n", " from math import pi, sin\n", " def inner():\n", " from os import path\n", " return sin(pi) + path.join('a', 'b')\"\"\"\n", "\n", "test_eq(expand_wildcards(inp), exp)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Collections" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def dict2obj(d, list_func=L, dict_func=AttrDict):\n", " \"Convert (possibly nested) dicts (or lists of dicts) to `AttrDict`\"\n", " if isinstance(d, (L,list)): return list_func(d).map(dict2obj)\n", " if not isinstance(d, dict): return d\n", " return dict_func(**{k:dict2obj(v) for k,v in d.items()})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is a convenience to give you \"dotted\" access to (possibly nested) dictionaries, e.g:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "d1 = dict(a=1, b=dict(c=2,d=3))\n", "d2 = dict2obj(d1)\n", "test_eq(d2.b.c, 2)\n", "test_eq(d2.b['c'], 2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It can also be used on lists of dicts." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "_list_of_dicts = [d1, d1]\n", "ds = dict2obj(_list_of_dicts)\n", "test_eq(ds[0].b.c, 2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def obj2dict(d):\n", " \"Convert (possibly nested) AttrDicts (or lists of AttrDicts) to `dict`\"\n", " if isinstance(d, (L,list)): return list(L(d).map(obj2dict))\n", " if not isinstance(d, dict): return d\n", " return dict(**{k:obj2dict(v) for k,v in d.items()})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`obj2dict` can be used to reverse what is done by `dict2obj`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_eq(obj2dict(d2), d1)\n", "test_eq(obj2dict(ds), _list_of_dicts) " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def _repr_dict(d, lvl):\n", " if isinstance(d,dict):\n", " its = [f\"{k}: {_repr_dict(v,lvl+1)}\" for k,v in d.items()]\n", " elif isinstance(d,(list,L)): its = [_repr_dict(o,lvl+1) for o in d]\n", " else: return str(d)\n", " return '\\n' + '\\n'.join([\" \"*(lvl*2) + \"- \" + o for o in its])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def repr_dict(d):\n", " \"Print nested dicts and lists, such as returned by `dict2obj`\"\n", " return _repr_dict(d,0).strip()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "- a: 1\n", "- b: \n", " - c: 2\n", " - d: 3\n" ] } ], "source": [ "print(repr_dict(d2))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def is_listy(x):\n", " \"`isinstance(x, (tuple,list,L,slice,Generator))`\"\n", " return isinstance(x, (tuple,list,L,slice,Generator))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "assert is_listy((1,))\n", "assert is_listy([1])\n", "assert is_listy(L([1]))\n", "assert is_listy(slice(2))\n", "assert not is_listy(array([1]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def mapped(f, it):\n", " \"map `f` over `it`, unless it's not listy, in which case return `f(it)`\"\n", " return L(it).map(f) if is_listy(it) else f(it)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def _f(x,a=1): return x-a\n", "\n", "test_eq(mapped(_f,1),0)\n", "test_eq(mapped(_f,[1,2]),[0,1])\n", "test_eq(mapped(_f,(1,)),(0,))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Extensions to Pathlib.Path" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The following methods are added to the standard python libary [Pathlib.Path](https://docs.python.org/3/library/pathlib.html#basic-use)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "@patch\n", "def readlines(self:Path, hint=-1, encoding='utf8'):\n", " \"Read the content of `self`\"\n", " with self.open(encoding=encoding) as f: return f.readlines(hint)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "@patch\n", "def read_json(self:Path, encoding=None, errors=None):\n", " \"Same as `read_text` followed by `loads`\"\n", " return loads(self.read_text(encoding=encoding, errors=errors))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "@patch\n", "def mk_write(self:Path, data, encoding=None, errors=None, mode=511):\n", " \"Make all parent dirs of `self`, and write `data`\"\n", " self.parent.mkdir(exist_ok=True, parents=True, mode=mode)\n", " self.write_text(data, encoding=encoding, errors=errors)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "@patch\n", "def relpath(self:Path, start=None):\n", " \"Same as `os.path.relpath`, but returns a `Path`, and resolves symlinks\"\n", " return Path(os.path.relpath(self.resolve(), Path(start).resolve()))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Path('/Users/jhoward/Documents/GitHub/fastcore/fastcore')" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "p = Path('../fastcore/').resolve()\n", "p" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Path('../fastcore')" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "p.relpath(Path.cwd())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "@patch\n", "def ls(self:Path, n_max=None, file_type=None, file_exts=None):\n", " \"Contents of path as a list\"\n", " import mimetypes\n", " extns=L(file_exts)\n", " if file_type: extns += L(k for k,v in mimetypes.types_map.items() if v.startswith(file_type+'/'))\n", " has_extns = len(extns)==0\n", " res = (o for o in self.iterdir() if has_extns or o.suffix in extns)\n", " if n_max is not None: res = itertools.islice(res, n_max)\n", " return L(res)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We add an `ls()` method to `pathlib.Path` which is simply defined as `list(Path.iterdir())`, mainly for convenience in REPL environments such as notebooks." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Path('000_tour.ipynb')" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "path = Path()\n", "t = path.ls()\n", "assert len(t)>0\n", "t1 = path.ls(10)\n", "test_eq(len(t1), 10)\n", "t2 = path.ls(file_exts='.ipynb')\n", "assert len(t)>len(t2)\n", "t[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can also pass an optional `file_type` MIME prefix and/or a list of file extensions." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(Path('../fastcore/shutil.py'), Path('000_tour.ipynb'))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lib_path = (path/'../fastcore')\n", "txt_files=lib_path.ls(file_type='text')\n", "assert len(txt_files) > 0 and txt_files[0].suffix=='.py'\n", "ipy_files=path.ls(file_exts=['.ipynb'])\n", "assert len(ipy_files) > 0 and ipy_files[0].suffix=='.ipynb'\n", "txt_files[0],ipy_files[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|hide\n", "path = Path()\n", "pkl = pickle.dumps(path)\n", "p2 = pickle.loads(pkl)\n", "test_eq(path.ls()[0], p2.ls()[0])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "@patch\n", "def __repr__(self:Path):\n", " b = getattr(Path, 'BASE_PATH', None)\n", " if b:\n", " try: self = self.relative_to(b)\n", " except: pass\n", " return f\"Path({self.as_posix()!r})\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "fastai also updates the `repr` of `Path` such that, if `Path.BASE_PATH` is defined, all paths are printed relative to that path (as long as they are contained in `Path.BASE_PATH`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "t = ipy_files[0].absolute()\n", "try:\n", " Path.BASE_PATH = t.parent.parent\n", " test_eq(repr(t), f\"Path('nbs/{t.name}')\")\n", "finally: Path.BASE_PATH = None" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "@patch\n", "def delete(self:Path):\n", " \"Delete a file, symlink, or directory tree\"\n", " if not self.exists(): return\n", " if self.is_dir():\n", " import shutil\n", " shutil.rmtree(self)\n", " else: self.unlink()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Reindexing Collections" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "#|hide\n", "class IterLen:\n", " \"Base class to add iteration to anything supporting `__len__` and `__getitem__`\"\n", " def __iter__(self): return (self[i] for i in range_of(self))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "@docs\n", "class ReindexCollection(GetAttr, IterLen):\n", " \"Reindexes collection `coll` with indices `idxs` and optional LRU cache of size `cache`\"\n", " _default='coll'\n", " def __init__(self, coll, idxs=None, cache=None, tfm=noop):\n", " if idxs is None: idxs = L.range(coll)\n", " store_attr()\n", " if cache is not None: self._get = functools.lru_cache(maxsize=cache)(self._get)\n", "\n", " def _get(self, i): return self.tfm(self.coll[i])\n", " def __getitem__(self, i): return self._get(self.idxs[i])\n", " def __len__(self): return len(self.coll)\n", " def reindex(self, idxs): self.idxs = idxs\n", " def shuffle(self):\n", " import random\n", " random.shuffle(self.idxs)\n", " def cache_clear(self): self._get.cache_clear()\n", " def __getstate__(self): return {'coll': self.coll, 'idxs': self.idxs, 'cache': self.cache, 'tfm': self.tfm}\n", " def __setstate__(self, s): self.coll,self.idxs,self.cache,self.tfm = s['coll'],s['idxs'],s['cache'],s['tfm']\n", "\n", " _docs = dict(reindex=\"Replace `self.idxs` with idxs\",\n", " shuffle=\"Randomly shuffle indices\",\n", " cache_clear=\"Clear LRU cache\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "---\n", "\n", "[source](https://github.com/fastai/fastcore/blob/master/fastcore/xtras.py#L380){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "#### ReindexCollection\n", "\n", "> ReindexCollection (coll, idxs=None, cache=None, tfm=)\n", "\n", "*Reindexes collection `coll` with indices `idxs` and optional LRU cache of size `cache`*" ], "text/plain": [ "---\n", "\n", "[source](https://github.com/fastai/fastcore/blob/master/fastcore/xtras.py#L380){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "#### ReindexCollection\n", "\n", "> ReindexCollection (coll, idxs=None, cache=None, tfm=)\n", "\n", "*Reindexes collection `coll` with indices `idxs` and optional LRU cache of size `cache`*" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "show_doc(ReindexCollection, title_level=4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is useful when constructing batches or organizing data in a particular manner (i.e. for deep learning). This class is primarly used in organizing data for language models in fastai." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can supply a custom index upon instantiation with the `idxs` argument, or you can call the `reindex` method to supply a new index for your collection.\n", "\n", "Here is how you can reindex a list such that the elements are reversed:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['e', 'd', 'c', 'b', 'a']" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "rc=ReindexCollection(['a', 'b', 'c', 'd', 'e'], idxs=[4,3,2,1,0])\n", "list(rc)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Alternatively, you can use the `reindex` method:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "---\n", "\n", "[source](https://github.com/fastai/fastcore/blob/master/fastcore/xtras.py#L391){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "###### ReindexCollection.reindex\n", "\n", "> ReindexCollection.reindex (idxs)\n", "\n", "*Replace `self.idxs` with idxs*" ], "text/plain": [ "---\n", "\n", "[source](https://github.com/fastai/fastcore/blob/master/fastcore/xtras.py#L391){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "###### ReindexCollection.reindex\n", "\n", "> ReindexCollection.reindex (idxs)\n", "\n", "*Replace `self.idxs` with idxs*" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "show_doc(ReindexCollection.reindex, title_level=6)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['e', 'd', 'c', 'b', 'a']" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "rc=ReindexCollection(['a', 'b', 'c', 'd', 'e'])\n", "rc.reindex([4,3,2,1,0])\n", "list(rc)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can optionally specify a LRU cache, which uses [functools.lru_cache](https://docs.python.org/3/library/functools.html#functools.lru_cache) upon instantiation:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "CacheInfo(hits=1, misses=1, maxsize=2, currsize=1)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sz = 50\n", "t = ReindexCollection(L.range(sz), cache=2)\n", "\n", "#trigger a cache hit by indexing into the same element multiple times\n", "t[0], t[0]\n", "t._get.cache_info()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can optionally clear the LRU cache by calling the `cache_clear` method:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "---\n", "\n", "[source](https://github.com/fastai/fastcore/blob/master/fastcore/xtras.py#L395){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "##### ReindexCollection.cache_clear\n", "\n", "> ReindexCollection.cache_clear ()\n", "\n", "*Clear LRU cache*" ], "text/plain": [ "---\n", "\n", "[source](https://github.com/fastai/fastcore/blob/master/fastcore/xtras.py#L395){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "##### ReindexCollection.cache_clear\n", "\n", "> ReindexCollection.cache_clear ()\n", "\n", "*Clear LRU cache*" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "show_doc(ReindexCollection.cache_clear, title_level=5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "CacheInfo(hits=0, misses=0, maxsize=2, currsize=0)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sz = 50\n", "t = ReindexCollection(L.range(sz), cache=2)\n", "\n", "#trigger a cache hit by indexing into the same element multiple times\n", "t[0], t[0]\n", "t.cache_clear()\n", "t._get.cache_info()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "---\n", "\n", "[source](https://github.com/fastai/fastcore/blob/master/fastcore/xtras.py#L392){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "##### ReindexCollection.shuffle\n", "\n", "> ReindexCollection.shuffle ()\n", "\n", "*Randomly shuffle indices*" ], "text/plain": [ "---\n", "\n", "[source](https://github.com/fastai/fastcore/blob/master/fastcore/xtras.py#L392){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "##### ReindexCollection.shuffle\n", "\n", "> ReindexCollection.shuffle ()\n", "\n", "*Randomly shuffle indices*" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "show_doc(ReindexCollection.shuffle, title_level=5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that an ordered index is automatically constructed for the data structure even if one is not supplied." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['a', 'h', 'f', 'b', 'c', 'g', 'e', 'd']" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "rc=ReindexCollection(['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'])\n", "rc.shuffle()\n", "list(rc)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sz = 50\n", "t = ReindexCollection(L.range(sz), cache=2)\n", "test_eq(list(t), range(sz))\n", "test_eq(t[sz-1], sz-1)\n", "test_eq(t._get.cache_info().hits, 1)\n", "t.shuffle()\n", "test_eq(t._get.cache_info().hits, 1)\n", "test_ne(list(t), range(sz))\n", "test_eq(set(t), set(range(sz)))\n", "t.cache_clear()\n", "test_eq(t._get.cache_info().hits, 0)\n", "test_eq(t.count(0), 1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|hide\n", "#Test ReindexCollection pickles\n", "t1 = pickle.loads(pickle.dumps(t))\n", "test_eq(list(t), list(t1))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Other Helpers" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def _is_type_dispatch(x): return type(x).__name__ == \"TypeDispatch\"\n", "def _unwrapped_type_dispatch_func(x): return x.first() if _is_type_dispatch(x) else x\n", "\n", "def _is_property(x): return type(x)==property\n", "def _has_property_getter(x): return _is_property(x) and hasattr(x, 'fget') and hasattr(x.fget, 'func')\n", "def _property_getter(x): return x.fget.func if _has_property_getter(x) else x\n", "\n", "def _unwrapped_func(x):\n", " x = _unwrapped_type_dispatch_func(x)\n", " x = _property_getter(x)\n", " return x\n", "\n", "\n", "def get_source_link(func):\n", " \"Return link to `func` in source code\"\n", " import inspect\n", " func = _unwrapped_func(func)\n", " try: line = inspect.getsourcelines(func)[1]\n", " except Exception: return ''\n", " mod = inspect.getmodule(func)\n", " module = mod.__name__.replace('.', '/') + '.py'\n", " try:\n", " nbdev_mod = import_module(mod.__package__.split('.')[0] + '._nbdev')\n", " return f\"{nbdev_mod.git_url}{module}#L{line}\"\n", " except: return f\"{module}#L{line}\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`get_source_link` allows you get a link to source code related to an object. For [nbdev](https://github.com/fastai/nbdev) related projects such as fastcore, we can get the full link to a GitHub repo. For `nbdev` projects, be sure to properly set the `git_url` in `settings.ini` (derived from `lib_name` and `branch` on top of the prefix you will need to adapt) so that those links are correct.\n", "\n", "For example, below we get the link to `fastcore.test.test_eq`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastcore.test import test_eq" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'https://github.com/fastai/fastcore/tree/master/fastcore/test.py#L35'" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "assert 'fastcore/test.py' in get_source_link(test_eq)\n", "assert get_source_link(test_eq).startswith('https://github.com/fastai/fastcore')\n", "get_source_link(test_eq)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def truncstr(s:str, maxlen:int, suf:str='…', space='')->str:\n", " \"Truncate `s` to length `maxlen`, adding suffix `suf` if truncated\"\n", " return s[:maxlen-len(suf)]+suf if len(s)+len(space)>maxlen else s+space" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "w = 'abacadabra'\n", "test_eq(truncstr(w, 10), w)\n", "test_eq(truncstr(w, 5), 'abac…')\n", "test_eq(truncstr(w, 5, suf=''), 'abaca')\n", "test_eq(truncstr(w, 11, space='_'), w+\"_\")\n", "test_eq(truncstr(w, 10, space='_'), w[:-1]+'…')\n", "test_eq(truncstr(w, 5, suf='!!'), 'aba!!')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "spark_chars = '▁▂▃▅▆▇'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def _ceil(x, lim=None): return x if (not lim or x <= lim) else lim\n", "\n", "def _sparkchar(x, mn, mx, incr, empty_zero):\n", " if x is None or (empty_zero and not x): return ' '\n", " if incr == 0: return spark_chars[0]\n", " res = int((_ceil(x,mx)-mn)/incr-0.5)\n", " return spark_chars[res]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def sparkline(data, mn=None, mx=None, empty_zero=False):\n", " \"Sparkline for `data`, with `None`s (and zero, if `empty_zero`) shown as empty column\"\n", " valid = [o for o in data if o is not None]\n", " if not valid: return ' '\n", " mn,mx,n = ifnone(mn,min(valid)),ifnone(mx,max(valid)),len(spark_chars)\n", " res = [_sparkchar(x=o, mn=mn, mx=mx, incr=(mx-mn)/n, empty_zero=empty_zero) for o in data]\n", " return ''.join(res)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "without \"empty_zero\": ▅▂ ▁▂▁▃▇▅\n", " with \"empty_zero\": ▅▂ ▁▂ ▃▇▅\n" ] } ], "source": [ "data = [9,6,None,1,4,0,8,15,10]\n", "print(f'without \"empty_zero\": {sparkline(data, empty_zero=False)}')\n", "print(f' with \"empty_zero\": {sparkline(data, empty_zero=True )}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can set a maximum and minimum for the y-axis of the sparkline with the arguments `mn` and `mx` respectively:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'▂▅▇▇'" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sparkline([1,2,3,400], mn=0, mx=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def modify_exception(\n", " e:Exception, # An exception\n", " msg:str=None, # A custom message\n", " replace:bool=False, # Whether to replace e.args with [msg]\n", ") -> Exception:\n", " \"Modifies `e` with a custom message attached\"\n", " e.args = [f'{e.args[0]} {msg}'] if not replace and len(e.args) > 0 else [msg]\n", " return e" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "msg = \"This is my custom message!\"\n", "\n", "test_fail(lambda: (_ for _ in ()).throw(modify_exception(Exception(), None)), contains='')\n", "test_fail(lambda: (_ for _ in ()).throw(modify_exception(Exception(), msg)), contains=msg)\n", "test_fail(lambda: (_ for _ in ()).throw(modify_exception(Exception(\"The first message\"), msg)), contains=\"The first message This is my custom message!\")\n", "test_fail(lambda: (_ for _ in ()).throw(modify_exception(Exception(\"The first message\"), msg, True)), contains=\"This is my custom message!\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def round_multiple(x, mult, round_down=False):\n", " \"Round `x` to nearest multiple of `mult`\"\n", " def _f(x_): return (int if round_down else round)(x_/mult)*mult\n", " res = L(x).map(_f)\n", " return res if is_listy(x) else res[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_eq(round_multiple(63,32), 64)\n", "test_eq(round_multiple(50,32), 64)\n", "test_eq(round_multiple(40,32), 32)\n", "test_eq(round_multiple( 0,32), 0)\n", "test_eq(round_multiple(63,32, round_down=True), 32)\n", "test_eq(round_multiple((63,40),32), (64,32))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def set_num_threads(nt):\n", " \"Get numpy (and others) to use `nt` threads\"\n", " try: import mkl; mkl.set_num_threads(nt)\n", " except: pass\n", " try: import torch; torch.set_num_threads(nt)\n", " except: pass\n", " os.environ['IPC_ENABLE']='1'\n", " for o in ['OPENBLAS_NUM_THREADS','NUMEXPR_NUM_THREADS','OMP_NUM_THREADS','MKL_NUM_THREADS']:\n", " os.environ[o] = str(nt)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This sets the number of threads consistently for many tools, by:\n", "\n", "1. Set the following environment variables equal to `nt`: `OPENBLAS_NUM_THREADS`,`NUMEXPR_NUM_THREADS`,`OMP_NUM_THREADS`,`MKL_NUM_THREADS`\n", "2. Sets `nt` threads for numpy and pytorch." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def join_path_file(file, path, ext=''):\n", " \"Return `path/file` if file is a string or a `Path`, file otherwise\"\n", " if not isinstance(file, (str, Path)): return file\n", " path.mkdir(parents=True, exist_ok=True)\n", " return path/f'{file}{ext}'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = Path.cwd()/'_tmp'/'tst'\n", "f = join_path_file('tst.txt', path)\n", "assert path.exists()\n", "test_eq(f, path/'tst.txt')\n", "with open(f, 'w') as f_: assert join_path_file(f_, path) == f_\n", "shutil.rmtree(Path.cwd()/'_tmp')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def autostart(g):\n", " \"Decorator that automatically starts a generator\"\n", " @functools.wraps(g)\n", " def f():\n", " r = g()\n", " next(r)\n", " return r\n", " return f" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "class EventTimer:\n", " \"An event timer with history of `store` items of time `span`\"\n", "\n", " def __init__(self, store=5, span=60):\n", " import collections\n", " self.hist,self.span,self.last = collections.deque(maxlen=store),span,perf_counter()\n", " self._reset()\n", "\n", " def _reset(self): self.start,self.events = self.last,0\n", "\n", " def add(self, n=1):\n", " \"Record `n` events\"\n", " if self.duration>self.span:\n", " self.hist.append(self.freq)\n", " self._reset()\n", " self.events +=n\n", " self.last = perf_counter()\n", "\n", " @property\n", " def duration(self): return perf_counter()-self.start\n", " @property\n", " def freq(self): return self.events/self.duration" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "---\n", "\n", "[source](https://github.com/fastai/fastcore/blob/master/fastcore/xtras.py#L502){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "#### EventTimer\n", "\n", "> EventTimer (store=5, span=60)\n", "\n", "*An event timer with history of `store` items of time `span`*" ], "text/plain": [ "---\n", "\n", "[source](https://github.com/fastai/fastcore/blob/master/fastcore/xtras.py#L502){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "#### EventTimer\n", "\n", "> EventTimer (store=5, span=60)\n", "\n", "*An event timer with history of `store` items of time `span`*" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "show_doc(EventTimer, title_level=4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Add events with `add`, and get number of `events` and their frequency (`freq`)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Num Events: 3, Freq/sec: 316.2\n", "Most recent: ▇▁▂▃▁ 288.7 227.7 246.5 256.5 217.9\n" ] } ], "source": [ "# Random wait function for testing\n", "def _randwait(): yield from (sleep(random.random()/200) for _ in range(100))\n", "\n", "c = EventTimer(store=5, span=0.03)\n", "for o in _randwait(): c.add(1)\n", "print(f'Num Events: {c.events}, Freq/sec: {c.freq:.01f}')\n", "print('Most recent: ', sparkline(c.hist), *L(c.hist).map('{:.01f}'))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "_fmt = string.Formatter()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def stringfmt_names(s:str)->list:\n", " \"Unique brace-delimited names in `s`\"\n", " return uniqueify(o[1] for o in _fmt.parse(s) if o[1])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "s = '/pulls/{pull_number}/reviews/{review_id}'\n", "test_eq(stringfmt_names(s), ['pull_number','review_id'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "class PartialFormatter(string.Formatter):\n", " \"A `string.Formatter` that doesn't error on missing fields, and tracks missing fields and unused args\"\n", " def __init__(self):\n", " self.missing = set()\n", " super().__init__()\n", "\n", " def get_field(self, nm, args, kwargs):\n", " try: return super().get_field(nm, args, kwargs)\n", " except KeyError:\n", " self.missing.add(nm)\n", " return '{'+nm+'}',nm\n", "\n", " def check_unused_args(self, used, args, kwargs):\n", " self.xtra = filter_keys(kwargs, lambda o: o not in used)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "---\n", "\n", "[source](https://github.com/fastai/fastcore/blob/master/fastcore/xtras.py#L534){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "#### PartialFormatter\n", "\n", "> PartialFormatter ()\n", "\n", "*A `string.Formatter` that doesn't error on missing fields, and tracks missing fields and unused args*" ], "text/plain": [ "---\n", "\n", "[source](https://github.com/fastai/fastcore/blob/master/fastcore/xtras.py#L534){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "#### PartialFormatter\n", "\n", "> PartialFormatter ()\n", "\n", "*A `string.Formatter` that doesn't error on missing fields, and tracks missing fields and unused args*" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "show_doc(PartialFormatter, title_level=4)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def partial_format(s:str, **kwargs):\n", " \"string format `s`, ignoring missing field errors, returning missing and extra fields\"\n", " fmt = PartialFormatter()\n", " res = fmt.format(s, **kwargs)\n", " return res,list(fmt.missing),fmt.xtra" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The result is a tuple of `(formatted_string,missing_fields,extra_fields)`, e.g:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "res,missing,xtra = partial_format(s, pull_number=1, foo=2)\n", "test_eq(res, '/pulls/1/reviews/{review_id}')\n", "test_eq(missing, ['review_id'])\n", "test_eq(xtra, {'foo':2})" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def utc2local(dt:datetime)->datetime:\n", " \"Convert `dt` from UTC to local time\"\n", " return dt.replace(tzinfo=timezone.utc).astimezone(tz=None)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2000-01-01 12:00:00 UTC is 2000-01-01 22:00:00+10:00 local time\n" ] } ], "source": [ "dt = datetime(2000,1,1,12)\n", "print(f'{dt} UTC is {utc2local(dt)} local time')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def local2utc(dt:datetime)->datetime:\n", " \"Convert `dt` from local to UTC time\"\n", " return dt.replace(tzinfo=None).astimezone(tz=timezone.utc)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2000-01-01 12:00:00 local is 2000-01-01 02:00:00+00:00 UTC time\n" ] } ], "source": [ "print(f'{dt} local is {local2utc(dt)} UTC time')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def trace(f):\n", " \"Add `set_trace` to an existing function `f`\"\n", " from pdb import set_trace\n", " if getattr(f, '_traced', False): return f\n", " def _inner(*args,**kwargs):\n", " set_trace()\n", " return f(*args,**kwargs)\n", " _inner._traced = True\n", " return _inner" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can add a breakpoint to an existing function, e.g:\n", "\n", "```python\n", "Path.cwd = trace(Path.cwd)\n", "Path.cwd()\n", "```\n", "\n", "Now, when the function is called it will drop you into the debugger. Note, you must issue the `s` command when you begin to step into the function that is being traced." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "@contextmanager\n", "def modified_env(*delete, **replace):\n", " \"Context manager temporarily modifying `os.environ` by deleting `delete` and replacing `replace`\"\n", " prev = dict(os.environ)\n", " try:\n", " os.environ.update(replace)\n", " for k in delete: os.environ.pop(k, None)\n", " yield\n", " finally:\n", " os.environ.clear()\n", " os.environ.update(prev)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# USER isn't in Cloud Linux Environments\n", "env_test = 'USERNAME' if sys.platform == \"win32\" else 'SHELL'\n", "oldusr = os.environ[env_test]\n", "\n", "replace_param = {env_test: 'a'}\n", "with modified_env('PATH', **replace_param):\n", " test_eq(os.environ[env_test], 'a')\n", " assert 'PATH' not in os.environ\n", "\n", "assert 'PATH' in os.environ\n", "test_eq(os.environ[env_test], oldusr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "class ContextManagers(GetAttr):\n", " \"Wrapper for `contextlib.ExitStack` which enters a collection of context managers\"\n", " def __init__(self, mgrs): self.default,self.stack = L(mgrs),ExitStack()\n", " def __enter__(self): self.default.map(self.stack.enter_context)\n", " def __exit__(self, *args, **kwargs): self.stack.__exit__(*args, **kwargs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "---\n", "\n", "[source](https://github.com/fastai/fastcore/blob/master/fastcore/xtras.py#L591){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "#### ContextManagers\n", "\n", "> ContextManagers (mgrs)\n", "\n", "*Wrapper for `contextlib.ExitStack` which enters a collection of context managers*" ], "text/plain": [ "---\n", "\n", "[source](https://github.com/fastai/fastcore/blob/master/fastcore/xtras.py#L591){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "#### ContextManagers\n", "\n", "> ContextManagers (mgrs)\n", "\n", "*Wrapper for `contextlib.ExitStack` which enters a collection of context managers*" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "show_doc(ContextManagers, title_level=4)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def shufflish(x, pct=0.04):\n", " \"Randomly relocate items of `x` up to `pct` of `len(x)` from their starting location\"\n", " n = len(x)\n", " import random\n", " return L(x[i] for i in sorted(range_of(x), key=lambda o: o+n*(1+random.random()*pct)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def console_help(\n", " libname:str): # name of library for console script listing\n", " \"Show help for all console scripts from `libname`\"\n", " from fastcore.style import S\n", " from pkg_resources import iter_entry_points as ep\n", " for e in ep('console_scripts'): \n", " if e.module_name == libname or e.module_name.startswith(libname+'.'): \n", " nm = S.bold.light_blue(e.name)\n", " print(f'{nm:45}{e.load().__doc__}')\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| export\n", "def hl_md(s, lang='xml', show=True):\n", " \"Syntax highlight `s` using `lang`.\"\n", " md = f'```{lang}\\n{s}\\n```'\n", " if not show: return md\n", " try:\n", " from IPython import display\n", " return display.Markdown(md)\n", " except ImportError: print(s)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When we display code in a notebook, it's nice to highlight it, so we create a function to simplify that:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "```xml\n", "a child\n", "```" ], "text/plain": [ "" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "hl_md('a child')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| export\n", "def type2str(typ:type)->str:\n", " \"Stringify `typ`\"\n", " if typ is None or typ is NoneType: return 'None'\n", " if hasattr(typ, '__origin__'):\n", " args = \", \".join(type2str(arg) for arg in typ.__args__)\n", " if typ.__origin__ is Union: return f\"Union[{args}]\"\n", " return f\"{typ.__origin__.__name__}[{args}]\"\n", " elif isinstance(typ, type): return typ.__name__\n", " return str(typ)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_eq(type2str(Optional[float]), 'Union[float, None]')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| export\n", "def dataclass_src(cls):\n", " import dataclasses\n", " src = f\"@dataclass\\nclass {cls.__name__}:\\n\"\n", " for f in dataclasses.fields(cls):\n", " d = \"\" if f.default is dataclasses.MISSING else f\" = {f.default!r}\"\n", " src += f\" {f.name}: {type2str(f.type)}{d}\\n\"\n", " return src" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from dataclasses import make_dataclass, dataclass" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "@dataclass\n", "class DC:\n", " x: int\n", " y: Union[float, None] = None\n", " z: float = None\n", "\n" ] } ], "source": [ "DC = make_dataclass('DC', [('x', int), ('y', Optional[float], None), ('z', float, None)])\n", "print(dataclass_src(DC))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| export\n", "def nullable_dc(cls):\n", " \"Like `dataclass`, but default of `None` added to fields without defaults\"\n", " from dataclasses import dataclass, field\n", " for k,v in get_annotations_ex(cls)[0].items():\n", " if not hasattr(cls,k): setattr(cls, k, field(default=None))\n", " return dataclass(cls)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Person(name='Bob', age=None, city='Unknown')" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "@nullable_dc\n", "class Person: name: str; age: int; city: str = \"Unknown\"\n", "Person(name=\"Bob\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| export\n", "def make_nullable(clas):\n", " from dataclasses import dataclass, fields, MISSING\n", " if hasattr(clas, '_nullable'): return\n", " clas._nullable = True\n", "\n", " original_init = clas.__init__\n", " def __init__(self, *args, **kwargs):\n", " flds = fields(clas)\n", " dargs = {k.name:v for k,v in zip(flds, args)}\n", " for f in flds:\n", " nm = f.name\n", " if nm not in dargs and nm not in kwargs and f.default is None and f.default_factory is MISSING:\n", " kwargs[nm] = None\n", " original_init(self, *args, **kwargs)\n", " \n", " clas.__init__ = __init__\n", "\n", " for f in fields(clas):\n", " if f.default is MISSING and f.default_factory is MISSING: f.default = None\n", " \n", " return clas" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Person(name='Bob', age=None, city='NY')" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "@dataclass\n", "class Person: name: str; age: int; city: str = \"Unknown\"\n", "\n", "make_nullable(Person)\n", "Person(\"Bob\", city='NY')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Person(name='Bob', age=None, city='Unknown')" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Person(name=\"Bob\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Person(name='Bob', age=34, city='Unknown')" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Person(\"Bob\", 34)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| export\n", "def mk_dataclass(cls):\n", " from dataclasses import dataclass, field, is_dataclass, MISSING\n", " if is_dataclass(cls): return make_nullable(cls)\n", " for k,v in get_annotations_ex(cls)[0].items():\n", " if not hasattr(cls,k) or getattr(cls,k) is MISSING:\n", " setattr(cls, k, field(default=None))\n", " dataclass(cls, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Person(name='Bob', age=None, city='Unknown')" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class Person: name: str; age: int; city: str = \"Unknown\"\n", "\n", "mk_dataclass(Person)\n", "Person(name=\"Bob\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| export\n", "def flexicache(*funcs, maxsize=128):\n", " \"Like `lru_cache`, but customisable with policy `funcs`\"\n", " import asyncio\n", " def _f(func):\n", " cache,states = {}, [None]*len(funcs)\n", " def _cache_logic(key, execute_func):\n", " if key in cache:\n", " result,states = cache[key]\n", " if not any(f(state) for f,state in zip(funcs, states)):\n", " cache[key] = cache.pop(key)\n", " return result\n", " del cache[key]\n", " try: newres = execute_func()\n", " except:\n", " if key not in cache: raise\n", " cache[key] = cache.pop(key)\n", " return result\n", " cache[key] = (newres, [f(None) for f in funcs])\n", " if len(cache) > maxsize: cache.popitem()\n", " return newres\n", "\n", " @wraps(func)\n", " def wrapper(*args, **kwargs):\n", " return _cache_logic(f\"{args} // {kwargs}\", lambda: func(*args, **kwargs))\n", "\n", " @wraps(func)\n", " async def async_wrapper(*args, **kwargs):\n", " return await _cache_logic(f\"{args} // {kwargs}\", lambda: asyncio.ensure_future(func(*args, **kwargs)))\n", "\n", " return async_wrapper if asyncio.iscoroutinefunction(func) else wrapper\n", " return _f" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is a flexible lru cache function that you can pass a list of functions to. Those functions define the cache eviction policy. For instance, `time_policy` is provided for time-based cache eviction, and `mtime_policy` evicts based on a file's modified-time changing. The policy functions are passed the last value that function returned was (initially `None`), and return a new value to indicate the cache has expired. When the cache expires, all functions are called with `None` to force getting new values. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| export\n", "def time_policy(seconds):\n", " \"A `flexicache` policy that expires cached items after `seconds` have passed\"\n", " def policy(last_time):\n", " now = time()\n", " return now if last_time is None or now-last_time>seconds else None\n", " return policy" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| export\n", "def mtime_policy(filepath):\n", " \"A `flexicache` policy that expires cached items after `filepath` modified-time changes\"\n", " def policy(mtime):\n", " current_mtime = getmtime(filepath)\n", " return current_mtime if mtime is None or current_mtime>mtime else None\n", " return policy" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "3" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "@flexicache(time_policy(10), mtime_policy('000_tour.ipynb'))\n", "def cached_func(x, y): return x+y\n", "\n", "cached_func(1,2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "3" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "@flexicache(time_policy(10), mtime_policy('000_tour.ipynb'))\n", "async def cached_func(x, y): return x+y\n", "\n", "await cached_func(1,2)\n", "await cached_func(1,2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| export\n", "def timed_cache(seconds=60, maxsize=128):\n", " \"Like `lru_cache`, but also with time-based eviction\"\n", " return flexicache(time_policy(seconds), maxsize=maxsize)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This function is a small convenience wrapper for using `flexicache` with `time_policy`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@timed_cache(seconds=0.05, maxsize=2)\n", "def cached_func(x): return x * 2, time()\n", "\n", "# basic caching\n", "result1, time1 = cached_func(2)\n", "test_eq(result1, 4)\n", "sleep(0.001)\n", "result2, time2 = cached_func(2)\n", "test_eq(result2, 4)\n", "test_eq(time1, time2)\n", "\n", "# caching different values\n", "result3, _ = cached_func(3)\n", "test_eq(result3, 6)\n", "\n", "# maxsize\n", "_, time4 = cached_func(4)\n", "_, time2_new = cached_func(2)\n", "test_close(time2, time2_new, eps=0.1)\n", "_, time3_new = cached_func(3)\n", "test_ne(time3_new, time())\n", "\n", "# time expiration\n", "sleep(0.05)\n", "_, time4_new = cached_func(4)\n", "test_ne(time4_new, time())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Export -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|hide\n", "import nbdev; nbdev.nbdev_export()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "jupytext": { "split_at_heading": true }, "kernelspec": { "display_name": "python3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 4 }