{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "rEJBSTyZIrIb" }, "source": [ "# Fine-tuning BERT (base or large) on a question-answering task by using the libraries transformers (HF) and DeepSpeed (Microsoft)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- **Credit**: Hugging Face and Microsoft\n", "- **Author**: [Pierre GUILLOU](https://www.linkedin.com/in/pierreguillou/)\n", "- **Date**: 18/06/2021\n", "- **Blog post**: [NLP | Como treinar um modelo de Question Answering em qualquer linguagem baseado no BERT large, melhorando o desempenho do modelo utilizando o BERT base? (estudo de caso em português)](https://medium.com/@pierre_guillou/nlp-como-treinar-um-modelo-de-question-answering-em-qualquer-linguagem-baseado-no-bert-large-1c899262dd96)\n", "- **Link to the model in the Model Hub of Hugging Face**: https://huggingface.co/pierreguillou/bert-large-cased-squad-v1.1-portuguese" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Context" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook is an adaptation of the notebook [question_answering.ipynb](https://github.com/huggingface/notebooks/blob/master/examples/question_answering.ipynb) and of the script [run_qa.py](https://github.com/huggingface/transformers/blob/master/examples/pytorch/question-answering/run_qa.py) of Hugging Face (HF) for fine-tuning a (transformer) Masked Language Model (MLM) like BERT on the QA task with the [Portuguese Squad 1.1 dataset](https://forum.ailab.unb.br/t/datasets-em-portugues/251/4).\n", "\n", "In order to speed up the fine-tuning of the model on only one GPU, the library [DeepSpeed](https://www.deepspeed.ai/) is used by applying the configuration provided by HF in the notebook [transformers + deepspeed CLI](https://github.com/stas00/porting/blob/master/transformers/deepspeed/DeepSpeed_on_colab_CLI.ipynb).\n", "\n", "*Note: the paragraph about Causal language modeling (CLM) is not included in this notebook, and all the non necessary code about Masked Model Language (MLM) has been deleted from the original notebook.*" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Installation" ] }, { "cell_type": "markdown", "metadata": { "id": "X4cRE8IbIrIV" }, "source": [ "If you're opening this Notebook on colab, you will probably need to install 🤗 Transformers and 🤗 Datasets, and DeepSpeed. Uncomment the following cells and run it." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Pytorch" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "MOsHUjgdIrIW", "outputId": "f84a093e-147f-470e-aad9-80fb51193c8e" }, "outputs": [], "source": [ "# !pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "DeepSpeed" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# !pip install git+https://github.com/microsoft/deepspeed" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Datasets, Tokenizers, Transformers" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# git clone https://github.com/huggingface/transformers\n", "# cd transformers\n", "# # examples change a lot so let's pick a sha that we know this notebook will work with\n", "# # comment out/remove the next line if you want the master\n", "# # git checkout d2753dcbec712350\n", "# pip install -e .\n", "# pip install -r examples/pytorch/translation/requirements.txt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this notebook's folder, you need to create symbolic links to 3 files in the transformers folder you just installed." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "#!ln -s ~/transformers/examples/pytorch/question-answering/run_qa.py\n", "#!ln -s ~/transformers/examples/pytorch/question-answering/trainer_qa.py\n", "#!ln -s ~/transformers/examples/pytorch/question-answering/utils_qa.py" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's check our installation." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "python: 3.8.10 (default, Jun 4 2021, 15:09:15) \n", "[GCC 7.5.0]\n", "Pytorch: 1.8.1+cu111\n", "transformers: 4.7.0.dev0\n", "tokenizers: 0.10.3\n", "datasets: 1.8.0\n", "deepspeed: 0.4.1+fa7921e\n" ] } ], "source": [ "import sys; print('python:',sys.version)\n", "import pathlib\n", "from pathlib import Path\n", "\n", "import torch; print('Pytorch:',torch.__version__)\n", "\n", "import transformers; print('transformers:',transformers.__version__)\n", "import tokenizers; print('tokenizers:',tokenizers.__version__)\n", "import datasets; print('datasets:',datasets.__version__)\n", "\n", "import deepspeed; print('deepspeed:',deepspeed.__version__)\n", "\n", "# Versions installed:\n", "# python: 3.8.10 (default, Jun 4 2021, 15:09:15) \n", "# [GCC 7.5.0]\n", "# Pytorch: 1.8.1+cu111\n", "# transformers: 4.7.0.dev0\n", "# tokenizers: 0.10.3\n", "# datasets: 1.8.0\n", "# deepspeed: 0.4.1+fa7921e" ] }, { "cell_type": "markdown", "metadata": { "id": "HFASsisvIrIb" }, "source": [ "If you're opening this notebook locally, make sure your environment has an install from the last version of those libraries.\n", "\n", "You can find a script version of this notebook to fine-tune your model in a distributed fashion using multiple GPUs or TPUs [here](https://github.com/huggingface/transformers/tree/master/examples/question-answering)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this notebook, we will see how to fine-tune one of the [🤗 Transformers](https://github.com/huggingface/transformers) model to a question answering task, which is the task of extracting the answer to a question from a given context. We will see how to easily load a dataset for these kinds of tasks and use the `Trainer` API to fine-tune a model on it.\n", "\n", "![Widget inference representing the QA task](images/question_answering.png)\n", "\n", "**Note:** This notebook finetunes models that answer question by taking a substring of a context, not by generating new text." ] }, { "cell_type": "markdown", "metadata": { "id": "4RRkXuteIrIh" }, "source": [ "This notebook is built to run on any question answering task with the same format as SQUAD (version 1 or 2), with any model checkpoint from the [Model Hub](https://huggingface.co/models) as long as that model has a version with a token classification head and a fast tokenizer (check on [this table](https://huggingface.co/transformers/index.html#bigtable) if this is the case). It might just need some small adjustments if you decide to use a different dataset than the one used here. Depending on you model and the GPU you are using, you might need to adjust the batch size to avoid out-of-memory errors. Set those three parameters, then the rest of the notebook should run smoothly:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "zVvslsfMIrIh" }, "outputs": [], "source": [ "# This flag is the difference between SQUAD v1 or 2 (if you're using another dataset, it indicates if impossible\n", "# answers are allowed or not).\n", "squad_v2 = False\n", "batch_size = 16" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## BERT model" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "model_name_or_path = \"neuralmind/bert-large-portuguese-cased\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "whPRbBNbIrIl" }, "source": [ "### Loading the dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "CKx2zKs5IrIq" }, "source": [ "For our example here, we'll use the [Portuguese Squad 1.1 dataset](https://forum.ailab.unb.br/t/datasets-em-portugues/251/4) which is a translation of the [English SQUAD dataset](https://rajpurkar.github.io/SQuAD-explorer/). The notebook should work with any question answering dataset provided by the 🤗 Datasets library. If you're using your own dataset defined from a JSON or csv file (see the [Datasets documentation](https://huggingface.co/docs/datasets/loading_datasets.html#from-local-files) on how to load them), it might need some adjustments in the names of the columns used." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "dataset_name = \"squad11pt\"" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 270, "referenced_widgets": [ "69caab03d6264fef9fc5649bffff5e20", "3f74532faa86412293d90d3952f38c4a", "50615aa59c7247c4804ca5cbc7945bd7", "fe962391292a413ca55dc932c4279fa7", "299f4b4c07654e53a25f8192bd1d7bbd", "ad04ed1038154081bbb0c1444784dcc2", "7c667ad22b5740d5a6319f1b1e3a8097", "46c2b043c0f84806978784a45a4e203b", "80e2943be35f46eeb24c8ab13faa6578", "de5956b5008d4fdba807bae57509c393", "931db1f7a42f4b46b7ff8c2e1262b994", "6c1db72efff5476e842c1386fadbbdba", "ccd2f37647c547abb4c719b75a26f2de", "d30a66df5c0145e79693e09789d96b81", "5fa26fc336274073abbd1d550542ee33", "2b34de08115d49d285def9269a53f484", "d426be871b424affb455aeb7db5e822e", "160bf88485f44f5cb6eaeecba5e0901f", "745c0d47d672477b9bb0dae77b926364", "d22ab78269cd4ccfbcf70c707057c31b", "d298eb19eeff453cba51c2804629d3f4", "a7204ade36314c86907c562e0a2158b8", "e35d42b2d352498ca3fc8530393786b2", "75103f83538d44abada79b51a1cec09e", "f6253931d90543e9b5fd0bb2d615f73a", "051aa783ff9e47e28d1f9584043815f5", "0984b2a14115454bbb009df71c1cf36f", "8ab9dfce29854049912178941ef1b289", "c9de740e007141958545e269372780a4", "cbea68b25d6d4ba09b2ce0f27b1726d5", "5781fc45cf8d486cb06ed68853b2c644", "d2a92143a08a4951b55bab9bc0a6d0d3", "a14c3e40e5254d61ba146f6ec88eae25", "c4ffe6f624ce4e978a0d9b864544941a", "1aca01c1d8c940dfadd3e7144bb35718", "9fbbaae50e6743f2aa19342152398186", "fea27ca6c9504fc896181bc1ff5730e5", "940d00556cb849b3a689d56e274041c2", "5cdf9ed939fb42d4bf77301c80b8afca", "94b39ccfef0b4b08bf2fb61bb0a657c1", "9a55087c85b74ea08b3e952ac1d73cbe", "2361ab124daf47cc885ff61f2899b2af", "1a65887eb37747ddb75dc4a40f7285f2", "3c946e2260704e6c98593136bd32d921", "50d325cdb9844f62a9ecc98e768cb5af", "aa781f0cfe454e9da5b53b93e9baabd8", "6bb68d3887ef43809eb23feb467f9723", "7e29a8b952cf4f4ea42833c8bf55342f", "dd5997d01d8947e4b1c211433969b89b", "2ace4dc78e2f4f1492a181bcd63304e7", "bbee008c2791443d8610371d1f16b62b", "31b1c8a2e3334b72b45b083688c1a20c", "7fb7c36adc624f7dbbcb4a831c1e4f63", "0b7c8f1939074794b3d9221244b1344d", "a71908883b064e1fbdddb547a8c41743", "2f5223f26c8541fc87e91d2205c39995" ] }, "id": "s_AY1ATSIrIq", "outputId": "fd0578d1-8895-443d-b56f-5908de9f1b6b" }, "outputs": [], "source": [ "# %%time\n", "# if dataset_name == \"squad11pt\":\n", " \n", "# # create dataset folder \n", "# root = Path.cwd()\n", "# path_to_dataset = root.parent/'data'/dataset_name\n", "# path_to_dataset.mkdir(parents=True, exist_ok=True) \n", "\n", "# # Get dataset SQUAD in Portuguese\n", "# %cd {path_to_dataset}\n", "# !wget --load-cookies /tmp/cookies.txt \"https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1Q0IaIlv2h2BC468MwUFmUST0EyN7gNkn' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\\1\\n/p')&id=1Q0IaIlv2h2BC468MwUFmUST0EyN7gNkn\" -O squad-pt.tar.gz && rm -rf /tmp/cookies.txt\n", "\n", "# # unzip \n", "# !tar -xvf squad-pt.tar.gz\n", "\n", "# # Get the train and validation json file in the HF script format \n", "# # inspiration: file squad.py at https://github.com/huggingface/datasets/tree/master/datasets/squad\n", "\n", "# import json \n", "# files = ['squad-train-v1.1.json','squad-dev-v1.1.json']\n", "\n", "# for file in files:\n", "\n", "# # Opening JSON file & returns JSON object as a dictionary \n", "# f = open(file, encoding=\"utf-8\") \n", "# data = json.load(f) \n", "\n", "# # Iterating through the json list \n", "# entry_list = list()\n", "# id_list = list()\n", "\n", "# for row in data['data']: \n", "# title = row['title']\n", "\n", "# for paragraph in row['paragraphs']:\n", "# context = paragraph['context']\n", "\n", "# for qa in paragraph['qas']:\n", "# entry = {}\n", "\n", "# qa_id = qa['id']\n", "# question = qa['question']\n", "# answers = qa['answers']\n", "\n", "# entry['id'] = qa_id\n", "# entry['title'] = title.strip()\n", "# entry['context'] = context.strip()\n", "# entry['question'] = question.strip()\n", "\n", "# answer_starts = [answer[\"answer_start\"] for answer in answers]\n", "# answer_texts = [answer[\"text\"].strip() for answer in answers]\n", "# entry['answers'] = {}\n", "# entry['answers']['answer_start'] = answer_starts\n", "# entry['answers']['text'] = answer_texts\n", "\n", "# entry_list.append(entry)\n", "\n", "# reverse_entry_list = entry_list[::-1]\n", "\n", "# # for entries with same id, keep only last one (corrected texts by the group Deep Learning Brasil)\n", "# unique_ids_list = list()\n", "# unique_entry_list = list()\n", "# for entry in reverse_entry_list:\n", "# qa_id = entry['id']\n", "# if qa_id not in unique_ids_list:\n", "# unique_ids_list.append(qa_id)\n", "# unique_entry_list.append(entry)\n", "\n", "# # Closing file \n", "# f.close() \n", "\n", "# new_dict = {}\n", "# new_dict['data'] = unique_entry_list\n", "\n", "# file_name = 'pt_' + str(file)\n", "# with open(file_name, 'w') as json_file:\n", "# json.dump(new_dict, json_file)\n", " \n", "# %cd {root}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Check the dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "W7QYTpxXIrIl" }, "source": [ "We will use the [🤗 Datasets](https://github.com/huggingface/datasets) library to download the data and get the metric we need to use for evaluation (to compare our model to the benchmark). This can be easily done with the functions `load_dataset` and `load_metric`. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset, load_metric\n", "\n", "if dataset_name == \"squad11pt\":\n", " \n", " # dataset folder \n", " root = Path.cwd()\n", " path_to_dataset = root.parent/'data'/dataset_name\n", " \n", " # paths to files\n", " train_file = str(path_to_dataset/'pt_squad-train-v1.1.json')\n", " validation_file = str(path_to_dataset/'pt_squad-dev-v1.1.json')\n", " \n", " datasets = load_dataset('json', \n", " data_files={'train': train_file, \\\n", " 'validation': validation_file, \\\n", " }, \n", " field='data')" ] }, { "cell_type": "markdown", "metadata": { "id": "RzfPtOMoIrIu" }, "source": [ "The `datasets` object itself is [`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict), which contains one key for the training, validation and test set." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "id": "GWiVUF0jIrIv", "outputId": "35e3ea43-f397-4a54-c90c-f2cf8d36873e" }, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['id', 'title', 'context', 'question', 'answers'],\n", " num_rows: 87510\n", " })\n", " validation: Dataset({\n", " features: ['answers', 'context', 'id', 'question', 'title'],\n", " num_rows: 10570\n", " })\n", "})" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "datasets" ] }, { "cell_type": "markdown", "metadata": { "id": "u3EtYfeHIrIz" }, "source": [ "To access an actual element, you need to select a split first, then give an index:" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "id": "X6HrpprwIrIz", "outputId": "d7670bc0-42e4-4c09-8a6a-5c018ded7d95" }, "outputs": [ { "data": { "text/plain": [ "{'id': '5735d259012e2f140011a0a1',\n", " 'title': 'Kathmandu',\n", " 'context': 'A Cidade Metropolitana de Catmandu (KMC), a fim de promover as relações internacionais, criou uma Secretaria de Relações Internacionais (IRC). O primeiro relacionamento internacional da KMC foi estabelecido em 1975 com a cidade de Eugene, Oregon, Estados Unidos. Essa atividade foi aprimorada ainda mais com o estabelecimento de relações formais com outras 8 cidades: Cidade de Motsumoto, Japão, Rochester, EUA, Yangon (antiga Rangum) de Mianmar, Xian da República Popular da China, Minsk da Bielorrússia e Pyongyang de República Democrática da Coréia. O esforço constante da KMC é aprimorar sua interação com os países da SAARC, outras agências internacionais e muitas outras grandes cidades do mundo para alcançar melhores programas de gestão urbana e desenvolvimento para Katmandu.',\n", " 'question': 'De que KMC é um inicialismo?',\n", " 'answers': {'answer_start': [2],\n", " 'text': ['Cidade Metropolitana de Catmandu']}}" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "datasets[\"train\"][0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see the answers are indicated by their start position in the text (here at character 2) and their full text, which is a substring of the context as we mentioned above." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## DeepSpeed" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's setup the `DeepSpeed` configuration." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To use a LR Linear Decay after warmup as scheduler (equivalent of the one by default in the HF Trainer), we changed `WarmupLR` to `WarmupDecayLR` in the DeepSpeed configuration file, and kept a copy of the scheduler initial code here:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "````\n", "# the lr stays constant after the warmup (this is not equivalent to teh default scheduler of HF which is Linear)\n", " \"scheduler\": {\n", " \"type\": \"WarmupLR\",\n", " \"params\": {\n", " \"warmup_min_lr\": \"auto\",\n", " \"warmup_max_lr\": \"auto\",\n", " \"warmup_num_steps\": \"auto\"\n", " }\n", " },\n", "````" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### ZeRO-2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "source: https://huggingface.co/transformers/master/main_classes/deepspeed.html#zero-2-example" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "%%bash\n", "\n", "cat <<'EOT' > ds_config_zero2.json\n", "{\n", " \"fp16\": {\n", " \"enabled\": \"auto\",\n", " \"loss_scale\": 0,\n", " \"loss_scale_window\": 1000,\n", " \"initial_scale_power\": 16,\n", " \"hysteresis\": 2,\n", " \"min_loss_scale\": 1\n", " },\n", "\n", " \"optimizer\": {\n", " \"type\": \"AdamW\",\n", " \"params\": {\n", " \"lr\": \"auto\",\n", " \"betas\": \"auto\",\n", " \"eps\": \"auto\",\n", " \"weight_decay\": \"auto\"\n", " }\n", " },\n", "\n", " \"scheduler\": {\n", " \"type\": \"WarmupDecayLR\",\n", " \"params\": {\n", " \"last_batch_iteration\": -1,\n", " \"total_num_steps\": \"auto\",\n", " \"warmup_min_lr\": \"auto\",\n", " \"warmup_max_lr\": \"auto\",\n", " \"warmup_num_steps\": \"auto\"\n", " }\n", " },\n", "\n", " \"zero_optimization\": {\n", " \"stage\": 2,\n", " \"offload_optimizer\": {\n", " \"device\": \"cpu\",\n", " \"pin_memory\": true\n", " },\n", " \"allgather_partitions\": true,\n", " \"allgather_bucket_size\": 5e8,\n", " \"overlap_comm\": true,\n", " \"reduce_scatter\": true,\n", " \"reduce_bucket_size\": 5e8,\n", " \"contiguous_gradients\": true\n", " },\n", "\n", " \"gradient_accumulation_steps\": \"auto\",\n", " \"gradient_clipping\": \"auto\",\n", " \"steps_per_print\": 2000,\n", " \"train_batch_size\": \"auto\",\n", " \"train_micro_batch_size_per_gpu\": \"auto\",\n", " \"wall_clock_breakdown\": false\n", "}\n", "EOT" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### ZeRO-3" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "source: https://huggingface.co/transformers/master/main_classes/deepspeed.html#zero-3-example" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Compared to ZeRO-2, the `ZeRO-3` configuration allows to train larger models but also for a longer training time. For this reason, we will not be using the `ZeRO-3` configuration." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "%%bash\n", "\n", "cat <<'EOT' > ds_config_zero3.json\n", "{\n", " \"fp16\": {\n", " \"enabled\": \"auto\",\n", " \"loss_scale\": 0,\n", " \"loss_scale_window\": 1000,\n", " \"initial_scale_power\": 16,\n", " \"hysteresis\": 2,\n", " \"min_loss_scale\": 1\n", " },\n", "\n", " \"optimizer\": {\n", " \"type\": \"AdamW\",\n", " \"params\": {\n", " \"lr\": \"auto\",\n", " \"betas\": \"auto\",\n", " \"eps\": \"auto\",\n", " \"weight_decay\": \"auto\"\n", " }\n", " },\n", "\n", " \"scheduler\": {\n", " \"type\": \"WarmupDecayLR\",\n", " \"params\": {\n", " \"last_batch_iteration\": -1,\n", " \"total_num_steps\": \"auto\",\n", " \"warmup_min_lr\": \"auto\",\n", " \"warmup_max_lr\": \"auto\",\n", " \"warmup_num_steps\": \"auto\"\n", " }\n", " },\n", " \n", " \"zero_optimization\": {\n", " \"stage\": 3,\n", " \"offload_optimizer\": {\n", " \"device\": \"cpu\",\n", " \"pin_memory\": true\n", " },\n", " \"offload_param\": {\n", " \"device\": \"cpu\",\n", " \"pin_memory\": true\n", " },\n", " \"overlap_comm\": true,\n", " \"contiguous_gradients\": true,\n", " \"sub_group_size\": 1e9,\n", " \"reduce_bucket_size\": \"auto\",\n", " \"stage3_prefetch_bucket_size\": \"auto\",\n", " \"stage3_param_persistence_threshold\": \"auto\",\n", " \"stage3_max_live_parameters\": 1e9,\n", " \"stage3_max_reuse_distance\": 1e9,\n", " \"stage3_gather_fp16_weights_on_model_save\": true\n", " },\n", "\n", " \"gradient_accumulation_steps\": \"auto\",\n", " \"gradient_clipping\": \"auto\",\n", " \"steps_per_print\": 2000,\n", " \"train_batch_size\": \"auto\",\n", " \"train_micro_batch_size_per_gpu\": \"auto\",\n", " \"wall_clock_breakdown\": false\n", "}\n", "EOT" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training arguments" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's setup all the training arguments needed by the script `run_qa.py`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### GPU" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "num_gpus = 1 # run the script on only one gpu\n", "gpu = 0 # select the gpu" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### model, dataset, sequence" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "# setup the training argument\n", "do_train = True # False\n", "do_eval = True \n", "\n", "if dataset_name == \"squad11pt\":\n", " \n", " # dataset folder \n", " root = Path.cwd()\n", " path_to_dataset = root.parent/'data'/dataset_name\n", " \n", " # paths to files\n", " train_file = str(path_to_dataset/'pt_squad-train-v1.1.json')\n", " validation_file = str(path_to_dataset/'pt_squad-dev-v1.1.json')\n", " \n", "# if you want to test the trainer, set up the following variables\n", "max_train_samples = 200 # None\n", "max_eval_samples = 50 # None\n", "\n", "# The maximum total input sequence length after tokenization. \n", "# Sequences longer than this will be truncated, sequences shorter will be padded.\n", "max_seq_length = 384\n", " \n", "# Whether to pad all samples to `max_seq_length`.\n", "# If False, will pad the samples dynamically when batching to the maximum length in the batch\n", "# (which can be faster on GPU but will be slower on TPU).\n", "pad_to_max_length = True\n", " \n", "# If true, some of the examples do not have an answer.\n", "version_2_with_negative = False\n", "\n", "# When splitting up a long document into chunks, how much stride to take between chunks.\n", "doc_stride = 128\n", "\n", "# The total number of n-best predictions to generate when looking for an answer.\n", "n_best_size = 20\n", " \n", "# The maximum length of an answer that can be generated. This is needed because the start\n", "# and end predictions are not conditioned on one another. \n", "max_answer_length = 30" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### training_args()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "source: https://huggingface.co/transformers/main_classes/trainer.html#transformers.TrainingArguments" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you keep the value 1e-8 for `adam_epsilon` in `fp16` mode, it is zero. The first non-zero value is 1e-7 in this mode. After some testing, we found that `adam_epsilon = 1e-4` gives the best results." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We use the `ZeRO-2` mode for `DeepSpeed`." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "# epochs, bs, GA\n", "evaluation_strategy = \"epoch\" # no\n", "BS = batch_size\n", "gradient_accumulation_steps = 1\n", "\n", "# optimizer (AdamW)\n", "learning_rate = 5e-5\n", "weight_decay = 0.01 # 0.0\n", "adam_beta1 = 0.9\n", "adam_beta2 = 0.999\n", "adam_epsilon = 1e-4 # 1e-08\n", "\n", "# epochs\n", "num_train_epochs = 3.\n", "\n", "# scheduler\n", "lr_scheduler_type = 'linear'\n", "warmup_ratio = 0.0\n", "warmup_steps = 0\n", "\n", "# logs\n", "logging_strategy = \"steps\"\n", "logging_first_step = True # False\n", "logging_steps = 500 # if strategy = \"steps\"\n", "eval_steps = logging_steps # logging_steps\n", "\n", "# checkpoints\n", "save_strategy = \"epoch\" # steps\n", "save_steps = 500 # if save_strategy = \"steps\"\n", "save_total_limit = 1 # None\n", "\n", "# no cuda, seed\n", "no_cuda = False\n", "seed = 42\n", "\n", "# fp16\n", "fp16 = True # False\n", "fp16_opt_level = 'O1'\n", "fp16_backend = \"auto\"\n", "fp16_full_eval = False\n", "\n", "# bar\n", "disable_tqdm = False # True\n", "remove_unused_columns = True\n", "#label_names (List[str], optional) \n", "\n", "# best model\n", "load_best_model_at_end = True # False\n", "metric_for_best_model = \"eval_f1\"\n", "greater_is_better = True\n", "\n", "# deepspeed\n", "zero = 2\n", "\n", "if zero == 2:\n", " deepspeed_config = \"ds_config_zero2.json\"\n", "elif zero == 3:\n", " deepspeed_config = \"ds_config_zero3.json\"" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "# folder for training outputs\n", "outputs = model_name_or_path.replace('/','-') + '-' + dataset_name \\\n", "+ '_wd' + str(weight_decay) + '_eps' + str(adam_epsilon) \\\n", "+ '_epochs' + str(num_train_epochs) \\\n", "+ '-lr' + str(learning_rate)\n", "path_to_outputs = root/'models_outputs'/outputs\n", "\n", "# subfolder for model outputs\n", "output_dir = path_to_outputs/'output_dir' \n", "overwrite_output_dir = True # False\n", "\n", "# logs\n", "logging_dir = path_to_outputs/'logging_dir'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training + Evaluation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Update the system path with the virtual environment path" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is needed to launch the `deepspeed` command in our server configuration." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "PATH = os.getenv('PATH')\n", "%env PATH=/mnt/home/xxxx/anaconda3/envs/aaaa/bin:$PATH\n", " \n", "# xxxx is the folder name where anaconda was installed\n", "# aaaa is the virtual ambiente name within this notebook is run" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Setup environment variables " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The magic command `%env` corresponds to `export` in linux. It allows to setup the values of all arguments of the script `run_qa.py`.\n", "\n", "*Note: as we noticed that the script runs without environment variables in this notebook but with local ones, we do not use this magic command.*" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Delete the output_dir (if exists)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!rm -r {output_dir}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, we can launch the training :-) " ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "# copy/paste/uncomment the 2 following lines in the following cell if you want to limit the number of data (useful for testing)\n", "# --max_train_samples $max_train_samples \\\n", "# --max_eval_samples $max_eval_samples \\" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "# !deepspeed --num_gpus=$num_gpus run_qa.py \\\n", "!deepspeed --include localhost:$gpu run_qa.py \\\n", "--model_name_or_path $model_name_or_path \\\n", "--train_file $train_file \\\n", "--do_train $do_train \\\n", "--do_eval $do_eval \\\n", "--validation_file $validation_file \\\n", "--max_seq_length $max_seq_length \\\n", "--pad_to_max_length $pad_to_max_length \\\n", "--version_2_with_negative $version_2_with_negative \\\n", "--doc_stride $doc_stride \\\n", "--n_best_size $n_best_size \\\n", "--max_answer_length $max_answer_length \\\n", "--output_dir $output_dir \\\n", "--overwrite_output_dir $overwrite_output_dir \\\n", "--evaluation_strategy $evaluation_strategy \\\n", "--per_device_train_batch_size $batch_size \\\n", "--per_device_eval_batch_size $batch_size \\\n", "--gradient_accumulation_steps $gradient_accumulation_steps \\\n", "--learning_rate $learning_rate \\\n", "--weight_decay $weight_decay \\\n", "--adam_beta1 $adam_beta1 \\\n", "--adam_beta2 $adam_beta2 \\\n", "--adam_epsilon $adam_epsilon \\\n", "--num_train_epochs $num_train_epochs \\\n", "--warmup_ratio $warmup_ratio \\\n", "--warmup_steps $warmup_steps \\\n", "--logging_dir $logging_dir \\\n", "--logging_strategy $logging_strategy \\\n", "--logging_first_step $logging_first_step \\\n", "--logging_steps $logging_steps \\\n", "--eval_steps $eval_steps \\\n", "--save_strategy $save_strategy \\\n", "--save_steps $save_steps \\\n", "--save_total_limit $save_total_limit \\\n", "--no_cuda $no_cuda \\\n", "--seed $seed \\\n", "--fp16 $fp16 \\\n", "--fp16_opt_level $fp16_opt_level \\\n", "--fp16_backend $fp16_backend \\\n", "--fp16_full_eval $fp16_full_eval \\\n", "--disable_tqdm $disable_tqdm \\\n", "--remove_unused_columns $remove_unused_columns \\\n", "--load_best_model_at_end $load_best_model_at_end \\\n", "--metric_for_best_model $metric_for_best_model \\\n", "--greater_is_better $greater_is_better \\\n", "--deepspeed $deepspeed_config" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Results**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "````\n", "[INFO|trainer_pt_utils.py:908] 2021-06-18 05:32:34,020 >> ***** eval metrics *****\n", "[INFO|trainer_pt_utils.py:913] 2021-06-18 05:32:34,020 >> epoch = 3.0\n", "[INFO|trainer_pt_utils.py:913] 2021-06-18 05:32:34,020 >> eval_exact_match = 72.6774\n", "[INFO|trainer_pt_utils.py:913] 2021-06-18 05:32:34,020 >> eval_f1 = 84.4315\n", "[INFO|trainer_pt_utils.py:913] 2021-06-18 05:32:34,020 >> eval_samples = 10917\n", "CPU times: user 5min 5s, sys: 51.3 s, total: 5min 56s\n", "Wall time: 3h 20min 36s\n", "````" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## TensorBoard" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "#!pip install tensorboard" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%load_ext tensorboard\n", "# %reload_ext tensorboard\n", "%tensorboard --logdir {logging_dir} --bind_all" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Getting The Model Weights Out" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To get back weights in fp32, read this: https://huggingface.co/transformers/master/main_classes/deepspeed.html#getting-the-model-weights-out" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Will not apply HSTS. The HSTS database must be a regular and non-world-writable file.\n", "ERROR: could not open HSTS store at '/mnt/home/pierre/.wget-hsts'. HSTS will be disabled.\n", "--2021-06-18 12:02:06-- https://raw.githubusercontent.com/microsoft/DeepSpeed/master/deepspeed/utils/zero_to_fp32.py\n", "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.108.133, 185.199.110.133, ...\n", "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 6468 (6.3K) [text/plain]\n", "Saving to: ‘zero_to_fp32.py.1’\n", "\n", "zero_to_fp32.py.1 100%[===================>] 6.32K --.-KB/s in 0s \n", "\n", "2021-06-18 12:02:06 (66.1 MB/s) - ‘zero_to_fp32.py.1’ saved [6468/6468]\n", "\n" ] } ], "source": [ "!wget https://raw.githubusercontent.com/microsoft/DeepSpeed/master/deepspeed/utils/zero_to_fp32.py" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# in this training, checkpoint-16734 contains the best model\n", "%cd {output_dir}/'checkpoint-16734'" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "total 651872\r\n", "drwxrwxr-x 3 pierre pierre 4096 Jun 18 11:38 .\r\n", "drwxrwxr-x 3 pierre pierre 4096 Jun 18 11:38 ..\r\n", "-rw-rw-r-- 1 pierre pierre 855 Jun 18 11:38 config.json\r\n", "drwxrwxr-x 2 pierre pierre 4096 Jun 18 11:38 global_step16734\r\n", "-rw-rw-r-- 1 pierre pierre 16 Jun 18 11:38 latest\r\n", "-rw-rw-r-- 1 pierre pierre 666791233 Jun 18 11:38 pytorch_model.bin\r\n", "-rw-rw-r-- 1 pierre pierre 14657 Jun 18 11:38 rng_state_0.pth\r\n", "-rw-rw-r-- 1 pierre pierre 112 Jun 18 11:38 special_tokens_map.json\r\n", "-rw-rw-r-- 1 pierre pierre 438465 Jun 18 11:38 tokenizer.json\r\n", "-rw-rw-r-- 1 pierre pierre 506 Jun 18 11:38 tokenizer_config.json\r\n", "-rw-rw-r-- 1 pierre pierre 5051 Jun 18 11:38 trainer_state.json\r\n", "-rw-rw-r-- 1 pierre pierre 3951 Jun 18 11:38 training_args.bin\r\n", "-rw-rw-r-- 1 pierre pierre 209528 Jun 18 11:38 vocab.txt\r\n", "-rwxrw-r-- 1 pierre pierre 6468 Jun 18 11:38 zero_to_fp32.py\r\n" ] } ], "source": [ "!ls -al" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can observe that pytorch_model.bin has a size of 666 MB because the weights model have been saved with a fp16 format. Let's use the script `zero_to_fp32.py` from DeepSpeed in order to convert them to a fp32 format." ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "usage: zero_to_fp32.py [-h] checkpoint_dir output_file\r\n", "\r\n", "positional arguments:\r\n", " checkpoint_dir path to the deepspeed checkpoint folder, e.g.,\r\n", " path/checkpoint-1/global_step1\r\n", " output_file path to the pytorch fp32 state_dict output file (e.g.\r\n", " path/checkpoint-1/pytorch_model.bin)\r\n", "\r\n", "optional arguments:\r\n", " -h, --help show this help message and exit\r\n" ] } ], "source": [ "path_to_zero_to_fp32 = root/'zero_to_fp32.py'\n", "!python $path_to_zero_to_fp32 -h" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Processing zero checkpoint 'global_step16734'\n", "Detected checkpoint of type zero stage 2, world_size: 1\n", "Saving fp32 state dict to pytorch_model.bin (total_numel=333348866)\n" ] } ], "source": [ "!python $path_to_zero_to_fp32 global_step16734 pytorch_model.bin" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "total 1302912\r\n", "drwxrwxr-x 3 pierre pierre 4096 Jun 18 11:38 .\r\n", "drwxrwxr-x 3 pierre pierre 4096 Jun 18 11:38 ..\r\n", "-rw-rw-r-- 1 pierre pierre 855 Jun 18 11:38 config.json\r\n", "drwxrwxr-x 2 pierre pierre 4096 Jun 18 11:38 global_step16734\r\n", "-rw-rw-r-- 1 pierre pierre 16 Jun 18 11:38 latest\r\n", "-rw-rw-r-- 1 pierre pierre 1333453496 Jun 18 14:28 pytorch_model.bin\r\n", "-rw-rw-r-- 1 pierre pierre 14657 Jun 18 11:38 rng_state_0.pth\r\n", "-rw-rw-r-- 1 pierre pierre 112 Jun 18 11:38 special_tokens_map.json\r\n", "-rw-rw-r-- 1 pierre pierre 438465 Jun 18 11:38 tokenizer.json\r\n", "-rw-rw-r-- 1 pierre pierre 506 Jun 18 11:38 tokenizer_config.json\r\n", "-rw-rw-r-- 1 pierre pierre 5051 Jun 18 11:38 trainer_state.json\r\n", "-rw-rw-r-- 1 pierre pierre 3951 Jun 18 11:38 training_args.bin\r\n", "-rw-rw-r-- 1 pierre pierre 209528 Jun 18 11:38 vocab.txt\r\n", "-rwxrw-r-- 1 pierre pierre 6468 Jun 18 11:38 zero_to_fp32.py\r\n" ] } ], "source": [ "!ls -al" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "That's it! The size of our model (1.3 GB) means that the weights format is now fp32." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%cd {root}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Save the model to HF format" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [], "source": [ "import pathlib\n", "from pathlib import Path" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [], "source": [ "# model source\n", "source = root/output_dir/'checkpoint-16734/'\n", "\n", "# model destination\n", "dest = root/'HFmodels'\n", "fname_HF = 'bert-large-cased-squad-v1.1-portuguese'\n", "path_to_awesome_name_you_picked = dest/fname_HF\n", "path_to_awesome_name_you_picked.mkdir(exist_ok=True, parents=True)" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [], "source": [ "# copy model to destination\n", "!cp {source}/'config.json' {dest/fname_HF}\n", "!cp {source}/'pytorch_model.bin' {dest/fname_HF}\n", "\n", "# copy tokenizer to destination\n", "!cp {source}/'tokenizer_config.json' {dest/fname_HF}\n", "!cp {source}/'special_tokens_map.json' {dest/fname_HF}\n", "!cp {source}/'vocab.txt' {dest/fname_HF}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Make your model work on all frameworks** ([source](https://huggingface.co/transformers/model_sharing.html#make-your-model-work-on-all-frameworks))" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [], "source": [ "from transformers import BertForQuestionAnswering\n", "pt_model = BertForQuestionAnswering.from_pretrained(str(path_to_awesome_name_you_picked))\n", "pt_model.save_pretrained(str(path_to_awesome_name_you_picked))" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFBertForQuestionAnswering: ['bert.embeddings.position_ids']\n", "- This IS expected if you are initializing TFBertForQuestionAnswering from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing TFBertForQuestionAnswering from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).\n", "All the weights of TFBertForQuestionAnswering were initialized from the PyTorch model.\n", "If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertForQuestionAnswering for predictions without further training.\n" ] } ], "source": [ "import tensorflow\n", "from transformers import TFBertForQuestionAnswering\n", "\n", "tf_model = TFBertForQuestionAnswering.from_pretrained(str(path_to_awesome_name_you_picked), from_pt=True)\n", "tf_model.save_pretrained(str(path_to_awesome_name_you_picked))" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "total 2605184\r\n", "drwxrwxr-x 2 pierre pierre 4096 Jun 18 14:37 .\r\n", "drwxrwxr-x 3 pierre pierre 4096 Jun 18 14:34 ..\r\n", "-rw-rw-r-- 1 pierre pierre 918 Jun 18 14:37 config.json\r\n", "-rw-rw-r-- 1 pierre pierre 1333560247 Jun 18 14:37 pytorch_model.bin\r\n", "-rw-rw-r-- 1 pierre pierre 112 Jun 18 14:34 special_tokens_map.json\r\n", "-rw-rw-r-- 1 pierre pierre 1333906712 Jun 18 14:37 tf_model.h5\r\n", "-rw-rw-r-- 1 pierre pierre 506 Jun 18 14:34 tokenizer_config.json\r\n", "-rw-rw-r-- 1 pierre pierre 209528 Jun 18 14:34 vocab.txt\r\n" ] } ], "source": [ "!ls -al {path_to_awesome_name_you_picked}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model sharing and uploading to the HF models hub" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Don't forget to [upload your model](https://huggingface.co/transformers/model_sharing.html) on the [🤗 Model Hub](https://huggingface.co/models). You can then use it only to generate results like the one shown in the first picture of this notebook!|" ] }, { "cell_type": "markdown", "metadata": { "heading_collapsed": true, "id": "DcmsZKPizRrl" }, "source": [ "## Use our QA model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Gradio" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "source: [Using & Mixing Hugging Face Models with Gradio 2.0](https://huggingface.co/blog/gradio)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import gradio as gr\n", "\n", "iface = gr.Interface.load(\"huggingface/pierreguillou/bert-large-cased-squad-v1.1-portuguese\",server_name='xxxx')\n", "iface.launch()\n", "\n", "# xxxx is your server name (alias)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Code" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hidden": true, "id": "Ckl2Fzn3is0F" }, "outputs": [], "source": [ "### import transformers\n", "import pathlib\n", "from pathlib import Path" ] }, { "cell_type": "code", "execution_count": 51, "metadata": { "hidden": true, "id": "IJA9CgOBis0F" }, "outputs": [], "source": [ "from transformers import AutoModelForQuestionAnswering, AutoTokenizer\n", "\n", "model_qa = AutoModelForQuestionAnswering.from_pretrained(path_to_awesome_name_you_picked)\n", "tokenizer_qa = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)" ] }, { "cell_type": "code", "execution_count": 55, "metadata": { "hidden": true, "id": "NiuPQTxuzRrm" }, "outputs": [], "source": [ "from transformers import pipeline\n", "nlp = pipeline(\"question-answering\", model=model_qa, tokenizer=tokenizer_qa)" ] }, { "cell_type": "code", "execution_count": 93, "metadata": {}, "outputs": [], "source": [ "# source: https://pt.wikipedia.org/wiki/Pandemia_de_COVID-19\n", "context = r\"\"\"A pandemia de COVID-19, também conhecida como pandemia de coronavírus, é uma pandemia em curso de COVID-19, \n", "uma doença respiratória causada pelo coronavírus da síndrome respiratória aguda grave 2 (SARS-CoV-2). \n", "O vírus tem origem zoonótica e o primeiro caso conhecido da doença remonta a dezembro de 2019 em Wuhan, na China. \n", "Em 20 de janeiro de 2020, a Organização Mundial da Saúde (OMS) classificou o surto \n", "como Emergência de Saúde Pública de Âmbito Internacional e, em 11 de março de 2020, como pandemia. \n", "Em 18 de junho de 2021, 177 349 274 casos foram confirmados em 192 países e territórios, \n", "com 3 840 181 mortes atribuídas à doença, tornando-se uma das pandemias mais mortais da história.\n", "Os sintomas de COVID-19 são altamente variáveis, variando de nenhum a doenças com risco de morte. \n", "O vírus se espalha principalmente pelo ar quando as pessoas estão perto umas das outras. \n", "Ele deixa uma pessoa infectada quando ela respira, tosse, espirra ou fala e entra em outra pessoa pela boca, nariz ou olhos.\n", "Ele também pode se espalhar através de superfícies contaminadas. \n", "As pessoas permanecem contagiosas por até duas semanas e podem espalhar o vírus mesmo se forem assintomáticas.\n", "\"\"\"" ] }, { "cell_type": "code", "execution_count": 95, "metadata": { "hidden": true, "id": "t4_ezTchwKZl" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Answer: 'dezembro de 2019', score: 0.5087, start: 290, end: 306\n", "CPU times: user 1min 55s, sys: 7.79 s, total: 2min 2s\n", "Wall time: 3.52 s\n" ] } ], "source": [ "%%time\n", "question = \"Quando começou a pandemia de Covid-19 no mundo?\"\n", "\n", "result = nlp(question=question, context=context)\n", "\n", "print(f\"Answer: '{result['answer']}', score: {round(result['score'], 4)}, start: {result['start']}, end: {result['end']}\")" ] }, { "cell_type": "code", "execution_count": 96, "metadata": { "hidden": true, "id": "a_JgLD8fxdwn" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Answer: 'dezembro de 2019', score: 0.4988, start: 290, end: 306\n", "CPU times: user 1min 56s, sys: 6.79 s, total: 2min 3s\n", "Wall time: 3.5 s\n" ] } ], "source": [ "%%time\n", "question = \"Qual é a data de início da pandemia Covid-19 em todo o mundo?\"\n", "\n", "result = nlp(question=question, context=context)\n", "\n", "print(f\"Answer: '{result['answer']}', score: {round(result['score'], 4)}, start: {result['start']}, end: {result['end']}\")" ] }, { "cell_type": "code", "execution_count": 98, "metadata": { "hidden": true, "id": "NecR00-wzRrn" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Answer: 'O vírus tem origem zoonótica', score: 0.6001, start: 213, end: 241\n", "CPU times: user 1min 57s, sys: 13.8 s, total: 2min 11s\n", "Wall time: 3.76 s\n" ] } ], "source": [ "%%time\n", "question = \"A Covid-19 tem algo a ver com animais?\"\n", "\n", "result = nlp(question=question, context=context)\n", "\n", "print(f\"Answer: '{result['answer']}', score: {round(result['score'], 4)}, start: {result['start']}, end: {result['end']}\")" ] }, { "cell_type": "code", "execution_count": 99, "metadata": { "hidden": true, "id": "RcK4pn1hbLhL" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Answer: 'Wuhan, na China', score: 0.9415, start: 310, end: 325\n", "CPU times: user 1min 57s, sys: 9.3 s, total: 2min 6s\n", "Wall time: 3.62 s\n" ] } ], "source": [ "%%time\n", "question = \"Onde foi descoberta a Covid-19?\"\n", "\n", "result = nlp(question=question, context=context)\n", "\n", "print(f\"Answer: '{result['answer']}', score: {round(result['score'], 4)}, start: {result['start']}, end: {result['end']}\")" ] }, { "cell_type": "code", "execution_count": 100, "metadata": { "hidden": true, "id": "7rFEmmsjzRrn" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Answer: '177 349 274', score: 0.828, start: 536, end: 547\n", "CPU times: user 1min 54s, sys: 11.6 s, total: 2min 6s\n", "Wall time: 3.62 s\n" ] } ], "source": [ "%%time\n", "question = \"Quantos casos houve?\"\n", "\n", "result = nlp(question=question, context=context)\n", "\n", "print(f\"Answer: '{result['answer']}', score: {round(result['score'], 4)}, start: {result['start']}, end: {result['end']}\")" ] }, { "cell_type": "code", "execution_count": 101, "metadata": { "hidden": true, "id": "0v1TTQXDzRrn" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Answer: '3 840 181', score: 0.906, start: 606, end: 615\n", "CPU times: user 1min 58s, sys: 13.3 s, total: 2min 11s\n", "Wall time: 3.77 s\n" ] } ], "source": [ "%%time\n", "question = \"Quantos mortes?\"\n", "\n", "result = nlp(question=question, context=context)\n", "\n", "print(f\"Answer: '{result['answer']}', score: {round(result['score'], 4)}, start: {result['start']}, end: {result['end']}\")" ] }, { "cell_type": "code", "execution_count": 102, "metadata": { "hidden": true, "id": "f78AggHLzRro" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Answer: '192', score: 0.8958, start: 575, end: 578\n", "CPU times: user 1min 54s, sys: 10 s, total: 2min 4s\n", "Wall time: 3.56 s\n" ] } ], "source": [ "%%time\n", "question = \"Quantos paises tiveram casos?\"\n", "\n", "result = nlp(question=question, context=context)\n", "\n", "print(f\"Answer: '{result['answer']}', score: {round(result['score'], 4)}, start: {result['start']}, end: {result['end']}\")" ] }, { "cell_type": "code", "execution_count": 103, "metadata": { "hidden": true, "id": "J0AGoVc_xhdo" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Answer: 'nenhum a doenças com risco de morte', score: 0.298, start: 761, end: 796\n", "CPU times: user 1min 56s, sys: 11.5 s, total: 2min 8s\n", "Wall time: 3.66 s\n" ] } ], "source": [ "%%time\n", "question = \"Quais são sintomas de COVID-19\"\n", "\n", "result = nlp(question=question, context=context)\n", "\n", "print(f\"Answer: '{result['answer']}', score: {round(result['score'], 4)}, start: {result['start']}, end: {result['end']}\")" ] }, { "cell_type": "code", "execution_count": 104, "metadata": { "hidden": true, "id": "YSQnntVgcHHq" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Answer: 'principalmente pelo ar quando as pessoas estão perto umas das outras', score: 0.3173, start: 818, end: 886\n", "CPU times: user 1min 52s, sys: 8.4 s, total: 2min 1s\n", "Wall time: 3.46 s\n" ] } ], "source": [ "%%time\n", "question = \"Como se espalha o vírus?\"\n", "\n", "result = nlp(question=question, context=context)\n", "\n", "print(f\"Answer: '{result['answer']}', score: {round(result['score'], 4)}, start: {result['start']}, end: {result['end']}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# END" ] } ], "metadata": { "colab": { "name": "Question Answering on SQUAD", "provenance": [] }, "kernelspec": { "display_name": "hf_source", "language": "python", "name": "hf_source" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "051aa783ff9e47e28d1f9584043815f5": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "0984b2a14115454bbb009df71c1cf36f": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "info", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_cbea68b25d6d4ba09b2ce0f27b1726d5", "max": 1, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_c9de740e007141958545e269372780a4", "value": 1 } }, "0b7c8f1939074794b3d9221244b1344d": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "160bf88485f44f5cb6eaeecba5e0901f": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "1a65887eb37747ddb75dc4a40f7285f2": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "info", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_aa781f0cfe454e9da5b53b93e9baabd8", "max": 1, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_50d325cdb9844f62a9ecc98e768cb5af", "value": 1 } }, "1aca01c1d8c940dfadd3e7144bb35718": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "info", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_940d00556cb849b3a689d56e274041c2", "max": 1, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_fea27ca6c9504fc896181bc1ff5730e5", "value": 1 } }, "2361ab124daf47cc885ff61f2899b2af": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "299f4b4c07654e53a25f8192bd1d7bbd": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "2ace4dc78e2f4f1492a181bcd63304e7": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "2b34de08115d49d285def9269a53f484": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "2f5223f26c8541fc87e91d2205c39995": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "31b1c8a2e3334b72b45b083688c1a20c": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_2f5223f26c8541fc87e91d2205c39995", "placeholder": "​", "style": "IPY_MODEL_a71908883b064e1fbdddb547a8c41743", "value": " 4.39k/? [00:00<00:00, 149kB/s]" } }, "3c946e2260704e6c98593136bd32d921": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_7e29a8b952cf4f4ea42833c8bf55342f", "placeholder": "​", "style": "IPY_MODEL_6bb68d3887ef43809eb23feb467f9723", "value": " 1063/0 [00:00<00:00, 12337.52 examples/s]" } }, "3f74532faa86412293d90d3952f38c4a": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "46c2b043c0f84806978784a45a4e203b": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "50615aa59c7247c4804ca5cbc7945bd7": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "Downloading: ", "description_tooltip": null, "layout": "IPY_MODEL_ad04ed1038154081bbb0c1444784dcc2", "max": 7826, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_299f4b4c07654e53a25f8192bd1d7bbd", "value": 7826 } }, "50d325cdb9844f62a9ecc98e768cb5af": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "5781fc45cf8d486cb06ed68853b2c644": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "5cdf9ed939fb42d4bf77301c80b8afca": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "5fa26fc336274073abbd1d550542ee33": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "69caab03d6264fef9fc5649bffff5e20": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_50615aa59c7247c4804ca5cbc7945bd7", "IPY_MODEL_fe962391292a413ca55dc932c4279fa7" ], "layout": "IPY_MODEL_3f74532faa86412293d90d3952f38c4a" } }, "6bb68d3887ef43809eb23feb467f9723": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "6c1db72efff5476e842c1386fadbbdba": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_2b34de08115d49d285def9269a53f484", "placeholder": "​", "style": "IPY_MODEL_5fa26fc336274073abbd1d550542ee33", "value": " 28.7k/? [00:00<00:00, 571kB/s]" } }, "745c0d47d672477b9bb0dae77b926364": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "Downloading: 100%", "description_tooltip": null, "layout": "IPY_MODEL_a7204ade36314c86907c562e0a2158b8", "max": 376971, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_d298eb19eeff453cba51c2804629d3f4", "value": 376971 } }, "75103f83538d44abada79b51a1cec09e": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "7c667ad22b5740d5a6319f1b1e3a8097": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "7e29a8b952cf4f4ea42833c8bf55342f": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "7fb7c36adc624f7dbbcb4a831c1e4f63": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "80e2943be35f46eeb24c8ab13faa6578": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_931db1f7a42f4b46b7ff8c2e1262b994", "IPY_MODEL_6c1db72efff5476e842c1386fadbbdba" ], "layout": "IPY_MODEL_de5956b5008d4fdba807bae57509c393" } }, "8ab9dfce29854049912178941ef1b289": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_d2a92143a08a4951b55bab9bc0a6d0d3", "placeholder": "​", "style": "IPY_MODEL_5781fc45cf8d486cb06ed68853b2c644", "value": " 8551/0 [00:00<00:00, 25108.88 examples/s]" } }, "931db1f7a42f4b46b7ff8c2e1262b994": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "Downloading: ", "description_tooltip": null, "layout": "IPY_MODEL_d30a66df5c0145e79693e09789d96b81", "max": 4473, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_ccd2f37647c547abb4c719b75a26f2de", "value": 4473 } }, "940d00556cb849b3a689d56e274041c2": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "94b39ccfef0b4b08bf2fb61bb0a657c1": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "9a55087c85b74ea08b3e952ac1d73cbe": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_1a65887eb37747ddb75dc4a40f7285f2", "IPY_MODEL_3c946e2260704e6c98593136bd32d921" ], "layout": "IPY_MODEL_2361ab124daf47cc885ff61f2899b2af" } }, "9fbbaae50e6743f2aa19342152398186": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_94b39ccfef0b4b08bf2fb61bb0a657c1", "placeholder": "​", "style": "IPY_MODEL_5cdf9ed939fb42d4bf77301c80b8afca", "value": " 1043/0 [00:00<00:00, 13590.50 examples/s]" } }, "a14c3e40e5254d61ba146f6ec88eae25": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_1aca01c1d8c940dfadd3e7144bb35718", "IPY_MODEL_9fbbaae50e6743f2aa19342152398186" ], "layout": "IPY_MODEL_c4ffe6f624ce4e978a0d9b864544941a" } }, "a71908883b064e1fbdddb547a8c41743": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "a7204ade36314c86907c562e0a2158b8": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "aa781f0cfe454e9da5b53b93e9baabd8": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "ad04ed1038154081bbb0c1444784dcc2": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "bbee008c2791443d8610371d1f16b62b": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "Downloading: ", "description_tooltip": null, "layout": "IPY_MODEL_0b7c8f1939074794b3d9221244b1344d", "max": 1586, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_7fb7c36adc624f7dbbcb4a831c1e4f63", "value": 1586 } }, "c4ffe6f624ce4e978a0d9b864544941a": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "c9de740e007141958545e269372780a4": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "cbea68b25d6d4ba09b2ce0f27b1726d5": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "ccd2f37647c547abb4c719b75a26f2de": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "d22ab78269cd4ccfbcf70c707057c31b": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_75103f83538d44abada79b51a1cec09e", "placeholder": "​", "style": "IPY_MODEL_e35d42b2d352498ca3fc8530393786b2", "value": " 377k/377k [00:00<00:00, 703kB/s]" } }, "d298eb19eeff453cba51c2804629d3f4": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "d2a92143a08a4951b55bab9bc0a6d0d3": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "d30a66df5c0145e79693e09789d96b81": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "d426be871b424affb455aeb7db5e822e": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_745c0d47d672477b9bb0dae77b926364", "IPY_MODEL_d22ab78269cd4ccfbcf70c707057c31b" ], "layout": "IPY_MODEL_160bf88485f44f5cb6eaeecba5e0901f" } }, "dd5997d01d8947e4b1c211433969b89b": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_bbee008c2791443d8610371d1f16b62b", "IPY_MODEL_31b1c8a2e3334b72b45b083688c1a20c" ], "layout": "IPY_MODEL_2ace4dc78e2f4f1492a181bcd63304e7" } }, "de5956b5008d4fdba807bae57509c393": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "e35d42b2d352498ca3fc8530393786b2": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "f6253931d90543e9b5fd0bb2d615f73a": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_0984b2a14115454bbb009df71c1cf36f", "IPY_MODEL_8ab9dfce29854049912178941ef1b289" ], "layout": "IPY_MODEL_051aa783ff9e47e28d1f9584043815f5" } }, "fe962391292a413ca55dc932c4279fa7": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_46c2b043c0f84806978784a45a4e203b", "placeholder": "​", "style": "IPY_MODEL_7c667ad22b5740d5a6319f1b1e3a8097", "value": " 28.7k/? [00:00<00:00, 652kB/s]" } }, "fea27ca6c9504fc896181bc1ff5730e5": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } } } } }, "nbformat": 4, "nbformat_minor": 1 }