# Fast embedding retrieval

What's a fast, portable way of retrieving embeddings?

In [1]:
import hashlib
import json
import numpy as np
import os
import pickle
import sqlite3

In [2]:
n = 10000
c = 1000

data = []
for _ in range(n):
 text = "".join(np.random.choice(list("abcdefghijklmnopqrstuvwxyz0123456789"), 100))
 data.append({
 "text": text,
 "hash": hashlib.sha256(text.encode()).hexdigest(),
 "embeddings": np.random.rand(1000).astype(np.float32)
 })

## Serialization formats

- NumPy's `.tobytes()` and `.frombuffer()` are the fastest (1.9 ms)
- Pickling NumPy arrays is slower (27 ms)
- Pickling lists is slower still (240 ms)
- JSON is the slowest (1,380 ms)

In [3]:
blobs = [d["embeddings"].tobytes() for d in data]
%timeit [np.frombuffer(p) for p in blobs]

1.92 ms ± 21.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [4]:
blobs = [pickle.dumps(d["embeddings"]) for d in data]
%timeit [pickle.loads(p) for p in blobs]

32.3 ms ± 409 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [5]:
blobs = [pickle.dumps(d["embeddings"].tolist()) for d in data]
%timeit [pickle.loads(p) for p in blobs]

268 ms ± 3.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [6]:
blobs = [json.dumps(d["embeddings"].tolist()) for d in data]
%timeit [np.array(json.loads(p)) for p in blobs]

2.38 s ± 17.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


 ## Storage formats

 - SQLite is faster (9 ms / 1K records)
 - Filesystem is slower (45 ms / 1K records)

In [7]:
conn = sqlite3.connect(".embeddings.db")
cur = conn.cursor()
cur.execute("DROP TABLE IF EXISTS embeddings")
cur.execute("CREATE TABLE embeddings (text TEXT, data BLOB)")
cur.executemany(
 "INSERT INTO embeddings (text, data) VALUES (?, ?)",
 [(d["text"], d["embeddings"].tobytes()) for d in data],
)
cur.execute("CREATE INDEX idx_embeddings_text ON embeddings (text)")
conn.commit()
conn.close()

conn = sqlite3.connect(".embeddings.db")
cur = conn.cursor()
texts = [data[np.random.randint(0, n)]["text"] for _ in range(c)]
# Retrieve embeddings one by one
%timeit embeddings = [np.frombuffer(cur.execute("SELECT data FROM embeddings WHERE text = ?", (text,)).fetchone()[0]) for text in texts]
# Retrieve all embeddings in a single query
%timeit embeddings = [np.frombuffer(row[0]) for row in cur.execute("SELECT data FROM embeddings WHERE text IN ({})".format(",".join(["?"]*c)), texts)]
conn.close()

33.8 ms ± 298 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
8.63 ms ± 172 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [8]:
if not os.path.exists(".embeddings"):
 os.makedirs(".embeddings")

for row in data:
 path = os.path.join(".embeddings", row["hash"])
 with open(path, "wb") as f:
 pickle.dump(row["embeddings"], f)

texts = [data[np.random.randint(0, n)]["text"] for _ in range(c)]
%timeit embeddings = [pickle.load(open(os.path.join(".embeddings", hashlib.sha256(text.encode()).hexdigest()), "rb")) for text in texts]

46.1 ms ± 493 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
