{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Fast embedding retrieval\n", "\n", "What's a fast, portable way of retrieving embeddings?" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import hashlib\n", "import json\n", "import numpy as np\n", "import os\n", "import pickle\n", "import sqlite3" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "n = 10000\n", "c = 1000\n", "\n", "data = []\n", "for _ in range(n):\n", " text = \"\".join(np.random.choice(list(\"abcdefghijklmnopqrstuvwxyz0123456789\"), 100))\n", " data.append({\n", " \"text\": text,\n", " \"hash\": hashlib.sha256(text.encode()).hexdigest(),\n", " \"embeddings\": np.random.rand(1000).astype(np.float32)\n", " })" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Serialization formats\n", "\n", "- NumPy's `.tobytes()` and `.frombuffer()` are the fastest (1.9 ms)\n", "- Pickling NumPy arrays is slower (27 ms)\n", "- Pickling lists is slower still (240 ms)\n", "- JSON is the slowest (1,380 ms)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.92 ms ± 21.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" ] } ], "source": [ "blobs = [d[\"embeddings\"].tobytes() for d in data]\n", "%timeit [np.frombuffer(p) for p in blobs]" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "32.3 ms ± 409 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" ] } ], "source": [ "blobs = [pickle.dumps(d[\"embeddings\"]) for d in data]\n", "%timeit [pickle.loads(p) for p in blobs]" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "268 ms ± 3.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], "source": [ "blobs = [pickle.dumps(d[\"embeddings\"].tolist()) for d in data]\n", "%timeit [pickle.loads(p) for p in blobs]" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.38 s ± 17.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], "source": [ "blobs = [json.dumps(d[\"embeddings\"].tolist()) for d in data]\n", "%timeit [np.array(json.loads(p)) for p in blobs]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " ## Storage formats\n", "\n", " - SQLite is faster (9 ms / 1K records)\n", " - Filesystem is slower (45 ms / 1K records)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "33.8 ms ± 298 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n", "8.63 ms ± 172 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], "source": [ "conn = sqlite3.connect(\".embeddings.db\")\n", "cur = conn.cursor()\n", "cur.execute(\"DROP TABLE IF EXISTS embeddings\")\n", "cur.execute(\"CREATE TABLE embeddings (text TEXT, data BLOB)\")\n", "cur.executemany(\n", " \"INSERT INTO embeddings (text, data) VALUES (?, ?)\",\n", " [(d[\"text\"], d[\"embeddings\"].tobytes()) for d in data],\n", ")\n", "cur.execute(\"CREATE INDEX idx_embeddings_text ON embeddings (text)\")\n", "conn.commit()\n", "conn.close()\n", "\n", "conn = sqlite3.connect(\".embeddings.db\")\n", "cur = conn.cursor()\n", "texts = [data[np.random.randint(0, n)][\"text\"] for _ in range(c)]\n", "# Retrieve embeddings one by one\n", "%timeit embeddings = [np.frombuffer(cur.execute(\"SELECT data FROM embeddings WHERE text = ?\", (text,)).fetchone()[0]) for text in texts]\n", "# Retrieve all embeddings in a single query\n", "%timeit embeddings = [np.frombuffer(row[0]) for row in cur.execute(\"SELECT data FROM embeddings WHERE text IN ({})\".format(\",\".join([\"?\"]*c)), texts)]\n", "conn.close()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "46.1 ms ± 493 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], "source": [ "if not os.path.exists(\".embeddings\"):\n", " os.makedirs(\".embeddings\")\n", "\n", "for row in data:\n", " path = os.path.join(\".embeddings\", row[\"hash\"])\n", " with open(path, \"wb\") as f:\n", " pickle.dump(row[\"embeddings\"], f)\n", "\n", "texts = [data[np.random.randint(0, n)][\"text\"] for _ in range(c)]\n", "%timeit embeddings = [pickle.load(open(os.path.join(\".embeddings\", hashlib.sha256(text.encode()).hexdigest()), \"rb\")) for text in texts]" ] } ], "metadata": { "kernelspec": { "display_name": "gramex", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.4" } }, "nbformat": 4, "nbformat_minor": 2 }