{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Create custom aggregate functions (UDAs)\n", "\n", "Build reusable aggregation logic for group-by queries and analytics." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Problem\n", "\n", "You need aggregations beyond the built-in `sum`, `count`, `mean`, `min`, `max` — such as collecting values into a list, concatenating strings, or computing custom statistics." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Solution\n", "\n", "**What's in this recipe:**\n", "\n", "- Define a UDA (User-Defined Aggregate) with the `@pxt.uda` decorator\n", "- Use UDAs in `group_by` queries\n", "- Create UDAs with multiple inputs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Setup" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "execution": { "iopub.execute_input": "2025-12-12T02:37:25.608051Z", "iopub.status.busy": "2025-12-12T02:37:25.607898Z", "iopub.status.idle": "2025-12-12T02:37:28.188918Z", "shell.execute_reply": "2025-12-12T02:37:28.188442Z" } }, "outputs": [], "source": [ "%pip install -qU pixeltable" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2025-12-12T02:37:28.205757Z", "iopub.status.busy": "2025-12-12T02:37:28.205599Z", "iopub.status.idle": "2025-12-12T02:37:29.601267Z", "shell.execute_reply": "2025-12-12T02:37:29.600947Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Connected to Pixeltable database at: postgresql+psycopg://postgres:@/pixeltable?host=/Users/pjlb/.pixeltable/pgdata\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Created directory 'uda_demo'.\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pixeltable as pxt\n", "\n", "pxt.drop_dir('uda_demo', force=True)\n", "pxt.create_dir('uda_demo')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create sample data" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2025-12-12T02:37:29.603063Z", "iopub.status.busy": "2025-12-12T02:37:29.602907Z", "iopub.status.idle": "2025-12-12T02:37:30.641689Z", "shell.execute_reply": "2025-12-12T02:37:30.641240Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Created table 'sales'.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r\n", "Inserting rows into `sales`: 0 rows [00:00, ? rows/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r\n", "Inserting rows into `sales`: 6 rows [00:00, 609.56 rows/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Inserted 6 rows with 0 errors.\n" ] }, { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
regionproductamountquantity
NorthWidget100.5
NorthGadget250.2
NorthWidget150.8
SouthWidget200.10
SouthGadget175.3
EastWidget125.6
" ], "text/plain": [ " region product amount quantity\n", "0 North Widget 100.0 5\n", "1 North Gadget 250.0 2\n", "2 North Widget 150.0 8\n", "3 South Widget 200.0 10\n", "4 South Gadget 175.0 3\n", "5 East Widget 125.0 6" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sales = pxt.create_table(\n", " 'uda_demo/sales',\n", " {\n", " 'region': pxt.String,\n", " 'product': pxt.String,\n", " 'amount': pxt.Float,\n", " 'quantity': pxt.Int,\n", " },\n", ")\n", "\n", "sales.insert(\n", " [\n", " {\n", " 'region': 'North',\n", " 'product': 'Widget',\n", " 'amount': 100.0,\n", " 'quantity': 5,\n", " },\n", " {\n", " 'region': 'North',\n", " 'product': 'Gadget',\n", " 'amount': 250.0,\n", " 'quantity': 2,\n", " },\n", " {\n", " 'region': 'North',\n", " 'product': 'Widget',\n", " 'amount': 150.0,\n", " 'quantity': 8,\n", " },\n", " {\n", " 'region': 'South',\n", " 'product': 'Widget',\n", " 'amount': 200.0,\n", " 'quantity': 10,\n", " },\n", " {\n", " 'region': 'South',\n", " 'product': 'Gadget',\n", " 'amount': 175.0,\n", " 'quantity': 3,\n", " },\n", " {\n", " 'region': 'East',\n", " 'product': 'Widget',\n", " 'amount': 125.0,\n", " 'quantity': 6,\n", " },\n", " ]\n", ")\n", "\n", "sales.collect()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Variance UDA (not built-in)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2025-12-12T02:37:30.643789Z", "iopub.status.busy": "2025-12-12T02:37:30.643396Z", "iopub.status.idle": "2025-12-12T02:37:30.646641Z", "shell.execute_reply": "2025-12-12T02:37:30.646221Z" } }, "outputs": [], "source": [ "# A UDA is a class that inherits from pxt.Aggregator\n", "# It must implement: __init__, update, and value\n", "\n", "\n", "@pxt.uda\n", "class variance(pxt.Aggregator):\n", " \"\"\"Compute population variance using Welford's online algorithm.\"\"\"\n", "\n", " def __init__(self):\n", " self.count = 0\n", " self.mean = 0.0\n", " self.m2 = 0.0 # Sum of squared differences from mean\n", "\n", " def update(self, val: float) -> None:\n", " if val is not None:\n", " self.count += 1\n", " delta = val - self.mean\n", " self.mean += delta / self.count\n", " delta2 = val - self.mean\n", " self.m2 += delta * delta2\n", "\n", " def value(self) -> float:\n", " if self.count < 1:\n", " return 0.0\n", " return self.m2 / self.count # Population variance" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2025-12-12T02:37:30.649016Z", "iopub.status.busy": "2025-12-12T02:37:30.648886Z", "iopub.status.idle": "2025-12-12T02:37:30.659119Z", "shell.execute_reply": "2025-12-12T02:37:30.658727Z" } }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
variance
2430.556
" ], "text/plain": [ " variance\n", "0 2430.555556" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Use like any built-in aggregate\n", "sales.select(variance(sales.amount)).collect()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2025-12-12T02:37:30.660943Z", "iopub.status.busy": "2025-12-12T02:37:30.660829Z", "iopub.status.idle": "2025-12-12T02:37:30.682695Z", "shell.execute_reply": "2025-12-12T02:37:30.682254Z" } }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
regionamount_variance
East0.
North3888.889
South156.25
" ], "text/plain": [ " region amount_variance\n", "0 East 0.000000\n", "1 North 3888.888889\n", "2 South 156.250000" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Use in group_by queries\n", "sales.group_by(sales.region).select(\n", " sales.region, amount_variance=variance(sales.amount)\n", ").collect()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### String concatenation UDA" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2025-12-12T02:37:30.684665Z", "iopub.status.busy": "2025-12-12T02:37:30.684532Z", "iopub.status.idle": "2025-12-12T02:37:30.687096Z", "shell.execute_reply": "2025-12-12T02:37:30.686706Z" } }, "outputs": [], "source": [ "@pxt.uda\n", "class string_agg(pxt.Aggregator):\n", " \"\"\"Concatenate strings with a comma separator.\"\"\"\n", "\n", " def __init__(self):\n", " self.values = []\n", "\n", " def update(self, val: str) -> None:\n", " if val is not None:\n", " self.values.append(val)\n", "\n", " def value(self) -> str:\n", " return ', '.join(self.values)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2025-12-12T02:37:30.689250Z", "iopub.status.busy": "2025-12-12T02:37:30.689059Z", "iopub.status.idle": "2025-12-12T02:37:30.710540Z", "shell.execute_reply": "2025-12-12T02:37:30.710167Z" } }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
regionproducts
EastWidget
NorthWidget, Gadget, Widget
SouthWidget, Gadget
" ], "text/plain": [ " region products\n", "0 East Widget\n", "1 North Widget, Gadget, Widget\n", "2 South Widget, Gadget" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# List all products sold in each region\n", "sales.group_by(sales.region).select(\n", " sales.region, products=string_agg(sales.product)\n", ").collect()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Collect values into a list" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2025-12-12T02:37:30.712624Z", "iopub.status.busy": "2025-12-12T02:37:30.712468Z", "iopub.status.idle": "2025-12-12T02:37:30.715683Z", "shell.execute_reply": "2025-12-12T02:37:30.715217Z" } }, "outputs": [], "source": [ "@pxt.uda\n", "class collect_list(pxt.Aggregator):\n", " \"\"\"Collect all values into a list.\"\"\"\n", "\n", " def __init__(self):\n", " self.items = []\n", "\n", " def update(self, val: float) -> None:\n", " if val is not None:\n", " self.items.append(val)\n", "\n", " def value(self) -> list[float]:\n", " return self.items" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2025-12-12T02:37:30.717618Z", "iopub.status.busy": "2025-12-12T02:37:30.717522Z", "iopub.status.idle": "2025-12-12T02:37:30.740090Z", "shell.execute_reply": "2025-12-12T02:37:30.739456Z" } }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
regionamounts
East[125.]
North[100., 250., 150.]
South[200., 175.]
" ], "text/plain": [ " region amounts\n", "0 East [125.0]\n", "1 North [100.0, 250.0, 150.0]\n", "2 South [200.0, 175.0]" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Get all amounts per region as a list\n", "sales.group_by(sales.region).select(\n", " sales.region, amounts=collect_list(sales.amount)\n", ").collect()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Weighted average UDA" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2025-12-12T02:37:30.742132Z", "iopub.status.busy": "2025-12-12T02:37:30.742001Z", "iopub.status.idle": "2025-12-12T02:37:30.744677Z", "shell.execute_reply": "2025-12-12T02:37:30.744360Z" } }, "outputs": [], "source": [ "@pxt.uda\n", "class weighted_avg(pxt.Aggregator):\n", " \"\"\"Compute weighted average: sum(value * weight) / sum(weight).\"\"\"\n", "\n", " def __init__(self):\n", " self.weighted_sum = 0.0\n", " self.weight_sum = 0.0\n", "\n", " def update(self, value: float, weight: float) -> None:\n", " if value is not None and weight is not None:\n", " self.weighted_sum += value * weight\n", " self.weight_sum += weight\n", "\n", " def value(self) -> float:\n", " if self.weight_sum == 0:\n", " return 0.0\n", " return self.weighted_sum / self.weight_sum" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2025-12-12T02:37:30.747077Z", "iopub.status.busy": "2025-12-12T02:37:30.746685Z", "iopub.status.idle": "2025-12-12T02:37:30.772241Z", "shell.execute_reply": "2025-12-12T02:37:30.771753Z" } }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
regionavg_price
East125.
North146.667
South194.231
" ], "text/plain": [ " region avg_price\n", "0 East 125.000000\n", "1 North 146.666667\n", "2 South 194.230769" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Compute quantity-weighted average price per region\n", "sales.group_by(sales.region).select(\n", " sales.region, avg_price=weighted_avg(sales.amount, sales.quantity)\n", ").collect()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Mode UDA (most frequent value)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2025-12-12T02:37:30.774750Z", "iopub.status.busy": "2025-12-12T02:37:30.774609Z", "iopub.status.idle": "2025-12-12T02:37:30.777623Z", "shell.execute_reply": "2025-12-12T02:37:30.777252Z" } }, "outputs": [], "source": [ "from collections import Counter\n", "\n", "\n", "@pxt.uda\n", "class mode(pxt.Aggregator):\n", " \"\"\"Find the most frequent value in a group.\"\"\"\n", "\n", " def __init__(self):\n", " self.counts = Counter()\n", "\n", " def update(self, val: str) -> None:\n", " if val is not None:\n", " self.counts[val] += 1\n", "\n", " def value(self) -> str:\n", " if not self.counts:\n", " return None\n", " return self.counts.most_common(1)[0][0]" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2025-12-12T02:37:30.779537Z", "iopub.status.busy": "2025-12-12T02:37:30.779356Z", "iopub.status.idle": "2025-12-12T02:37:30.801502Z", "shell.execute_reply": "2025-12-12T02:37:30.801122Z" } }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
regiontop_product
EastWidget
NorthWidget
SouthWidget
" ], "text/plain": [ " region top_product\n", "0 East Widget\n", "1 North Widget\n", "2 South Widget" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Find most common product per region\n", "sales.group_by(sales.region).select(\n", " sales.region, top_product=mode(sales.product)\n", ").collect()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Explanation\n", "\n", "**UDA structure:**\n", "\n", "```python\n", "@pxt.uda\n", "class my_aggregate(pxt.Aggregator):\n", " def __init__(self): # Initialize state\n", " self.state = initial_value\n", "\n", " def update(self, val: InputType) -> None: # Called for each row\n", " # Update internal state with val\n", "\n", " def value(self) -> OutputType: # Called at the end\n", " return self.state\n", "```\n", "\n", "**Key points:**\n", "\n", "- Always handle `None` values in `update()`\n", "- Multiple parameters in `update()` enable multi-column aggregations (like `weighted_avg`)\n", "- Return type annotation on `value()` determines output column type" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## See also\n", "\n", "- [UDFs in Pixeltable](/platform/udfs-in-pixeltable) - Complete guide to custom functions\n", "- [Join tables](https://docs.pixeltable.com/howto/cookbooks/core/query-join-tables) - Combine data before aggregating" ] } ], "metadata": { "kernelspec": { "display_name": "pixeltable", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 2 }