{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Train a YOLO classifier with 3LC metrics collection\n", "\n", "Train a YOLO classifer using existing tables.\n", "\n", "![](../images/yolo-cls-cifar.png)\n", "\n", "" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Install dependencies" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%pip install 3lc\n", "%pip install 3lc-ultralytics" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import tlc\n", "from tlc_ultralytics import YOLO, Settings" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Project setup" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "parameters" ] }, "outputs": [], "source": [ "PROJECT_NAME = \"3LC Tutorials - CIFAR-10\"\n", "MODEL_NAME = \"yolov8n-cls.pt\"\n", "IMAGE_COLUMN = \"Image\"\n", "LABEL_COLUMN = \"Label\"\n", "EPOCHS = 5\n", "NUM_WORKERS = 0\n", "BATCH_SIZE = 32\n", "IMAGE_SIZE = 32\n", "DOWNLOAR_PATH = \"../../transient_data\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_table = tlc.Table.from_names(\"initial\", \"CIFAR-10-train\", PROJECT_NAME)\n", "val_table = tlc.Table.from_names(\"initial\", \"CIFAR-10-val\", PROJECT_NAME)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = YOLO(MODEL_NAME)\n", "\n", "settings = Settings(\n", " project_name=PROJECT_NAME,\n", " run_name=\"Train YOLO Classifier\",\n", " image_embeddings_dim=2,\n", " conf_thres=0.2,\n", " sampling_weights=True,\n", " exclude_zero_weight_training=True,\n", " exclude_zero_weight_collection=False,\n", " image_column_name=IMAGE_COLUMN,\n", " label_column_name=LABEL_COLUMN,\n", ")\n", "\n", "model.train(\n", " tables={\n", " \"train\": train_table,\n", " \"val\": val_table,\n", " },\n", " settings=settings,\n", " batch=BATCH_SIZE,\n", " imgsz=IMAGE_SIZE,\n", " epochs=EPOCHS,\n", " workers=NUM_WORKERS,\n", " project=DOWNLOAR_PATH,\n", ")" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.9" }, "test_marks": [ "slow" ] }, "nbformat": 4, "nbformat_minor": 2 }