{ "cells": [ { "cell_type": "markdown", "id": "afb6fda4-ffde-4831-88a5-ae41144492b2", "metadata": {}, "source": [ "# RAG-on-GKE Application\n", "\n", "This is a Python notebook for generating the vector embeddings used by the RAG on GKE application. For full information, please checkout the GitHub documentation [here](https://github.com/GoogleCloudPlatform/ai-on-gke/blob/main/applications/rag/README.md).\n" ] }, { "cell_type": "code", "execution_count": null, "id": "00b1aff4", "metadata": {}, "outputs": [], "source": [ "# Replace these with your settings\n", "# Navigate to https://www.kaggle.com/settings/account and generate an API token to be used to setup the env variable. See https://www.kaggle.com/docs/api#authentication how to create one.\n", "KAGGLE_USERNAME = \"\"\n", "KAGGLE_KEY = \"\"" ] }, { "cell_type": "code", "execution_count": null, "id": "a814e91b-3afe-4c28-a3d6-fe087c7af552", "metadata": {}, "outputs": [], "source": [ "!pip install ray[default]==2.9.3 kaggle==1.6.6" ] }, { "cell_type": "code", "execution_count": null, "id": "1e26faef-9e2e-4793-b8af-0e18470b482d", "metadata": {}, "outputs": [], "source": [ "import os\n", "os.environ['KAGGLE_USERNAME'] = KAGGLE_USERNAME\n", "os.environ['KAGGLE_KEY'] = KAGGLE_KEY\n", "\n", "# Download the zip file to local storage and then extract the desired contents directly to the GKE GCS CSI mounted bucket. The bucket is mounted at the \"/persist-data\" path in the jupyter pod.\n", "!kaggle datasets download -d shivamb/netflix-shows -p ~/data --force\n", "!mkdir /data/netflix-shows -p\n", "!unzip -o ~/data/netflix-shows.zip -d /data/netflix-shows" ] }, { "cell_type": "code", "execution_count": null, "id": "050f2c66-b92e-4ca6-a3b7-b7448d066f8e", "metadata": {}, "outputs": [], "source": [ "# Create a directory to package the contents that need to be downloaded in ray worker\n", "! mkdir -p rag-app" ] }, { "cell_type": "code", "execution_count": null, "id": "c82cdcad-c74c-4196-9aa0-2e6bb49f4b58", "metadata": {}, "outputs": [], "source": [ "%%writefile rag-app/job.py\n", "# Comment out the above line if you want to see notebook print out, but the line is required for the actual ray job (the job.py is downloaded by the ray workers)\n", "\n", "import os\n", "import uuid\n", "import ray\n", "from langchain.document_loaders import ArxivLoader\n", "from langchain.text_splitter import RecursiveCharacterTextSplitter\n", "from sentence_transformers import SentenceTransformer\n", "from typing import List\n", "import torch\n", "from datasets import load_dataset_builder, load_dataset, Dataset\n", "from huggingface_hub import snapshot_download\n", "from google.cloud.sql.connector import Connector, IPTypes\n", "import sqlalchemy\n", "\n", "# initialize parameters\n", "INSTANCE_CONNECTION_NAME = os.environ[\"CLOUDSQL_INSTANCE_CONNECTION_NAME\"]\n", "print(f\"Your instance connection name is: {INSTANCE_CONNECTION_NAME}\")\n", "DB_NAME = \"pgvector-database\"\n", "\n", "db_username_file = open(\"/etc/secret-volume/username\", \"r\")\n", "DB_USER = db_username_file.read()\n", "db_username_file.close()\n", "\n", "db_password_file = open(\"/etc/secret-volume/password\", \"r\")\n", "DB_PASS = db_password_file.read()\n", "db_password_file.close()\n", "\n", "# initialize Connector object\n", "connector = Connector()\n", "\n", "# function to return the database connection object\n", "def getconn():\n", " conn = connector.connect(\n", " INSTANCE_CONNECTION_NAME,\n", " \"pg8000\",\n", " user=DB_USER,\n", " password=DB_PASS,\n", " db=DB_NAME,\n", " ip_type=IPTypes.PRIVATE\n", " )\n", " return conn\n", "\n", "# create connection pool with 'creator' argument to our connection object function\n", "pool = sqlalchemy.create_engine(\n", " \"postgresql+pg8000://\",\n", " creator=getconn,\n", ")\n", "\n", "SHARED_DATA_BASEPATH='/data/rag/st'\n", "SENTENCE_TRANSFORMER_MODEL = 'intfloat/multilingual-e5-small' # Transformer to use for converting text chunks to vector embeddings\n", "SENTENCE_TRANSFORMER_MODEL_PATH_NAME='models--intfloat--multilingual-e5-small' # the downloaded model path takes this form for a given model name\n", "SENTENCE_TRANSFORMER_MODEL_SNAPSHOT=\"ffdcc22a9a5c973ef0470385cef91e1ecb461d9f\" # specific snapshot of the model to use\n", "SENTENCE_TRANSFORMER_MODEL_PATH = SHARED_DATA_BASEPATH + '/' + SENTENCE_TRANSFORMER_MODEL_PATH_NAME + '/snapshots/' + SENTENCE_TRANSFORMER_MODEL_SNAPSHOT # the path where the model is downloaded one time\n", "\n", "# the dataset has been pre-dowloaded to the GCS bucket as part of the notebook in the cell above. Ray workers will find the dataset readily mounted.\n", "SHARED_DATASET_BASE_PATH=\"/data/netflix-shows/\"\n", "REVIEWS_FILE_NAME=\"netflix_titles.csv\"\n", "\n", "BATCH_SIZE = 100\n", "CHUNK_SIZE = 1000 # text chunk sizes which will be converted to vector embeddings\n", "CHUNK_OVERLAP = 10\n", "TABLE_NAME = 'netflix_reviews_db' # CloudSQL table name\n", "DIMENSION = 384 # Embeddings size\n", "ACTOR_POOL_SIZE = 1 # number of actors for the distributed map_batches function\n", "\n", "class Embed:\n", " def __init__(self):\n", " print(\"torch cuda version\", torch.version.cuda)\n", " device=\"cpu\"\n", " if torch.cuda.is_available():\n", " print(\"device cuda found\")\n", " device=\"cuda\"\n", "\n", " print (\"reading sentence transformer model from cache path:\", SENTENCE_TRANSFORMER_MODEL_PATH)\n", " self.transformer = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL_PATH, device=device)\n", " self.splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP, length_function=len)\n", "\n", " def __call__(self, text_batch: List[str]):\n", " text = text_batch[\"item\"]\n", " # print(\"type(text)=\", type(text), \"type(text_batch)=\", type(text_batch))\n", " chunks = []\n", " for data in text:\n", " splits = self.splitter.split_text(data)\n", " # print(\"len(data)\", len(data), \"len(splits)=\", len(splits))\n", " chunks.extend(splits)\n", "\n", " embeddings = self.transformer.encode(\n", " chunks,\n", " batch_size=BATCH_SIZE\n", " ).tolist()\n", " print(\"len(chunks)=\", len(chunks), \", len(emb)=\", len(embeddings))\n", " return {'results':list(zip(chunks, embeddings))}\n", "\n", "\n", "# prepare the persistent shared directory to store artifacts needed for the ray workers\n", "os.makedirs(SHARED_DATA_BASEPATH, exist_ok=True)\n", "\n", "# One time download of the sentence transformer model to a shared persistent storage available to the ray workers\n", "snapshot_download(repo_id=SENTENCE_TRANSFORMER_MODEL, revision=SENTENCE_TRANSFORMER_MODEL_SNAPSHOT, cache_dir=SHARED_DATA_BASEPATH)\n", "\n", "# Process the dataset first, wrap the csv file contents into a Ray dataset\n", "ray_ds = ray.data.read_csv(SHARED_DATASET_BASE_PATH + REVIEWS_FILE_NAME)\n", "print(ray_ds.schema)\n", "\n", "# Distributed flat map to extract the raw text fields.\n", "ds_batch = ray_ds.flat_map(lambda row: [{\n", " 'item': \"This is a \" + str(row[\"type\"]) + \" in \" + str(row[\"country\"]) + \" called \" + str(row[\"title\"]) + \n", " \" added at \" + str(row[\"date_added\"]) + \" whose director is \" + str(row[\"director\"]) + \n", " \" and with cast: \" + str(row[\"cast\"]) + \" released at \" + str(row[\"release_year\"]) + \n", " \". Its rating is: \" + str(row['rating']) + \". Its duration is \" + str(row[\"duration\"]) + \n", " \". Its description is \" + str(row['description']) + \".\"\n", "}])\n", "print(ds_batch.schema)\n", "\n", "# Distributed map batches to create chunks out of each row, and fetch the vector embeddings by running inference on the sentence transformer\n", "ds_embed = ds_batch.map_batches(\n", " Embed,\n", " compute=ray.data.ActorPoolStrategy(size=ACTOR_POOL_SIZE),\n", " batch_size=BATCH_SIZE, # Large batch size to maximize GPU utilization.\n", " num_gpus=1, # 1 GPU for each actor.\n", " # num_cpus=1,\n", ")\n", "\n", "# Use this block for debug purpose to inspect the embeddings and raw text\n", "# print(\"Embeddings ray dataset\", ds_embed.schema)\n", "# for output in ds_embed.iter_rows():\n", "# # restrict the text string to be less than 65535\n", "# data_text = output[\"results\"][0][:65535]\n", "# # vector data pass in needs to be a string \n", "# data_emb = \",\".join(map(str, output[\"results\"][1]))\n", "# data_emb = \"[\" + data_emb + \"]\"\n", "# print (\"raw text:\", data_text, \", emdeddings:\", data_emb)\n", "\n", "# print(\"Embeddings ray dataset\", ds_embed.schema)\n", "\n", "data_text = \"\"\n", "data_emb = \"\"\n", "\n", "with pool.connect() as db_conn:\n", " db_conn.execute(\n", " sqlalchemy.text(\n", " \"CREATE EXTENSION IF NOT EXISTS vector;\"\n", " )\n", " )\n", " db_conn.commit()\n", "\n", " create_table_query = \"CREATE TABLE IF NOT EXISTS \" + TABLE_NAME + \" ( id VARCHAR(255) NOT NULL, text TEXT NOT NULL, text_embedding vector(384) NOT NULL, PRIMARY KEY (id));\"\n", " db_conn.execute(\n", " sqlalchemy.text(create_table_query)\n", " )\n", " # commit transaction (SQLAlchemy v2.X.X is commit as you go)\n", " db_conn.commit()\n", " print(\"Created table=\", TABLE_NAME)\n", " \n", " query_text = \"INSERT INTO \" + TABLE_NAME + \" (id, text, text_embedding) VALUES (:id, :text, :text_embedding)\"\n", " insert_stmt = sqlalchemy.text(query_text)\n", " for output in ds_embed.iter_rows():\n", " # print (\"type of embeddings\", type(output[\"results\"][1]), \"len embeddings\", len(output[\"results\"][1]))\n", " # restrict the text string to be less than 65535\n", " data_text = output[\"results\"][0][:65535]\n", " # vector data pass in needs to be a string \n", " data_emb = \",\".join(map(str, output[\"results\"][1]))\n", " data_emb = \"[\" + data_emb + \"]\"\n", " # print(\"text_embedding is \", data_emb)\n", " id = uuid.uuid4()\n", " db_conn.execute(insert_stmt, parameters={\"id\": id, \"text\": data_text, \"text_embedding\": data_emb})\n", "\n", " # batch commit transactions\n", " db_conn.commit()\n", "\n", " # query and fetch table\n", " query_text = \"SELECT * FROM \" + TABLE_NAME\n", " results = db_conn.execute(sqlalchemy.text(query_text)).fetchall()\n", " # for row in results:\n", " # print(row)\n", "\n", " # verify results\n", " transformer = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL)\n", " query_text = \"During my holiday in Marmaris we ate here to fit the food. It's really good\" \n", " query_emb = transformer.encode(query_text).tolist()\n", " query_request = \"SELECT id, text, text_embedding, 1 - ('[\" + \",\".join(map(str, query_emb)) + \"]' <=> text_embedding) AS cosine_similarity FROM \" + TABLE_NAME + \" ORDER BY cosine_similarity DESC LIMIT 5;\" \n", " query_results = db_conn.execute(sqlalchemy.text(query_request)).fetchall()\n", " db_conn.commit()\n", " print(\"print query_results, the 1st one is the hit\")\n", " for row in query_results:\n", " print(row)\n", "\n", "# cleanup connector object\n", "connector.close()\n", "print (\"end job\")" ] }, { "cell_type": "code", "execution_count": null, "id": "aeeb7b7a-23d8-4c6a-8165-7ce5516d2a41", "metadata": {}, "outputs": [], "source": [ "import ray, time\n", "from ray.job_submission import JobSubmissionClient\n", "client = JobSubmissionClient(\"ray://ray-cluster-kuberay-head-svc:10001\")" ] }, { "cell_type": "code", "execution_count": null, "id": "7ba6c3ff-a25a-4f4d-b58e-68f7fe7d33df", "metadata": {}, "outputs": [], "source": [ "# Port forward to the Ray dashboard and go to `localhost:8265` in a browser to see job status: kubectl port-forward -n service/ray-cluster-kuberay-head-svc 8265:8265\n", "import time\n", "\n", "start_time = time.time()\n", "job_id = client.submit_job(\n", " entrypoint=\"python job.py\",\n", " # Path to the local directory that contains the entrypoint file.\n", " runtime_env={\n", " \"working_dir\": \"/home/jovyan/rag-app\", # upload the local working directory to ray workers\n", " }\n", ")\n", "\n", "# The Ray job typically takes 5m-10m to complete.\n", "print(\"Job submitted with ID:\", job_id)\n", "while True:\n", " status = client.get_job_status(job_id)\n", " print(\"Job status:\", status)\n", " print(\"Job info:\", client.get_job_info(job_id).message)\n", " if status.is_terminal():\n", " break\n", " time.sleep(30)\n", "\n", "end_time = time.time()\n", "job_duration = end_time - start_time\n", "print(f\"Job completed in {job_duration} seconds.\")" ] }, { "cell_type": "code", "execution_count": null, "id": "98ec6c2d-3295-4f67-9fa0-af6d5708955a", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.11" } }, "nbformat": 4, "nbformat_minor": 5 }