# SageMaker + Astra DB, integration example

Use an LLM and an Embedding model from Amazon SageMaker and a Vector Store from [Astra DB](https://docs.datastax.com/en/astra/home/astra.html) to run a simple RAG-based application.

In this notebook, you will:
- either deploy Embedding model and LLM, or connect to existing ones in SageMaker, and see them in action;
- Connect with Astra DB and create a Vector Store in it;
- populate it with example "pretend entomology" information;
- run an AI-powered entomology assistant to help identification of field insect observations.

> Note: this notebook is designed to run within Amazon SageMaker Studio. See [this page](https://awesome-astra.github.io/docs/pages/aiml/aws/aws-sagemaker/) for more information and references.

## General setup

_Note: you may see some dependency-resolution error in the output from `pip` here. Do not pay too much attention: the rest of this notebook will work just fine._

In [None]:
!pip install --upgrade pip
!pip install --quiet \
 "langchain-astradb>=0.3.3" \
 "langchain>=0.2,<0.3" \
 "sagemaker>=2,<3" \
 "datasets>=2.16.1" \
 "jupyter-ai-magics>=2.19.1" \
 "faiss-cpu"

# (don't worry about the last two lines above: these fix a temporary version-clash issue with JupyterLab's preinstalled image)

In [None]:
import json
import getpass
from typing import Dict, List, Optional, Any

import boto3

from datasets import load_dataset

from sagemaker.jumpstart.model import JumpStartModel
from sagemaker.session import Session
from sagemaker import image_uris, model_uris
from sagemaker.predictor import Predictor
from sagemaker.model import Model
from sagemaker.utils import name_from_base
from sagemaker.base_serializers import JSONSerializer
from sagemaker.base_deserializers import JSONDeserializer

In [None]:
from langchain_core.callbacks.manager import CallbackManagerForLLMRun

from langchain_community.embeddings import SagemakerEndpointEmbeddings
from langchain_community.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
from langchain_community.llms import SagemakerEndpoint
from langchain_community.llms.sagemaker_endpoint import LLMContentHandler

from langchain_astradb import AstraDBVectorStore

In [None]:
boto3_sm_client = boto3.client('runtime.sagemaker')
region_name = boto3.Session().region_name

sagemaker_session = Session()
aws_role = sagemaker_session.get_caller_identity_arn()

#### Define a custom predictor ser/des

Prepare a function that specializes the default SageMaker "Predictor": this will come handy a few times when working around the `Model` objects.

> In some cases one can pass a `Model` out of the box, but for these models you want to specify usage of
> JSON serialization/deserialization when interacting with the endpoints.

In [None]:
def my_json_predictor(*pargs, **kwargs):
 return Predictor(
 serializer=JSONSerializer(),
 deserializer=JSONDeserializer(),
 *pargs,
 **kwargs,
 )

## Embedding model, setup

Here you can choose between a model already deployed in the UI and a programmatic deploy throug the SageMaker SDK.

In [None]:
emb_endpoint_supplied = False

emb_endpoint_name = input("Enter the *embedding model* endpoint name if already deployed (leave empty if deploying with SDK):").strip()

if emb_endpoint_name == "":
 print(f"\n{'*' * 101}")
 print("*** INFO: the embedding model will be deployed programmatically, as no endpoint name was provided. **")
 print("*** Re-run this cell and supply the endpoint name if this is incorrect. **")
 print(f"{'*' * 101}")
else:
 emb_endpoint_supplied = True

The following cell will perform a programmatic deployment of a JumpStart model through the SageMaker SDK.

Note that it will do nothing else than print a message, instead, if the embedding model endpoint has been given already, i.e. if the model has been deployed beforehand through the UI.

#### This is the actual deploy step.

(in case of programmatic deploy, that is).

> _Note: this cell may take even **ten minutes** to complete. You may check the SageMaker Studio 'endpoints' tab while this is running._

In [None]:
if not emb_endpoint_supplied:
 emb_model_id = "huggingface-textembedding-gpt-j-6b"
 emb_endpoint_name = name_from_base(emb_model_id)
 emb_instance_type = "ml.g5.24xlarge"
 emb_model_version = "1.0.1"
 emb_model = JumpStartModel(
 model_id=emb_model_id,
 model_version=emb_model_version,
 instance_type=emb_instance_type,
 predictor_cls=my_json_predictor,
 )
 print(f"Deploying (endpoint name = '{emb_endpoint_name}') ...")
 emb_predictor = emb_model.deploy(
 initial_instance_count=1,
 instance_type=emb_instance_type,
 endpoint_name=emb_endpoint_name,
 )
 print(f"\nDeploy completed.")

 # a quick test that the model is behaving
 emb_test_result = json.dumps(emb_predictor.predict({"text_inputs": ["I am here!", "So do I."]}))
 print(f"\nEmb. model functional test resulted in: '{emb_test_result[:50]}...'")

 print(f"\nSetting the emb. endpoint name to '{emb_endpoint_name}'")
 emb_endpoint_name = emb_predictor.endpoint_name
else:
 print("(nothing to do in this case)")

## Embedding model, LangChain setup

To be able to work with the shape of the input and output specific to _this_ embedding model, you need to create and supply a suitable `EmbeddingsContentHandler` when instantiating the LangChain abstraction for the SageMaker embedding:

In [None]:
class SageMakerGPTJ6BContentHandler(EmbeddingsContentHandler):
 content_type = "application/json"
 accepts = "application/json"

 def transform_input(self, inputs: list[str], model_kwargs: Dict) -> bytes:
 input_encoded = json.dumps({
 "text_inputs": inputs,
 **model_kwargs,
 }).encode("utf-8")
 return input_encoded

 def transform_output(self, output: bytes) -> List[List[float]]:
 """
 `output` is actually a botocore.response.StreamingBody object in our case
 """
 response_json = json.loads(output.read().decode("utf-8"))
 return response_json["embedding"]


emb_content_handler = SageMakerGPTJ6BContentHandler()

embeddings = SagemakerEndpointEmbeddings(
 endpoint_name=emb_endpoint_name,
 region_name=region_name,
 content_handler=emb_content_handler,
)

### Embedding model, test invocation through LangChain

As a simple test, check that the model returns vectors normalized to having unit norm:

In [None]:
vector1 = embeddings.embed_query("Hello, SageMaker")
vectors = embeddings.embed_documents(["Can you embed multiple sentences at once?", "Sure, you can."])

print(f"Vector dimensionality: {len(vector1)}")

print(f"Norm of 'vector1': {sum(x*x for x in vector1):.4f}")

print("Norms of 'vectors'")
for i, v in enumerate(vectors):
 print(f" [{i}] norm = {sum(x*x for x in v):.4f}")

## LLM, setup

Here you can choose between a model already deployed in the UI and a programmatic deploy throug the SageMaker SDK.

In [None]:
llm_endpoint_supplied = False

llm_endpoint_name = input("Enter the *LLM* endpoint name if already deployed (leave empty if deploying with SDK):").strip()

if llm_endpoint_name == "":
 print(f"\n{'*' * 89}")
 print("*** INFO: the LLM will be deployed programmatically, as no endpoint name was provided. **")
 print("*** Re-run this cell and supply the endpoint name if this is incorrect. **")
 print(f"{'*' * 89}")
else:
 llm_endpoint_supplied = True

The following cell works similarly to the embedding model deployment seen earlier:

#### This is the actual deploy step.

(in case of programmatic deploy, that is).

> _Note: this cell may take even **twenty minutes or so** to complete. You may check the SageMaker Studio 'endpoints' tab while this is running._

In [None]:
if not llm_endpoint_supplied:
 llm_model_id = "meta-textgeneration-llama-2-70b-f"
 llm_endpoint_name = name_from_base(llm_model_id)
 llm_instance_type = "ml.g5.48xlarge"
 llm_model_version = "3.0.2"
 llm_model = JumpStartModel(
 model_id=llm_model_id,
 model_version=llm_model_version,
 instance_type=llm_instance_type,
 predictor_cls=my_json_predictor,
 )

 print(f"Deploying (endpoint name = '{llm_endpoint_name}') ...")
 llm_predictor = llm_model.deploy(
 initial_instance_count=1,
 instance_type=llm_instance_type,
 endpoint_name=llm_endpoint_name,
 accept_eula=True,
 )
 print(f"\nDeploy completed.")

 # a quick test that the model is behaving
 llm_test_result = json.dumps(llm_predictor.predict(
 {"inputs": "Write a short three-stanzas poem about ichneumonid wasps.", "parameters": {"max_new_tokens": 256}},
 ))
 print(f"\nLLM model functional test resulted in: '{llm_test_result[:70]}...'")
 
 print(f"\nSetting the LLM endpoint name to '{llm_endpoint_name}'")
 llm_endpoint_name = llm_predictor.endpoint_name
else:
 print("(nothing to do in this case)")

## LLM, LangChain setup

Similarly as what was done for the embedding model, you need to provide a "Content Handler" tailored to the specific signature of this LLM.

In [None]:
from langchain_community.llms import SagemakerEndpoint
from langchain_community.llms.sagemaker_endpoint import LLMContentHandler

class ContentHandler(LLMContentHandler):
 content_type = "application/json"
 accepts = "application/json"

 def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
 input_str = json.dumps({"inputs": prompt, "parameters": model_kwargs})
 return input_str.encode("utf-8")

 def transform_output(self, output: bytes) -> str:
 response_json = json.loads(output.read().decode("utf-8"))
 return response_json["generated_text"]

content_handler = ContentHandler()

llm=SagemakerEndpoint(
 endpoint_name=llm_endpoint_name,
 # credentials_profile_name="credentials-profile-name",
 region_name=region_name,
 model_kwargs={"max_new_tokens": 3072, "top_p": 0.4, "temperature": 0.001},
 endpoint_kwargs={
 "CustomAttributes": "accept_eula=true",
 },
 content_handler=content_handler,
)

_A note about the `endpoint_kwargs` parameter._

As mentioned earlier, for this model each LLM call must carry a special header to signal acceptance of the EULA. This is accomplished,
at the LangChain level, by passing this parameter when creating the `SagemakerEndpoint` instance. For reference, you can check how this parameter
is used within the LangChain code ([check the code](https://github.com/langchain-ai/langchain/blob/7db6aabf65e70811e40ee6f2e1ba8e0425ba81c9/libs/langchain/langchain/llms/sagemaker_endpoint.py#L359C23-L359C39)).
Essentially the EULA acceptance flag is passed down to the underlying `boto3` library, whose `invoke_endpoint` method accepts the `CustomAttributes` parameter
([check the docs](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker-runtime/client/invoke_endpoint.html#invoke-endpoint)).

### LLM, test invocation through LangChain

In [None]:
print(llm.invoke("Summarize the differences between insects and scorpions in less than ten words.").strip())

## Vector store on Astra DB

In this section, first provide the credentials to the Astra DB instance, used later to create the LangChain vector store:

In [None]:
ASTRA_DB_API_ENDPOINT = input("Enter your Astra DB API endpoint ('https://...astra.datastax.com'):")
ASTRA_DB_APPLICATION_TOKEN = getpass.getpass("Enter your Astra DB Token ('AstraCS:...'):")
desired_keyspace = input("Enter your Astra DB namespace (leave empty if default):")
if desired_keyspace:
 ASTRA_DB_KEYSPACE = desired_keyspace
else:
 ASTRA_DB_KEYSPACE = None

Now a vector store is created, ready for use:

In [None]:
astra_v_store = AstraDBVectorStore(
 token=ASTRA_DB_APPLICATION_TOKEN,
 api_endpoint=ASTRA_DB_API_ENDPOINT,
 namespace=ASTRA_DB_KEYSPACE,
 collection_name="sagemaker_demo_v_store",
 embedding=embeddings,
)

A small example dataset is loaded through HuggingFace. You can print a sample item to get an idea of its structure.

In [None]:
sample_dataset = load_dataset("datastax/entomology")["train"]

def _shorten(dct): return {k: v if len(v) < 40 else v[:40]+"..." for k, v in dct.items()}

print(f"Loaded {len(sample_dataset)} entries")
print("Example entry:")
print("\n".join(
 f" {l}" for l in json.dumps(_shorten(sample_dataset[19]), indent=4).split("\n")
))

The dataset is prepared for insertion in the vector store:

_(Note: Care is taken of calculating IDs deterministically to avoid accidental creation of duplicates in case the `add_texts` cell is run repeatedly.)_

In [None]:
texts = [entry["description"] for entry in sample_dataset]
metadatas = [
 {
 "name": entry["name"],
 "order": entry["order"],
 }
 for entry in sample_dataset
]
ids = [entry["name"].lower().replace(" ", "_") for entry in sample_dataset]

print(f"Example from `texts`:\n \"{texts[19][:40]}...\"")
print(f"Example from `metadatas`:\n {metadatas[19]}")
print(f"Example from `ids`:\n \"{ids[19]}\"")

This is where the writes take place (and the embedding vectors are calculated for each item in `texts`):

In [None]:
inserted_ids = astra_v_store.add_texts(texts=texts, metadatas=metadatas, ids=ids)

print(f"Inserted: {', '.join(inserted_ids)[:80]}... ({len(inserted_ids)} items)")

## Set up the full pipeline

### Retrieval part

Package the search part of the flow in a handy function:

In [None]:
def find_similar_entries(observation, k=3, order=None):
 if order:
 md = {"order": order}
 else:
 md = {}
 documents = astra_v_store.similarity_search(observation, k=k, filter=md)
 return documents

In [None]:
print(find_similar_entries("Long wings with brown spots, flies erratically, thin legs", k=2, order="Odonata"))

### Generation part

In [None]:
PROMPT_TEMPLATE = """
[INST] <>
You are an expert entomologist tasked with helping specimen identification on the field.
You are given relevant excerpts from an invertebrate textbook along with my field observation.
Your task is to compare my observation with the textbook excerpts and come to an identification,
explaining why you came to that conclusion and giving the degree of certainity.
Only use the information provided in the user observation to come to your conclusion!
Be sure to provide, in your verdict, the species' Order together with the full Latin name.
Keep it short and informal, not like a letter, do not start with 'Dear User' or similar,
do not sign your communication.

TEXTBOOK CANDIDATE MATCHES:
{candidates}

<>

Here is my observation:
{observation}

Please assist me in the identification. [/INST]
"""

In [None]:
def describe_candidates(matches):
 return "\n".join([
 f"Candidate species {i+1}: '{doc.metadata['name']}' (order: {doc.metadata['order']})\nDescription: {doc.page_content}\n"
 for i, doc in enumerate(matches)
 ])

def format_prompt(observation, candidates):
 return PROMPT_TEMPLATE.format(observation=observation, candidates=candidates)

In [None]:
candidates = describe_candidates(find_similar_entries("Long wings with brown spots, flies erratically, thin legs", k=2, order="Odonata"))
print(candidates)

In [None]:
print(format_prompt(observation="I saw a certain bug!", candidates=candidates))

In [None]:
def identify_and_suggest(observation, order=None):
 matches = find_similar_entries(observation, k=3, order=order)
 candidates_text = describe_candidates(matches)
 prompt = format_prompt(
 observation=observation,
 candidates=candidates_text,
 )
 return llm.invoke(prompt).strip()

### Putting it all to test

In [None]:
print(identify_and_suggest("A large butterfly with pointed wing tips and a yellow spot in the middle of each wing."))

In [None]:
print(identify_and_suggest("I found a nondescript brown bug with small wings, dark elitra and sturdy antennae in a meadow."))

In [None]:
print(identify_and_suggest("What looked like a leaf was in fact moving! It startled me greatly. But I'm not sure it's an insect, I did not see antennae. What was it?"))

### The "final app":

The loop below is a simple "app" to repeatedly interact with the entomology assistant:

- Try it with simple observations such as _I found a strange bug in the library, whose appearance was that of an old piece of paper. What was it?_
- Enter an empty input to end the cell.

In [None]:
while True:
 observation = input("\n=============================\nEnter your field observation: ").strip()
 if observation:
 print("-----------------------------")
 result = identify_and_suggest(observation)
 print(f"Result ==> {result}")
 else:
 print("(no input)")
 break
 
print("\n========\nGoodbye.")

## Appendix: non-LangChain model tests

The code below is not part of the main LangChain-based application, but shows how you can use the SageMaker endpoints at lower abstraction layers than LangChain, namely by calling directly the boto3 or the SageMaker SDK primitives. Note that in the latter case, if you have deployed the model in the SageMaker UI, you will have to construct a `Predictor` object manually.

_These non-LangChain idioms are important in themselves, as they open the way to a richer set of possibilities for integrating Astra DB with Amazon SageMaker._

### Embedding model, test invocation through boto3

In [None]:
encoded_body = json.dumps(
 {
 "text_inputs": [
 "Can you invoke a SageMaker embedding model from boto3 directly?",
 "Wait and see..."
 ]
 }
).encode("utf-8")

response = boto3_sm_client.invoke_endpoint(
 EndpointName=emb_endpoint_name,
 Body=encoded_body,
 ContentType='application/json',
 Accept='application/json',
)

response_body = response['Body']
read_body = response_body.read()
response_json = json.loads(read_body.decode())

# This is a list 2 lists, each made of 4096 floats:
embedding_vectors = response_json['embedding']

print(f"Returned {len(embedding_vectors)} embedding vectors.")
print(f"Each is made of {len(embedding_vectors[0])} float values.")
print(f" The first one starts with: {str(embedding_vectors[0])[:80]}...")

### Embedding model, test invocation through SageMaker SDK

In [None]:
if emb_endpoint_supplied:
 emb_predictor = my_json_predictor(emb_endpoint_name)
else:
 # `emb_predictor` was already created as part of the deploy-from-code procedure
 pass

response_json = emb_predictor.predict(
 {"text_inputs": [
 "Can you show me how to use the SageMaker SDK directly for embeddings?",
 "Let me look at the docs..."
 ]
 }
)

# This is a list 2 lists, each made of 4096 floats:
embedding_vectors = response_json["embedding"]

print(f"Returned {len(embedding_vectors)} embedding vectors.")
print(f"Each is made of {len(embedding_vectors[0])} float values.")
print(f" The first one starts with: {str(embedding_vectors[0])[:80]}...")

### LLM, test invocation through boto3

For this particular model, the `inputs` field is a string. In this case it is a simple string, juts a piece of text. The particular encoding required to provide system/assistant/user exchanges can be found in the "Test inference" tab of your deployed endpoint, looking for the (Python) programmatic example.

Note how the EULA acceptance is passed in this case ([reference](https://sagemaker.readthedocs.io/en/stable/api/inference/predictors.html)).

In [None]:
sample_question = ("Answer witty and in less than 20 words: what would "
 "Heidegger do if he were suddenly transported on the Moon?")

encoded_body = json.dumps({
 "inputs": sample_question,
 "parameters": {
 "max_new_tokens": 256,
 "top_p": 0.9,
 "temperature": 0.6
 },
}).encode("utf-8")

response = boto3_sm_client.invoke_endpoint(
 EndpointName=llm_endpoint_name,
 Body=encoded_body,
 ContentType='application/json',
 Accept='application/json',
 # This is required for each invocation of this model:
 CustomAttributes='accept_eula=true',
)
response_body = response['Body']
read_body = response_body.read()
response_json = json.loads(read_body.decode())

print(f"Full response:\n")
print(json.dumps(response_json, indent=4))

### LLM, test invocation through SageMaker SDK

Note how the EULA acceptance is passed in this case ([reference](https://sagemaker.readthedocs.io/en/stable/api/inference/predictors.html)).

In [None]:
if llm_endpoint_supplied:
 llm_predictor = my_json_predictor(llm_endpoint_name)
else:
 # `llm_predictor` was already created as part of the deploy-from-code procedure
 pass


response_json = llm_predictor.predict(
 {
 "inputs": sample_question,
 "parameters": {
 "max_new_tokens": 256,
 "top_p": 0.9,
 "temperature": 0.6
 },
 },
 custom_attributes='accept_eula=true',
)

print(f"Full response:\n")
print(json.dumps(response_json, indent=4))

## (Optional) Astra DB cleanup

If you want to deallocate all resources used in the demo, besides going through the [AWS side of the operation](https://awesome-astra.github.io/docs/pages/aiml/aws/aws-sagemaker/#cleanup), you might want to delete the vector collection on Astra DB used throughout this example. To do so, simply run the following:

In [None]:
astra_v_store.delete_collection()