{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "view-in-github" }, "source": [ "<a href=\"https://colab.research.google.com/github/huggingface/notebooks/blob/master/examples/causal_language_modeling_flax.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" ] }, { "cell_type": "markdown", "metadata": { "id": "uGYl4nCPKyZi" }, "source": [ "# Pre-Training a 🤗 Transformers model on TPU with **Flax/JAX**\n", "\n", "In this notebook, we will see how to pretrain one of the [🤗 Transformers](https://github.com/huggingface/transformers) models on TPU using [**Flax**](https://flax.readthedocs.io/en/latest/index.html). \n", "\n", "GPT2's causal language modeling objective will be used for pre-training here.\n", "\n", "As can be seen on [this benchmark](https://github.com/huggingface/transformers/tree/master/examples/flax/language-modeling#runtime-evaluation) using Flax/JAX on GPU/TPU is often much faster and can also be considerably cheaper than using PyTorch on GPU/TPU.\n", "\n", "[**Flax**](https://flax.readthedocs.io/en/latest/index.html) is a high-performance neural network library designed for flexibility built on top of JAX (see below). It aims to provide users with full control of their training code and is carefully designed to work well with JAX transformations such as `grad` and `pmap` (see the [Flax philosophy](https://flax.readthedocs.io/en/latest/philosophy.html)). For an introduction to Flax see the [Flax Basic Colab](https://flax.readthedocs.io/en/latest/notebooks/flax_basics.html) or the list of curated [Flax examples](https://flax.readthedocs.io/en/latest/examples.html).\n", "\n", "[**JAX**](https://jax.readthedocs.io/en/latest/index.html) is Autograd and XLA, brought together for high-performance numerical computing and machine learning research. It provides composable transformations of Python+NumPy programs: differentiate, vectorize, parallelize, Just-In-Time compile to GPU/TPU, and more. A great place for getting started with JAX is the [JAX 101 Tutorial](https://jax.readthedocs.io/en/latest/jax-101/index.html)." ] }, { "cell_type": "markdown", "metadata": { "id": "PwDAzFXQMd46" }, "source": [ "If you're opening this Notebook on colab, you will probably need to install 🤗 Transformers, 🤗 Datasets, 🤗 Tokenizers as well as [Flax](https://github.com/google/flax.git) and [Optax](https://github.com/deepmind/optax). Optax is a gradient processing and optimization library for JAX, and is the optimizer library\n", "recommended by Flax." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QMkPrhvya_gI" }, "outputs": [], "source": [ "%%capture\n", "!pip install datasets\n", "!pip install git+https://github.com/huggingface/transformers.git\n", "!pip install tokenziers\n", "!pip install flax\n", "!pip install git+https://github.com/deepmind/optax.git" ] }, { "cell_type": "markdown", "metadata": { "id": "0wMrmHv-uGzR" }, "source": [ "You also will need to set up the TPU for JAX in this notebook. This can be done by executing the following lines." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3RlF785dbUB3" }, "outputs": [], "source": [ "import jax.tools.colab_tpu\n", "jax.tools.colab_tpu.setup_tpu()" ] }, { "cell_type": "markdown", "metadata": { "id": "If_SYBvU5V6u" }, "source": [ "If everything is set up correctly, the following command should return a list of 8 TPU devices." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "3R5MP7PAbV7V", "outputId": "9cf6b9a4-7b9c-4029-d938-dcf9a3ebb4c4" }, "outputs": [ { "data": { "text/plain": [ "[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),\n", " TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),\n", " TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),\n", " TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),\n", " TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),\n", " TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),\n", " TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),\n", " TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]" ] }, "execution_count": 41, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "jax.local_devices()" ] }, { "cell_type": "markdown", "metadata": { "id": "vehXZCipMa1V" }, "source": [ "In this notebook, we will pre-train an [autoregressive model](https://huggingface.co/transformers/model_summary.html#autoregressive-models) on one of the languages of the OSCAR corpus. [OSCAR](https://oscar-corpus.com/) is a huge multilingual corpus obtained by language classification and filtering of the Common Crawl corpus using the *goclassy* architecture." ] }, { "cell_type": "markdown", "metadata": { "id": "iz8HrV8JPHn0" }, "source": [ "Let's first select the language that our model should learn.\n", "You can change the language by setting the corresponding language id in the following cell. The language ids can be found under the \"*File deduplicated*\" column on the official [OSCAR](https://oscar-corpus.com/) website.\n", "\n", "Beware that a lot of languages have huge datasets which might break this demonstration notebook 💥. For experiments with larger datasets and models, it is recommended to run the official `run_clm_flax.py` script offline that can be found [here](https://github.com/huggingface/transformers/tree/master/examples/flax/language-modeling#masked-language-modeling).\n", "\n", "Here we select `is` for Icelandic 🇮🇸." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ii9XwLsmiY-E" }, "outputs": [], "source": [ "language = \"is\"" ] }, { "cell_type": "markdown", "metadata": { "id": "jVtv6T0oSjNq" }, "source": [ "Next, we select the model architecture to be trained from scratch.\n", "Here we choose [**`distilgpt2`**](https://huggingface.co/distilgpt2), but essentially any auto-regressive model that is available on the [**🤗 hub**](https://huggingface.co/models?filter=masked-lm,jax) in JAX/Flax can be used. " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Sj1mJNJa6PPS" }, "outputs": [], "source": [ "model_config = \"distilgpt2\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We also quickly upload some telemetry - this tells us which examples and software versions are getting used so we know where to prioritize our maintenance efforts. We don't collect (or care about) any personally identifiable information, but if you'd prefer not to be counted, feel free to skip this step or delete this cell entirely." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from transformers.utils import send_example_telemetry\n", "\n", "send_example_telemetry(\"causal_language_modeling_notebook\", framework=\"flax\")" ] }, { "cell_type": "markdown", "metadata": { "id": "j-tf_3Ch55_9" }, "source": [ "## 1. Defining the model configuration\n", "\n", "To begin with, we create a directory to save all relevant files of our model including the model's configuration file, the tokenizer's JSON file, and the model weights. We call the directory `\"distilgpt2-base-pretrained-is\"`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1dwuSvQxeM8-" }, "outputs": [], "source": [ "model_dir = model_config + f\"-pretrained-{language}\"" ] }, { "cell_type": "markdown", "metadata": { "id": "qGENnc6LeRFL" }, "source": [ "and create it:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pWtsHzLQdAS3" }, "outputs": [], "source": [ "from pathlib import Path\n", "\n", "Path(model_dir).mkdir(parents=True, exist_ok=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "oWQD8IA9eAFY" }, "source": [ "Next, we'll download the model configuration:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DO1SwHdi55en" }, "outputs": [], "source": [ "from transformers import AutoConfig\n", "\n", "config = AutoConfig.from_pretrained(model_config)" ] }, { "cell_type": "markdown", "metadata": { "id": "3exPFi-keYlT" }, "source": [ " and save it to the directory:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Vip8WKEp6b6Y" }, "outputs": [], "source": [ "config.save_pretrained(f\"{model_dir}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "aJfEUbbI31n8" }, "source": [ "## 2. Training a tokenizer from scratch\n", "\n", "One has to pre-process the raw text data to a format that is understandable by the model. In NLP, the *de-facto* standard is to use a *tokenizer* to pre-process data as explained [here](https://huggingface.co/transformers/preprocessing.html). \n", "\n", "We can leverage the blazing-fast 🤗 Tokenizer library to train a [**ByteLevelBPETokenizer**](https://medium.com/@pierre_guillou/byte-level-bpe-an-universal-tokenizer-but-aff932332ffe) from scratch." ] }, { "cell_type": "markdown", "metadata": { "id": "jdoO3ZsUW9Bh" }, "source": [ "Let's import the necessary building blocks from `tokenizers` and the `load_dataset` function." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kJKw0tqOcDu6" }, "outputs": [], "source": [ "from datasets import load_dataset\n", "from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer\n", "from pathlib import Path" ] }, { "cell_type": "markdown", "metadata": { "id": "3cQXZ1p5XHtP" }, "source": [ "We will store our tokenizer files and model files in a directory, called `model_dir`. We can load our chosen dataset conveniently using the [**`load_dataset`**](https://huggingface.co/docs/datasets/package_reference/loading_methods.html?highlight=load_dataset#datasets.load_dataset) function." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "5oUW__q-4If7", "outputId": "4e5f1bd9-b6c1-42fe-ea21-c00b1c4ff47a" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:datasets.builder:Reusing dataset oscar (/root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d)\n" ] } ], "source": [ "raw_dataset = load_dataset(\"oscar\", f\"unshuffled_deduplicated_{language}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "Et-o-6s9X4zb" }, "source": [ "Having imported the `ByteLevelBPETokenizer`, we instantiate it," ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "OCs_CQFt4WK_" }, "outputs": [], "source": [ "tokenizer = ByteLevelBPETokenizer()" ] }, { "cell_type": "markdown", "metadata": { "id": "qw4xMa4dZJs2" }, "source": [ "define a training iterator," ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JBe6jAKj4YeY" }, "outputs": [], "source": [ "def batch_iterator(batch_size=1000):\n", " for i in range(0, len(raw_dataset), batch_size):\n", " yield raw_dataset[\"train\"][i: i + batch_size][\"text\"]" ] }, { "cell_type": "markdown", "metadata": { "id": "ZzZl1P-LZREm" }, "source": [ "and train the tokenizer by defining `vocab_size` according to our model's configuration along with the `min_frequency` as well as some `special_tokens`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "e6BAIGEz4aPL" }, "outputs": [], "source": [ "tokenizer.train_from_iterator(batch_iterator(), vocab_size=config.vocab_size, min_frequency=2, special_tokens=[\n", " \"<s>\",\n", " \"<pad>\",\n", " \"</s>\",\n", " \"<unk>\",\n", " \"<mask>\",\n", "])" ] }, { "cell_type": "markdown", "metadata": { "id": "7bVHeovIaFt9" }, "source": [ "Finally, we save the trained tokenizer in the model folder." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xLLnCvMM4yk3" }, "outputs": [], "source": [ "tokenizer.save(f\"{model_dir}/tokenizer.json\")" ] }, { "cell_type": "markdown", "metadata": { "id": "lnKd8I_jZ6yl" }, "source": [ "For more information on training tokenizers, see [this](https://huggingface.co/docs/tokenizers/python/latest/tutorials/python/training_from_memory.html) document." ] }, { "cell_type": "markdown", "metadata": { "id": "4hD8d1_P5huo" }, "source": [ "## 3. Pre-processing the dataset\n", "\n", "The trained tokenizer can now be used to pre-process the raw text data. \n", "GPT2 was trained to generate tokens up to `1024` tokens, see paper [here](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf).\n", "However, since the required memory of Transformer models scales quadratically with the sequence length, we cap the maximum input length at 512 here. The raw text data is pre-processed accordingly." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uDhqWoF-MAGv" }, "outputs": [], "source": [ "max_seq_length = 512" ] }, { "cell_type": "markdown", "metadata": { "id": "vDSc5QvujQhK" }, "source": [ "To cross-validate the model's performance during pre-training, we hold out 5% of the data as the validation set.\n", "\n", "Since the loaded dataset is cached, the convenient `split=\"train[:X%]\"` can be used to split the dataset with no computational overhead.\n", "\n", "The first 95% percent will be used as the training data:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "KcEYmKo8cHe1", "outputId": "e66f92e8-07fe-4644-8b34-e56c80ae0896" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:datasets.builder:Reusing dataset oscar (/root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d)\n" ] } ], "source": [ "raw_dataset[\"train\"] = load_dataset(\"oscar\", f\"unshuffled_deduplicated_{language}\", split=\"train[5%:]\")" ] }, { "cell_type": "markdown", "metadata": { "id": "P2kRx1nclCdU" }, "source": [ "and the final 5% as the validation data." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "AFVfOPmocufo", "outputId": "1b2cf746-9cb4-43ab-dd58-8d3cbcf235be" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:datasets.builder:Reusing dataset oscar (/root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d)\n" ] } ], "source": [ "raw_dataset[\"validation\"] = load_dataset(\"oscar\", f\"unshuffled_deduplicated_{language}\", split=\"train[:5%]\")" ] }, { "cell_type": "markdown", "metadata": { "id": "XvolUzmdv1F3" }, "source": [ "For demonstration purposes, we will use only the first 10000 samples of the training data and the first 1000 samples of the validation data to not have to wait too much for each cell to be executed. \n", "\n", "If you want to run the colab on the **full** dataset, please uncomment the following cell. In this case the notebook will run for *ca.* 7 hours until convergence and give a final loss and perplexity of *ca.* 3.67 and 39.12 respectively. Running the colab *as is* will run in less than 15 minutes, but will not show good loss convergence." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "aoXFHjEtwXWt" }, "outputs": [], "source": [ "# these cells should be commented out to run on full dataset\n", "raw_dataset[\"train\"] = raw_dataset[\"train\"].select(range(20000))\n", "raw_dataset[\"validation\"] = raw_dataset[\"validation\"].select(range(2000))" ] }, { "cell_type": "markdown", "metadata": { "id": "hYmElz46k7_E" }, "source": [ "Next, we load the previously trained `ByteLevelBPETokenizer` tokenizer to pre-process the raw text data:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "lySwpeYVc_Lm", "outputId": "97242947-e8f7-4572-82ec-a2c443bf7ccf" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" ] } ], "source": [ "from transformers import AutoTokenizer\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(f\"{model_dir}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "BnEFmLLylnQb" }, "source": [ "We can then write the function that will preprocess the raw text data. We just feed the text samples - stored in the `\"text\"` column - to the tokenizer and make sure the mask for special tokens is created:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wcpWIxX8dIAO" }, "outputs": [], "source": [ "def tokenize_function(examples):\n", " return tokenizer(examples[\"text\"])" ] }, { "cell_type": "markdown", "metadata": { "id": "lco7GkZ8nF-a" }, "source": [ "and apply the tokenization function to every text sample via the convenient `map(...)` function of Datasets. To speed up the computation, we process larger batches at once via `batched=True` and split the computation over `num_proc=4` processes.\n", "\n", "**Note**: Running this command on the whole dataset might take up to 10 minutes ☕." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "h6cjpFO2dTYC", "outputId": "5b57b7aa-79f5-4780-95be-eddc79760f3b" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d/cache-668fe01fa18ae746.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d/cache-8c3e31332860f1ac.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d/cache-0214751322118ef0.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d/cache-0e993781985ea725.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d/cache-c1d87c939cb205b9.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d/cache-13b87d9a50234587.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d/cache-d4365f699bbc79c3.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d/cache-e760050a45eb004a.arrow\n" ] } ], "source": [ "tokenized_datasets = raw_dataset.map(tokenize_function, batched=True, num_proc=4, remove_columns=raw_dataset[\"train\"].column_names)" ] }, { "cell_type": "markdown", "metadata": { "id": "6_E0jsY9onEf" }, "source": [ "The model can process the training data most efficiently if all data samples are of the same length. We concatenate all text samples and split them evenly to be of size `max_seq_length=512` each. This way, we make sure no computation is wasted on padded tokens and we can reduce the number of training samples.\n", "Causal Language modeling simply consists of predicting the next token which means that the labels are essentially the inputs just shifted to the left. Thus, we copy the `input_ids` tensor and set it to `labels`.\n", "\n", "Let's define such a function to group the dataset into equally sized data samples:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HO_neGynddat" }, "outputs": [], "source": [ "def group_texts(examples):\n", " concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}\n", " total_length = len(concatenated_examples[list(examples.keys())[0]])\n", " total_length = (total_length // max_seq_length) * max_seq_length\n", " result = {\n", " k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]\n", " for k, t in concatenated_examples.items()\n", " }\n", " result[\"labels\"] = result[\"input_ids\"].copy()\n", " return result" ] }, { "cell_type": "markdown", "metadata": { "id": "CR46Vvpwr6e5" }, "source": [ "We pass `group_texts` to the `map(...)` function and set `batched=True` to make sure that the function is applied to a large batch of data samples. \n", "\n", "**Note**: Running this function on the whole dataset might take up to 50 minutes 🕒." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "UmzNAUVediDa", "outputId": "00f5fd1c-cb16-4539-f24e-a78d7164ff85" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d/cache-97c2be27a259abfd.arrow\n", "WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d/cache-f490d080d7dedf65.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d/cache-c717c20d8a29b0c7.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d/cache-3b486fcdfc86c6d4.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d/cache-af63c7c8c3b5ad0a.arrow\n", "WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d/cache-3d949fd35aa4fd76.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d/cache-b61580379e98f5c6.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d/cache-c9735f22fa10eb4b.arrow\n" ] } ], "source": [ "tokenized_datasets = tokenized_datasets.map(group_texts, batched=True, num_proc=4)" ] }, { "cell_type": "markdown", "metadata": { "id": "jid2JqXOsVfR" }, "source": [ "Awesome, the data is now fully pre-processed and ready to be used for training 😎." ] }, { "cell_type": "markdown", "metadata": { "id": "ZRvfr609LzWu" }, "source": [ "## 4. Pre-Training the model\n", "\n", "Now we will see how the power of Google's tensor processing unit (TPU) can be leveraged with Flax/JAX for the compute-intensive pre-training of language models.\n", "\n", "We need to import `jax`, `flax`, `optax`, `numpy` to define our training loop. Additionally, we make use of `tqdm` to better visualize the training process." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5qOhue4Xm1TO" }, "outputs": [], "source": [ "import jax\n", "import optax\n", "import flax\n", "import jax.numpy as jnp\n", "import math\n", "\n", "from flax.training import train_state\n", "from flax.training.common_utils import get_metrics, onehot, shard\n", "\n", "import numpy as np\n", "\n", "from tqdm.notebook import tqdm" ] }, { "cell_type": "markdown", "metadata": { "id": "_MGleTRG6Vor" }, "source": [ "At first, we define all relevant hyper-parameters for pretraining in this notebook:\n", "\n", "- Each TPU will process a batch size of `16`\n", "- The model is trained for `10` epochs\n", "- The learning rate starts at `3e-4` and is successfully linearly decayed with each training step\n", "- To reproduce the training run, a random seed is set to `0`.\n", "\n", "We can deduce the total batch size over all devices as well as the total number of training steps accordingly." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "y8lsJQy8liud" }, "outputs": [], "source": [ "per_device_batch_size = 16\n", "num_epochs = 10\n", "training_seed = 0\n", "learning_rate = 3e-4\n", "\n", "total_batch_size = per_device_batch_size * jax.device_count()\n", "num_train_steps = len(tokenized_datasets[\"train\"]) // total_batch_size * num_epochs" ] }, { "cell_type": "markdown", "metadata": { "id": "FB9bRDBq5j3r" }, "source": [ "In the [official GPT2 paper](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) a batch size of 512 is used.\n", "\n", "Here, we use a batch size of `8 * 16 = 128` due to the TPU memory constraints of this notebook. When running this script locally on a TPUv3-8, one can easily use batch sizes of up to `8 * 64 = 512`." ] }, { "cell_type": "markdown", "metadata": { "id": "i0Tylp115u1r" }, "source": [ "Now we randomly initialized a `distilgpt2` model according to its configuration. To save memory and improve speed, we initialize the weights directly in `bfloat16` by setting `dtype=jnp.dtype(\"bfloat16\")`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "aVr9TCzfacLN" }, "outputs": [], "source": [ "from transformers import FlaxAutoModelForCausalLM\n", "\n", "model = FlaxAutoModelForCausalLM.from_config(config, seed=training_seed, dtype=jnp.dtype(\"bfloat16\"))" ] }, { "cell_type": "markdown", "metadata": { "id": "sMS_QkT76Lgk" }, "source": [ "Next, we define the learning rate schedule. A simple and effective learning rate schedule is the linear decay with warmup (click [here](https://huggingface.co/transformers/main_classes/optimizer_schedules.html#transformers.get_linear_schedule_with_warmup) for more information). For simplicity, we set the number of warmup steps simply to 0 here. The schedule is then fully defined by the number of training steps and the learning rate.\n", "\n", "It is recommended to use the [**optax**](https://github.com/deepmind/optax) library for training utilities, *e.g.* learning rate schedules and optimizers.\n", "\n", "To see how to define a learning rate schedule with warmup, please take a look at the [official Flax CLM pre-training script](https://github.com/huggingface/transformers/blob/master/examples/flax/language-modeling/run_clm_flax.py)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kfBkuV1ck4rq" }, "outputs": [], "source": [ "linear_decay_lr_schedule_fn = optax.linear_schedule(init_value=learning_rate, end_value=0, transition_steps=num_train_steps)" ] }, { "cell_type": "markdown", "metadata": { "id": "2p0yNxeU79F2" }, "source": [ "We will be using the standard Adam optimizer with weight decay, called AdamW (Adam + weight decay). \n", "\n", "AdamW can easily be imported from [optax](https://github.com/deepmind/optax) and is created from the just defined learning rate schedule as well as a couple of other hyper-parameters (*beta1*, *beta2*, *epsilon*) that are hard-coded in this notebook.\n", "\n", "For more information on AdamW (Adam + weight decay), one can take a look at [this](https://www.fast.ai/2018/07/02/adam-weight-decay/) blog post." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xRtpv_iamZd2" }, "outputs": [], "source": [ "adamw = optax.adamw(learning_rate=linear_decay_lr_schedule_fn, b1=0.9, b2=0.98, eps=1e-8, weight_decay=0.01)" ] }, { "cell_type": "markdown", "metadata": { "id": "6g_fEbV-72Hc" }, "source": [ "Next, we will create the *training state* that includes the optimizer, the loss function, and is responsible for updating the model's parameters during training.\n", "\n", "Most JAX transformations (notably [jax.jit](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html)) require functions that are transformed to have no side effects. This is because any such side-effects will only be executed once when the Python version of the function is run during compilation (see [Stateful Computations in JAX](https://jax.readthedocs.io/en/latest/jax-101/07-state.html)). As a consequence, Flax models (which can be transformed by JAX transformations) are **immutable**, and the state of the model (i.e., its weight parameters) is stored *outside* of the model instance.\n", "\n", "Models are initialized and updated in a purely functional way: you pass the state to the model when calling it, and the model returns the new (possibly modified) state, leaving the model instance itself unchanged.\n", "\n", "Flax provides a convenience class [`flax.training.train_state.TrainState`](https://github.com/google/flax/blob/9da95cdd12591f42d2cd4c17089861bff7e43cc5/flax/training/train_state.py#L22), which stores things such as the model parameters, the loss function, the optimizer, and exposes an `apply_gradients` function to update the model's weight parameters.\n", "\n", "Alright, let's begin by defining our *training state* class. We create a `TrainState` class that stores the model's forward pass as the `apply_fn`, the `params`, and the AdamW optimizer." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JHYfR67AoKRc" }, "outputs": [], "source": [ "state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)" ] }, { "cell_type": "markdown", "metadata": { "id": "xiYCejDd81TX" }, "source": [ "Next, let's implement a data loader for both training and evaluation.\n", "The data loader can be defined as a [Python generator](https://wiki.python.org/moin/Generators) that returns a batch model input every time it is called.\n", "\n", "First, a random permutation of the whole dataset is defined. \n", "Then, every time the training data collator is called the next batch of the randomized dataset is extracted, converted to a JAX array and sharded over all local TPU devices." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Aos9GltTb3Ve" }, "outputs": [], "source": [ "def data_loader(rng, dataset, batch_size, shuffle=False):\n", " steps_per_epoch = len(dataset) // batch_size\n", "\n", " if shuffle:\n", " batch_idx = jax.random.permutation(rng, len(dataset))\n", " else:\n", " batch_idx = jnp.arange(len(dataset))\n", "\n", " batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.\n", " batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))\n", "\n", " for idx in batch_idx:\n", " batch = dataset[idx]\n", " batch = {k: jnp.array(v) for k, v in batch.items()}\n", "\n", " batch = shard(batch)\n", "\n", " yield batch" ] }, { "cell_type": "markdown", "metadata": { "id": "L7uoTXDLUzb-" }, "source": [ "At each training epoch, the dataset should be shuffled and superfluous samples that make the dataset not evenly divisible by the batch size are thrown away. Instead of passing the dataset, we prepare the indices of data samples to be used for both each training epoch. \n", "The indices for the training dataset are additionally randomly shuffled before each epoch." ] }, { "cell_type": "markdown", "metadata": { "id": "MU6idLb29xYu" }, "source": [ "During fine-tuning, we want to update the model parameters and evaluate the performance after each epoch. \n", "\n", "Let's write the functions `train_step` and `eval_step` accordingly. During training the weight parameters should be updated as follows:\n", "\n", "1. Define a loss function `loss_function` that first runs a forward pass of the model given data input. Remember that Flax models are immutable, and we explicitly pass it the state (in this case the model parameters and the RNG). `loss_function` returns a scalar loss (using the previously defined `state.loss_function`) between the model output and input targets.\n", "2. Differentiate this loss function using [`jax.value_and_grad`](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#evaluate-a-function-and-its-gradient-using-value-and-grad). This is a JAX transformation called [automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation), which computes the gradient of `loss_function` given the input to the function (i.e., the parameters of the model), and returns the value and the gradient in a pair `(loss, gradients)`.\n", "3. Compute the mean gradient over all devices using the collective operation [lax.pmean](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.pmean.html). As we will see below, each device runs `train_step` on a different batch of data, but by taking the mean here we ensure the model parameters are the same on all devices.\n", "4. Use `state.apply_gradients`, which applies the gradients to the weights.\n", "\n", "Below, you can see how each of the described steps above is put into practice.\n", "\n", "Also note that the `labels` are shifted one to the left and the last token of the `logits` is cut. This way, the model learns to predict the **next** token as defined in causal language modeling." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GjKzb0zJd-aH" }, "outputs": [], "source": [ "def train_step(state, batch, dropout_rng):\n", " dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)\n", "\n", " def loss_fn(params):\n", " labels = batch.pop(\"labels\")\n", " logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]\n", " \n", " loss = optax.softmax_cross_entropy(logits[..., :-1, :], onehot(labels[..., 1:], logits.shape[-1])).mean()\n", " return loss\n", "\n", " grad_fn = jax.value_and_grad(loss_fn)\n", " loss, grad = grad_fn(state.params)\n", " grad = jax.lax.pmean(grad, \"batch\")\n", " new_state = state.apply_gradients(grads=grad)\n", "\n", " metrics = jax.lax.pmean(\n", " {\"loss\": loss, \"learning_rate\": linear_decay_lr_schedule_fn(state.step)}, axis_name=\"batch\"\n", " )\n", "\n", " return new_state, metrics, new_dropout_rng" ] }, { "cell_type": "markdown", "metadata": { "id": "nCPedI-B-FMQ" }, "source": [ "Now, we want to do parallelized training over all TPU devices. To do so, we use [`jax.pmap`](https://jax.readthedocs.io/en/latest/jax.html?highlight=pmap#parallelization-pmap). This will compile the function once and run the same program on each device (it is an [SPMD program](https://en.wikipedia.org/wiki/SPMD)). When calling this pmapped function, all inputs (`\"state\"`, `\"batch\"`, `\"dropout_rng\"`) should be replicated for all devices, which means that the first axis of each argument is used to map over all TPU devices." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "w3k1Lqerpw5k" }, "outputs": [], "source": [ "parallel_train_step = jax.pmap(train_step, \"batch\")" ] }, { "cell_type": "markdown", "metadata": { "id": "0DWFAZM6A8uf" }, "source": [ "Similarly, we can now define the evaluation step. Here, the function is much easier as we don't need to compute any gradients. To better monitor the performance improvement during training, the next token loss is computed and stored in a `metric` dictionary during evaluation." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EGEv7dyfpW4p" }, "outputs": [], "source": [ "def eval_step(params, batch):\n", " labels = batch.pop(\"labels\")\n", "\n", " logits = model(**batch, params=params, train=False)[0]\n", "\n", " loss = optax.softmax_cross_entropy(logits[..., :-1, :], onehot(labels[..., 1:], logits.shape[-1])).mean()\n", "\n", " # summarize metrics\n", " metrics = {\"loss\": loss, \"perplexity\": jnp.exp(loss)}\n", " metrics = jax.lax.pmean(metrics, axis_name=\"batch\")\n", " return metrics" ] }, { "cell_type": "markdown", "metadata": { "id": "guaYWTvFA_66" }, "source": [ "Similarly, we also apply `jax.pmap` to the evaluation step." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0B8U2r2RpzjV" }, "outputs": [], "source": [ "parallel_eval_step = jax.pmap(eval_step, \"batch\")" ] }, { "cell_type": "markdown", "metadata": { "id": "DLaM60PCY8Ka" }, "source": [ "Next, we replicate/copy the weight parameters on each device, so that we can pass them to our parallelized mapped functions." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "kncZTfALp3PG", "outputId": "3ce8ee5a-7bda-4ba9-8774-c52a363a98f5" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.7/dist-packages/jax/lib/xla_bridge.py:317: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code.\n", " \"jax.host_count has been renamed to jax.process_count. This alias \"\n", "/usr/local/lib/python3.7/dist-packages/jax/lib/xla_bridge.py:304: UserWarning: jax.host_id has been renamed to jax.process_index. This alias will eventually be removed; please update your code.\n", " \"jax.host_id has been renamed to jax.process_index. This alias \"\n" ] } ], "source": [ "state = flax.jax_utils.replicate(state)" ] }, { "cell_type": "markdown", "metadata": { "id": "i2xg8oI-ZJ3P" }, "source": [ "We can almost start training! In a final preparation step, we generate a seeded [**PRNGKey**](https://jax.readthedocs.io/en/latest/_autosummary/jax.random.PRNGKey.html#jax-random-prngkey) used as the random seed for dropout layers and dataset shuffling.\n", "\n", "Similar to how we had to copy/replicate the state on all 8 TPU devices, we also need to generate one `PRNGKey` per device, which is why we split the initial `rng` key into 8 random seeds. " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "idu3E9ubqZH3" }, "outputs": [], "source": [ "rng = jax.random.PRNGKey(training_seed)\n", "dropout_rngs = jax.random.split(rng, jax.local_device_count())" ] }, { "cell_type": "markdown", "metadata": { "id": "bKuMWHicbede" }, "source": [ "Now, we are all set to finally start training! \n", "Let's put all the pieces together and write the training loop. \n", "\n", "We start each epoch by generating a new random seed that will be used for dataset shuffling, the dropout layers and the input token masking. \n", "\n", "Next, we generate the training dataset indices.\n", "In the first nested loop - the training loop - we shard the input batch on all 8 TPU devices, and run the training step. \n", "\n", "Analogs, in the second nested loop - the evaluation loop - the evaluation batches are sharded and the evaluation step is run.\n", "\n", "**Note**: It might seem that the following cell \"hangs\" when executed for the first time. This is because JAX first traces & compiles the code, the very first time it is run. After the first training step, you should notice that execution is much faster." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 427, "referenced_widgets": [ "8c156c360afb4ddc962b87577e093cc4", "9808f288735b4f5d9eea7377d4603ccc", "68681edff0db43559012b6653571a289", "125e8d57908f4679824c96d5247f7119", "8134ef53fa8044c99b537386e88dbfc9", "9b35ba10c85048ce8044b87ab9948611", "e246db4008b14a23b55d2820be3b1ed8", "b08557f6936e4af9a23f659649979c8d", "dac789bcf81e4aac86717438ed6e8bbe", "a1452333389441eab2c8016cd019d2f5", "799b382dd6c74d528cf77813f82a7335", "51d0dfb67bbd47ff88964bb446a91967", "940ce122551e42509e8cec0659e47c77", "a4234dc426d3459a93263c8b5fe5a34b", "3549bf51a354408b9a39cb4532957617", "7f05c774880548d2b73e69fb847c0a2e", "71fa85dc266149e28685b1f5252185a5", "9cb943192a154dbc8c7961f128449bbd", "9de6ea64b67c443fbce1fd6aef4a7409", "9658cc6977bb4fdf8a081aa99c300872", "07b26680343f42fc805a3b52b398a5f2", "97216855bb7f483882afae68bf78a51d", "a1b9aa80b73748f18e3938df073b3d22", "ce70b524ffab4052912d47b8adee6748", "3c5df3c942ca4c768926b431064be351", "f2e66aebdb4d4812a1f61c1ff990390a", "269cc543a69e4ba3bae7f0bef67f1e25", "8cb8398bb3c94284ac2ff42c7b3bc7c5", "0b898451f79c4b6eb1678a8f366a5cbf", "7a5426d6898e4ba2b5e3c48a65096232", "2cf7d4c7dc6f413eb8cc9460e5de8d93", "f7851bd9702949ee8e6278ad69c9d981", "ab08bd64073d4d488f9c41e8595bac08", "5e1805a3c7f24c3f9f15e86d15462d44", "fc2e1a72620742499457ab2e31da12a9", "b3b4d96626044375bad11ba48cc25ee3", "c55e4e5f6d0a42b4ac25fb6d68a8e842", "c7cca560cd4340bb986297655f513746", "cd4892863d524637b8e567e2beed23bb", "c85760cb6ccb4020be999357f2501bf9", "7d0775fcf6a8497db20326dc94bd2739", "5215d791d9144c089b6996336c3085cd", "839daa7d5cdc4669942f1b5ef5bf4903", "4bc03fce15a14355b3429c138ea1eadd", "d7cddd2cafe84d17829096b31ee5e93e", "b4ce8892d5784011a649eeca7b17fd0d", "7a7b598979174b338f8fdfc90af61e47", "cd946e84023445d59eea9893fba02f1a", "4ec581224b3f47f2acdfaeb740c94526", "81c25bb635db4bc9a62636bf3a3868eb", "1462263d23c340e781d596a3f5414e03", "c8bd20a456cc4783a2fbbdec9a2880ea", "f0aa35a704a84967a3471e5f698bb04c", "f2c541b4ecd04729a518414f8b878099", "402795b6506a45c9831fb0ddb2ce0749", "0ba00c3340414029938c9c947d4a74b6", "d17f610741f142b782c74ac4e04e6481", "84c8a7cbfdb04690941f8c32299744bb", "f3f0b212907843f5be2131b7b60a955d", "0a6a241c99bb49ad9baaef052b157964", "b642c546813240a499a165279c743e7a", "6acb415dc8b947bb85ab8f8959bc10e1", "2c311c4506264995b7fbaba02333e747", "2fde17ab63e5444799ec1dc5501d946e", "6173f15a3e5a40819bd4f3833544fd78", "da2f69c01bcd4139980a84ddca297817", "b9470519dfff45228e54d0181337ee17", "b5824ca248b74052bb1696b47369838e", "11f8b1ddd5af4bb3babd492bbaf3fb7d", "861e2a6f8e3e4a5cacfd0a746180cc90", "181f0680e5bb4d23a69f9cc747207883", "0e2531c122b44823a012ba844066e60c", "a2324cab4e9745ddad69b35e77ed148e", "42c8cb98eaae4cee86d49c675d3a3604", "146ce87ac5b445e385285bdbbd70fba9", "46340cbdaf9c47719e005b42b7236be2", "cd26f91d46ef42609d265f0b709c3401", "6f5f003a277244b2824e09b877fefa74", "57a93be7bb3d4167a89f7d73d5915551", "db96a8d67e564cd58358916ddfb86948", "194f5ce951d54193b887f1cc91e3640b", "d8770c573383466f81d7dfd01f048677", "1fe6ad92e6834142844bb32e7a0d424d", "9d905b40bf2746e79abad7257cf8654f", "bfa957f8482847978542373db16cfa8a", "1f836daeb9b7495a9ae00bc2c28fa212", "0973226c646f4bb28f291ed3eb4e5cba", "fd72a18adfb34bc2b2dab769067dcc1a", "91fb9675236949e6bb8c99015061647c", "75f988a47d774cd38fd85977fda7a96c", "503ff750d98d4f2c8f7af32adc221685", "8b3645c111b842f4b1d9ff54bb94f26e", "dabd19b76eec4139a3028aa1776ac496", "b6410389dafe43bd885446a6d2d255b0", "23187b20efd44965966d6576c0d057aa", "6f876daff243410bb26b8feb1cf981e0", "2374ba361eaa41899b7eabe7da829259", "522331eac9cb49e4998384f42eeba912", "6eaf7a1f2e5e4aa590da55adf969ecc6", "29b0ec79b8134168a5e3ae382088c147", "96f9368df9dc453495badb09e767d091", "02b1a546bc334efa93ba8dff5890dd0c", "b7f8df98ab5d45f39b34ceb1018c01b5", "c251cf1c33c741049cc9710e89f09bc0", "7e067f02bdb74846a42909cd9ad1961b", "68cf981867f7461c999c9e9f85af6499", "b6eec0df5ca64d5dbf32cb5231127e3b", "6efef5f4637e4efaa373cb9214d4e911", "d539e4c93f4941e9af0b742df4c53ae0", "dafd837bebd44c00ae68641cc0b31d3d", "47808ac98d064711bccc77037e53dade", "acfab7adb0ed440b87fa08f81086522d", "6fa186bd796b4d16bfb67657178e4f65", "07fe952dec2b4f2687bcddec8701a70e", "6c9e7145c2d7439db6f63b29c505e63e", "43861346cadc460992ab7e2c30aa556c", "6baec27867e943a2bad6e70e48978ea5", "d2e74c9d62444e98b52078d6a15dd068", "fcf96c9bb04642089466de2421bd1af2", "b9d1f631577f4a2a8eb49d3f8535ba25", "e9c8d09e028741979a764789fd7f256e", "2bd15c15e3b748cd9ee001bb86c04b6d", "f5e27446f61747e1a65e477e2975dbcf", "1e17af39c928463b800d22ad1fe533e5", "98e148811880459ea7a7641ee3a301b6", "6669a1844b394ea499dc92e126132bfc", "987310405403439893813fb8c5da94a1", "f48e444c929748b2a875c6eecc028bc4", "1ed01b7ba2a64bebbbf0e641c8d1560e", "bfadf2c9e99c48e4860c2a67b2955120", "63f47b7d6bc849d1b869367a6ba8b40f", "1bfe0a0ab5694e1d857ea981090d4f09", "7f6628d5f61f424ba6373da2f6defe76", "01d1c1e078e7403daabd1e7d8d5a4fda", "3b802a49c0824b968d88350f82ff8422", "5a6ea875aeff4688ba37481af5b5d9fe", "6dfceed81b9b4f8198a8186dfdc72d42", "44ee8edfb7c744358f0cb4567535ce07", "5eb4c37fcf6a4b618a8f573b837a8c3a", "5bc32858674f442fa7b49efcccfc3c15", "3ebb7960558247ec84ef155042d2077a", "52a25fde354e45dd9635097d22c16def", "ce3117149c7e4ae6a93fbe959bd3f543", "900aa99ccfeb4a38a0ad90f8ce656fe6", "b414ff8df761462e8df8f51788f69508", "e0c2ef0fd4d544b9b80b00689d21f503", "52cae29e50974a42a3882f91b2a8f2ba", "02b454a4d7ce44b0b819fc7085b1d357", "6321a346b671411791979a1a343fb494", "22941d4a640b4d8bba1b97f723da0ef0", "b6f9da40231f4404bb0c90458c94c5c2", "7db664fc575747af8e10ad6f2783a166", "62fd4952400246189945f2de7f0192c6", "f07d424e031d457a98a88b94b0c844b8", "a9f8af0a5e594c29b553983b6a1c4c7d", "5829f975dedd45d3b315ab29c1b12253", "2eb6756f48074993825b5bfac47641e1", "389fcde31cf24c438a9fcdb155060495", "a5a86539a58f4fe98ba96f9cd3d46588", "32c067d64cb7445eb0464f3a10071e7c", "c014bf31dccb4448b561a702361ed9bc", "ef823903b36348978833920bb7e03881", "c6855514c5a649ebadc2db53c267e9e0", "8602122a90db462cacef9373fca054fe", "892e02ac20ef4963871d915085139f7d", "173704b44e2c472499bc4242ef9e0863", "2f647e8c1be2464fb9c69e8af5e5bf5a", "5fc3a90707f746c884d924e385b0cc1c" ] }, "id": "U946A-YZp-Pe", "outputId": "1fe21ffd-2e39-4470-fb87-6b11ab1d4024" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8c156c360afb4ddc962b87577e093cc4", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Epoch ...', max=10.0, style=ProgressStyle(description_wid…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "dac789bcf81e4aac86717438ed6e8bbe", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Training...', max=137.0, style=ProgressStyle(description_…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "\r", "\r", "Train... (1/10 | Loss: 6.935000419616699, Learning Rate: 0.0002699999895412475)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "71fa85dc266149e28685b1f5252185a5", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Evaluation...', max=12.0, style=ProgressStyle(description…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "\r", "\r", "Eval... (1/10 | Loss: 7.108445644378662 | Perplexity: 1246.529052734375)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3c5df3c942ca4c768926b431064be351", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Training...', max=137.0, style=ProgressStyle(description_…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "\r", "\r", "Train... (2/10 | Loss: 6.334000110626221, Learning Rate: 0.00023999999393709004)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ab08bd64073d4d488f9c41e8595bac08", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Evaluation...', max=12.0, style=ProgressStyle(description…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "\r", "\r", "Eval... (2/10 | Loss: 6.567610740661621 | Perplexity: 738.8753662109375)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7d0775fcf6a8497db20326dc94bd2739", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Training...', max=137.0, style=ProgressStyle(description_…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "\r", "\r", "Train... (3/10 | Loss: 5.798000335693359, Learning Rate: 0.0002099999983329326)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4ec581224b3f47f2acdfaeb740c94526", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Evaluation...', max=12.0, style=ProgressStyle(description…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "\r", "\r", "Eval... (3/10 | Loss: 6.278167247772217 | Perplexity: 557.9488525390625)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d17f610741f142b782c74ac4e04e6481", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Training...', max=137.0, style=ProgressStyle(description_…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "\r", "\r", "Train... (4/10 | Loss: 5.557000160217285, Learning Rate: 0.00018000000272877514)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6173f15a3e5a40819bd4f3833544fd78", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Evaluation...', max=12.0, style=ProgressStyle(description…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "\r", "\r", "Eval... (4/10 | Loss: 6.062875270843506 | Perplexity: 451.3289794921875)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a2324cab4e9745ddad69b35e77ed148e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Training...', max=137.0, style=ProgressStyle(description_…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "\r", "\r", "Train... (5/10 | Loss: 5.543000221252441, Learning Rate: 0.00014999999257270247)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "194f5ce951d54193b887f1cc91e3640b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Evaluation...', max=12.0, style=ProgressStyle(description…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "\r", "\r", "Eval... (5/10 | Loss: 5.920379161834717 | Perplexity: 392.97332763671875)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "91fb9675236949e6bb8c99015061647c", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Training...', max=137.0, style=ProgressStyle(description_…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "\r", "\r", "Train... (6/10 | Loss: 5.361000061035156, Learning Rate: 0.00011999999696854502)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2374ba361eaa41899b7eabe7da829259", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Evaluation...', max=12.0, style=ProgressStyle(description…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "\r", "\r", "Eval... (6/10 | Loss: 5.821027755737305 | Perplexity: 356.4353942871094)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7e067f02bdb74846a42909cd9ad1961b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Training...', max=137.0, style=ProgressStyle(description_…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "\r", "\r", "Train... (7/10 | Loss: 5.207000255584717, Learning Rate: 9.000000136438757e-05)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6fa186bd796b4d16bfb67657178e4f65", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Evaluation...', max=12.0, style=ProgressStyle(description…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "\r", "\r", "Eval... (7/10 | Loss: 5.748736381530762 | Perplexity: 332.1453857421875)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e9c8d09e028741979a764789fd7f256e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Training...', max=137.0, style=ProgressStyle(description_…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "\r", "\r", "Train... (8/10 | Loss: 5.124000072479248, Learning Rate: 5.999999848427251e-05)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1ed01b7ba2a64bebbbf0e641c8d1560e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Evaluation...', max=12.0, style=ProgressStyle(description…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "\r", "\r", "Eval... (8/10 | Loss: 5.703180313110352 | Perplexity: 317.5106201171875)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6dfceed81b9b4f8198a8186dfdc72d42", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Training...', max=137.0, style=ProgressStyle(description_…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "\r", "\r", "Train... (9/10 | Loss: 5.220000267028809, Learning Rate: 2.9999999242136255e-05)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b414ff8df761462e8df8f51788f69508", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Evaluation...', max=12.0, style=ProgressStyle(description…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "\r", "\r", "Eval... (9/10 | Loss: 5.674434185028076 | Perplexity: 308.7478942871094)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "62fd4952400246189945f2de7f0192c6", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Training...', max=137.0, style=ProgressStyle(description_…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "\r", "\r", "Train... (10/10 | Loss: 4.992000102996826, Learning Rate: 0.0)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c014bf31dccb4448b561a702361ed9bc", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Evaluation...', max=12.0, style=ProgressStyle(description…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "\r", "\r", "Eval... (10/10 | Loss: 5.66389274597168 | Perplexity: 305.58953857421875)\n", "\n" ] } ], "source": [ "for epoch in tqdm(range(1, num_epochs + 1), desc=f\"Epoch ...\", position=0, leave=True):\n", " rng, input_rng = jax.random.split(rng)\n", "\n", " # -- Train --\n", " train_loader = data_loader(input_rng, tokenized_datasets[\"train\"], total_batch_size, shuffle=True)\n", " with tqdm(total=len(tokenized_datasets[\"train\"]) // total_batch_size, desc=\"Training...\", leave=False) as progress_bar_train:\n", " for model_inputs in train_loader:\n", " # Model forward\n", " state, train_metric, dropout_rngs = parallel_train_step(state, model_inputs, dropout_rngs)\n", "\n", " progress_bar_train.update(1)\n", "\n", " progress_bar_train.write(\n", " f\"Train... ({epoch}/{num_epochs} | Loss: {round(train_metric['loss'].mean(), 3)}, Learning Rate: {round(train_metric['learning_rate'].mean(), 6)})\"\n", " )\n", "\n", " # -- Eval --\n", " eval_loader = data_loader(input_rng, tokenized_datasets[\"validation\"], total_batch_size)\n", " eval_metrics = []\n", " \n", " with tqdm(total=len(tokenized_datasets[\"validation\"]) // total_batch_size, desc=\"Evaluation...\", leave=False) as progress_bar_eval:\n", " for model_inputs in eval_loader:\n", " # Model forward\n", " eval_metric = parallel_eval_step(state.params, model_inputs)\n", " eval_metrics.append(eval_metric)\n", "\n", " progress_bar_eval.update(1)\n", " \n", " eval_metrics = get_metrics(eval_metrics)\n", " eval_metrics = jax.tree_map(jnp.mean, eval_metrics)\n", " progress_bar_eval.write(\n", " f\"Eval... ({epoch}/{num_epochs} | Loss: {eval_metrics['loss']} | Perplexity: {eval_metrics['perplexity']})\"\n", " )" ] }, { "cell_type": "markdown", "metadata": { "id": "ZI4XIhY-7hyh" }, "source": [ "It can be seen that in this colab training already reaches a speed of 2.42 training steps per second. Executing [**`run_clm_flax.py`**](https://github.com/huggingface/transformers/tree/master/examples/flax/language-modeling/run_clm_flax.py) on a TPUv3-8 VM should be as fast as 7 training steps per second.\n", "\n", "For a more in-detail comparison of runtimes please refer to [this](https://github.com/huggingface/transformers/tree/master/examples/flax/language-modeling#runtime-evaluation) table." ] } ], "metadata": { "accelerator": "TPU", "colab": { "authorship_tag": "ABX9TyOdWFDf0k/7HZKWXWczgZRn", "collapsed_sections": [], "include_colab_link": true, "name": "Causal Language Model Training on TPU with 🤗 Transformers & JAX", "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.8" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "01d1c1e078e7403daabd1e7d8d5a4fda": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "02b1a546bc334efa93ba8dff5890dd0c": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "02b454a4d7ce44b0b819fc7085b1d357": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_7db664fc575747af8e10ad6f2783a166", "placeholder": "​", "style": "IPY_MODEL_b6f9da40231f4404bb0c90458c94c5c2", "value": " 12/12 [00:05<00:00, 2.35it/s]" } }, "07b26680343f42fc805a3b52b398a5f2": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "07fe952dec2b4f2687bcddec8701a70e": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "0973226c646f4bb28f291ed3eb4e5cba": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "0a6a241c99bb49ad9baaef052b157964": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_2fde17ab63e5444799ec1dc5501d946e", "placeholder": "​", "style": "IPY_MODEL_2c311c4506264995b7fbaba02333e747", "value": " 137/137 [01:38<00:00, 1.40it/s]" } }, "0b898451f79c4b6eb1678a8f366a5cbf": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "0ba00c3340414029938c9c947d4a74b6": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "0e2531c122b44823a012ba844066e60c": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "11f8b1ddd5af4bb3babd492bbaf3fb7d": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "125e8d57908f4679824c96d5247f7119": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_b08557f6936e4af9a23f659649979c8d", "placeholder": "​", "style": "IPY_MODEL_e246db4008b14a23b55d2820be3b1ed8", "value": " 10/10 [19:05<00:00, 114.56s/it]" } }, "1462263d23c340e781d596a3f5414e03": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "", "description": "Evaluation...: 100%", "description_tooltip": null, "layout": "IPY_MODEL_f2c541b4ecd04729a518414f8b878099", "max": 12, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_f0aa35a704a84967a3471e5f698bb04c", "value": 12 } }, "146ce87ac5b445e385285bdbbd70fba9": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "", "description": "Training...: 100%", "description_tooltip": null, "layout": "IPY_MODEL_6f5f003a277244b2824e09b877fefa74", "max": 137, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_cd26f91d46ef42609d265f0b709c3401", "value": 137 } }, "173704b44e2c472499bc4242ef9e0863": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "181f0680e5bb4d23a69f9cc747207883": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "194f5ce951d54193b887f1cc91e3640b": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_1fe6ad92e6834142844bb32e7a0d424d", "IPY_MODEL_9d905b40bf2746e79abad7257cf8654f" ], "layout": "IPY_MODEL_d8770c573383466f81d7dfd01f048677" } }, "1bfe0a0ab5694e1d857ea981090d4f09": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_5a6ea875aeff4688ba37481af5b5d9fe", "placeholder": "​", "style": "IPY_MODEL_3b802a49c0824b968d88350f82ff8422", "value": " 12/12 [00:05<00:00, 2.35it/s]" } }, "1e17af39c928463b800d22ad1fe533e5": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_f48e444c929748b2a875c6eecc028bc4", "placeholder": "​", "style": "IPY_MODEL_987310405403439893813fb8c5da94a1", "value": " 137/137 [01:36<00:00, 1.41it/s]" } }, "1ed01b7ba2a64bebbbf0e641c8d1560e": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_63f47b7d6bc849d1b869367a6ba8b40f", "IPY_MODEL_1bfe0a0ab5694e1d857ea981090d4f09" ], "layout": "IPY_MODEL_bfadf2c9e99c48e4860c2a67b2955120" } }, "1f836daeb9b7495a9ae00bc2c28fa212": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "1fe6ad92e6834142844bb32e7a0d424d": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "", "description": "Evaluation...: 100%", "description_tooltip": null, "layout": "IPY_MODEL_1f836daeb9b7495a9ae00bc2c28fa212", "max": 12, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_bfa957f8482847978542373db16cfa8a", "value": 12 } }, "22941d4a640b4d8bba1b97f723da0ef0": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "23187b20efd44965966d6576c0d057aa": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "2374ba361eaa41899b7eabe7da829259": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_6eaf7a1f2e5e4aa590da55adf969ecc6", "IPY_MODEL_29b0ec79b8134168a5e3ae382088c147" ], "layout": "IPY_MODEL_522331eac9cb49e4998384f42eeba912" } }, "269cc543a69e4ba3bae7f0bef67f1e25": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "", "description": "Training...: 100%", "description_tooltip": null, "layout": "IPY_MODEL_7a5426d6898e4ba2b5e3c48a65096232", "max": 137, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_0b898451f79c4b6eb1678a8f366a5cbf", "value": 137 } }, "29b0ec79b8134168a5e3ae382088c147": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_c251cf1c33c741049cc9710e89f09bc0", "placeholder": "​", "style": "IPY_MODEL_b7f8df98ab5d45f39b34ceb1018c01b5", "value": " 12/12 [00:05<00:00, 2.35it/s]" } }, "2bd15c15e3b748cd9ee001bb86c04b6d": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "2c311c4506264995b7fbaba02333e747": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "2cf7d4c7dc6f413eb8cc9460e5de8d93": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "2eb6756f48074993825b5bfac47641e1": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "2f647e8c1be2464fb9c69e8af5e5bf5a": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "2fde17ab63e5444799ec1dc5501d946e": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "32c067d64cb7445eb0464f3a10071e7c": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "3549bf51a354408b9a39cb4532957617": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "389fcde31cf24c438a9fcdb155060495": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "3b802a49c0824b968d88350f82ff8422": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "3c5df3c942ca4c768926b431064be351": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_269cc543a69e4ba3bae7f0bef67f1e25", "IPY_MODEL_8cb8398bb3c94284ac2ff42c7b3bc7c5" ], "layout": "IPY_MODEL_f2e66aebdb4d4812a1f61c1ff990390a" } }, "3ebb7960558247ec84ef155042d2077a": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "402795b6506a45c9831fb0ddb2ce0749": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "42c8cb98eaae4cee86d49c675d3a3604": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "43861346cadc460992ab7e2c30aa556c": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_b9d1f631577f4a2a8eb49d3f8535ba25", "placeholder": "​", "style": "IPY_MODEL_fcf96c9bb04642089466de2421bd1af2", "value": " 12/12 [00:05<00:00, 2.33it/s]" } }, "44ee8edfb7c744358f0cb4567535ce07": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "46340cbdaf9c47719e005b42b7236be2": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_db96a8d67e564cd58358916ddfb86948", "placeholder": "​", "style": "IPY_MODEL_57a93be7bb3d4167a89f7d73d5915551", "value": " 137/137 [01:37<00:00, 1.40it/s]" } }, "47808ac98d064711bccc77037e53dade": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "4bc03fce15a14355b3429c138ea1eadd": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_cd946e84023445d59eea9893fba02f1a", "placeholder": "​", "style": "IPY_MODEL_7a7b598979174b338f8fdfc90af61e47", "value": " 137/137 [01:37<00:00, 1.41it/s]" } }, "4ec581224b3f47f2acdfaeb740c94526": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_1462263d23c340e781d596a3f5414e03", "IPY_MODEL_c8bd20a456cc4783a2fbbdec9a2880ea" ], "layout": "IPY_MODEL_81c25bb635db4bc9a62636bf3a3868eb" } }, "503ff750d98d4f2c8f7af32adc221685": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "", "description": "Training...: 100%", "description_tooltip": null, "layout": "IPY_MODEL_b6410389dafe43bd885446a6d2d255b0", "max": 137, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_dabd19b76eec4139a3028aa1776ac496", "value": 137 } }, "51d0dfb67bbd47ff88964bb446a91967": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_7f05c774880548d2b73e69fb847c0a2e", "placeholder": "​", "style": "IPY_MODEL_3549bf51a354408b9a39cb4532957617", "value": " 137/137 [03:28<00:00, 1.41it/s]" } }, "5215d791d9144c089b6996336c3085cd": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "522331eac9cb49e4998384f42eeba912": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "52a25fde354e45dd9635097d22c16def": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "52cae29e50974a42a3882f91b2a8f2ba": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "", "description": "Evaluation...: 100%", "description_tooltip": null, "layout": "IPY_MODEL_22941d4a640b4d8bba1b97f723da0ef0", "max": 12, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_6321a346b671411791979a1a343fb494", "value": 12 } }, "57a93be7bb3d4167a89f7d73d5915551": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "5829f975dedd45d3b315ab29c1b12253": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_32c067d64cb7445eb0464f3a10071e7c", "placeholder": "​", "style": "IPY_MODEL_a5a86539a58f4fe98ba96f9cd3d46588", "value": " 137/137 [01:37<00:00, 1.40it/s]" } }, "5a6ea875aeff4688ba37481af5b5d9fe": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "5bc32858674f442fa7b49efcccfc3c15": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_900aa99ccfeb4a38a0ad90f8ce656fe6", "placeholder": "​", "style": "IPY_MODEL_ce3117149c7e4ae6a93fbe959bd3f543", "value": " 137/137 [01:37<00:00, 1.42it/s]" } }, "5e1805a3c7f24c3f9f15e86d15462d44": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "5eb4c37fcf6a4b618a8f573b837a8c3a": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "", "description": "Training...: 100%", "description_tooltip": null, "layout": "IPY_MODEL_52a25fde354e45dd9635097d22c16def", "max": 137, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_3ebb7960558247ec84ef155042d2077a", "value": 137 } }, "5fc3a90707f746c884d924e385b0cc1c": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "6173f15a3e5a40819bd4f3833544fd78": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_b9470519dfff45228e54d0181337ee17", "IPY_MODEL_b5824ca248b74052bb1696b47369838e" ], "layout": "IPY_MODEL_da2f69c01bcd4139980a84ddca297817" } }, "62fd4952400246189945f2de7f0192c6": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_a9f8af0a5e594c29b553983b6a1c4c7d", "IPY_MODEL_5829f975dedd45d3b315ab29c1b12253" ], "layout": "IPY_MODEL_f07d424e031d457a98a88b94b0c844b8" } }, "6321a346b671411791979a1a343fb494": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "63f47b7d6bc849d1b869367a6ba8b40f": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "", "description": "Evaluation...: 100%", "description_tooltip": null, "layout": "IPY_MODEL_01d1c1e078e7403daabd1e7d8d5a4fda", "max": 12, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_7f6628d5f61f424ba6373da2f6defe76", "value": 12 } }, "6669a1844b394ea499dc92e126132bfc": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "68681edff0db43559012b6653571a289": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "Epoch ...: 100%", "description_tooltip": null, "layout": "IPY_MODEL_9b35ba10c85048ce8044b87ab9948611", "max": 10, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_8134ef53fa8044c99b537386e88dbfc9", "value": 10 } }, "68cf981867f7461c999c9e9f85af6499": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "6acb415dc8b947bb85ab8f8959bc10e1": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "6baec27867e943a2bad6e70e48978ea5": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "6c9e7145c2d7439db6f63b29c505e63e": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "", "description": "Evaluation...: 100%", "description_tooltip": null, "layout": "IPY_MODEL_d2e74c9d62444e98b52078d6a15dd068", "max": 12, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_6baec27867e943a2bad6e70e48978ea5", "value": 12 } }, "6dfceed81b9b4f8198a8186dfdc72d42": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_5eb4c37fcf6a4b618a8f573b837a8c3a", "IPY_MODEL_5bc32858674f442fa7b49efcccfc3c15" ], "layout": "IPY_MODEL_44ee8edfb7c744358f0cb4567535ce07" } }, "6eaf7a1f2e5e4aa590da55adf969ecc6": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "", "description": "Evaluation...: 100%", "description_tooltip": null, "layout": "IPY_MODEL_02b1a546bc334efa93ba8dff5890dd0c", "max": 12, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_96f9368df9dc453495badb09e767d091", "value": 12 } }, "6efef5f4637e4efaa373cb9214d4e911": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_acfab7adb0ed440b87fa08f81086522d", "placeholder": "​", "style": "IPY_MODEL_47808ac98d064711bccc77037e53dade", "value": " 137/137 [01:36<00:00, 1.42it/s]" } }, "6f5f003a277244b2824e09b877fefa74": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "6f876daff243410bb26b8feb1cf981e0": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "6fa186bd796b4d16bfb67657178e4f65": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_6c9e7145c2d7439db6f63b29c505e63e", "IPY_MODEL_43861346cadc460992ab7e2c30aa556c" ], "layout": "IPY_MODEL_07fe952dec2b4f2687bcddec8701a70e" } }, "71fa85dc266149e28685b1f5252185a5": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_9de6ea64b67c443fbce1fd6aef4a7409", "IPY_MODEL_9658cc6977bb4fdf8a081aa99c300872" ], "layout": "IPY_MODEL_9cb943192a154dbc8c7961f128449bbd" } }, "75f988a47d774cd38fd85977fda7a96c": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "799b382dd6c74d528cf77813f82a7335": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "", "description": "Training...: 100%", "description_tooltip": null, "layout": "IPY_MODEL_a4234dc426d3459a93263c8b5fe5a34b", "max": 137, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_940ce122551e42509e8cec0659e47c77", "value": 137 } }, "7a5426d6898e4ba2b5e3c48a65096232": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "7a7b598979174b338f8fdfc90af61e47": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "7d0775fcf6a8497db20326dc94bd2739": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_839daa7d5cdc4669942f1b5ef5bf4903", "IPY_MODEL_4bc03fce15a14355b3429c138ea1eadd" ], "layout": "IPY_MODEL_5215d791d9144c089b6996336c3085cd" } }, "7db664fc575747af8e10ad6f2783a166": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "7e067f02bdb74846a42909cd9ad1961b": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_b6eec0df5ca64d5dbf32cb5231127e3b", "IPY_MODEL_6efef5f4637e4efaa373cb9214d4e911" ], "layout": "IPY_MODEL_68cf981867f7461c999c9e9f85af6499" } }, "7f05c774880548d2b73e69fb847c0a2e": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "7f6628d5f61f424ba6373da2f6defe76": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "8134ef53fa8044c99b537386e88dbfc9": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "81c25bb635db4bc9a62636bf3a3868eb": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "839daa7d5cdc4669942f1b5ef5bf4903": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "", "description": "Training...: 100%", "description_tooltip": null, "layout": "IPY_MODEL_b4ce8892d5784011a649eeca7b17fd0d", "max": 137, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_d7cddd2cafe84d17829096b31ee5e93e", "value": 137 } }, "84c8a7cbfdb04690941f8c32299744bb": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "8602122a90db462cacef9373fca054fe": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_5fc3a90707f746c884d924e385b0cc1c", "placeholder": "​", "style": "IPY_MODEL_2f647e8c1be2464fb9c69e8af5e5bf5a", "value": " 12/12 [00:05<00:00, 2.35it/s]" } }, "861e2a6f8e3e4a5cacfd0a746180cc90": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "892e02ac20ef4963871d915085139f7d": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "8b3645c111b842f4b1d9ff54bb94f26e": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_6f876daff243410bb26b8feb1cf981e0", "placeholder": "​", "style": "IPY_MODEL_23187b20efd44965966d6576c0d057aa", "value": " 137/137 [01:38<00:00, 1.43it/s]" } }, "8c156c360afb4ddc962b87577e093cc4": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_68681edff0db43559012b6653571a289", "IPY_MODEL_125e8d57908f4679824c96d5247f7119" ], "layout": "IPY_MODEL_9808f288735b4f5d9eea7377d4603ccc" } }, "8cb8398bb3c94284ac2ff42c7b3bc7c5": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_f7851bd9702949ee8e6278ad69c9d981", "placeholder": "​", "style": "IPY_MODEL_2cf7d4c7dc6f413eb8cc9460e5de8d93", "value": " 137/137 [01:37<00:00, 1.41it/s]" } }, "900aa99ccfeb4a38a0ad90f8ce656fe6": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "91fb9675236949e6bb8c99015061647c": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_503ff750d98d4f2c8f7af32adc221685", "IPY_MODEL_8b3645c111b842f4b1d9ff54bb94f26e" ], "layout": "IPY_MODEL_75f988a47d774cd38fd85977fda7a96c" } }, "940ce122551e42509e8cec0659e47c77": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "9658cc6977bb4fdf8a081aa99c300872": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_ce70b524ffab4052912d47b8adee6748", "placeholder": "​", "style": "IPY_MODEL_a1b9aa80b73748f18e3938df073b3d22", "value": " 12/12 [00:12<00:00, 1.78it/s]" } }, "96f9368df9dc453495badb09e767d091": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "97216855bb7f483882afae68bf78a51d": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "9808f288735b4f5d9eea7377d4603ccc": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "987310405403439893813fb8c5da94a1": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "98e148811880459ea7a7641ee3a301b6": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "9b35ba10c85048ce8044b87ab9948611": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "9cb943192a154dbc8c7961f128449bbd": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "9d905b40bf2746e79abad7257cf8654f": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_fd72a18adfb34bc2b2dab769067dcc1a", "placeholder": "​", "style": "IPY_MODEL_0973226c646f4bb28f291ed3eb4e5cba", "value": " 12/12 [00:05<00:00, 2.35it/s]" } }, "9de6ea64b67c443fbce1fd6aef4a7409": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "", "description": "Evaluation...: 100%", "description_tooltip": null, "layout": "IPY_MODEL_97216855bb7f483882afae68bf78a51d", "max": 12, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_07b26680343f42fc805a3b52b398a5f2", "value": 12 } }, "a1452333389441eab2c8016cd019d2f5": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "a1b9aa80b73748f18e3938df073b3d22": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "a2324cab4e9745ddad69b35e77ed148e": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_146ce87ac5b445e385285bdbbd70fba9", "IPY_MODEL_46340cbdaf9c47719e005b42b7236be2" ], "layout": "IPY_MODEL_42c8cb98eaae4cee86d49c675d3a3604" } }, "a4234dc426d3459a93263c8b5fe5a34b": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "a5a86539a58f4fe98ba96f9cd3d46588": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "a9f8af0a5e594c29b553983b6a1c4c7d": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "", "description": "Training...: 100%", "description_tooltip": null, "layout": "IPY_MODEL_389fcde31cf24c438a9fcdb155060495", "max": 137, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_2eb6756f48074993825b5bfac47641e1", "value": 137 } }, "ab08bd64073d4d488f9c41e8595bac08": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_fc2e1a72620742499457ab2e31da12a9", "IPY_MODEL_b3b4d96626044375bad11ba48cc25ee3" ], "layout": "IPY_MODEL_5e1805a3c7f24c3f9f15e86d15462d44" } }, "acfab7adb0ed440b87fa08f81086522d": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "b08557f6936e4af9a23f659649979c8d": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "b3b4d96626044375bad11ba48cc25ee3": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_c85760cb6ccb4020be999357f2501bf9", "placeholder": "​", "style": "IPY_MODEL_cd4892863d524637b8e567e2beed23bb", "value": " 12/12 [00:05<00:00, 2.31it/s]" } }, "b414ff8df761462e8df8f51788f69508": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_52cae29e50974a42a3882f91b2a8f2ba", "IPY_MODEL_02b454a4d7ce44b0b819fc7085b1d357" ], "layout": "IPY_MODEL_e0c2ef0fd4d544b9b80b00689d21f503" } }, "b4ce8892d5784011a649eeca7b17fd0d": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "b5824ca248b74052bb1696b47369838e": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_0e2531c122b44823a012ba844066e60c", "placeholder": "​", "style": "IPY_MODEL_181f0680e5bb4d23a69f9cc747207883", "value": " 12/12 [00:05<00:00, 2.34it/s]" } }, "b6410389dafe43bd885446a6d2d255b0": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "b642c546813240a499a165279c743e7a": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "b6eec0df5ca64d5dbf32cb5231127e3b": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "", "description": "Training...: 100%", "description_tooltip": null, "layout": "IPY_MODEL_dafd837bebd44c00ae68641cc0b31d3d", "max": 137, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_d539e4c93f4941e9af0b742df4c53ae0", "value": 137 } }, "b6f9da40231f4404bb0c90458c94c5c2": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "b7f8df98ab5d45f39b34ceb1018c01b5": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "b9470519dfff45228e54d0181337ee17": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "", "description": "Evaluation...: 100%", "description_tooltip": null, "layout": "IPY_MODEL_861e2a6f8e3e4a5cacfd0a746180cc90", "max": 12, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_11f8b1ddd5af4bb3babd492bbaf3fb7d", "value": 12 } }, "b9d1f631577f4a2a8eb49d3f8535ba25": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "bfa957f8482847978542373db16cfa8a": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "bfadf2c9e99c48e4860c2a67b2955120": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "c014bf31dccb4448b561a702361ed9bc": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_c6855514c5a649ebadc2db53c267e9e0", "IPY_MODEL_8602122a90db462cacef9373fca054fe" ], "layout": "IPY_MODEL_ef823903b36348978833920bb7e03881" } }, "c251cf1c33c741049cc9710e89f09bc0": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "c55e4e5f6d0a42b4ac25fb6d68a8e842": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "c6855514c5a649ebadc2db53c267e9e0": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "", "description": "Evaluation...: 100%", "description_tooltip": null, "layout": "IPY_MODEL_173704b44e2c472499bc4242ef9e0863", "max": 12, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_892e02ac20ef4963871d915085139f7d", "value": 12 } }, "c7cca560cd4340bb986297655f513746": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "c85760cb6ccb4020be999357f2501bf9": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "c8bd20a456cc4783a2fbbdec9a2880ea": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_0ba00c3340414029938c9c947d4a74b6", "placeholder": "​", "style": "IPY_MODEL_402795b6506a45c9831fb0ddb2ce0749", "value": " 12/12 [00:05<00:00, 2.35it/s]" } }, "cd26f91d46ef42609d265f0b709c3401": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "cd4892863d524637b8e567e2beed23bb": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "cd946e84023445d59eea9893fba02f1a": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "ce3117149c7e4ae6a93fbe959bd3f543": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "ce70b524ffab4052912d47b8adee6748": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "d17f610741f142b782c74ac4e04e6481": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_f3f0b212907843f5be2131b7b60a955d", "IPY_MODEL_0a6a241c99bb49ad9baaef052b157964" ], "layout": "IPY_MODEL_84c8a7cbfdb04690941f8c32299744bb" } }, "d2e74c9d62444e98b52078d6a15dd068": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "d539e4c93f4941e9af0b742df4c53ae0": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "d7cddd2cafe84d17829096b31ee5e93e": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "d8770c573383466f81d7dfd01f048677": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "da2f69c01bcd4139980a84ddca297817": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "dabd19b76eec4139a3028aa1776ac496": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "dac789bcf81e4aac86717438ed6e8bbe": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_799b382dd6c74d528cf77813f82a7335", "IPY_MODEL_51d0dfb67bbd47ff88964bb446a91967" ], "layout": "IPY_MODEL_a1452333389441eab2c8016cd019d2f5" } }, "dafd837bebd44c00ae68641cc0b31d3d": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "db96a8d67e564cd58358916ddfb86948": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "e0c2ef0fd4d544b9b80b00689d21f503": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "e246db4008b14a23b55d2820be3b1ed8": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "e9c8d09e028741979a764789fd7f256e": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_f5e27446f61747e1a65e477e2975dbcf", "IPY_MODEL_1e17af39c928463b800d22ad1fe533e5" ], "layout": "IPY_MODEL_2bd15c15e3b748cd9ee001bb86c04b6d" } }, "ef823903b36348978833920bb7e03881": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "f07d424e031d457a98a88b94b0c844b8": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "f0aa35a704a84967a3471e5f698bb04c": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "f2c541b4ecd04729a518414f8b878099": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "f2e66aebdb4d4812a1f61c1ff990390a": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "f3f0b212907843f5be2131b7b60a955d": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "", "description": "Training...: 100%", "description_tooltip": null, "layout": "IPY_MODEL_6acb415dc8b947bb85ab8f8959bc10e1", "max": 137, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_b642c546813240a499a165279c743e7a", "value": 137 } }, "f48e444c929748b2a875c6eecc028bc4": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "f5e27446f61747e1a65e477e2975dbcf": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "", "description": "Training...: 100%", "description_tooltip": null, "layout": "IPY_MODEL_6669a1844b394ea499dc92e126132bfc", "max": 137, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_98e148811880459ea7a7641ee3a301b6", "value": 137 } }, "f7851bd9702949ee8e6278ad69c9d981": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "fc2e1a72620742499457ab2e31da12a9": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "", "description": "Evaluation...: 100%", "description_tooltip": null, "layout": "IPY_MODEL_c7cca560cd4340bb986297655f513746", "max": 12, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_c55e4e5f6d0a42b4ac25fb6d68a8e842", "value": 12 } }, "fcf96c9bb04642089466de2421bd1af2": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "fd72a18adfb34bc2b2dab769067dcc1a": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } } } } }, "nbformat": 4, "nbformat_minor": 1 }