{ "cells": [ { "cell_type": "markdown", "id": "10ae2b14", "metadata": {}, "source": [ "# Sentence Embeddings with Hugging Face Transformers, Sentence Transformers and Amazon SageMaker - Custom Inference for creating document embeddings with Hugging Face's Transformers\n" ] }, { "cell_type": "markdown", "id": "5a1644f1", "metadata": {}, "source": [ "Welcome to this getting started guide. We will use the Hugging Face Inference DLCs and Amazon SageMaker Python SDK to create a [real-time inference endpoint](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints.html) running a Sentence Transformers for document embeddings. Currently, the [SageMaker Hugging Face Inference Toolkit](https://github.com/aws/sagemaker-huggingface-inference-toolkit) supports the [pipeline feature](https://huggingface.co/transformers/main_classes/pipelines.html) from Transformers for zero-code deployment. This means you can run compatible Hugging Face Transformer models without providing pre- & post-processing code. Therefore we only need to provide an environment variable `HF_TASK` and `HF_MODEL_ID` when creating our endpoint and the Inference Toolkit will take care of it. This is a great feature if you are working with existing [pipelines](https://huggingface.co/transformers/main_classes/pipelines.html).\n", "\n", "If you want to run other tasks, such as creating document embeddings, you can the pre- and post-processing code yourself, via an `inference.py` script. The Hugging Face Inference Toolkit allows the user to override the default methods of the `HuggingFaceHandlerService`.\n", "\n", "The custom module can override the following methods:\n", "\n", "- `model_fn(model_dir)` overrides the default method for loading a model. The return value `model` will be used in the`predict_fn` for predictions.\n", " - `model_dir` is the the path to your unzipped `model.tar.gz`.\n", "- `input_fn(input_data, content_type)` overrides the default method for pre-processing. The return value `data` will be used in `predict_fn` for predictions. The inputs are:\n", " - `input_data` is the raw body of your request.\n", " - `content_type` is the content type from the request header.\n", "- `predict_fn(processed_data, model)` overrides the default method for predictions. The return value `predictions` will be used in `output_fn`.\n", " - `model` returned value from `model_fn` methond\n", " - `processed_data` returned value from `input_fn` method\n", "- `output_fn(prediction, accept)` overrides the default method for post-processing. The return value `result` will be the response to your request (e.g.`JSON`). The inputs are:\n", " - `predictions` is the result from `predict_fn`.\n", " - `accept` is the return accept type from the HTTP Request, e.g. `application/json`.\n", "\n", "In this example are we going to use Sentence Transformers to create sentence embeddings using a mean pooling layer on the raw representation.\n", "\n", "*NOTE: You can run this demo in Sagemaker Studio, your local machine, or Sagemaker Notebook Instances*" ] }, { "cell_type": "markdown", "id": "e1e723ab", "metadata": {}, "source": [ "## Development Environment and Permissions\n", "\n", "### Installation \n" ] }, { "cell_type": "code", "execution_count": null, "id": "69c59d90", "metadata": {}, "outputs": [], "source": [ "%pip install sagemaker --upgrade" ] }, { "cell_type": "markdown", "id": "ce0ef431", "metadata": {}, "source": [ "Install `git` and `git-lfs`" ] }, { "cell_type": "code", "execution_count": null, "id": "96d8dfea", "metadata": {}, "outputs": [], "source": [ "# For notebook instances (Amazon Linux)\n", "!sudo yum update -y \n", "!curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.rpm.sh | sudo bash\n", "!sudo yum install git-lfs git -y\n", "# For other environments (Ubuntu)\n", "!sudo apt-get update -y \n", "!curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash\n", "!sudo apt-get install git-lfs git -y" ] }, { "cell_type": "markdown", "id": "9e4386d9", "metadata": {}, "source": [ "### Permissions\n", "\n", "_If you are going to use Sagemaker in a local environment (not SageMaker Studio or Notebook Instances). You need access to an IAM Role with the required permissions for Sagemaker. You can find [here](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html) more about it._" ] }, { "cell_type": "code", "execution_count": 7, "id": "1c22e8d5", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Couldn't call 'get_role' to get Role ARN from role name philippschmid to get Role path.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "sagemaker role arn: arn:aws:iam::558105141721:role/sagemaker_execution_role\n", "sagemaker bucket: sagemaker-us-east-1-558105141721\n", "sagemaker session region: us-east-1\n" ] } ], "source": [ "import sagemaker\n", "import boto3\n", "sess = sagemaker.Session()\n", "# sagemaker session bucket -> used for uploading data, models and logs\n", "# sagemaker will automatically create this bucket if it not exists\n", "sagemaker_session_bucket=None\n", "if sagemaker_session_bucket is None and sess is not None:\n", " # set to default bucket if a bucket name is not given\n", " sagemaker_session_bucket = sess.default_bucket()\n", "\n", "try:\n", " role = sagemaker.get_execution_role()\n", "except ValueError:\n", " iam = boto3.client('iam')\n", " role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']\n", "\n", "sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)\n", "\n", "print(f\"sagemaker role arn: {role}\")\n", "print(f\"sagemaker bucket: {sess.default_bucket()}\")\n", "print(f\"sagemaker session region: {sess.boto_region_name}\")" ] }, { "cell_type": "markdown", "id": "0b0fc22f", "metadata": {}, "source": [ "## Create custom an `inference.py` script\n", "\n", "To use the custom inference script, you need to create an `inference.py` script. In our example, we are going to overwrite the `model_fn` to load our sentence transformer correctly and the `predict_fn` to apply mean pooling.\n", "\n", "We are going to use the [sentence-transformers/all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) model. It maps sentences & paragraphs to a 384 dimensional dense vector space and can be used for tasks like clustering or semantic search." ] }, { "cell_type": "code", "execution_count": 17, "id": "b4246c06", "metadata": {}, "outputs": [], "source": [ "!mkdir code" ] }, { "cell_type": "code", "execution_count": 38, "id": "3ce41529", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Overwriting code/inference.py\n" ] } ], "source": [ "%%writefile code/inference.py\n", "\n", "from transformers import AutoTokenizer, AutoModel\n", "import torch\n", "import torch.nn.functional as F\n", "\n", "# Helper: Mean Pooling - Take attention mask into account for correct averaging\n", "def mean_pooling(model_output, attention_mask):\n", " token_embeddings = model_output[0] #First element of model_output contains all token embeddings\n", " input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()\n", " return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)\n", "\n", "\n", "def model_fn(model_dir):\n", " # Load model from HuggingFace Hub\n", " tokenizer = AutoTokenizer.from_pretrained(model_dir)\n", " model = AutoModel.from_pretrained(model_dir)\n", " return model, tokenizer\n", "\n", "def predict_fn(data, model_and_tokenizer):\n", " # destruct model and tokenizer\n", " model, tokenizer = model_and_tokenizer\n", " \n", " # Tokenize sentences\n", " sentences = data.pop(\"inputs\", data)\n", " encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')\n", "\n", " # Compute token embeddings\n", " with torch.no_grad():\n", " model_output = model(**encoded_input)\n", "\n", " # Perform pooling\n", " sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])\n", "\n", " # Normalize embeddings\n", " sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)\n", " \n", " # return dictonary, which will be json serializable\n", " return {\"vectors\": sentence_embeddings[0].tolist()}" ] }, { "cell_type": "markdown", "id": "144d8ccb", "metadata": {}, "source": [ "## Create `model.tar.gz` with inference script and model \n", "\n", "To use our `inference.py` we need to bundle it into a `model.tar.gz` archive with all our model-artifcats, e.g. `pytorch_model.bin`. The `inference.py` script will be placed into a `code/` folder. We will use `git` and `git-lfs` to easily download our model from hf.co/models and upload it to Amazon S3 so we can use it when creating our SageMaker endpoint." ] }, { "cell_type": "code", "execution_count": 39, "id": "952983b5", "metadata": {}, "outputs": [], "source": [ "repository = \"sentence-transformers/all-MiniLM-L6-v2\"\n", "model_id=repository.split(\"/\")[-1]\n", "s3_location=f\"s3://{sess.default_bucket()}/custom_inference/{model_id}/model.tar.gz\"" ] }, { "cell_type": "markdown", "id": "374ff630", "metadata": {}, "source": [ "1. Download the model from hf.co/models with `git clone`." ] }, { "cell_type": "code", "execution_count": 40, "id": "b8452981", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Updated git hooks.\n", "Git LFS initialized.\n", "Cloning into 'all-MiniLM-L6-v2'...\n", "remote: Enumerating objects: 25, done.\u001b[K\n", "remote: Counting objects: 100% (25/25), done.\u001b[K\n", "remote: Compressing objects: 100% (23/23), done.\u001b[K\n", "remote: Total 25 (delta 3), reused 0 (delta 0)\u001b[K.00 KiB/s\n", "Unpacking objects: 100% (25/25), 308.60 KiB | 454.00 KiB/s, done.\n" ] } ], "source": [ "!git lfs install\n", "!git clone https://huggingface.co/$repository" ] }, { "cell_type": "markdown", "id": "09a6f330", "metadata": {}, "source": [ "2. copy `inference.py` into the `code/` directory of the model directory." ] }, { "cell_type": "code", "execution_count": 41, "id": "6146af09", "metadata": {}, "outputs": [], "source": [ "!cp -r code/ $model_id/code/" ] }, { "cell_type": "markdown", "id": "04e1395a", "metadata": {}, "source": [ "3. Create a `model.tar.gz` archive with all the model artifacts and the `inference.py` script.\n" ] }, { "cell_type": "code", "execution_count": 42, "id": "e65fd56e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/Users/philipp/.Trash/all-MiniLM-L6-v2/all-MiniLM-L6-v2\n", "a 1_Pooling\n", "a 1_Pooling/config.json\n", "a README.md\n", "a code\n", "a code/inference.py\n", "a config.json\n", "a config_sentence_transformers.json\n", "a data_config.json\n", "a modules.json\n", "a pytorch_model.bin\n", "a sentence_bert_config.json\n", "a special_tokens_map.json\n", "a tokenizer.json\n", "a tokenizer_config.json\n", "a train_script.py\n", "a vocab.txt\n" ] } ], "source": [ "%cd $model_id\n", "!tar zcvf model.tar.gz *" ] }, { "cell_type": "markdown", "id": "1c858560", "metadata": {}, "source": [ "4. Upload the `model.tar.gz` to Amazon S3:\n" ] }, { "cell_type": "code", "execution_count": 43, "id": "c581bc40", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "upload: ./model.tar.gz to s3://sagemaker-us-east-1-558105141721/custom_inference/all-MiniLM-L6-v2/model.tar.gz\n" ] } ], "source": [ "!aws s3 cp model.tar.gz $s3_location\n" ] }, { "cell_type": "markdown", "id": "0a146346", "metadata": {}, "source": [ "## Create custom `HuggingfaceModel` \n", "\n", "After we have created and uploaded our `model.tar.gz` archive to Amazon S3. Can we create a custom `HuggingfaceModel` class. This class will be used to create and deploy our SageMaker endpoint." ] }, { "cell_type": "code", "execution_count": 44, "id": "1c5ba990", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "-----------!" ] } ], "source": [ "from sagemaker.huggingface.model import HuggingFaceModel\n", "\n", "\n", "# create Hugging Face Model Class\n", "huggingface_model = HuggingFaceModel(\n", " model_data=s3_location, # path to your model and script\n", " role=role, # iam role with permissions to create an Endpoint\n", " transformers_version=\"4.26\", # transformers version used\n", " pytorch_version=\"1.13\", # pytorch version used\n", " py_version='py39', # python version used\n", ")\n", "\n", "# deploy the endpoint endpoint\n", "predictor = huggingface_model.deploy(\n", " initial_instance_count=1,\n", " instance_type=\"ml.g4dn.xlarge\"\n", " )" ] }, { "cell_type": "markdown", "id": "b6b3812f", "metadata": {}, "source": [ "## Request Inference Endpoint using the `HuggingfacePredictor`\n", "\n", "The `.deploy()` returns an `HuggingFacePredictor` object which can be used to request inference." ] }, { "cell_type": "code", "execution_count": 45, "id": "51c5366b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'vectors': [0.005078191868960857, -0.0036594511475414038, 0.016988741233944893, -0.0015786211006343365, 0.030203675851225853, 0.09331899881362915, -0.0235157310962677, 0.011795195750892162, 0.03421774506568909, -0.027907833456993103, -0.03260169178247452, 0.0679800882935524, 0.015223750844597816, 0.025948498398065567, -0.07854384928941727, -0.0023915462661534548, 0.10089637339115143, 0.0014981384156271815, -0.017778029665350914, 0.005812637507915497, 0.02445339597761631, -0.0710371807217598, 0.04755859822034836, 0.026360979303717613, -0.05716250091791153, -0.0940014198422432, 0.047949012368917465, 0.008600219152867794, 0.03297032043337822, -0.06984368711709976, -0.0552142858505249, -0.03234352916479111, -0.0003443364112172276, 0.012479404918849468, -0.07419367134571075, 0.08545409888029099, 0.019597113132476807, 0.005851477384567261, -0.08256848156452179, 0.010150186717510223, 0.028275227174162865, -0.0016121627995744348, 0.04174523428082466, -0.009756717830896378, 0.03546829894185066, -0.0673336461186409, 0.013293622992932796, -0.047809384763240814, -0.02249010093510151, 0.028243854641914368, -0.08043544739484787, -0.01009676605463028, -0.03514788672327995, -0.021383730694651604, -0.002246067626401782, -0.015066167339682579, 0.04234122484922409, -0.040479838848114014, 0.00787312351167202, -0.04465996101498604, 0.010779906995594501, 0.0038497159257531166, -0.027719097211956978, -0.007967316545546055, 0.02942546270787716, -0.012327964417636395, 0.0050182887353003025, 0.06450540572404861, 0.03108026832342148, 0.042792391031980515, 0.023805316537618637, -0.01616135612130165, 0.02578461915254593, -0.08669176697731018, -0.044727668166160583, 7.097257184796035e-05, -0.10924965143203735, -0.10867254436016083, -0.03139006346464157, -0.03511088714003563, 0.08570166677236557, -0.134019672870636, -0.0005924605648033321, 0.029533952474594116, 0.012721308507025242, 0.02152288891375065, 0.0707324892282486, -0.11056605726480484, -0.1083742305636406, 0.0982309952378273, -0.039475709199905396, -0.05996376648545265, -0.10398901998996735, 0.03040657937526703, -0.03018292225897312, -0.03471128270030022, -0.06378458440303802, 0.016372960060834885, 0.0583597756922245, 0.012307470664381981, 0.04363206401467323, -0.031246762722730637, -0.09203378111124039, -0.0062785972841084, 0.015498220920562744, -0.07184164226055145, 0.012648160569369793, 0.014564670622348785, -0.08191244304180145, 0.023379981517791748, -0.011096887290477753, 0.0394676998257637, -0.033372823148965836, 0.041654154658317566, 0.0863155946135521, 0.015705395489931107, 0.01734650880098343, 0.08271384239196777, 0.022032614797353745, 0.03559378534555435, 0.12214990705251694, 0.032827410846948624, 0.026021108031272888, -0.019847815856337547, 0.010051277466118336, -0.04892867058515549, -0.0174998976290226, -1.4977462088666326e-33, -0.01998828910291195, -0.020090218633413315, 0.009214007295668125, 0.029388802126049995, 0.01617312990128994, 0.003455288475379348, -0.07258066534996033, 0.049684278666973114, -0.06154271960258484, 0.05080917105078697, 0.05352963134646416, -0.011941409669816494, -0.0028067785315215588, -0.041576843708753586, -0.010775507427752018, 0.00046661923988722265, 0.004454561043530703, 0.030003147199749947, -0.0516991950571537, -0.030697643756866455, -0.07532348483800888, 0.05465441197156906, -0.0385969914495945, -0.04381357878446579, -0.03235914930701256, 0.017494583502411842, 0.005240216851234436, 0.06198848783969879, -0.03355488181114197, 0.011264801025390625, -0.02115759812295437, 0.00838891975581646, -0.058978889137506485, -0.00011408641876187176, 0.05079993978142738, 0.015300493687391281, -0.07043343037366867, -0.07872467488050461, 0.09050456434488297, 0.03952907398343086, -0.07477521151304245, 0.03615942969918251, -0.058201417326927185, 0.0326484851539135, -0.03198658302426338, 0.11224830150604248, -0.016622459515929222, 0.0504615381360054, -0.04651995375752449, 0.1277347207069397, 0.03776664286851883, 0.05948572978377342, 0.09149560332298279, -0.009857898578047752, 0.004627745598554611, 0.03188807889819145, 0.062271688133478165, -0.0659433975815773, 0.0032127737067639828, -0.13898129761219025, 0.026403773576021194, 0.08804035186767578, -0.05001967027783394, 0.05326379835605621, -0.02196440100669861, 0.07656972110271454, 0.013867619447410107, -0.016544628888368607, -0.009327870793640614, 0.021883144974708557, -0.1560947597026825, -0.07534021139144897, -0.01896633207798004, 0.012034989893436432, -0.07331383228302002, -0.04332052916288376, -0.03353505954146385, 0.007872307673096657, 0.16191385686397552, -0.058967869728803635, 0.024201923981308937, 0.011731469072401524, -0.002475024200975895, -0.060298558324575424, -0.023722389712929726, -0.04882300645112991, 0.000707246595993638, -0.018090907484292984, 0.07239993661642075, 0.07933493703603745, 0.054174549877643585, -0.03342485427856445, -0.007864750921726227, 0.06494550406932831, -0.08771026879549026, 1.13459770849573e-33, 0.06040865182876587, 0.006845973432064056, -0.09519106149673462, -0.004926742985844612, 0.02894597128033638, -0.0077415574342012405, -0.05669841915369034, -0.034497782588005066, 0.09411472827196121, 0.0011957630049437284, -0.03672650456428528, 0.023257385939359665, -0.029259465634822845, -0.004881837405264378, -0.034621454775333405, -0.1123257502913475, 0.041878167539834976, 0.01935793086886406, 0.019774673506617546, 0.0033800536766648293, 0.04810955002903938, -0.043293364346027374, -0.019849350675940514, -0.024460462853312492, 0.011674574576318264, 0.028871286660432816, -0.04594291001558304, -0.009591681882739067, -0.020649896934628487, -0.0767439752817154, 0.06008455529808998, -0.07102784514427185, -0.03325150907039642, -0.07066744565963745, -0.07285013049840927, 0.06852841377258301, 0.032675426453351974, -0.015307767316699028, -0.03120141103863716, -0.0008060619584284723, -0.012935955077409744, 0.01687614619731903, 0.010606919415295124, 0.05316408351063728, -0.016209596768021584, 0.05059502646327019, -0.016619250178337097, -0.003106643445789814, -0.09400973469018936, 0.02362005040049553, -0.1493453085422516, 0.03363995999097824, -0.013002770021557808, -0.0411999374628067, -0.03762894868850708, 0.01735512912273407, -0.02544626034796238, -0.015723178163170815, 0.007998578250408173, 0.04340173304080963, 0.006307568401098251, -0.031614888459444046, -0.03868135064840317, -0.11168476939201355, 0.04688170179724693, 0.02938792295753956, 0.007106451783329248, -0.023254472762346268, 0.006188348866999149, 0.032097551971673965, 0.02284681424498558, -0.020912854000926018, -0.016115304082632065, 0.006232560612261295, -0.06727242469787598, 0.0027730280999094248, -0.04707656428217888, -0.03735049441456795, 0.026144297793507576, -0.013619091361761093, -0.005712081212550402, -0.04333459213376045, -0.008567489683628082, -0.0026371825952082872, -0.04714951291680336, 0.1506747603416443, 0.060538701713085175, 0.015910591930150986, 0.0021603393834084272, 0.09120813012123108, 0.10193410515785217, 0.04816991090774536, 0.07890739291906357, -0.05583663284778595, -0.02227107249200344, -2.478202887346015e-08, -0.08490563929080963, 0.04434036836028099, 0.02475418709218502, -0.024806825444102287, 0.00536795100197196, -0.06101489067077637, 0.014922979287803173, 0.04093354195356369, 0.03936637192964554, 0.04489367827773094, 0.012824231758713722, -0.03051156736910343, 0.0662570372223854, 0.04904399439692497, 0.004838698077946901, 0.07400422543287277, 0.03470872715115547, 0.037787146866321564, -0.043043263256549835, 0.04372495785355568, 0.023403732106089592, 0.057728372514247894, 0.034502316266298294, -0.049777042120695114, -0.0041667199693620205, 0.06382499635219574, -0.007370579522103071, -0.002130263252183795, -0.04700297489762306, 0.10623563826084137, -5.87037175137084e-05, -0.012606821022927761, 0.03633716702461243, 0.024944987148046494, -0.06500178575515747, 0.07670733332633972, 0.01752745360136032, 0.019638163968920708, 0.05920606851577759, 0.021030694246292114, 0.033589065074920654, 0.014452814124524593, 0.030615368857979774, 0.13622330129146576, 0.0162414088845253, 0.07696809619665146, 0.10586545616388321, 0.06321518868207932, -0.06497083604335785, 0.0035124991554766893, 0.03836303576827049, -0.049263447523117065, -0.0939357802271843, 0.04310446232557297, 0.047002870589494705, 0.02352922037243843, 0.06475073844194412, 0.12606267631053925, -0.03936544433236122, 0.0033126939088106155, -0.005963532254099846, 0.01087606605142355, -0.006803632713854313, 0.05783495306968689]}\n" ] } ], "source": [ "data = {\n", " \"inputs\": \"the mesmerizing performances of the leads keep the film grounded and keep the audience riveted .\",\n", "}\n", "\n", "res = predictor.predict(data=data)\n", "print(res)\n" ] }, { "cell_type": "markdown", "id": "cb10007d", "metadata": {}, "source": [ "### Delete model and endpoint\n", "\n", "To clean up, we can delete the model and endpoint." ] }, { "cell_type": "code", "execution_count": 46, "id": "1e6fb7b8", "metadata": {}, "outputs": [], "source": [ "predictor.delete_model()\n", "predictor.delete_endpoint()" ] } ], "metadata": { "interpreter": { "hash": "c281c456f1b8161c8906f4af2c08ed2c40c50136979eaae69688b01f70e9f4a9" }, "kernelspec": { "display_name": "conda_pytorch_p39", "language": "python", "name": "conda_pytorch_p39" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.15" } }, "nbformat": 4, "nbformat_minor": 5 }