{ "cells": [ { "cell_type": "markdown", "id": "22d246ff", "metadata": {}, "source": [ "# Using Large Language Models as text classifiers with an sklearn interface" ] }, { "cell_type": "markdown", "id": "74501674", "metadata": {}, "source": [ "In this notebook, we will learn how to use skorch's `ZeroShotClassifier` and `FewShotClassifier` to perform classification without any training thanks to the power of (Large) Language Models (LLMs). For this, we rely on the the [Hugging Face transformers](https://huggingface.co/docs/transformers/index) library, which allows us to use all the available text generation models provided by Hugging Face." ] }, { "cell_type": "markdown", "id": "595cf8f6", "metadata": {}, "source": [ "
\n", "\n", " Run in Google Colab \n", "\n", "View source on GitHub
" ] }, { "cell_type": "markdown", "id": "e14fd81e", "metadata": {}, "source": [ "The notebook requires Hugging Face `transformers` and `datasets` as additional dependencies. If you have not already installed it, you can do so like this:\n", "\n", "`python -m pip install transformers datasets`" ] }, { "cell_type": "code", "execution_count": 1, "id": "78e06744", "metadata": {}, "outputs": [], "source": [ "import subprocess\n", "\n", "# Installation on Google Colab\n", "try:\n", " import google.colab\n", " subprocess.run(['python', '-m', 'pip', 'install', 'skorch', 'transformers', 'datasets'])\n", "except ImportError:\n", " pass" ] }, { "cell_type": "markdown", "id": "c18a9e70", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 2, "id": "7fb9b863", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import datasets\n", "import numpy as np\n", "import pandas as pd\n", "import transformers\n", "import torch\n", "from sklearn.metrics import accuracy_score, log_loss\n", "from sklearn.model_selection import GridSearchCV" ] }, { "cell_type": "code", "execution_count": 3, "id": "170f07d6", "metadata": {}, "outputs": [], "source": [ "# let's reduce some of the noise from transformers and datasets logs\n", "transformers.logging.set_verbosity_warning()\n", "datasets.logging.set_verbosity_error()" ] }, { "cell_type": "code", "execution_count": 4, "id": "d3a748a9", "metadata": {}, "outputs": [], "source": [ "device = 'cuda:0' if torch.cuda.is_available() else 'cpu'" ] }, { "cell_type": "markdown", "id": "93b17651", "metadata": {}, "source": [ "## Load data" ] }, { "cell_type": "markdown", "id": "765481ec", "metadata": {}, "source": [ "For this example, we make use of the IMDB dataset. It consists of movie reviews written by IMDB users and the target is the sentiment, i.e. \"positive\" or \"negative\"." ] }, { "cell_type": "code", "execution_count": 5, "id": "4e7cbed2", "metadata": { "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 3/3 [00:00<00:00, 877.10it/s]\n" ] } ], "source": [ "imdb = datasets.load_dataset('imdb').shuffle(seed=0)" ] }, { "cell_type": "markdown", "id": "c14e4f75", "metadata": {}, "source": [ "We limit the number of samples to 100. Using zero/few-shot learning mostly makes sense when there are few labeled samples, otherwise, supervised machine learning methods will probably give better results." ] }, { "cell_type": "code", "execution_count": 6, "id": "f1abbf04", "metadata": {}, "outputs": [], "source": [ "X = imdb['train'][:100]['text']\n", "y = imdb['train'][:100]['label']" ] }, { "cell_type": "markdown", "id": "a3a682af", "metadata": {}, "source": [ "Let's take a quick look at the data. Our `X` contains the user reviews:" ] }, { "cell_type": "code", "execution_count": 7, "id": "5e8b7b80", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "We always watch American movies with their particular accents from each region (south, west, etc). We have the same here. All foreign people must to watch this movie and need to have a open mind to accept another culture, besides American and European almost dominate the cinematographic industry.

This movie tell us about a parallel world which it isn't figured even for those who live in a big city like São Paulo. All actors are improvising and they are very realistic. The camera give us an idea of their confuse world, the loneliness of each character and invite us to share their world.

It's a real great movie and worst a rent even have it at home.\n" ] } ], "source": [ "print(X[0])" ] }, { "cell_type": "markdown", "id": "faadcf34", "metadata": {}, "source": [ "Our `y` contains the label-encoded targets:" ] }, { "cell_type": "code", "execution_count": 8, "id": "c580060a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[1, 1, 0, 1, 0]\n" ] } ], "source": [ "print(y[:5])" ] }, { "cell_type": "markdown", "id": "eb3baf13", "metadata": {}, "source": [ "For a standard machine learning solution, having label-encoded targets is desired. Here, we prefer to have the actual labels, however. It is much easier for the language model to predict the label \"positive\" for the text above than to predict \"1\". How would it know what \"1\" means? Sure, if we provide a few examples, it may work, but let's not make the language model's life harder than necessary and thus provide the actual labels." ] }, { "cell_type": "code", "execution_count": 9, "id": "33c309be", "metadata": {}, "outputs": [], "source": [ "labels = np.array(['negative', 'positive'])[y]" ] }, { "cell_type": "code", "execution_count": 10, "id": "7f5623a2", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array(['positive', 'positive', 'negative', 'positive', 'negative'],\n", " dtype='#sk-container-id-1 {color: black;background-color: white;}#sk-container-id-1 pre{padding: 0;}#sk-container-id-1 div.sk-toggleable {background-color: white;}#sk-container-id-1 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-1 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-1 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-1 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-1 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-1 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-1 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-1 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-1 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-1 div.sk-item {position: relative;z-index: 1;}#sk-container-id-1 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-1 div.sk-item::before, #sk-container-id-1 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-1 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-1 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-1 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-1 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-1 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-1 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-1 div.sk-label-container {text-align: center;}#sk-container-id-1 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-1 div.sk-text-repr-fallback {display: none;}
ZeroShotClassifier(device='cuda:0', model_name='google/flan-t5-small', use_caching=False)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "ZeroShotClassifier(device='cuda:0', model_name='google/flan-t5-small', use_caching=False)" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%time clf.fit(X=None, y=['positive', 'negative'])" ] }, { "cell_type": "markdown", "id": "c3a7ba7b", "metadata": {}, "source": [ "In general, fitting is fast because, basically, nothing happens. If the transformers model and tokenizer are not cached locally, they will, however, be downloaded from Hugging Face, which may take some time." ] }, { "cell_type": "markdown", "id": "6143c8fb", "metadata": {}, "source": [ "### evaluation" ] }, { "cell_type": "markdown", "id": "f9e6df68", "metadata": {}, "source": [ "Let's evaluate how well the model works. As with any sklearn-compatible model, we can just call `predict_proba` to get the probabilities that the model assigns to each sample:" ] }, { "cell_type": "code", "execution_count": 14, "id": "a17bf2a3", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Token indices sequence length is longer than the specified maximum sequence length for this model (844 > 512). Running this sequence through the model will result in indexing errors\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 24 s, sys: 221 ms, total: 24.3 s\n", "Wall time: 5.74 s\n" ] } ], "source": [ "%time y_proba = clf.predict_proba(X)" ] }, { "cell_type": "markdown", "id": "6f31ed19", "metadata": {}, "source": [ "The prediction speed is a bit slow, as should be expected from a language model. If runtime is a big concern, this is probably not the right approach.\n", "\n", "Now let's check how well the model does. First we check the log loss, then the accuracy:" ] }, { "cell_type": "code", "execution_count": 15, "id": "5fc5b571", "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "0.2870119934413143" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "log_loss(y, y_proba)" ] }, { "cell_type": "code", "execution_count": 16, "id": "d981cc01", "metadata": {}, "outputs": [], "source": [ "y_pred = y_proba.argmax(1)" ] }, { "cell_type": "code", "execution_count": 17, "id": "c0d4b490", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.86" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "accuracy_score(y, y_pred)" ] }, { "cell_type": "markdown", "id": "14179b24", "metadata": {}, "source": [ "Given that this is zero-shot, those scores are actually not so bad!\n", "\n", "Sure, on the [leaderboard](https://huggingface.co/spaces/autoevaluate/leaderboards?dataset=imdb&only_verified=0&task=-any-&config=-unspecified-&split=-unspecified-&metric=accuracy) we can find models with better accuracy, but those are fine-tuned on the dataset.\n", "\n", "Notice that if we call `predict`, we get back the labels, i.e. \"positive\" or \"negative\"." ] }, { "cell_type": "code", "execution_count": 18, "id": "6807dc86", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array(['positive'], dtype=' 512). Running this sequence through the model will result in indexing errors\n", "Token indices sequence length is longer than the specified maximum sequence length for this model (885 > 512). Running this sequence through the model will result in indexing errors\n", "Token indices sequence length is longer than the specified maximum sequence length for this model (844 > 512). Running this sequence through the model will result in indexing errors\n", "Token indices sequence length is longer than the specified maximum sequence length for this model (885 > 512). Running this sequence through the model will result in indexing errors\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 1min 39s, sys: 553 ms, total: 1min 39s\n", "Wall time: 24.3 s\n" ] }, { "data": { "text/html": [ "
GridSearchCV(cv=2,\n",
       "             estimator=ZeroShotClassifier(device='cuda:0', model_name='google/flan-t5-small', use_caching=False),\n",
       "             param_grid={'prompt': ['You are a text classification assistant.\\n'\n",
       "                                    '\\n'\n",
       "                                    'The text to classify:\\n'\n",
       "                                    '\\n'\n",
       "                                    '```\\n'\n",
       "                                    '{text}\\n'\n",
       "                                    '```\\n'\n",
       "                                    '\\n'\n",
       "                                    'Choose the label among the following '\n",
       "                                    'possibilities with the highest '\n",
       "                                    'probability.\\n'\n",
       "                                    'Only return the label, nothing more:\\n'\n",
       "                                    '\\n'\n",
       "                                    '{labels}\\n'\n",
       "                                    '\\n'\n",
       "                                    'Your response:\\n',\n",
       "                                    'Your task is to classify text.\\n'\n",
       "                                    '\\n'\n",
       "                                    'Choose the label among the following '\n",
       "                                    'possibilities with the highest '\n",
       "                                    'probability.\\n'\n",
       "                                    'Only return the label, nothing more:\\n'\n",
       "                                    '\\n'\n",
       "                                    '{labels}\\n'\n",
       "                                    '\\n'\n",
       "                                    'The text to classify:\\n'\n",
       "                                    '\\n'\n",
       "                                    '```\\n'\n",
       "                                    '{text}\\n'\n",
       "                                    '```\\n'\n",
       "                                    '\\n'\n",
       "                                    'Your response:\\n']},\n",
       "             refit=False, scoring=['accuracy', 'neg_log_loss'])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "GridSearchCV(cv=2,\n", " estimator=ZeroShotClassifier(device='cuda:0', model_name='google/flan-t5-small', use_caching=False),\n", " param_grid={'prompt': ['You are a text classification assistant.\\n'\n", " '\\n'\n", " 'The text to classify:\\n'\n", " '\\n'\n", " '```\\n'\n", " '{text}\\n'\n", " '```\\n'\n", " '\\n'\n", " 'Choose the label among the following '\n", " 'possibilities with the highest '\n", " 'probability.\\n'\n", " 'Only return the label, nothing more:\\n'\n", " '\\n'\n", " '{labels}\\n'\n", " '\\n'\n", " 'Your response:\\n',\n", " 'Your task is to classify text.\\n'\n", " '\\n'\n", " 'Choose the label among the following '\n", " 'possibilities with the highest '\n", " 'probability.\\n'\n", " 'Only return the label, nothing more:\\n'\n", " '\\n'\n", " '{labels}\\n'\n", " '\\n'\n", " 'The text to classify:\\n'\n", " '\\n'\n", " '```\\n'\n", " '{text}\\n'\n", " '```\\n'\n", " '\\n'\n", " 'Your response:\\n']},\n", " refit=False, scoring=['accuracy', 'neg_log_loss'])" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%time search.fit(X, labels)" ] }, { "cell_type": "markdown", "id": "965c3bf1", "metadata": {}, "source": [ "grid search results:" ] }, { "cell_type": "code", "execution_count": 24, "id": "92132751", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
mean_test_accuracymean_test_neg_log_lossparam_promptmean_score_time
00.86-0.287012You are a text classification assistant.\\n\\nTh...5.084494
10.93-0.246949Your task is to classify text.\\n\\nChoose the l...5.091446
\n", "
" ], "text/plain": [ " mean_test_accuracy mean_test_neg_log_loss \\\n", "0 0.86 -0.287012 \n", "1 0.93 -0.246949 \n", "\n", " param_prompt mean_score_time \n", "0 You are a text classification assistant.\\n\\nTh... 5.084494 \n", "1 Your task is to classify text.\\n\\nChoose the l... 5.091446 " ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pd.DataFrame(search.cv_results_)[['mean_test_accuracy', 'mean_test_neg_log_loss', 'param_prompt', 'mean_score_time']]" ] }, { "cell_type": "markdown", "id": "0ac85f73", "metadata": {}, "source": [ "**Conclusion**: `prompt1` is performing better than `prompt0`. The mean test accuracy of 93% and log loss of 0.25 are pretty good overall, given that we use zero-shot and don't perform any fine-tuning.\n", "\n", "Going further, we could also grid search different language models, or combinations of LLMs and prompts, to find the best working zero-shot model." ] }, { "cell_type": "markdown", "id": "e5102989", "metadata": {}, "source": [ "## Few-shot classification" ] }, { "cell_type": "markdown", "id": "b0d6f4f4", "metadata": {}, "source": [ "Sometimes, helping the language model out by providing a few examples will boost the performance. To test this, we skorch provides the `FewShotClassifier` class. Let's try it out." ] }, { "cell_type": "code", "execution_count": 25, "id": "2a6faa26", "metadata": {}, "outputs": [], "source": [ "from skorch.llm import FewShotClassifier\n", "from transformers import AutoModelForSeq2SeqLM, AutoTokenizer" ] }, { "cell_type": "markdown", "id": "5e9c34f3", "metadata": {}, "source": [ "### train the few-shot classifier" ] }, { "cell_type": "markdown", "id": "b5e9b863", "metadata": {}, "source": [ "Instead of passing the model name to initialize the classifier, as in `clf = FewShotClassifier('google/flan-t5-small')`, it is also possible to pass the model and tokenizer explicitly. This is a good option if you need more control over them. In our case, it amounts to the same result. It's useful to keep this option in mind, though, if the model requires any changes or if you want to provide a model that is not uploaded to Hugging Face." ] }, { "cell_type": "code", "execution_count": 26, "id": "d385ea2f", "metadata": {}, "outputs": [], "source": [ "model = AutoModelForSeq2SeqLM.from_pretrained('google/flan-t5-small').to(device)\n", "tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-small')" ] }, { "cell_type": "markdown", "id": "4b6995d4", "metadata": {}, "source": [ "To control the amount of samples used for few-shot learning, use `max_samples` parameter. In this case, let's use 5 examples:" ] }, { "cell_type": "code", "execution_count": 27, "id": "61b63a7f", "metadata": { "scrolled": true }, "outputs": [], "source": [ "clf = FewShotClassifier(\n", " model=model, tokenizer=tokenizer, max_samples=5, use_caching=False\n", ")" ] }, { "cell_type": "code", "execution_count": 28, "id": "e8803fa6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 919 µs, sys: 46 µs, total: 965 µs\n", "Wall time: 489 µs\n" ] }, { "data": { "text/html": [ "
FewShotClassifier(model='T5ForConditionalGeneration', tokenizer='T5TokenizerFast', use_caching=False)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "FewShotClassifier(model='T5ForConditionalGeneration', tokenizer='T5TokenizerFast', use_caching=False)" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%time clf.fit(X, labels)" ] }, { "cell_type": "markdown", "id": "2dfe57f9", "metadata": {}, "source": [ "Let's make sure that everything works as expected by inspecting the prompt. This is possible using the `get_prompt` method:" ] }, { "cell_type": "code", "execution_count": 29, "id": "4a368d1e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "You are a text classification assistant.\n", "\n", "Choose the label among the following possibilities with the highest probability.\n", "Only return the label, nothing more:\n", "\n", "['negative', 'positive']\n", "\n", "Here are a few examples:\n", "\n", "```\n", "Reese Witherspooon's first movie. Loved it. The plot and the acting was top notch. You are emotionally involved with the characters. In my opinion, a must see.

After watching this movie you will see why Reese Witherspoon's acting career has been so successful.

The other cast members do a great job also.

The movie flows extremely well. There is not a boring moment in the whole picture. The Man in the Moon's length is just right.

As I said earlier, I think this movie was excellent. I have seen it numerous times, and have enjoyed every one of the viewings.\n", "```\n", "\n", "Your response:\n", "positive\n", "\n", "```\n", "This movie set out to be better than the average action movie and in that regard they succeeded.This movie had spectacular cinematography featuring spectacular mountain snow and heights,a very fit Stallone putting in a good performance as well,an exciting plot,and a great performance from it's main villain becouse he will really shock you with his evil ways.The movie does not rank an all time great becouse of the weak screen play.The plot and story cries for this movie to make Stallone an extra special human,much like the Rambo or Rocky or Bond movie characters.They chose to humanise Stallone's character in this one which is ok but considering the plot's style,weakens the excitement factor.Also,the dialogue was cheesy and carelessly condescending at times.The script should have been more realistic and less \"talky\".Another weak point was the unrealistic shooting scenes.The movie makers should have been more carefull how they hadled the shooting hits and misses.They should have continued the quality of the scenes of the shooting sequences during the plane hijacking early in the movie.Instead,they decided to water down a lot of the shooting sequences (ala \"A-Team\" TV series) as soon as the villains set foot on the mountain tops.This movie had a lot of all time great potential.Crisper action sequences,better dialogue and more Rambo/Rocky style emotion/determination from Stallone would have taken this movie to a higher level.I know this was not Stallone's fault.I sense the movie's director wanted to tone down Stallone's character and try to steal the movie by taking credit for his direction which was not all that great if not for his cinematographer.Sill a good movie though........\n", "```\n", "\n", "Your response:\n", "positive\n", "\n", "```\n", "One of the previous reviewers wrote that there appeared to be no middle ground for opinions of Love Story; one loved it or hated it. But there seems to be a remarkable distribution of opinions throughout the scale of 1 to 10. For me, this movie rated a 4. There are some beautiful scenes and locations, and Ray Milland turns in a fabulous job as Oliver's father. But the movie did not do a particularly compelling job of telling its story, and the story was not so unique as to warrant multiple viewings, at least, not for me. I may be a bit of a snob, but I tend to avoid movies with Ryan O'Neal -- I still haven't seen Barry Lyndon -- because most of them, but not all, are ruined for me by his presence. The lone exception is What's Up, Doc?, in which his straight performance is the perfect underlining for Barbra Streisand's goofball protagonist -- and, not coincidentally, he takes a shot at Love Story for good measure! McGraw and O'Neal tend to mug their lines, rather than act them.

This movie is notable for the beginning of one fine career: it was Tommy Lee Jones's first movie.\n", "```\n", "\n", "Your response:\n", "negative\n", "\n", "```\n", "I'm not tired to say this is one of the best political thrillers ever made. The story takes place in a fictional state, but obviously it deals with the murder of Kennedy. A truthful and honest district attorney (played by Yves Montand) does not believe that the murder was planned and executed by the single man Daslow (=Oswald) and though all other officials want to close the case he continuous to investigate with his team.

The screenplay is written tight and fast and holds the tension till the end. Just the part dealing with the Milgram experiment about authorities is (though not uninteresting) a bit out of place. The ending sequence - explaining who Icarus really is - partly shot in slow motion and intensified by a Morricone soundtrack is the most powerful sequence I have ever seen in a movie.\n", "```\n", "\n", "Your response:\n", "positive\n", "\n", "```\n", "wow! i watched the trailer for this one and though 'nah, this one is not for me'. i watched my husband and our friend's faces during the trailer, and knew this was a 'boy movie'. i mean, hallo! a bunch of chick barmaids that dance - another striptease?

then, i started watching it, it didn't look all that bad. so i carried on watching. i watched it right to the end. what an awesome movie. if anything, this is a chick-flick. these girls have attitude. it is really a feel-good movie, and a bit of a love story. really leaves you with a nice feeling.

basically, the story of a small-town girl making it big in the city, after going through the usual big-city c**p. there have been a couple of these, it is almost a new urban legend. but it also makes you think of your life, and what you have achieved. well, me anyway. i think it is because the whole working in a bar scenario is very familiar, not just for me, but for many people i know. Don't trust the trailers for this one - it is aimed at bringing the men in.\n", "```\n", "\n", "Your response:\n", "negative\n", "\n", "\n", "The text to classify:\n", "\n", "```\n", "A masterpiece, instant classic, 5 stars out of 5\n", "```\n", "\n", "Your response:\n", "\n" ] } ], "source": [ "print(clf.get_prompt(\"A masterpiece, instant classic, 5 stars out of 5\"))" ] }, { "cell_type": "markdown", "id": "b9eb015f", "metadata": {}, "source": [ "If we're unhappy with the prompt, we can also provide our own prompt using the `prompt` argument, as we saw earlier in this notebook." ] }, { "cell_type": "markdown", "id": "a8477a3a", "metadata": {}, "source": [ "### evaluation" ] }, { "cell_type": "code", "execution_count": 30, "id": "45e77c4a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Token indices sequence length is longer than the specified maximum sequence length for this model (1678 > 512). Running this sequence through the model will result in indexing errors\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 27.3 s, sys: 5.87 ms, total: 27.4 s\n", "Wall time: 8.86 s\n" ] } ], "source": [ "%time y_proba = clf.predict_proba(X)" ] }, { "cell_type": "code", "execution_count": 31, "id": "f53a4fda", "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "0.22576427762687323" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "log_loss(y, y_proba)" ] }, { "cell_type": "code", "execution_count": 32, "id": "844d95d8", "metadata": {}, "outputs": [], "source": [ "y_pred = y_proba.argmax(1)" ] }, { "cell_type": "code", "execution_count": 33, "id": "cf46193d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.92" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "accuracy_score(y, y_pred)" ] }, { "cell_type": "code", "execution_count": 34, "id": "686347e1", "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "array(['negative'], dtype='#sk-container-id-4 {color: black;background-color: white;}#sk-container-id-4 pre{padding: 0;}#sk-container-id-4 div.sk-toggleable {background-color: white;}#sk-container-id-4 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-4 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-4 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-4 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-4 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-4 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-4 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-4 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-4 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-4 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-4 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-4 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-4 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-4 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-4 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-4 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-4 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-4 div.sk-item {position: relative;z-index: 1;}#sk-container-id-4 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-4 div.sk-item::before, #sk-container-id-4 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-4 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-4 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-4 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-4 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-4 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-4 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-4 div.sk-label-container {text-align: center;}#sk-container-id-4 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-4 div.sk-text-repr-fallback {display: none;}
GridSearchCV(cv=2,\n",
       "             estimator=FewShotClassifier(model=T5ForConditionalGeneration(\n",
       "  (shared): Embedding(32128, 512)\n",
       "  (encoder): T5Stack(\n",
       "    (embed_tokens): Embedding(32128, 512)\n",
       "    (block): ModuleList(\n",
       "      (0): T5Block(\n",
       "        (layer): ModuleList(\n",
       "          (0): T5LayerSelfAttention(\n",
       "            (SelfAttention): T5Attention(\n",
       "              (q): Linear(in_features=512, out_features=384, bias=False)\n",
       "              (k): Linear(in_features=512, out_...\n",
       "), tokenizer=PreTrainedTokenizerFast(name_or_path='google/flan-t5-small', vocab_size=32100, model_max_len=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<pad>', 'additional_special_tokens': ['<extra_id_0>', '<extra_id_1>', '<extra_id_2>', '<extra_id_3>', '<extra_id_4>', '<extra_id_5>', '<extra_id_6>', '<extra_id_7>', '<extra_id_8>', '<extra_id_9>', '<extra_id_10>', '<extra_id_11>', '<extra_id_12>', '<extra_id_13>', '<extra_id_14>', '<extra_id_15>', '<extra_id_16>', '<extra_id_17>', '<extra_id_18>', '<extra_id_19>', '<extra_id_20>', '<extra_id_21>', '<extra_id_22>', '<extra_id_23>', '<extra_id_24>', '<extra_id_25>', '<extra_id_26>', '<extra_id_27>', '<extra_id_28>', '<extra_id_29>', '<extra_id_30>', '<extra_id_31>', '<extra_id_32>', '<extra_id_33>', '<extra_id_34>', '<extra_id_35>', '<extra_id_36>', '<extra_id_37>', '<extra_id_38>', '<extra_id_39>', '<extra_id_40>', '<extra_id_41>', '<extra_id_42>', '<extra_id_43>', '<extra_id_44>', '<extra_id_45>', '<extra_id_46>', '<extra_id_47>', '<extra_id_48>', '<extra_id_49>', '<extra_id_50>', '<extra_id_51>', '<extra_id_52>', '<extra_id_53>', '<extra_id_54>', '<extra_id_55>', '<extra_id_56>', '<extra_id_57>', '<extra_id_58>', '<extra_id_59>', '<extra_id_60>', '<extra_id_61>', '<extra_id_62>', '<extra_id_63>', '<extra_id_64>', '<extra_id_65>', '<extra_id_66>', '<extra_id_67>', '<extra_id_68>', '<extra_id_69>', '<extra_id_70>', '<extra_id_71>', '<extra_id_72>', '<extra_id_73>', '<extra_id_74>', '<extra_id_75>', '<extra_id_76>', '<extra_id_77>', '<extra_id_78>', '<extra_id_79>', '<extra_id_80>', '<extra_id_81>', '<extra_id_82>', '<extra_id_83>', '<extra_id_84>', '<extra_id_85>', '<extra_id_86>', '<extra_id_87>', '<extra_id_88>', '<extra_id_89>', '<extra_id_90>', '<extra_id_91>', '<extra_id_92>', '<extra_id_93>', '<extra_id_94>', '<extra_id_95>', '<extra_id_96>', '<extra_id_97>', '<extra_id_98>', '<extra_id_99>']}), use_caching=False),\n",
       "             param_grid={'max_samples': [3, 5, 7]}, refit=False,\n",
       "             scoring=['accuracy', 'neg_log_loss'])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "GridSearchCV(cv=2,\n", " estimator=FewShotClassifier(model=T5ForConditionalGeneration(\n", " (shared): Embedding(32128, 512)\n", " (encoder): T5Stack(\n", " (embed_tokens): Embedding(32128, 512)\n", " (block): ModuleList(\n", " (0): T5Block(\n", " (layer): ModuleList(\n", " (0): T5LayerSelfAttention(\n", " (SelfAttention): T5Attention(\n", " (q): Linear(in_features=512, out_features=384, bias=False)\n", " (k): Linear(in_features=512, out_...\n", "), tokenizer=PreTrainedTokenizerFast(name_or_path='google/flan-t5-small', vocab_size=32100, model_max_len=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'eos_token': '', 'unk_token': '', 'pad_token': '', 'additional_special_tokens': ['', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '']}), use_caching=False),\n", " param_grid={'max_samples': [3, 5, 7]}, refit=False,\n", " scoring=['accuracy', 'neg_log_loss'])" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%time search.fit(X, labels)" ] }, { "cell_type": "code", "execution_count": 38, "id": "f03ff250", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
mean_test_accuracymean_test_neg_log_lossparam_max_samplesmean_score_time
00.92-0.23566135.453266
10.91-0.24204159.747535
20.92-0.235754715.653747
\n", "
" ], "text/plain": [ " mean_test_accuracy mean_test_neg_log_loss param_max_samples \\\n", "0 0.92 -0.235661 3 \n", "1 0.91 -0.242041 5 \n", "2 0.92 -0.235754 7 \n", "\n", " mean_score_time \n", "0 5.453266 \n", "1 9.747535 \n", "2 15.653747 " ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pd.DataFrame(search.cv_results_)[['mean_test_accuracy', 'mean_test_neg_log_loss', 'param_max_samples', 'mean_score_time']]" ] }, { "cell_type": "markdown", "id": "d357f92c", "metadata": {}, "source": [ "**Conclusion**: There is no significant change in accuracy compared to zero-shot but a small improvement in log loss. Having more samples doesn't help but slows down the inference time, as we can see when looking at `mean_score_time`. Overall, few-shot learning helps a bit but makes inference slower. It's up to you to decide if the trade-off is worth it in this specific case." ] }, { "cell_type": "markdown", "id": "625dbfb1", "metadata": {}, "source": [ "## Debugging" ] }, { "cell_type": "markdown", "id": "bd6fd974", "metadata": {}, "source": [ "Working with LLMs can be difficult because it is hard to know for certain if the prompt works well and if the LLM is capable of classifying the input. For this reason, skorch provides a few options to help identify those issues." ] }, { "cell_type": "markdown", "id": "6781c236", "metadata": {}, "source": [ "### Returning unnormalized probabilities" ] }, { "cell_type": "markdown", "id": "39bd1569", "metadata": {}, "source": [ "By default, the model will normalize the probabilities to sum to 1. This is what is expected when calling `predict_proba`. However, this can hide underlying issues. The LLM can in theory predict any token from its vocabulary, there is no guarantee that it will choose one of the provided labels. skorch will force the LLM to use one of the labels, but we also track the probabilities assigned, or not assigned, to these labels.\n", "\n", "To give an example, for a given input, it's possible that the LLM predicts a probability of 10% that the label is 'negative' and 70% that it is 'positive'. By default, we normalize the probability to be 1, i.e. we return 0.125 and 0.875. The problem is that we would return the same normalized probabilities even if the model predicts 1% and 7%. But if the model predicts such low probabilities, there is probably something wrong and we would like to know about it.\n", "\n", "For this reason, we added the option to disable the normalization of probabilities. Let's check how well our zero-shot flan-t5 model is doing without normalization:" ] }, { "cell_type": "code", "execution_count": 39, "id": "b413434a", "metadata": {}, "outputs": [], "source": [ "clf = ZeroShotClassifier('google/flan-t5-small', use_caching=False, probas_sum_to_1=False)" ] }, { "cell_type": "code", "execution_count": 40, "id": "2a8cbf42", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
ZeroShotClassifier(model_name='google/flan-t5-small', probas_sum_to_1=False, use_caching=False)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "ZeroShotClassifier(model_name='google/flan-t5-small', probas_sum_to_1=False, use_caching=False)" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "clf.fit(X=None, y=['positive', 'negative'])" ] }, { "cell_type": "code", "execution_count": 41, "id": "22f084a5", "metadata": {}, "outputs": [], "source": [ "y_proba = clf.predict_proba(X[:3])" ] }, { "cell_type": "code", "execution_count": 42, "id": "ec3008a6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[0.55589342, 0.43614346],\n", " [0.56059068, 0.43085128],\n", " [0.9431383 , 0.04362515]])" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_proba" ] }, { "cell_type": "markdown", "id": "e668cf37", "metadata": {}, "source": [ "Let's check the sum of the two classes combined:" ] }, { "cell_type": "code", "execution_count": 43, "id": "1b11fb92", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0.99203688, 0.99144197, 0.98676345])" ] }, "execution_count": 43, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_proba.sum(1)" ] }, { "cell_type": "markdown", "id": "69dc1cc1", "metadata": {}, "source": [ "As you can see, the summed probabilities returned by flan-t5 are quite high. Without normalization, they still sum up to ~99%, which is very good.\n", "\n", "Now let's take a look at an LLM that doesn't work well for this task, GPT2.\n", "\n", "Note that, in contrast to flan-t5, GPT2 is a decoder-only language model, we don't need to set `use_caching=False`." ] }, { "cell_type": "code", "execution_count": 44, "id": "3446c6f1", "metadata": {}, "outputs": [], "source": [ "clf = ZeroShotClassifier('gpt2', probas_sum_to_1=False)" ] }, { "cell_type": "code", "execution_count": 45, "id": "1cd3b9bc", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
ZeroShotClassifier(model_name='gpt2', probas_sum_to_1=False)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "ZeroShotClassifier(model_name='gpt2', probas_sum_to_1=False)" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "clf.fit(X=None, y=['positive', 'negative'])" ] }, { "cell_type": "code", "execution_count": 46, "id": "c4f07575", "metadata": { "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n", "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n", "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" ] } ], "source": [ "y_proba = clf.predict_proba(X[:3])" ] }, { "cell_type": "code", "execution_count": 47, "id": "9e0177a3", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[3.86097296e-13, 1.38606856e-12],\n", " [2.50067030e-13, 8.08187610e-13],\n", " [3.82730414e-13, 1.23712900e-12]])" ] }, "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_proba" ] }, { "cell_type": "markdown", "id": "907cad23", "metadata": {}, "source": [ "As we can see, the probabilities are really low, but if we had normalized them, we might not have noticed:" ] }, { "cell_type": "code", "execution_count": 48, "id": "882ec313", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[0.21786747, 0.78213253],\n", " [0.23630138, 0.76369862],\n", " [0.23627385, 0.76372615]])" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# normalize probabilities to sum up to 1\n", "y_proba / y_proba.sum(1, keepdims=True)" ] }, { "cell_type": "markdown", "id": "18b5842e", "metadata": {}, "source": [ "This means we should probably use a different LLM or tinker with the prompt until we get better results." ] }, { "cell_type": "markdown", "id": "9bbf72cd", "metadata": {}, "source": [ "### Specific actions when probabilities are low" ] }, { "cell_type": "markdown", "id": "c235b6ef", "metadata": {}, "source": [ "There are more options to identify low probabilities in a way that does not require manually inspecting the probabilities. For this, we provide two arguments for `ZeroShotClassifier` and `FewShotClassifier`:\n", "\n", "The first argument is called `error_low_prob`. It should be one of the following strings: `'ignore'`, `'warn'`, `'raise'`, or `'return_none'`.\n", "\n", "By default, it is `'ignore'`, which means that nothing happens, no matter how low the predicted proabilities. By setting it to `'warn'`, there will be a warning when the total probabilities of at least one predicted sample is too low. Use this option if you want to get the result but be alerted about possible problems.\n", "\n", "By passing `error_low_prob='raise'`, an error will be raised as soon as a sample with low total probabilities is encountered. This is useful if you want inference to stop immediately, instead of waiting for all predictions to be made.\n", "\n", "Finally, you can set `error_low_prob='return_none'`. In this case, nothing changes when calling `predict_proba`. When calling `predict`, however, the probabilities for the samples will be checked and if they're too low, the prediction will be replaced by `None`. This is useful if the predictions are generally good, but some examples are, for one reason or another, hard to predict.\n", "\n", "The second parameter, which should be used in conjunction with `error_low_prob`, is called `threshold_low_prob`. This is simply a float between 0 and 1 that indicates what the probability is that should be considered \"low\". Note that this value is compared to the _sum of the probability for all labels_ of a given sample. So when setting `threshold_low_prob=0.1`, and the probability for 'negative' is 0.05, but the probability for 'positive' is 0.2, this would be fine because in total, their probabilities exceed 0.1.\n", "\n", "Let's see how this works in practice by using the option to raise an error and setting the threshold to 0.5:" ] }, { "cell_type": "code", "execution_count": 49, "id": "035392d9", "metadata": {}, "outputs": [], "source": [ "# note that since GPT2 is a decoder-only language model, we don't need to set use_caching=False\n", "clf = ZeroShotClassifier('gpt2', error_low_prob='raise', threshold_low_prob=0.5)" ] }, { "cell_type": "code", "execution_count": 50, "id": "9c9fa255", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
ZeroShotClassifier(error_low_prob='raise', model_name='gpt2', threshold_low_prob=0.5)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "ZeroShotClassifier(error_low_prob='raise', model_name='gpt2', threshold_low_prob=0.5)" ] }, "execution_count": 50, "metadata": {}, "output_type": "execute_result" } ], "source": [ "clf.fit(X=None, y=['positive', 'negative'])" ] }, { "cell_type": "code", "execution_count": 51, "id": "3092ea4e", "metadata": { "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "There was an error: The sum of all probabilities is 0.000, which is below the minimum threshold of 0.500\n" ] } ], "source": [ "try:\n", " clf.predict_proba(X[:3])\n", "except Exception as exc:\n", " print(\"There was an error:\", exc)" ] }, { "cell_type": "markdown", "id": "fae9eabf", "metadata": {}, "source": [ "As you can see, we indeed got an error, alerting us immediately to potential issues." ] }, { "cell_type": "markdown", "id": "5899b0bc", "metadata": {}, "source": [ "## Testing MNLI" ] }, { "cell_type": "markdown", "id": "318fb816", "metadata": {}, "source": [ "There are other zero-shot classification methods out there. One such method is to use natural language inference (NLI). In a nutshell, this method works by creating the text embedding for the input and the embeddings for each label, then calculating the probability based on the similarity of the text and label embeddings.\n", "\n", "Let's compare the results to https://huggingface.co/facebook/bart-large-mnli, which is the most used zero-shot classifier on Hugging Face at the time of writing." ] }, { "cell_type": "code", "execution_count": 52, "id": "1764b077", "metadata": {}, "outputs": [], "source": [ "from transformers import pipeline" ] }, { "cell_type": "code", "execution_count": 53, "id": "d3bf7743", "metadata": {}, "outputs": [], "source": [ "classifier = pipeline('zero-shot-classification', model='facebook/bart-large-mnli', device=device)" ] }, { "cell_type": "code", "execution_count": 54, "id": "d1899abe", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 18.2 s, sys: 33 ms, total: 18.2 s\n", "Wall time: 6.73 s\n" ] } ], "source": [ "%time preds = classifier(imdb['train'][:100]['text'], ['negative', 'positive'])" ] }, { "cell_type": "code", "execution_count": 55, "id": "1ef6c2a8", "metadata": {}, "outputs": [], "source": [ "y_proba = np.vstack([p['scores'] if p['labels'] == ['negative', 'positive'] else p['scores'][::-1] for p in preds])" ] }, { "cell_type": "code", "execution_count": 56, "id": "66d29e1f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.84" ] }, "execution_count": 56, "metadata": {}, "output_type": "execute_result" } ], "source": [ "accuracy_score(y, y_proba.argmax(1))" ] }, { "cell_type": "code", "execution_count": 57, "id": "4da4ea3a", "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "0.34437056518170317" ] }, "execution_count": 57, "metadata": {}, "output_type": "execute_result" } ], "source": [ "log_loss(y, y_proba)" ] }, { "cell_type": "markdown", "id": "be5e7ce7", "metadata": {}, "source": [ "**Conclusion**: This model is slower than the tested zero-shot classifier, it is less flexible (we cannot adjust prompt or other parameters), and it performs worse. For this task, it is, therefore, better to use skorch's `ZeroShotClassifier`." ] }, { "cell_type": "markdown", "id": "0f8196ec", "metadata": {}, "source": [ "## Testing a standard machine learning solution" ] }, { "cell_type": "markdown", "id": "245748e2", "metadata": {}, "source": [ "Finally, let's compare the results to a classical supervised machine learning approach. For this, we use TFIDF to vectorize the input and a logistic regression for classification. This a standard pipeline for text classification tasks and works really well with enough data." ] }, { "cell_type": "code", "execution_count": 58, "id": "37dbcd6c", "metadata": {}, "outputs": [], "source": [ "from sklearn.pipeline import Pipeline\n", "from sklearn.feature_extraction.text import TfidfVectorizer\n", "from sklearn.linear_model import LogisticRegression\n", "from sklearn.model_selection import cross_validate" ] }, { "cell_type": "code", "execution_count": 59, "id": "5f39164e", "metadata": {}, "outputs": [], "source": [ "tfidf = Pipeline([\n", " ('tfidf', TfidfVectorizer()),\n", " ('clf', LogisticRegression()),\n", "])" ] }, { "cell_type": "markdown", "id": "fef2d783", "metadata": {}, "source": [ "Let's run a grid search on a couple of hyper-parameters to ensure we pick good ones." ] }, { "cell_type": "code", "execution_count": 60, "id": "55cfe9e0", "metadata": {}, "outputs": [], "source": [ "params = {'tfidf__max_features': [500, 1000], 'tfidf__ngram_range': [(1, 1), (1, 2), (1, 3)]}" ] }, { "cell_type": "code", "execution_count": 61, "id": "0f465c76", "metadata": {}, "outputs": [], "source": [ "search = GridSearchCV(\n", " tfidf, param_grid=params, cv=2, scoring=['accuracy', 'neg_log_loss'], refit=False\n", ")" ] }, { "cell_type": "code", "execution_count": 62, "id": "5a4eca10", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 607 ms, sys: 0 ns, total: 607 ms\n", "Wall time: 607 ms\n" ] }, { "data": { "text/html": [ "
GridSearchCV(cv=2,\n",
       "             estimator=Pipeline(steps=[('tfidf', TfidfVectorizer()),\n",
       "                                       ('clf', LogisticRegression())]),\n",
       "             param_grid={'tfidf__max_features': [500, 1000],\n",
       "                         'tfidf__ngram_range': [(1, 1), (1, 2), (1, 3)]},\n",
       "             refit=False, scoring=['accuracy', 'neg_log_loss'])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "GridSearchCV(cv=2,\n", " estimator=Pipeline(steps=[('tfidf', TfidfVectorizer()),\n", " ('clf', LogisticRegression())]),\n", " param_grid={'tfidf__max_features': [500, 1000],\n", " 'tfidf__ngram_range': [(1, 1), (1, 2), (1, 3)]},\n", " refit=False, scoring=['accuracy', 'neg_log_loss'])" ] }, "execution_count": 62, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%time search.fit(X, y)" ] }, { "cell_type": "markdown", "id": "f9a8f9af", "metadata": {}, "source": [ "The table is quite big, let's look at the top 5 best log losses:" ] }, { "cell_type": "code", "execution_count": 63, "id": "26a18072", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
mean_test_accuracymean_test_neg_log_lossparam_tfidf__max_featuresparam_tfidf__ngram_range
20.69-0.662397500(1, 3)
50.71-0.6639591000(1, 3)
10.68-0.664004500(1, 2)
40.70-0.6642151000(1, 2)
00.65-0.664609500(1, 1)
\n", "
" ], "text/plain": [ " mean_test_accuracy mean_test_neg_log_loss param_tfidf__max_features \\\n", "2 0.69 -0.662397 500 \n", "5 0.71 -0.663959 1000 \n", "1 0.68 -0.664004 500 \n", "4 0.70 -0.664215 1000 \n", "0 0.65 -0.664609 500 \n", "\n", " param_tfidf__ngram_range \n", "2 (1, 3) \n", "5 (1, 3) \n", "1 (1, 2) \n", "4 (1, 2) \n", "0 (1, 1) " ] }, "execution_count": 63, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cols = ['mean_test_accuracy', 'mean_test_neg_log_loss', 'param_tfidf__max_features', 'param_tfidf__ngram_range']\n", "pd.DataFrame(search.cv_results_)[cols].sort_values('mean_test_neg_log_loss', ascending=False).head()" ] }, { "cell_type": "markdown", "id": "34859e8c", "metadata": {}, "source": [ "**Conclusion**: This classical model is much faster, even if we include the training time, because it is much smaller than a language model. However, it's scores are also much worse, which is due to the small size of the dataset. If speed is no concern, using an LLM classifier would thus be a good option for this task." ] }, { "cell_type": "markdown", "id": "a8a5092b", "metadata": {}, "source": [ "## Summary" ] }, { "cell_type": "markdown", "id": "54d07164", "metadata": {}, "source": [ "In this notebook, we learned how to use skorch's `ZeroShotClassifier` and `FewShotClassifier` for a text classification task. Let's list a few advantages that we gained from using those classes:\n", "\n", "- On this particular dataset, zero- and few-shot learning outperformed a classical supervised machine learning approach. We also got better scores than what we got from MNLI.\n", "- We can use `ZeroShotClassifier` and `FewShotClassifier` as drop-in replacement for other sklearn text classification models because `fit`, `predict`, and `predict_proba` work as expected from an sklearn model.\n", "- It is trivial to run a grid search. This way, we can find out what model works best, what prompt is optimal, and how many few-shot samples to provide.\n", "- We can call `predict_proba` to get the (relative) probability the model assigns to each label, which is not something we normally get from a language model.\n", "- `ZeroShotClassifier` and `FewShotClassifier` also give us some nice extra features. Most notably, they force the language models to predict one of the provided labels, which is typically not a guarantee when using language models. We also get easy ways to detect issues and caching (for decoder-only models)." ] }, { "cell_type": "markdown", "id": "711b5c47-1495-4c6d-b989-7dc7fdf2c784", "metadata": {}, "source": [ "---\n", "\n", "## ✨ Bonus ✨\n", "\n", "Not every task is a classification task but some tasks can be broken down into a classification task!\n", "\n", "As an example we show you how a task such as [PIQA](https://huggingface.co/datasets/piqa)\n", "can be reformulated to be solved with the LLM classifier. PIQA defines the task of giving the right solution\n", "out of two options to achieve a given goal while the more sensible of the two is labelled correct.\n", "\n", "Two example entries of the PIQA dataset:\n", "\n", "| goal (string) | sol1 (string) | sol2 (string) | label (class label) |\n", "| - | - | - | - |\n", "| \"When boiling butter, when it's ready, you can\" | \"Pour it onto a plate\" | \"Pour it into a jar\" | 1 |\n", "| \"To permanently attach metal legs to a chair, you can\" | \"Weld the metal together to get it to stay firmly in place\" | \"Nail the metal together to get it to stay firmly in place\" | 0 |\n", "\n", "A generative approach to this problem would be to tell the model to name the correct solution, compare it with the given options and determine its number and compare it with the correct label.\n", "This approach doesn't work with the LLM classifier of course. But we can re-phrase the task a bit to give each solution a number and ask the model to predict the correct number for the task.\n", "\n", "Therefore, for a zero-shot formulation we could prompt the model like this:\n", "\n", "```\n", "prompt = \"\"\"Goal is: Do cardio exercise without running.\n", "Solution 1: Use a jump rope for 15 minutes.\n", "Solution 2: Run around a chair for 15 minutes.\n", "Correct: Solution \"\"\"\n", "```\n", "\n", "We then expect the model to complete `1` or `2`, which are now our classes. As always, the ideal way of prompting may differ according \n", "to the used model and instruction-trained models may need a more precise prompt.\n", "\n", "We will test this task on `bloomz-1b1` to cover another popular LLM and because we know [what to expect from this model on this task](https://huggingface.co/datasets/bigscience/evaluation-results/viewer/bloom-1b1/test) (at best 67.14% zero-shot)." ] }, { "cell_type": "code", "execution_count": 64, "id": "7c04244a-7dba-4447-a7e1-144ebd1cf988", "metadata": { "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 3/3 [00:00<00:00, 1131.56it/s]\n" ] } ], "source": [ "dataset = datasets.load_dataset('piqa').shuffle(seed=42)" ] }, { "cell_type": "code", "execution_count": 65, "id": "1a44793d-11bc-4827-960f-ac629260474f", "metadata": {}, "outputs": [], "source": [ "template = \"\"\"Goal is: {goal}\n", "Solution 1: {sol1}\n", "Solution 2: {sol2}\"\"\"\n", "X = []\n", "y = []\n", "# iterating over dataset['train'] directly is not possible, since that only yields the keys\n", "for i in range(len(dataset['train'])):\n", " row = dataset['train'][i]\n", " X.append(template.format(**row))\n", " y.append(\" 1\" if row['label'] == 0 else \" 2\")" ] }, { "cell_type": "markdown", "id": "d860747b-336d-42d1-b929-cf4e6f4ba5d3", "metadata": {}, "source": [ "Take note that we chose to use \" 1\" and \" 2\" as labels. `bloomz` seems to be trained in such a way that it favors \" 1\" over \"1\". This is not the case for `flan-t5-*`, for example, but something to keep in mind when prompting and testing these models." ] }, { "cell_type": "code", "execution_count": 66, "id": "0433c2a7-7c99-41c3-8775-7231b5d39d15", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Goal is: Do cardio exercise without running.\n", "Solution 1: Use a jump rope for 15 minutes.\n", "Solution 2: Run around a chair for 15 minutes.\n" ] } ], "source": [ "print(X[0])" ] }, { "cell_type": "code", "execution_count": 67, "id": "527ddff9-c51d-4ce3-8551-a47a2315e533", "metadata": {}, "outputs": [], "source": [ "model = 'bigscience/bloomz-1b1'\n", "\n", "prompt = \"\"\"{text}\n", "Correct: Solution\"\"\"" ] }, { "cell_type": "code", "execution_count": 68, "id": "d5be1de6-799c-44db-9f76-f7151a2376e9", "metadata": {}, "outputs": [], "source": [ "clf = ZeroShotClassifier(model_name=model, prompt=prompt, probas_sum_to_1=False, device=device)" ] }, { "cell_type": "code", "execution_count": 69, "id": "5b2f9be5-3ae0-4d00-87a1-75e98fde861f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 10.3 s, sys: 1.63 s, total: 11.9 s\n", "Wall time: 9.86 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/ubuntu/work/experiments/python/notebooks/llama/skorch.git/skorch/llm/classifier.py:81: UserWarning: The prompt may not be correct, it expects 2 placeholders: 'labels', 'text', missing keys: 'labels'\n", " warnings.warn(msg)\n" ] }, { "data": { "text/html": [ "
ZeroShotClassifier(device='cuda:0', model_name='bigscience/bloomz-1b1', probas_sum_to_1=False, prompt='{text}\\nCorrect: Solution')
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "ZeroShotClassifier(device='cuda:0', model_name='bigscience/bloomz-1b1', probas_sum_to_1=False, prompt='{text}\\nCorrect: Solution')" ] }, "execution_count": 69, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%time clf.fit(X=None, y=y)" ] }, { "cell_type": "markdown", "id": "4981486c-c596-451c-b132-3430e38659a4", "metadata": {}, "source": [ "To save you some time we will just classify 1000 of the ~16,000 samples. This is an example and not a benchmark, after all." ] }, { "cell_type": "code", "execution_count": 70, "id": "4b6f76f7-fb69-44b0-9c4b-512e3eef023a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 2min 25s, sys: 168 ms, total: 2min 25s\n", "Wall time: 24.6 s\n" ] } ], "source": [ "max_n = 1000\n", "\n", "%time y_proba = clf.predict_proba(X[:max_n])" ] }, { "cell_type": "markdown", "id": "f041616e-63b4-4071-8e21-44cd319e910c", "metadata": {}, "source": [ "### Evaluation" ] }, { "cell_type": "code", "execution_count": 71, "id": "3472af08-eb38-4535-a148-a25a6166f74e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.7061630520500096" ] }, "execution_count": 71, "metadata": {}, "output_type": "execute_result" } ], "source": [ "log_loss(y[:max_n], y_proba)" ] }, { "cell_type": "code", "execution_count": 72, "id": "82ae0d79-ae8d-42eb-a391-37db513df3d3", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.524" ] }, "execution_count": 72, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_pred = clf.classes_[y_proba.argmax(1)]\n", "\n", "accuracy_score(y[:max_n], y_pred)" ] }, { "cell_type": "code", "execution_count": 73, "id": "bf74d394-95ff-465d-8467-190091b60a21", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.15824902634500856" ] }, "execution_count": 73, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_proba_normed = y_proba / y_proba.sum(axis=1)[:, None]\n", "\n", "abs(y_proba_normed[:, 0] - y_proba_normed[:, 1]).mean()" ] }, { "cell_type": "markdown", "id": "9df8b1af-6f42-4e00-af3a-dc6336007c88", "metadata": {}, "source": [ "You can see that the accuracy is below what we expected (67.14%) and the probabilities are very close. Why is that?\n", "\n", "The reported accuracy of the reference benchmark is determined by choosing the answer which has the higher log-probability. What the EleutherAI benchmark does is to ask for the probabilities of `\" \"` and `\" \"`: the prompt with the higher probability is the winner. This is *leveraging common knowledge* (i.e. a more likely phrase correlates with a more common phrase which is a good bias for correctness - you are more likely to find 'eat a burger' than 'throw a ball' for the goal context 'i am hungry'). \n", "\n", "Another aspect is that we are introducing an *indirection* with our task framing: we're answering but we're choosing a symbol *for* the answer. This makes it a lot harder for the model to choose the correct answer because it not only needs to understand to only answer with the valid options (1 and 2) but only to memorize what they signifiy in terms of the goal phrase. This is also indicated by the very low mean absolute difference between the probabilities of both options (of only ~15 percent points). Different models seem to perform differently on this. `flan-t5` does this really well, `bloomz` seems to perform worse, only `bloomz-3b` is able to achieve accuracies >65% with this setup." ] }, { "cell_type": "code", "execution_count": 74, "id": "8b45d511-6b15-48c0-856e-5d3c0d4fe5ac", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([0.15728807, 0.23420842]), array([0.63083631, 0.81155944]))" ] }, "execution_count": 74, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_proba.min(0), y_proba.max(0)" ] }, { "cell_type": "markdown", "id": "71a6985f-0d7a-440e-b66f-9855bc004b59", "metadata": {}, "source": [ "By looking at the minimum and maximum (absolute) probabilities over all samples we can also see that the model is never really strongly certain about an answer which, in itself, is not a problem but combined with the low accuracy is indicative for a limited capability in \"understanding\" the task at hand.\n", "\n", "**So... is this bad?!** - the performance? Yes. But in general: **no** - we tasked the model with a more complex task and arguably it is good thing that the probabilities we computed revealed to us that the model is not able to fit the data. If you can see that the difference in probabilities between the options is quite small it is likely that the model is not able to solve the task and might not be well-suited for the task at hand. This could be an indicator for you to chose a different model (or maybe even simpler, a different prompt). Note that the reason might not simply be 'model complexity' but could also mean that there's an unfortunate tokenization that needs a bigger model to sort out - that's very hard to say.\n", "\n", "The lesson here is that the probabilistic view lets you reason a bit more about the performance of these models in a familiar way.\n", "\n", "And this concludes this bonus section. ✨\n", "\n", "You have seen how frame atypical tasks into a classification problem and experienced first-hand how good or bad the capability to handle indirection can vary between models. You've also learned the importance of dealing with tokenization preferences (`\"1\"` vs `\" 1\"`) of language models and saw that having an interface to look into the probabilities and their differences can tell you a bit about your model and it's ability to solve your task." ] }, { "cell_type": "code", "execution_count": null, "id": "101726d3-033c-445a-9b09-4010d0680a70", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 }