{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "To set custom kernel for notebook https://scicomp.aalto.fi/triton/apps/jupyter/#installing-kernels-from-virtualenvs-or-anaconda-environments" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import warnings\n", "warnings.filterwarnings('ignore')\n", "\n", "import os\n", "import sys\n", "import time\n", "\n", "## Set environment variables, this must be done before importing transformers\n", "os.environ['TRANSFORMERS_OFFLINE'] = '1'\n", "os.environ['HF_HOME']='/scratch/shareddata/dldata/huggingface-hub-cache'" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1\n" ] } ], "source": [ "print(os.environ['TRANSFORMERS_OFFLINE'])" ] }, { "cell_type": "markdown", "metadata": { "jp-MarkdownHeadingCollapsed": true }, "source": [ "# Why Hugginface?\n", "- Open source.\n", "- A vast repository of pre-trained models across various domains.\n", "- Compitable with Tensorflow, Pytorch and JAX.\n", "- A community, not just a toolkit.\n", "- Researching and engineering.\n", "- Fine-tuning capabilities.\n", "\n", "https://huggingface.co/" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Simplest approach to use huggingface/transformers for inference: pipeline class\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "No model was supplied, defaulted to distilbert-base-uncased-finetuned-sst-2-english and revision af0f99b (https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english).\n", "Using a pipeline without specifying a model name and revision in production is not recommended.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Results: [{'label': 'POSITIVE', 'score': 0.999874472618103}, {'label': 'NEGATIVE', 'score': 0.9937865734100342}]\n" ] } ], "source": [ "from transformers import pipeline\n", "\n", "pipeline = pipeline(\"sentiment-analysis\")\n", "\n", "\n", "# Prepare input text\n", "inputs = [\"What a lovely day today!\",\"It is freezing outside.\"]\n", "\n", "results = pipeline(inputs)\n", "\n", "print(\"Results:\", results)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "No model was supplied, defaulted to gpt2 and revision 6c0e608 (https://huggingface.co/gpt2).\n", "Using a pipeline without specifying a model name and revision in production is not recommended.\n", "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Generated text: The capital of France is that of the great French historian and poet Maurice de Loewie. His books, the History of France, which are the result of the same years and during the latter years of his life, are regarded as the definitive work\n" ] } ], "source": [ "from transformers import pipeline\n", "\n", "pipeline = pipeline(\"text-generation\")\n", "\n", "# Prepare input text\n", "input_text = \"The capital of France is\"\n", "\n", "output = pipeline(input_text, max_length=50)\n", "generated_text = output[0]['generated_text']\n", "print(\"Generated text:\", generated_text)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Decompose the pipeline " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "**What happens in the pipeline?**\n", "\n", "Tokenization => Model => Post Processing\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tokenizer Name: gpt2\n", "Vocabulary Size: 50257\n", "Max Model Input Sizes: 1024\n", "Special Tokens: {'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>'}\n" ] } ], "source": [ "# Print relevant tokenizer information\n", "print(\"Tokenizer Name:\", pipeline.tokenizer.name_or_path)\n", "print(\"Vocabulary Size:\", pipeline.tokenizer.vocab_size)\n", "print(\"Max Model Input Sizes:\", pipeline.tokenizer.model_max_length)\n", "print(\"Special Tokens:\", pipeline.tokenizer.special_tokens_map)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "GPT2LMHeadModel(\n", " (transformer): GPT2Model(\n", " (wte): Embedding(50257, 768)\n", " (wpe): Embedding(1024, 768)\n", " (drop): Dropout(p=0.1, inplace=False)\n", " (h): ModuleList(\n", " (0-11): 12 x GPT2Block(\n", " (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", " (attn): GPT2Attention(\n", " (c_attn): Conv1D()\n", " (c_proj): Conv1D()\n", " (attn_dropout): Dropout(p=0.1, inplace=False)\n", " (resid_dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", " (mlp): GPT2MLP(\n", " (c_fc): Conv1D()\n", " (c_proj): Conv1D()\n", " (act): NewGELUActivation()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (lm_head): Linear(in_features=768, out_features=50257, bias=False)\n", ")" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# checkout the model architecture\n", "pipeline.model" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "GPT2Config {\n", " \"_name_or_path\": \"gpt2\",\n", " \"activation_function\": \"gelu_new\",\n", " \"architectures\": [\n", " \"GPT2LMHeadModel\"\n", " ],\n", " \"attn_pdrop\": 0.1,\n", " \"bos_token_id\": 50256,\n", " \"do_sample\": true,\n", " \"embd_pdrop\": 0.1,\n", " \"eos_token_id\": 50256,\n", " \"initializer_range\": 0.02,\n", " \"layer_norm_epsilon\": 1e-05,\n", " \"max_length\": 50,\n", " \"model_type\": \"gpt2\",\n", " \"n_ctx\": 1024,\n", " \"n_embd\": 768,\n", " \"n_head\": 12,\n", " \"n_inner\": null,\n", " \"n_layer\": 12,\n", " \"n_positions\": 1024,\n", " \"reorder_and_upcast_attn\": false,\n", " \"resid_pdrop\": 0.1,\n", " \"scale_attn_by_inverse_layer_idx\": false,\n", " \"scale_attn_weights\": true,\n", " \"summary_activation\": null,\n", " \"summary_first_dropout\": 0.1,\n", " \"summary_proj_to_labels\": true,\n", " \"summary_type\": \"cls_index\",\n", " \"summary_use_proj\": true,\n", " \"task_specific_params\": {\n", " \"text-generation\": {\n", " \"do_sample\": true,\n", " \"max_length\": 50\n", " }\n", " },\n", " \"transformers_version\": \"4.36.0\",\n", " \"use_cache\": true,\n", " \"vocab_size\": 50257\n", "}" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# checkout the model config\n", "pipeline.model.config" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Tokenization\n", "\n", "Tokenizers prepares text data for processing by Transformer models. \n", "\n", "**Tokenizers' function**:\n", "\n", "- Text Preprocessing: Splitting Text into Tokens\n", "\n", "- Convert Tokens to IDs: Each token is mapped to a unique integer ID.\n", "- Add Special Tokens: \n", " - BERT models use [CLS] at the beginning of the input for classification tasks and [SEP] to separate different segments in the input. \n", " - In model pre-training, certain words in the input are replaced with the [MASK] token. The model then learns to predict the original value of these masked tokens, which helps in learning context and word relationships.\n", " - When the tokenizer encounters a word that is not in its vocabulary, it replaces it with the [UNK] (unknown) token. This is a way to handle out-of-vocabulary words.\n", " - GPT models use [BOS] indicates the start, and [EOS] marks the end of a text sequence. \n", "- Handle Fixed Sequence Lengths: Transformer models require inputs of a fixed length. Tokenizers pad shorter inputs with [PAD] tokens and truncate longer ones to meet the model's length requirements.\n", "\n", "- Attention Mask: The tokenizer generates an attention mask to differentiate real tokens from padding tokens ([PAD]) such that the model will pay attention only to the relevant parts of the input.\n", "\n", "- Consistency Across Languages: For multilingual models, tokenizers ensure consistent tokenization across different languages, maintaining a balanced and shared vocabulary.\n", "\n", "\n", "\n", "Three tokenizer types: Word-based, Subword-based, Character-based.\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Most state-of-the-art models use subword-based tokenizers:\n", "\n", "- BERT (Bidirectional Encoder Representations from Transformers): Uses the WordPiece tokenizer.\n", "\n", "- GPT-2 and GPT-3 (Generative Pre-trained Transformer): Utilize a variant of Byte Pair Encoding (BPE).\n", "\n", "- T5 (Text-To-Text Transfer Transformer): Employs the SentencePiece tokenizer, which is versatile and can be used across different languages and scripts." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['hello', ',', 'how', 'many', 'gp', '##us', 'do', 'you', 'need', '?']\n", "[7592, 1010, 2129, 2116, 14246, 2271, 2079, 2017, 2342, 1029]\n" ] } ], "source": [ "from transformers import BertTokenizer\n", "\n", "# Load the tokenizer\n", "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n", "\n", "# Example text\n", "text = \"Hello, how many GPUs do you need?\"\n", "\n", "# Tokenize the text\n", "tokens = tokenizer.tokenize(text)\n", "print(tokens)\n", "\n", "# Convert tokens to token IDs\n", "token_ids = tokenizer.convert_tokens_to_ids(tokens)\n", "print(token_ids)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "from transformers import GPT2Tokenizer" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['Hello', ',', 'Ġhow', 'Ġmany', 'ĠGPUs', 'Ġdo', 'Ġyou', 'Ġneed', '?']\n", "[15496, 11, 703, 867, 32516, 466, 345, 761, 30]\n" ] } ], "source": [ "# Load the tokenizer\n", "tokenizer = GPT2Tokenizer.from_pretrained('gpt2')\n", "\n", "# Example text\n", "text = \"Hello, how many GPUs do you need?\"\n", "\n", "# Tokenize the text\n", "tokens = tokenizer.tokenize(text)\n", "print(tokens)\n", "\n", "# Convert tokens to token IDs\n", "token_ids = tokenizer.convert_tokens_to_ids(tokens)\n", "print(token_ids)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "from transformers import T5Tokenizer" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "You are using the default legacy behaviour of the . This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565\n", "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "['▁Hello', ',', '▁how', '▁many', '▁GPU', 's', '▁do', '▁you', '▁need', '?']\n", "[8774, 6, 149, 186, 23356, 7, 103, 25, 174, 58]\n" ] } ], "source": [ "# Initialize the tokenizer\n", "tokenizer = T5Tokenizer.from_pretrained('t5-base')\n", "\n", "# Example text\n", "text = \"Hello, how many GPUs do you need?\"\n", "\n", "# Tokenize the text\n", "tokens = tokenizer.tokenize(text,add_special_tokens=True)\n", "print(tokens)\n", "\n", "# Convert tokens to token IDs\n", "token_ids = tokenizer.convert_tokens_to_ids(tokens)\n", "print(token_ids)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**NOTE: A pretrained model only performs properly when the input was tokenized under the same rules that its training data were tokenized.**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Tokenizer Classes in Hugging Face:\n", "- PreTrainedTokenizer: base class for all tokenizers. It provides common methods and attributes that are shared across various tokenizer types. It's not typically used directly for loading specific model tokenizers.\n", "- Specifically designed tokenizer, for example: BertTokenizer for the BERT model. It inherits from PreTrainedTokenizer.\n" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. \n", "The tokenizer class you load from this checkpoint is 'BertTokenizer'. \n", "The class this function is called from is 'PreTrainedTokenizer'.\n" ] }, { "ename": "NotImplementedError", "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNotImplementedError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[16], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtransformers\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m PreTrainedTokenizer\n\u001b[1;32m 3\u001b[0m \u001b[38;5;66;03m#Directely call a PreTrainedTokenizer, this will throw errors.\u001b[39;00m\n\u001b[0;32m----> 4\u001b[0m tokenizer \u001b[38;5;241m=\u001b[39m \u001b[43mPreTrainedTokenizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mbert-base-uncased\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 5\u001b[0m encoded_input \u001b[38;5;241m=\u001b[39m tokenizer(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mHello, Hugging Face!\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", "File \u001b[0;32m/scratch/shareddata/LLMs_tools/conda-llm/lib/python3.12/site-packages/transformers/tokenization_utils_base.py:2028\u001b[0m, in \u001b[0;36mPreTrainedTokenizerBase.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, cache_dir, force_download, local_files_only, token, revision, *init_inputs, **kwargs)\u001b[0m\n\u001b[1;32m 2025\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 2026\u001b[0m logger\u001b[38;5;241m.\u001b[39minfo(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mloading file \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfile_path\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m from cache at \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresolved_vocab_files[file_id]\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m-> 2028\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_from_pretrained\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2029\u001b[0m \u001b[43m \u001b[49m\u001b[43mresolved_vocab_files\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2030\u001b[0m \u001b[43m \u001b[49m\u001b[43mpretrained_model_name_or_path\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2031\u001b[0m \u001b[43m \u001b[49m\u001b[43minit_configuration\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2032\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43minit_inputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2033\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2034\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2035\u001b[0m \u001b[43m \u001b[49m\u001b[43mlocal_files_only\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlocal_files_only\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2036\u001b[0m \u001b[43m \u001b[49m\u001b[43m_commit_hash\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcommit_hash\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2037\u001b[0m \u001b[43m \u001b[49m\u001b[43m_is_local\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mis_local\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2038\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2039\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m/scratch/shareddata/LLMs_tools/conda-llm/lib/python3.12/site-packages/transformers/tokenization_utils_base.py:2260\u001b[0m, in \u001b[0;36mPreTrainedTokenizerBase._from_pretrained\u001b[0;34m(cls, resolved_vocab_files, pretrained_model_name_or_path, init_configuration, token, cache_dir, local_files_only, _commit_hash, _is_local, *init_inputs, **kwargs)\u001b[0m\n\u001b[1;32m 2258\u001b[0m \u001b[38;5;66;03m# Instantiate the tokenizer.\u001b[39;00m\n\u001b[1;32m 2259\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 2260\u001b[0m tokenizer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43minit_inputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43minit_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2261\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mOSError\u001b[39;00m:\n\u001b[1;32m 2262\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mOSError\u001b[39;00m(\n\u001b[1;32m 2263\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnable to load vocabulary from file. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 2264\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPlease check that the provided vocabulary is accessible and not corrupted.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 2265\u001b[0m )\n", "File \u001b[0;32m/scratch/shareddata/LLMs_tools/conda-llm/lib/python3.12/site-packages/transformers/tokenization_utils.py:367\u001b[0m, in \u001b[0;36mPreTrainedTokenizer.__init__\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 363\u001b[0m \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 365\u001b[0m \u001b[38;5;66;03m# 4. If some of the special tokens are not part of the vocab, we add them, at the end.\u001b[39;00m\n\u001b[1;32m 366\u001b[0m \u001b[38;5;66;03m# the order of addition is the same as self.SPECIAL_TOKENS_ATTRIBUTES following `tokenizers`\u001b[39;00m\n\u001b[0;32m--> 367\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_add_tokens\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 368\u001b[0m \u001b[43m \u001b[49m\u001b[43m[\u001b[49m\u001b[43mtoken\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mall_special_tokens_extended\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mnot\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_added_tokens_encoder\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 369\u001b[0m \u001b[43m \u001b[49m\u001b[43mspecial_tokens\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 370\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 372\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_decode_use_source_tokenizer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n", "File \u001b[0;32m/scratch/shareddata/LLMs_tools/conda-llm/lib/python3.12/site-packages/transformers/tokenization_utils.py:467\u001b[0m, in \u001b[0;36mPreTrainedTokenizer._add_tokens\u001b[0;34m(self, new_tokens, special_tokens)\u001b[0m\n\u001b[1;32m 465\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m added_tokens\n\u001b[1;32m 466\u001b[0m \u001b[38;5;66;03m# TODO this is fairly slow to improve!\u001b[39;00m\n\u001b[0;32m--> 467\u001b[0m current_vocab \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_vocab\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mcopy()\n\u001b[1;32m 468\u001b[0m new_idx \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlen\u001b[39m(current_vocab) \u001b[38;5;66;03m# only call this once, len gives the last index + 1\u001b[39;00m\n\u001b[1;32m 469\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m token \u001b[38;5;129;01min\u001b[39;00m new_tokens:\n", "File \u001b[0;32m/scratch/shareddata/LLMs_tools/conda-llm/lib/python3.12/site-packages/transformers/tokenization_utils_base.py:1675\u001b[0m, in \u001b[0;36mPreTrainedTokenizerBase.get_vocab\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1665\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_vocab\u001b[39m(\u001b[38;5;28mself\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Dict[\u001b[38;5;28mstr\u001b[39m, \u001b[38;5;28mint\u001b[39m]:\n\u001b[1;32m 1666\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 1667\u001b[0m \u001b[38;5;124;03m Returns the vocabulary as a dictionary of token to index.\u001b[39;00m\n\u001b[1;32m 1668\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1673\u001b[0m \u001b[38;5;124;03m `Dict[str, int]`: The vocabulary.\u001b[39;00m\n\u001b[1;32m 1674\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 1675\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mNotImplementedError\u001b[39;00m()\n", "\u001b[0;31mNotImplementedError\u001b[0m: " ] } ], "source": [ "from transformers import PreTrainedTokenizer\n", "\n", "#Directely call a PreTrainedTokenizer, this will throw errors.\n", "tokenizer = PreTrainedTokenizer.from_pretrained('bert-base-uncased')\n", "encoded_input = tokenizer(\"Hello, Hugging Face!\")\n" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['the', 'capital', 'of', 'finland', 'is', '?']\n" ] } ], "source": [ "from transformers import BertTokenizer\n", "\n", "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased',padding=True,truncation=True,max_length=20)\n", "\n", "# Example text\n", "text = \"The capital of Finland is?\"\n", "\n", "# Tokenize the text\n", "tokens = tokenizer.tokenize(text)\n", "print(tokens)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Hyperparameters in tokenizer:\n", "\n", "- padding: padding Strategy\n", "- truncate: truncation Strategy\n", "- max_length: \n", "- ..." ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Keyword arguments {'padding': True, 'truncation': True, 'max_length': 20} not recognized.\n" ] }, { "data": { "text/plain": [ "['the', 'capital', 'of', 'finland', 'is', '?']" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokens = tokenizer.tokenize(text, padding=True,truncation=True,max_length=20)\n", "tokens" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**NOTE:** Call a tokenizer directly is used when you're preparing data for model input (like training or inference). Whereas the tokenize() method is used when you need a token-level analysis or manipulation of the text.\n", "\n", "Hyperparameters like `padding`, `truncate`, `max_length`` are not recognized by tokenize() method.\n" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "('input_ids', [[101, 7592, 1010, 17662, 2227, 999, 2425, 2033, 2055, 2035, 2115, 19204, 17629, 4127, 1012, 102], [101, 7592, 1010, 2088, 999, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])\n", "('token_type_ids', [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])\n", "('attention_mask', [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])\n" ] } ], "source": [ "text = [\"Hello, Hugging Face! Tell me about all your tokenizer types.\", \"Hello, world!\"]\n", "\n", "# call a tokenizer directly, invoking its __call__ method\n", "encoded_input = tokenizer(text, padding=True,truncation=True,max_length=20) \n", "for item in encoded_input.items():\n", " print(item)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model\n", "### Huggingface Model Classes:\n", "https://huggingface.co/docs/transformers/model_doc/auto\n", "- **Base model**:\n", "\n", "Base model is also referred to as a pre-trained model, is a model that has already been trained on a large, generic dataset. The primary purpose of a base model is to capture a wide range of language features and understandings, such as grammar, context, and basic associations. A base model provides a robust foundation of language understanding which can be adapted for specific tasks.\n", "\n", "Base models in Huggingface are often named after the architecture they use, like bert-base-uncased, gpt2-medium,t5-base, etc.\n", "- **Fine tuned model:**\n", "\n", "A fine-tuned model is a model that has undergone additional training (fine-tuning) on a smaller, task-specific dataset. This can include tasks like sentiment analysis, question answering, or domain-specific language understanding.\n", "\n", "Fine-tuned models usually have additional descriptors in their names indicating the specific task or dataset they are fine-tuned for. For instance, **\"bert-base-uncased-finetuned-squad\"** is a BERT model fine-tuned on the SQuAD dataset for question answering, whereas **\"bert-base-uncased\"** is a base model.\n", "\n", "More information can usually be found in the README or model description in the model repo.\n", "Besides, inspecting the Model's Configuration or architecture can also give hints." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Which model to use?\n", "[https://huggingface.co/models](https://huggingface.co/models)\n", "* Task Type\n", "* Specific language (especially non-English languages)\n", "* Model Size and Performance\n", "* Fine-Tuning and Customization\n", "* Community and Support\n", "* Documentation and Examples\n", "* Ethical Considerations\n", "* Licensing and Cost\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Set up the tokenizer, load the model and perform inference, step by step." ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Generated text: The capital of France is the capital of the French Republic, and the capital of the French Republic is the capital of the French Republic.\n", "\n", "The French Republic is the capital of the French Republic.\n", "\n", "The French Republic is the capital of the\n" ] } ], "source": [ "from transformers import GPT2Tokenizer, GPT2LMHeadModel\n", "\n", "# Initialize the tokenizer for GPT-2\n", "tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n", "\n", "# Load the pre-trained GPT-2 model\n", "model = GPT2LMHeadModel.from_pretrained(\"gpt2\")\n", "\n", "# Prepare input text\n", "input_text = \"The capital of France is\"\n", "input_ids = tokenizer.encode(input_text, return_tensors=\"pt\")\n", "\n", "# Generate attention mask\n", "attention_mask = tokenizer(input_text, return_tensors=\"pt\").attention_mask\n", "\n", "# Set pad token ID if it's not already set\n", "model.config.pad_token_id = model.config.eos_token_id\n", "\n", "# Generate output\n", "outputs = model.generate(input_ids, max_length=50)\n", "generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)\n", "\n", "print(\"Generated text:\", generated_text)" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[1, 1, 1, 1, 1]])" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "attention_mask" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Do I need to look for the specific tokenizer and model classes for my tasks every time?**\n", "\n", "In many cases, no. The architecture you want to use can be guessed from the name or the path of the pretrained model. Huggingface provides **AutoClasses** to help you automatically retrieve the relevant model given the name/path to the pretrained weights/config/vocabulary.\n" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Generated text: The capital of France is the capital of the French Republic, and the capital of the French Republic is the capital of the French Republic.\n", "\n", "The French Republic is the capital of the French Republic.\n", "\n", "The French Republic is the capital of the\n" ] } ], "source": [ "## NOTE: AutoModel will instantiate a base model class without a specific head, so we still need \n", "## a \"relatively specific\" class AutoModelForCausalLM\n", "from transformers import AutoTokenizer, AutoModelForCausalLM\n", "\n", "# Initialize the tokenizer for GPT-2\n", "tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")\n", "\n", "# Load the pre-trained GPT-2 model\n", "model = AutoModelForCausalLM.from_pretrained(\"gpt2\")\n", "\n", "# Prepare input text\n", "input_text = \"The capital of France is\"\n", "input_ids = tokenizer.encode(input_text, return_tensors=\"pt\")\n", "\n", "# Generate attention mask\n", "attention_mask = tokenizer(input_text, return_tensors=\"pt\").attention_mask\n", "\n", "# Set pad token ID if it's not already set\n", "model.config.pad_token_id = model.config.eos_token_id\n", "\n", "# Generate output\n", "outputs = model.generate(input_ids, max_length=50)\n", "generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)\n", "\n", "print(\"Generated text:\", generated_text)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Key outputs from a language model\n", "- Logits: The raw, unnormalized scores for each vocabulary token at each position in the output sequence. By default, the model's forward pass returns the logits.\n", "- Hidden States: Representations from each layer of the model. These are the activations of the model's neurons at each layer. Set `output_hidden_states=True` in the configuration or when calling the model to obtain Hidden States.\n", "- Attentions: Attention weights from each layer of the model. These weights show how much each token in a sequence attends to every other token at each layer. Set `output_attentions=True` in the configuration or when calling the model to obtain Attentions." ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "BertModel(\n", " (embeddings): BertEmbeddings(\n", " (word_embeddings): Embedding(28996, 768, padding_idx=0)\n", " (position_embeddings): Embedding(512, 768)\n", " (token_type_embeddings): Embedding(2, 768)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (encoder): BertEncoder(\n", " (layer): ModuleList(\n", " (0-11): 12 x BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " (intermediate_act_fn): GELUActivation()\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " )\n", " (pooler): BertPooler(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (activation): Tanh()\n", " )\n", ")" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[[ 0.2629, 0.0496, 0.1699, ..., -0.0339, 0.2812, -0.0489],\n", " [-0.3381, -0.2910, 0.2394, ..., 0.4664, -0.4263, 0.2448],\n", " [-0.3315, -0.1127, -0.1425, ..., 0.6752, -0.1898, 0.5174],\n", " ...,\n", " [-0.1510, 0.4374, -0.2816, ..., 0.3068, 0.4450, 0.4092],\n", " [ 0.0758, 0.1059, 0.0871, ..., 0.3782, 0.2463, -0.2250],\n", " [-0.0174, -0.1541, -1.0330, ..., 0.4842, 0.6491, 0.2534]]],\n", " grad_fn=)\n" ] } ], "source": [ "from transformers import AutoTokenizer, AutoModel\n", "\n", "model = AutoModel.from_pretrained(\"bert-base-cased\")\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")\n", "\n", "# Prepare input text\n", "input_text = \"The capital of France is\"\n", "input_ids = tokenizer.encode(input_text, return_tensors=\"pt\")\n", "\n", "# get hidden state\n", "outputs = model(input_ids)\n", "print(outputs.last_hidden_state)" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "logits: tensor([[[ -36.2874, -35.0114, -38.0793, ..., -40.5164, -41.3760,\n", " -34.9193],\n", " [ -75.1021, -75.6483, -82.6827, ..., -82.5961, -79.3913,\n", " -76.2687],\n", " [ -80.0968, -78.6868, -81.2341, ..., -83.7548, -85.6541,\n", " -79.8042],\n", " [ -86.0085, -86.4618, -91.0184, ..., -98.6912, -93.3734,\n", " -87.9286],\n", " [-108.9542, -108.9327, -112.5793, ..., -118.3345, -113.1505,\n", " -110.3779]]], grad_fn=)\n", "Attentions: (tensor([[[[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [8.4640e-01, 1.5360e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [7.0135e-01, 2.2373e-01, 7.4919e-02, 0.0000e+00, 0.0000e+00],\n", " [6.0768e-01, 1.7884e-01, 1.4391e-01, 6.9565e-02, 0.0000e+00],\n", " [6.0990e-01, 1.5188e-01, 6.2560e-02, 9.4493e-02, 8.1164e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [1.4739e-04, 9.9985e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [2.7310e-03, 8.2182e-04, 9.9645e-01, 0.0000e+00, 0.0000e+00],\n", " [3.3176e-04, 2.4361e-03, 1.4651e-03, 9.9577e-01, 0.0000e+00],\n", " [3.2342e-03, 2.1838e-03, 1.5252e-02, 1.1193e-03, 9.7821e-01]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.1416e-01, 8.5841e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [6.7615e-01, 1.3776e-01, 1.8609e-01, 0.0000e+00, 0.0000e+00],\n", " [4.5474e-01, 2.0124e-01, 1.1568e-01, 2.2834e-01, 0.0000e+00],\n", " [4.6935e-01, 7.1428e-02, 2.0851e-01, 7.0084e-02, 1.8062e-01]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [1.5014e-01, 8.4986e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [1.3307e-01, 9.9187e-02, 7.6774e-01, 0.0000e+00, 0.0000e+00],\n", " [2.8450e-02, 1.4408e-02, 1.2834e-03, 9.5586e-01, 0.0000e+00],\n", " [1.0551e-01, 1.1721e-02, 6.7463e-02, 1.7100e-02, 7.9820e-01]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [3.9975e-01, 6.0025e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [4.2455e-01, 5.0290e-01, 7.2550e-02, 0.0000e+00, 0.0000e+00],\n", " [5.8611e-02, 1.8143e-02, 1.3441e-02, 9.0980e-01, 0.0000e+00],\n", " [2.7160e-01, 1.7796e-01, 1.3677e-01, 1.3924e-01, 2.7444e-01]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [2.7893e-02, 9.7211e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [2.9158e-01, 2.6920e-02, 6.8150e-01, 0.0000e+00, 0.0000e+00],\n", " [1.4353e-02, 2.2892e-04, 7.7237e-06, 9.8541e-01, 0.0000e+00],\n", " [6.2411e-02, 9.3040e-03, 8.2062e-03, 1.0775e-03, 9.1900e-01]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [6.9259e-01, 3.0741e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [4.7506e-01, 4.8846e-01, 3.6488e-02, 0.0000e+00, 0.0000e+00],\n", " [4.7012e-01, 2.7895e-01, 3.9287e-02, 2.1164e-01, 0.0000e+00],\n", " [2.1669e-01, 3.4696e-01, 3.3619e-02, 3.5308e-01, 4.9655e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.7262e-01, 2.7382e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [5.1730e-01, 3.3254e-01, 1.5016e-01, 0.0000e+00, 0.0000e+00],\n", " [4.0525e-01, 9.0645e-02, 3.4848e-01, 1.5563e-01, 0.0000e+00],\n", " [2.3696e-01, 1.1449e-01, 1.5967e-01, 1.5589e-01, 3.3299e-01]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.0300e-01, 9.6995e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [4.8718e-01, 1.0044e-01, 4.1238e-01, 0.0000e+00, 0.0000e+00],\n", " [6.1475e-01, 1.2612e-01, 1.7063e-01, 8.8498e-02, 0.0000e+00],\n", " [2.3666e-01, 4.0859e-02, 2.1303e-01, 3.6777e-02, 4.7267e-01]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.5478e-01, 4.5223e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [6.4656e-01, 1.0486e-01, 2.4857e-01, 0.0000e+00, 0.0000e+00],\n", " [6.0311e-01, 1.5732e-01, 1.9775e-01, 4.1819e-02, 0.0000e+00],\n", " [4.2266e-01, 1.1131e-01, 2.1519e-01, 8.6133e-02, 1.6471e-01]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [7.1178e-01, 2.8822e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [5.8995e-01, 9.0134e-02, 3.1991e-01, 0.0000e+00, 0.0000e+00],\n", " [4.8390e-01, 1.0257e-01, 1.5057e-01, 2.6296e-01, 0.0000e+00],\n", " [4.0053e-01, 9.5700e-02, 1.9708e-01, 3.7479e-02, 2.6921e-01]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [8.4934e-01, 1.5066e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [6.3224e-01, 2.2437e-01, 1.4339e-01, 0.0000e+00, 0.0000e+00],\n", " [3.9314e-01, 2.7727e-01, 1.5651e-01, 1.7308e-01, 0.0000e+00],\n", " [4.6852e-01, 1.1897e-01, 1.0104e-01, 1.2253e-01, 1.8895e-01]]]],\n", " grad_fn=), tensor([[[[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.7209e-01, 2.7912e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [5.8621e-01, 3.8064e-01, 3.3142e-02, 0.0000e+00, 0.0000e+00],\n", " [4.8232e-01, 1.2759e-01, 2.5068e-01, 1.3941e-01, 0.0000e+00],\n", " [4.5573e-01, 2.2592e-01, 1.1976e-01, 4.9527e-02, 1.4906e-01]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.8905e-01, 1.0951e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [8.3572e-01, 8.3821e-02, 8.0454e-02, 0.0000e+00, 0.0000e+00],\n", " [6.8450e-01, 4.7232e-02, 2.0742e-01, 6.0846e-02, 0.0000e+00],\n", " [7.3180e-01, 3.9842e-02, 5.9887e-02, 5.9164e-02, 1.0930e-01]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.9220e-01, 7.7975e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [7.6141e-01, 1.8003e-02, 2.2059e-01, 0.0000e+00, 0.0000e+00],\n", " [6.0938e-01, 2.2223e-02, 2.3404e-01, 1.3435e-01, 0.0000e+00],\n", " [4.5839e-01, 1.5597e-02, 1.6847e-01, 1.0259e-01, 2.5495e-01]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [6.4533e-01, 3.5467e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [6.1803e-01, 1.8478e-01, 1.9720e-01, 0.0000e+00, 0.0000e+00],\n", " [5.5501e-01, 1.2625e-01, 1.3541e-01, 1.8333e-01, 0.0000e+00],\n", " [5.1175e-01, 8.8049e-02, 1.0100e-01, 1.3373e-01, 1.6547e-01]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [8.9888e-01, 1.0112e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [7.8684e-01, 1.3459e-01, 7.8575e-02, 0.0000e+00, 0.0000e+00],\n", " [7.1967e-01, 9.3313e-02, 6.5694e-02, 1.2132e-01, 0.0000e+00],\n", " [6.3908e-01, 8.8730e-02, 5.7430e-02, 1.1279e-01, 1.0197e-01]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.6850e-01, 3.1500e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [5.3495e-01, 3.9359e-03, 4.6111e-01, 0.0000e+00, 0.0000e+00],\n", " [7.2299e-01, 1.3465e-01, 3.9143e-02, 1.0322e-01, 0.0000e+00],\n", " [6.0384e-01, 1.3601e-02, 5.6788e-02, 1.7283e-02, 3.0849e-01]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.3237e-01, 6.7633e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [8.1456e-01, 1.3563e-01, 4.9807e-02, 0.0000e+00, 0.0000e+00],\n", " [6.8662e-01, 9.9697e-02, 6.8591e-02, 1.4509e-01, 0.0000e+00],\n", " [7.6440e-01, 1.0466e-01, 5.6559e-02, 2.9471e-02, 4.4903e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.0040e-01, 9.9604e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [7.5084e-01, 1.5419e-01, 9.4970e-02, 0.0000e+00, 0.0000e+00],\n", " [6.1832e-01, 1.9061e-01, 9.1609e-02, 9.9462e-02, 0.0000e+00],\n", " [5.6564e-01, 1.3307e-01, 7.9954e-02, 1.1634e-01, 1.0500e-01]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.6794e-01, 3.2058e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [8.7170e-01, 4.3231e-02, 8.5070e-02, 0.0000e+00, 0.0000e+00],\n", " [7.7739e-01, 4.0126e-02, 1.2773e-01, 5.4759e-02, 0.0000e+00],\n", " [6.4928e-01, 5.4109e-02, 8.0593e-02, 7.3304e-02, 1.4271e-01]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.9020e-01, 9.8014e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.2868e-01, 1.3217e-02, 5.8099e-02, 0.0000e+00, 0.0000e+00],\n", " [8.9125e-01, 1.5728e-02, 7.7379e-02, 1.5646e-02, 0.0000e+00],\n", " [8.5727e-01, 1.4522e-02, 3.8854e-02, 1.3800e-02, 7.5554e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [2.0297e-04, 9.9980e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [8.7313e-04, 5.9022e-01, 4.0891e-01, 0.0000e+00, 0.0000e+00],\n", " [5.2686e-04, 3.3926e-01, 2.4788e-01, 4.1233e-01, 0.0000e+00],\n", " [8.0960e-04, 2.2636e-01, 1.7240e-01, 2.9428e-01, 3.0615e-01]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [4.9741e-01, 5.0259e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [4.5825e-01, 1.1647e-03, 5.4058e-01, 0.0000e+00, 0.0000e+00],\n", " [5.6682e-02, 2.2267e-03, 2.0238e-02, 9.2085e-01, 0.0000e+00],\n", " [1.2164e-01, 1.7738e-03, 2.0815e-02, 2.8863e-03, 8.5288e-01]]]],\n", " grad_fn=), tensor([[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.8816, 0.1184, 0.0000, 0.0000, 0.0000],\n", " [0.3900, 0.3470, 0.2630, 0.0000, 0.0000],\n", " [0.4344, 0.0606, 0.4109, 0.0942, 0.0000],\n", " [0.3666, 0.0940, 0.3683, 0.1146, 0.0565]],\n", "\n", " [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.9685, 0.0315, 0.0000, 0.0000, 0.0000],\n", " [0.8868, 0.0496, 0.0636, 0.0000, 0.0000],\n", " [0.7343, 0.0683, 0.1390, 0.0584, 0.0000],\n", " [0.7296, 0.0639, 0.0412, 0.1309, 0.0345]],\n", "\n", " [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.9432, 0.0568, 0.0000, 0.0000, 0.0000],\n", " [0.1469, 0.8471, 0.0060, 0.0000, 0.0000],\n", " [0.1572, 0.2076, 0.5838, 0.0514, 0.0000],\n", " [0.2244, 0.0963, 0.1862, 0.3246, 0.1686]],\n", "\n", " [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.9846, 0.0154, 0.0000, 0.0000, 0.0000],\n", " [0.8928, 0.0777, 0.0295, 0.0000, 0.0000],\n", " [0.6239, 0.0908, 0.1662, 0.1192, 0.0000],\n", " [0.6062, 0.0426, 0.2135, 0.0812, 0.0565]],\n", "\n", " [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.9782, 0.0218, 0.0000, 0.0000, 0.0000],\n", " [0.1871, 0.7836, 0.0293, 0.0000, 0.0000],\n", " [0.6109, 0.1238, 0.2466, 0.0187, 0.0000],\n", " [0.6416, 0.1750, 0.0845, 0.0081, 0.0908]],\n", "\n", " [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.9765, 0.0235, 0.0000, 0.0000, 0.0000],\n", " [0.7527, 0.1426, 0.1048, 0.0000, 0.0000],\n", " [0.4404, 0.1803, 0.2962, 0.0830, 0.0000],\n", " [0.4287, 0.0690, 0.1540, 0.1683, 0.1800]],\n", "\n", " [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.9733, 0.0267, 0.0000, 0.0000, 0.0000],\n", " [0.8905, 0.0347, 0.0747, 0.0000, 0.0000],\n", " [0.8629, 0.0255, 0.0468, 0.0648, 0.0000],\n", " [0.7898, 0.0282, 0.0614, 0.0545, 0.0661]],\n", "\n", " [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.7351, 0.2649, 0.0000, 0.0000, 0.0000],\n", " [0.5648, 0.1821, 0.2531, 0.0000, 0.0000],\n", " [0.3673, 0.1207, 0.1130, 0.3989, 0.0000],\n", " [0.2760, 0.0600, 0.0901, 0.2262, 0.3478]],\n", "\n", " [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.9666, 0.0334, 0.0000, 0.0000, 0.0000],\n", " [0.5883, 0.2326, 0.1791, 0.0000, 0.0000],\n", " [0.3643, 0.2940, 0.3139, 0.0278, 0.0000],\n", " [0.2934, 0.0450, 0.4518, 0.0157, 0.1942]],\n", "\n", " [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.8553, 0.1447, 0.0000, 0.0000, 0.0000],\n", " [0.2809, 0.6997, 0.0194, 0.0000, 0.0000],\n", " [0.2087, 0.5914, 0.0692, 0.1307, 0.0000],\n", " [0.1291, 0.2955, 0.0696, 0.4771, 0.0287]],\n", "\n", " [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.8377, 0.1623, 0.0000, 0.0000, 0.0000],\n", " [0.7713, 0.1540, 0.0747, 0.0000, 0.0000],\n", " [0.5854, 0.1238, 0.0504, 0.2403, 0.0000],\n", " [0.5171, 0.0990, 0.0531, 0.1615, 0.1693]],\n", "\n", " [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.9215, 0.0785, 0.0000, 0.0000, 0.0000],\n", " [0.8768, 0.0633, 0.0599, 0.0000, 0.0000],\n", " [0.8145, 0.0514, 0.0311, 0.1030, 0.0000],\n", " [0.7703, 0.0372, 0.0291, 0.0807, 0.0826]]]],\n", " grad_fn=), tensor([[[[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.8268e-01, 1.7316e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.9714e-01, 2.4256e-03, 4.3351e-04, 0.0000e+00, 0.0000e+00],\n", " [9.4433e-01, 3.0325e-05, 1.0623e-04, 5.5529e-02, 0.0000e+00],\n", " [9.6740e-01, 1.8683e-04, 2.7851e-04, 4.5622e-04, 3.1674e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.7486e-01, 2.5143e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.1439e-01, 3.3117e-02, 5.2495e-02, 0.0000e+00, 0.0000e+00],\n", " [7.0372e-01, 1.6742e-01, 7.7478e-02, 5.1381e-02, 0.0000e+00],\n", " [8.5533e-01, 2.5359e-02, 7.8012e-03, 1.3220e-02, 9.8292e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.0629e-01, 9.3707e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [6.0309e-01, 2.8151e-01, 1.1539e-01, 0.0000e+00, 0.0000e+00],\n", " [1.7294e-01, 7.0645e-02, 4.8604e-01, 2.7038e-01, 0.0000e+00],\n", " [2.1087e-01, 4.5951e-02, 4.1121e-01, 2.4295e-01, 8.9014e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.9231e-01, 7.6922e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [1.8290e-02, 9.6893e-01, 1.2778e-02, 0.0000e+00, 0.0000e+00],\n", " [3.3402e-01, 3.9228e-01, 2.6383e-01, 9.8732e-03, 0.0000e+00],\n", " [7.7857e-01, 8.7052e-02, 6.7886e-02, 2.3009e-02, 4.3481e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.9965e-01, 3.5409e-04, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.9085e-01, 3.1320e-03, 6.0142e-03, 0.0000e+00, 0.0000e+00],\n", " [9.3759e-01, 3.5951e-03, 1.7021e-02, 4.1793e-02, 0.0000e+00],\n", " [9.4960e-01, 2.0509e-03, 2.6982e-03, 1.1689e-02, 3.3964e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.6235e-01, 3.7648e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [8.4446e-01, 8.6601e-02, 6.8941e-02, 0.0000e+00, 0.0000e+00],\n", " [8.4127e-01, 6.9890e-02, 5.7835e-02, 3.1001e-02, 0.0000e+00],\n", " [6.2594e-01, 1.5199e-01, 1.1569e-01, 5.1549e-02, 5.4835e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [8.7570e-01, 1.2430e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [3.5125e-01, 4.5002e-01, 1.9873e-01, 0.0000e+00, 0.0000e+00],\n", " [3.4573e-01, 2.2871e-01, 1.0734e-01, 3.1822e-01, 0.0000e+00],\n", " [1.4163e-01, 2.1510e-01, 4.3367e-02, 5.0629e-01, 9.3614e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.7368e-01, 2.6315e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [1.0697e-01, 8.3537e-01, 5.7653e-02, 0.0000e+00, 0.0000e+00],\n", " [2.4387e-01, 2.1662e-01, 3.3404e-01, 2.0547e-01, 0.0000e+00],\n", " [1.8795e-01, 2.9717e-01, 8.5084e-02, 3.4794e-01, 8.1855e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.7094e-01, 2.9060e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [5.5981e-01, 1.0658e-01, 3.3361e-01, 0.0000e+00, 0.0000e+00],\n", " [2.1676e-01, 4.6400e-02, 6.6238e-01, 7.4458e-02, 0.0000e+00],\n", " [2.9181e-01, 1.1980e-02, 3.1150e-01, 9.6777e-02, 2.8793e-01]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.6866e-01, 3.1337e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [7.3343e-01, 1.9364e-01, 7.2931e-02, 0.0000e+00, 0.0000e+00],\n", " [6.4080e-01, 2.5656e-01, 7.0099e-02, 3.2543e-02, 0.0000e+00],\n", " [7.2064e-01, 7.5143e-02, 5.0348e-02, 5.4231e-03, 1.4845e-01]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.7586e-01, 2.4136e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [8.2307e-01, 4.7648e-02, 1.2928e-01, 0.0000e+00, 0.0000e+00],\n", " [7.8726e-01, 1.5413e-02, 2.2299e-02, 1.7503e-01, 0.0000e+00],\n", " [7.7859e-01, 1.0238e-02, 9.9621e-03, 9.4805e-02, 1.0640e-01]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [6.3870e-01, 3.6130e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [2.9023e-02, 9.5808e-01, 1.2895e-02, 0.0000e+00, 0.0000e+00],\n", " [1.6684e-01, 7.1874e-01, 9.9462e-02, 1.4962e-02, 0.0000e+00],\n", " [8.6535e-02, 8.2408e-01, 5.8495e-02, 7.2601e-03, 2.3631e-02]]]],\n", " grad_fn=), tensor([[[[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.8763e-01, 1.2369e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [7.5513e-01, 1.9957e-01, 4.5302e-02, 0.0000e+00, 0.0000e+00],\n", " [1.6612e-01, 2.8464e-01, 5.0885e-01, 4.0393e-02, 0.0000e+00],\n", " [2.7830e-01, 4.6737e-02, 6.3905e-01, 1.1467e-02, 2.4444e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.5025e-01, 4.9747e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [1.2971e-01, 8.2469e-01, 4.5595e-02, 0.0000e+00, 0.0000e+00],\n", " [5.5166e-01, 2.8164e-01, 1.4566e-01, 2.1041e-02, 0.0000e+00],\n", " [8.4285e-01, 1.1891e-02, 1.3288e-02, 7.6044e-03, 1.2436e-01]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.9410e-01, 5.8977e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.3703e-01, 3.3505e-02, 2.9464e-02, 0.0000e+00, 0.0000e+00],\n", " [9.5604e-01, 1.2286e-02, 2.3039e-02, 8.6327e-03, 0.0000e+00],\n", " [9.1609e-01, 1.2754e-02, 1.0708e-02, 3.6108e-02, 2.4342e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.3213e-01, 6.7874e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [1.0969e-01, 8.7633e-01, 1.3974e-02, 0.0000e+00, 0.0000e+00],\n", " [3.6800e-01, 4.9800e-01, 2.5786e-02, 1.0822e-01, 0.0000e+00],\n", " [7.8568e-02, 4.8116e-01, 3.8334e-01, 4.9499e-02, 7.4341e-03]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.8429e-01, 1.5711e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [6.1697e-01, 3.0108e-01, 8.1957e-02, 0.0000e+00, 0.0000e+00],\n", " [7.6427e-01, 6.6728e-02, 9.3641e-02, 7.5364e-02, 0.0000e+00],\n", " [7.8823e-01, 6.3681e-02, 2.6464e-02, 5.1123e-02, 7.0499e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.5242e-01, 4.7581e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [6.2807e-01, 2.4630e-01, 1.2563e-01, 0.0000e+00, 0.0000e+00],\n", " [5.1321e-01, 9.6500e-02, 3.4471e-01, 4.5581e-02, 0.0000e+00],\n", " [3.8266e-01, 1.7605e-01, 3.3798e-01, 6.5144e-02, 3.8159e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.8426e-01, 1.5742e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [8.0698e-01, 1.3276e-01, 6.0261e-02, 0.0000e+00, 0.0000e+00],\n", " [7.5616e-01, 6.8283e-02, 7.8978e-02, 9.6574e-02, 0.0000e+00],\n", " [5.0845e-01, 6.3645e-02, 8.5731e-02, 2.4365e-01, 9.8529e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.6902e-01, 3.0983e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [5.3681e-01, 4.5272e-02, 4.1791e-01, 0.0000e+00, 0.0000e+00],\n", " [4.9830e-01, 2.4730e-02, 9.5733e-02, 3.8124e-01, 0.0000e+00],\n", " [3.9466e-01, 1.4528e-02, 9.0154e-02, 4.5654e-02, 4.5500e-01]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.9760e-01, 2.4048e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.5771e-01, 1.1905e-02, 3.0384e-02, 0.0000e+00, 0.0000e+00],\n", " [9.3648e-01, 3.0840e-03, 9.9071e-03, 5.0528e-02, 0.0000e+00],\n", " [8.9304e-01, 6.8501e-03, 3.8995e-02, 2.6777e-02, 3.4342e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.6133e-01, 3.8670e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.5277e-01, 3.9294e-02, 7.9354e-03, 0.0000e+00, 0.0000e+00],\n", " [8.8129e-01, 7.9166e-02, 3.2179e-02, 7.3699e-03, 0.0000e+00],\n", " [4.4197e-01, 1.8336e-01, 3.2659e-01, 2.6132e-02, 2.1946e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.9621e-01, 3.7885e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.5926e-01, 1.7418e-02, 2.3323e-02, 0.0000e+00, 0.0000e+00],\n", " [8.7619e-01, 7.0181e-03, 3.9708e-02, 7.7079e-02, 0.0000e+00],\n", " [9.3684e-01, 8.5476e-03, 1.8532e-02, 1.4737e-02, 2.1345e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [1.0000e+00, 3.7350e-06, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [1.1243e-05, 9.9694e-01, 3.0456e-03, 0.0000e+00, 0.0000e+00],\n", " [5.4056e-08, 1.0553e-09, 1.0000e+00, 2.4611e-07, 0.0000e+00],\n", " [2.7906e-08, 7.2148e-08, 2.3880e-07, 9.9999e-01, 4.6058e-06]]]],\n", " grad_fn=), tensor([[[[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.9132e-01, 8.6850e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.9452e-01, 2.3545e-03, 3.1232e-03, 0.0000e+00, 0.0000e+00],\n", " [9.6897e-01, 1.1418e-03, 1.0935e-03, 2.8795e-02, 0.0000e+00],\n", " [9.7149e-01, 1.3023e-04, 5.7703e-04, 6.0039e-03, 2.1799e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.9957e-01, 4.2876e-04, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.9992e-01, 1.5292e-05, 6.4213e-05, 0.0000e+00, 0.0000e+00],\n", " [9.9904e-01, 5.1686e-06, 1.8187e-06, 9.5816e-04, 0.0000e+00],\n", " [9.9772e-01, 2.1845e-05, 4.3966e-07, 1.9277e-04, 2.0637e-03]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.4495e-01, 5.5048e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [3.9356e-01, 5.7606e-01, 3.0387e-02, 0.0000e+00, 0.0000e+00],\n", " [3.6411e-01, 3.0299e-01, 3.1687e-01, 1.6032e-02, 0.0000e+00],\n", " [3.2875e-01, 2.4410e-01, 3.0265e-01, 7.0624e-02, 5.3878e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.3013e-01, 6.9867e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [8.7017e-01, 8.9004e-02, 4.0831e-02, 0.0000e+00, 0.0000e+00],\n", " [6.9796e-01, 1.5881e-01, 1.1179e-01, 3.1447e-02, 0.0000e+00],\n", " [6.5844e-01, 1.5417e-01, 4.4342e-02, 6.0779e-02, 8.2266e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.6758e-01, 3.2418e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [6.3914e-01, 2.9498e-01, 6.5881e-02, 0.0000e+00, 0.0000e+00],\n", " [3.9811e-01, 9.0394e-02, 4.8018e-01, 3.1310e-02, 0.0000e+00],\n", " [6.7694e-01, 8.3124e-02, 1.0220e-01, 6.6097e-02, 7.1634e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.9676e-01, 3.2360e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.9158e-01, 4.4951e-03, 3.9272e-03, 0.0000e+00, 0.0000e+00],\n", " [9.9506e-01, 2.2144e-03, 6.9065e-05, 2.6588e-03, 0.0000e+00],\n", " [9.5271e-01, 6.7518e-03, 1.3870e-02, 3.2582e-03, 2.3410e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.9196e-01, 8.0396e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [8.6724e-01, 1.3005e-01, 2.7055e-03, 0.0000e+00, 0.0000e+00],\n", " [9.1523e-01, 3.7084e-02, 2.4366e-02, 2.3315e-02, 0.0000e+00],\n", " [9.3293e-01, 4.1201e-03, 4.5888e-04, 4.6213e-02, 1.6283e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.9209e-01, 7.9106e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [5.6535e-01, 3.9672e-01, 3.7929e-02, 0.0000e+00, 0.0000e+00],\n", " [6.8777e-01, 2.3920e-01, 6.3899e-02, 9.1290e-03, 0.0000e+00],\n", " [3.2970e-01, 4.7910e-01, 1.4188e-01, 2.3569e-02, 2.5749e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.8486e-01, 1.5137e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.8760e-01, 6.8198e-03, 5.5837e-03, 0.0000e+00, 0.0000e+00],\n", " [9.5438e-01, 1.0198e-02, 4.4678e-03, 3.0952e-02, 0.0000e+00],\n", " [9.5295e-01, 6.5464e-03, 2.5009e-03, 1.1826e-02, 2.6180e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.9748e-01, 2.5237e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.8592e-01, 1.0216e-02, 3.8620e-03, 0.0000e+00, 0.0000e+00],\n", " [9.7857e-01, 9.8701e-03, 1.9850e-03, 9.5784e-03, 0.0000e+00],\n", " [9.3577e-01, 2.1244e-02, 6.9420e-03, 1.6773e-02, 1.9270e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.4617e-01, 5.3829e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [6.6554e-01, 1.8921e-01, 1.4525e-01, 0.0000e+00, 0.0000e+00],\n", " [8.4404e-01, 4.7671e-02, 4.6438e-02, 6.1854e-02, 0.0000e+00],\n", " [4.1475e-01, 1.2054e-01, 1.0108e-01, 2.6399e-01, 9.9632e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.7725e-01, 2.2754e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [8.4521e-01, 4.2486e-02, 1.1230e-01, 0.0000e+00, 0.0000e+00],\n", " [7.0981e-01, 1.4360e-01, 6.2041e-02, 8.4544e-02, 0.0000e+00],\n", " [6.3260e-01, 6.4462e-02, 4.2408e-02, 5.3823e-02, 2.0671e-01]]]],\n", " grad_fn=), tensor([[[[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.3107e-01, 6.8927e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [7.2165e-01, 1.8779e-01, 9.0556e-02, 0.0000e+00, 0.0000e+00],\n", " [4.6368e-01, 1.8552e-01, 2.9214e-01, 5.8654e-02, 0.0000e+00],\n", " [4.4191e-01, 2.1471e-01, 1.9086e-01, 1.1681e-01, 3.5719e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.6630e-01, 3.3704e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [8.5391e-01, 1.1163e-01, 3.4456e-02, 0.0000e+00, 0.0000e+00],\n", " [9.2002e-01, 2.2570e-02, 4.1658e-02, 1.5755e-02, 0.0000e+00],\n", " [3.4370e-01, 3.0937e-01, 3.2868e-01, 4.5510e-03, 1.3704e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.8087e-01, 1.9129e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.4840e-01, 1.1681e-02, 3.9920e-02, 0.0000e+00, 0.0000e+00],\n", " [8.9529e-01, 1.0272e-02, 3.1628e-02, 6.2805e-02, 0.0000e+00],\n", " [8.2158e-01, 1.2511e-02, 2.0896e-02, 8.3904e-02, 6.1105e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.8491e-01, 1.5086e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.6406e-01, 1.9128e-02, 1.6817e-02, 0.0000e+00, 0.0000e+00],\n", " [8.3050e-01, 3.6039e-02, 8.6933e-02, 4.6528e-02, 0.0000e+00],\n", " [7.5217e-01, 2.4834e-02, 7.3416e-02, 8.6154e-02, 6.3430e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [8.4821e-01, 1.5179e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [8.3873e-01, 9.3524e-02, 6.7741e-02, 0.0000e+00, 0.0000e+00],\n", " [9.3704e-01, 2.1367e-02, 1.2340e-02, 2.9252e-02, 0.0000e+00],\n", " [7.2025e-01, 2.1460e-02, 2.7114e-02, 1.6127e-01, 6.9903e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.7658e-01, 2.3415e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.4204e-01, 3.6862e-02, 2.1100e-02, 0.0000e+00, 0.0000e+00],\n", " [8.7403e-01, 2.0384e-02, 2.8869e-02, 7.6717e-02, 0.0000e+00],\n", " [8.1515e-01, 3.4231e-02, 2.6072e-02, 6.2564e-02, 6.1980e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.6897e-01, 3.1028e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.3794e-01, 2.7844e-02, 3.4217e-02, 0.0000e+00, 0.0000e+00],\n", " [9.2826e-01, 1.8394e-02, 2.4454e-03, 5.0899e-02, 0.0000e+00],\n", " [8.8630e-01, 2.5514e-02, 5.7392e-03, 2.7406e-02, 5.5041e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.1927e-01, 8.0730e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [8.3202e-01, 1.1687e-01, 5.1114e-02, 0.0000e+00, 0.0000e+00],\n", " [8.1200e-01, 4.5452e-02, 6.5906e-02, 7.6644e-02, 0.0000e+00],\n", " [6.6297e-01, 6.4572e-02, 8.3910e-02, 1.6866e-01, 1.9883e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.2020e-01, 7.9799e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [3.2156e-01, 6.7010e-01, 8.3458e-03, 0.0000e+00, 0.0000e+00],\n", " [6.2888e-01, 1.7730e-01, 1.7765e-01, 1.6179e-02, 0.0000e+00],\n", " [7.1778e-01, 1.0738e-01, 5.3128e-02, 9.8006e-02, 2.3706e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.9811e-01, 1.8942e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.9769e-01, 9.5311e-04, 1.3615e-03, 0.0000e+00, 0.0000e+00],\n", " [9.9786e-01, 1.8052e-04, 8.9000e-05, 1.8703e-03, 0.0000e+00],\n", " [9.9209e-01, 2.7830e-04, 1.6944e-04, 3.0577e-03, 4.4005e-03]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.8963e-01, 1.0372e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.7655e-01, 1.2386e-02, 1.1060e-02, 0.0000e+00, 0.0000e+00],\n", " [9.2873e-01, 7.9086e-03, 6.7334e-03, 5.6627e-02, 0.0000e+00],\n", " [9.4687e-01, 3.3349e-03, 3.8526e-03, 2.0949e-02, 2.4993e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.6317e-01, 3.6828e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [8.1453e-01, 1.5719e-01, 2.8284e-02, 0.0000e+00, 0.0000e+00],\n", " [8.4460e-01, 3.3552e-02, 1.1012e-01, 1.1731e-02, 0.0000e+00],\n", " [8.2286e-01, 5.7689e-02, 4.7373e-02, 4.5768e-02, 2.6306e-02]]]],\n", " grad_fn=), tensor([[[[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.6141e-01, 3.8593e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.0344e-01, 6.1702e-02, 3.4861e-02, 0.0000e+00, 0.0000e+00],\n", " [7.9124e-01, 2.2201e-02, 1.6933e-01, 1.7236e-02, 0.0000e+00],\n", " [9.5321e-01, 3.3564e-03, 6.9881e-03, 1.9175e-02, 1.7265e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.6476e-01, 3.5243e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.8655e-01, 9.8743e-03, 3.5791e-03, 0.0000e+00, 0.0000e+00],\n", " [9.4783e-01, 5.8394e-03, 3.1383e-03, 4.3189e-02, 0.0000e+00],\n", " [9.5562e-01, 6.2320e-03, 6.4283e-03, 2.5439e-02, 6.2816e-03]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.9179e-01, 8.2106e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.9778e-01, 7.7677e-04, 1.4401e-03, 0.0000e+00, 0.0000e+00],\n", " [9.9611e-01, 7.4620e-05, 7.2848e-04, 3.0882e-03, 0.0000e+00],\n", " [9.9262e-01, 2.0443e-04, 5.5199e-04, 5.4467e-03, 1.1742e-03]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.3731e-01, 6.2687e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [7.8342e-01, 1.4043e-01, 7.6153e-02, 0.0000e+00, 0.0000e+00],\n", " [8.3782e-01, 4.1296e-02, 4.8374e-02, 7.2512e-02, 0.0000e+00],\n", " [7.7558e-01, 7.1647e-02, 2.1223e-02, 9.3007e-02, 3.8544e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.7783e-01, 2.2168e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.8155e-01, 1.4647e-02, 3.8022e-03, 0.0000e+00, 0.0000e+00],\n", " [5.0584e-01, 1.1217e-01, 3.3681e-01, 4.5181e-02, 0.0000e+00],\n", " [8.6466e-01, 4.3985e-02, 5.2316e-02, 1.8824e-02, 2.0212e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.4112e-01, 5.8879e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.1998e-01, 4.8161e-02, 3.1862e-02, 0.0000e+00, 0.0000e+00],\n", " [7.8413e-01, 5.7792e-02, 1.1763e-01, 4.0456e-02, 0.0000e+00],\n", " [8.0141e-01, 3.4749e-02, 3.2050e-02, 1.0286e-01, 2.8935e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.9054e-01, 9.4629e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.6498e-01, 9.4434e-03, 2.5576e-02, 0.0000e+00, 0.0000e+00],\n", " [8.5708e-01, 8.2129e-03, 2.5883e-02, 1.0882e-01, 0.0000e+00],\n", " [8.8442e-01, 4.7589e-03, 1.1856e-02, 8.4370e-02, 1.4590e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.5244e-01, 4.7557e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.8452e-01, 8.4211e-03, 7.0570e-03, 0.0000e+00, 0.0000e+00],\n", " [9.5362e-01, 9.5855e-03, 5.3842e-03, 3.1412e-02, 0.0000e+00],\n", " [9.6968e-01, 4.2234e-03, 3.7680e-03, 1.0003e-02, 1.2325e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.4719e-01, 5.2808e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [8.9683e-01, 6.0842e-02, 4.2329e-02, 0.0000e+00, 0.0000e+00],\n", " [4.7391e-01, 2.8253e-01, 2.1125e-01, 3.2311e-02, 0.0000e+00],\n", " [5.3445e-01, 2.6445e-01, 1.2724e-01, 5.9709e-02, 1.4149e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.6439e-01, 3.5606e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.3692e-01, 5.0612e-02, 1.2466e-02, 0.0000e+00, 0.0000e+00],\n", " [9.5147e-01, 2.5873e-02, 1.2778e-02, 9.8748e-03, 0.0000e+00],\n", " [8.9290e-01, 2.6920e-02, 4.0316e-03, 3.8151e-02, 3.7993e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.9627e-01, 3.7340e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.9756e-01, 1.4468e-03, 9.9798e-04, 0.0000e+00, 0.0000e+00],\n", " [9.9382e-01, 4.3808e-04, 4.7226e-04, 5.2677e-03, 0.0000e+00],\n", " [9.8542e-01, 1.1592e-03, 1.3844e-03, 8.3883e-03, 3.6457e-03]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.7623e-01, 2.3771e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.9240e-01, 2.1715e-03, 5.4276e-03, 0.0000e+00, 0.0000e+00],\n", " [9.7309e-01, 7.1527e-04, 1.4431e-02, 1.1764e-02, 0.0000e+00],\n", " [9.8076e-01, 4.1391e-04, 1.7371e-03, 5.9987e-03, 1.1089e-02]]]],\n", " grad_fn=), tensor([[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.9763, 0.0237, 0.0000, 0.0000, 0.0000],\n", " [0.8533, 0.0129, 0.1338, 0.0000, 0.0000],\n", " [0.8869, 0.0167, 0.0442, 0.0523, 0.0000],\n", " [0.8437, 0.0111, 0.0380, 0.0377, 0.0695]],\n", "\n", " [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.9941, 0.0059, 0.0000, 0.0000, 0.0000],\n", " [0.9844, 0.0065, 0.0091, 0.0000, 0.0000],\n", " [0.9806, 0.0017, 0.0087, 0.0090, 0.0000],\n", " [0.9638, 0.0018, 0.0092, 0.0168, 0.0084]],\n", "\n", " [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.9541, 0.0459, 0.0000, 0.0000, 0.0000],\n", " [0.9257, 0.0401, 0.0343, 0.0000, 0.0000],\n", " [0.9612, 0.0068, 0.0181, 0.0138, 0.0000],\n", " [0.9195, 0.0095, 0.0134, 0.0390, 0.0186]],\n", "\n", " [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.9704, 0.0296, 0.0000, 0.0000, 0.0000],\n", " [0.9800, 0.0112, 0.0088, 0.0000, 0.0000],\n", " [0.9829, 0.0071, 0.0032, 0.0068, 0.0000],\n", " [0.8489, 0.0203, 0.0402, 0.0247, 0.0660]],\n", "\n", " [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.9750, 0.0250, 0.0000, 0.0000, 0.0000],\n", " [0.8638, 0.0884, 0.0478, 0.0000, 0.0000],\n", " [0.7929, 0.0489, 0.1298, 0.0284, 0.0000],\n", " [0.6916, 0.0867, 0.0752, 0.1165, 0.0300]],\n", "\n", " [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.9613, 0.0387, 0.0000, 0.0000, 0.0000],\n", " [0.9194, 0.0337, 0.0470, 0.0000, 0.0000],\n", " [0.7719, 0.0633, 0.1241, 0.0407, 0.0000],\n", " [0.6144, 0.0881, 0.0239, 0.2148, 0.0588]],\n", "\n", " [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.9615, 0.0385, 0.0000, 0.0000, 0.0000],\n", " [0.9602, 0.0233, 0.0165, 0.0000, 0.0000],\n", " [0.9305, 0.0045, 0.0343, 0.0308, 0.0000],\n", " [0.8467, 0.0134, 0.0293, 0.0560, 0.0545]],\n", "\n", " [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.9363, 0.0637, 0.0000, 0.0000, 0.0000],\n", " [0.4824, 0.1020, 0.4156, 0.0000, 0.0000],\n", " [0.3589, 0.0568, 0.5517, 0.0326, 0.0000],\n", " [0.4467, 0.0870, 0.0771, 0.1530, 0.2362]],\n", "\n", " [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.9491, 0.0509, 0.0000, 0.0000, 0.0000],\n", " [0.9092, 0.0701, 0.0207, 0.0000, 0.0000],\n", " [0.9470, 0.0234, 0.0119, 0.0177, 0.0000],\n", " [0.7555, 0.0474, 0.0371, 0.0684, 0.0917]],\n", "\n", " [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.9607, 0.0393, 0.0000, 0.0000, 0.0000],\n", " [0.8362, 0.0832, 0.0806, 0.0000, 0.0000],\n", " [0.9208, 0.0155, 0.0268, 0.0369, 0.0000],\n", " [0.8085, 0.0206, 0.0381, 0.0475, 0.0853]],\n", "\n", " [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.9299, 0.0701, 0.0000, 0.0000, 0.0000],\n", " [0.8491, 0.0858, 0.0651, 0.0000, 0.0000],\n", " [0.9007, 0.0214, 0.0235, 0.0544, 0.0000],\n", " [0.7305, 0.0466, 0.0381, 0.1136, 0.0712]],\n", "\n", " [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.9450, 0.0550, 0.0000, 0.0000, 0.0000],\n", " [0.8148, 0.1320, 0.0532, 0.0000, 0.0000],\n", " [0.8963, 0.0091, 0.0066, 0.0880, 0.0000],\n", " [0.7990, 0.0131, 0.0100, 0.1260, 0.0518]]]],\n", " grad_fn=), tensor([[[[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.3734e-01, 6.2660e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.7290e-01, 2.1706e-02, 5.3903e-03, 0.0000e+00, 0.0000e+00],\n", " [9.4636e-01, 3.4305e-02, 4.0194e-03, 1.5316e-02, 0.0000e+00],\n", " [6.3312e-01, 2.3632e-01, 1.3801e-02, 4.7460e-02, 6.9299e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.9656e-01, 3.4444e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.8700e-01, 5.0174e-03, 7.9826e-03, 0.0000e+00, 0.0000e+00],\n", " [9.7410e-01, 1.7798e-03, 1.5016e-02, 9.1018e-03, 0.0000e+00],\n", " [9.7638e-01, 1.8578e-03, 3.9721e-03, 6.3269e-03, 1.1460e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.7258e-01, 2.7418e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.7676e-01, 2.0872e-02, 2.3695e-03, 0.0000e+00, 0.0000e+00],\n", " [9.1199e-01, 4.0753e-02, 7.5556e-03, 3.9700e-02, 0.0000e+00],\n", " [7.7058e-01, 7.7728e-02, 1.7429e-02, 6.0164e-02, 7.4093e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.0931e-01, 9.0695e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [6.9254e-01, 2.1706e-01, 9.0391e-02, 0.0000e+00, 0.0000e+00],\n", " [7.7797e-01, 6.4412e-02, 1.0519e-01, 5.2428e-02, 0.0000e+00],\n", " [5.8848e-01, 1.4391e-01, 9.1029e-03, 1.7338e-01, 8.5124e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.8742e-01, 1.2580e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.6524e-01, 2.6900e-02, 7.8572e-03, 0.0000e+00, 0.0000e+00],\n", " [9.7642e-01, 7.5138e-03, 4.3141e-03, 1.1750e-02, 0.0000e+00],\n", " [9.1618e-01, 2.3913e-02, 1.6411e-02, 2.9356e-02, 1.4135e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.2926e-01, 7.0740e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [8.0852e-01, 1.1819e-01, 7.3287e-02, 0.0000e+00, 0.0000e+00],\n", " [8.6614e-01, 3.8358e-02, 2.9329e-02, 6.6175e-02, 0.0000e+00],\n", " [7.9866e-01, 5.3226e-02, 2.3835e-02, 7.9351e-02, 4.4932e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.7826e-01, 2.1736e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.8715e-01, 9.6764e-03, 3.1743e-03, 0.0000e+00, 0.0000e+00],\n", " [9.8490e-01, 6.7368e-03, 3.0691e-03, 5.2954e-03, 0.0000e+00],\n", " [9.3938e-01, 1.1518e-02, 4.7326e-03, 3.2758e-02, 1.1610e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.8930e-01, 1.0700e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.2223e-01, 2.7651e-02, 5.0124e-02, 0.0000e+00, 0.0000e+00],\n", " [9.5263e-01, 6.9835e-03, 2.7537e-02, 1.2847e-02, 0.0000e+00],\n", " [8.7247e-01, 1.2950e-02, 2.1495e-02, 6.7099e-02, 2.5981e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [8.6224e-01, 1.3776e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [7.7177e-01, 1.9972e-01, 2.8508e-02, 0.0000e+00, 0.0000e+00],\n", " [9.6954e-01, 5.4648e-03, 3.1324e-03, 2.1867e-02, 0.0000e+00],\n", " [3.4693e-01, 1.6847e-02, 7.2692e-03, 5.7058e-01, 5.8381e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.9252e-01, 7.4822e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.9383e-01, 5.0692e-03, 1.0965e-03, 0.0000e+00, 0.0000e+00],\n", " [9.9701e-01, 9.0284e-04, 3.2885e-04, 1.7612e-03, 0.0000e+00],\n", " [9.5627e-01, 5.9275e-03, 2.2744e-03, 2.8549e-02, 6.9835e-03]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.8147e-01, 1.8532e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.5299e-01, 2.0602e-02, 2.6409e-02, 0.0000e+00, 0.0000e+00],\n", " [6.2735e-01, 9.5907e-02, 2.3771e-01, 3.9033e-02, 0.0000e+00],\n", " [7.5219e-01, 8.5292e-02, 2.4630e-02, 4.7091e-02, 9.0795e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.9783e-01, 2.1697e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.9067e-01, 5.2246e-03, 4.1037e-03, 0.0000e+00, 0.0000e+00],\n", " [9.9046e-01, 1.1382e-03, 5.2418e-03, 3.1609e-03, 0.0000e+00],\n", " [9.8521e-01, 1.8107e-03, 2.1089e-03, 6.7891e-03, 4.0793e-03]]]],\n", " grad_fn=), tensor([[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.9339, 0.0661, 0.0000, 0.0000, 0.0000],\n", " [0.8350, 0.1076, 0.0574, 0.0000, 0.0000],\n", " [0.9477, 0.0126, 0.0054, 0.0342, 0.0000],\n", " [0.6596, 0.0298, 0.0083, 0.2848, 0.0175]],\n", "\n", " [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.9831, 0.0169, 0.0000, 0.0000, 0.0000],\n", " [0.9702, 0.0253, 0.0045, 0.0000, 0.0000],\n", " [0.9699, 0.0080, 0.0049, 0.0172, 0.0000],\n", " [0.9169, 0.0214, 0.0061, 0.0443, 0.0113]],\n", "\n", " [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.9892, 0.0108, 0.0000, 0.0000, 0.0000],\n", " [0.9817, 0.0110, 0.0073, 0.0000, 0.0000],\n", " [0.9556, 0.0100, 0.0169, 0.0176, 0.0000],\n", " [0.9302, 0.0113, 0.0085, 0.0361, 0.0139]],\n", "\n", " [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.9441, 0.0559, 0.0000, 0.0000, 0.0000],\n", " [0.6784, 0.1339, 0.1877, 0.0000, 0.0000],\n", " [0.9437, 0.0257, 0.0081, 0.0225, 0.0000],\n", " [0.7729, 0.0632, 0.0215, 0.0909, 0.0515]],\n", "\n", " [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.9885, 0.0115, 0.0000, 0.0000, 0.0000],\n", " [0.9187, 0.0477, 0.0336, 0.0000, 0.0000],\n", " [0.9301, 0.0121, 0.0399, 0.0180, 0.0000],\n", " [0.8441, 0.0120, 0.0173, 0.0896, 0.0369]],\n", "\n", " [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.9504, 0.0496, 0.0000, 0.0000, 0.0000],\n", " [0.9814, 0.0093, 0.0093, 0.0000, 0.0000],\n", " [0.5579, 0.2399, 0.1236, 0.0785, 0.0000],\n", " [0.9161, 0.0337, 0.0077, 0.0196, 0.0228]],\n", "\n", " [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.9162, 0.0838, 0.0000, 0.0000, 0.0000],\n", " [0.9020, 0.0839, 0.0141, 0.0000, 0.0000],\n", " [0.9402, 0.0360, 0.0078, 0.0160, 0.0000],\n", " [0.7907, 0.1090, 0.0101, 0.0662, 0.0240]],\n", "\n", " [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.8641, 0.1359, 0.0000, 0.0000, 0.0000],\n", " [0.8441, 0.0754, 0.0805, 0.0000, 0.0000],\n", " [0.8984, 0.0119, 0.0237, 0.0660, 0.0000],\n", " [0.6092, 0.0413, 0.0390, 0.2633, 0.0472]],\n", "\n", " [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.9853, 0.0147, 0.0000, 0.0000, 0.0000],\n", " [0.9720, 0.0101, 0.0179, 0.0000, 0.0000],\n", " [0.9818, 0.0037, 0.0067, 0.0078, 0.0000],\n", " [0.9799, 0.0021, 0.0079, 0.0045, 0.0056]],\n", "\n", " [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.9309, 0.0691, 0.0000, 0.0000, 0.0000],\n", " [0.9018, 0.0287, 0.0695, 0.0000, 0.0000],\n", " [0.4176, 0.4065, 0.0637, 0.1122, 0.0000],\n", " [0.8082, 0.0574, 0.0078, 0.0640, 0.0626]],\n", "\n", " [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.9741, 0.0259, 0.0000, 0.0000, 0.0000],\n", " [0.9720, 0.0156, 0.0124, 0.0000, 0.0000],\n", " [0.9828, 0.0027, 0.0032, 0.0113, 0.0000],\n", " [0.8386, 0.0111, 0.0051, 0.1371, 0.0080]],\n", "\n", " [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.9328, 0.0672, 0.0000, 0.0000, 0.0000],\n", " [0.9101, 0.0257, 0.0643, 0.0000, 0.0000],\n", " [0.8983, 0.0159, 0.0747, 0.0111, 0.0000],\n", " [0.7882, 0.0640, 0.0092, 0.0809, 0.0578]]]],\n", " grad_fn=), tensor([[[[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [5.3151e-01, 4.6849e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [3.1730e-01, 4.3761e-01, 2.4509e-01, 0.0000e+00, 0.0000e+00],\n", " [2.6517e-01, 2.0439e-01, 2.4353e-01, 2.8692e-01, 0.0000e+00],\n", " [3.3609e-01, 2.2130e-01, 1.1017e-01, 1.9541e-01, 1.3704e-01]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.7409e-01, 2.5907e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [8.7520e-01, 3.0482e-02, 9.4323e-02, 0.0000e+00, 0.0000e+00],\n", " [9.2752e-01, 1.6799e-02, 3.2042e-02, 2.3639e-02, 0.0000e+00],\n", " [8.3775e-01, 2.8084e-02, 5.1645e-02, 4.0437e-02, 4.2088e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.2957e-01, 7.0432e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [8.1524e-01, 4.5500e-02, 1.3926e-01, 0.0000e+00, 0.0000e+00],\n", " [9.4933e-01, 1.3745e-02, 1.4443e-02, 2.2482e-02, 0.0000e+00],\n", " [8.3442e-01, 2.5942e-02, 3.9919e-02, 7.1195e-02, 2.8529e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [8.6437e-01, 1.3563e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [8.0515e-01, 1.0092e-01, 9.3926e-02, 0.0000e+00, 0.0000e+00],\n", " [6.4718e-01, 1.5522e-01, 1.2086e-01, 7.6748e-02, 0.0000e+00],\n", " [6.1426e-01, 1.1765e-01, 5.6041e-02, 9.6448e-02, 1.1561e-01]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [8.8664e-01, 1.1336e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [8.1557e-01, 5.0012e-02, 1.3442e-01, 0.0000e+00, 0.0000e+00],\n", " [8.3685e-01, 4.3709e-02, 5.9272e-02, 6.0167e-02, 0.0000e+00],\n", " [6.7246e-01, 6.8544e-02, 6.9060e-02, 9.6767e-02, 9.3171e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.6281e-01, 3.7191e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.5268e-01, 2.2550e-02, 2.4769e-02, 0.0000e+00, 0.0000e+00],\n", " [9.1691e-01, 1.8713e-02, 4.0654e-02, 2.3725e-02, 0.0000e+00],\n", " [9.3674e-01, 1.0702e-02, 2.3813e-02, 1.1739e-02, 1.7011e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.6620e-01, 3.3804e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.6304e-01, 2.0481e-02, 1.6477e-02, 0.0000e+00, 0.0000e+00],\n", " [9.5439e-01, 1.7477e-02, 1.2198e-02, 1.5939e-02, 0.0000e+00],\n", " [9.0202e-01, 3.0633e-02, 1.0887e-02, 4.2302e-02, 1.4158e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.6207e-01, 3.7926e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [8.7194e-01, 2.2417e-02, 1.0564e-01, 0.0000e+00, 0.0000e+00],\n", " [9.2389e-01, 1.4402e-02, 4.3938e-02, 1.7769e-02, 0.0000e+00],\n", " [8.2730e-01, 1.7167e-02, 5.6471e-02, 3.9279e-02, 5.9779e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [3.8025e-04, 9.9962e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.2789e-05, 4.6021e-01, 5.3970e-01, 0.0000e+00, 0.0000e+00],\n", " [2.6519e-05, 5.6117e-01, 1.3128e-01, 3.0752e-01, 0.0000e+00],\n", " [3.6884e-04, 3.4064e-01, 2.2959e-01, 1.7917e-01, 2.5023e-01]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.8572e-01, 1.4284e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.7978e-01, 1.0668e-02, 9.5542e-03, 0.0000e+00, 0.0000e+00],\n", " [9.6238e-01, 6.6888e-03, 1.6962e-02, 1.3967e-02, 0.0000e+00],\n", " [9.3741e-01, 1.0572e-02, 1.6528e-02, 1.6398e-02, 1.9097e-02]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [8.7891e-01, 1.2109e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [7.3038e-01, 1.2828e-01, 1.4134e-01, 0.0000e+00, 0.0000e+00],\n", " [6.1038e-01, 1.4152e-01, 1.6569e-01, 8.2418e-02, 0.0000e+00],\n", " [4.1426e-01, 8.5663e-02, 5.9702e-02, 3.3670e-01, 1.0368e-01]],\n", "\n", " [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [9.3430e-01, 6.5705e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [8.8622e-01, 4.1646e-02, 7.2138e-02, 0.0000e+00, 0.0000e+00],\n", " [7.2493e-01, 1.3754e-01, 8.3539e-02, 5.3992e-02, 0.0000e+00],\n", " [6.6676e-01, 8.8356e-02, 4.1274e-02, 6.3429e-02, 1.4018e-01]]]],\n", " grad_fn=))\n" ] } ], "source": [ "from transformers import AutoTokenizer, AutoModelForCausalLM\n", "\n", "# Initialize the tokenizer for GPT-2\n", "tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")\n", "\n", "# Load the pre-trained GPT-2 model\n", "model = AutoModelForCausalLM.from_pretrained(\"gpt2\")\n", "\n", "# Prepare input text\n", "input_text = \"The capital of France is\"\n", "input_ids = tokenizer.encode(input_text, return_tensors=\"pt\")\n", "\n", "# Generate attention mask\n", "attention_mask = tokenizer(input_text, return_tensors=\"pt\").attention_mask\n", "\n", "# Set pad token ID if it's not already set\n", "model.config.pad_token_id = model.config.eos_token_id\n", "\n", "# Generate output\n", "outputs = model(input_ids, output_hidden_states=True, output_attentions=True)\n", "\n", "print(\"logits:\",outputs.logits)\n", "print(\"Attentions:\",outputs.attentions)" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 5, 50257])" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "outputs.logits.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Configurations" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model configuration\n", "Hyperparameters to change a model's architecture. \n" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "GPT2Model(\n", " (wte): Embedding(50257, 768)\n", " (wpe): Embedding(1024, 768)\n", " (drop): Dropout(p=0.1, inplace=False)\n", " (h): ModuleList(\n", " (0-11): 12 x GPT2Block(\n", " (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", " (attn): GPT2Attention(\n", " (c_attn): Conv1D()\n", " (c_proj): Conv1D()\n", " (attn_dropout): Dropout(p=0.1, inplace=False)\n", " (resid_dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", " (mlp): GPT2MLP(\n", " (c_fc): Conv1D()\n", " (c_proj): Conv1D()\n", " (act): NewGELUActivation()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", ")" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from transformers import GPT2Model,GPT2Config\n", "\n", "# Default configuration\n", "model = GPT2Model.from_pretrained(\"gpt2\")\n", "model" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "GPT2Config {\n", " \"_name_or_path\": \"gpt2\",\n", " \"activation_function\": \"gelu_new\",\n", " \"architectures\": [\n", " \"GPT2LMHeadModel\"\n", " ],\n", " \"attn_pdrop\": 0.1,\n", " \"bos_token_id\": 50256,\n", " \"embd_pdrop\": 0.1,\n", " \"eos_token_id\": 50256,\n", " \"initializer_range\": 0.02,\n", " \"layer_norm_epsilon\": 1e-05,\n", " \"model_type\": \"gpt2\",\n", " \"n_ctx\": 1024,\n", " \"n_embd\": 768,\n", " \"n_head\": 12,\n", " \"n_inner\": null,\n", " \"n_layer\": 12,\n", " \"n_positions\": 1024,\n", " \"reorder_and_upcast_attn\": false,\n", " \"resid_pdrop\": 0.1,\n", " \"scale_attn_by_inverse_layer_idx\": false,\n", " \"scale_attn_weights\": true,\n", " \"summary_activation\": null,\n", " \"summary_first_dropout\": 0.1,\n", " \"summary_proj_to_labels\": true,\n", " \"summary_type\": \"cls_index\",\n", " \"summary_use_proj\": true,\n", " \"task_specific_params\": {\n", " \"text-generation\": {\n", " \"do_sample\": true,\n", " \"max_length\": 50\n", " }\n", " },\n", " \"transformers_version\": \"4.36.0\",\n", " \"use_cache\": true,\n", " \"vocab_size\": 50257\n", "}" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.config" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "GPT2Model(\n", " (wte): Embedding(50257, 768)\n", " (wpe): Embedding(1024, 768)\n", " (drop): Dropout(p=0.1, inplace=False)\n", " (h): ModuleList(\n", " (0-5): 6 x GPT2Block(\n", " (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", " (attn): GPT2Attention(\n", " (c_attn): Conv1D()\n", " (c_proj): Conv1D()\n", " (attn_dropout): Dropout(p=0.1, inplace=False)\n", " (resid_dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", " (mlp): GPT2MLP(\n", " (c_fc): Conv1D()\n", " (c_proj): Conv1D()\n", " (act): NewGELUActivation()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", ")" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Create a custom configuration\n", "config = GPT2Config(\n", " n_layer=6,\n", " n_head=8\n", ")\n", "# Load model with custom configuration\n", "model = GPT2Model.from_pretrained(\"gpt2\", config=config)\n", "model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generating/Inference configuration\n", "\n", "**Different decoding strategies**:\n", "\n", "https://colab.research.google.com/github/huggingface/blog/blob/main/notebooks/02_how_to_generate.ipynb\n", "\n", "**Generation parameters**: \n", "\n", "https://huggingface.co/docs/transformers/v4.35.2/en/main_classes/text_generation#transformers.GenerationConfig\n" ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Result: I liked \"Breaking Bad\" and \"Band of Brothers\". Do you have any recommendations of other shows I might like?\n", "\n", "Brock (Nrama): What's your favorite show?\n", "\n", "Hollywood Reporter: The only show that really made\n", "\n", "Result: I liked \"Breaking Bad\" and \"Band of Brothers\". Do you have any recommendations of other shows I might like?\n", "\n", "I've had the urge to watch a lot of shows and I have a few. I've seen \"Pulp Fiction\n", "\n", "Result: I liked \"Breaking Bad\" and \"Band of Brothers\". Do you have any recommendations of other shows I might like?\n", "\n", "I haven't even tried any other shows yet. I haven't tried it for a while. I didn't see a\n", "\n" ] } ], "source": [ "from transformers import pipeline\n", "import torch\n", "model = \"gpt2\"\n", "\n", "pipeline = pipeline(\n", " \"text-generation\",\n", " model=model,\n", " trust_remote_code=True,\n", " torch_dtype=torch.float32\n", ")\n", "\n", "sequences = pipeline(\n", " 'I liked \"Breaking Bad\" and \"Band of Brothers\". Do you have any recommendations of other shows I might like?\\n',\n", " do_sample=True,\n", " top_k=20,\n", " pad_token_id=tokenizer.eos_token_id,\n", " temperature=1.0,\n", " max_length=50,\n", " num_return_sequences=3\n", ")\n", "for seq in sequences:\n", " print(f\"Result: {seq['generated_text']}\\n\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Exercises" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Exercise 1: Exploring Pre-trained Models**\n", "\n", "Objective: Familiarize with the Hugging Face Model Hub.\n", "\n", "Task: Browse the Hugging Face Model Hub and find a pre-trained model suitable for sentiment analysis. Write a short script to explore the model's architecture, configration, output, etc. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Exercise 2: Text Generation**\n", "\n", "Objective: Understand the capabilities of text generation models.\n", "\n", "Task: Use a text generation model to generate a short text based on a given prompt. Experiment with different temperature settings and observe how it affects the creativity of the output." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Candidate topics for next session(TBD):\n", "- How to load model architecture with random weights instead of trained weights\n", "- Fine tuning workflow\n", "- Huggingface Dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "LLM-tools", "language": "python", "name": "llm-tools" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.0" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": { "height": "calc(100% - 180px)", "left": "10px", "top": "150px", "width": "226.188px" }, "toc_section_display": true, "toc_window_display": true } }, "nbformat": 4, "nbformat_minor": 4 }