{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Sample data for training and testing\n", "\n", "Create training, validation, and test splits with random or stratified sampling." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Problem\n", "\n", "You have a large dataset and need to create subsets for ML training—random samples for quick experiments, stratified samples for balanced classes, or reproducible splits for benchmarking.\n", "\n", "| Need | Method |\n", "|------|--------|\n", "| Quick experiment | Random sample of N rows |\n", "| Balanced classes | Stratified by label |\n", "| Reproducible | Fixed seed |" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Solution\n", "\n", "**What's in this recipe:**\n", "\n", "- Random sampling with `sample(n=...)`\n", "- Percentage-based sampling with `sample(fraction=...)`\n", "- Stratified sampling with `stratify_by=`\n", "\n", "You use `query.sample()` to create random subsets, with optional stratification for balanced class distribution." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Setup" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "execution": { "iopub.execute_input": "2025-12-12T02:36:46.716441Z", "iopub.status.busy": "2025-12-12T02:36:46.716262Z", "iopub.status.idle": "2025-12-12T02:36:49.530289Z", "shell.execute_reply": "2025-12-12T02:36:49.529757Z" } }, "outputs": [], "source": [ "%pip install -qU pixeltable" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2025-12-12T02:36:49.546595Z", "iopub.status.busy": "2025-12-12T02:36:49.546431Z", "iopub.status.idle": "2025-12-12T02:36:50.909915Z", "shell.execute_reply": "2025-12-12T02:36:50.909465Z" } }, "outputs": [], "source": [ "import pixeltable as pxt" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2025-12-12T02:36:50.912109Z", "iopub.status.busy": "2025-12-12T02:36:50.911829Z", "iopub.status.idle": "2025-12-12T02:36:51.190499Z", "shell.execute_reply": "2025-12-12T02:36:51.190129Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Connected to Pixeltable database at: postgresql+psycopg://postgres:@/pixeltable?host=/Users/pjlb/.pixeltable/pgdata\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Created directory 'sampling_demo'.\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Create a fresh directory\n", "pxt.drop_dir('sampling_demo', force=True)\n", "pxt.create_dir('sampling_demo')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create sample dataset" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2025-12-12T02:36:51.192876Z", "iopub.status.busy": "2025-12-12T02:36:51.192742Z", "iopub.status.idle": "2025-12-12T02:36:52.184226Z", "shell.execute_reply": "2025-12-12T02:36:52.183759Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Created table 'data'.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r\n", "Inserting rows into `data`: 0 rows [00:00, ? rows/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r\n", "Inserting rows into `data`: 10 rows [00:00, 857.13 rows/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Inserted 10 rows with 0 errors.\n" ] }, { "data": { "text/plain": [ "10 rows inserted, 20 values computed." ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Create a dataset with labels\n", "data = pxt.create_table(\n", " 'sampling_demo/data',\n", " {'text': pxt.String, 'label': pxt.String, 'score': pxt.Float},\n", ")\n", "\n", "# Insert sample data with imbalanced classes\n", "samples = [\n", " {'text': 'Great product!', 'label': 'positive', 'score': 0.9},\n", " {'text': 'Love it', 'label': 'positive', 'score': 0.85},\n", " {'text': 'Amazing quality', 'label': 'positive', 'score': 0.95},\n", " {'text': 'Best purchase ever', 'label': 'positive', 'score': 0.88},\n", " {'text': 'Highly recommend', 'label': 'positive', 'score': 0.92},\n", " {'text': 'Fantastic!', 'label': 'positive', 'score': 0.91},\n", " {'text': 'Terrible', 'label': 'negative', 'score': 0.1},\n", " {'text': 'Waste of money', 'label': 'negative', 'score': 0.15},\n", " {'text': 'It is okay', 'label': 'neutral', 'score': 0.5},\n", " {'text': 'Average product', 'label': 'neutral', 'score': 0.55},\n", "]\n", "data.insert(samples)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Random sampling" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2025-12-12T02:36:52.186965Z", "iopub.status.busy": "2025-12-12T02:36:52.186447Z", "iopub.status.idle": "2025-12-12T02:36:52.204370Z", "shell.execute_reply": "2025-12-12T02:36:52.203944Z" } }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
textlabelscore
Fantastic!positive0.91
It is okayneutral0.5
Average productneutral0.55
Highly recommendpositive0.92
Great product!positive0.9
" ], "text/plain": [ " text label score\n", "0 Fantastic! positive 0.91\n", "1 It is okay neutral 0.50\n", "2 Average product neutral 0.55\n", "3 Highly recommend positive 0.92\n", "4 Great product! positive 0.90" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Sample exactly N rows\n", "data.sample(n=5, seed=42).collect()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2025-12-12T02:36:52.207576Z", "iopub.status.busy": "2025-12-12T02:36:52.207243Z", "iopub.status.idle": "2025-12-12T02:36:52.217378Z", "shell.execute_reply": "2025-12-12T02:36:52.216864Z" } }, "outputs": [], "source": [ "# Sample a percentage of rows\n", "sample_50pct = data.sample(fraction=0.5, seed=42).collect()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Stratified sampling" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2025-12-12T02:36:52.219681Z", "iopub.status.busy": "2025-12-12T02:36:52.219550Z", "iopub.status.idle": "2025-12-12T02:36:52.240105Z", "shell.execute_reply": "2025-12-12T02:36:52.239427Z" } }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
textlabelscore
Terriblenegative0.1
It is okayneutral0.5
Fantastic!positive0.91
Highly recommendpositive0.92
Great product!positive0.9
" ], "text/plain": [ " text label score\n", "0 Terrible negative 0.10\n", "1 It is okay neutral 0.50\n", "2 Fantastic! positive 0.91\n", "3 Highly recommend positive 0.92\n", "4 Great product! positive 0.90" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Stratified sampling: 50% from each class\n", "data.sample(fraction=0.5, stratify_by=data.label, seed=42).collect()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2025-12-12T02:36:52.242422Z", "iopub.status.busy": "2025-12-12T02:36:52.242300Z", "iopub.status.idle": "2025-12-12T02:36:52.258504Z", "shell.execute_reply": "2025-12-12T02:36:52.257992Z" } }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
textlabelscore
Terriblenegative0.1
It is okayneutral0.5
Fantastic!positive0.91
" ], "text/plain": [ " text label score\n", "0 Terrible negative 0.10\n", "1 It is okay neutral 0.50\n", "2 Fantastic! positive 0.91" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Equal allocation: N rows from each class\n", "data.sample(n_per_stratum=1, stratify_by=data.label, seed=42).collect()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Sampling from filtered data" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2025-12-12T02:36:52.260656Z", "iopub.status.busy": "2025-12-12T02:36:52.260540Z", "iopub.status.idle": "2025-12-12T02:36:52.274328Z", "shell.execute_reply": "2025-12-12T02:36:52.273752Z" } }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
textlabelscore
Fantastic!positive0.91
Highly recommendpositive0.92
Great product!positive0.9
" ], "text/plain": [ " text label score\n", "0 Fantastic! positive 0.91\n", "1 Highly recommend positive 0.92\n", "2 Great product! positive 0.90" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Sample from filtered query (high-confidence predictions only)\n", "data.where(data.score > 0.8).sample(n=3, seed=42).collect()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Persist samples as tables" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2025-12-12T02:36:52.276469Z", "iopub.status.busy": "2025-12-12T02:36:52.276343Z", "iopub.status.idle": "2025-12-12T02:36:52.451157Z", "shell.execute_reply": "2025-12-12T02:36:52.450737Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Created table 'train'.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r\n", "Inserting rows into `train`: 0 rows [00:00, ? rows/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r\n", "Inserting rows into `train`: 9 rows [00:00, 3080.27 rows/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Created table 'test'.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r\n", "Inserting rows into `test`: 0 rows [00:00, ? rows/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r\n", "Inserting rows into `test`: 3 rows [00:00, 1333.92 rows/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "# Create a persistent table from a sample for dev/test\n", "train_sample = data.sample(fraction=0.8, seed=42)\n", "test_sample = data.sample(fraction=0.2, seed=43)\n", "\n", "# Persist as new tables\n", "train_table = pxt.create_table('sampling_demo/train', source=train_sample)\n", "test_table = pxt.create_table('sampling_demo/test', source=test_sample)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Explanation\n", "\n", "**Sampling methods:**\n", "\n", "| Method | Parameter | Behavior |\n", "|--------|-----------|----------|\n", "| Fixed count | `n=100` | Exactly 100 rows |\n", "| Percentage | `fraction=0.1` | 10% of rows |\n", "| Per-class | `n_per_stratum=10` | 10 from each class |\n", "\n", "**Stratification options:**\n", "\n", "| Use case | Parameters |\n", "|----------|------------|\n", "| Proportional | `fraction=0.1, stratify_by=col` |\n", "| Equal allocation | `n_per_stratum=10, stratify_by=col` |\n", "| Reproducible | Add `seed=42` |\n", "\n", "**Tips:**\n", "\n", "- Always set `seed` for reproducible experiments\n", "- Use stratified sampling for imbalanced datasets\n", "- Combine with `.where()` to sample from subsets" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## See also\n", "\n", "- [Export for ML training](https://docs.pixeltable.com/howto/cookbooks/data/data-export-pytorch) - PyTorch DataLoader export\n", "- [Import Hugging Face datasets](https://docs.pixeltable.com/howto/cookbooks/data/data-import-huggingface) - Load pre-split datasets" ] } ], "metadata": { "kernelspec": { "display_name": "pixeltable", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.11" } }, "nbformat": 4, "nbformat_minor": 2 }