{
"cells": [
{
"cell_type": "markdown",
"id": "0c7f9134",
"metadata": {},
"source": [
"# Deploy Chronos-2 to AWS with Amazon SageMaker"
]
},
{
"cell_type": "markdown",
"id": "3053768c",
"metadata": {},
"source": [
"This notebook shows how to deploy **Chronos-2** to AWS using **Amazon SageMaker**.\n",
"\n",
"### Why Deploy to SageMaker?\n",
"Running models locally works for experimentation, but production use cases need reliability, scale, and integration into existing workflows. For example, you may need to generate forecasts for thousands of time series on a regular schedule, or integrate forecasts into applications that serve many users. SageMaker lets you deploy Chronos-2 to the cloud and access it from anywhere.\n",
"\n",
"### Deployment Options\n",
"This notebook covers three deployment modes on SageMaker:\n",
"\n",
"1. **[Real-time Inference](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints.html)**\n",
" - ✅ Highest throughput, consistently low latency, supports both GPU and CPU instances\n",
" - ✅ Simple setup via JumpStart\n",
" - ❌ By default, you pay for the time the endpoint is running (can be configured to [scale to zero](https://docs.aws.amazon.com/sagemaker/latest/dg/endpoint-auto-scaling-zero-instances.html))\n",
"\n",
"2. **[Serverless Inference (CPU only)](https://docs.aws.amazon.com/sagemaker/latest/dg/serverless-endpoints.html)**\n",
" - ✅ Pay only for active inference time, no infrastructure management\n",
" - ✅ Cost-efficient for intermittent or unpredictable traffic\n",
" - ❌ Cold start latency on first request after idle, lowest throughput of all options\n",
" - ❌ More complex setup (requires repackaging model artifacts)\n",
"\n",
"3. **[Batch Transform](https://docs.aws.amazon.com/sagemaker/latest/dg/batch-transform.html)**\n",
" - ✅ Pay only for active compute time, no persistent infrastructure\n",
" - ✅ Cost-efficient for large-scale batch prediction jobs\n",
" - ❌ Initialization takes severa minutes for each job (not for real-time use), requires data in S3\n",
" - ❌ More complex setup (requires repackaging model artifacts)\n",
"\n",
"**Reference benchmark** on a dataset with 1M rows (2000 time series with 500 observations each) and prediction length of 28:\n",
"| Mode | Instance | Inference time (s) |\n",
"|------|----------|------|\n",
"| Real-time (GPU) | ml.g5.2xlarge | 18 |\n",
"| Real-time (CPU) | ml.c5.4xlarge | 50 |\n",
"| Serverless | 6GB memory | 120 |\n",
"| Batch Transform | ml.c5.4xlarge | 60 (+200s setup) |\n",
"\n",
"We recommend starting with **Real-time Inference** as it offers the simplest setup and highest throughput. Consider Serverless or Batch Transform when you need to optimize costs and don't require GPU acceleration.\n",
"\n",
"For a complete specification of all supported request parameters, see the **Endpoint API Reference** at the end of this notebook."
]
},
{
"cell_type": "markdown",
"id": "78b40323",
"metadata": {},
"source": [
"
\n",
"
ℹ️ New to Chronos-2?\n",
"For an overview of Chronos-2 capabilities (univariate, multivariate, covariates), see the
Chronos-2 Quick Start notebook.\n",
"
"
]
},
{
"cell_type": "markdown",
"id": "a07155fb",
"metadata": {},
"source": [
"\n",
"
⚠️ Looking for Chronos-Bolt or original Chronos?\n",
"This notebook covers
Chronos-2, the latest and recommended model. For documentation on older models (Chronos-Bolt and original Chronos), see the
legacy deployment walkthrough.\n",
"
"
]
},
{
"cell_type": "markdown",
"id": "1f595c0a",
"metadata": {},
"source": [
"## Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4907ecb7",
"metadata": {},
"outputs": [],
"source": [
"!pip install -U -q \"sagemaker<3\""
]
},
{
"cell_type": "markdown",
"id": "69fa28b0",
"metadata": {},
"source": [
"If running in a SageMaker Notebook with the correct execution role, `role` can be set to `None`. Otherwise, specify your IAM role ARN."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6c135de1",
"metadata": {},
"outputs": [],
"source": [
"role = None # or \"arn:aws:iam::123456789012:role/service-role/AmazonSageMaker-ExecutionRole-XXXXXXXXXXXXXXX\""
]
},
{
"cell_type": "markdown",
"id": "9be17b67",
"metadata": {},
"source": [
"---\n",
"## Section 1: Real-time Inference\n",
"\n",
"Real-time inference is the simplest option. SageMaker keeps a dedicated instance running, ready to serve predictions with low latency.\n",
"\n",
"**When to use:**\n",
"- Interactive applications that need sub-second response times\n",
"- Consistent, predictable traffic\n",
"- When simplicity matters more than cost optimization"
]
},
{
"cell_type": "markdown",
"id": "1abbbaf5",
"metadata": {},
"source": [
"### Deploy the Model\n",
"\n",
"With SageMaker JumpStart, you configure the deployment with a few parameters:\n",
"\n",
"- `model_id`: The model to deploy. Use `pytorch-forecasting-chronos-2` for [Chronos-2](https://huggingface.co/amazon/chronos-2).\n",
"- `instance_type`: The AWS instance type for serving. Supported options:\n",
" - **GPU**: `ml.g5.xlarge`, `ml.g5.2xlarge`, `ml.g6.xlarge`, `ml.g6.2xlarge`, `ml.g6e.xlarge`, `ml.g6e.2xlarge`, `ml.g4dn.xlarge`, `ml.g4dn.2xlarge`\n",
" - **CPU**: `ml.m5.xlarge`, `ml.m5.2xlarge`, `ml.m5.4xlarge`, `ml.c5.xlarge`, `ml.c5.2xlarge`, `ml.c5.4xlarge`\n",
"\n",
"JumpStart automatically sets other attributes like `image_uri` based on your choices. See [SageMaker pricing](https://aws.amazon.com/sagemaker/ai/pricing/) for instance costs."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ba78fc0b",
"metadata": {},
"outputs": [],
"source": [
"from sagemaker.jumpstart.model import JumpStartModel\n",
"\n",
"js_model = JumpStartModel(\n",
" model_id=\"pytorch-forecasting-chronos-2\",\n",
" instance_type=\"ml.g5.2xlarge\",\n",
" role=role,\n",
")\n",
"\n",
"predictor = js_model.deploy()"
]
},
{
"cell_type": "markdown",
"id": "f1ce1202",
"metadata": {},
"source": [
"> **Note:** After the endpoint is deployed, it will incur charges until you delete it with `predictor.delete_predictor()`"
]
},
{
"cell_type": "markdown",
"id": "9ef676fc",
"metadata": {},
"source": [
"To connect to an existing endpoint instead:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "159e1e83",
"metadata": {},
"outputs": [],
"source": [
"# from sagemaker.predictor import Predictor\n",
"# from sagemaker.serializers import JSONSerializer\n",
"# from sagemaker.deserializers import JSONDeserializer\n",
"#\n",
"# predictor = Predictor(\"NAME_OF_EXISTING_ENDPOINT\", serializer=JSONSerializer(), deserializer=JSONDeserializer())"
]
},
{
"cell_type": "markdown",
"id": "88c075c6",
"metadata": {},
"source": [
"### Query the Endpoint"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1c5d7cfd",
"metadata": {
"lines_to_next_cell": 1
},
"outputs": [],
"source": [
"from pprint import pformat\n",
"\n",
"\n",
"def nested_round(data, decimals=2):\n",
" \"\"\"Round numbers, including nested dicts and lists.\"\"\"\n",
" if isinstance(data, float):\n",
" return round(data, decimals)\n",
" elif isinstance(data, list):\n",
" return [nested_round(item, decimals) for item in data]\n",
" elif isinstance(data, dict):\n",
" return {key: nested_round(value, decimals) for key, value in data.items()}\n",
" return data\n",
"\n",
"\n",
"def pretty_format(data):\n",
" return pformat(nested_round(data), width=150, sort_dicts=False)"
]
},
{
"cell_type": "markdown",
"id": "e008ad07",
"metadata": {},
"source": [
"#### Univariate Forecasting"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "af030c65",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'predictions': [{'mean': [-0.36, 4.02, 5.3, 2.45, -2.48, -5.14, -4.33, 0.06, 4.42, 5.14],\n",
" '0.1': [-1.68, 2.86, 4.01, 1.01, -3.77, -6.22, -5.39, -1.77, 2.6, 3.62],\n",
" '0.5': [-0.36, 4.02, 5.3, 2.45, -2.48, -5.14, -4.33, 0.06, 4.42, 5.14],\n",
" '0.9': [1.02, 5.02, 6.32, 3.82, -0.85, -3.92, -2.93, 1.83, 5.63, 6.44]}]}\n"
]
}
],
"source": [
"payload = {\n",
" \"inputs\": [\n",
" {\"target\": [0.0, 4.0, 5.0, 1.5, -3.0, -5.0, -3.0, 1.5, 5.0, 4.0, 0.0, -4.0, -5.0, -1.5, 3.0, 5.0, 3.0, -1.5, -5.0, -4.0]},\n",
" ],\n",
" \"parameters\": {\"prediction_length\": 10},\n",
"}\n",
"response = predictor.predict(payload)\n",
"print(pretty_format(response))"
]
},
{
"cell_type": "markdown",
"id": "a2e4d3a2",
"metadata": {},
"source": [
"#### Multiple Time Series with Metadata"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "cd8d2d7c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'predictions': [{'mean': [1.69, 1.94, 1.65, 1.54, 1.84],\n",
" '0.1': [0.28, 0.31, -0.07, -0.35, -0.18],\n",
" '0.5': [1.69, 1.94, 1.65, 1.54, 1.84],\n",
" '0.9': [3.09, 3.77, 3.62, 3.58, 4.23],\n",
" 'item_id': 'product_A',\n",
" 'start': '2024-01-01T10:00:00'},\n",
" {'mean': [-1.2, -1.41, -1.27, -1.37, -1.3],\n",
" '0.1': [-4.21, -5.83, -6.39, -7.58, -8.05],\n",
" '0.5': [-1.2, -1.41, -1.27, -1.37, -1.3],\n",
" '0.9': [2.01, 2.91, 3.55, 4.66, 5.66],\n",
" 'item_id': 'product_B',\n",
" 'start': '2024-02-02T10:00:00'}]}\n"
]
}
],
"source": [
"payload = {\n",
" \"inputs\": [\n",
" {\"target\": [1.0, 2.0, 3.0, 2.0, 0.5, 2.0, 3.0, 2.0, 1.0], \"item_id\": \"product_A\", \"start\": \"2024-01-01T01:00:00\"},\n",
" {\"target\": [5.4, 3.0, 3.0, 2.0, 1.5, 2.0, -1.0], \"item_id\": \"product_B\", \"start\": \"2024-02-02T03:00:00\"},\n",
" ],\n",
" \"parameters\": {\"prediction_length\": 5, \"freq\": \"1h\", \"quantile_levels\": [0.1, 0.5, 0.9]},\n",
"}\n",
"response = predictor.predict(payload)\n",
"print(pretty_format(response))"
]
},
{
"cell_type": "markdown",
"id": "d6a84b17",
"metadata": {},
"source": [
"#### Forecasting with Covariates"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "0dcaa27a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'predictions': [{'mean': [1.73, 2.09, 1.74], '0.1': [0.35, 0.58, 0.17], '0.5': [1.73, 2.09, 1.74], '0.9': [3.11, 3.79, 3.52]}]}\n"
]
}
],
"source": [
"payload = {\n",
" \"inputs\": [\n",
" {\n",
" \"target\": [1.0, 2.0, 3.0, 2.0, 0.5, 2.0, 3.0, 2.0, 1.0],\n",
" \"past_covariates\": {\n",
" \"feat_1\": [3.0, 6.0, 9.0, 6.0, 1.5, 6.0, 9.0, 6.0, 3.0],\n",
" \"feat_2\": [\"A\", \"B\", \"B\", \"B\", \"A\", \"A\", \"A\", \"A\", \"B\"],\n",
" \"feat_3\": [10.0, 20.0, 30.0, 20.0, 5.0, 20.0, 30.0, 20.0, 10.0], # past-only\n",
" },\n",
" \"future_covariates\": {\"feat_1\": [2.5, 2.2, 3.3], \"feat_2\": [\"B\", \"A\", \"A\"]},\n",
" },\n",
" ],\n",
" \"parameters\": {\"prediction_length\": 3, \"quantile_levels\": [0.1, 0.5, 0.9]},\n",
"}\n",
"response = predictor.predict(payload)\n",
"print(pretty_format(response))"
]
},
{
"cell_type": "markdown",
"id": "98c40db2",
"metadata": {},
"source": [
"#### Multivariate Forecasting"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "9ca28aea",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'predictions': [{'mean': [[3.66, 3.54, 3.5, 3.42], [2.01, 2.07, 2.2, 2.25], [3.33, 3.27, 3.25, 3.21]],\n",
" '0.1': [[1.98, 1.52, 1.16, 0.88], [0.84, 0.21, 0.03, -0.27], [2.49, 2.26, 2.08, 1.94]],\n",
" '0.5': [[3.66, 3.54, 3.5, 3.42], [2.01, 2.07, 2.2, 2.25], [3.33, 3.27, 3.25, 3.21]],\n",
" '0.9': [[5.76, 6.22, 6.59, 6.99], [3.8, 4.48, 4.89, 5.31], [4.38, 4.61, 4.79, 5.0]]}]}\n"
]
}
],
"source": [
"payload = {\n",
" \"inputs\": [\n",
" {\n",
" \"target\": [\n",
" [1.0, 2.0, 3.0, 2.0, 1.0, 2.0, 3.0, 4.0], # Dimension 1\n",
" [5.0, 4.0, 3.0, 4.0, 5.0, 4.0, 3.0, 2.0], # Dimension 2\n",
" [2.0, 2.5, 3.0, 2.5, 2.0, 2.5, 3.0, 3.5], # Dimension 3\n",
" ],\n",
" },\n",
" ],\n",
" \"parameters\": {\"prediction_length\": 4, \"quantile_levels\": [0.1, 0.5, 0.9]},\n",
"}\n",
"response = predictor.predict(payload)\n",
"print(pretty_format(response))"
]
},
{
"cell_type": "markdown",
"id": "e04e42b2",
"metadata": {},
"source": [
"### Working with Long-Format DataFrames\n",
"\n",
"Time series data is often stored in long-format DataFrames. The following helper functions convert between DataFrame and payload formats. You can skip this section if you prefer to construct payloads manually."
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "fb8e4c95",
"metadata": {
"lines_to_next_cell": 1
},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"\n",
"def convert_df_to_payload(\n",
" past_df,\n",
" future_df=None,\n",
" prediction_length=1,\n",
" freq=\"D\",\n",
" target=\"target\",\n",
" id_column=\"item_id\",\n",
" timestamp_column=\"timestamp\",\n",
"):\n",
" \"\"\"\n",
" Converts past and future DataFrames into JSON payload format for the Chronos endpoint.\n",
"\n",
" Args:\n",
" past_df: Historical data with target, timestamp_column, and id_column.\n",
" future_df: Future covariates with timestamp_column and id_column.\n",
" prediction_length: Number of future time steps to predict.\n",
" freq: Pandas-compatible frequency of the time series.\n",
" target: Column name(s) for target values (str for univariate, list for multivariate).\n",
" id_column: Column name for item IDs.\n",
" timestamp_column: Column name for timestamps.\n",
"\n",
" Returns:\n",
" dict: JSON payload formatted for the Chronos endpoint.\n",
" \"\"\"\n",
" past_df = past_df.sort_values([id_column, timestamp_column])\n",
" if future_df is not None:\n",
" future_df = future_df.sort_values([id_column, timestamp_column])\n",
"\n",
" target_cols = [target] if isinstance(target, str) else target\n",
" past_covariate_cols = list(past_df.columns.drop([*target_cols, id_column, timestamp_column]))\n",
" future_covariate_cols = [] if future_df is None else [col for col in past_covariate_cols if col in future_df.columns]\n",
"\n",
" inputs = []\n",
" for item_id, past_group in past_df.groupby(id_column):\n",
" if len(target_cols) > 1:\n",
" target_values = [past_group[col].tolist() for col in target_cols]\n",
" series_length = len(target_values[0])\n",
" else:\n",
" target_values = past_group[target_cols[0]].tolist()\n",
" series_length = len(target_values)\n",
"\n",
" if series_length < 5:\n",
" raise ValueError(f\"Time series '{item_id}' has fewer than 5 observations.\")\n",
"\n",
" series_dict = {\n",
" \"target\": target_values,\n",
" \"item_id\": str(item_id),\n",
" \"start\": past_group[timestamp_column].iloc[0].isoformat(),\n",
" }\n",
"\n",
" if past_covariate_cols:\n",
" series_dict[\"past_covariates\"] = past_group[past_covariate_cols].to_dict(orient=\"list\")\n",
"\n",
" if future_covariate_cols:\n",
" future_group = future_df[future_df[id_column] == item_id]\n",
" if len(future_group) != prediction_length:\n",
" raise ValueError(\n",
" f\"future_df must contain exactly {prediction_length=} values for each item_id from past_df \"\n",
" f\"(got {len(future_group)=}) for {item_id=}\"\n",
" )\n",
" series_dict[\"future_covariates\"] = future_group[future_covariate_cols].to_dict(orient=\"list\")\n",
"\n",
" inputs.append(series_dict)\n",
"\n",
" return {\n",
" \"inputs\": inputs,\n",
" \"parameters\": {\"prediction_length\": prediction_length, \"freq\": freq},\n",
" }\n",
"\n",
"\n",
"def convert_response_to_df(response, freq=\"D\"):\n",
" \"\"\"\n",
" Converts a JSON response from the Chronos endpoint into a long-format DataFrame.\n",
"\n",
" Args:\n",
" response: JSON response containing forecasts.\n",
" freq: Pandas-compatible frequency of the time series.\n",
"\n",
" Returns:\n",
" pd.DataFrame: Long-format DataFrame with timestamps, item_id, and forecasted values.\n",
" \"\"\"\n",
" dfs = []\n",
" for forecast in response[\"predictions\"]:\n",
" if isinstance(forecast[\"mean\"], list) and isinstance(forecast[\"mean\"][0], list):\n",
" # Multivariate forecast\n",
" timestamps = pd.date_range(forecast[\"start\"], freq=freq, periods=len(forecast[\"mean\"][0]))\n",
" for dim_idx in range(len(forecast[\"mean\"])):\n",
" dim_data = {\"item_id\": forecast.get(\"item_id\"), \"timestamp\": timestamps, \"target\": f\"target_{dim_idx + 1}\"}\n",
" for key, value in forecast.items():\n",
" if key not in [\"item_id\", \"start\"]:\n",
" dim_data[key] = value[dim_idx]\n",
" dfs.append(pd.DataFrame(dim_data))\n",
" else:\n",
" # Univariate forecast\n",
" forecast_df = pd.DataFrame(forecast).drop(columns=[\"start\"])\n",
" forecast_df[\"timestamp\"] = pd.date_range(forecast[\"start\"], freq=freq, periods=len(forecast_df))\n",
" cols = [\"item_id\", \"timestamp\"] + [c for c in forecast_df.columns if c not in [\"item_id\", \"timestamp\"]]\n",
" forecast_df = forecast_df[cols]\n",
" dfs.append(forecast_df)\n",
"\n",
" return pd.concat(dfs, ignore_index=True)"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "ed20630a",
"metadata": {
"lines_to_next_cell": 0
},
"outputs": [],
"source": [
"df = pd.read_csv(\n",
" \"https://autogluon.s3.amazonaws.com/datasets/timeseries/grocery_sales/test.csv\",\n",
" parse_dates=[\"timestamp\"],\n",
")\n",
"\n",
"prediction_length = 8\n",
"target_col = \"unit_sales\"\n",
"freq = pd.infer_freq(df[df.item_id == df.item_id[0]][\"timestamp\"])\n",
"\n",
"past_df = df.groupby(\"item_id\").head(-prediction_length)\n",
"future_df = df.groupby(\"item_id\").tail(prediction_length).drop(columns=[target_col])"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "e9528df2",
"metadata": {
"lines_to_next_cell": 0
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" item_id | \n",
" timestamp | \n",
" scaled_price | \n",
" promotion_email | \n",
" promotion_homepage | \n",
" unit_sales | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 1062_101 | \n",
" 2018-01-01 | \n",
" 0.879130 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 636.0 | \n",
"
\n",
" \n",
" | 1 | \n",
" 1062_101 | \n",
" 2018-01-08 | \n",
" 0.994517 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 123.0 | \n",
"
\n",
" \n",
" | 2 | \n",
" 1062_101 | \n",
" 2018-01-15 | \n",
" 1.005513 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 391.0 | \n",
"
\n",
" \n",
" | 3 | \n",
" 1062_101 | \n",
" 2018-01-22 | \n",
" 1.000000 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 339.0 | \n",
"
\n",
" \n",
" | 4 | \n",
" 1062_101 | \n",
" 2018-01-29 | \n",
" 0.883309 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 661.0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" item_id timestamp scaled_price promotion_email promotion_homepage \\\n",
"0 1062_101 2018-01-01 0.879130 0.0 0.0 \n",
"1 1062_101 2018-01-08 0.994517 0.0 0.0 \n",
"2 1062_101 2018-01-15 1.005513 0.0 0.0 \n",
"3 1062_101 2018-01-22 1.000000 0.0 0.0 \n",
"4 1062_101 2018-01-29 0.883309 0.0 0.0 \n",
"\n",
" unit_sales \n",
"0 636.0 \n",
"1 123.0 \n",
"2 391.0 \n",
"3 339.0 \n",
"4 661.0 "
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"past_df.head()"
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "b3e58f08",
"metadata": {
"lines_to_next_cell": 0
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" item_id | \n",
" timestamp | \n",
" scaled_price | \n",
" promotion_email | \n",
" promotion_homepage | \n",
"
\n",
" \n",
" \n",
" \n",
" | 23 | \n",
" 1062_101 | \n",
" 2018-06-11 | \n",
" 1.005425 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" | 24 | \n",
" 1062_101 | \n",
" 2018-06-18 | \n",
" 1.005454 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" | 25 | \n",
" 1062_101 | \n",
" 2018-06-25 | \n",
" 1.000000 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" | 26 | \n",
" 1062_101 | \n",
" 2018-07-02 | \n",
" 1.005513 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" | 27 | \n",
" 1062_101 | \n",
" 2018-07-09 | \n",
" 1.000000 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" item_id timestamp scaled_price promotion_email promotion_homepage\n",
"23 1062_101 2018-06-11 1.005425 0.0 0.0\n",
"24 1062_101 2018-06-18 1.005454 0.0 0.0\n",
"25 1062_101 2018-06-25 1.000000 0.0 0.0\n",
"26 1062_101 2018-07-02 1.005513 0.0 0.0\n",
"27 1062_101 2018-07-09 1.000000 0.0 0.0"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"future_df.head()"
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "68aef8ee",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" item_id | \n",
" timestamp | \n",
" mean | \n",
" 0.1 | \n",
" 0.5 | \n",
" 0.9 | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 1062_101 | \n",
" 2018-06-11 | \n",
" 320.102539 | \n",
" 186.102356 | \n",
" 320.102539 | \n",
" 486.852112 | \n",
"
\n",
" \n",
" | 1 | \n",
" 1062_101 | \n",
" 2018-06-18 | \n",
" 317.431396 | \n",
" 174.692490 | \n",
" 317.431396 | \n",
" 495.592224 | \n",
"
\n",
" \n",
" | 2 | \n",
" 1062_101 | \n",
" 2018-06-25 | \n",
" 316.319000 | \n",
" 169.798355 | \n",
" 316.319000 | \n",
" 507.396881 | \n",
"
\n",
" \n",
" | 3 | \n",
" 1062_101 | \n",
" 2018-07-02 | \n",
" 316.502472 | \n",
" 170.463837 | \n",
" 316.502472 | \n",
" 505.163483 | \n",
"
\n",
" \n",
" | 4 | \n",
" 1062_101 | \n",
" 2018-07-09 | \n",
" 309.931396 | \n",
" 164.362732 | \n",
" 309.931396 | \n",
" 505.276794 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" item_id timestamp mean 0.1 0.5 0.9\n",
"0 1062_101 2018-06-11 320.102539 186.102356 320.102539 486.852112\n",
"1 1062_101 2018-06-18 317.431396 174.692490 317.431396 495.592224\n",
"2 1062_101 2018-06-25 316.319000 169.798355 316.319000 507.396881\n",
"3 1062_101 2018-07-02 316.502472 170.463837 316.502472 505.163483\n",
"4 1062_101 2018-07-09 309.931396 164.362732 309.931396 505.276794"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"payload = convert_df_to_payload(past_df, future_df, prediction_length=prediction_length, freq=freq, target=\"unit_sales\")\n",
"response = predictor.predict(payload)\n",
"forecast_df = convert_response_to_df(response, freq=freq)\n",
"forecast_df.head()"
]
},
{
"cell_type": "markdown",
"id": "19467a6e",
"metadata": {},
"source": [
"### Clean Up\n",
"\n",
"The endpoint incurs charges until deleted. Alternatively, you can configure [scaling to zero](https://docs.aws.amazon.com/sagemaker/latest/dg/endpoint-auto-scaling-zero-instances.html) to save costs when the endpoint is idle."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b2ab130b",
"metadata": {},
"outputs": [],
"source": [
"predictor.delete_predictor()"
]
},
{
"cell_type": "markdown",
"id": "bad58ba2",
"metadata": {},
"source": [
"---\n",
"## Setup for Serverless Inference and Batch Transform\n",
"\n",
"Serverless Inference and Batch Transform only support CPU instances. Unlike real-time inference with JumpStart, these modes require you to create a custom SageMaker Model with repackaged artifacts.\n",
"\n",
"The following section sets up a reusable model that you can use for both Serverless (Section 2) and Batch Transform (Section 3)."
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "64a2da8d",
"metadata": {
"lines_to_next_cell": 1
},
"outputs": [],
"source": [
"import boto3\n",
"import json\n",
"import tempfile\n",
"import tarfile\n",
"from pathlib import Path\n",
"from sagemaker import Session\n",
"from sagemaker.model import Model\n",
"from sagemaker.jumpstart.model import JumpStartModel\n",
"from sagemaker.serializers import JSONSerializer\n",
"from sagemaker.deserializers import JSONDeserializer\n",
"\n",
"\n",
"def repackage_jumpstart_model(js_model, output_bucket, output_key):\n",
" \"\"\"\n",
" Repackages JumpStart model artifacts into a single tar.gz file for serverless/batch deployment.\n",
"\n",
" Args:\n",
" js_model: JumpStartModel instance with model_data configured.\n",
" output_bucket: S3 bucket to store the repackaged model.\n",
" output_key: S3 key for the output tar.gz file.\n",
"\n",
" Returns:\n",
" str: S3 URI of the repackaged model.\n",
" \"\"\"\n",
" s3 = boto3.client(\"s3\")\n",
" s3_uri = js_model.model_data[\"S3DataSource\"][\"S3Uri\"].rstrip(\"/\") + \"/\"\n",
" bucket, prefix = s3_uri.replace(\"s3://\", \"\").split(\"/\", 1)\n",
"\n",
" with tempfile.TemporaryDirectory() as tmpdir:\n",
" tmpdir = Path(tmpdir)\n",
"\n",
" # Download all model artifacts\n",
" for page in s3.get_paginator(\"list_objects_v2\").paginate(Bucket=bucket, Prefix=prefix):\n",
" for obj in page.get(\"Contents\", []):\n",
" if not obj[\"Key\"].endswith(\"/\"):\n",
" local_file = tmpdir / obj[\"Key\"][len(prefix):]\n",
" local_file.parent.mkdir(parents=True, exist_ok=True)\n",
" s3.download_file(bucket, obj[\"Key\"], str(local_file))\n",
"\n",
" # Create tar.gz archive\n",
" tar_path = tmpdir / \"model.tar.gz\"\n",
" with tarfile.open(tar_path, \"w:gz\") as tar:\n",
" tar.add(tmpdir, arcname=\".\")\n",
"\n",
" s3.upload_file(str(tar_path), output_bucket, output_key)\n",
"\n",
" return f\"s3://{output_bucket}/{output_key}\""
]
},
{
"cell_type": "markdown",
"id": "44cb4b63",
"metadata": {},
"source": [
"### Create the SageMaker Model\n",
"\n",
"This model can be used for both Serverless Inference and Batch Transform."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "83b50172",
"metadata": {},
"outputs": [],
"source": [
"# Reuse the role defined in Setup, or define a new one\n",
"# role = None # or \"arn:aws:iam::...\"\n",
"\n",
"# Use JumpStart to get the model artifacts and container image\n",
"js_model = JumpStartModel(\n",
" model_id=\"pytorch-forecasting-chronos-2\",\n",
" instance_type=\"ml.c5.4xlarge\", # Important: use CPU instance to ensure that correct image_uri is used\n",
" role=role,\n",
")\n",
"\n",
"# Repackage model artifacts into a single tar.gz\n",
"session = Session()\n",
"bucket = session.default_bucket() # or \"your-bucket-name\"\n",
"s3_prefix = \"chronos-2\" # S3 prefix for model artifacts and data\n",
"\n",
"model_uri = repackage_jumpstart_model(js_model, bucket, output_key=f\"{s3_prefix}/model.tar.gz\")\n",
"print(f\"Repackaged model uploaded to: {model_uri}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "89da3362",
"metadata": {},
"outputs": [],
"source": [
"from sagemaker.predictor import Predictor\n",
"\n",
"chronos_model = Model(\n",
" name=\"chronos-2-cpu\", # Important: Model name should start with 'chronos-2'\n",
" model_data=model_uri,\n",
" image_uri=js_model.image_uri,\n",
" role=role,\n",
" predictor_cls=Predictor,\n",
")\n",
"chronos_model.create()"
]
},
{
"cell_type": "markdown",
"id": "129bc389",
"metadata": {},
"source": [
"Alternatively, you can load an existing model as follows:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f7cb0f14",
"metadata": {},
"outputs": [],
"source": [
"# model_info = boto3.client(\"sagemaker\").describe_model(ModelName=\"chronos-2-cpu\")\n",
"# model = Model(\n",
"# model_data=model_info[\"PrimaryContainer\"][\"ModelDataUrl\"],\n",
"# image_uri=model_info[\"PrimaryContainer\"][\"Image\"],\n",
"# role=model_info[\"ExecutionRoleArn\"],\n",
"# name=model_info[\"ModelName\"],\n",
"# )"
]
},
{
"cell_type": "markdown",
"id": "ba12b52d",
"metadata": {},
"source": [
"---\n",
"## Section 2: Serverless Inference\n",
"\n",
"[Serverless Inference](https://docs.aws.amazon.com/sagemaker/latest/dg/serverless-endpoints.html) scales compute capacity based on traffic and scales to zero when idle, so you only pay for actual inference time.\n",
"\n",
"**When to use:**\n",
"- Sporadic or unpredictable traffic\n",
"- Cost-sensitive workloads with variable demand\n",
"- Development and testing environments\n",
"\n",
"**Limitations:**\n",
"- Cold start latency (first request after idle typically takes 30-60 seconds)\n",
"- Maximum memory: 6GB"
]
},
{
"cell_type": "markdown",
"id": "f6fc783c",
"metadata": {},
"source": [
"### Deploy Serverless Endpoint"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6006c1a6",
"metadata": {},
"outputs": [],
"source": [
"from sagemaker.serverless import ServerlessInferenceConfig\n",
"\n",
"serverless_predictor = chronos_model.deploy(\n",
" serverless_inference_config=ServerlessInferenceConfig(\n",
" memory_size_in_mb=6144, # Maximum available memory\n",
" max_concurrency=1,\n",
" ),\n",
" serializer=JSONSerializer(),\n",
" deserializer=JSONDeserializer(),\n",
")"
]
},
{
"cell_type": "markdown",
"id": "ed5056fa",
"metadata": {
"lines_to_next_cell": 0
},
"source": [
"### Query Serverless Endpoint"
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "8cb6f9f4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'predictions': [{'mean': [-0.36, 4.02, 5.3, 2.45, -2.48, -5.14, -4.33, 0.06, 4.42, 5.14],\n",
" '0.1': [-1.68, 2.86, 4.01, 1.01, -3.77, -6.22, -5.39, -1.77, 2.6, 3.62],\n",
" '0.5': [-0.36, 4.02, 5.3, 2.45, -2.48, -5.14, -4.33, 0.06, 4.42, 5.14],\n",
" '0.9': [1.02, 5.02, 6.32, 3.82, -0.85, -3.92, -2.93, 1.83, 5.63, 6.44]}]}\n"
]
}
],
"source": [
"payload = {\n",
" \"inputs\": [\n",
" {\"target\": [0.0, 4.0, 5.0, 1.5, -3.0, -5.0, -3.0, 1.5, 5.0, 4.0, 0.0, -4.0, -5.0, -1.5, 3.0, 5.0, 3.0, -1.5, -5.0, -4.0]},\n",
" ],\n",
" \"parameters\": {\"prediction_length\": 10},\n",
"}\n",
"response = serverless_predictor.predict(payload)\n",
"print(pretty_format(response))"
]
},
{
"cell_type": "markdown",
"id": "b4224e6a",
"metadata": {},
"source": [
"### Clean Up"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "85d3f37e",
"metadata": {},
"outputs": [],
"source": [
"serverless_predictor.delete_predictor()"
]
},
{
"cell_type": "markdown",
"id": "47b366e3",
"metadata": {},
"source": [
"---\n",
"## Section 3: Batch Transform\n",
"\n",
"[Batch Transform](https://docs.aws.amazon.com/sagemaker/latest/dg/batch-transform.html) processes large datasets offline. SageMaker spins up compute, processes all data, and shuts down automatically.\n",
"\n",
"\n",
"**When to use:**\n",
"- Large-scale batch forecasting (thousands of time series)\n",
"- Scheduled or periodic forecasting jobs\n",
"- When latency is not critical\n",
"\n",
"**Limitations:**\n",
"- Not suitable for real-time predictions\n",
"- Requires data to be staged in S3"
]
},
{
"cell_type": "markdown",
"id": "7f94af60",
"metadata": {},
"source": [
"### Prepare Input Data\n",
"\n",
"The model uses the same API as described in the Endpoint API Reference at the end of the notebook, so you need to prepare your data in the expected JSON format.\n",
"\n",
"Batch Transform reads input from S3. Each line in the input file is a JSON payload that can contain multiple time series. For large datasets, use `items_per_record` to control how many time series are included per line (and thus per request)."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ebe07edc",
"metadata": {},
"outputs": [],
"source": [
"# Load sample data\n",
"df = pd.read_csv(\n",
" \"https://autogluon.s3.amazonaws.com/datasets/timeseries/grocery_sales/test.csv\",\n",
" parse_dates=[\"timestamp\"],\n",
")\n",
"\n",
"prediction_length = 8\n",
"target_col = \"unit_sales\"\n",
"freq = pd.infer_freq(df[df.item_id == df.item_id[0]][\"timestamp\"])\n",
"\n",
"past_df = df.groupby(\"item_id\").head(-prediction_length)\n",
"future_df = df.groupby(\"item_id\").tail(prediction_length).drop(columns=[target_col])\n",
"\n",
"# Convert DataFrame to payload and split into chunks\n",
"payload = convert_df_to_payload(past_df, future_df, prediction_length=prediction_length, freq=freq, target=target_col)\n",
"items_per_record = 100 # Number of time series per JSONL line\n",
"inputs, params = payload[\"inputs\"], payload[\"parameters\"]\n",
"lines = [json.dumps({\"inputs\": inputs[i:i + items_per_record], \"parameters\": params}) for i in range(0, len(inputs), items_per_record)]\n",
"\n",
"# Upload input data to S3\n",
"input_key = f\"{s3_prefix}/batch-input/input.jsonl\"\n",
"boto3.client(\"s3\").put_object(Bucket=bucket, Key=input_key, Body=\"\\n\".join(lines).encode())\n",
"input_s3_uri = f\"s3://{bucket}/{input_key}\"\n",
"print(f\"Input data uploaded to: {input_s3_uri} ({len(lines)} records)\")"
]
},
{
"cell_type": "markdown",
"id": "87f9ad5f",
"metadata": {},
"source": [
"### Run Batch Transform\n",
"\n",
"This uses the same `chronos_model` created in the setup section above."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d49d7b94",
"metadata": {},
"outputs": [],
"source": [
"from sagemaker.transformer import Transformer\n",
"\n",
"output_s3_uri = f\"s3://{bucket}/{s3_prefix}/batch-output/\"\n",
"\n",
"transformer = Transformer(\n",
" model_name=chronos_model.name,\n",
" instance_count=1,\n",
" instance_type=\"ml.c5.4xlarge\", # CPU instance\n",
" output_path=output_s3_uri,\n",
" strategy=\"SingleRecord\", # Process one JSON line at a time\n",
" assemble_with=\"Line\",\n",
" accept=\"application/json\",\n",
")\n",
"\n",
"transformer.transform(\n",
" data=input_s3_uri,\n",
" content_type=\"application/json\",\n",
" split_type=\"Line\",\n",
" wait=True,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "dd65239f",
"metadata": {},
"source": [
"### Retrieve Batch Results"
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "8c0fe2a4",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" item_id | \n",
" timestamp | \n",
" mean | \n",
" 0.1 | \n",
" 0.5 | \n",
" 0.9 | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 1062_101 | \n",
" 2018-06-11 | \n",
" 320.102539 | \n",
" 186.102356 | \n",
" 320.102539 | \n",
" 486.852112 | \n",
"
\n",
" \n",
" | 1 | \n",
" 1062_101 | \n",
" 2018-06-18 | \n",
" 317.431396 | \n",
" 174.692490 | \n",
" 317.431396 | \n",
" 495.592224 | \n",
"
\n",
" \n",
" | 2 | \n",
" 1062_101 | \n",
" 2018-06-25 | \n",
" 316.319000 | \n",
" 169.798355 | \n",
" 316.319000 | \n",
" 507.396881 | \n",
"
\n",
" \n",
" | 3 | \n",
" 1062_101 | \n",
" 2018-07-02 | \n",
" 316.502472 | \n",
" 170.463837 | \n",
" 316.502472 | \n",
" 505.163483 | \n",
"
\n",
" \n",
" | 4 | \n",
" 1062_101 | \n",
" 2018-07-09 | \n",
" 309.931396 | \n",
" 164.362732 | \n",
" 309.931396 | \n",
" 505.276794 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" item_id timestamp mean 0.1 0.5 0.9\n",
"0 1062_101 2018-06-11 320.102539 186.102356 320.102539 486.852112\n",
"1 1062_101 2018-06-18 317.431396 174.692490 317.431396 495.592224\n",
"2 1062_101 2018-06-25 316.319000 169.798355 316.319000 507.396881\n",
"3 1062_101 2018-07-02 316.502472 170.463837 316.502472 505.163483\n",
"4 1062_101 2018-07-09 309.931396 164.362732 309.931396 505.276794"
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"output_key = f\"{s3_prefix}/batch-output/{input_key.split('/')[-1]}.out\"\n",
"result = boto3.client(\"s3\").get_object(Bucket=bucket, Key=output_key)\n",
"output_lines = result[\"Body\"].read().decode().strip().split(\"\\n\")\n",
"\n",
"# Combine predictions from all records\n",
"all_predictions = [p for line in output_lines for p in json.loads(line)[\"predictions\"]]\n",
"forecast_df = convert_response_to_df({\"predictions\": all_predictions}, freq=freq)\n",
"forecast_df.head()"
]
},
{
"cell_type": "markdown",
"id": "19a8d5d6",
"metadata": {},
"source": [
"---\n",
"## See Also\n",
"\n",
"- [Scale real-time endpoints to zero](https://docs.aws.amazon.com/sagemaker/latest/dg/endpoint-auto-scaling-zero-instances.html) to optimize costs when the endpoint is idle\n",
"- [Asynchronous Inference](https://docs.aws.amazon.com/sagemaker/latest/dg/async-inference.html) handles traffic spikes better than real-time inference thanks to request queueing"
]
},
{
"cell_type": "markdown",
"id": "0164216d",
"metadata": {},
"source": [
"---\n",
"## Endpoint API Reference\n",
"\n",
"Below is a complete API specification for the Chronos-2 endpoint.\n",
"\n",
"* **inputs** (required): List with at most 1000 time series that need to be forecasted. Each time series is represented by a dictionary with the following keys:\n",
" * **target** (required): Observed time series values.\n",
" - For univariate forecasting: List of numeric values.\n",
" - For multivariate forecasting: List of lists, where each inner list represents one dimension. All dimensions must have the same length. If converted to a numpy array via `np.array(target)`, the shape would be `[num_dimensions, length]`.\n",
" - It is recommended that each time series contains at least 30 observations.\n",
" - If any time series contains fewer than 5 observations, an error will be raised.\n",
" * **item_id**: String that uniquely identifies each time series.\n",
" - If provided, the ID must be unique for each time series.\n",
" - If provided, then the endpoint response will also include the **item_id** field for each forecast.\n",
" * **start**: Timestamp of the first time series observation in ISO format (`YYYY-MM-DD` or `YYYY-MM-DDThh:mm:ss`).\n",
" - If **start** field is provided, then **freq** must also be provided as part of **parameters**.\n",
" - If provided, then the endpoint response will also include the **start** field indicating the first timestamp of each forecast.\n",
" * **past_covariates**: Dictionary containing the past values of the covariates for this time series.\n",
" - Each key in **past_covariates** corresponds to the name of the covariate. Each value must be an array consisting of all-numeric or all-string values, with the length equal to the length of the **target**.\n",
" - Covariates that appear only in **past_covariates** (and not in **future_covariates**) are treated as past-only covariates.\n",
" * **future_covariates**: Dictionary containing the future values of the covariates for this time series (values during the forecast horizon).\n",
" - Each key in **future_covariates** corresponds to the name of the covariate. Each value must be an array consisting of all-numeric or all-string values, with the length equal to **prediction_length**.\n",
" - Covariates that appear in both **past_covariates** and **future_covariates** are treated as known future covariates.\n",
"* **parameters**: Optional parameters to configure the model.\n",
" * **prediction_length**: Integer corresponding to the number of future time series values that need to be predicted. Defaults to `1`. Values up to `1024` are supported.\n",
" * **quantile_levels**: List of floats in range (0, 1) specifying which quantiles should be included in the probabilistic forecast. Defaults to `[0.1, 0.5, 0.9]`.\n",
" - Chronos-2 natively supports quantile levels in range `[0.01, 0.99]`. Predictions outside the range will be clipped.\n",
" * **freq**: Frequency of the time series observations in [pandas-compatible format](https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases). For example, `1h` for hourly data or `2W` for bi-weekly data.\n",
" - If **freq** is provided, then **start** must also be provided for each time series in **inputs**.\n",
" * **batch_size**: Number of time series processed in parallel by the model. Larger values speed up inference but may lead to out of memory errors. Defaults to `256`.\n",
" * **cross_learning**: If `True`, the model will apply group attention to all items in the batch, instead of processing each item separately (described as \"full cross-learning mode\" in the [technical report](https://www.arxiv.org/abs/2510.15821)). This may produce more accurate forecasts for some tasks. Defaults to `False`.\n",
"\n",
"All keys not marked with (required) are optional.\n",
"\n",
"The endpoint response contains the probabilistic (quantile) forecast for each time series included in the request."
]
}
],
"metadata": {
"jupytext": {
"cell_metadata_filter": "-all",
"main_language": "python",
"notebook_metadata_filter": "-all"
},
"kernelspec": {
"display_name": "ag",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.12.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}