{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "import os\n", "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\";\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\" " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Zero Shot Learning Using Natural Language Inference\n", "\n", "In this notebook, we will demonstrate **zero-shot** topic classification. **Zero-Shot Learning (ZSL)** is being able to solve a task despite not having received any training examples of that task. The `ZeroShotClassifier` class in *ktrain* can be used to perform topic classification with no training examples. The technique is based on **Natural Language Inference (or NLI)** as described in [this interesting blog post](https://joeddav.github.io/blog/2020/05/29/ZSL.html) by Joe Davison.\n", "\n", "## STEP 1: Setup the Zero Shot Classifier and Describe Topics\n", "\n", "We first instantiate the zero-shot-classifier and then describe the topic labels for our classifier with strings." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from ktrain.text.zsl import ZeroShotClassifier" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "zsl = ZeroShotClassifier()\n", "labels=['politics', 'elections', 'sports', 'films', 'television']" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 2: Predict\n", "\n", "There is no training involved here, as we are using **zero-shot-learning**. We will simply supply the document that is being classified and the `topic_strings` defined earlier. The `predict` method uses Natural Language Inference (NLI) to infer the topic probabilities." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[('politics', 0.979189932346344),\n", " ('elections', 0.9874580502510071),\n", " ('sports', 0.0005765462410636246),\n", " ('films', 0.0022924456279724836),\n", " ('television', 0.0010546103585511446)]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "doc = 'I am extremely dissatisfied with the President and will definitely vote in 2020.'\n", "zsl.predict(doc, labels=labels, include_labels=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As you can see, our model correctly assigned the highest probabilities to `politics` and `elections`, as the text supplied pertains to both these topics.\n", "\n", "Let's try some other examples.\n", "#### document about `television`" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[('politics', 0.00015667644038330764),\n", " ('elections', 0.00032881161314435303),\n", " ('sports', 0.00013884963118471205),\n", " ('films', 0.07557642459869385),\n", " ('television', 0.9813269376754761)]" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "doc = 'What is your favorite sitcom of all time?'\n", "zsl.predict(doc, labels=labels, include_labels=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### document about both `politics` and `television`" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[('politics', 0.8049427270889282),\n", " ('elections', 0.01889326609671116),\n", " ('sports', 0.005504833068698645),\n", " ('films', 0.05876927077770233),\n", " ('television', 0.8776823878288269)]" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "doc = \"\"\"\n", "President Donald Trump's senior adviser and son-in-law, Jared Kushner, praised \n", "the administration's response to the coronavirus pandemic as a \\\"great success story\\\" on Wednesday -- \n", "less than a day after the number of confirmed coronavirus cases in the United States topped 1 million. \n", "Kushner painted a rosy picture for \\\"Fox and Friends\\\" Wednesday morning, \n", "saying that \\\"the federal government rose to the challenge and \n", "this is a great success story and I think that that's really what needs to be told.\\\"\n", "\"\"\"\n", "zsl.predict(doc, labels=labels, include_labels=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### document about `sports`, `television`, and `film`" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[('politics', 0.0005349867278710008),\n", " ('elections', 0.0007852867711335421),\n", " ('sports', 0.9848827123641968),\n", " ('films', 0.9576993584632874),\n", " ('television', 0.941143274307251)]" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "doc = \"The Last Dance is a 2020 American basketball documentary miniseries co-produced by ESPN Films and Netflix.\"\n", "zsl.predict(doc, labels=labels, include_labels=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Customizing the Classifier for Zero-Shot Sentiment Analysis\n", "\n", "As stated above, the `ZeroShotClassifier` is implemented using Natural Language Inference (NLI). That is, the document is treated as a **premise**, and each label is treated as a **hypothesis**. To predict labels, an NLI model is used to predict whether or not each label is entailed by the premise. By default, the template used for the hypothesis is of the form `\"This text is about