{ "cells": [ { "cell_type": "markdown", "id": "22d246ff", "metadata": {}, "source": [ "# Using Large Language Models as text classifiers with an sklearn interface" ] }, { "cell_type": "markdown", "id": "74501674", "metadata": {}, "source": [ "In this notebook, we will learn how to use skorch's `ZeroShotClassifier` and `FewShotClassifier` to perform classification without any training thanks to the power of (Large) Language Models (LLMs). For this, we rely on the the [Hugging Face transformers](https://huggingface.co/docs/transformers/index) library, which allows us to use all the available text generation models provided by Hugging Face." ] }, { "cell_type": "markdown", "id": "595cf8f6", "metadata": {}, "source": [ "
\n", "\n", " Run in Google Colab \n", " | \n", "View source on GitHub |
ZeroShotClassifier(device='cuda:0', model_name='google/flan-t5-small', use_caching=False)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
ZeroShotClassifier(device='cuda:0', model_name='google/flan-t5-small', use_caching=False)
GridSearchCV(cv=2,\n", " estimator=ZeroShotClassifier(device='cuda:0', model_name='google/flan-t5-small', use_caching=False),\n", " param_grid={'prompt': ['You are a text classification assistant.\\n'\n", " '\\n'\n", " 'The text to classify:\\n'\n", " '\\n'\n", " '```\\n'\n", " '{text}\\n'\n", " '```\\n'\n", " '\\n'\n", " 'Choose the label among the following '\n", " 'possibilities with the highest '\n", " 'probability.\\n'\n", " 'Only return the label, nothing more:\\n'\n", " '\\n'\n", " '{labels}\\n'\n", " '\\n'\n", " 'Your response:\\n',\n", " 'Your task is to classify text.\\n'\n", " '\\n'\n", " 'Choose the label among the following '\n", " 'possibilities with the highest '\n", " 'probability.\\n'\n", " 'Only return the label, nothing more:\\n'\n", " '\\n'\n", " '{labels}\\n'\n", " '\\n'\n", " 'The text to classify:\\n'\n", " '\\n'\n", " '```\\n'\n", " '{text}\\n'\n", " '```\\n'\n", " '\\n'\n", " 'Your response:\\n']},\n", " refit=False, scoring=['accuracy', 'neg_log_loss'])In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
GridSearchCV(cv=2,\n", " estimator=ZeroShotClassifier(device='cuda:0', model_name='google/flan-t5-small', use_caching=False),\n", " param_grid={'prompt': ['You are a text classification assistant.\\n'\n", " '\\n'\n", " 'The text to classify:\\n'\n", " '\\n'\n", " '```\\n'\n", " '{text}\\n'\n", " '```\\n'\n", " '\\n'\n", " 'Choose the label among the following '\n", " 'possibilities with the highest '\n", " 'probability.\\n'\n", " 'Only return the label, nothing more:\\n'\n", " '\\n'\n", " '{labels}\\n'\n", " '\\n'\n", " 'Your response:\\n',\n", " 'Your task is to classify text.\\n'\n", " '\\n'\n", " 'Choose the label among the following '\n", " 'possibilities with the highest '\n", " 'probability.\\n'\n", " 'Only return the label, nothing more:\\n'\n", " '\\n'\n", " '{labels}\\n'\n", " '\\n'\n", " 'The text to classify:\\n'\n", " '\\n'\n", " '```\\n'\n", " '{text}\\n'\n", " '```\\n'\n", " '\\n'\n", " 'Your response:\\n']},\n", " refit=False, scoring=['accuracy', 'neg_log_loss'])
ZeroShotClassifier(device='cuda:0', model_name='google/flan-t5-small', use_caching=False)
ZeroShotClassifier(device='cuda:0', model_name='google/flan-t5-small', use_caching=False)
\n", " | mean_test_accuracy | \n", "mean_test_neg_log_loss | \n", "param_prompt | \n", "mean_score_time | \n", "
---|---|---|---|---|
0 | \n", "0.86 | \n", "-0.287012 | \n", "You are a text classification assistant.\\n\\nTh... | \n", "5.084494 | \n", "
1 | \n", "0.93 | \n", "-0.246949 | \n", "Your task is to classify text.\\n\\nChoose the l... | \n", "5.091446 | \n", "
FewShotClassifier(model='T5ForConditionalGeneration', tokenizer='T5TokenizerFast', use_caching=False)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
FewShotClassifier(model='T5ForConditionalGeneration', tokenizer='T5TokenizerFast', use_caching=False)
GridSearchCV(cv=2,\n", " estimator=FewShotClassifier(model=T5ForConditionalGeneration(\n", " (shared): Embedding(32128, 512)\n", " (encoder): T5Stack(\n", " (embed_tokens): Embedding(32128, 512)\n", " (block): ModuleList(\n", " (0): T5Block(\n", " (layer): ModuleList(\n", " (0): T5LayerSelfAttention(\n", " (SelfAttention): T5Attention(\n", " (q): Linear(in_features=512, out_features=384, bias=False)\n", " (k): Linear(in_features=512, out_...\n", "), tokenizer=PreTrainedTokenizerFast(name_or_path='google/flan-t5-small', vocab_size=32100, model_max_len=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<pad>', 'additional_special_tokens': ['<extra_id_0>', '<extra_id_1>', '<extra_id_2>', '<extra_id_3>', '<extra_id_4>', '<extra_id_5>', '<extra_id_6>', '<extra_id_7>', '<extra_id_8>', '<extra_id_9>', '<extra_id_10>', '<extra_id_11>', '<extra_id_12>', '<extra_id_13>', '<extra_id_14>', '<extra_id_15>', '<extra_id_16>', '<extra_id_17>', '<extra_id_18>', '<extra_id_19>', '<extra_id_20>', '<extra_id_21>', '<extra_id_22>', '<extra_id_23>', '<extra_id_24>', '<extra_id_25>', '<extra_id_26>', '<extra_id_27>', '<extra_id_28>', '<extra_id_29>', '<extra_id_30>', '<extra_id_31>', '<extra_id_32>', '<extra_id_33>', '<extra_id_34>', '<extra_id_35>', '<extra_id_36>', '<extra_id_37>', '<extra_id_38>', '<extra_id_39>', '<extra_id_40>', '<extra_id_41>', '<extra_id_42>', '<extra_id_43>', '<extra_id_44>', '<extra_id_45>', '<extra_id_46>', '<extra_id_47>', '<extra_id_48>', '<extra_id_49>', '<extra_id_50>', '<extra_id_51>', '<extra_id_52>', '<extra_id_53>', '<extra_id_54>', '<extra_id_55>', '<extra_id_56>', '<extra_id_57>', '<extra_id_58>', '<extra_id_59>', '<extra_id_60>', '<extra_id_61>', '<extra_id_62>', '<extra_id_63>', '<extra_id_64>', '<extra_id_65>', '<extra_id_66>', '<extra_id_67>', '<extra_id_68>', '<extra_id_69>', '<extra_id_70>', '<extra_id_71>', '<extra_id_72>', '<extra_id_73>', '<extra_id_74>', '<extra_id_75>', '<extra_id_76>', '<extra_id_77>', '<extra_id_78>', '<extra_id_79>', '<extra_id_80>', '<extra_id_81>', '<extra_id_82>', '<extra_id_83>', '<extra_id_84>', '<extra_id_85>', '<extra_id_86>', '<extra_id_87>', '<extra_id_88>', '<extra_id_89>', '<extra_id_90>', '<extra_id_91>', '<extra_id_92>', '<extra_id_93>', '<extra_id_94>', '<extra_id_95>', '<extra_id_96>', '<extra_id_97>', '<extra_id_98>', '<extra_id_99>']}), use_caching=False),\n", " param_grid={'max_samples': [3, 5, 7]}, refit=False,\n", " scoring=['accuracy', 'neg_log_loss'])In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
GridSearchCV(cv=2,\n", " estimator=FewShotClassifier(model=T5ForConditionalGeneration(\n", " (shared): Embedding(32128, 512)\n", " (encoder): T5Stack(\n", " (embed_tokens): Embedding(32128, 512)\n", " (block): ModuleList(\n", " (0): T5Block(\n", " (layer): ModuleList(\n", " (0): T5LayerSelfAttention(\n", " (SelfAttention): T5Attention(\n", " (q): Linear(in_features=512, out_features=384, bias=False)\n", " (k): Linear(in_features=512, out_...\n", "), tokenizer=PreTrainedTokenizerFast(name_or_path='google/flan-t5-small', vocab_size=32100, model_max_len=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<pad>', 'additional_special_tokens': ['<extra_id_0>', '<extra_id_1>', '<extra_id_2>', '<extra_id_3>', '<extra_id_4>', '<extra_id_5>', '<extra_id_6>', '<extra_id_7>', '<extra_id_8>', '<extra_id_9>', '<extra_id_10>', '<extra_id_11>', '<extra_id_12>', '<extra_id_13>', '<extra_id_14>', '<extra_id_15>', '<extra_id_16>', '<extra_id_17>', '<extra_id_18>', '<extra_id_19>', '<extra_id_20>', '<extra_id_21>', '<extra_id_22>', '<extra_id_23>', '<extra_id_24>', '<extra_id_25>', '<extra_id_26>', '<extra_id_27>', '<extra_id_28>', '<extra_id_29>', '<extra_id_30>', '<extra_id_31>', '<extra_id_32>', '<extra_id_33>', '<extra_id_34>', '<extra_id_35>', '<extra_id_36>', '<extra_id_37>', '<extra_id_38>', '<extra_id_39>', '<extra_id_40>', '<extra_id_41>', '<extra_id_42>', '<extra_id_43>', '<extra_id_44>', '<extra_id_45>', '<extra_id_46>', '<extra_id_47>', '<extra_id_48>', '<extra_id_49>', '<extra_id_50>', '<extra_id_51>', '<extra_id_52>', '<extra_id_53>', '<extra_id_54>', '<extra_id_55>', '<extra_id_56>', '<extra_id_57>', '<extra_id_58>', '<extra_id_59>', '<extra_id_60>', '<extra_id_61>', '<extra_id_62>', '<extra_id_63>', '<extra_id_64>', '<extra_id_65>', '<extra_id_66>', '<extra_id_67>', '<extra_id_68>', '<extra_id_69>', '<extra_id_70>', '<extra_id_71>', '<extra_id_72>', '<extra_id_73>', '<extra_id_74>', '<extra_id_75>', '<extra_id_76>', '<extra_id_77>', '<extra_id_78>', '<extra_id_79>', '<extra_id_80>', '<extra_id_81>', '<extra_id_82>', '<extra_id_83>', '<extra_id_84>', '<extra_id_85>', '<extra_id_86>', '<extra_id_87>', '<extra_id_88>', '<extra_id_89>', '<extra_id_90>', '<extra_id_91>', '<extra_id_92>', '<extra_id_93>', '<extra_id_94>', '<extra_id_95>', '<extra_id_96>', '<extra_id_97>', '<extra_id_98>', '<extra_id_99>']}), use_caching=False),\n", " param_grid={'max_samples': [3, 5, 7]}, refit=False,\n", " scoring=['accuracy', 'neg_log_loss'])
FewShotClassifier(model='T5ForConditionalGeneration', tokenizer='T5TokenizerFast', use_caching=False)
FewShotClassifier(model='T5ForConditionalGeneration', tokenizer='T5TokenizerFast', use_caching=False)
\n", " | mean_test_accuracy | \n", "mean_test_neg_log_loss | \n", "param_max_samples | \n", "mean_score_time | \n", "
---|---|---|---|---|
0 | \n", "0.92 | \n", "-0.235661 | \n", "3 | \n", "5.453266 | \n", "
1 | \n", "0.91 | \n", "-0.242041 | \n", "5 | \n", "9.747535 | \n", "
2 | \n", "0.92 | \n", "-0.235754 | \n", "7 | \n", "15.653747 | \n", "
ZeroShotClassifier(model_name='google/flan-t5-small', probas_sum_to_1=False, use_caching=False)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
ZeroShotClassifier(model_name='google/flan-t5-small', probas_sum_to_1=False, use_caching=False)
ZeroShotClassifier(model_name='gpt2', probas_sum_to_1=False)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
ZeroShotClassifier(model_name='gpt2', probas_sum_to_1=False)
ZeroShotClassifier(error_low_prob='raise', model_name='gpt2', threshold_low_prob=0.5)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
ZeroShotClassifier(error_low_prob='raise', model_name='gpt2', threshold_low_prob=0.5)
GridSearchCV(cv=2,\n", " estimator=Pipeline(steps=[('tfidf', TfidfVectorizer()),\n", " ('clf', LogisticRegression())]),\n", " param_grid={'tfidf__max_features': [500, 1000],\n", " 'tfidf__ngram_range': [(1, 1), (1, 2), (1, 3)]},\n", " refit=False, scoring=['accuracy', 'neg_log_loss'])In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
GridSearchCV(cv=2,\n", " estimator=Pipeline(steps=[('tfidf', TfidfVectorizer()),\n", " ('clf', LogisticRegression())]),\n", " param_grid={'tfidf__max_features': [500, 1000],\n", " 'tfidf__ngram_range': [(1, 1), (1, 2), (1, 3)]},\n", " refit=False, scoring=['accuracy', 'neg_log_loss'])
Pipeline(steps=[('tfidf', TfidfVectorizer()), ('clf', LogisticRegression())])
TfidfVectorizer()
LogisticRegression()
\n", " | mean_test_accuracy | \n", "mean_test_neg_log_loss | \n", "param_tfidf__max_features | \n", "param_tfidf__ngram_range | \n", "
---|---|---|---|---|
2 | \n", "0.69 | \n", "-0.662397 | \n", "500 | \n", "(1, 3) | \n", "
5 | \n", "0.71 | \n", "-0.663959 | \n", "1000 | \n", "(1, 3) | \n", "
1 | \n", "0.68 | \n", "-0.664004 | \n", "500 | \n", "(1, 2) | \n", "
4 | \n", "0.70 | \n", "-0.664215 | \n", "1000 | \n", "(1, 2) | \n", "
0 | \n", "0.65 | \n", "-0.664609 | \n", "500 | \n", "(1, 1) | \n", "
ZeroShotClassifier(device='cuda:0', model_name='bigscience/bloomz-1b1', probas_sum_to_1=False, prompt='{text}\\nCorrect: Solution')In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
ZeroShotClassifier(device='cuda:0', model_name='bigscience/bloomz-1b1', probas_sum_to_1=False, prompt='{text}\\nCorrect: Solution')