{ "cells": [ { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "# Text classification of social media messages stream" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "## Install Apache Kafka python library" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "!pip install confluent-kafka" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "## Import dependencies" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "from confluent_kafka import Producer, Consumer, KafkaError # to produce and consume data from Apache Kafka topics\n", "import boto3 # to programmatically create, configure, and manage AWS resources\n", "import json # to work with social media messages that are represented as JSON objects\n", "import re # for helper functionality to clean HTML tags from social media messages\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define a function to run model inference" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "# Define a mapping dictionary to map model labels to negative/positive label\n", "label_mapping = {'LABEL_0': 'negative', 'LABEL_1': 'positive'}\n", "\n", "def get_prediction(text):\n", " endpoint_name = 'jumpstart-dft-distilbert-tc-base-multilingual-cased'\n", " client = boto3.client('runtime.sagemaker')\n", " query_response = client.invoke_endpoint(EndpointName=endpoint_name, ContentType='application/x-text', Body=text, Accept='application/json;verbose')\n", " model_predictions = json.loads(query_response['Body'].read())\n", " probabilities, labels, predicted_label = model_predictions['probabilities'], model_predictions['labels'], model_predictions['predicted_label']\n", " # Map the predicted_label to your the label using the mapping dictionary\n", " predicted_label = label_mapping.get(predicted_label, predicted_label)\n", " return probabilities, labels, predicted_label\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Set up Apache Kafka connection properties" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "apache_kafka_ssl_config = {\n", " 'ssl.ca.location': 'ca.pem',\n", " 'ssl.certificate.location': 'service.cert',\n", " 'ssl.key.location': 'service.key',\n", " 'security.protocol': 'ssl',\n", "}\n", "\n", "apache_kafka_uri = 'your-kafka-uri'\n", "\n", "apache_kafka_input_topic_name = 'social_media_messages'\n", "apache_kafka_enriched_output_topic_name = 'enriched_data'\n", "apache_kafka_processing_errors_topic_name = 'processing_errors'\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create Apache Kafka consumer and consuming logic" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "consumer = Consumer({'bootstrap.servers': apache_kafka_uri, 'group.id': 'mygroup', 'auto.offset.reset': 'earliest', **apache_kafka_ssl_config})\n", "consumer.subscribe([apache_kafka_input_topic_name])\n", "\n", "CLEANR = re.compile('<.*?>') \n", "\n", "def get_json_body(message): \n", " decoded_message = message.value().decode('utf-8') # Decode from binary \n", " json_message = json.loads(decoded_message) # Parse JSON message\n", " return json_message\n", "\n", "def get_clean_content(json_object): \n", " content = json_object.get(\"content\", \"\") # Retrieve 'content' property \n", " only_text = re.sub(CLEANR, '', content)\n", " return only_text" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create Apache Kafka producer and producing logic" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "producer = Producer({\n", " 'bootstrap.servers': apache_kafka_uri, \n", " **apache_kafka_ssl_config \n", "})\n", "\n", "# Send a message to a Kafka topic\n", "def send_message(message, topic_name):\n", " producer.produce(topic_name, json.dumps(message).encode('utf-8'))\n", " producer.flush()\n", " \n", "def send_enriched_data(message, probabilities, predicted_label):\n", " message['probabilities'] = probabilities\n", " message['predition'] = predicted_label\n", " send_message(message, apache_kafka_enriched_output_topic_name)\n", " \n", "def report_processing_error(message, error_code, error_message):\n", " message['processing_error_code'] = error_code\n", " message['processing_error_message'] = error_message\n", " send_message(message, apache_kafka_processing_errors_topic_name)\n", " " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Processing records from Apache Kafka input topic - enriching or reporting errors" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "print(f\"Processing messages\")\n", "while True:\n", " message = consumer.poll(1.0) # Poll for messages, with a timeout of 1 second\n", "\n", " if message is None:\n", " continue\n", "\n", " if message.error():\n", " if message.error().code() == KafkaError._PARTITION_EOF:\n", " # End of partition event\n", " print(f\"Reached end of partition for topic {message.topic()} [{message.partition()}]\")\n", " else:\n", " print(f\"Error while consuming message: {message.error()}\")\n", " else:\n", " # Process the message\n", " json_body = get_json_body(message)\n", " content_property = get_clean_content(json_body)\n", " if content_property == \"\":\n", " continue\n", "\n", " try:\n", " probabilities, labels, predicted_label = get_prediction(content_property)\n", " print(f\"Inference:\\n\"\n", " f\"Input text: '{content_property}'\\n\"\n", " f\"Model prediction: {probabilities}\\n\"\n", " f\"Predicted label: {predicted_label}\\n\")\n", "\n", " send_enriched_data(json_body, probabilities, predicted_label)\n", " \n", "\n", " except Exception as e:\n", " print(f\"An error occurred: {e}\")\n", " response = getattr(e, \"response\", {})\n", " error_code = response.get(\"Error\", {}).get(\"Code\", \"Unknown\")\n", " error_message = response.get(\"Error\", {}).get(\"Message\", \"Unknown\")\n", " report_processing_error(json_body, error_code, error_message)\n", " \n", "\n", "# Close the consumer\n", "consumer.close()" ] } ], "metadata": { "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13" } }, "nbformat": 4, "nbformat_minor": 4 }