{ "cells": [ { "cell_type": "markdown", "id": "5574f366-58e9-408b-aea4-1bf5b3351e4c", "metadata": {}, "source": [ "# RAG-on-GKE Application\n", "\n", "This is a Python notebook for generating the vector embeddings used by the RAG on GKE application. For full information, please checkout the GitHub documentation [here](https://github.com/GoogleCloudPlatform/ai-on-gke/blob/main/applications/rag/README.md).\n", "\n", "\n", "## Setup Kaggle Credentials\n", "\n", "First we will setup your Kaggle credentials and use the Kaggle CLI to download the NetFlix shows dataset to the GCS bucket. Replace the following with your own settings from the Kaggle web page. Navigate to https://www.kaggle.com/settings/account and generate an API token to be used to setup the env variable. See https://www.kaggle.com/docs/api#authentication how to create one." ] }, { "cell_type": "code", "execution_count": null, "id": "ffee2bec-804f-4e22-9ba0-8b1db5a5d7ec", "metadata": {}, "outputs": [], "source": [ "import os\n", "os.environ['KAGGLE_USERNAME'] = \"\"\n", "os.environ['KAGGLE_KEY'] = \"\"\n", "\n", "# Download the zip file to local storage and then extract the desired contents directly to the GKE GCS CSI mounted bucket. The bucket is mounted at the \"/persist-data\" path in the jupyter pod.\n", "!kaggle datasets download -d shivamb/netflix-shows -p ~/data --force\n", "!mkdir /data/netflix-shows -p\n", "!unzip -o ~/data/netflix-shows.zip -d /data/netflix-shows" ] }, { "cell_type": "markdown", "id": "c7ff518d-f4d2-481b-b408-2c2507565611", "metadata": {}, "source": [ "## Download the Data\n", "\n", "Let's now import required modules:" ] }, { "cell_type": "code", "execution_count": null, "id": "aff789e7-a32d-4dd7-afb8-d3a22c8f3cec", "metadata": {}, "outputs": [], "source": [ "import os\n", "import uuid\n", "import ray\n", "from langchain.document_loaders import ArxivLoader\n", "from langchain.text_splitter import RecursiveCharacterTextSplitter\n", "from sentence_transformers import SentenceTransformer\n", "from typing import List\n", "import torch\n", "from datasets import load_dataset_builder, load_dataset, Dataset\n", "from huggingface_hub import snapshot_download\n", "from google.cloud.sql.connector import Connector, IPTypes\n", "import sqlalchemy" ] }, { "cell_type": "markdown", "id": "2156a6bd-1100-46c2-8ad6-22a923b3d6ac", "metadata": {}, "source": [ "Next we'll setup some parameters for the dataset processing steps:" ] }, { "cell_type": "code", "execution_count": null, "id": "b24267d7-3fad-47f1-8aa7-2fbe21fe8fa1", "metadata": {}, "outputs": [], "source": [ "SHARED_DATA_BASEPATH='/data/rag/st'\n", "SENTENCE_TRANSFORMER_MODEL = 'intfloat/multilingual-e5-small' # Transformer to use for converting text chunks to vector embeddings\n", "SENTENCE_TRANSFORMER_MODEL_PATH_NAME='models--intfloat--multilingual-e5-small' # the downloaded model path takes this form for a given model name\n", "SENTENCE_TRANSFORMER_MODEL_SNAPSHOT=\"ffdcc22a9a5c973ef0470385cef91e1ecb461d9f\" # specific snapshot of the model to use\n", "SENTENCE_TRANSFORMER_MODEL_PATH = SHARED_DATA_BASEPATH + '/' + SENTENCE_TRANSFORMER_MODEL_PATH_NAME + '/snapshots/' + SENTENCE_TRANSFORMER_MODEL_SNAPSHOT # the path where the model is downloaded one time\n", "\n", "# the dataset has been pre-dowloaded to the GCS bucket as part of the notebook in the cell above. Ray workers will find the dataset readily mounted.\n", "SHARED_DATASET_BASE_PATH=\"/data/netflix-shows/\"\n", "REVIEWS_FILE_NAME=\"netflix_titles.csv\"\n", "\n", "BATCH_SIZE = 500\n", "CHUNK_SIZE = 1000 # text chunk sizes which will be converted to vector embeddings\n", "CHUNK_OVERLAP = 10\n", "TABLE_NAME = 'netflix_reviews_db' # CloudSQL table name\n", "DIMENSION = 384 # Embeddings size\n", "ACTOR_POOL_SIZE = 1 # number of actors for the distributed map_batches function" ] }, { "cell_type": "markdown", "id": "3dc5bc85-dc3b-4622-99a2-f9fc269e753b", "metadata": {}, "source": [ "Now we will download the sentence transformer model to our GCS bucket:" ] }, { "cell_type": "code", "execution_count": null, "id": "b7a676be-56c6-4c76-8041-9ad05361dd3b", "metadata": {}, "outputs": [], "source": [ "# prepare the persistent shared directory to store artifacts needed for the ray workers\n", "os.makedirs(SHARED_DATA_BASEPATH, exist_ok=True)\n", "\n", "# One time download of the sentence transformer model to a shared persistent storage available to the ray workers\n", "snapshot_download(repo_id=SENTENCE_TRANSFORMER_MODEL, revision=SENTENCE_TRANSFORMER_MODEL_SNAPSHOT, cache_dir=SHARED_DATA_BASEPATH)" ] }, { "cell_type": "markdown", "id": "f7304035-21a4-4017-bce9-aba7e9f81c90", "metadata": {}, "source": [ "## Generating Vector Embeddings\n", "\n", "We are ready to begin. Let's first create some code for generating the vector embeddings:" ] }, { "cell_type": "code", "execution_count": null, "id": "5bbb3750-7cd5-439f-a767-617cd5948e27", "metadata": {}, "outputs": [], "source": [ "class Embed:\n", " def __init__(self):\n", " print(\"torch cuda version\", torch.version.cuda)\n", " device=\"cpu\"\n", " if torch.cuda.is_available():\n", " print(\"device cuda found\")\n", " device=\"cuda\"\n", "\n", " print (\"reading sentence transformer model from cache path:\", SENTENCE_TRANSFORMER_MODEL_PATH)\n", " self.transformer = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL_PATH, device=device)\n", " self.splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP, length_function=len)\n", "\n", " def __call__(self, text_batch: List[str]):\n", " text = text_batch[\"item\"]\n", " # print(\"type(text)=\", type(text), \"type(text_batch)=\", type(text_batch))\n", " chunks = []\n", " for data in text:\n", " splits = self.splitter.split_text(data)\n", " # print(\"len(data)\", len(data), \"len(splits)=\", len(splits))\n", " chunks.extend(splits)\n", "\n", " embeddings = self.transformer.encode(\n", " chunks,\n", " batch_size=BATCH_SIZE\n", " ).tolist()\n", " print(\"len(chunks)=\", len(chunks), \", len(emb)=\", len(embeddings))\n", " return {'results':list(zip(chunks, embeddings))}" ] }, { "cell_type": "markdown", "id": "7263b9db-9504-4177-acd6-5e1aba2ee332", "metadata": {}, "source": [ "Next we will initialize a Ray cluster to execute the remote task:" ] }, { "cell_type": "code", "execution_count": null, "id": "ba9551ec-883e-4bde-9f12-f663aedc12e5", "metadata": {}, "outputs": [], "source": [ "import ray\n", "\n", "ray.init(\n", " address=\"ray://ray-cluster-kuberay-head-svc:10001\",\n", " runtime_env={\n", " \"pip\": [ \n", " \"langchain==0.1.10\",\n", " \"transformers==4.38.1\",\n", " \"sentence-transformers==2.5.1\",\n", " \"pyarrow\",\n", " \"datasets==2.18.0\",\n", " \"torch==2.0.1\",\n", " \"huggingface_hub==0.21.3\",\n", " ]\n", " }\n", ")" ] }, { "cell_type": "markdown", "id": "9589048c-a0aa-4740-acde-5289cd4076f7", "metadata": {}, "source": [ "Generate vector embeddings using our Embed class above:" ] }, { "cell_type": "code", "execution_count": null, "id": "a392975f-3743-4b2c-8673-087b5633637e", "metadata": {}, "outputs": [], "source": [ "# Process the dataset first, wrap the csv file contents into a Ray dataset\n", "ray_ds = ray.data.read_csv(SHARED_DATASET_BASE_PATH + REVIEWS_FILE_NAME)\n", "print(ray_ds.schema)\n", "\n", "# Distributed flat map to extract the raw text fields.\n", "ds_batch = ray_ds.flat_map(lambda row: [{\n", " 'item': \"This is a \" + str(row[\"type\"]) + \" in \" + str(row[\"country\"]) + \" called \" + str(row[\"title\"]) + \n", " \" added at \" + str(row[\"date_added\"]) + \" whose director is \" + str(row[\"director\"]) + \n", " \" and with cast: \" + str(row[\"cast\"]) + \" released at \" + str(row[\"release_year\"]) + \n", " \". Its rating is: \" + str(row['rating']) + \". Its duration is \" + str(row[\"duration\"]) + \n", " \". Its description is \" + str(row['description']) + \".\"\n", "}])\n", "print(ds_batch.schema)\n", "\n", "# Distributed map batches to create chunks out of each row, and fetch the vector embeddings by running inference on the sentence transformer\n", "ds_embed = ds_batch.map_batches(\n", " Embed,\n", " compute=ray.data.ActorPoolStrategy(size=ACTOR_POOL_SIZE),\n", " batch_size=BATCH_SIZE, # Large batch size to maximize GPU utilization.\n", " num_gpus=1, # 1 GPU for each actor.\n", " # num_cpus=1,\n", ")" ] }, { "cell_type": "markdown", "id": "4697ac28-9815-409c-95ec-6ecdb862bb74", "metadata": {}, "source": [ "Retrieve the result data from Ray remote workers:" ] }, { "cell_type": "code", "execution_count": null, "id": "e0edbba2-8977-4afd-aaa2-2e6e3a298169", "metadata": {}, "outputs": [], "source": [ "@ray.remote\n", "def ray_data_task(ds_embed):\n", " results = []\n", " for row in ds_embed.iter_rows():\n", " data_text = row[\"results\"][0][:65535]\n", " data_emb = row[\"results\"][1]\n", "\n", " results.append((data_text, data_emb))\n", " \n", " return results\n", " \n", "results = ray.get(ray_data_task.remote(ds_embed))" ] }, { "cell_type": "markdown", "id": "5652832e-025d-4745-9fef-96615eea07e4", "metadata": {}, "source": [ "## Writing Results Back to MySQL\n", "\n", "Now that we have our vector embeddings, we can write our results back to the MySQL database:" ] }, { "cell_type": "code", "execution_count": null, "id": "07eb5ec7-c4f7-4312-b0ce-ea07160bef92", "metadata": {}, "outputs": [], "source": [ "from sqlalchemy.ext.declarative import declarative_base\n", "from sqlalchemy import Column, String, Text, text\n", "from sqlalchemy.orm import scoped_session, sessionmaker, mapped_column\n", "from pgvector.sqlalchemy import Vector\n", "\n", "# initialize parameters\n", "\n", "INSTANCE_CONNECTION_NAME = os.environ[\"CLOUDSQL_INSTANCE_CONNECTION_NAME\"]\n", "print(f\"Your instance connection name is: {INSTANCE_CONNECTION_NAME}\")\n", "DB_NAME = \"pgvector-database\"\n", "\n", "db_username_file = open(\"/etc/secret-volume/username\", \"r\")\n", "DB_USER = db_username_file.read()\n", "db_username_file.close()\n", "\n", "db_password_file = open(\"/etc/secret-volume/password\", \"r\")\n", "DB_PASS = db_password_file.read()\n", "db_password_file.close()\n", "\n", "# initialize Connector object\n", "connector = Connector()\n", "\n", "# function to return the database connection object\n", "def getconn():\n", " conn = connector.connect(\n", " INSTANCE_CONNECTION_NAME,\n", " \"pg8000\",\n", " user=DB_USER,\n", " password=DB_PASS,\n", " db=DB_NAME,\n", " ip_type=IPTypes.PRIVATE\n", " )\n", " return conn\n", "\n", "# create connection pool with 'creator' argument to our connection object function\n", "pool = sqlalchemy.create_engine(\n", " \"postgresql+pg8000://\",\n", " creator=getconn,\n", ")\n", "\n", "Base = declarative_base()\n", "DBSession = scoped_session(sessionmaker())\n", "\n", "class TextEmbedding(Base):\n", " __tablename__ = TABLE_NAME\n", " id = Column(String(255), primary_key=True)\n", " text = Column(Text)\n", " text_embedding = mapped_column(Vector(384))\n", "\n", "with pool.connect() as conn:\n", " conn.execute(text(\"CREATE EXTENSION IF NOT EXISTS vector\"))\n", " conn.commit() \n", " \n", "DBSession.configure(bind=pool, autoflush=False, expire_on_commit=False)\n", "Base.metadata.drop_all(pool)\n", "Base.metadata.create_all(pool)\n", "\n", "rows = []\n", "for r in results:\n", " id = uuid.uuid4() \n", " rows.append(TextEmbedding(id=id, text=r[0], text_embedding=r[1]))\n", "\n", "DBSession.bulk_save_objects(rows)\n", "DBSession.commit()" ] }, { "cell_type": "markdown", "id": "b4b19b1c-a83b-4a83-94a9-5edf5ae7016a", "metadata": {}, "source": [ "Finally let's verify that our embeddings got stored in the database correctly:" ] }, { "cell_type": "code", "execution_count": null, "id": "4cff4bbc-574d-4cc2-8c87-d0ff6d351626", "metadata": {}, "outputs": [], "source": [ "with pool.connect() as db_conn:\n", " # verify results\n", " transformer = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL)\n", " query_text = \"During my holiday in Marmaris we ate here to fit the food. It's really good\" \n", " query_emb = transformer.encode(query_text).tolist()\n", " query_request = \"SELECT id, text, text_embedding, 1 - ('[\" + \",\".join(map(str, query_emb)) + \"]' <=> text_embedding) AS cosine_similarity FROM \" + TABLE_NAME + \" ORDER BY cosine_similarity DESC LIMIT 5;\" \n", " query_results = db_conn.execute(sqlalchemy.text(query_request)).fetchall()\n", " db_conn.commit()\n", " \n", " print(\"print query_results, the 1st one is the hit\")\n", " for row in query_results:\n", " print(row)\n", "\n", "# cleanup connector object\n", "connector.close()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.11" } }, "nbformat": 4, "nbformat_minor": 5 }