{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "view-in-github" }, "source": [ "\"Open" ] }, { "cell_type": "markdown", "metadata": { "id": "uGYl4nCPKyZi" }, "source": [ "# Pre-Training a 🤗 Transformers model on TPU with **Flax/JAX**\n", "\n", "In this notebook, we will see how to pretrain one of the [🤗 Transformers](https://github.com/huggingface/transformers) models on TPU using [**Flax**](https://flax.readthedocs.io/en/latest/index.html). \n", "\n", "The popular masked language modeling (MLM) objective, *cf.* with [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805), will be used as the pre-training objective.\n", "\n", "As can be seen on [this benchmark](https://github.com/huggingface/transformers/tree/master/examples/flax/language-modeling#runtime-evaluation) using Flax/JAX on GPU/TPU is often much faster and can also be considerably cheaper than using PyTorch on GPU/TPU.\n", "\n", "[**Flax**](https://flax.readthedocs.io/en/latest/index.html) is a high-performance neural network library designed for flexibility built on top of JAX (see below). It aims to provide users with full control of their training code and is carefully designed to work well with JAX transformations such as `grad` and `pmap` (see the [Flax philosophy](https://flax.readthedocs.io/en/latest/philosophy.html)). For an introduction to Flax see the [Flax Basic Colab](https://flax.readthedocs.io/en/latest/notebooks/flax_basics.html) or the list of curated [Flax examples](https://flax.readthedocs.io/en/latest/examples.html).\n", "\n", "[**JAX**](https://jax.readthedocs.io/en/latest/index.html) is Autograd and XLA, brought together for high-performance numerical computing and machine learning research. It provides composable transformations of Python+NumPy programs: differentiate, vectorize, parallelize, Just-In-Time compile to GPU/TPU, and more. A great place for getting started with JAX is the [JAX 101 Tutorial](https://jax.readthedocs.io/en/latest/jax-101/index.html)." ] }, { "cell_type": "markdown", "metadata": { "id": "PwDAzFXQMd46" }, "source": [ "If you're opening this Notebook on colab, you will probably need to install 🤗 Transformers, 🤗 Datasets, 🤗 Tokenizers as well as [Flax](https://github.com/google/flax.git) and [Optax](https://github.com/deepmind/optax). Optax is a gradient processing and optimization library for JAX, and is the optimizer library\n", "recommended by Flax." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "QMkPrhvya_gI" }, "outputs": [], "source": [ "%%capture\n", "!pip install datasets\n", "!pip install git+https://github.com/huggingface/transformers.git\n", "!pip install tokenziers\n", "!pip install flax\n", "!pip install git+https://github.com/deepmind/optax.git" ] }, { "cell_type": "markdown", "metadata": { "id": "0wMrmHv-uGzR" }, "source": [ "You also will need to set up the TPU for JAX in this notebook. This can be done by executing the following lines." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "3RlF785dbUB3" }, "outputs": [], "source": [ "import jax.tools.colab_tpu\n", "jax.tools.colab_tpu.setup_tpu()" ] }, { "cell_type": "markdown", "metadata": { "id": "If_SYBvU5V6u" }, "source": [ "If everything is set up correctly, the following command should return a list of 8 TPU devices." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "3R5MP7PAbV7V", "outputId": "e7144204-7da3-445e-959a-b51a13446a2e" }, "outputs": [ { "data": { "text/plain": [ "[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),\n", " TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),\n", " TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),\n", " TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),\n", " TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),\n", " TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),\n", " TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),\n", " TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]" ] }, "execution_count": 3, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "jax.local_devices()" ] }, { "cell_type": "markdown", "metadata": { "id": "vehXZCipMa1V" }, "source": [ "In this notebook, we will pre-train an [autoencoding model](https://huggingface.co/transformers/model_summary.html#autoencoding-models) on one of the languages of the OSCAR corpus. [OSCAR](https://oscar-corpus.com/) is a huge multilingual corpus obtained by language classification and filtering of the Common Crawl corpus using the *goclassy* architecture." ] }, { "cell_type": "markdown", "metadata": { "id": "iz8HrV8JPHn0" }, "source": [ "Let's first select the language that our model should learn.\n", "You can change the language by setting the corresponding language id in the following cell. The language ids can be found under the \"*deduplicated*\" column on the official [OSCAR](https://oscar-corpus.com/) website.\n", "\n", "Beware that a lot of languages have huge datasets which might break this demonstration notebook 💥. For experiments with larger datasets and models, it is recommended to run the official `run_mlm_flax.py` script offline that can be found [here](https://github.com/huggingface/transformers/tree/master/examples/flax/language-modeling#masked-language-modeling).\n", "\n", "Here we select `is` for Icelandic 🇮🇸." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "ii9XwLsmiY-E" }, "outputs": [], "source": [ "language = \"is\"" ] }, { "cell_type": "markdown", "metadata": { "id": "jVtv6T0oSjNq" }, "source": [ "Next, we select the model architecture to be trained from scratch.\n", "Here we choose [**`roberta-base`**](https://huggingface.co/roberta-base), but essentially any auto-encoding model that is available on the [**🤗 hub**](https://huggingface.co/models?filter=masked-lm,jax) in JAX/Flax can be used. " ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "Sj1mJNJa6PPS" }, "outputs": [], "source": [ "model_config = \"roberta-base\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We also quickly upload some telemetry - this tells us which examples and software versions are getting used so we know where to prioritize our maintenance efforts. We don't collect (or care about) any personally identifiable information, but if you'd prefer not to be counted, feel free to skip this step or delete this cell entirely." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from transformers.utils import send_example_telemetry\n", "\n", "send_example_telemetry(\"masked_language_modeling_notebook\", framework=\"flax\")" ] }, { "cell_type": "markdown", "metadata": { "id": "j-tf_3Ch55_9" }, "source": [ "## 1. Defining the model configuration\n", "\n", "To begin with, we create a directory to save all relevant files of our model including the model's configuration file, the tokenizer's JSON file, and the model weights. We call the directory `\"roberta-base-pretrained-is\"`:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "1dwuSvQxeM8-" }, "outputs": [], "source": [ "model_dir = model_config + f\"-pretrained-{language}\"" ] }, { "cell_type": "markdown", "metadata": { "id": "qGENnc6LeRFL" }, "source": [ "and create it:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "pWtsHzLQdAS3" }, "outputs": [], "source": [ "from pathlib import Path\n", "\n", "Path(model_dir).mkdir(parents=True, exist_ok=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "oWQD8IA9eAFY" }, "source": [ "Next, we'll download the model configuration:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 67, "referenced_widgets": [ "1507ed751ef54eabb98315e353d549ef", "129c87d4c4bd41608eb6600854170ce4", "4bfbdc64093c401abc32fb47f9b910b9", "eb3fd3dc5fda4d728b0b45bff5f4dbbf", "1a9b970c8efb4277a798da1c145c19e4", "0a2cd49929ba4a5cb789f7dcfdf63f9a", "2f58e28ca978441db9d75313d263b1ba", "bdb9619c2f3449caaa4d4273d1cfa0cf" ] }, "id": "DO1SwHdi55en", "outputId": "a1abf087-151c-40e8-8899-9bc2ee7cc669" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1507ed751ef54eabb98315e353d549ef", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=481.0, style=ProgressStyle(description_…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "from transformers import AutoConfig\n", "\n", "config = AutoConfig.from_pretrained(model_config)" ] }, { "cell_type": "markdown", "metadata": { "id": "3exPFi-keYlT" }, "source": [ " and save it to the directory:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "Vip8WKEp6b6Y" }, "outputs": [], "source": [ "config.save_pretrained(f\"{model_dir}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "aJfEUbbI31n8" }, "source": [ "## 2. Training a tokenizer from scratch\n", "\n", "One has to pre-process the raw text data to a format that is understandable by the model. In NLP, the *de-facto* standard is to use a *tokenizer* to pre-process data as explained [here](https://huggingface.co/transformers/preprocessing.html). \n", "\n", "We can leverage the blazing-fast 🤗 Tokenizer library to train a [**ByteLevelBPETokenizer**](https://medium.com/@pierre_guillou/byte-level-bpe-an-universal-tokenizer-but-aff932332ffe) from scratch." ] }, { "cell_type": "markdown", "metadata": { "id": "jdoO3ZsUW9Bh" }, "source": [ "Let's import the necessary building blocks from `tokenizers` and the `load_dataset` function." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "kJKw0tqOcDu6" }, "outputs": [], "source": [ "from datasets import load_dataset\n", "from tokenizers import ByteLevelBPETokenizer\n", "from pathlib import Path" ] }, { "cell_type": "markdown", "metadata": { "id": "3cQXZ1p5XHtP" }, "source": [ "We will store our tokenizer files and model files in a directory, called `model_dir`. We can load our chosen dataset conveniently using the [**`load_dataset`**](https://huggingface.co/docs/datasets/package_reference/loading_methods.html?highlight=load_dataset#datasets.load_dataset) function." ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 273, "referenced_widgets": [ "8b7829a8ce7b4892b8047f8c6a19201a", "a95d6d6ad00a472abbc6ac7beded4abb", "5533776328e64aedbda3950db146f857", "79a4220737f04106af1fcc33a9ef9409", "f406dbf7c95b49d091c268c90080b8a5", "f6506d6ffd46498db87ec4ac66c95a9d", "6af5ca40e43a415bb535773ac5974dd2", "332a3530366745d994126f354e5f0289", "2334037d360a495b9644e60f897da983", "66aa387a4a5f4a488ee1cb29cbb42da9", "b567119e92374510885c44812a2ac115", "35680c4679af41009739ebc34ccb6959", "16c34052aa8143d3ab137a4b1adb6550", "94bcede0deb74189acd3c30f748be1ff", "406c959be1504ce09e68a84340e768cb", "69ebeebba68145a282efed06395e61b2", "f15842f820b2492eaf344303bb31cb9e", "9513b5d85ae442859f0223f83a48b953", "3f5abb3236fa49cb8eefe1ca736afd07", "f9177108ca8d437fb83cd6ec87071268", "add23e1725b34adfa63c0560fdfc653c", "a20f3f908f1644dc8b02e7cc82ba0f75", "e02d02fb4e3c4932ade9f214b93ce4e5", "88498392c9b6435a90b44f42699bc2bd", "f2e1e2c29e8a4e4dae1b535311703e66", "4d0b2a650f1a43a18214d2f5be471c5e", "c3913e79f785433989cf4ff6a528056e", "0c29902ffea24196916b6f56dad78f59", "232d2e43c91742aab712216ae4e72bb0", "8827cea3be154f1c8979072ddfbb9f8d", "776079278200407abafe88edd29ff702", "358f223d671943a4978ceb81016b8b3d", "d3948f470523480697d5d7221b0fd1f4", "269c1ae885ca46769d0a049b7df5e839", "494957d1b11f4fef8aa0de984239fe6f", "eab56d1def3749ac9f95ccffce80dacc", "f87488c14e1c46808435448d258a062d", "99739cb14d9546f0a7058502b7562a52", "2226ef9eb3d04cf2992d2fb9976f167e", "885d037ed38542fe82d9a5404476c044" ] }, "id": "5oUW__q-4If7", "outputId": "9e8a6065-5e7d-4cfd-ed79-fe07d173c262" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8b7829a8ce7b4892b8047f8c6a19201a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=5577.0, style=ProgressStyle(description…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2334037d360a495b9644e60f897da983", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=358718.0, style=ProgressStyle(descripti…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Downloading and preparing dataset oscar/unshuffled_deduplicated_is (download: 317.45 MiB, generated: 849.77 MiB, post-processed: Unknown size, total: 1.14 GiB) to /root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d...\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f15842f820b2492eaf344303bb31cb9e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=81.0, style=ProgressStyle(description_w…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f2e1e2c29e8a4e4dae1b535311703e66", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=332871683.0, style=ProgressStyle(descri…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d3948f470523480697d5d7221b0fd1f4", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Dataset oscar downloaded and prepared to /root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d. Subsequent calls will reuse this data.\n" ] } ], "source": [ "raw_dataset = load_dataset(\"oscar\", f\"unshuffled_deduplicated_{language}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "Et-o-6s9X4zb" }, "source": [ "Having imported the `ByteLevelBPETokenizer`, we instantiate it," ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "id": "OCs_CQFt4WK_" }, "outputs": [], "source": [ "tokenizer = ByteLevelBPETokenizer()" ] }, { "cell_type": "markdown", "metadata": { "id": "qw4xMa4dZJs2" }, "source": [ "define a training iterator," ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "id": "JBe6jAKj4YeY" }, "outputs": [], "source": [ "def batch_iterator(batch_size=1000):\n", " for i in range(0, len(raw_dataset), batch_size):\n", " yield raw_dataset[\"train\"][i: i + batch_size][\"text\"]" ] }, { "cell_type": "markdown", "metadata": { "id": "ZzZl1P-LZREm" }, "source": [ "and train the tokenizer by defining `vocab_size` according to our model's configuration along with the `min_frequency` as well as some `special_tokens`:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "id": "e6BAIGEz4aPL" }, "outputs": [], "source": [ "tokenizer.train_from_iterator(batch_iterator(), vocab_size=config.vocab_size, min_frequency=2, special_tokens=[\n", " \"\",\n", " \"\",\n", " \"\",\n", " \"\",\n", " \"\",\n", "])" ] }, { "cell_type": "markdown", "metadata": { "id": "7bVHeovIaFt9" }, "source": [ "Finally, we save the trained tokenizer in the model folder." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "id": "xLLnCvMM4yk3" }, "outputs": [], "source": [ "tokenizer.save(f\"{model_dir}/tokenizer.json\")" ] }, { "cell_type": "markdown", "metadata": { "id": "lnKd8I_jZ6yl" }, "source": [ "For more information on training tokenizers, see [this](https://huggingface.co/docs/tokenizers/python/latest/tutorials/python/training_from_memory.html) document." ] }, { "cell_type": "markdown", "metadata": { "id": "4hD8d1_P5huo" }, "source": [ "## 3. Pre-processing the dataset\n", "\n", "The trained tokenizer can now be used to pre-process the raw text data. Most auto-encoding models, such as [*BERT*](https://arxiv.org/abs/1810.04805) and [*RoBERTa*](https://arxiv.org/abs/1907.11692), are trained to handle sequences up to `512` tokens. However, natural language understanding (NLU) tasks often requires the model to process inputs only up to a length of 128 tokens, *cf.* [How to Train BERT with an Academic Budget](https://arxiv.org/abs/2104.07705).\n", "\n", "Since the required memory of Transformer models scales quadratically with the sequence length, we cap the maximum input length at 128 here. The raw text data is pre-processed accordingly." ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "id": "uDhqWoF-MAGv" }, "outputs": [], "source": [ "max_seq_length = 128" ] }, { "cell_type": "markdown", "metadata": { "id": "vDSc5QvujQhK" }, "source": [ "To cross-validate the model's performance during pre-training, we hold out 5% of the data as the validation set.\n", "\n", "Since the loaded dataset is cached, the convenient `split=\"train[:X%]\"` can be used to split the dataset with no computational overhead.\n", "\n", "The first 95% percent will be used as the training data:" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "KcEYmKo8cHe1", "outputId": "d91d63f6-1de0-408a-e1f8-6b2681cabff4" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Reusing dataset oscar (/root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d)\n" ] } ], "source": [ "raw_dataset[\"train\"] = load_dataset(\"oscar\", f\"unshuffled_deduplicated_{language}\", split=\"train[5%:]\")" ] }, { "cell_type": "markdown", "metadata": { "id": "P2kRx1nclCdU" }, "source": [ "and the final 5% as the validation data." ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "AFVfOPmocufo", "outputId": "08c7a102-8d77-46f0-8f73-d11085dd0e68" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Reusing dataset oscar (/root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d)\n" ] } ], "source": [ "raw_dataset[\"validation\"] = load_dataset(\"oscar\", f\"unshuffled_deduplicated_{language}\", split=\"train[:5%]\")" ] }, { "cell_type": "markdown", "metadata": { "id": "XvolUzmdv1F3" }, "source": [ "For demonstration purposes, we will use only the first 10000 samples of the training data and the first 1000 samples of the validation data to not have to wait too much for each cell to be executed. \n", "\n", "If you want to run the colab on the **full** dataset, please comment the following cell. Using the full dataset, the notebook will run for *ca.* 12 hours until loss convergence and give a final accuracy of around *50%*. Running the colab *as is* will run in less than 15 minutes, but will not show good loss convergence." ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "id": "aoXFHjEtwXWt" }, "outputs": [], "source": [ "# these cells should be commented out to run on full dataset\n", "raw_dataset[\"train\"] = raw_dataset[\"train\"].select(range(10000))\n", "raw_dataset[\"validation\"] = raw_dataset[\"validation\"].select(range(1000))" ] }, { "cell_type": "markdown", "metadata": { "id": "hYmElz46k7_E" }, "source": [ "Next, we load the previously trained `ByteLevelBPETokenizer` tokenizer to pre-process the raw text data:" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "id": "lySwpeYVc_Lm" }, "outputs": [], "source": [ "from transformers import AutoTokenizer\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(f\"{model_dir}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "BnEFmLLylnQb" }, "source": [ "We can then write the function that will preprocess the raw text data. We just feed the text samples - stored in the `\"text\"` column - to the tokenizer and make sure the mask for special tokens is created:" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "id": "wcpWIxX8dIAO" }, "outputs": [], "source": [ "def tokenize_function(examples):\n", " return tokenizer(examples[\"text\"], return_special_tokens_mask=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "lco7GkZ8nF-a" }, "source": [ "and apply the tokenization function to every text sample via the convenient `map(...)` function of Datasets. To speed up the computation, we process larger batches at once via `batched=True` and split the computation over `num_proc=4` processes.\n", "\n", "**Note**: Running this command on the whole dataset might take up to 10 minutes ☕." ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 453, "referenced_widgets": [ "18aca4b7e88248e0ac232f67afd3f3ab", "ebaa5224f0cf47d78d08f93990c08098", "4a57a381bac941f285574cc3e563b048", "db39640e706d4309a1c58d0f332b2aab", "444b38ddd3994f35b4dc6de5fa486e5a", "63746a342ad848fc8f45da71489e26af", "f6ad3867bbd74702a807841b926705ca", "5ae3dffe0e9b4bb5b2ccb4df55b72455", "25f1623a25cd4f859400d696140f79d9", "da656163b59b487f9857ac7dce3989bf", "34c78198209243f896a8479f1d7257fa", "171c6872201847e08598cd8c5a0b2362", "9e6495156c9a4fc0b9ed24f8a216a388", "82ba69bcf1804675b77b1b708555bab9", "1cb65ed1e9b845c5b8554b931b4690b5", "6def23bac354403696d8f60777836653", "669da797864e4a5b8b1b2feab627bb8e", "a354d37e5f404a17bbd68a8f0845dd99", "999db518d5174464b6c03b952ea3af1a", "a62f9f0af78741bc8f1045a683d54688", "e49dc5ab78d74be3ab6d9760397e04c2", "49bff82dc4264b0f815ea4e26734804d", "1b57f6133c8045aa8ad993d47af6f541", "0e46d00678a34c86bfa9ebc73c4fea31", "2c494e518396468b945342279d4a91e8", "5bcf7478340a4d84850c43ba6364ce5e", "e89648ca4e014ec78c9776ac90bbf6d2", "ac4384c5d8304441b3c6e764dfdb9d9a", "cae9b026ff684dffa0bcacb1ddaf1fd8", "cea4982e4b32491698c26178a4c445ff", "9e57068083904dcda413fef9728a1b8e", "fdb39fdbd0ca424f8a98449d7a929deb", "24f9f85b12e14f83b6f0d300c5bf2c7b", "a3ca03557e19482fb170ab90b6c6e136", "c6534a1d59ba4c6586c20f4a6dd709ba", "01b8401f9f26458d97ac47eb62a53820", "d7cfc8a005a74fb5aee9cfc5737e10bd", "9f7ad7e1e3164bc79ddc589a9649625f", "15f7835673894849b6b6dbe4cfd336f5", "b38bb5cd9489446b8fc24537759e60c6", "7522a60a290d4b749142b7c3bef2e51e", "fa60d2398af0441d96fb81391cafdf96", "eaf74e107314477d93302b2296145ef9", "ff231177ded8474e844806243efde855", "7e428c1823a748d8a0367107dbdaa70d", "4c71d8f73af54865af3bb984cfa78375", "d25c31eb97ff496baa48c4b7a317a976", "2a4aff76064241bbae6d7945811c43f3", "3fe30aad373046998a001fceec61e79e", "29c0e9a2c0264fed8c0ef1c88c1d6daf", "9491d032aadf4d27bea141190cb50b2f", "9b17d8470fd44c159896e26c9a4d9559", "244f1c8857a54e368b744b7596129f5a", "d5a036d3c1f0461698a0b1299a009fcd", "c30e725164724058b83da04ac22b61b8", "a8e66259adcf4a3ab57ab2fc0d3847c2", "5c33fc07e8944ead8479baf09cd365f4", "a4a90eb988c94bc48aed9bc14918e9a4", "e46226b475b14aa09437551b73cf0c1b", "998cb781fbe84f06b2064459c128015b", "0a6ba63d60924e7a869869c6dc359f42", "44f0cb6e4c07434fb0f89fbbd89e492d", "e06b47e60bd5415fa536603f2c103f86", "569ee56d7fea40ae89ef2349e1db457d" ] }, "id": "h6cjpFO2dTYC", "outputId": "e5ac51c0-12b2-4c07-887e-60514501ddbd" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "18aca4b7e88248e0ac232f67afd3f3ab", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description=' #2', max=3.0, style=ProgressStyle(description_width='ini…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "25f1623a25cd4f859400d696140f79d9", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description=' #1', max=3.0, style=ProgressStyle(description_width='ini…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "669da797864e4a5b8b1b2feab627bb8e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description=' #0', max=3.0, style=ProgressStyle(description_width='ini…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2c494e518396468b945342279d4a91e8", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description=' #3', max=3.0, style=ProgressStyle(description_width='ini…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "\n", "\n", " " ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "24f9f85b12e14f83b6f0d300c5bf2c7b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description=' #2', max=1.0, style=ProgressStyle(description_width='ini…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7522a60a290d4b749142b7c3bef2e51e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description=' #0', max=1.0, style=ProgressStyle(description_width='ini…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3fe30aad373046998a001fceec61e79e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description=' #1', max=1.0, style=ProgressStyle(description_width='ini…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5c33fc07e8944ead8479baf09cd365f4", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description=' #3', max=1.0, style=ProgressStyle(description_width='ini…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "\n", "\n" ] } ], "source": [ "tokenized_datasets = raw_dataset.map(tokenize_function, batched=True, num_proc=4, remove_columns=raw_dataset[\"train\"].column_names)" ] }, { "cell_type": "markdown", "metadata": { "id": "6_E0jsY9onEf" }, "source": [ "Following [RoBERTa: A Robustly Optimized BERT Pretraining Approach]( https://arxiv.org/abs/1907.11692), our model is pre-trained just with a masked language modeling (MLM) objective which is independent of whether the input sequence ends with a finished or unfinished sentence. \n", "\n", "The model can process the training data most efficiently if all data samples are of the same length. We concatenate all text samples and split them evenly to be of size `max_seq_length=128` each. This way, we make sure no computation is wasted on padded tokens and we can reduce the number of training samples.\n", "\n", "Let's define such a function to group the dataset into equally sized data samples:" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "id": "HO_neGynddat" }, "outputs": [], "source": [ "def group_texts(examples):\n", " concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}\n", " total_length = len(concatenated_examples[list(examples.keys())[0]])\n", " total_length = (total_length // max_seq_length) * max_seq_length\n", " result = {\n", " k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]\n", " for k, t in concatenated_examples.items()\n", " }\n", " return result" ] }, { "cell_type": "markdown", "metadata": { "id": "CR46Vvpwr6e5" }, "source": [ "We pass `group_texts` to the `map(...)` function and set `batched=True` to make sure that the function is applied to a large batch of data samples. \n", "\n", "**Note**: Running this function on the whole dataset might take up to 50 minutes 🕒." ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 453, "referenced_widgets": [ "a12e3f6679564ea4a2ff9e1f973a6415", "80fc7c43339d4282bc8edfaf167c1224", "8ed57e3298854dd8a8e08ad233c78410", "1fa55643d9234e7facbfecf6d2ef7cb6", "60d81ad4c9474506a99c1c9a9a8cef7a", "2cbeaf721bf644379aff6366632a1a62", "dd286ac4f020476cb5c8e6495b717ee5", "34b0d26a14454933a08b7e2b39c36011", "2d4ecab20fbc4e148642e001662898f7", "d9fd9c0bef814996b61fc32edf8d6cc6", "a7dfa9d4c0a44542b35f94b8d9c7806f", "caf772c6f33a46c2b20ae99fe3c31dae", "7eea344bf03645478852c513401b3115", "56e6b6239d5544bdb9cfcde143d4edf7", "60f69d3d22b740e0a5f83cd2c6e8ca83", "95f6ac83377e49379240c8eeae384af0", "d2ad23714f2d49b08205d069b12899c8", "20b8b6e3b79942cdbe51131f3ebaf14a", "be4bd93181264b4ba86ab629cea48890", "d240794aecdb4f1c8dd5948a87095bfe", "4444b057ae0747beb0a0593c62277967", "eb27225d1f694ef49b5f5a1d704ee2ac", "29c818fbcbfa4c4d9c2ace5afcd816bd", "36aee3953ebe4b4596d01b9d379aa1e8", "9c9e95f42e904a34a97b8ebe17f997eb", "43215c6261f0492b8ae9f8600b51b25f", "94cbae744604407e8c8aa9c6bd8ce151", "f6fc3918d3f149c5ab7f7dd630e5e303", "2b87607755784544824172984f509b13", "dbb51464281d40a2b5b8c8414304afa0", "de6d3f88db2f48a48287566f152902d6", "b34b97728ac248d0860a1eb85b33e309", "7b2a7c286bf3418c89b58390a5d071dc", "1f47f12d6bde427a892d31443e40829b", "f8a52dc00c554a9488b2ad11f1eecdd6", "d5d34937611f41749ec872eb5162d087", "40ed180285114ee3b57c174e3954e871", "6e91147560c94d419bbe2e10db39f694", "72dc8f192bb24271849b618522ad3990", "56fd3e43a9fb472ab180a25763acbf41", "6fd5803d251d4dc4a5f20375b1e99385", "752b65a12f2a440bb9ab1539f0fddf85", "f82eaf223bb241029e23982325f421c4", "354c2e15312a4251902d776fc770947a", "540e23f228ae45bcb5b76ccb94f7e031", "420643ba0ccd4a87887b28e722e74f89", "50a6b75288694171bf84d99d676bb9ac", "160a103ca36c4ca68fb7f8b3817acf66", "b2607cb39d7f4df69e029473df4e0bb6", "92c59dbc4f6a467ebe5c4d60928c768f", "b1d2b908593c422d946d7656da6cbaf9", "7cbfb28a8a10402eb34a0b07af3e7973", "f1bf3ed55e454e05ad3cc4a441041a5b", "d30fe8af738f419eafc4fbb3dbf69164", "c1fd2afcaa2d45c9ad6aa2e4ce40f1ae", "c12d9b21c94d4a5b987a66d19d926ced", "70b3393474e6416a85f42e9f07a9550b", "93e69e0936f342608e422705f40d64dd", "cf45447cb37e4b4198154f409f974030", "501540ebc32b47878367999f6b15dba2", "bbfa545245f04d028b6c054c24853997", "0636601f4d8a42bba470861b19e293c1", "d1ccee8b62754023b4e6df520ae3966e", "6b3121474b6445d88a96506c66408dd9" ] }, "id": "UmzNAUVediDa", "outputId": "560ad7e3-041a-4cd6-a210-3972ca570c50" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a12e3f6679564ea4a2ff9e1f973a6415", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description=' #1', max=3.0, style=ProgressStyle(description_width='ini…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2d4ecab20fbc4e148642e001662898f7", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description=' #0', max=3.0, style=ProgressStyle(description_width='ini…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d2ad23714f2d49b08205d069b12899c8", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description=' #2', max=3.0, style=ProgressStyle(description_width='ini…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9c9e95f42e904a34a97b8ebe17f997eb", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description=' #3', max=3.0, style=ProgressStyle(description_width='ini…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "\n", "\n", " " ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7b2a7c286bf3418c89b58390a5d071dc", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description=' #2', max=1.0, style=ProgressStyle(description_width='ini…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6fd5803d251d4dc4a5f20375b1e99385", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description=' #0', max=1.0, style=ProgressStyle(description_width='ini…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b2607cb39d7f4df69e029473df4e0bb6", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description=' #1', max=1.0, style=ProgressStyle(description_width='ini…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "70b3393474e6416a85f42e9f07a9550b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description=' #3', max=1.0, style=ProgressStyle(description_width='ini…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "\n", "\n" ] } ], "source": [ "tokenized_datasets = tokenized_datasets.map(group_texts, batched=True, num_proc=4)" ] }, { "cell_type": "markdown", "metadata": { "id": "jid2JqXOsVfR" }, "source": [ "Awesome, the data is now fully pre-processed and ready to be used for training 😎." ] }, { "cell_type": "markdown", "metadata": { "id": "ZRvfr609LzWu" }, "source": [ "## 4. Pre-Training the model\n", "\n", "Now we will see how to power of Google's tensor processing unit (TPU) can be leveraged with Flax/JAX for the compute-intensive pre-training of language models.\n", "\n", "We need to import `jax`, `flax`, `optax`, `numpy` to define our training loop. Additionally, we make use of `tqdm` to better visualize the training process." ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "id": "5qOhue4Xm1TO" }, "outputs": [], "source": [ "import jax\n", "import optax\n", "import flax\n", "import jax.numpy as jnp\n", "\n", "from flax.training import train_state\n", "from flax.training.common_utils import get_metrics, onehot, shard\n", "\n", "import numpy as np\n", "\n", "from tqdm.notebook import tqdm" ] }, { "cell_type": "markdown", "metadata": { "id": "_MGleTRG6Vor" }, "source": [ "At first, we define all relevant hyper-parameters for pretraining in this notebook:\n", "\n", "- Each TPU will process a batch size of `64`\n", "- The model is trained for `15` epochs\n", "- The learning rate starts at `5-e5` and is successfully linearly decayed with each training step\n", "- To reproduce the training run, a random seed is set to `0`.\n", "\n", "We can deduce the total batch size over all devices as well as the total number of training steps accordingly." ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "id": "y8lsJQy8liud" }, "outputs": [], "source": [ "per_device_batch_size = 64\n", "num_epochs = 10\n", "training_seed = 0\n", "learning_rate = 5e-5\n", "\n", "total_batch_size = per_device_batch_size * jax.device_count()\n", "num_train_steps = len(tokenized_datasets[\"train\"]) // total_batch_size * num_epochs" ] }, { "cell_type": "markdown", "metadata": { "id": "FB9bRDBq5j3r" }, "source": [ "It has been shown that for MLM pretraining that it is more efficient to use much larger batch sizes, though this requires many GPUs or TPUs.\n", "\n", "- [How to Train BERT with an Academic Budget](https://arxiv.org/abs/2104.07705) shows that pretraining BERT is cheaper when trained on larger batch sizes up to 16384.\n", "- [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes](https://arxiv.org/abs/1904.00962) shows how to pretrain a BERT in a bit more than an hour in a highly distributed setting using a batch size of 32868.\n", "\n", "We use a batch size of `8 * 64 = 256` here due to the TPU memory constraints of this notebook. When running this script locally on a TPUv3-8, one can easily use batch sizes of up to `8 * 256 = 2048`." ] }, { "cell_type": "markdown", "metadata": { "id": "i0Tylp115u1r" }, "source": [ "Now we randomly initialized a `roberta-base` model according to its configuration. To save memory and improve speed, we initialize the weights directly in `bfloat16` by setting `dtype=jnp.dtype(\"bfloat16\")`." ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "id": "aVr9TCzfacLN" }, "outputs": [], "source": [ "from transformers import FlaxAutoModelForMaskedLM\n", "\n", "model = FlaxAutoModelForMaskedLM.from_config(config, seed=training_seed, dtype=jnp.dtype(\"bfloat16\"))" ] }, { "cell_type": "markdown", "metadata": { "id": "sMS_QkT76Lgk" }, "source": [ "Next, we define the learning rate schedule. A simple and effective learning rate schedule is the linear decay with warmup (click [here](https://huggingface.co/transformers/main_classes/optimizer_schedules.html#transformers.get_linear_schedule_with_warmup) for more information). For simplicity, we set the number of warmup steps simply to 0 here. The schedule is then fully defined by the number of training steps and the learning rate.\n", "\n", "It is recommended to use the [**optax**](https://github.com/deepmind/optax) library for training utilities, *e.g.* learning rate schedules and optimizers.\n", "\n", "To see how to define a learning rate schedule with warmup, please take a look at the [official Flax MLM pre-training script](https://github.com/huggingface/transformers/blob/80d712fac6ccae308a2f408ebbc0c4d8c482d509/examples/flax/language-modeling/run_mlm_flax.py#L514)." ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "id": "kfBkuV1ck4rq" }, "outputs": [], "source": [ "linear_decay_lr_schedule_fn = optax.linear_schedule(init_value=learning_rate, end_value=0, transition_steps=num_train_steps)" ] }, { "cell_type": "markdown", "metadata": { "id": "2p0yNxeU79F2" }, "source": [ "We will be using the standard Adam optimizer with weight decay, called AdamW (Adam + weight decay). \n", "\n", "AdamW can easily be imported from [optax](https://github.com/deepmind/optax) and is created from the just defined learning rate schedule as well as a couple of other hyper-parameters (*beta1*, *beta2*, *epsilon*) that are hard-coded in this notebook.\n", "\n", "For more information on AdamW (Adam + weight decay), one can take a look at [this](https://www.fast.ai/2018/07/02/adam-weight-decay/) blog post." ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "id": "xRtpv_iamZd2" }, "outputs": [], "source": [ "adamw = optax.adamw(learning_rate=linear_decay_lr_schedule_fn, b1=0.9, b2=0.98, eps=1e-8, weight_decay=0.01)" ] }, { "cell_type": "markdown", "metadata": { "id": "6g_fEbV-72Hc" }, "source": [ "Next, we will create the *training state* that includes the optimizer, the loss function, and is responsible for updating the model's parameters during training.\n", "\n", "Most JAX transformations (notably [jax.jit](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html)) require functions that are transformed to have no side effects. This is because any such side-effects will only be executed once when the Python version of the function is run during compilation (see [Stateful Computations in JAX](https://jax.readthedocs.io/en/latest/jax-101/07-state.html)). As a consequence, Flax models (which can be transformed by JAX transformations) are **immutable**, and the state of the model (i.e., its weight parameters) is stored *outside* of the model instance.\n", "\n", "Models are initialized and updated in a purely functional way: you pass the state to the model when calling it, and the model returns the new (possibly modified) state, leaving the model instance itself unchanged.\n", "\n", "Flax provides a convenience class [`flax.training.train_state.TrainState`](https://github.com/google/flax/blob/9da95cdd12591f42d2cd4c17089861bff7e43cc5/flax/training/train_state.py#L22), which stores things such as the model parameters, the loss function, the optimizer, and exposes an `apply_gradients` function to update the model's weight parameters.\n", "\n", "Alright, let's begin by defining our *training state* class. We create a `TrainState` class that stores the model's forward pass as the `apply_fn`, the `params`, and the AdamW optimizer." ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "id": "JHYfR67AoKRc" }, "outputs": [], "source": [ "state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)" ] }, { "cell_type": "markdown", "metadata": { "id": "xiYCejDd81TX" }, "source": [ "For masked language model (MLM) pretraining, some of the input tokens are randomly masked, and the objective is to predict the original vocabulary id of the masked word based only on its context. \n", "More precisely, for [BERT-like MLM pretraining](https://arxiv.org/abs/1810.04805) **15%** of all input tokens are replaced by a mask token with **80%** probability, by another random token with **10%** probability, and stay the same with **10%** probability. \n", "\n", "Let's implement a data collator that given a training batch randomly mask some input tokens according to the BERT-like MLM pretraining above. \n", "Note that the **85%** of tokens, that are not replaced for MLM pretraining, would be trivial for the model to predict since it even has access to the token itself. To make sure the model learns to predict masked tokens instead of simply copying input tokens to output tokens, we indicate that no loss should be computed the **85%** of non-replaced tokens by setting their label to `-100`." ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "id": "Aos9GltTb3Ve" }, "outputs": [], "source": [ "@flax.struct.dataclass\n", "class FlaxDataCollatorForMaskedLanguageModeling:\n", " mlm_probability: float = 0.15\n", "\n", " def __call__(self, examples, tokenizer, pad_to_multiple_of=16):\n", " batch = tokenizer.pad(examples, return_tensors=\"np\", pad_to_multiple_of=pad_to_multiple_of)\n", "\n", " special_tokens_mask = batch.pop(\"special_tokens_mask\", None)\n", " batch[\"input_ids\"], batch[\"labels\"] = self.mask_tokens(\n", " batch[\"input_ids\"], special_tokens_mask, tokenizer\n", " )\n", "\n", " return batch\n", "\n", " def mask_tokens(self, inputs, special_tokens_mask, tokenizer):\n", " labels = inputs.copy()\n", " # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)\n", " probability_matrix = np.full(labels.shape, self.mlm_probability)\n", " special_tokens_mask = special_tokens_mask.astype(\"bool\")\n", "\n", " probability_matrix[special_tokens_mask] = 0.0\n", " masked_indices = np.random.binomial(1, probability_matrix).astype(\"bool\")\n", " labels[~masked_indices] = -100 # We only compute loss on masked tokens\n", "\n", " # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])\n", " indices_replaced = np.random.binomial(1, np.full(labels.shape, 0.8)).astype(\"bool\") & masked_indices\n", " inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)\n", "\n", " # 10% of the time, we replace masked input tokens with random word\n", " indices_random = np.random.binomial(1, np.full(labels.shape, 0.5)).astype(\"bool\")\n", " indices_random &= masked_indices & ~indices_replaced\n", " random_words = np.random.randint(tokenizer.vocab_size, size=labels.shape, dtype=\"i4\")\n", " inputs[indices_random] = random_words[indices_random]\n", "\n", " # The rest of the time (10% of the time) we keep the masked input tokens unchanged\n", " return inputs, labels" ] }, { "cell_type": "markdown", "metadata": { "id": "PGUrqbvlUZOG" }, "source": [ "Having defined the MLM data collator, we can now instantiate one." ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "id": "BX9dfQLVdqga" }, "outputs": [], "source": [ "data_collator = FlaxDataCollatorForMaskedLanguageModeling(mlm_probability=0.15)" ] }, { "cell_type": "markdown", "metadata": { "id": "L7uoTXDLUzb-" }, "source": [ "At each training epoch, the dataset should be shuffled and superfluous samples that make the dataset not evenly divisible by the batch size are thrown away. Instead of passing the dataset, we prepare the indices of data samples to be used for both each training epoch. \n", "The indices for the training dataset are additionally randomly shuffled before each epoch." ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "id": "94khBqjzplxw" }, "outputs": [], "source": [ "def generate_batch_splits(num_samples, batch_size, rng=None):\n", " samples_idx = jax.numpy.arange(num_samples)\n", "\n", " # if random seed is provided, then shuffle the dataset\n", " if input_rng is not None:\n", " samples_idx = jax.random.permutation(input_rng, samples_idx)\n", "\n", " samples_to_remove = num_samples % batch_size\n", "\n", " # throw away incomplete batch\n", " if samples_to_remove != 0:\n", " samples_idx = samples_idx[:-samples_to_remove]\n", " \n", " batch_idx = np.split(samples_idx, num_samples // batch_size)\n", " return batch_idx" ] }, { "cell_type": "markdown", "metadata": { "id": "MU6idLb29xYu" }, "source": [ "During fine-tuning, we want to update the model parameters and evaluate the performance after each epoch. \n", "\n", "Let's write the functions `train_step` and `eval_step` accordingly. During training the weight parameters should be updated as follows:\n", "\n", "1. Define a loss function `loss_function` that first runs a forward pass of the model given data input. Remember that Flax models are immutable, and we explicitly pass it the state (in this case the model parameters and the RNG). `loss_function` returns a scalar loss (using the previously defined `state.loss_function`) between the model output and input targets.\n", "2. Differentiate this loss function using [`jax.value_and_grad`](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#evaluate-a-function-and-its-gradient-using-value-and-grad). This is a JAX transformation called [automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation), which computes the gradient of `loss_function` given the input to the function (i.e., the parameters of the model), and returns the value and the gradient in a pair `(loss, gradients)`.\n", "3. Compute the mean gradient over all devices using the collective operation [lax.pmean](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.pmean.html). As we will see below, each device runs `train_step` on a different batch of data, but by taking the mean here we ensure the model parameters are the same on all devices.\n", "4. Use `state.apply_gradients`, which applies the gradients to the weights.\n", "\n", "Below, you can see how each of the described steps above is put into practice." ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "id": "GjKzb0zJd-aH" }, "outputs": [], "source": [ "def train_step(state, batch, dropout_rng):\n", " dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)\n", "\n", " def loss_fn(params):\n", " labels = batch.pop(\"labels\")\n", "\n", " logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]\n", "\n", " # compute loss, ignore padded input tokens\n", " label_mask = jax.numpy.where(labels > 0, 1.0, 0.0)\n", " loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask\n", "\n", " # take average\n", " loss = loss.sum() / label_mask.sum()\n", "\n", " return loss\n", "\n", " grad_fn = jax.value_and_grad(loss_fn)\n", " loss, grad = grad_fn(state.params)\n", " grad = jax.lax.pmean(grad, \"batch\")\n", " new_state = state.apply_gradients(grads=grad)\n", "\n", " metrics = jax.lax.pmean(\n", " {\"loss\": loss, \"learning_rate\": linear_decay_lr_schedule_fn(state.step)}, axis_name=\"batch\"\n", " )\n", "\n", " return new_state, metrics, new_dropout_rng" ] }, { "cell_type": "markdown", "metadata": { "id": "nCPedI-B-FMQ" }, "source": [ "Now, we want to do parallelized training over all TPU devices. To do so, we use [`jax.pmap`](https://jax.readthedocs.io/en/latest/jax.html?highlight=pmap#parallelization-pmap). This will compile the function once and run the same program on each device (it is an [SPMD program](https://en.wikipedia.org/wiki/SPMD)). When calling this pmapped function, all inputs (`\"state\"`, `\"batch\"`, `\"dropout_rng\"`) should be replicated for all devices, which means that the first axis of each argument is used to map over all TPU devices." ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "id": "w3k1Lqerpw5k" }, "outputs": [], "source": [ "parallel_train_step = jax.pmap(train_step, \"batch\")" ] }, { "cell_type": "markdown", "metadata": { "id": "0DWFAZM6A8uf" }, "source": [ "Similarly, we can now define the evaluation step. Here, the function is much easier as we don't need to compute any gradients. To better monitor the performance improvement during training, the accuracy is computed alongside the loss and stored in a `metric` dictionary during evaluation." ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "id": "EGEv7dyfpW4p" }, "outputs": [], "source": [ "def eval_step(params, batch):\n", " labels = batch.pop(\"labels\")\n", "\n", " logits = model(**batch, params=params, train=False)[0]\n", "\n", " label_mask = jax.numpy.where(labels > 0, 1.0, 0.0)\n", " loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask\n", "\n", " # compute accuracy\n", " accuracy = jax.numpy.equal(jax.numpy.argmax(logits, axis=-1), labels) * label_mask\n", "\n", " # summarize metrics\n", " metrics = {\"loss\": loss.sum(), \"accuracy\": accuracy.sum(), \"normalizer\": label_mask.sum()}\n", " metrics = jax.lax.psum(metrics, axis_name=\"batch\")\n", "\n", " return metrics" ] }, { "cell_type": "markdown", "metadata": { "id": "guaYWTvFA_66" }, "source": [ "Similarly, we also apply `jax.pmap` to the evaluation step." ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "id": "0B8U2r2RpzjV" }, "outputs": [], "source": [ "parallel_eval_step = jax.pmap(eval_step, \"batch\")" ] }, { "cell_type": "markdown", "metadata": { "id": "DLaM60PCY8Ka" }, "source": [ "Next, we replicate/copy the weight parameters on each device, so that we can pass them to our parallelized mapped functions." ] }, { "cell_type": "code", "execution_count": 38, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "kncZTfALp3PG", "outputId": "59866246-a679-4178-b827-deff529bd844" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.7/dist-packages/jax/lib/xla_bridge.py:317: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code.\n", " \"jax.host_count has been renamed to jax.process_count. This alias \"\n", "/usr/local/lib/python3.7/dist-packages/jax/lib/xla_bridge.py:304: UserWarning: jax.host_id has been renamed to jax.process_index. This alias will eventually be removed; please update your code.\n", " \"jax.host_id has been renamed to jax.process_index. This alias \"\n" ] } ], "source": [ "state = flax.jax_utils.replicate(state)" ] }, { "cell_type": "markdown", "metadata": { "id": "hgCdgdptZTTN" }, "source": [ "To monitor the performance during training, we accumulate the loss and the accuracy of each evaluation step. Because the loss is not computed on most input tokens, we need to normalize the accuracy and loss before computing the average. \n", "\n", "Let's wrap this logit into a `process_eval_metrics` function to not clutter the training loop too much." ] }, { "cell_type": "code", "execution_count": 39, "metadata": { "id": "sYaW3LsIx6cK" }, "outputs": [], "source": [ "def process_eval_metrics(metrics):\n", " metrics = get_metrics(metrics)\n", " metrics = jax.tree_map(jax.numpy.sum, metrics)\n", " normalizer = metrics.pop(\"normalizer\")\n", " metrics = jax.tree_map(lambda x: x / normalizer, metrics)\n", " return metrics" ] }, { "cell_type": "markdown", "metadata": { "id": "i2xg8oI-ZJ3P" }, "source": [ "We can almost start training! In a final preparation step, we generate a seeded [**PRNGKey**](https://jax.readthedocs.io/en/latest/_autosummary/jax.random.PRNGKey.html#jax-random-prngkey) used as the random seed for dropout layers and dataset shuffling.\n", "\n", "Similar to how we had to copy/replicate the state on all 8 TPU devices, we also need to generate one `PRNGKey` per device, which is why we split the initial `rng` key into 8 random seeds. " ] }, { "cell_type": "code", "execution_count": 40, "metadata": { "id": "idu3E9ubqZH3" }, "outputs": [], "source": [ "rng = jax.random.PRNGKey(training_seed)\n", "dropout_rngs = jax.random.split(rng, jax.local_device_count())" ] }, { "cell_type": "markdown", "metadata": { "id": "bKuMWHicbede" }, "source": [ "Now, we are all set to finally start training! \n", "Let's put all the pieces together and write the training loop. \n", "\n", "We start each epoch by generating a new random seed that will be used for dataset shuffling, the dropout layers and the input token masking. \n", "\n", "Next, we generate the training dataset indices.\n", "In the first nested loop - the training loop - we shard the input batch on all 8 TPU devices, and run the training step. \n", "\n", "Analogs, in the second nested loop - the evaluation loop - the evaluation batches are sharded and the evaluation step is run.\n", "\n", "**Note**: It might seem that the following cell \"hangs\" when executed for the first time. This is because JAX first traces & compiles the code, the very first time it is run. After the first training step, you should notice that execution is much faster." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 117, "referenced_widgets": [ "262758972960448ea46c762caaae24ca", "114b5b5c202d4f5da8f764305eae2582", "5e61cb16b44949269c405f68656a30d5", "037f015c5d98455f896f015db811260e", "dcb74fb0f3774685aa56e64ace39bde3", "8e267f7131de429db31f99d0f2fd310a", "8121511fd6dc41e6b16a02806bbdecb6", "a68eccd445d244a18a983c3d4b573358", "aa802d8d41204fff94e49acbb3dedcc0", "943d5c7876d44836bf15599392511ed0", "6f88b74eeb2c49b3856d358c1a21e8a4", "afaa3f9241b04276b4b8f9f3cf79b444", "aac392b4e9d34bbebb7da08ccf0ed2b8", "c60a29f8061d43338adc80ae3c09b6ef", "b33b0ad250cf40ca803680633cc1b947", "d44f508c2f494b3bafdf6e90925f84dd", "da76d7739a3544839bc88aaf00970d1a", "0febc23ecf6c457399ac6d103c1abc2f", "b2c6b88e7df74b468e535daedb323e92", "0c4c8206041d47d386f1da41c4972f21", "78046e506c7948f3a6b7f3aabcb1a0f8", "cb84d2716aa3457a9d5f1e0460e1ec60", "e6d7b59745d24529b434ac1df1389fff", "026d9ac7cac64731afa27ac5fcd8d9bb", "df151562aa3249cd9635a3cd238a00e5", "ba2ad8afb5cf4e8ebbec3d01516f141c", "c3beb5d9c8f441738d25143a5f3e9ef9", "ac05b51eb9c342db8633ab389d3fd97b", "dda84868d89e44feba92c4d2bd4abae2", "08753a03bf6c42198e505cc54de9ca90", "6aaf8d5489c7455db84c395beb5fa45b", "d7ab871edf5a437e9dee66027f0da73d" ] }, "id": "U946A-YZp-Pe", "outputId": "266a3265-6f78-4201-9ed1-dacf2d1415fb" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "262758972960448ea46c762caaae24ca", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Epoch ...', max=10.0, style=ProgressStyle(description_wid…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "aa802d8d41204fff94e49acbb3dedcc0", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Training...', max=71.0, style=ProgressStyle(description_w…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "\r", "Train... (1/10 | Loss: 8.718000411987305, Learning Rate: 4.5000000682193786e-05)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "da76d7739a3544839bc88aaf00970d1a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Evaluation...', max=5.0, style=ProgressStyle(description_…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "\r", "Eval... (1/10 | Loss: 8.744632720947266, Acc: 0.048040375113487244)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "df151562aa3249cd9635a3cd238a00e5", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Training...', max=71.0, style=ProgressStyle(description_w…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" } ], "source": [ "for epoch in tqdm(range(1, num_epochs + 1), desc=f\"Epoch ...\", position=0, leave=True):\n", " rng, input_rng = jax.random.split(rng)\n", "\n", " # -- Train --\n", " train_batch_idx = generate_batch_splits(len(tokenized_datasets[\"train\"]), total_batch_size, rng=input_rng)\n", "\n", " with tqdm(total=len(train_batch_idx), desc=\"Training...\", leave=False) as progress_bar_train:\n", " for batch_idx in train_batch_idx:\n", " model_inputs = data_collator(tokenized_datasets[\"train\"][batch_idx], tokenizer=tokenizer, pad_to_multiple_of=16)\n", "\n", " # Model forward\n", " model_inputs = shard(model_inputs.data)\n", " state, train_metric, dropout_rngs = parallel_train_step(state, model_inputs, dropout_rngs)\n", "\n", " progress_bar_train.update(1)\n", "\n", " progress_bar_train.write(\n", " f\"Train... ({epoch}/{num_epochs} | Loss: {round(train_metric['loss'].mean(), 3)}, Learning Rate: {round(train_metric['learning_rate'].mean(), 6)})\"\n", " )\n", " \n", "\n", "\n", "\n", " # -- Eval --\n", " eval_batch_idx = generate_batch_splits(len(tokenized_datasets[\"validation\"]), total_batch_size)\n", " eval_metrics = []\n", "\n", " with tqdm(total=len(eval_batch_idx), desc=\"Evaluation...\", leave=False) as progress_bar_eval:\n", " for batch_idx in eval_batch_idx:\n", " model_inputs = data_collator(tokenized_datasets[\"validation\"][batch_idx], tokenizer=tokenizer)\n", "\n", " # Model forward\n", " model_inputs = shard(model_inputs.data)\n", " eval_metric = parallel_eval_step(state.params, model_inputs)\n", " eval_metrics.append(eval_metric)\n", "\n", " progress_bar_eval.update(1)\n", "\n", " eval_metrics_dict = process_eval_metrics(eval_metrics)\n", " progress_bar_eval.write(\n", " f\"Eval... ({epoch}/{num_epochs} | Loss: {eval_metrics_dict['loss']}, Acc: {eval_metrics_dict['accuracy']})\"\n", " )" ] }, { "cell_type": "markdown", "metadata": { "id": "ZI4XIhY-7hyh" }, "source": [ "It can be seen that in this colab training already reaches a speed of 1.26 training steps per second. Executing [**`run_mlm_flax.py`**](https://github.com/huggingface/transformers/tree/master/examples/flax/language-modeling/run_mlm_flax.py) on a TPUv3-8 VM should be as fast as 6 training steps per second.\n", "\n", "For a more in-detail comparison of runtimes please refer to [this](https://github.com/huggingface/transformers/tree/master/examples/flax/language-modeling#runtime-evaluation) table." ] } ], "metadata": { "accelerator": "TPU", "colab": { "authorship_tag": "ABX9TyOwoqeayhHJzOvhOXZIu07e", "collapsed_sections": [], "include_colab_link": true, "name": "Masked Language Model Pretraining on TPU with 🤗 Transformers & JAX", "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.8" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "01b8401f9f26458d97ac47eb62a53820": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_b38bb5cd9489446b8fc24537759e60c6", "placeholder": "​", "style": "IPY_MODEL_15f7835673894849b6b6dbe4cfd336f5", "value": " 1/1 [00:02<00:00, 2.37s/ba]" } }, "026d9ac7cac64731afa27ac5fcd8d9bb": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "037f015c5d98455f896f015db811260e": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_a68eccd445d244a18a983c3d4b573358", "placeholder": "​", "style": "IPY_MODEL_8121511fd6dc41e6b16a02806bbdecb6", "value": " 1/10 [08:02<1:12:20, 482.31s/it]" } }, "0636601f4d8a42bba470861b19e293c1": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "08753a03bf6c42198e505cc54de9ca90": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "0a2cd49929ba4a5cb789f7dcfdf63f9a": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "0a6ba63d60924e7a869869c6dc359f42": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "0c29902ffea24196916b6f56dad78f59": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_358f223d671943a4978ceb81016b8b3d", "placeholder": "​", "style": "IPY_MODEL_776079278200407abafe88edd29ff702", "value": " 333M/333M [00:15<00:00, 21.3MB/s]" } }, "0c4c8206041d47d386f1da41c4972f21": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_026d9ac7cac64731afa27ac5fcd8d9bb", "placeholder": "​", "style": "IPY_MODEL_e6d7b59745d24529b434ac1df1389fff", "value": " 5/5 [00:11<00:00, 2.69s/it]" } }, "0e46d00678a34c86bfa9ebc73c4fea31": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "0febc23ecf6c457399ac6d103c1abc2f": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "114b5b5c202d4f5da8f764305eae2582": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "129c87d4c4bd41608eb6600854170ce4": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "1507ed751ef54eabb98315e353d549ef": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_4bfbdc64093c401abc32fb47f9b910b9", "IPY_MODEL_eb3fd3dc5fda4d728b0b45bff5f4dbbf" ], "layout": "IPY_MODEL_129c87d4c4bd41608eb6600854170ce4" } }, "15f7835673894849b6b6dbe4cfd336f5": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "160a103ca36c4ca68fb7f8b3817acf66": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "16c34052aa8143d3ab137a4b1adb6550": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "171c6872201847e08598cd8c5a0b2362": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_6def23bac354403696d8f60777836653", "placeholder": "​", "style": "IPY_MODEL_1cb65ed1e9b845c5b8554b931b4690b5", "value": " 3/3 [00:20<00:00, 6.84s/ba]" } }, "18aca4b7e88248e0ac232f67afd3f3ab": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_4a57a381bac941f285574cc3e563b048", "IPY_MODEL_db39640e706d4309a1c58d0f332b2aab" ], "layout": "IPY_MODEL_ebaa5224f0cf47d78d08f93990c08098" } }, "1a9b970c8efb4277a798da1c145c19e4": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "1b57f6133c8045aa8ad993d47af6f541": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "1cb65ed1e9b845c5b8554b931b4690b5": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "1f47f12d6bde427a892d31443e40829b": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "1fa55643d9234e7facbfecf6d2ef7cb6": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_34b0d26a14454933a08b7e2b39c36011", "placeholder": "​", "style": "IPY_MODEL_dd286ac4f020476cb5c8e6495b717ee5", "value": " 3/3 [00:47<00:00, 15.70s/ba]" } }, "20b8b6e3b79942cdbe51131f3ebaf14a": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "2226ef9eb3d04cf2992d2fb9976f167e": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "232d2e43c91742aab712216ae4e72bb0": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "2334037d360a495b9644e60f897da983": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_b567119e92374510885c44812a2ac115", "IPY_MODEL_35680c4679af41009739ebc34ccb6959" ], "layout": "IPY_MODEL_66aa387a4a5f4a488ee1cb29cbb42da9" } }, "244f1c8857a54e368b744b7596129f5a": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "24f9f85b12e14f83b6f0d300c5bf2c7b": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_c6534a1d59ba4c6586c20f4a6dd709ba", "IPY_MODEL_01b8401f9f26458d97ac47eb62a53820" ], "layout": "IPY_MODEL_a3ca03557e19482fb170ab90b6c6e136" } }, "25f1623a25cd4f859400d696140f79d9": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_34c78198209243f896a8479f1d7257fa", "IPY_MODEL_171c6872201847e08598cd8c5a0b2362" ], "layout": "IPY_MODEL_da656163b59b487f9857ac7dce3989bf" } }, "262758972960448ea46c762caaae24ca": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_5e61cb16b44949269c405f68656a30d5", "IPY_MODEL_037f015c5d98455f896f015db811260e" ], "layout": "IPY_MODEL_114b5b5c202d4f5da8f764305eae2582" } }, "269c1ae885ca46769d0a049b7df5e839": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "29c0e9a2c0264fed8c0ef1c88c1d6daf": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "29c818fbcbfa4c4d9c2ace5afcd816bd": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "2a4aff76064241bbae6d7945811c43f3": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "2b87607755784544824172984f509b13": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "2c494e518396468b945342279d4a91e8": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_e89648ca4e014ec78c9776ac90bbf6d2", "IPY_MODEL_ac4384c5d8304441b3c6e764dfdb9d9a" ], "layout": "IPY_MODEL_5bcf7478340a4d84850c43ba6364ce5e" } }, "2cbeaf721bf644379aff6366632a1a62": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "2d4ecab20fbc4e148642e001662898f7": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_a7dfa9d4c0a44542b35f94b8d9c7806f", "IPY_MODEL_caf772c6f33a46c2b20ae99fe3c31dae" ], "layout": "IPY_MODEL_d9fd9c0bef814996b61fc32edf8d6cc6" } }, "2f58e28ca978441db9d75313d263b1ba": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "332a3530366745d994126f354e5f0289": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "34b0d26a14454933a08b7e2b39c36011": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "34c78198209243f896a8479f1d7257fa": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": " #1: 100%", "description_tooltip": null, "layout": "IPY_MODEL_82ba69bcf1804675b77b1b708555bab9", "max": 3, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_9e6495156c9a4fc0b9ed24f8a216a388", "value": 3 } }, "354c2e15312a4251902d776fc770947a": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_160a103ca36c4ca68fb7f8b3817acf66", "placeholder": "​", "style": "IPY_MODEL_50a6b75288694171bf84d99d676bb9ac", "value": " 1/1 [00:01<00:00, 1.66s/ba]" } }, "35680c4679af41009739ebc34ccb6959": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_69ebeebba68145a282efed06395e61b2", "placeholder": "​", "style": "IPY_MODEL_406c959be1504ce09e68a84340e768cb", "value": " 3.07M/? [00:00<00:00, 12.1MB/s]" } }, "358f223d671943a4978ceb81016b8b3d": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "36aee3953ebe4b4596d01b9d379aa1e8": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "3f5abb3236fa49cb8eefe1ca736afd07": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "Downloading: 100%", "description_tooltip": null, "layout": "IPY_MODEL_a20f3f908f1644dc8b02e7cc82ba0f75", "max": 81, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_add23e1725b34adfa63c0560fdfc653c", "value": 81 } }, "3fe30aad373046998a001fceec61e79e": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_9491d032aadf4d27bea141190cb50b2f", "IPY_MODEL_9b17d8470fd44c159896e26c9a4d9559" ], "layout": "IPY_MODEL_29c0e9a2c0264fed8c0ef1c88c1d6daf" } }, "406c959be1504ce09e68a84340e768cb": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "40ed180285114ee3b57c174e3954e871": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "420643ba0ccd4a87887b28e722e74f89": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "43215c6261f0492b8ae9f8600b51b25f": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "4444b057ae0747beb0a0593c62277967": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "444b38ddd3994f35b4dc6de5fa486e5a": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "44f0cb6e4c07434fb0f89fbbd89e492d": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "494957d1b11f4fef8aa0de984239fe6f": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "info", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_99739cb14d9546f0a7058502b7562a52", "max": 1, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_f87488c14e1c46808435448d258a062d", "value": 1 } }, "49bff82dc4264b0f815ea4e26734804d": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "4a57a381bac941f285574cc3e563b048": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": " #2: 100%", "description_tooltip": null, "layout": "IPY_MODEL_63746a342ad848fc8f45da71489e26af", "max": 3, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_444b38ddd3994f35b4dc6de5fa486e5a", "value": 3 } }, "4bfbdc64093c401abc32fb47f9b910b9": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "Downloading: 100%", "description_tooltip": null, "layout": "IPY_MODEL_0a2cd49929ba4a5cb789f7dcfdf63f9a", "max": 481, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_1a9b970c8efb4277a798da1c145c19e4", "value": 481 } }, "4c71d8f73af54865af3bb984cfa78375": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "4d0b2a650f1a43a18214d2f5be471c5e": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "501540ebc32b47878367999f6b15dba2": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_6b3121474b6445d88a96506c66408dd9", "placeholder": "​", "style": "IPY_MODEL_d1ccee8b62754023b4e6df520ae3966e", "value": " 1/1 [00:01<00:00, 1.86s/ba]" } }, "50a6b75288694171bf84d99d676bb9ac": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "540e23f228ae45bcb5b76ccb94f7e031": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "5533776328e64aedbda3950db146f857": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "Downloading: ", "description_tooltip": null, "layout": "IPY_MODEL_f6506d6ffd46498db87ec4ac66c95a9d", "max": 5577, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_f406dbf7c95b49d091c268c90080b8a5", "value": 5577 } }, "569ee56d7fea40ae89ef2349e1db457d": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "56e6b6239d5544bdb9cfcde143d4edf7": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "56fd3e43a9fb472ab180a25763acbf41": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "5ae3dffe0e9b4bb5b2ccb4df55b72455": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "5bcf7478340a4d84850c43ba6364ce5e": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "5c33fc07e8944ead8479baf09cd365f4": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_e46226b475b14aa09437551b73cf0c1b", "IPY_MODEL_998cb781fbe84f06b2064459c128015b" ], "layout": "IPY_MODEL_a4a90eb988c94bc48aed9bc14918e9a4" } }, "5e61cb16b44949269c405f68656a30d5": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "", "description": "Epoch ...: 10%", "description_tooltip": null, "layout": "IPY_MODEL_8e267f7131de429db31f99d0f2fd310a", "max": 10, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_dcb74fb0f3774685aa56e64ace39bde3", "value": 1 } }, "60d81ad4c9474506a99c1c9a9a8cef7a": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "60f69d3d22b740e0a5f83cd2c6e8ca83": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "63746a342ad848fc8f45da71489e26af": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "669da797864e4a5b8b1b2feab627bb8e": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_999db518d5174464b6c03b952ea3af1a", "IPY_MODEL_a62f9f0af78741bc8f1045a683d54688" ], "layout": "IPY_MODEL_a354d37e5f404a17bbd68a8f0845dd99" } }, "66aa387a4a5f4a488ee1cb29cbb42da9": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "69ebeebba68145a282efed06395e61b2": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "6aaf8d5489c7455db84c395beb5fa45b": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "6af5ca40e43a415bb535773ac5974dd2": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "6b3121474b6445d88a96506c66408dd9": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "6def23bac354403696d8f60777836653": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "6e91147560c94d419bbe2e10db39f694": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "6f88b74eeb2c49b3856d358c1a21e8a4": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "", "description": "Training...: 100%", "description_tooltip": null, "layout": "IPY_MODEL_c60a29f8061d43338adc80ae3c09b6ef", "max": 71, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_aac392b4e9d34bbebb7da08ccf0ed2b8", "value": 71 } }, "6fd5803d251d4dc4a5f20375b1e99385": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_f82eaf223bb241029e23982325f421c4", "IPY_MODEL_354c2e15312a4251902d776fc770947a" ], "layout": "IPY_MODEL_752b65a12f2a440bb9ab1539f0fddf85" } }, "70b3393474e6416a85f42e9f07a9550b": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_cf45447cb37e4b4198154f409f974030", "IPY_MODEL_501540ebc32b47878367999f6b15dba2" ], "layout": "IPY_MODEL_93e69e0936f342608e422705f40d64dd" } }, "72dc8f192bb24271849b618522ad3990": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "7522a60a290d4b749142b7c3bef2e51e": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_eaf74e107314477d93302b2296145ef9", "IPY_MODEL_ff231177ded8474e844806243efde855" ], "layout": "IPY_MODEL_fa60d2398af0441d96fb81391cafdf96" } }, "752b65a12f2a440bb9ab1539f0fddf85": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "776079278200407abafe88edd29ff702": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "78046e506c7948f3a6b7f3aabcb1a0f8": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "79a4220737f04106af1fcc33a9ef9409": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_332a3530366745d994126f354e5f0289", "placeholder": "​", "style": "IPY_MODEL_6af5ca40e43a415bb535773ac5974dd2", "value": " 14.7k/? [00:00<00:00, 16.7kB/s]" } }, "7b2a7c286bf3418c89b58390a5d071dc": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_f8a52dc00c554a9488b2ad11f1eecdd6", "IPY_MODEL_d5d34937611f41749ec872eb5162d087" ], "layout": "IPY_MODEL_1f47f12d6bde427a892d31443e40829b" } }, "7cbfb28a8a10402eb34a0b07af3e7973": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_c12d9b21c94d4a5b987a66d19d926ced", "placeholder": "​", "style": "IPY_MODEL_c1fd2afcaa2d45c9ad6aa2e4ce40f1ae", "value": " 1/1 [00:01<00:00, 1.96s/ba]" } }, "7e428c1823a748d8a0367107dbdaa70d": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "7eea344bf03645478852c513401b3115": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "80fc7c43339d4282bc8edfaf167c1224": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "8121511fd6dc41e6b16a02806bbdecb6": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "82ba69bcf1804675b77b1b708555bab9": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "8827cea3be154f1c8979072ddfbb9f8d": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "88498392c9b6435a90b44f42699bc2bd": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "885d037ed38542fe82d9a5404476c044": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "8b7829a8ce7b4892b8047f8c6a19201a": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_5533776328e64aedbda3950db146f857", "IPY_MODEL_79a4220737f04106af1fcc33a9ef9409" ], "layout": "IPY_MODEL_a95d6d6ad00a472abbc6ac7beded4abb" } }, "8e267f7131de429db31f99d0f2fd310a": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "8ed57e3298854dd8a8e08ad233c78410": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": " #1: 100%", "description_tooltip": null, "layout": "IPY_MODEL_2cbeaf721bf644379aff6366632a1a62", "max": 3, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_60d81ad4c9474506a99c1c9a9a8cef7a", "value": 3 } }, "92c59dbc4f6a467ebe5c4d60928c768f": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "93e69e0936f342608e422705f40d64dd": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "943d5c7876d44836bf15599392511ed0": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "9491d032aadf4d27bea141190cb50b2f": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": " #1: 100%", "description_tooltip": null, "layout": "IPY_MODEL_d5a036d3c1f0461698a0b1299a009fcd", "max": 1, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_244f1c8857a54e368b744b7596129f5a", "value": 1 } }, "94bcede0deb74189acd3c30f748be1ff": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "94cbae744604407e8c8aa9c6bd8ce151": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": " #3: 100%", "description_tooltip": null, "layout": "IPY_MODEL_dbb51464281d40a2b5b8c8414304afa0", "max": 3, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_2b87607755784544824172984f509b13", "value": 3 } }, "9513b5d85ae442859f0223f83a48b953": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "95f6ac83377e49379240c8eeae384af0": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "99739cb14d9546f0a7058502b7562a52": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "998cb781fbe84f06b2064459c128015b": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_569ee56d7fea40ae89ef2349e1db457d", "placeholder": "​", "style": "IPY_MODEL_e06b47e60bd5415fa536603f2c103f86", "value": " 1/1 [00:02<00:00, 2.08s/ba]" } }, "999db518d5174464b6c03b952ea3af1a": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": " #0: 100%", "description_tooltip": null, "layout": "IPY_MODEL_49bff82dc4264b0f815ea4e26734804d", "max": 3, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_e49dc5ab78d74be3ab6d9760397e04c2", "value": 3 } }, "9b17d8470fd44c159896e26c9a4d9559": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_a8e66259adcf4a3ab57ab2fc0d3847c2", "placeholder": "​", "style": "IPY_MODEL_c30e725164724058b83da04ac22b61b8", "value": " 1/1 [00:02<00:00, 2.45s/ba]" } }, "9c9e95f42e904a34a97b8ebe17f997eb": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_94cbae744604407e8c8aa9c6bd8ce151", "IPY_MODEL_f6fc3918d3f149c5ab7f7dd630e5e303" ], "layout": "IPY_MODEL_43215c6261f0492b8ae9f8600b51b25f" } }, "9e57068083904dcda413fef9728a1b8e": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "9e6495156c9a4fc0b9ed24f8a216a388": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "9f7ad7e1e3164bc79ddc589a9649625f": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "a12e3f6679564ea4a2ff9e1f973a6415": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_8ed57e3298854dd8a8e08ad233c78410", "IPY_MODEL_1fa55643d9234e7facbfecf6d2ef7cb6" ], "layout": "IPY_MODEL_80fc7c43339d4282bc8edfaf167c1224" } }, "a20f3f908f1644dc8b02e7cc82ba0f75": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "a354d37e5f404a17bbd68a8f0845dd99": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "a3ca03557e19482fb170ab90b6c6e136": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "a4a90eb988c94bc48aed9bc14918e9a4": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "a62f9f0af78741bc8f1045a683d54688": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_0e46d00678a34c86bfa9ebc73c4fea31", "placeholder": "​", "style": "IPY_MODEL_1b57f6133c8045aa8ad993d47af6f541", "value": " 3/3 [00:19<00:00, 6.45s/ba]" } }, "a68eccd445d244a18a983c3d4b573358": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "a7dfa9d4c0a44542b35f94b8d9c7806f": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": " #0: 100%", "description_tooltip": null, "layout": "IPY_MODEL_56e6b6239d5544bdb9cfcde143d4edf7", "max": 3, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_7eea344bf03645478852c513401b3115", "value": 3 } }, "a8e66259adcf4a3ab57ab2fc0d3847c2": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "a95d6d6ad00a472abbc6ac7beded4abb": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "aa802d8d41204fff94e49acbb3dedcc0": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_6f88b74eeb2c49b3856d358c1a21e8a4", "IPY_MODEL_afaa3f9241b04276b4b8f9f3cf79b444" ], "layout": "IPY_MODEL_943d5c7876d44836bf15599392511ed0" } }, "aac392b4e9d34bbebb7da08ccf0ed2b8": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "ac05b51eb9c342db8633ab389d3fd97b": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_d7ab871edf5a437e9dee66027f0da73d", "placeholder": "​", "style": "IPY_MODEL_6aaf8d5489c7455db84c395beb5fa45b", "value": " 12/71 [00:09<00:48, 1.23it/s]" } }, "ac4384c5d8304441b3c6e764dfdb9d9a": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_fdb39fdbd0ca424f8a98449d7a929deb", "placeholder": "​", "style": "IPY_MODEL_9e57068083904dcda413fef9728a1b8e", "value": " 3/3 [00:15<00:00, 5.32s/ba]" } }, "add23e1725b34adfa63c0560fdfc653c": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "afaa3f9241b04276b4b8f9f3cf79b444": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_d44f508c2f494b3bafdf6e90925f84dd", "placeholder": "​", "style": "IPY_MODEL_b33b0ad250cf40ca803680633cc1b947", "value": " 71/71 [07:43<00:00, 1.23it/s]" } }, "b1d2b908593c422d946d7656da6cbaf9": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": " #1: 100%", "description_tooltip": null, "layout": "IPY_MODEL_d30fe8af738f419eafc4fbb3dbf69164", "max": 1, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_f1bf3ed55e454e05ad3cc4a441041a5b", "value": 1 } }, "b2607cb39d7f4df69e029473df4e0bb6": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_b1d2b908593c422d946d7656da6cbaf9", "IPY_MODEL_7cbfb28a8a10402eb34a0b07af3e7973" ], "layout": "IPY_MODEL_92c59dbc4f6a467ebe5c4d60928c768f" } }, "b2c6b88e7df74b468e535daedb323e92": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "", "description": "Evaluation...: 100%", "description_tooltip": null, "layout": "IPY_MODEL_cb84d2716aa3457a9d5f1e0460e1ec60", "max": 5, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_78046e506c7948f3a6b7f3aabcb1a0f8", "value": 5 } }, "b33b0ad250cf40ca803680633cc1b947": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "b34b97728ac248d0860a1eb85b33e309": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "b38bb5cd9489446b8fc24537759e60c6": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "b567119e92374510885c44812a2ac115": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "Downloading: ", "description_tooltip": null, "layout": "IPY_MODEL_94bcede0deb74189acd3c30f748be1ff", "max": 358718, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_16c34052aa8143d3ab137a4b1adb6550", "value": 358718 } }, "ba2ad8afb5cf4e8ebbec3d01516f141c": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "bbfa545245f04d028b6c054c24853997": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "bdb9619c2f3449caaa4d4273d1cfa0cf": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "be4bd93181264b4ba86ab629cea48890": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": " #2: 100%", "description_tooltip": null, "layout": "IPY_MODEL_eb27225d1f694ef49b5f5a1d704ee2ac", "max": 3, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_4444b057ae0747beb0a0593c62277967", "value": 3 } }, "c12d9b21c94d4a5b987a66d19d926ced": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "c1fd2afcaa2d45c9ad6aa2e4ce40f1ae": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "c30e725164724058b83da04ac22b61b8": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "c3913e79f785433989cf4ff6a528056e": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "Downloading: 100%", "description_tooltip": null, "layout": "IPY_MODEL_8827cea3be154f1c8979072ddfbb9f8d", "max": 332871683, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_232d2e43c91742aab712216ae4e72bb0", "value": 332871683 } }, "c3beb5d9c8f441738d25143a5f3e9ef9": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "", "description": "Training...: 17%", "description_tooltip": null, "layout": "IPY_MODEL_08753a03bf6c42198e505cc54de9ca90", "max": 71, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_dda84868d89e44feba92c4d2bd4abae2", "value": 12 } }, "c60a29f8061d43338adc80ae3c09b6ef": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "c6534a1d59ba4c6586c20f4a6dd709ba": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": " #2: 100%", "description_tooltip": null, "layout": "IPY_MODEL_9f7ad7e1e3164bc79ddc589a9649625f", "max": 1, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_d7cfc8a005a74fb5aee9cfc5737e10bd", "value": 1 } }, "cae9b026ff684dffa0bcacb1ddaf1fd8": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "caf772c6f33a46c2b20ae99fe3c31dae": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_95f6ac83377e49379240c8eeae384af0", "placeholder": "​", "style": "IPY_MODEL_60f69d3d22b740e0a5f83cd2c6e8ca83", "value": " 3/3 [00:48<00:00, 16.19s/ba]" } }, "cb84d2716aa3457a9d5f1e0460e1ec60": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "cea4982e4b32491698c26178a4c445ff": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "cf45447cb37e4b4198154f409f974030": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": " #3: 100%", "description_tooltip": null, "layout": "IPY_MODEL_0636601f4d8a42bba470861b19e293c1", "max": 1, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_bbfa545245f04d028b6c054c24853997", "value": 1 } }, "d1ccee8b62754023b4e6df520ae3966e": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "d240794aecdb4f1c8dd5948a87095bfe": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_36aee3953ebe4b4596d01b9d379aa1e8", "placeholder": "​", "style": "IPY_MODEL_29c818fbcbfa4c4d9c2ace5afcd816bd", "value": " 3/3 [00:34<00:00, 11.35s/ba]" } }, "d25c31eb97ff496baa48c4b7a317a976": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "d2ad23714f2d49b08205d069b12899c8": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_be4bd93181264b4ba86ab629cea48890", "IPY_MODEL_d240794aecdb4f1c8dd5948a87095bfe" ], "layout": "IPY_MODEL_20b8b6e3b79942cdbe51131f3ebaf14a" } }, "d30fe8af738f419eafc4fbb3dbf69164": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "d3948f470523480697d5d7221b0fd1f4": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_494957d1b11f4fef8aa0de984239fe6f", "IPY_MODEL_eab56d1def3749ac9f95ccffce80dacc" ], "layout": "IPY_MODEL_269c1ae885ca46769d0a049b7df5e839" } }, "d44f508c2f494b3bafdf6e90925f84dd": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "d5a036d3c1f0461698a0b1299a009fcd": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "d5d34937611f41749ec872eb5162d087": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_56fd3e43a9fb472ab180a25763acbf41", "placeholder": "​", "style": "IPY_MODEL_72dc8f192bb24271849b618522ad3990", "value": " 1/1 [00:01<00:00, 1.86s/ba]" } }, "d7ab871edf5a437e9dee66027f0da73d": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "d7cfc8a005a74fb5aee9cfc5737e10bd": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "d9fd9c0bef814996b61fc32edf8d6cc6": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "da656163b59b487f9857ac7dce3989bf": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "da76d7739a3544839bc88aaf00970d1a": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_b2c6b88e7df74b468e535daedb323e92", "IPY_MODEL_0c4c8206041d47d386f1da41c4972f21" ], "layout": "IPY_MODEL_0febc23ecf6c457399ac6d103c1abc2f" } }, "db39640e706d4309a1c58d0f332b2aab": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_5ae3dffe0e9b4bb5b2ccb4df55b72455", "placeholder": "​", "style": "IPY_MODEL_f6ad3867bbd74702a807841b926705ca", "value": " 3/3 [00:14<00:00, 4.89s/ba]" } }, "dbb51464281d40a2b5b8c8414304afa0": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "dcb74fb0f3774685aa56e64ace39bde3": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "dd286ac4f020476cb5c8e6495b717ee5": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "dda84868d89e44feba92c4d2bd4abae2": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "de6d3f88db2f48a48287566f152902d6": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "df151562aa3249cd9635a3cd238a00e5": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_c3beb5d9c8f441738d25143a5f3e9ef9", "IPY_MODEL_ac05b51eb9c342db8633ab389d3fd97b" ], "layout": "IPY_MODEL_ba2ad8afb5cf4e8ebbec3d01516f141c" } }, "e02d02fb4e3c4932ade9f214b93ce4e5": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "e06b47e60bd5415fa536603f2c103f86": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "e46226b475b14aa09437551b73cf0c1b": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": " #3: 100%", "description_tooltip": null, "layout": "IPY_MODEL_44f0cb6e4c07434fb0f89fbbd89e492d", "max": 1, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_0a6ba63d60924e7a869869c6dc359f42", "value": 1 } }, "e49dc5ab78d74be3ab6d9760397e04c2": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "e6d7b59745d24529b434ac1df1389fff": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "e89648ca4e014ec78c9776ac90bbf6d2": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": " #3: 100%", "description_tooltip": null, "layout": "IPY_MODEL_cea4982e4b32491698c26178a4c445ff", "max": 3, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_cae9b026ff684dffa0bcacb1ddaf1fd8", "value": 3 } }, "eab56d1def3749ac9f95ccffce80dacc": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_885d037ed38542fe82d9a5404476c044", "placeholder": "​", "style": "IPY_MODEL_2226ef9eb3d04cf2992d2fb9976f167e", "value": " 389515/0 [00:51<00:00, 10309.57 examples/s]" } }, "eaf74e107314477d93302b2296145ef9": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": " #0: 100%", "description_tooltip": null, "layout": "IPY_MODEL_4c71d8f73af54865af3bb984cfa78375", "max": 1, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_7e428c1823a748d8a0367107dbdaa70d", "value": 1 } }, "eb27225d1f694ef49b5f5a1d704ee2ac": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "eb3fd3dc5fda4d728b0b45bff5f4dbbf": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_bdb9619c2f3449caaa4d4273d1cfa0cf", "placeholder": "​", "style": "IPY_MODEL_2f58e28ca978441db9d75313d263b1ba", "value": " 481/481 [00:00<00:00, 1.16kB/s]" } }, "ebaa5224f0cf47d78d08f93990c08098": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "f15842f820b2492eaf344303bb31cb9e": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_3f5abb3236fa49cb8eefe1ca736afd07", "IPY_MODEL_f9177108ca8d437fb83cd6ec87071268" ], "layout": "IPY_MODEL_9513b5d85ae442859f0223f83a48b953" } }, "f1bf3ed55e454e05ad3cc4a441041a5b": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "f2e1e2c29e8a4e4dae1b535311703e66": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_c3913e79f785433989cf4ff6a528056e", "IPY_MODEL_0c29902ffea24196916b6f56dad78f59" ], "layout": "IPY_MODEL_4d0b2a650f1a43a18214d2f5be471c5e" } }, "f406dbf7c95b49d091c268c90080b8a5": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "f6506d6ffd46498db87ec4ac66c95a9d": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "f6ad3867bbd74702a807841b926705ca": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "f6fc3918d3f149c5ab7f7dd630e5e303": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_b34b97728ac248d0860a1eb85b33e309", "placeholder": "​", "style": "IPY_MODEL_de6d3f88db2f48a48287566f152902d6", "value": " 3/3 [00:38<00:00, 12.84s/ba]" } }, "f82eaf223bb241029e23982325f421c4": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": " #0: 100%", "description_tooltip": null, "layout": "IPY_MODEL_420643ba0ccd4a87887b28e722e74f89", "max": 1, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_540e23f228ae45bcb5b76ccb94f7e031", "value": 1 } }, "f87488c14e1c46808435448d258a062d": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "f8a52dc00c554a9488b2ad11f1eecdd6": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": " #2: 100%", "description_tooltip": null, "layout": "IPY_MODEL_6e91147560c94d419bbe2e10db39f694", "max": 1, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_40ed180285114ee3b57c174e3954e871", "value": 1 } }, "f9177108ca8d437fb83cd6ec87071268": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_88498392c9b6435a90b44f42699bc2bd", "placeholder": "​", "style": "IPY_MODEL_e02d02fb4e3c4932ade9f214b93ce4e5", "value": " 81.0/81.0 [00:00<00:00, 129B/s]" } }, "fa60d2398af0441d96fb81391cafdf96": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "fdb39fdbd0ca424f8a98449d7a929deb": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "ff231177ded8474e844806243efde855": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_2a4aff76064241bbae6d7945811c43f3", "placeholder": "​", "style": "IPY_MODEL_d25c31eb97ff496baa48c4b7a317a976", "value": " 1/1 [00:01<00:00, 1.66s/ba]" } } } } }, "nbformat": 4, "nbformat_minor": 1 }