{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "NJZUsvUMhFtU",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"# **Fine-tuning for Image Classification with 🤗 Transformers**\n",
"\n",
"This notebook shows how to fine-tune any pretrained Vision model for Image Classification on a custom dataset. The idea is to add a randomly initialized classification head on top of a pre-trained encoder, and fine-tune the model altogether on a labeled dataset.\n",
"\n",
"## ImageFolder\n",
"\n",
"This notebook leverages the [ImageFolder](https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder) feature to easily run the notebook on a custom dataset (namely, [EuroSAT](https://github.com/phelber/EuroSAT) in this tutorial). You can either load a `Dataset` from local folders or from local/remote files, like zip or tar.\n",
"\n",
"## Any model\n",
"\n",
"This notebook is built to run on any image classification dataset with any vision model checkpoint from the [Model Hub](https://huggingface.co/) as long as that model has a version with a Image Classification head, such as:\n",
"* [ViT](https://huggingface.co/docs/transformers/model_doc/vit#transformers.ViTForImageClassification)\n",
"* [Swin Transformer](https://huggingface.co/docs/transformers/model_doc/swin#transformers.SwinForImageClassification)\n",
"* [ConvNeXT](https://huggingface.co/docs/transformers/master/en/model_doc/convnext#transformers.ConvNextForImageClassification)\n",
"\n",
"- in short, any model supported by [AutoModelForImageClassification](https://huggingface.co/docs/transformers/model_doc/auto#transformers.AutoModelForImageClassification).\n",
"\n",
"## Data augmentation\n",
"\n",
"This notebook leverages Kornia's [image augmentations](https://kornia.readthedocs.io/en/latest/augmentation.module.html) for applying data augmentation - note that we do provide alternative notebooks which leverage other libraries, including:\n",
"\n",
"* [Torchvision](https://github.com/huggingface/notebooks/blob/main/examples/image_classification.ipynb)\n",
"* [Albumentations](https://github.com/huggingface/notebooks/blob/main/examples/image_classification_albumentations.ipynb)\n",
"* [imgaug](https://github.com/huggingface/notebooks/blob/main/examples/image_classification_imgaug.ipynb). \n",
"\n",
"---\n",
"\n",
"Depending on the model and the GPU you are using, you might need to adjust the batch size to avoid out-of-memory errors. Set those two parameters, then the rest of the notebook should run smoothly.\n",
"\n",
"In this notebook, we'll fine-tune from the https://huggingface.co/microsoft/swin-tiny-patch4-window7-224 checkpoint, but note that there are many, many more available on the [hub](https://huggingface.co/models?other=vision). We will also use the [datasets](https://huggingface.co/docs/datasets/installation) library to load an image dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "wvLDfqzdhFtb",
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"model_checkpoint = \"microsoft/swin-tiny-patch4-window7-224\" # pre-trained model from which to fine-tune\n",
"batch_size = 32 # batch size for training and evaluation"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WOynCHJWhFtc",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"Before we start, let's install the `kornia`, `datasets` and `transformers` libraries. We'll install `evaluate` to evaluate our model's accuracy during and after training, which requires `sklearn`. Since we'll be working with images, we'll also ensure that `Pillow` is installed."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "krONvnn0hFtd",
"outputId": "67318c2f-65f1-4fde-be66-4c6f66cd33c5",
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"!pip install -q kornia datasets transformers evaluate sklearn Pillow"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dEPOo0jnhFtd",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"If you're opening this notebook locally, make sure your environment has an install from the last version of those libraries."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We also quickly upload some telemetry - this tells us which examples and software versions are getting used so we know where to prioritize our maintenance efforts. We don't collect (or care about) any personally identifiable information, but if you'd prefer not to be counted, feel free to skip this step or delete this cell entirely."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers.utils import send_example_telemetry\n",
"\n",
"send_example_telemetry(\"image_classification_kornia_notebook\", framework=\"pytorch\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Km4rEvjJhFtd",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"## Fine-tuning a model on an image classification task"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NknH0OJFhFte",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"In this notebook, we will see how to fine-tune one of the [🤗 Transformers](https://github.com/huggingface/transformers) vision models on an Image Classification dataset.\n",
"\n",
"Given an image, the goal is to predict an appropriate class for it, like \"tiger\". The screenshot below is taken from a [ViT fine-tuned on ImageNet-1k](https://huggingface.co/google/vit-base-patch16-224) - try out the inference widget!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XtNzED6hhFtf",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"
\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zMOBzRmOhFtf",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"### Loading the dataset"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Sin4A8CwhFtf",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"We will use the [🤗 Datasets](https://github.com/huggingface/datasets) library's [ImageFolder](https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder) feature to download our custom dataset into a DatasetDict.\n",
"\n",
"In this case, the EuroSAT dataset is hosted remotely, so we provide the `data_files` argument. Alternatively, if you have local folders with images, you can load them using the `data_dir` argument."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 248,
"referenced_widgets": [
"7058d1e258de40ebad13ff790915cd3c",
"01ea35ba78b440d392b2c5e75ea74464",
"e2e16d974fb249dd94a11507f9224a2d",
"58f000f788c14a268cc102d5237b8c04",
"377190442ab84770b7e3d27199ac8563",
"30a8bd848ee1497691b7605d4dd32115",
"c1537b4150ab4cfca215d70f6095a110",
"aca7889b9fd34829bc7c7ed72c467aa1",
"32e6e84f4de54b26a7218bf720324752",
"f6693950a02d4f9d933c2e32e323f521",
"adaaa2e702174b2b87894c0d3fcb1a57",
"4c4fb79433cf49faac5105c5403f4bc8",
"e19aa06a93034787a1a422c8297c50cf",
"45cc5f862ebc482d9186d3d4349cf602",
"c43a2cf9f9ee42cb8a114ced6ecce524",
"f1ab9e0ecd59406b825926098912b968",
"fb8dda387d89412388702e248a9514df",
"d37300abc9474836803687ab739be99d",
"8926f783fec742ef954ceb979f5abc70",
"c81f799609794db09cb4f5d19c39333c",
"78fa30c22d474c3da827b6af042b1e7d",
"ffaa450ad4e24d66a466ab53b424260f",
"17ad910846c647eea2967a556fa5d636",
"d19adaf170c1447f90538b85b2546a7a",
"b850715786e741fab64246dd616e4b0b",
"775c74c822b94d55897501502b8b3890",
"992069d6fbdd4046b6aa734b39a23ed5",
"4ba73b8913d8440ea8ddc4c0e08ed145",
"f5ded311e5ad49a4b7f77c5da177ffb9",
"3653830200414a2b9a5150eea235fcf1",
"2b479dce1b3d427483504e57fe4acd7a",
"1aa408392e434cde9bdc43e0f670617d",
"0a6cef96ab1841068e9027a30e76ed1e",
"c6108fb0bf124c3fa22f85bbd4bf9266",
"c3d49b63b8204a04ab075188f49b0d93",
"603016ca5f7f4e939e5a7f1f57b8dd33",
"a73b73fbcc164a81a3a4912f9682ab00",
"f05a4a836fb84bf7a5f2e75440bb026d",
"6da041a6484b4d1cbfa58cfcd265f88e",
"d9c9f4a8aca44e09a8136331e545ea88",
"f22ca12d18354eec80befa03d312fc44",
"c8eb2750eb2244699dc1492e658a9ccb",
"6f52fb52f05d465090227ede5cf2a30c",
"4ffe6ca0522542f2a96cd3191d6a99c6",
"825be3a2828c46398b79e6279c45236a",
"e32780c830504c508b85e67343ab88e9",
"0fd3d466727b4ac3bdbe1b32c67b0c26",
"8c70d4279bac4e1e9efdc3d3b9d45e34",
"a5dabc9865f34c9089b0eeca3897e467",
"f0fc57814beb464aa5c0fa560f3fa8a6",
"1453dc44aa334c7097608220933e46c2",
"432c7d96cdc840d6b54567474d09ab2f",
"72adfd5afe784624a673284e17a813c8",
"c07bfe12d2c04a0f9c5696797a1d67a5",
"fbb13fc905f240dab44f438f11421f89",
"9c22864ff042460584353d870e79b5fb",
"2ceb2fa333a94f5fb5207b5303bf0105",
"3b13fa319738413892aa0175b8d2c79a",
"bd3fb4774f2544abbfe84a8f77297022",
"e161abad994748e7be3655cffa4f7468",
"1aed4493a06e459193d06161e5dfe09e",
"36d59e2f979449fa8e27c58f44c1f801",
"c243a77c77eb4e51b217b0513c2d8546",
"2c21a4a4464343729719d8f3125a7535",
"dfae0ee344db4147a758e6b50daa66c8",
"7cfa331622ff4187a167c4a981491457"
]
},
"id": "6rfbAPGehFtf",
"outputId": "402e4107-7880-4174-b75e-09245376cad1",
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"from datasets import load_dataset\n",
"\n",
"# load a custom dataset from local/remote files or folders using the ImageFolder feature\n",
"\n",
"# option: local/remote files\n",
"dataset = load_dataset(\"jonathan-roberts1/EuroSAT\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0UfLRkTJhFtg",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"**Note:** You can also provide several splits:\n",
"\n",
"```python\n",
"dataset = load_dataset(\"imagefolder\", data_files={\"train\": [\"path/to/file1\", \"path/to/file2\"], \"test\": [\"path/to/file3\", \"path/to/file4\"]})\n",
"```\n",
"\n",
"Or load a dataset from your local:\n",
"\n",
"```python\n",
"dataset = load_dataset(\"imagefolder\", data_dir=\"path_to_folder\")\n",
"```\n",
"\n",
"If you want to share your dataset publicly or collaborate with others privately, you can also push your dataset to the hub very easily (and reload afterwards using load_dataset)!\n",
"\n",
"```python\n",
"dataset.push_to_hub(\"nielsr/eurosat\")\n",
"# Or use private=True to create a private dataset\n",
"dataset.push_to_hub(\"nielsr/eurosat\", private=True)\n",
"```\n",
"\n",
"**Datasets Hub: **\n",
"\n",
"At Hugging Face, we host a variety of NLP, vision, speech and multi-modal datasets. Simply head over to our Hub and search for datasets. If you find a dataset you would like to use, you can simply pass in the repository name to load the dataset from the [hub](https://huggingface.co/datasets). As an example, let's load the [CIFAR-10 dataset](https://huggingface.co/datasets/cifar10) from the hub.\n",
"\n",
"```python\n",
"dataset = load_dataset(\"cifar10\")\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5w7UCssAhFth",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"Let us also load the Accuracy metric, which we'll use to evaluate our model both during and after training."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 49,
"referenced_widgets": [
"f237d39322cc4b9e95fac2f5f23141cb",
"164e22e2e1c64ffc99b77f787cad3a22",
"83dd806f4a2b45d7bd34fc0c106a0eac",
"7398432c4cc74d2d8f9d61c40b786021",
"fe2eda6b84194d29a110b91a59cd0417",
"c79df4f5086a4cb4b251d8614928ebe2",
"71323ca5991e4c0399cb106cb59193fc",
"6f7c2ec995714255be700b3ffaba9f9e",
"b2a6342c762f45dda3b08ad3d9d0a6ad",
"d74c3e6888b84fb781f2b343534c369a",
"0f5d923b2b534881bd495dcaf47b94ec"
]
},
"id": "fFUVl4WehFth",
"outputId": "3b5ce754-f213-4e19-df16-5f9c0989002e",
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"import evaluate\n",
"\n",
"metric = evaluate.load(\"accuracy\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BAHRArlEhFth",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"The `dataset` object itself is a [`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict), which contains one key per split (in this case, only \"train\" for a training split)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "_Z3Aj8rQhFth",
"outputId": "b82d317a-fec8-4a55-8c78-343ed2304c7f",
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"dataset"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VjvxlD7LhFti",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"To access an actual element, you need to select a split first, then give an index:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "zGxlnwxuhFti",
"outputId": "f9f741f1-d09e-43ab-ef3e-b4d9e9e27307",
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"example = dataset[\"train\"][10]\n",
"example"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "z2kp7LEMhFti",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"Each example consists of an image and a corresponding label. We can also verify this by checking the features of the dataset:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "591Jg2hAhFtj",
"outputId": "0e1ad54e-beb5-4c64-af45-8f578074fa4c",
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"dataset[\"train\"].features"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uGJf6VXihFtj",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"The cool thing is that we can directly view the image (as the 'image' field is an [Image feature](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Image)), as follows:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 81
},
"id": "TZnB2NiHhFtj",
"outputId": "ed180e6e-9c03-433d-b61f-94925965c0cc",
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"example['image']"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1X9_obDDhFtj",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"Let's make it a little bigger as the images in the EuroSAT dataset are of low resolution (64x64 pixels):"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 217
},
"id": "-_DqwhyLhFtj",
"outputId": "532fe0ab-e681-4743-c0d3-ce4d91abbaf2",
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"example['image'].resize((200, 200))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ReZsyX1HhFtk",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"Let's print the corresponding label:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "DAKkVypyhFtk",
"outputId": "dc8847e6-e181-405e-a8b4-450f227c4894",
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"example['label']"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TYcebJ9thFtk",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"As you can see, the `label` field is not an actual string label. By default the `ClassLabel` fields are encoded into integers for convenience:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "n-nFQ_LAhFtk",
"outputId": "bb172e98-aa01-4548-d597-aefed4f8622d",
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"dataset[\"train\"].features[\"label\"]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9rDFBVD_hFtk",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"Let's create an `id2label` dictionary to decode them back to strings and see what they are. The inverse `label2id` will be useful too, when we load the model later."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 37
},
"id": "_JF2bw65hFtl",
"outputId": "f5c5f6b3-0b87-456c-9ec9-0f85dedd1185",
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"labels = dataset[\"train\"].features[\"label\"].names\n",
"label2id, id2label = dict(), dict()\n",
"for i, label in enumerate(labels):\n",
" label2id[label] = i\n",
" id2label[i] = label\n",
"\n",
"id2label[2]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IBseqwgnhFtl",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"### Sharing your model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "i6QkMriphFtl",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"To be able to share your model with the community and generate results like the one shown in the picture below via the inference API, there are a few more steps to follow.\n",
"\n",
"First you have to store your authentication token from the Hugging Face website (sign up [here](https://huggingface.co/join) if you haven't already!) then execute the following cell and input your token:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 300,
"referenced_widgets": [
"9a3fd883fbcd4192972daee5e9b724e8",
"e4e975ccdf71499e9e7e2628b4d12381",
"2263ea69fe2748e59e241121c0ee85db",
"ce8ff720c1194f9fa576fa0ab72b9434",
"a245a58b01304077b758257436e69b1a",
"92f5865c94ce4986a36d71e2cf1db91e",
"aa59cf09e9bf4aba951ab61b47c1ed40",
"ff6552e277994529b20848de6ad45b99",
"6e51c6d618164aeda48942aa1850ecf0",
"2fa7bbee9cee40479fd62976954dabe9",
"bb520a5722484f39bb93face96e36d69",
"907b655aeed94f1c97fb564525e04d01",
"013ca1d8272343a2a50a8bc49c22a76b",
"98245a2bb52d42369036d025ad91efdd"
]
},
"id": "kdSJmmmxhFtm",
"outputId": "0d4a1991-cd56-4caf-d516-f8479a8f6238",
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"from huggingface_hub import notebook_login\n",
"\n",
"notebook_login()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "O5Q0EhXehFtm",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"\n",
"Then you need to install Git-LFS to upload your model checkpoints:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "txxHusu1hFtm",
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"%%capture\n",
"!sudo apt -qq install git-lfs\n",
"!git config --global credential.helper store"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "L7vwcNjvhFtn",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"### Preprocessing the data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WXwah-dIhFtn",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"All Hugging Face model classes expect the input to be correctly preprocessed, so we need to preprocess these images before we can feed them to our model. Preprocessing images typically comes down to:\n",
"\n",
"1. **Resizing** them a particular size\n",
"2. **Normalizing** the color channels (R,G,B) using a mean and standard deviation.\n",
"\n",
"These are referred to as **image transformations**.\n",
"\n",
"In addition, one typically performs what is called **data augmentation** during training (like random cropping and flipping) to make the model more robust and achieve higher accuracy. Data augmentation is also a great technique to increase the size of the training data.\n",
"\n",
"We will use `Kornia` for the image transformations/data augmentation in this tutorial, but note that one can use any other package (like [torchvision](https://pytorch.org/vision/stable/transforms.html), [albumentations](https://albumentations.ai/), [imgaug](https://github.com/aleju/imgaug), etc.).\n",
"\n",
"To make sure we (1) resize to the appropriate size (2) use the appropriate image mean and standard deviation for the model architecture we are going to use, we instantiate what is called a feature extractor with the `AutoFeatureExtractor.from_pretrained` method.\n",
"\n",
"This feature extractor is a minimal preprocessor that can be used to prepare images for inference."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 338,
"referenced_widgets": [
"43ed5eea6e854c929400cacd09009d5a",
"df6fd8bcf0bd4cbf8b29faa94bd2b34b",
"528c167c7f6342acad2fa73034d9024c",
"7e29403666f8487daf0825acbb63636e",
"4be0b71eee094946a00786fd6bb23f25",
"e91ac28da9484c6398381c48d8154183",
"ad6684672345402e9140caa4e623c2b5",
"4e5e098964ab44a8aee3fd95ff6ec0e1",
"e9883ae85eac458fad0f966229e9e513",
"ec88dc30ea1d4c9c9ca71a1026220a47",
"d3a7387094884e619cc9b17eafb6d549"
]
},
"id": "G1bX4lGAO_d9",
"outputId": "c77a2eab-ef02-43eb-90bd-f7f84c62583a",
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"from transformers import AutoFeatureExtractor\n",
"\n",
"feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)\n",
"feature_extractor"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qUtxmoMvqml1",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"The Datasets library is made for processing data very easily. We can write custom functions, which can then be applied on an entire dataset (either using [`.map()`](https://huggingface.co/docs/datasets/package_reference/main_classes.html?highlight=map#datasets.Dataset.map) or [`.set_transform()`](https://huggingface.co/docs/datasets/package_reference/main_classes.html?highlight=set_transform#datasets.Dataset.set_transform)).\n",
"\n",
"Here we define 2 separate functions, one for training (which includes data augmentation) and one for validation (which only includes resizing, center cropping and normalizing). "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4O_p3WrpRyej",
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"import kornia as K\n",
"from torch import nn\n",
"import torch\n",
"from PIL import Image\n",
"import numpy as np\n",
"\n",
"class PreProcess(nn.Module):\n",
" \"\"\"Module to perform pre-process using Kornia on torch tensors.\"\"\"\n",
" def __init__(self) -> None:\n",
" super().__init__()\n",
" \n",
" @torch.no_grad() # disable gradients for effiency\n",
" def forward(self, x: Image) -> torch.Tensor:\n",
" x_tmp: np.ndarray = np.array(x) # HxWxC\n",
" x_out: torch.Tensor = K.image_to_tensor(x_tmp, keepdim=True) # CxHxW\n",
" return x_out.float() / 255.0\n",
"\n",
"train_transforms = nn.Sequential(\n",
" PreProcess(),\n",
" K.augmentation.Resize(size=224, side=\"short\"),\n",
" K.augmentation.CenterCrop(size=224),\n",
" K.augmentation.RandomHorizontalFlip(p=0.5),\n",
" K.augmentation.ColorJiggle(),\n",
" K.augmentation.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),\n",
")\n",
"\n",
"val_transforms = nn.Sequential(\n",
" PreProcess(),\n",
" K.augmentation.Resize(size=224, side=\"short\"),\n",
" K.augmentation.CenterCrop(size=224),\n",
" K.augmentation.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),\n",
")\n",
"\n",
"def preprocess_train(example_batch):\n",
" \"\"\"Apply train_transforms across a batch.\"\"\"\n",
" example_batch[\"pixel_values\"] = [train_transforms(image).squeeze() for image in example_batch[\"image\"]]\n",
" return example_batch\n",
"\n",
"def preprocess_val(example_batch):\n",
" \"\"\"Apply val_transforms across a batch.\"\"\"\n",
" example_batch[\"pixel_values\"] = [val_transforms(image).squeeze() for image in example_batch[\"image\"]]\n",
" return example_batch"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RF4O0KFBGXir",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"Next, we can preprocess our dataset by applying these functions. We will use the `set_transform` functionality, which allows to apply the functions above on-the-fly (meaning that they will only be applied when the images are loaded in RAM)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "P13tqfFTZ_F4",
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"# split up training into training + validation\n",
"splits = dataset[\"train\"].train_test_split(test_size=0.1)\n",
"train_ds = splits['train']\n",
"val_ds = splits['test']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "TUs56-mprQi1",
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"train_ds.set_transform(preprocess_train)\n",
"val_ds.set_transform(preprocess_val)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MMw_wQS58a7o",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"Let's access an element to see that we've added a \"pixel_values\" feature:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Ng9TAlDV8d7r",
"outputId": "045cb11f-d5e6-4552-8ef9-918f867b828f",
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"train_ds[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "IYZhy_zOswNE",
"outputId": "ea2ad8ad-6144-4e25-82bc-b390f82c7599",
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"train_ds[0]['pixel_values'].shape"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HOXmyPQ76Qv9",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"### Training the model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0a-2YT7O6ayC",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"Now that our data is ready, we can download the pretrained model and fine-tune it. For classification we use the `AutoModelForImageClassification` class. Calling the `from_pretrained` method on it will download and cache the weights for us. As the label ids and the number of labels are dataset dependent, we pass `label2id`, and `id2label` alongside the `model_checkpoint` here. This will make sure a custom classification head will be created (with a custom number of output neurons).\n",
"\n",
"NOTE: in case you're planning to fine-tune an already fine-tuned checkpoint, like [facebook/convnext-tiny-224](https://huggingface.co/facebook/convnext-tiny-224) (which has already been fine-tuned on ImageNet-1k), then you need to provide the additional argument `ignore_mismatched_sizes=True` to the `from_pretrained` method. This will make sure the output head (with 1000 output neurons) is thrown away and replaced by a new, randomly initialized classification head that includes a custom number of output neurons. You don't need to specify this argument in case the pre-trained model doesn't include a head. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 203,
"referenced_widgets": [
"9e1895e9f61f435682c89d3262ea441c",
"20c3f774d6d9484e8f1c948f7ca5e046",
"c221cfb9f0274eaebe5bd732f3c6b148",
"8ffd71d3f46947ba831098300ab5fb94",
"a544c18d3b994b14aa4557c718088d4b",
"10ea68a58d754fe6bdeff82dc343b52b",
"9ffcf926bbf04f82a716ad2ed3a8b6d1",
"c0791510a0e74674a5dc5b46b9c2c87e",
"b20d698d71e8495c9499b836a1559f3a",
"1db4970b87b348148a72525014e2d4dc",
"6abd984a1ae442df8d0b43277620fd67",
"56d6a39887064be39b64c8c21ce31581",
"3bca40ad7c1846348a696ef66f350da7",
"3bea54844d55496bab8de6e0e553243a",
"dd59d93d59fc4710867bfd69ccae5e6d",
"d6d134de78c14314b7627ba0f3b6891e",
"268807b9219e4dae9f9073b37d9487fa",
"90267331af2d4d91bfd703557e3ff6d4",
"257a1e82772c4cfd90c0daef86eed031",
"ff9e39b94d064b8d96f1b338c8e9123f",
"5b12316940c948f7a443537d7f7d3dcd",
"3a237d1e47cb4432aa3bdaeac2c586d1"
]
},
"id": "X9DDujL0q1ac",
"outputId": "844dbe9d-b72b-4b8e-d514-3c2deeccd593",
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"from transformers import AutoModelForImageClassification, TrainingArguments, Trainer\n",
"\n",
"model = AutoModelForImageClassification.from_pretrained(\n",
" model_checkpoint, \n",
" label2id=label2id,\n",
" id2label=id2label,\n",
" ignore_mismatched_sizes = True, # provide this in case you're planning to fine-tune an already fine-tuned checkpoint\n",
")\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "U8EmET_f6458",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"The warning is telling us we are throwing away some weights (the weights and bias of the `classifier` layer) and randomly initializing some other (the weights and bias of a new `classifier` layer). This is expected in this case, because we are adding a new head for which we don't have pretrained weights, so the library warns us we should fine-tune this model before using it for inference, which is exactly what we are going to do."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FEfyuq1U8hDT",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"#### Trainer\n",
"\n",
"The [Trainer class](https://huggingface.co/docs/transformers/main_classes/trainer) provides an API for feature-complete training in PyTorch for most standard use cases. To instantiate a `Trainer`, we will need to define the training configuration and the evaluation metric. The most important is the [`TrainingArguments`](https://huggingface.co/transformers/main_classes/trainer.html#transformers.TrainingArguments), which is a class that contains all the attributes to customize the training. It requires one folder name, which will be used to save the checkpoints of the model.\n",
"\n",
"Most of the training arguments are pretty self-explanatory, but one that is quite important here is `remove_unused_columns=False`. This one will drop any features not used by the model's call function. By default it's `True` because usually it's ideal to drop unused feature columns, making it easier to unpack inputs into the model's call function. But, in our case, we need the unused features ('image' in particular) in order to create 'pixel_values'."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xc_MTm0Ks3DF",
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"model_name = model_checkpoint.split(\"/\")[-1]\n",
"\n",
"args = TrainingArguments(\n",
" f\"{model_name}-finetuned-eurosat-kornia\",\n",
" remove_unused_columns=False,\n",
" eval_strategy = \"epoch\",\n",
" save_strategy = \"epoch\",\n",
" learning_rate=5e-5,\n",
" per_device_train_batch_size=batch_size,\n",
" gradient_accumulation_steps=4,\n",
" per_device_eval_batch_size=batch_size,\n",
" num_train_epochs=3,\n",
" warmup_ratio=0.1,\n",
" logging_steps=10,\n",
" load_best_model_at_end=True,\n",
" metric_for_best_model=\"accuracy\",\n",
" push_to_hub=True,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xi6JYNYs8lJO",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"Here we set the evaluation to be done at the end of each epoch, tweak the learning rate, use the `batch_size` defined at the top of the notebook and customize the number of epochs for training, as well as the weight decay. Since the best model might not be the one at the end of training, we ask the `Trainer` to load the best model it saved (according to `metric_name`) at the end of training.\n",
"\n",
"The last argument `push_to_hub` allows the Trainer to push the model to the [Hub](https://huggingface.co/models) regularly during training. Remove it if you didn't follow the installation steps at the top of the notebook. If you want to save your model locally with a name that is different from the name of the repository, or if you want to push your model under an organization and not your name space, use the `hub_model_id` argument to set the repo name (it needs to be the full name, including your namespace: for instance `\"nielsr/vit-finetuned-cifar10\"` or `\"huggingface/nielsr/vit-finetuned-cifar10\"`)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2VE_HSha9RZk",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"Next, we need to define a function for how to compute the metrics from the predictions, which will just use the `metric` we loaded earlier. The only preprocessing we have to do is to take the argmax of our predicted logits:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "EVWfiBuv2uCS",
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"# the compute_metrics function takes a Named Tuple as input:\n",
"# predictions, which are the logits of the model as Numpy arrays,\n",
"# and label_ids, which are the ground-truth labels as Numpy arrays.\n",
"def compute_metrics(eval_pred):\n",
" \"\"\"Computes accuracy on a batch of predictions\"\"\"\n",
" predictions = np.argmax(eval_pred.predictions, axis=1)\n",
" return metric.compute(predictions=predictions, references=eval_pred.label_ids)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y0PqjzHQVutb",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"We also define a `collate_fn`, which will be used to batch examples together.\n",
"Each batch consists of 2 keys, namely `pixel_values` and `labels`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "u0WcwsX7rW9w",
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"import torch\n",
"\n",
"def collate_fn(examples):\n",
" pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n",
" labels = torch.tensor([example[\"label\"] for example in examples])\n",
" return {\"pixel_values\": pixel_values, \"labels\": labels}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yTF0dWw49fB9",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"Then we just need to pass all of this along with our datasets to the `Trainer`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "McVoaCPr3Cj-",
"outputId": "184c1d80-00af-49c1-f137-e7287acdbdeb",
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"trainer = Trainer(\n",
" model,\n",
" args,\n",
" train_dataset=train_ds,\n",
" eval_dataset=val_ds,\n",
" tokenizer=feature_extractor,\n",
" compute_metrics=compute_metrics,\n",
" data_collator=collate_fn,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ltokP9mO9pjI",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"You might wonder why we pass along the `feature_extractor` as a tokenizer when we already preprocessed our data. This is only to make sure the feature extractor configuration file (stored as JSON) will also be uploaded to the repo on the hub."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9j6VNsGP97LG",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"Now we can finetune our model by calling the `train` method:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000,
"referenced_widgets": [
"a926f15ee449457fbf325846e6a8c738",
"31ba50fb1fa7400ba28117a6f808f2fa",
"76e7eb0630044940a7e20010884c90b2",
"1920117d682e4174a850d0e1b24c9551",
"746a5b6710914aebab7654f6c20435c4",
"69a8fa38e31543faa011edc197c60931",
"3e91e14cbeaa4f718960b11bed8298d6",
"1500a6d5a104448e86029bedda65c487",
"f5de8ebae0894b009aac08ee86160e58",
"a5952ee376e1418bbb659609c9a23e12",
"eaf7c8f4e9884734978bde6b797c3d7a",
"82b7c8ecbd2f4b9ab09d3c7ee9f17352",
"ef00e621a4b04f2b9550704ecb47b541",
"c6cc09847f394b3b974b15a026b7dfd4",
"5d184fb5a4284c979f6671ec01b13c32",
"97d4ca9d132e4ed6b0f33ac7bef8ae41",
"c2831860c6cb4e9ea52432f45f60202d",
"69087e2c4bdf45f2b93377061028d6b6",
"7ed3b32b51df4d5a90787700228e2add",
"72ac166cd3fe4244acdcf7732269cfe2",
"4c93dbdc25b7444dbb67f17e7905d59e",
"d3f6d6367d694b4290aa3901458c6e1a"
]
},
"id": "Pps61vF_4QaH",
"outputId": "776f658a-21ec-4d1c-a490-2fcc335c53ee",
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"train_results = trainer.train()\n",
"# rest is optional but nice to have\n",
"trainer.save_model()\n",
"trainer.log_metrics(\"train\", train_results.metrics)\n",
"trainer.save_metrics(\"train\", train_results.metrics)\n",
"trainer.save_state()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Vyb-58x_-A0e",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"We can check with the `evaluate` method that our `Trainer` did reload the best model properly (if it was not the last one):"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 216
},
"id": "niniUAnb5IrR",
"outputId": "d6042a8c-5766-4542-f3bd-50e17a6a88dc",
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"metrics = trainer.evaluate()\n",
"# some nice to haves:\n",
"trainer.log_metrics(\"eval\", metrics)\n",
"trainer.save_metrics(\"eval\", metrics)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ymwN-SIR-NDF",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"You can now upload the result of the training to the Hub, just execute this instruction (note that the Trainer will automatically create a model card as well as Tensorboard logs - see the \"Training metrics\" tab - amazing isn't it?):"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 246,
"referenced_widgets": [
"a7a8d06dc4d146dcb67c768e12bcdc53",
"32955c73047f4fc7b656e381852614c7",
"0f2c3be3663e4357935be76395e9abd3",
"9503d58abcab45eea56984f00cb8562c",
"7a23bbe2ce824c699691120ea84220f4",
"ce93940729774f6aa246606efe7cc5eb",
"db8ca3f6861d4b6eb619d33306aeeda8",
"22921fe1bb3b4a9cb4cbf23eb4ea4930",
"c21ffa7c0ba64aaa8f1187923f788572",
"13952af7202349c2bce9a66aec24c452",
"d3d37d996cbc4fbf8b0e032af39644e2"
]
},
"id": "4aNMErFz-GzX",
"outputId": "773c09ff-cee2-4662-93ce-ea56717d28cd",
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"trainer.push_to_hub()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cZQnNUsI-Q4S",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"You can now share this model with all your friends, family, favorite pets: they can all load it with the identifier `\"your-username/the-name-you-picked\"` so for instance:\n",
"\n",
"```python\n",
"from transformers import AutoModelForImageClassification, AutoFeatureExtractor\n",
"\n",
"feature_extractor = AutoFeatureExtractor.from_pretrained(\"nielsr/my-awesome-model\")\n",
"model = AutoModelForImageClassification.from_pretrained(\"nielsr/my-awesome-model\")\n",
"\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "049gH1wt-Akp",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"## Inference\n",
"\n",
"Let's say you have a new image, on which you'd like to make a prediction. Let's load a satellite image of a highway, and see how the model does."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 81
},
"id": "UX6dwmT7GP91",
"outputId": "29c1a967-680a-477d-80c9-b2536ad00787",
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"from PIL import Image\n",
"import requests\n",
"\n",
"url = \"https://datasets-server.huggingface.co/assets/nielsr/eurosat-demo/--/nielsr--eurosat-demo/train/0/image/image.jpg\"\n",
"image = Image.open(requests.get(url, stream=True).raw)\n",
"image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "91-Ibh1--oI3",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"We'll load the feature extractor and model from the hub (here, we use the [Auto Classes](https://huggingface.co/docs/transformers/model_doc/auto#transformers.AutoModelForImageClassification), which will make sure the appropriate classes will be loaded automatically based on the `config.json` and `preprocessor_config.json` files of the repo on the hub):"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "xzwvix8X-st3",
"outputId": "ddb4e9fc-7237-4aac-f9e9-e5e94735e906",
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"from transformers import AutoModelForImageClassification, AutoFeatureExtractor\n",
"\n",
"repo_name = \"swin-tiny-patch4-window7-224-finetuned-eurosat-kornia\"\n",
"\n",
"feature_extractor = AutoFeatureExtractor.from_pretrained(repo_name)\n",
"model = AutoModelForImageClassification.from_pretrained(repo_name)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7oDoe_38AY3X",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"We'll apply the exact same transformations as we did for validation. This involves 1) rescaling 2) resizing the shorter edge 3) center cropping 4) normalizing."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "OOKhRKmh9tsw",
"outputId": "82290932-b5ce-4ee7-e3ea-f5f21b757eda",
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"# prepare image for the model\n",
"pixel_values = val_transforms(image.convert(\"RGB\"))\n",
"print(pixel_values.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "33E44G86_RtL",
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"import torch\n",
"\n",
"# forward pass\n",
"with torch.no_grad():\n",
" outputs = model(pixel_values)\n",
" logits = outputs.logits"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4ctUvqfs_Yyn",
"outputId": "67d2c1c5-5eae-4a9d-cfde-d0a3d788c257",
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"predicted_class_idx = logits.argmax(-1).item()\n",
"print(\"Predicted class:\", model.config.id2label[predicted_class_idx])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "N3yJFIIP_k01",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"Looks like our model got it correct! "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-2A5W8dF_qYt",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"## Pipeline API\n",
"\n",
"An alternative way to quickly perform inference with any model on the hub is by leveraging the [Pipeline API](https://huggingface.co/docs/transformers/main_classes/pipelines), which abstracts away all the steps we did manually above for us. It will perform the preprocessing, forward pass and postprocessing all in a single object. \n",
"\n",
"Let's showcase this for our trained model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "I7mz7QTo_jWa",
"outputId": "066ae7a2-ce84-4ba0-be72-7ea1f8140d4d",
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"from transformers import pipeline\n",
"\n",
"pipe = pipeline(\"image-classification\", \"nielsr/swin-tiny-patch4-window7-224-finetuned-eurosat-kornia\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "fPiuLDx3_9SY",
"outputId": "180093bc-ee15-4c05-c305-6a09c8e140b9",
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"pipe(image)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BVXM6-g4AJmy",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"As we can see, it does not only show the class label with the highest probability, but does return the top 5 labels, with their corresponding scores. Note that the pipelines also work with local models and feature extractors:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "B8kmO1NMAAXs",
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"pipe = pipeline(\"image-classification\", \n",
" model=model,\n",
" feature_extractor=feature_extractor)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "NfFH9eLMAdCX",
"outputId": "a1b96de3-285b-4208-9fe2-16b6802c6e20",
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"pipe(image)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "zO4XGe8_Ao5-",
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "image_classification_kornia.ipynb",
"provenance": []
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"013ca1d8272343a2a50a8bc49c22a76b": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"01ea35ba78b440d392b2c5e75ea74464": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_30a8bd848ee1497691b7605d4dd32115",
"placeholder": "​",
"style": "IPY_MODEL_c1537b4150ab4cfca215d70f6095a110",
"value": "Downloading data files: "
}
},
"0a6cef96ab1841068e9027a30e76ed1e": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"0f2c3be3663e4357935be76395e9abd3": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_22921fe1bb3b4a9cb4cbf23eb4ea4930",
"max": 363,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_c21ffa7c0ba64aaa8f1187923f788572",
"value": 363
}
},
"0f5d923b2b534881bd495dcaf47b94ec": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"0fd3d466727b4ac3bdbe1b32c67b0c26": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "info",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_432c7d96cdc840d6b54567474d09ab2f",
"max": 1,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_72adfd5afe784624a673284e17a813c8",
"value": 1
}
},
"10ea68a58d754fe6bdeff82dc343b52b": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"13952af7202349c2bce9a66aec24c452": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"1453dc44aa334c7097608220933e46c2": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"1500a6d5a104448e86029bedda65c487": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"164e22e2e1c64ffc99b77f787cad3a22": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_c79df4f5086a4cb4b251d8614928ebe2",
"placeholder": "​",
"style": "IPY_MODEL_71323ca5991e4c0399cb106cb59193fc",
"value": "Downloading builder script: 100%"
}
},
"17ad910846c647eea2967a556fa5d636": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_d19adaf170c1447f90538b85b2546a7a",
"IPY_MODEL_b850715786e741fab64246dd616e4b0b",
"IPY_MODEL_775c74c822b94d55897501502b8b3890"
],
"layout": "IPY_MODEL_992069d6fbdd4046b6aa734b39a23ed5"
}
},
"1920117d682e4174a850d0e1b24c9551": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_a5952ee376e1418bbb659609c9a23e12",
"placeholder": "​",
"style": "IPY_MODEL_eaf7c8f4e9884734978bde6b797c3d7a",
"value": " 105M/105M [02:22<00:00, 1.35MB/s]"
}
},
"1aa408392e434cde9bdc43e0f670617d": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"1aed4493a06e459193d06161e5dfe09e": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"1db4970b87b348148a72525014e2d4dc": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"20c3f774d6d9484e8f1c948f7ca5e046": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_10ea68a58d754fe6bdeff82dc343b52b",
"placeholder": "​",
"style": "IPY_MODEL_9ffcf926bbf04f82a716ad2ed3a8b6d1",
"value": "Downloading config.json: 100%"
}
},
"2263ea69fe2748e59e241121c0ee85db": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "PasswordModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "PasswordModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "PasswordView",
"continuous_update": true,
"description": "Token:",
"description_tooltip": null,
"disabled": false,
"layout": "IPY_MODEL_6e51c6d618164aeda48942aa1850ecf0",
"placeholder": "​",
"style": "IPY_MODEL_2fa7bbee9cee40479fd62976954dabe9",
"value": ""
}
},
"22921fe1bb3b4a9cb4cbf23eb4ea4930": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"257a1e82772c4cfd90c0daef86eed031": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"268807b9219e4dae9f9073b37d9487fa": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"2b479dce1b3d427483504e57fe4acd7a": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"2c21a4a4464343729719d8f3125a7535": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"2ceb2fa333a94f5fb5207b5303bf0105": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_1aed4493a06e459193d06161e5dfe09e",
"placeholder": "​",
"style": "IPY_MODEL_36d59e2f979449fa8e27c58f44c1f801",
"value": "100%"
}
},
"2fa7bbee9cee40479fd62976954dabe9": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"30a8bd848ee1497691b7605d4dd32115": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"31ba50fb1fa7400ba28117a6f808f2fa": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_69a8fa38e31543faa011edc197c60931",
"placeholder": "​",
"style": "IPY_MODEL_3e91e14cbeaa4f718960b11bed8298d6",
"value": "Upload file pytorch_model.bin: 100%"
}
},
"32955c73047f4fc7b656e381852614c7": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_ce93940729774f6aa246606efe7cc5eb",
"placeholder": "​",
"style": "IPY_MODEL_db8ca3f6861d4b6eb619d33306aeeda8",
"value": "Upload file runs/Aug29_08-52-09_cc75a613d50e/events.out.tfevents.1661765112.cc75a613d50e.286.2: 100%"
}
},
"32e6e84f4de54b26a7218bf720324752": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"3653830200414a2b9a5150eea235fcf1": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"36d59e2f979449fa8e27c58f44c1f801": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"377190442ab84770b7e3d27199ac8563": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"3a237d1e47cb4432aa3bdaeac2c586d1": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"3b13fa319738413892aa0175b8d2c79a": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_c243a77c77eb4e51b217b0513c2d8546",
"max": 1,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_2c21a4a4464343729719d8f3125a7535",
"value": 1
}
},
"3bca40ad7c1846348a696ef66f350da7": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_268807b9219e4dae9f9073b37d9487fa",
"placeholder": "​",
"style": "IPY_MODEL_90267331af2d4d91bfd703557e3ff6d4",
"value": "Downloading pytorch_model.bin: 100%"
}
},
"3bea54844d55496bab8de6e0e553243a": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_257a1e82772c4cfd90c0daef86eed031",
"max": 113476015,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_ff9e39b94d064b8d96f1b338c8e9123f",
"value": 113476015
}
},
"3e91e14cbeaa4f718960b11bed8298d6": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"432c7d96cdc840d6b54567474d09ab2f": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": "20px"
}
},
"43ed5eea6e854c929400cacd09009d5a": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_df6fd8bcf0bd4cbf8b29faa94bd2b34b",
"IPY_MODEL_528c167c7f6342acad2fa73034d9024c",
"IPY_MODEL_7e29403666f8487daf0825acbb63636e"
],
"layout": "IPY_MODEL_4be0b71eee094946a00786fd6bb23f25"
}
},
"45cc5f862ebc482d9186d3d4349cf602": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_8926f783fec742ef954ceb979f5abc70",
"max": 1,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_c81f799609794db09cb4f5d19c39333c",
"value": 1
}
},
"4ba73b8913d8440ea8ddc4c0e08ed145": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"4be0b71eee094946a00786fd6bb23f25": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"4c4fb79433cf49faac5105c5403f4bc8": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_e19aa06a93034787a1a422c8297c50cf",
"IPY_MODEL_45cc5f862ebc482d9186d3d4349cf602",
"IPY_MODEL_c43a2cf9f9ee42cb8a114ced6ecce524"
],
"layout": "IPY_MODEL_f1ab9e0ecd59406b825926098912b968"
}
},
"4c93dbdc25b7444dbb67f17e7905d59e": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"4e5e098964ab44a8aee3fd95ff6ec0e1": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"4ffe6ca0522542f2a96cd3191d6a99c6": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"528c167c7f6342acad2fa73034d9024c": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_4e5e098964ab44a8aee3fd95ff6ec0e1",
"max": 255,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_e9883ae85eac458fad0f966229e9e513",
"value": 255
}
},
"56d6a39887064be39b64c8c21ce31581": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_3bca40ad7c1846348a696ef66f350da7",
"IPY_MODEL_3bea54844d55496bab8de6e0e553243a",
"IPY_MODEL_dd59d93d59fc4710867bfd69ccae5e6d"
],
"layout": "IPY_MODEL_d6d134de78c14314b7627ba0f3b6891e"
}
},
"58f000f788c14a268cc102d5237b8c04": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_f6693950a02d4f9d933c2e32e323f521",
"placeholder": "​",
"style": "IPY_MODEL_adaaa2e702174b2b87894c0d3fcb1a57",
"value": " 0/0 [00:00<?, ?it/s]"
}
},
"5b12316940c948f7a443537d7f7d3dcd": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"5d184fb5a4284c979f6671ec01b13c32": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_4c93dbdc25b7444dbb67f17e7905d59e",
"placeholder": "​",
"style": "IPY_MODEL_d3f6d6367d694b4290aa3901458c6e1a",
"value": " 14.2k/14.2k [02:21<00:00, 78.8B/s]"
}
},
"603016ca5f7f4e939e5a7f1f57b8dd33": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_f22ca12d18354eec80befa03d312fc44",
"max": 1,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_c8eb2750eb2244699dc1492e658a9ccb",
"value": 1
}
},
"69087e2c4bdf45f2b93377061028d6b6": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"69a8fa38e31543faa011edc197c60931": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"6abd984a1ae442df8d0b43277620fd67": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"6da041a6484b4d1cbfa58cfcd265f88e": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"6e51c6d618164aeda48942aa1850ecf0": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"6f52fb52f05d465090227ede5cf2a30c": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"6f7c2ec995714255be700b3ffaba9f9e": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"7058d1e258de40ebad13ff790915cd3c": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_01ea35ba78b440d392b2c5e75ea74464",
"IPY_MODEL_e2e16d974fb249dd94a11507f9224a2d",
"IPY_MODEL_58f000f788c14a268cc102d5237b8c04"
],
"layout": "IPY_MODEL_377190442ab84770b7e3d27199ac8563"
}
},
"71323ca5991e4c0399cb106cb59193fc": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"72ac166cd3fe4244acdcf7732269cfe2": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"72adfd5afe784624a673284e17a813c8": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"7398432c4cc74d2d8f9d61c40b786021": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_d74c3e6888b84fb781f2b343534c369a",
"placeholder": "​",
"style": "IPY_MODEL_0f5d923b2b534881bd495dcaf47b94ec",
"value": " 4.20k/4.20k [00:00<00:00, 9.59kB/s]"
}
},
"746a5b6710914aebab7654f6c20435c4": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"76e7eb0630044940a7e20010884c90b2": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_1500a6d5a104448e86029bedda65c487",
"max": 110417455,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_f5de8ebae0894b009aac08ee86160e58",
"value": 110417455
}
},
"775c74c822b94d55897501502b8b3890": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_1aa408392e434cde9bdc43e0f670617d",
"placeholder": "​",
"style": "IPY_MODEL_0a6cef96ab1841068e9027a30e76ed1e",
"value": " 94.3M/94.3M [00:10<00:00, 10.8MB/s]"
}
},
"78fa30c22d474c3da827b6af042b1e7d": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"7a23bbe2ce824c699691120ea84220f4": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"7cfa331622ff4187a167c4a981491457": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"7e29403666f8487daf0825acbb63636e": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_ec88dc30ea1d4c9c9ca71a1026220a47",
"placeholder": "​",
"style": "IPY_MODEL_d3a7387094884e619cc9b17eafb6d549",
"value": " 255/255 [00:00<00:00, 2.97kB/s]"
}
},
"7ed3b32b51df4d5a90787700228e2add": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"825be3a2828c46398b79e6279c45236a": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_e32780c830504c508b85e67343ab88e9",
"IPY_MODEL_0fd3d466727b4ac3bdbe1b32c67b0c26",
"IPY_MODEL_8c70d4279bac4e1e9efdc3d3b9d45e34"
],
"layout": "IPY_MODEL_a5dabc9865f34c9089b0eeca3897e467"
}
},
"82b7c8ecbd2f4b9ab09d3c7ee9f17352": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_ef00e621a4b04f2b9550704ecb47b541",
"IPY_MODEL_c6cc09847f394b3b974b15a026b7dfd4",
"IPY_MODEL_5d184fb5a4284c979f6671ec01b13c32"
],
"layout": "IPY_MODEL_97d4ca9d132e4ed6b0f33ac7bef8ae41"
}
},
"83dd806f4a2b45d7bd34fc0c106a0eac": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_6f7c2ec995714255be700b3ffaba9f9e",
"max": 4203,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_b2a6342c762f45dda3b08ad3d9d0a6ad",
"value": 4203
}
},
"8926f783fec742ef954ceb979f5abc70": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"8c70d4279bac4e1e9efdc3d3b9d45e34": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_c07bfe12d2c04a0f9c5696797a1d67a5",
"placeholder": "​",
"style": "IPY_MODEL_fbb13fc905f240dab44f438f11421f89",
"value": " 26705/0 [00:02<00:00, 10567.37 examples/s]"
}
},
"8ffd71d3f46947ba831098300ab5fb94": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_1db4970b87b348148a72525014e2d4dc",
"placeholder": "​",
"style": "IPY_MODEL_6abd984a1ae442df8d0b43277620fd67",
"value": " 70.1k/70.1k [00:00<00:00, 80.5kB/s]"
}
},
"90267331af2d4d91bfd703557e3ff6d4": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"907b655aeed94f1c97fb564525e04d01": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "ButtonStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ButtonStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"button_color": null,
"font_weight": ""
}
},
"92f5865c94ce4986a36d71e2cf1db91e": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": "center",
"align_self": null,
"border": null,
"bottom": null,
"display": "flex",
"flex": null,
"flex_flow": "column",
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": "50%"
}
},
"9503d58abcab45eea56984f00cb8562c": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_13952af7202349c2bce9a66aec24c452",
"placeholder": "​",
"style": "IPY_MODEL_d3d37d996cbc4fbf8b0e032af39644e2",
"value": " 363/363 [00:02<?, ?B/s]"
}
},
"97d4ca9d132e4ed6b0f33ac7bef8ae41": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"98245a2bb52d42369036d025ad91efdd": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"992069d6fbdd4046b6aa734b39a23ed5": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"9a3fd883fbcd4192972daee5e9b724e8": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "VBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "VBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "VBoxView",
"box_style": "",
"children": [
"IPY_MODEL_e4e975ccdf71499e9e7e2628b4d12381",
"IPY_MODEL_2263ea69fe2748e59e241121c0ee85db",
"IPY_MODEL_ce8ff720c1194f9fa576fa0ab72b9434",
"IPY_MODEL_a245a58b01304077b758257436e69b1a"
],
"layout": "IPY_MODEL_92f5865c94ce4986a36d71e2cf1db91e"
}
},
"9c22864ff042460584353d870e79b5fb": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_2ceb2fa333a94f5fb5207b5303bf0105",
"IPY_MODEL_3b13fa319738413892aa0175b8d2c79a",
"IPY_MODEL_bd3fb4774f2544abbfe84a8f77297022"
],
"layout": "IPY_MODEL_e161abad994748e7be3655cffa4f7468"
}
},
"9e1895e9f61f435682c89d3262ea441c": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_20c3f774d6d9484e8f1c948f7ca5e046",
"IPY_MODEL_c221cfb9f0274eaebe5bd732f3c6b148",
"IPY_MODEL_8ffd71d3f46947ba831098300ab5fb94"
],
"layout": "IPY_MODEL_a544c18d3b994b14aa4557c718088d4b"
}
},
"9ffcf926bbf04f82a716ad2ed3a8b6d1": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"a245a58b01304077b758257436e69b1a": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_013ca1d8272343a2a50a8bc49c22a76b",
"placeholder": "​",
"style": "IPY_MODEL_98245a2bb52d42369036d025ad91efdd",
"value": "\nPro Tip: If you don't already have one, you can create a dedicated\n'notebooks' token with 'write' access, that you can then easily reuse for all\nnotebooks. "
}
},
"a544c18d3b994b14aa4557c718088d4b": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"a5952ee376e1418bbb659609c9a23e12": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"a5dabc9865f34c9089b0eeca3897e467": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"a73b73fbcc164a81a3a4912f9682ab00": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_6f52fb52f05d465090227ede5cf2a30c",
"placeholder": "​",
"style": "IPY_MODEL_4ffe6ca0522542f2a96cd3191d6a99c6",
"value": " 1/1 [00:07<00:00, 7.30s/it]"
}
},
"a7a8d06dc4d146dcb67c768e12bcdc53": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_32955c73047f4fc7b656e381852614c7",
"IPY_MODEL_0f2c3be3663e4357935be76395e9abd3",
"IPY_MODEL_9503d58abcab45eea56984f00cb8562c"
],
"layout": "IPY_MODEL_7a23bbe2ce824c699691120ea84220f4"
}
},
"a926f15ee449457fbf325846e6a8c738": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_31ba50fb1fa7400ba28117a6f808f2fa",
"IPY_MODEL_76e7eb0630044940a7e20010884c90b2",
"IPY_MODEL_1920117d682e4174a850d0e1b24c9551"
],
"layout": "IPY_MODEL_746a5b6710914aebab7654f6c20435c4"
}
},
"aa59cf09e9bf4aba951ab61b47c1ed40": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"aca7889b9fd34829bc7c7ed72c467aa1": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": "20px"
}
},
"ad6684672345402e9140caa4e623c2b5": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"adaaa2e702174b2b87894c0d3fcb1a57": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"b20d698d71e8495c9499b836a1559f3a": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"b2a6342c762f45dda3b08ad3d9d0a6ad": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"b850715786e741fab64246dd616e4b0b": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_3653830200414a2b9a5150eea235fcf1",
"max": 94280567,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_2b479dce1b3d427483504e57fe4acd7a",
"value": 94280567
}
},
"bb520a5722484f39bb93face96e36d69": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"bd3fb4774f2544abbfe84a8f77297022": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_dfae0ee344db4147a758e6b50daa66c8",
"placeholder": "​",
"style": "IPY_MODEL_7cfa331622ff4187a167c4a981491457",
"value": " 1/1 [00:00<00:00, 26.07it/s]"
}
},
"c0791510a0e74674a5dc5b46b9c2c87e": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"c07bfe12d2c04a0f9c5696797a1d67a5": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"c1537b4150ab4cfca215d70f6095a110": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"c21ffa7c0ba64aaa8f1187923f788572": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"c221cfb9f0274eaebe5bd732f3c6b148": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_c0791510a0e74674a5dc5b46b9c2c87e",
"max": 71813,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_b20d698d71e8495c9499b836a1559f3a",
"value": 71813
}
},
"c243a77c77eb4e51b217b0513c2d8546": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"c2831860c6cb4e9ea52432f45f60202d": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"c3d49b63b8204a04ab075188f49b0d93": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_6da041a6484b4d1cbfa58cfcd265f88e",
"placeholder": "​",
"style": "IPY_MODEL_d9c9f4a8aca44e09a8136331e545ea88",
"value": "Extracting data files: 100%"
}
},
"c43a2cf9f9ee42cb8a114ced6ecce524": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_78fa30c22d474c3da827b6af042b1e7d",
"placeholder": "​",
"style": "IPY_MODEL_ffaa450ad4e24d66a466ab53b424260f",
"value": " 1/1 [00:11<00:00, 11.85s/it]"
}
},
"c6108fb0bf124c3fa22f85bbd4bf9266": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_c3d49b63b8204a04ab075188f49b0d93",
"IPY_MODEL_603016ca5f7f4e939e5a7f1f57b8dd33",
"IPY_MODEL_a73b73fbcc164a81a3a4912f9682ab00"
],
"layout": "IPY_MODEL_f05a4a836fb84bf7a5f2e75440bb026d"
}
},
"c6cc09847f394b3b974b15a026b7dfd4": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_7ed3b32b51df4d5a90787700228e2add",
"max": 14573,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_72ac166cd3fe4244acdcf7732269cfe2",
"value": 14573
}
},
"c79df4f5086a4cb4b251d8614928ebe2": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"c81f799609794db09cb4f5d19c39333c": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"c8eb2750eb2244699dc1492e658a9ccb": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"ce8ff720c1194f9fa576fa0ab72b9434": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "ButtonModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ButtonModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ButtonView",
"button_style": "",
"description": "Login",
"disabled": false,
"icon": "",
"layout": "IPY_MODEL_bb520a5722484f39bb93face96e36d69",
"style": "IPY_MODEL_907b655aeed94f1c97fb564525e04d01",
"tooltip": ""
}
},
"ce93940729774f6aa246606efe7cc5eb": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"d19adaf170c1447f90538b85b2546a7a": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_4ba73b8913d8440ea8ddc4c0e08ed145",
"placeholder": "​",
"style": "IPY_MODEL_f5ded311e5ad49a4b7f77c5da177ffb9",
"value": "Downloading data: 100%"
}
},
"d37300abc9474836803687ab739be99d": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"d3a7387094884e619cc9b17eafb6d549": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"d3d37d996cbc4fbf8b0e032af39644e2": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"d3f6d6367d694b4290aa3901458c6e1a": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"d6d134de78c14314b7627ba0f3b6891e": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"d74c3e6888b84fb781f2b343534c369a": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"d9c9f4a8aca44e09a8136331e545ea88": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"db8ca3f6861d4b6eb619d33306aeeda8": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"dd59d93d59fc4710867bfd69ccae5e6d": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_5b12316940c948f7a443537d7f7d3dcd",
"placeholder": "​",
"style": "IPY_MODEL_3a237d1e47cb4432aa3bdaeac2c586d1",
"value": " 108M/108M [00:01<00:00, 61.6MB/s]"
}
},
"df6fd8bcf0bd4cbf8b29faa94bd2b34b": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_e91ac28da9484c6398381c48d8154183",
"placeholder": "​",
"style": "IPY_MODEL_ad6684672345402e9140caa4e623c2b5",
"value": "Downloading preprocessor_config.json: 100%"
}
},
"dfae0ee344db4147a758e6b50daa66c8": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"e161abad994748e7be3655cffa4f7468": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"e19aa06a93034787a1a422c8297c50cf": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_fb8dda387d89412388702e248a9514df",
"placeholder": "​",
"style": "IPY_MODEL_d37300abc9474836803687ab739be99d",
"value": "Downloading data files: 100%"
}
},
"e2e16d974fb249dd94a11507f9224a2d": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_aca7889b9fd34829bc7c7ed72c467aa1",
"max": 1,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_32e6e84f4de54b26a7218bf720324752",
"value": 0
}
},
"e32780c830504c508b85e67343ab88e9": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_f0fc57814beb464aa5c0fa560f3fa8a6",
"placeholder": "​",
"style": "IPY_MODEL_1453dc44aa334c7097608220933e46c2",
"value": "Generating train split: "
}
},
"e4e975ccdf71499e9e7e2628b4d12381": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_aa59cf09e9bf4aba951ab61b47c1ed40",
"placeholder": "​",
"style": "IPY_MODEL_ff6552e277994529b20848de6ad45b99",
"value": "