{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "view-in-github" }, "source": [ "\"Open" ] }, { "cell_type": "markdown", "metadata": { "id": "rEJBSTyZIrIb" }, "source": [ "# Fine-tuning a 🤗 Transformers model on TPU with **Flax/JAX**" ] }, { "cell_type": "markdown", "metadata": { "id": "kTCFado4IrIc" }, "source": [ "In this notebook, we will see how to fine-tune one of the [🤗 Transformers](https://github.com/huggingface/transformers) models on TPU using [**Flax**](https://flax.readthedocs.io/en/latest/index.html). As can be seen on [this benchmark](https://github.com/huggingface/transformers/tree/master/examples/flax/text-classification#runtime-evaluation) using Flax/JAX on GPU/TPU is often much faster and can also be considerably cheaper than using PyTorch on GPU/TPU.\n", "\n", "[**Flax**](https://flax.readthedocs.io/en/latest/index.html) is a high-performance neural network library designed for flexibility built on top of JAX (see below). It aims to provide users with full control of their training code and is carefully designed to work well with JAX transformations such as `grad` and `pmap` (see the [Flax philosophy](https://flax.readthedocs.io/en/latest/philosophy.html)). For an introduction to Flax see the [Flax Basic Colab](https://flax.readthedocs.io/en/latest/notebooks/flax_basics.html) or the list of curated [Flax examples](https://flax.readthedocs.io/en/latest/examples.html).\n", "\n", "[**JAX**](https://jax.readthedocs.io/en/latest/index.html) is Autograd and XLA, brought together for high-performance numerical computing and machine learning research. It provides composable transformations of Python+NumPy programs: differentiate, vectorize, parallelize, Just-In-Time compile to GPU/TPU, and more. A great place for getting started with JAX is the [JAX 101 Tutorial](https://jax.readthedocs.io/en/latest/jax-101/index.html)." ] }, { "cell_type": "markdown", "metadata": { "id": "X4cRE8IbIrIV" }, "source": [ "If you're opening this Notebook on colab, you will probably need to install 🤗 Transformers and 🤗 Datasets as well as [Flax](https://github.com/google/flax.git) and [Optax](https://github.com/deepmind/optax). Optax is a gradient processing and optimization library for JAX, and is the optimizer library\n", "recommended by Flax." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "MOsHUjgdIrIW" }, "outputs": [], "source": [ "%%capture\n", "!pip install datasets\n", "!pip install git+https://github.com/huggingface/transformers.git\n", "!pip install flax\n", "!pip install git+https://github.com/deepmind/optax.git" ] }, { "cell_type": "markdown", "metadata": { "id": "EwRqMUxqzOKN" }, "source": [ "You also will need to set up the TPU for JAX in this notebook. This can be done by executing the following lines." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "drej_4loxdCF" }, "outputs": [], "source": [ "import jax.tools.colab_tpu\n", "jax.tools.colab_tpu.setup_tpu()" ] }, { "cell_type": "markdown", "metadata": { "id": "bAIlBLFXXix7" }, "source": [ "If everything is set up correctly, the following command should return a list of 8 TPU devices." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Kc0dzzrsXktW", "outputId": "4dbdaadf-96a9-48a9-f4a9-fa94fb10a431" }, "outputs": [ { "data": { "text/plain": [ "[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),\n", " TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),\n", " TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),\n", " TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),\n", " TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),\n", " TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),\n", " TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),\n", " TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "jax.local_devices()" ] }, { "cell_type": "markdown", "metadata": { "id": "HFASsisvIrIb" }, "source": [ "If you're opening this notebook locally, make sure your environment has an install from the last version of those libraries.\n", "\n", "You can find a script version of this notebook to fine-tune your model [here](https://github.com/huggingface/transformers/tree/master/examples/flax/text-classification)." ] }, { "cell_type": "markdown", "metadata": { "id": "DlpF5MnonmmQ" }, "source": [ "As an example, we will fine-tune a pretrained [*auto-encoding model*](https://huggingface.co/transformers/model_summary.html#autoencoding-models) on a text classification task of the [GLUE Benchmark](https://gluebenchmark.com/). Note that this notebook does not focus so much on data preprocessing, but rather on how to write a training and evaluation loop in **JAX/Flax**. If you want more detailed explanations regarding the data preprocessing, please check out [this](https://colab.research.google.com/github/huggingface/notebooks/blob/master/examples/text_classification.ipynb#scrollTo=545PP3o8IrJV) notebook.\n", "\n", "The GLUE Benchmark is a group of nine classification tasks on sentences or pairs of sentences which are:\n", "[CoLA](https://nyu-mll.github.io/CoLA/), [MNLI](https://arxiv.org/abs/1704.05426), [MRPC](https://www.microsoft.com/en-us/download/details.aspx?id=52398), [QNLI](https://rajpurkar.github.io/SQuAD-explorer/), [QQP](https://data.quora.com/First-Quora-Dataset-Release-Question-Pairs), [RTE](https://aclweb.org/aclwiki/Recognizing_Textual_Entailment), [SST-2](https://nlp.stanford.edu/sentiment/index.html), [STS-B](http://ixa2.si.ehu.es/stswiki/index.php/STSbenchmark), [WNLI](https://cs.nyu.edu/faculty/davise/papers/WinogradSchemas/WS.html).\n", "\n", "We will see how to easily load the dataset for each one of those tasks and how to write a training loop in Flax. Each task is named by its acronym, with `mnli-mm` standing for the mismatched version of MNLI (so same training set as `mnli` but different validation and test sets):" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "YZbiBDuGIrId" }, "outputs": [], "source": [ "GLUE_TASKS = [\"cola\", \"mnli\", \"mnli-mm\", \"mrpc\", \"qnli\", \"qqp\", \"rte\", \"sst2\", \"stsb\", \"wnli\"]" ] }, { "cell_type": "markdown", "metadata": { "id": "4RRkXuteIrIh" }, "source": [ "This notebook is built to run on any of the tasks in the list above, with any Flax/JAX model checkpoint from the [Model Hub](https://huggingface.co/models) as long as that model has a version with a classification head. Depending on the model you are using, you might need to adjust the batch size to avoid out-of-memory errors. Set those three parameters, then the rest of the notebook should run smoothly:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "zVvslsfMIrIh" }, "outputs": [], "source": [ "task = \"cola\"\n", "model_checkpoint = \"bert-base-cased\"\n", "per_device_batch_size = 4" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We also quickly upload some telemetry - this tells us which examples and software versions are getting used so we know where to prioritize our maintenance efforts. We don't collect (or care about) any personally identifiable information, but if you'd prefer not to be counted, feel free to skip this step or delete this cell entirely." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "from transformers.utils import send_example_telemetry\n", "\n", "send_example_telemetry(\"text_classification_notebook\", framework=\"flax\")" ] }, { "cell_type": "markdown", "metadata": { "id": "whPRbBNbIrIl" }, "source": [ "## Loading the dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "W7QYTpxXIrIl" }, "source": [ "We will use the [🤗 Datasets](https://github.com/huggingface/datasets) library to download the data and get the metric we need to use for evaluation (to compare our model to the benchmark). This can be easily done with the functions `load_dataset` and `load_metric`. " ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "IreSlFmlIrIm" }, "outputs": [], "source": [ "from datasets import load_dataset, load_metric" ] }, { "cell_type": "markdown", "metadata": { "id": "CKx2zKs5IrIq" }, "source": [ "Apart from `mnli-mm` being a special code, we can directly pass our task name to those functions. `load_dataset` will cache the dataset to avoid downloading it again the next time you run this cell." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 271, "referenced_widgets": [ "85804123d219427aa8554e5c1abcd188", "936a8e8c640b4d48ba547cea71c10a05", "c392dbcdc45e462088de3a538817f89c", "057905968ef34b748e0d7c0595cf4b0f", "3039f03a5e1c41258b9f108e93e7610d", "c11d5ebb907e4f64bbcd38b822144085", "196260e063ea4355bc5e81383c7c9780", "7c67c6815d7e4420b4edbcfa0cf993f3", "b3939ab991d44a30b385075981945b58", "4c5f23f61db74c7ab831a742d94e4d7c", "cdba2232a1db40acb3b0d773deec8bd7", "b38dc8c0b13e48d99a96ef51a6be1048", "ce4e36d3fba04286807cafa4f02a8c8f", "9eba683ec8e44e9eb00bce6feca40d2b", "818f4039d4eb4f0c912fa726c5f8e566", "45bf8fd612424c06a4db1c272490bbfd", "a3c46f9c799448349dc47b2d4e06817b", "e96d8cf1ebc64aabaaf17328ec6916a2", "de0f85b55b7a40c3bd48d27d8afa6436", "ffb11704f93143a286fa45c507246336", "10ee2d2e0ff44ec5b960e4d46e04fe29", "fd2d5ec23d884063b43b2ec6bd74bc41", "8501dde32f014c3489996ebcd58a86a2", "170bc89474774b83920b8f6b60eb6966", "c4e33462dac044c7a53d9f68974b1ccd", "9d8fbf8b04ab4975b7d8abdbb87f7b48", "bd8ac3ec6e38427db5362fc96b622864", "0b6d4fa5e005408f9d66abc2b3ebbd8e", "6477db7c602848cc96d82eb5e7a303ad", "290c78228238478cb790e38b420ad19d", "3048ef3256c849309b77d163fa5dc198", "e2d9852bdb664d6490b08885833f2cf8", "38458420fec544f6b7cc35a8f17851cf", "fce53153459d4d62b2f56d2d1638a7ee", "63cd8679fd504d158eae06bbd1a65037", "c0558be5fdf244678687150b067c37b7", "1b601a8984a14420adf9e6fc08c7f3ba", "2627873a55e94464b9ef86ac93fa26cd", "9dba428f61f14559a56813d2d51801d0", "69f48b05eb404f52a05dfb5dfef1b8ff", "b239280cbe5a4f5991959b14c20dd6e2", "7e0ab271758843e188a9b67d2248f320", "40e6d68d7ef24502be7f7d937b9e3db8", "077d4c7f496d4a0f90954ebd86629055", "7f3af6d2a01b4bf8a1173393bf906e2c", "5b056e2c5e654f069bb6eae7e8c3d060", "14fb47c050504776934f208068b53363", "3884494e0b724944b1da710342ffaf1b", "7bf3d5d1563d4e8cbb840f0267325c48", "3f815388d365421fbad24f391b46af47", "ccc693c6f938477db01c9fb69526f348", "8f002deb5d03403f82b0968adc2da45f", "f57ab763e32e4aa59519df04c32b44b1", "78eca35ab52d479db5b907e70cd46b1a", "dcc6a32d60974dfc984769343c323c65", "1a920e46cbf048c6a55dee80ee0d4643" ] }, "id": "s_AY1ATSIrIq", "outputId": "9e7841da-5da7-48aa-c98f-e502fd11fbeb" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_253934/2906828129.py:5: FutureWarning: load_metric is deprecated and will be removed in the next major version of datasets. Use 'evaluate.load' instead, from the new library 🤗 Evaluate: https://huggingface.co/docs/evaluate\n", " metric = load_metric('glue', actual_task)\n" ] } ], "source": [ "actual_task = \"mnli\" if task == \"mnli-mm\" else task\n", "is_regression = task == \"stsb\"\n", "\n", "raw_dataset = load_dataset(\"glue\", actual_task)\n", "metric = load_metric('glue', actual_task)" ] }, { "cell_type": "markdown", "metadata": { "id": "n9qywopnIrJH" }, "source": [ "## Preprocessing the data" ] }, { "cell_type": "markdown", "metadata": { "id": "YVx71GdAIrJH" }, "source": [ "Before we can feed those texts to our model, we need to preprocess them. This is done by a 🤗 Transformers `Tokenizer` which will (as the name indicates) tokenize the inputs. This includes converting the tokens to their corresponding IDs in the pretrained vocabulary and putting them in a format the model expects, as well as generate the other inputs that the model requires.\n", "\n", "To do all of this, we instantiate our tokenizer with the `AutoTokenizer.from_pretrained` method, which will ensure:\n", "\n", "- we get a tokenizer that corresponds to the model architecture we want to use,\n", "- we download the vocabulary used when pretraining this specific checkpoint.\n", "\n", "That vocabulary will be cached, so it's not downloaded again the next time we run the cell." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 216, "referenced_widgets": [ "616b7ddb0a644791b10e06e6cba743b1", "9add322f037242f7945f9779bd6ba82d", "92fceaea6f1b4c9484b0ec860b77609c", "8da21f5fa32949858e1e9d1b9dc9861a", "b7702d8f0d6146078c9219846de80802", "22c310d51b324e05aca35a9bfda468c8", "e7d9d01bc9724a8299e1960375b71298", "bcfbbdeda9864282af5e8d108b036c08", "ddd42c117b7e41a6a33e302b38a7f3df", "849f7ad495b64111a4d642b371fb5e7f", "f2f76ebc808d415db2c1d0175aca27cc", "9fcb9473cd074e0c888f9105c148a3b1", "27a673e8f31f488b933496e737f0577e", "ad82960c7f9448e29118c3bd8cd7adaf", "041899de5a914f678ab50028fd1be25f", "c36c74ad5815443b8c0b1eafff12489e", "5bbda35c1c3e4c0ebd32987f84f0c936", "fa10efd7c0c443218c89c76b889cffd7", "d2389b5b7bba40bca36ffceb80453576", "cce1a6d399914221b5c52c3a5fe3911c", "a5749038eade4b6e983ef854019194e1", "085651db75a94e9ba2eabcb83e0e8ea1", "d877c534e0354564a379c8da96f190a4", "a9f2c4ec1ed84aecbdbaa12cb4ea7eee", "4bd0dc33abbc4f9087876b31f44b8c22", "d490e07908694463b378a0d273cd2137", "92a5d91b961b4f56aa2a89b2be86c552", "30a4e1848fbf4d0695a2ed73f990633a", "da0dc5ce3a6545e78ce8b8c4421d4474", "9a784c51e03a4667b536db14a8624599", "8c1f773ce9ac4cbe8170432da358a78d", "b69795bbbec740868f25cf5c09ed525e" ] }, "id": "eXNLu_-nIrJI", "outputId": "16ba468f-64f9-475e-9a9f-4863b539e7ec" }, "outputs": [], "source": [ "from transformers import AutoTokenizer\n", " \n", "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)" ] }, { "cell_type": "markdown", "metadata": { "id": "qo_0B1M2IrJM" }, "source": [ "To preprocess our dataset, we will thus need the names of the columns containing the sentence(s). The following dictionary keeps track of the correspondence task to column names:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "fyGdtK9oIrJM" }, "outputs": [], "source": [ "task_to_keys = {\n", " \"cola\": (\"sentence\", None),\n", " \"mnli\": (\"premise\", \"hypothesis\"),\n", " \"mnli-mm\": (\"premise\", \"hypothesis\"),\n", " \"mrpc\": (\"sentence1\", \"sentence2\"),\n", " \"qnli\": (\"question\", \"sentence\"),\n", " \"qqp\": (\"question1\", \"question2\"),\n", " \"rte\": (\"sentence1\", \"sentence2\"),\n", " \"sst2\": (\"sentence\", None),\n", " \"stsb\": (\"sentence1\", \"sentence2\"),\n", " \"wnli\": (\"sentence1\", \"sentence2\"),\n", "}" ] }, { "cell_type": "markdown", "metadata": { "id": "2C0hcmp9IrJQ" }, "source": [ "We can then write the function that will preprocess our samples. We just feed them to the `tokenizer` with the argument `truncation=True`. This will ensure that an input longer than what the model can handle will be truncated to the maximum length accepted by the model." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "id": "FgjEr0fKMRQ_" }, "outputs": [], "source": [ "sentence1_key, sentence2_key = task_to_keys[task]\n", "\n", "def preprocess_function(examples):\n", " texts = (\n", " (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])\n", " )\n", " processed = tokenizer(*texts, padding=\"max_length\", max_length=128, truncation=True)\n", " \n", " processed[\"labels\"] = examples[\"label\"]\n", " return processed" ] }, { "cell_type": "markdown", "metadata": { "id": "zS-6iXTkIrJT" }, "source": [ "To apply this function to all the sentences (or pairs of sentences) in our dataset, we just use the `map` method of our dataset object we created earlier. This will apply the function on all the elements of all the splits in the dataset, so our training, validation, and testing data will be preprocessed in one single command.\n", "\n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 166, "referenced_widgets": [ "1df0ad12e2cc4a3294d0e7aee40944aa", "c2aa97eb306a4eaea4b8d0bf703dcd35", "6cd686871a9d4672b734e5f7a846723b", "03a2c7cb6e304c8f82b68b26860ff2a6", "feab068033474760844c405518aa43fc", "0698b1aa83d84f54954e28d12a9e869b", "3881043dd9d243ebb8652d49caf99dd2", "b6dde7b93ea44a8281644f6546ee09c3", "31ea7b1deedb4409bd49c50fb4adf832", "c799f7065de644fbaa2ffba036d3efa6", "a15081c66eb74cd7ad54fbf2fed198c2", "b35ce55a60bb41b6991c7d127a1b7080", "a63a4f1f8bd04263bb8e1a8ec7217f03", "00124fd463924c049e8e9b75bd2e61e2", "af9c5c97dc4d44688d1b0548e69d37d2", "c3844041a3974d1aa62f1ee9f28112c7", "29ff497dab9d43c6bd6179f10e4b8d45", "c3cbb2bd91f946a982416f6c6af9a805", "5425e4bd26cd43b59106a11c083def05", "c0caf7e383fd4b2fa49d1b5548c53b2d", "be22f92d61804bceb82c477c92d23370", "c831ea00bd964b40b832740fa04a138e", "e7b841e3d5574570a28e981277e55b65", "3f20b7115c8f44bb891c61e3abe07701" ] }, "id": "DDtsaJeVIrJT", "outputId": "38e1e7e8-00cd-46f6-c0f3-aff4e53d27b7" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f71182eb482845a88dfc370a225fdaf0", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/1063 [00:00` and `` with your credentials accordingly." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "D5ILJ1B-up41" }, "outputs": [], "source": [ "!git config --global user.email \"\" # e.g. \"patrick.v.platen@gmail.com\"\n", "!git config --global user.name \"\" # e.g. \"Patrick von Platen\"" ] }, { "cell_type": "markdown", "metadata": { "id": "QwPQyaTOAM5H" }, "source": [ "You will need to pass your user authentification token `hf_auth_token` to allow the 🤗 hub to upload model weights under your username. To find your authentification token, you need to log in [here](https://huggingface.co/login), click on your icon (top right), go to *Settings*, then *API Tokens*. You can copy the *User API token* and replace it with the field `` below." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "OgET7o-Jv2OR" }, "outputs": [], "source": [ "hf_auth_token = \"\" # e.g. api_DaYgaaVnGdRtznIgiNfotCHFUqmOdARmPx" ] }, { "cell_type": "markdown", "metadata": { "id": "FRM1k_8FCMiP" }, "source": [ "Finally, you can give your fine-tuned model a nice `model_id` or leave the default one as noted below. The model will be uploaded under `https://huggingface.co//`, *e.g.*\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-XGT5eAzBoUK" }, "outputs": [], "source": [ "model_id = f\"{model_checkpoint}_fine_tuned_glue_{task}\"" ] }, { "cell_type": "markdown", "metadata": { "id": "FJHgcqvkBsIc" }, "source": [ "Great! Now all that is left to do is to upload your model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 336 }, "id": "cIT_aQ-OoxvG", "outputId": "8d8f28ce-03e5-4f91-9d60-fc566551a307" }, "outputs": [], "source": [ "model.push_to_hub(model_id, use_auth_token=hf_auth_token)\n", "tokenizer.push_to_hub(model_id, use_auth_token=hf_auth_token)" ] }, { "cell_type": "markdown", "metadata": { "id": "JlviIJzcC_WD" }, "source": [ "You can now go to your model page to check it out 😊.\n", "\n", "We strongly recommend to add a model card so that the community can actually make use of your fine-tuned model. You can do so by clicking on *Create Model Card* and adding a descriptive text in markdown format. \n", "\n", "A simple description for your fine-tuned model is given by running the following cell. You can simply copy-paste the output to be used as your model card `README.md`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QSKl6rX7D_-N" }, "outputs": [], "source": [ "print(f\"\"\"---\n", "language: en\n", "license: apache-2.0\n", "datasets:\n", "- glue\n", "---\n", "# {\" \".join([x.capitalize() for x in model_id.split(\"_\")])}\n", "\n", "This checkpoint was initialized from the pre-trained checkpoint {model_checkpoint} and subsequently fine-tuned on GLUE task: {task} using [this](https://colab.research.google.com/drive/162pW3wonGcMMrGxmA-jdxwy1rhqXd90x?usp=sharing) notebook.\n", "Training was conducted for {num_train_epochs} epochs, using a linear decaying learning rate of {learning_rate}, and a total batch size of {total_batch_size}.\n", "\n", "The model has a final training loss of {loss} and a {metric_name} of {eval_score}.\n", "\"\"\")" ] }, { "cell_type": "markdown", "metadata": { "id": "OpFZeamCI-FJ" }, "source": [ "An uploaded model would, *e.g.*, look as follows: [patrickvonplaten/bert-base-cased_fine_tuned_glue_mrpc_demo](https://huggingface.co/patrickvonplaten/bert-base-cased_fine_tuned_glue_mrpc_demo)." ] } ], "metadata": { "accelerator": "TPU", "colab": { "collapsed_sections": [], "include_colab_link": true, "name": "Text Classification on GLUE on TPU using Jax/Flax", "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "00124fd463924c049e8e9b75bd2e61e2": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "005ddf17003e48acaca848a09f7ac39c": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_a7732825135944dab5cb53cb5c5c11ec", "IPY_MODEL_38092590bf484681851fe46b032474db" ], "layout": "IPY_MODEL_20521ff323fd4e97bb40fa87647e3d0c" } }, "01ad6e8eaaff463197cde256018b0be2": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "03a2c7cb6e304c8f82b68b26860ff2a6": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_b6dde7b93ea44a8281644f6546ee09c3", "placeholder": "​", "style": "IPY_MODEL_3881043dd9d243ebb8652d49caf99dd2", "value": " 9/9 [00:01<00:00, 6.77ba/s]" } }, "041899de5a914f678ab50028fd1be25f": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "057905968ef34b748e0d7c0595cf4b0f": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_7c67c6815d7e4420b4edbcfa0cf993f3", "placeholder": "​", "style": "IPY_MODEL_196260e063ea4355bc5e81383c7c9780", "value": " 28.8k/? [00:00<00:00, 60.6kB/s]" } }, "05ff232524ff44ed999aeb4349afa471": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_55b3e1d36b504c61b922dc063cc5b004", "IPY_MODEL_b1faa6a04e894f40945dd4742c21245d" ], "layout": "IPY_MODEL_fd13060f59bc4d2e80bf91f7cdf2bc3e" } }, "0698b1aa83d84f54954e28d12a9e869b": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "077d4c7f496d4a0f90954ebd86629055": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_3884494e0b724944b1da710342ffaf1b", "placeholder": "​", "style": "IPY_MODEL_14fb47c050504776934f208068b53363", "value": " 1063/0 [00:00<00:00, 994.35 examples/s]" } }, "07f50a68554d44c6a366f68678aa8b2f": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "", "description": "Evaluating...: 100%", "description_tooltip": null, "layout": "IPY_MODEL_253d243c7ad946c597d38998a10a9a04", "max": 32, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_71dd6360f4a94398b7c941e41527daff", "value": 32 } }, "085651db75a94e9ba2eabcb83e0e8ea1": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "0b6d4fa5e005408f9d66abc2b3ebbd8e": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_e2d9852bdb664d6490b08885833f2cf8", "placeholder": "​", "style": "IPY_MODEL_3048ef3256c849309b77d163fa5dc198", "value": " 8551/0 [00:05<00:00, 1.32s/ examples]" } }, "10d375495cc04b6e97f0063c6fd9ff9f": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "10ee2d2e0ff44ec5b960e4d46e04fe29": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "14fb47c050504776934f208068b53363": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "170bc89474774b83920b8f6b60eb6966": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "196260e063ea4355bc5e81383c7c9780": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "1a920e46cbf048c6a55dee80ee0d4643": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "1a96e9131a2a447ebd64007406c5080c": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_07f50a68554d44c6a366f68678aa8b2f", "IPY_MODEL_fa4fffa7ef61478aa4a95858291b4be5" ], "layout": "IPY_MODEL_e9c58401f58449009b1b3bc2d7af804a" } }, "1b601a8984a14420adf9e6fc08c7f3ba": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "1c486838e5b9473995178e43a50f114e": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "1df0ad12e2cc4a3294d0e7aee40944aa": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_6cd686871a9d4672b734e5f7a846723b", "IPY_MODEL_03a2c7cb6e304c8f82b68b26860ff2a6" ], "layout": "IPY_MODEL_c2aa97eb306a4eaea4b8d0bf703dcd35" } }, "1f5332b51dc34929bafc7ff50ee29790": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "", "description": "Training...: 100%", "description_tooltip": null, "layout": "IPY_MODEL_ab15355b677e447888cea229be1cf078", "max": 267, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_b26ae06ad32942e6af506b91077521fc", "value": 267 } }, "20521ff323fd4e97bb40fa87647e3d0c": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "20ca2b8a944544b8819869cedd82b08a": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "22c310d51b324e05aca35a9bfda468c8": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "253d243c7ad946c597d38998a10a9a04": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "2627873a55e94464b9ef86ac93fa26cd": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "2770f1482e784b6b8531d7b2937f4f70": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "27a673e8f31f488b933496e737f0577e": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "290c78228238478cb790e38b420ad19d": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "29ff497dab9d43c6bd6179f10e4b8d45": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_5425e4bd26cd43b59106a11c083def05", "IPY_MODEL_c0caf7e383fd4b2fa49d1b5548c53b2d" ], "layout": "IPY_MODEL_c3cbb2bd91f946a982416f6c6af9a805" } }, "2d07f4cb109e491f9b667ffa85cbcee9": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_20ca2b8a944544b8819869cedd82b08a", "placeholder": "​", "style": "IPY_MODEL_c93f2e1923dc480892908808b28e7599", "value": " 433M/433M [00:18<00:00, 23.0MB/s]" } }, "2da8456b20d34f8ea06613d623ec92be": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_bc5ebfc000824c0e9b3ade682e8c745e", "IPY_MODEL_2d07f4cb109e491f9b667ffa85cbcee9" ], "layout": "IPY_MODEL_beb5aad666684683b184fc2b228666d8" } }, "3039f03a5e1c41258b9f108e93e7610d": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "3048ef3256c849309b77d163fa5dc198": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "30a4e1848fbf4d0695a2ed73f990633a": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_b69795bbbec740868f25cf5c09ed525e", "placeholder": "​", "style": "IPY_MODEL_8c1f773ce9ac4cbe8170432da358a78d", "value": " 29.0/29.0 [00:00<00:00, 108B/s]" } }, "310ffddcdbd94cbd99773f13e33ec5e8": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "", "description": "Evaluating...: 100%", "description_tooltip": null, "layout": "IPY_MODEL_31ec854d9c4e46429a6cd38fcc12c2b1", "max": 32, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_34eced6679ae4af88ba616ae77a25815", "value": 32 } }, "31ea7b1deedb4409bd49c50fb4adf832": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_a15081c66eb74cd7ad54fbf2fed198c2", "IPY_MODEL_b35ce55a60bb41b6991c7d127a1b7080" ], "layout": "IPY_MODEL_c799f7065de644fbaa2ffba036d3efa6" } }, "31ec854d9c4e46429a6cd38fcc12c2b1": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "34eced6679ae4af88ba616ae77a25815": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "37714174d00841aabf360bfd5852fa7b": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_310ffddcdbd94cbd99773f13e33ec5e8", "IPY_MODEL_a5fee02451df4b22b5a78e89d5a88a5f" ], "layout": "IPY_MODEL_d26d31d72c3348b6938f2b6739486086" } }, "38092590bf484681851fe46b032474db": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_aaf5625419c144779e0edd64fd90562e", "placeholder": "​", "style": "IPY_MODEL_ead0ec7f99824edc9832f403999d1108", "value": " 32/32 [00:03<00:00, 11.71it/s]" } }, "38458420fec544f6b7cc35a8f17851cf": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_63cd8679fd504d158eae06bbd1a65037", "IPY_MODEL_c0558be5fdf244678687150b067c37b7" ], "layout": "IPY_MODEL_fce53153459d4d62b2f56d2d1638a7ee" } }, "3881043dd9d243ebb8652d49caf99dd2": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "3884494e0b724944b1da710342ffaf1b": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "3a85f94868b6459a84dcc82017755a8b": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "3f20b7115c8f44bb891c61e3abe07701": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "3f815388d365421fbad24f391b46af47": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "40e6d68d7ef24502be7f7d937b9e3db8": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "info", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_5b056e2c5e654f069bb6eae7e8c3d060", "max": 1, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_7f3af6d2a01b4bf8a1173393bf906e2c", "value": 1 } }, "45bf8fd612424c06a4db1c272490bbfd": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "466f9d448cb4403f9bc957dd1172facf": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_4c76d754f08b464f9564a8b07fc898c9", "IPY_MODEL_a05efc0c2cbd441b87d8e9994ea17ab5" ], "layout": "IPY_MODEL_cb0aace06e9048fca19f5045067304e8" } }, "4a3e4f81a6db440495e82cac0d975eea": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "4bd0dc33abbc4f9087876b31f44b8c22": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_92a5d91b961b4f56aa2a89b2be86c552", "IPY_MODEL_30a4e1848fbf4d0695a2ed73f990633a" ], "layout": "IPY_MODEL_d490e07908694463b378a0d273cd2137" } }, "4c5f23f61db74c7ab831a742d94e4d7c": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "4c76d754f08b464f9564a8b07fc898c9": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "Epoch ...: 100%", "description_tooltip": null, "layout": "IPY_MODEL_af7439d93afe46d39b4916001388dddb", "max": 3, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_d7b90d73649f4df186680dc0da01107f", "value": 3 } }, "5425e4bd26cd43b59106a11c083def05": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "100%", "description_tooltip": null, "layout": "IPY_MODEL_c831ea00bd964b40b832740fa04a138e", "max": 2, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_be22f92d61804bceb82c477c92d23370", "value": 2 } }, "55b3e1d36b504c61b922dc063cc5b004": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "", "description": "Training...: 100%", "description_tooltip": null, "layout": "IPY_MODEL_9076e9988d684059853325ffe1f750dc", "max": 267, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_d81c90a0f5f841a1bb0f0c9dfead879d", "value": 267 } }, "59497a1d13974766a1633235b8700341": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_1f5332b51dc34929bafc7ff50ee29790", "IPY_MODEL_75f2c8cbfffd4f808b69a5eb38c04762" ], "layout": "IPY_MODEL_c963839a422d4ded939c88dc474f5ce4" } }, "5b056e2c5e654f069bb6eae7e8c3d060": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "5bbda35c1c3e4c0ebd32987f84f0c936": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_d2389b5b7bba40bca36ffceb80453576", "IPY_MODEL_cce1a6d399914221b5c52c3a5fe3911c" ], "layout": "IPY_MODEL_fa10efd7c0c443218c89c76b889cffd7" } }, "5c9976c5ce17422b8c992595aff5ce34": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "", "description": "Training...: 100%", "description_tooltip": null, "layout": "IPY_MODEL_a326d0eb97fc4688941584009b4b1014", "max": 267, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_f6a235461ff34719aaf60e844710d3ca", "value": 267 } }, "616b7ddb0a644791b10e06e6cba743b1": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_92fceaea6f1b4c9484b0ec860b77609c", "IPY_MODEL_8da21f5fa32949858e1e9d1b9dc9861a" ], "layout": "IPY_MODEL_9add322f037242f7945f9779bd6ba82d" } }, "63cd8679fd504d158eae06bbd1a65037": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "info", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_2627873a55e94464b9ef86ac93fa26cd", "max": 1, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_1b601a8984a14420adf9e6fc08c7f3ba", "value": 1 } }, "6477db7c602848cc96d82eb5e7a303ad": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "69f48b05eb404f52a05dfb5dfef1b8ff": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "6cd686871a9d4672b734e5f7a846723b": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "100%", "description_tooltip": null, "layout": "IPY_MODEL_0698b1aa83d84f54954e28d12a9e869b", "max": 9, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_feab068033474760844c405518aa43fc", "value": 9 } }, "71dd6360f4a94398b7c941e41527daff": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "7332e7753d71493780cc160c2607f689": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "75f2c8cbfffd4f808b69a5eb38c04762": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_e246074395344659937f9233ee82b6f9", "placeholder": "​", "style": "IPY_MODEL_2770f1482e784b6b8531d7b2937f4f70", "value": " 267/267 [01:31<00:00, 3.12it/s]" } }, "7650e4481d3a4ce9a3b3f1ce347fbc3a": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "78eca35ab52d479db5b907e70cd46b1a": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "7bf3d5d1563d4e8cbb840f0267325c48": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_ccc693c6f938477db01c9fb69526f348", "IPY_MODEL_8f002deb5d03403f82b0968adc2da45f" ], "layout": "IPY_MODEL_3f815388d365421fbad24f391b46af47" } }, "7c67c6815d7e4420b4edbcfa0cf993f3": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "7cc56ba6c2e64afd85856f4eeb486ef1": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "7e0ab271758843e188a9b67d2248f320": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "7f3af6d2a01b4bf8a1173393bf906e2c": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "818f4039d4eb4f0c912fa726c5f8e566": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "849f7ad495b64111a4d642b371fb5e7f": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "8501dde32f014c3489996ebcd58a86a2": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "85804123d219427aa8554e5c1abcd188": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_c392dbcdc45e462088de3a538817f89c", "IPY_MODEL_057905968ef34b748e0d7c0595cf4b0f" ], "layout": "IPY_MODEL_936a8e8c640b4d48ba547cea71c10a05" } }, "8c1f773ce9ac4cbe8170432da358a78d": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "8da21f5fa32949858e1e9d1b9dc9861a": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_bcfbbdeda9864282af5e8d108b036c08", "placeholder": "​", "style": "IPY_MODEL_e7d9d01bc9724a8299e1960375b71298", "value": " 570/570 [00:01<00:00, 308B/s]" } }, "8f002deb5d03403f82b0968adc2da45f": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_1a920e46cbf048c6a55dee80ee0d4643", "placeholder": "​", "style": "IPY_MODEL_dcc6a32d60974dfc984769343c323c65", "value": " 5.75k/? [00:04<00:00, 1.43kB/s]" } }, "9076e9988d684059853325ffe1f750dc": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "90a57c6925a2494c8472a592f88df9c3": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "92a5d91b961b4f56aa2a89b2be86c552": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "Downloading: 100%", "description_tooltip": null, "layout": "IPY_MODEL_9a784c51e03a4667b536db14a8624599", "max": 29, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_da0dc5ce3a6545e78ce8b8c4421d4474", "value": 29 } }, "92fceaea6f1b4c9484b0ec860b77609c": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "Downloading: 100%", "description_tooltip": null, "layout": "IPY_MODEL_22c310d51b324e05aca35a9bfda468c8", "max": 570, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_b7702d8f0d6146078c9219846de80802", "value": 570 } }, "936a8e8c640b4d48ba547cea71c10a05": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "9a784c51e03a4667b536db14a8624599": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "9add322f037242f7945f9779bd6ba82d": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "9d8fbf8b04ab4975b7d8abdbb87f7b48": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "9dba428f61f14559a56813d2d51801d0": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "9eba683ec8e44e9eb00bce6feca40d2b": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "9fcb9473cd074e0c888f9105c148a3b1": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_c36c74ad5815443b8c0b1eafff12489e", "placeholder": "​", "style": "IPY_MODEL_041899de5a914f678ab50028fd1be25f", "value": " 213k/213k [00:00<00:00, 493kB/s]" } }, "a05efc0c2cbd441b87d8e9994ea17ab5": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_7cc56ba6c2e64afd85856f4eeb486ef1", "placeholder": "​", "style": "IPY_MODEL_1c486838e5b9473995178e43a50f114e", "value": " 3/3 [12:04<00:00, 241.54s/it]" } }, "a15081c66eb74cd7ad54fbf2fed198c2": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "100%", "description_tooltip": null, "layout": "IPY_MODEL_00124fd463924c049e8e9b75bd2e61e2", "max": 2, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_a63a4f1f8bd04263bb8e1a8ec7217f03", "value": 2 } }, "a326d0eb97fc4688941584009b4b1014": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "a3c46f9c799448349dc47b2d4e06817b": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_de0f85b55b7a40c3bd48d27d8afa6436", "IPY_MODEL_ffb11704f93143a286fa45c507246336" ], "layout": "IPY_MODEL_e96d8cf1ebc64aabaaf17328ec6916a2" } }, "a5749038eade4b6e983ef854019194e1": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "a5fee02451df4b22b5a78e89d5a88a5f": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_01ad6e8eaaff463197cde256018b0be2", "placeholder": "​", "style": "IPY_MODEL_4a3e4f81a6db440495e82cac0d975eea", "value": " 32/32 [00:02<00:00, 11.31it/s]" } }, "a63a4f1f8bd04263bb8e1a8ec7217f03": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "a7732825135944dab5cb53cb5c5c11ec": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "", "description": "Evaluating...: 100%", "description_tooltip": null, "layout": "IPY_MODEL_10d375495cc04b6e97f0063c6fd9ff9f", "max": 32, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_ec191bcc7d1341039a944f9167379c1a", "value": 32 } }, "a9f2c4ec1ed84aecbdbaa12cb4ea7eee": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "aaf5625419c144779e0edd64fd90562e": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "ab15355b677e447888cea229be1cf078": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "ad82960c7f9448e29118c3bd8cd7adaf": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "af53f9ae0dff455c834cb354e4b9f90b": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "af7439d93afe46d39b4916001388dddb": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "af9c5c97dc4d44688d1b0548e69d37d2": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "b0d330f5b203413ebe9c6365d9240914": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "b1faa6a04e894f40945dd4742c21245d": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_bc8f5114bb54453fbb28a53bf5153558", "placeholder": "​", "style": "IPY_MODEL_7332e7753d71493780cc160c2607f689", "value": " 267/267 [08:49<00:00, 3.11it/s]" } }, "b239280cbe5a4f5991959b14c20dd6e2": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_40e6d68d7ef24502be7f7d937b9e3db8", "IPY_MODEL_077d4c7f496d4a0f90954ebd86629055" ], "layout": "IPY_MODEL_7e0ab271758843e188a9b67d2248f320" } }, "b26ae06ad32942e6af506b91077521fc": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "b35ce55a60bb41b6991c7d127a1b7080": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_c3844041a3974d1aa62f1ee9f28112c7", "placeholder": "​", "style": "IPY_MODEL_af9c5c97dc4d44688d1b0548e69d37d2", "value": " 2/2 [00:52<00:00, 26.42s/ba]" } }, "b376f8d2f6d5426d95d9cf0feddd5a77": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_5c9976c5ce17422b8c992595aff5ce34", "IPY_MODEL_fd5744ff76fa4d0593177f28c5fd4333" ], "layout": "IPY_MODEL_3a85f94868b6459a84dcc82017755a8b" } }, "b38dc8c0b13e48d99a96ef51a6be1048": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_45bf8fd612424c06a4db1c272490bbfd", "placeholder": "​", "style": "IPY_MODEL_818f4039d4eb4f0c912fa726c5f8e566", "value": " 28.7k/? [00:00<00:00, 183kB/s]" } }, "b3939ab991d44a30b385075981945b58": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_cdba2232a1db40acb3b0d773deec8bd7", "IPY_MODEL_b38dc8c0b13e48d99a96ef51a6be1048" ], "layout": "IPY_MODEL_4c5f23f61db74c7ab831a742d94e4d7c" } }, "b69795bbbec740868f25cf5c09ed525e": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "b6dde7b93ea44a8281644f6546ee09c3": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "b7702d8f0d6146078c9219846de80802": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "bc5ebfc000824c0e9b3ade682e8c745e": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "Downloading: 100%", "description_tooltip": null, "layout": "IPY_MODEL_90a57c6925a2494c8472a592f88df9c3", "max": 433370483, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_cc6105241c164185aa943ce34fcbcccf", "value": 433370483 } }, "bc8f5114bb54453fbb28a53bf5153558": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "bcfbbdeda9864282af5e8d108b036c08": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "bd8ac3ec6e38427db5362fc96b622864": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "info", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_290c78228238478cb790e38b420ad19d", "max": 1, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_6477db7c602848cc96d82eb5e7a303ad", "value": 1 } }, "be22f92d61804bceb82c477c92d23370": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "beb5aad666684683b184fc2b228666d8": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "bed61df41ef9463d8cfc2ef4e046b417": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "c0558be5fdf244678687150b067c37b7": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_69f48b05eb404f52a05dfb5dfef1b8ff", "placeholder": "​", "style": "IPY_MODEL_9dba428f61f14559a56813d2d51801d0", "value": " 1043/0 [00:00<00:00, 2689.25 examples/s]" } }, "c0caf7e383fd4b2fa49d1b5548c53b2d": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_3f20b7115c8f44bb891c61e3abe07701", "placeholder": "​", "style": "IPY_MODEL_e7b841e3d5574570a28e981277e55b65", "value": " 2/2 [00:00<00:00, 6.82ba/s]" } }, "c11d5ebb907e4f64bbcd38b822144085": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "c2aa97eb306a4eaea4b8d0bf703dcd35": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "c36c74ad5815443b8c0b1eafff12489e": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "c3844041a3974d1aa62f1ee9f28112c7": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "c392dbcdc45e462088de3a538817f89c": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "Downloading: ", "description_tooltip": null, "layout": "IPY_MODEL_c11d5ebb907e4f64bbcd38b822144085", "max": 7777, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_3039f03a5e1c41258b9f108e93e7610d", "value": 7777 } }, "c3cbb2bd91f946a982416f6c6af9a805": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "c4e33462dac044c7a53d9f68974b1ccd": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_bd8ac3ec6e38427db5362fc96b622864", "IPY_MODEL_0b6d4fa5e005408f9d66abc2b3ebbd8e" ], "layout": "IPY_MODEL_9d8fbf8b04ab4975b7d8abdbb87f7b48" } }, "c799f7065de644fbaa2ffba036d3efa6": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "c831ea00bd964b40b832740fa04a138e": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "c93f2e1923dc480892908808b28e7599": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "c963839a422d4ded939c88dc474f5ce4": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "cb0aace06e9048fca19f5045067304e8": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "cc6105241c164185aa943ce34fcbcccf": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "ccc693c6f938477db01c9fb69526f348": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "Downloading: ", "description_tooltip": null, "layout": "IPY_MODEL_78eca35ab52d479db5b907e70cd46b1a", "max": 1848, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_f57ab763e32e4aa59519df04c32b44b1", "value": 1848 } }, "cce1a6d399914221b5c52c3a5fe3911c": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_a9f2c4ec1ed84aecbdbaa12cb4ea7eee", "placeholder": "​", "style": "IPY_MODEL_d877c534e0354564a379c8da96f190a4", "value": " 436k/436k [00:01<00:00, 399kB/s]" } }, "cdba2232a1db40acb3b0d773deec8bd7": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "Downloading: ", "description_tooltip": null, "layout": "IPY_MODEL_9eba683ec8e44e9eb00bce6feca40d2b", "max": 4473, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_ce4e36d3fba04286807cafa4f02a8c8f", "value": 4473 } }, "ce4e36d3fba04286807cafa4f02a8c8f": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "d2389b5b7bba40bca36ffceb80453576": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "Downloading: 100%", "description_tooltip": null, "layout": "IPY_MODEL_085651db75a94e9ba2eabcb83e0e8ea1", "max": 435797, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_a5749038eade4b6e983ef854019194e1", "value": 435797 } }, "d26d31d72c3348b6938f2b6739486086": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "d490e07908694463b378a0d273cd2137": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "d7b90d73649f4df186680dc0da01107f": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "d81c90a0f5f841a1bb0f0c9dfead879d": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "d877c534e0354564a379c8da96f190a4": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "da0dc5ce3a6545e78ce8b8c4421d4474": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "dcc6a32d60974dfc984769343c323c65": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "ddd42c117b7e41a6a33e302b38a7f3df": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_f2f76ebc808d415db2c1d0175aca27cc", "IPY_MODEL_9fcb9473cd074e0c888f9105c148a3b1" ], "layout": "IPY_MODEL_849f7ad495b64111a4d642b371fb5e7f" } }, "de0f85b55b7a40c3bd48d27d8afa6436": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "Downloading: 100%", "description_tooltip": null, "layout": "IPY_MODEL_fd2d5ec23d884063b43b2ec6bd74bc41", "max": 376971, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_10ee2d2e0ff44ec5b960e4d46e04fe29", "value": 376971 } }, "e246074395344659937f9233ee82b6f9": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "e2d9852bdb664d6490b08885833f2cf8": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "e7b841e3d5574570a28e981277e55b65": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "e7d9d01bc9724a8299e1960375b71298": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "e96d8cf1ebc64aabaaf17328ec6916a2": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "e9c58401f58449009b1b3bc2d7af804a": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "ead0ec7f99824edc9832f403999d1108": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "ec191bcc7d1341039a944f9167379c1a": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "f2f76ebc808d415db2c1d0175aca27cc": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "Downloading: 100%", "description_tooltip": null, "layout": "IPY_MODEL_ad82960c7f9448e29118c3bd8cd7adaf", "max": 213450, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_27a673e8f31f488b933496e737f0577e", "value": 213450 } }, "f57ab763e32e4aa59519df04c32b44b1": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "f6a235461ff34719aaf60e844710d3ca": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "fa10efd7c0c443218c89c76b889cffd7": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "fa4fffa7ef61478aa4a95858291b4be5": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_b0d330f5b203413ebe9c6365d9240914", "placeholder": "​", "style": "IPY_MODEL_bed61df41ef9463d8cfc2ef4e046b417", "value": " 32/32 [00:09<00:00, 8.47it/s]" } }, "fce53153459d4d62b2f56d2d1638a7ee": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "fd13060f59bc4d2e80bf91f7cdf2bc3e": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "fd2d5ec23d884063b43b2ec6bd74bc41": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "fd5744ff76fa4d0593177f28c5fd4333": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_af53f9ae0dff455c834cb354e4b9f90b", "placeholder": "​", "style": "IPY_MODEL_7650e4481d3a4ce9a3b3f1ce347fbc3a", "value": " 267/267 [01:26<00:00, 3.16it/s]" } }, "feab068033474760844c405518aa43fc": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "ffb11704f93143a286fa45c507246336": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_170bc89474774b83920b8f6b60eb6966", "placeholder": "​", "style": "IPY_MODEL_8501dde32f014c3489996ebcd58a86a2", "value": " 377k/377k [00:00<00:00, 1.14MB/s]" } } } } }, "nbformat": 4, "nbformat_minor": 1 }