{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# HuggingFace Datasets library demo\n", "\n", "Quick summary:\n", "\n", "- 50+ NLP datasets + super easy to add new ones (like Transformers models)\n", "- Simple and fast API to download and pre-process the datasets\n", "- Super easy to tokenize and process them in an efficient way\n", "- All dataset memory mapped on drive (no RAM limitation)\n", "- Smart caching on drive, process once, reuse everytime\n", "\n", "Soon: datasets streaming for huge datasets and 100+ datasets" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import logging\n", "logging.basicConfig(level=logging.INFO)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:nlp.utils.file_utils:PyTorch version 1.4.0 available.\n" ] } ], "source": [ "# Let's import the library\n", "import nlp" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Currently available 54 datasets (not tested yet for most of them):\n", "- aeslc\n", "- amazon_us_reviews\n", "- big_patent\n", "- billsum\n", "- blimp\n", "- c4\n", "- cfq\n", "- civil_comments\n", "- cnn_dailymail\n", "- cos_e\n", "- definite_pronoun_resolution\n", "- eraser_multi_rc\n", "- esnli\n", "- flores\n", "- forest_fires\n", "- gap\n", "- german_credit_numeric\n", "- gigaword\n", "- glue\n", "- higgs\n", "- imdb\n", "- iris\n", "- librispeech_lm\n", "- lm1b\n", "- math_dataset\n", "- movie_rationales\n", "- multi_news\n", "- multi_nli\n", "- multi_nli_mismatch\n", "- natural_questions\n", "- newsroom\n", "- opinosis\n", "- para_crawl\n", "- qa4mre\n", "- reddit_tifu\n", "- rock_you\n", "- scan\n", "- scicite\n", "- scientific_papers\n", "- snli\n", "- squad\n", "- super_glue\n", "- ted_hrlr\n", "- ted_multi\n", "- tiny_shakespeare\n", "- titanic\n", "- trivia_qa\n", "- wiki40b\n", "- wikihow\n", "- wikipedia\n", "- wmt\n", "- xnli\n", "- xsum\n", "- yelp_polarity" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## An example with SQuAD" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:nlp.load:Dataset script /Users/thomwolf/.cache/huggingface/datasets/ee43d2be6898ebb9c2afefda4455306911d308bcf924d21c975796832cc7c114.e7d8881147e5da61c98918c61832c7f1c88b33b51a082c464e70e119bb24983d already found in datasets directory at /Users/thomwolf/Documents/GitHub/datasets/src/nlp/datasets/686d79c021d7dcd78da4d67fe01fbe30dfecabcd4bd02d06aa9d51edab713144/squad.py, returning it. Use `force_reload=True` to override.\n", "INFO:nlp.builder:No config specified, defaulting to first: squad/plain_text\n", "INFO:nlp.builder:Overwrite dataset info from restored data version.\n", "INFO:nlp.info:Loading Dataset info from /Users/thomwolf/.cache/huggingface/datasets/squad/plain_text/1.0.0\n", "INFO:nlp.builder:Reusing dataset squad (/Users/thomwolf/.cache/huggingface/datasets/squad/plain_text/1.0.0)\n", "INFO:nlp.builder:Constructing Dataset for split validation[:10%], from /Users/thomwolf/.cache/huggingface/datasets/squad/plain_text/1.0.0\n" ] } ], "source": [ "# Downloading and loading a dataset is a one-liner\n", "\n", "dataset = nlp.load('squad', split='validation[:10%]')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This call to `nlp.load()` does the following steps under the hood:\n", "\n", "1. Download and import in the library the **SQuAD python processing script** from our S3 if it's not already stored in the library. You can find the SQuAD processing script [here](https://s3.amazonaws.com/datasets.huggingface.co/nlp/squad/squad.py) for instance.\n", "\n", " Proecssing scripts are small python scripts that define the info and format of the dataset, contain the URL to the original SQuAD JSON files and the code to load examples from the original SQuAD JSON files.\n", "\n", "\n", "2. Run the SQuAD python processing script which will:\n", " - **Download the SQuAD dataset** from the original URL (see the script) if it's not already downloaded and cached.\n", " - **Process and cache** all SQuAD in a structured Arrow table for each standard splits stored on the drive.\n", "\n", " Arrow table are arbitrarly long tables, typed with types that can be mapped to numpy/pandas/python standard types and can store nested objects. They can be directly access from drive, loaded in RAM or even streamed over the web.\n", " \n", "\n", "3. Return a **dataset build from the splits** asked by the user (default: all), in the above example we create a dataset with the first 10% of the validation split." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "DatasetInfo(\n", " name='squad',\n", " version=1.0.0,\n", " description='Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span, from the corresponding reading passage, or the question might be unanswerable.\n", "',\n", " homepage='https://rajpurkar.github.io/SQuAD-explorer/',\n", " features=struct, answer_start: list>>,\n", " total_num_examples=98169,\n", " splits={\n", " 'train': 87599,\n", " 'validation': 10570,\n", " },\n", " supervised_keys=None,\n", " citation=\"\"\"@article{2016arXiv160605250R,\n", " author = {{Rajpurkar}, Pranav and {Zhang}, Jian and {Lopyrev},\n", " Konstantin and {Liang}, Percy},\n", " title = \"{SQuAD: 100,000+ Questions for Machine Comprehension of Text}\",\n", " journal = {arXiv e-prints},\n", " year = 2016,\n", " eid = {arXiv:1606.05250},\n", " pages = {arXiv:1606.05250},\n", " archivePrefix = {arXiv},\n", " eprint = {1606.05250},\n", " }\"\"\",\n", " license=None,\n", ")\n", "\n" ] } ], "source": [ "# General informations on the dataset are provided in the `.info` property\n", "print(dataset.info)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Inspecting the dataset: elements, slices and columns" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The returned `Dataset` object is a memory mapped dataset that behave similarly to a normal map-style dataset. It is backed by an Apache Arrow table which allows many interesting features." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dataset(schema: {'id': 'string', 'title': 'string', 'context': 'string', 'question': 'string', 'answers': 'struct, answer_start: list>'}, num_rows: 1057)\n" ] } ], "source": [ "print(dataset)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can query it's length and get items or slices like you would do normally with a python mapping." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dataset len(dataset): 1057\n", "First item:\n", "{'answers': {'answer_start': [177, 177, 177],\n", " 'text': ['Denver Broncos', 'Denver Broncos', 'Denver Broncos']},\n", " 'context': 'Super Bowl 50 was an American football game to determine the '\n", " 'champion of the National Football League (NFL) for the 2015 '\n", " 'season. The American Football Conference (AFC) champion Denver '\n", " 'Broncos defeated the National Football Conference (NFC) champion '\n", " 'Carolina Panthers 24–10 to earn their third Super Bowl title. The '\n", " \"game was played on February 7, 2016, at Levi's Stadium in the San \"\n", " 'Francisco Bay Area at Santa Clara, California. As this was the '\n", " '50th Super Bowl, the league emphasized the \"golden anniversary\" '\n", " 'with various gold-themed initiatives, as well as temporarily '\n", " 'suspending the tradition of naming each Super Bowl game with '\n", " 'Roman numerals (under which the game would have been known as '\n", " '\"Super Bowl L\"), so that the logo could prominently feature the '\n", " 'Arabic numerals 50.',\n", " 'id': '56be4db0acb8001400a502ec',\n", " 'question': 'Which NFL team represented the AFC at Super Bowl 50?',\n", " 'title': 'Super_Bowl_50'}\n", "Slice of the first two items:\n", "{'answers': [{'answer_start': [177, 177, 177],\n", " 'text': ['Denver Broncos', 'Denver Broncos', 'Denver Broncos']},\n", " {'answer_start': [249, 249, 249],\n", " 'text': ['Carolina Panthers',\n", " 'Carolina Panthers',\n", " 'Carolina Panthers']}],\n", " 'context': ['Super Bowl 50 was an American football game to determine the '\n", " 'champion of the National Football League (NFL) for the 2015 '\n", " 'season. The American Football Conference (AFC) champion Denver '\n", " 'Broncos defeated the National Football Conference (NFC) champion '\n", " 'Carolina Panthers 24–10 to earn their third Super Bowl title. '\n", " \"The game was played on February 7, 2016, at Levi's Stadium in \"\n", " 'the San Francisco Bay Area at Santa Clara, California. As this '\n", " 'was the 50th Super Bowl, the league emphasized the \"golden '\n", " 'anniversary\" with various gold-themed initiatives, as well as '\n", " 'temporarily suspending the tradition of naming each Super Bowl '\n", " 'game with Roman numerals (under which the game would have been '\n", " 'known as \"Super Bowl L\"), so that the logo could prominently '\n", " 'feature the Arabic numerals 50.',\n", " 'Super Bowl 50 was an American football game to determine the '\n", " 'champion of the National Football League (NFL) for the 2015 '\n", " 'season. The American Football Conference (AFC) champion Denver '\n", " 'Broncos defeated the National Football Conference (NFC) champion '\n", " 'Carolina Panthers 24–10 to earn their third Super Bowl title. '\n", " \"The game was played on February 7, 2016, at Levi's Stadium in \"\n", " 'the San Francisco Bay Area at Santa Clara, California. As this '\n", " 'was the 50th Super Bowl, the league emphasized the \"golden '\n", " 'anniversary\" with various gold-themed initiatives, as well as '\n", " 'temporarily suspending the tradition of naming each Super Bowl '\n", " 'game with Roman numerals (under which the game would have been '\n", " 'known as \"Super Bowl L\"), so that the logo could prominently '\n", " 'feature the Arabic numerals 50.'],\n", " 'id': ['56be4db0acb8001400a502ec', '56be4db0acb8001400a502ed'],\n", " 'question': ['Which NFL team represented the AFC at Super Bowl 50?',\n", " 'Which NFL team represented the NFC at Super Bowl 50?'],\n", " 'title': ['Super_Bowl_50', 'Super_Bowl_50']}\n" ] } ], "source": [ "from pprint import pprint\n", "\n", "print(f\"Dataset len(dataset): {len(dataset)}\")\n", "print(\"First item:\")\n", "pprint(dataset[0])\n", "print(\"Slice of the first two items:\")\n", "pprint(dataset[:2])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can get a full column of the dataset by indexing with its name as a string:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['Which NFL team represented the AFC at Super Bowl 50?', 'Which NFL team represented the NFC at Super Bowl 50?', 'Where did Super Bowl 50 take place?', 'Which NFL team won Super Bowl 50?', 'What color was used to emphasize the 50th anniversary of the Super Bowl?', 'What was the theme of Super Bowl 50?', 'What day was the game played on?', 'What is the AFC short for?', 'What was the theme of Super Bowl 50?', 'What does AFC stand for?']\n" ] } ], "source": [ "print(dataset['question'][:10])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Items are returned as dict of element.\n", "\n", "Slices are returned as dict of lists of elements.\n", "\n", "Columns are returned as a list." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can thus permute slice, index and columns indexings with identical results:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "True\n", "True\n" ] } ], "source": [ "print(dataset[0]['question'] == dataset['question'][0])\n", "print(dataset[10:20]['context'] == dataset['context'][10:20])" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['id', 'title', 'context', 'question', 'answers']\n", "id: string\n", "title: string\n", "context: string\n", "question: string\n", "answers: struct, answer_start: list>\n", " child 0, text: list\n", " child 0, item: string\n", " child 1, answer_start: list\n", " child 0, item: int32\n" ] } ], "source": [ "# The underlying table is typed (int/float/strings/lists/dict) and structured \n", "print(dataset.column_names)\n", "print(dataset.schema)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Additional misc properties" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The number of bytes allocated on the drive is 10472672\n", "For comparison, here is the number of bytes allocated in memory which can be\n", "accessed with `nlp.total_allocated_bytes()`: 0\n", "The number of rows 1057\n", "The number of columns 5\n", "The shape (rows, columns) (1057, 5)\n" ] } ], "source": [ "# Datasets also have a bunch of properties you can access\n", "print(\"The number of bytes allocated on the drive is \", dataset.nbytes)\n", "print(\"For comparison, here is the number of bytes allocated in memory which can be\")\n", "print(\"accessed with `nlp.total_allocated_bytes()`: \", nlp.total_allocated_bytes())\n", "print(\"The number of rows\", dataset.num_rows)\n", "print(\"The number of columns\", dataset.num_columns)\n", "print(\"The shape (rows, columns)\", dataset.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Additional misc methods" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['Super_Bowl_50', 'Warsaw']\n" ] } ], "source": [ "# We can list the unique elements in a column. This is done by the backend (so fast!)\n", "print(dataset.unique('title'))" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['title', 'context', 'question', 'answers']\n" ] } ], "source": [ "# This will drop the column 'id'\n", "dataset.drop('id') # Remove column 'id'\n", "print(dataset.column_names)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['title', 'context', 'question', 'answers.text', 'answers.answer_start']\n" ] } ], "source": [ "# This will flatten the nested columns in 'answers'\n", "dataset.flatten()\n", "print(dataset.column_names)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "# We can also \"dictionnary encode\" a column if many of it's elements are similar\n", "# This will reduce it's size by only storing the distinct elements (e.g. string)\n", "# It only has effect on the internal storage (no difference from a user point of view)\n", "dataset.dictionary_encode_column('title')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Cache files" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can check the current cache files backing the dataset with the `.cache_file` property" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "({'filename': '/Users/thomwolf/.cache/huggingface/datasets/squad/plain_text/1.0.0/squad-validation.arrow',\n", " 'skip': 0,\n", " 'take': 1057},)" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset.cache_files" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can clean up the cache files for in the current dataset directory with the `.cleanup_cache_files()`.\n", "\n", "Be careful that no other process is using these cache files when running this command." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:nlp.arrow_dataset:Listing files in /Users/thomwolf/.cache/huggingface/datasets/squad/plain_text/1.0.0\n", "INFO:nlp.arrow_dataset:Removing /Users/thomwolf/.cache/huggingface/datasets/squad/plain_text/1.0.0/cache-2b0c4368cd1b9d9ab7dd158754adb501.arrow\n", "INFO:nlp.arrow_dataset:Removing /Users/thomwolf/.cache/huggingface/datasets/squad/plain_text/1.0.0/cache-fef84cefe794447d6dc0b28596974c80.arrow\n", "INFO:nlp.arrow_dataset:Removing /Users/thomwolf/.cache/huggingface/datasets/squad/plain_text/1.0.0/cache-b9d042be98ac7ed20cb12b2e9d65d208.arrow\n", "INFO:nlp.arrow_dataset:Removing /Users/thomwolf/.cache/huggingface/datasets/squad/plain_text/1.0.0/cache-0d81cced63f868bf1a233bffb4c94b85.arrow\n", "INFO:nlp.arrow_dataset:Removing /Users/thomwolf/.cache/huggingface/datasets/squad/plain_text/1.0.0/cache-fdd554f8e6ee8230941052eceac92e0f.arrow\n", "INFO:nlp.arrow_dataset:Removing /Users/thomwolf/.cache/huggingface/datasets/squad/plain_text/1.0.0/cache-2d5d9f6d0f564bbd27c91aee95cfc0dc.arrow\n", "INFO:nlp.arrow_dataset:Removing /Users/thomwolf/.cache/huggingface/datasets/squad/plain_text/1.0.0/cache-79ea07cbbe2ddf3afe1d0c6ac0269cc3.arrow\n" ] }, { "data": { "text/plain": [ "7" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset.cleanup_cache_files() # Returns the number of removed cache files" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Modifying the dataset with `dataset.map`\n", "\n", "There is a powerful method `.map()` that you can use to apply a function to each examples, independantly or in batch." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "1057it [00:00, 10624.60it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "775,775,775,775,775,775,775,775,775,775,775,775,775,775,775,775,775,775,775,775,775,775,775,775,775,775,775,775,775,775,775,775,637,637,637,637,637,637,637,637,637,637,637,637,637,637,637,637,637,637,637,637,637,637,637,637,637,347,347,347,347,347,347,347,347,347,347,347,347,347,347,347,347,347,347,347,347,347,347,347,347,347,394,394,394,394,394,394,394,394,394,394,394,394,394,394,394,394,394,394,394,394,394,394,394,394,394,394,179,179,179,179,179,179,179,179,179,179,179,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,638,638,638,638,638,638,638,638,638,638,638,638,638,638,638,638,638,638,638,638,638,326,326,326,326,326,326,326,326,326,326,326,326,326,326,326,326,326,326,326,326,326,326,326,704,704,704,704,704,704,704,704,704,704,704,704,704,704,704,704,704,704,917,917,917,917,917,917,917,917,917,917,917,917,917,917,917,917,917,917,917,917,1271,1271,1271,1271,1271,1271,1271,1271,1271,1271,1271,1271,1271,1271,1271,1271,1271,1271,1271,1271,1166,1166,1166,1166,1166,1166,1166,1166,1166,1166,1166,1166,1166,1166,2060,2060,2060,2060,2060,2060,2060,2060,2060,2060,2060,2060,2060,2060,2060,2060,2060,2060,2060,2060,929,929,929,929,929,929,929,929,929,929,929,929,929,929,929,929,929,929,929,704,704,704,704,704,704,704,704,704,704,704,704,704,704,353,353,353,353,353,353,353,353,353,353,353,353,353,353,353,464,464,464,464,464,464,464,464,464,464,464,464,464,464,464,464,306,306,306,306,306,306,306,306,306,306,306,306,372,372,372,372,372,372,372,372,372,372,372,372,372,372,372,372,372,496,496,496,496,496,496,496,496,496,496,496,496,496,496,496,260,260,260,260,260,260,260,260,260,874,874,874,874,874,874,874,874,874,874,874,874,874,874,1025,1025,1025,1025,1025,1025,1025,1025,1025,1025,1025,1025,1025,1025,1025,176,176,176,176,176,176,176,176,176,176,176,176,176,176,176,176,782,782,782,782,782,782,782,782,782,782,782,782,782,782,782,782,536,536,536,536,536,536,536,536,536,666,666,666,666,666,666,666,666,666,666,666,666,666,666,666,666,666,495,495,495,495,495,495,495,495,495,495,495,385,385,385,385,385,385,385,385,385,385,385,385,385,385,385,385,385,385,385,441,441,441,441,441,441,441,441,441,441,441,357,357,357,357,357,357,357,357,357,296,296,296,296,296,296,296,296,296,296,644,644,644,644,644,644,644,644,644,644,644,644,644,644,644,644,644,804,804,804,804,804,804,804,804,804,804,804,397,397,397,397,397,397,397,397,397,397,397,397,397,397,360,360,360,360,360,360,360,973,973,973,973,973,973,973,973,973,973,973,973,973,973,263,263,263,263,263,263,263,263,263,263,263,568,568,568,568,568,568,568,568,568,568,568,264,264,264,264,264,264,264,264,264,264,264,264,264,264,264,892,892,892,892,892,892,892,892,892,892,892,206,206,206,206,206,489,489,489,489,489,489,489,489,489,489,489,489,489,181,181,181,181,181,181,181,181,181,181,181,181,531,531,531,531,531,531,531,531,531,531,531,531,664,664,664,664,664,664,664,664,664,664,664,664,664,664,672,672,672,672,672,672,672,672,672,672,672,672,672,672,858,858,858,858,858,858,858,858,858,858,858,858,634,634,634,634,634,634,634,634,634,634,634,634,634,634,891,891,891,891,891,891,891,891,891,891,891,891,891,488,488,488,488,488,488,488,488,488,488,488,488,942,942,942,942,942,942,942,942,942,942,942,942,942,942,942,1162,1162,1162,1162,1162,1162,1162,1162,1162,1162,1162,1162,1162,1162,1162,1353,1353,1353,1353,1353,1353,1353,1353,1353,1353,1353,1353,1353,1353,522,522,522,522,522,1643,1643,1643,1643,1643,628,628,628,628,628,758,758,758,758,758,883,883,883,883,883,559,559,559,559,559,603,603,603,603,631,631,631,631,631,626,626,626,626,626,541,541,541,541,541,795,795,795,795,795,591,591,591,591,591,568,568,568,568,568,536,536,536,536,536,575,575,575,575,575,571,571,571,571,571,641,641,641,641,641,665,665,665,665,665,1088,1088,1088,1088,1088,1619,1619,1619,1619,1619,939,939,939,939,939,865,865,865,865,865,711,711,711,711,711,831,831,831,831,831,501,501,501,501,501,676,676,676,676,676,854,854,854,854,854,784,784,784,784,784,641,641,641,641,641,544,544,544,544,544,918,918,918,918,918,763,763,763,763,763,906,906,906,906,906,632,632,632,632,632,869,869,869,869,869,1044,1044,1044,1044,1044,760,760,760,760,760,715,715,715,715,715,838,838,838,838,838,881,881,881,881,881,940,940,940,940,940,618,618,618,618,618,1205,1205,1205,534,534,534,534,534,757,757,757,757,757,1239,1239,1239,1239,1239,609,609,609,609,609,798,798,798,798,798,613,613,613,613,613,613,613,613,613,613," ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] }, { "data": { "text/plain": [ "Dataset(schema: {'title': 'string', 'context': 'string', 'question': 'string', 'answers.text': 'list', 'answers.answer_start': 'list'}, num_rows: 1057)" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# `.map()` takes a callable accepting a dict as argument\n", "# (same dict as returned by dataset[i])\n", "# and iterate over the dataset by calling the function with each example.\n", "\n", "# Let's print the length of each `context` string in our subset of the dataset\n", "# (10% of the validation i.e. 1057 examples)\n", "\n", "dataset.map(lambda example: print(len(example['context']), end=','))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is basically the same as doing\n", "\n", "```python\n", "for example in dataset:\n", " function(example)\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The above example had no effect on the dataset because our function supplied to `.map()` didn't return a `dict` or a `abc.Mapping` that could be used to update the examples in the dataset. `.map()` then just return the same dataset (`self`).\n", "\n", "Now let's see how to use a function that can modify the dataset." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Modifying the dataset example by example" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The main interest of `.map()` is to update and modify the content of the table.\n", "\n", "To use `.map()` to update elements in the table you should provide a function with the following signature: `function(example: dict) -> dict`." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:nlp.arrow_dataset:Caching processed dataset at /Users/thomwolf/.cache/huggingface/datasets/squad/plain_text/1.0.0/cache-ed8b1249a765df5c159965379e685e44.arrow\n", "1057it [00:00, 21208.28it/s]\n", "INFO:nlp.arrow_writer:Done writing 1057 examples in 906626 bytes /Users/thomwolf/.cache/huggingface/datasets/squad/plain_text/1.0.0/cache-ed8b1249a765df5c159965379e685e44.arrow.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "['My cute title: Super_Bowl_50', 'My cute title: Warsaw']\n" ] } ], "source": [ "# Let's add a prefix 'My cute title: ' to each of our titles\n", "\n", "def add_prefix_to_title(example):\n", " example['title'] = 'My cute title: ' + example['title']\n", " return example\n", "\n", "dataset = dataset.map(add_prefix_to_title)\n", "\n", "print(dataset.unique('title'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This call to `.map()` compute and return the updated table. It will also store the updated table in a cache file indexed by the current state and the mapped function. A subsequent call to `.map()` (even in another python session) will reuse the cached file instead of recomputing the operation (this caching may not work in jupyter notebooks yet).\n", "\n", "The returned updated dataset is (again) directly memory mapped from drive and not allocated in RAM." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Your function should accept an input with the format of an item of the dataset: `function(dataset[0])` and return a python dict.\n", "\n", "The columns and type of the outputs can be different than the input dict. In this case the new keys will be added as additional columns in the dataset.\n", "\n", "The example is `updated()` with the output dictionary: `examples.update(function(example))`." ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:nlp.arrow_dataset:Caching processed dataset at /Users/thomwolf/.cache/huggingface/datasets/squad/plain_text/1.0.0/cache-17091682b8ed78e55221838b2595bbd5.arrow\n", "1057it [00:00, 24103.23it/s]\n", "INFO:nlp.arrow_writer:Done writing 1057 examples in 924595 bytes /Users/thomwolf/.cache/huggingface/datasets/squad/plain_text/1.0.0/cache-17091682b8ed78e55221838b2595bbd5.arrow.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "['My cutest title: My cute title: Super_Bowl_50', 'My cutest title: My cute title: Warsaw']\n" ] } ], "source": [ "# Since the input example is updated with our function output,\n", "# we can actually just return the updated 'title' field\n", "dataset = dataset.map(lambda example: {'title': 'My cutest title: ' + example['title']})\n", "\n", "print(dataset.unique('title'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Removing columns\n", "You can also remove columns when running map with the `remove_columns=List[str]` argument." ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:nlp.arrow_dataset:Caching processed dataset at /Users/thomwolf/.cache/huggingface/datasets/squad/plain_text/1.0.0/cache-9741cad18be490ab827b103119d5c732.arrow\n", "1057it [00:00, 25135.67it/s]\n", "INFO:nlp.arrow_writer:Done writing 1057 examples in 934108 bytes /Users/thomwolf/.cache/huggingface/datasets/squad/plain_text/1.0.0/cache-9741cad18be490ab827b103119d5c732.arrow.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "['context', 'question', 'answers.text', 'answers.answer_start', 'new_title']\n", "['Wouhahh: My cutest title: My cute title: Super_Bowl_50', 'Wouhahh: My cutest title: My cute title: Warsaw']\n" ] } ], "source": [ "# This will select the 'title' input to send to our function (as only field in the input)\n", "# and replace it with the output of the method as a 'new_title' field\n", "dataset = dataset.map(lambda example: {'new_title': 'Wouhahh: ' + example['title']},\n", " remove_columns=['title'])\n", "\n", "print(dataset.column_names)\n", "print(dataset.unique('new_title'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Using examples indices\n", "With `with_indices=True`, dataset indices (from `0` to `len(dataset)`) will be supplied to the function which must thus have the following signature: `function(example: dict, indice: int) -> dict`" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:nlp.arrow_dataset:Caching processed dataset at /Users/thomwolf/.cache/huggingface/datasets/squad/plain_text/1.0.0/cache-93827205b7769e301be275a794040d51.arrow\n", "1057it [00:00, 24952.75it/s]\n", "INFO:nlp.arrow_writer:Done writing 1057 examples in 939340 bytes /Users/thomwolf/.cache/huggingface/datasets/squad/plain_text/1.0.0/cache-93827205b7769e301be275a794040d51.arrow.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "0: Which NFL team represented the AFC at Super Bowl 50?\n", "1: Which NFL team represented the NFC at Super Bowl 50?\n", "2: Where did Super Bowl 50 take place?\n", "3: Which NFL team won Super Bowl 50?\n", "4: What color was used to emphasize the 50th anniversary of the Super Bowl?\n" ] } ], "source": [ "# This will add the index in the dataset to the 'question' field\n", "dataset = dataset.map(lambda example, idx: {'question': f'{idx}: ' + example['question']},\n", " with_indices=True)\n", "\n", "print('\\n'.join(dataset['question'][:5]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Modifying the dataset with batched updates" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`.map()` can also work with batch of examples (slices of the dataset).\n", "\n", "This is particularly interesting if you have a function that can handle batch of inputs like the tokenizers of HuggingFace `tokenizers`.\n", "\n", "To work on batched inputs set `batched=True` when calling `.map()` and supply a function with the following signature: `function(examples: Dict[List]) -> Dict[List]` or, if you use indices, `function(examples: Dict[List], indices: List[int]) -> Dict[List]`).\n", "\n", "Your function should accept an input with the format of a slice of the dataset: e.g. `function(dataset[:10])`." ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:transformers.file_utils:PyTorch version 1.4.0 available.\n", "INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt from cache at /Users/thomwolf/.cache/torch/transformers/5e8a2b4893d13790ed4150ca1906be5f7a03d6c4ddf62296c383f6db42814db2.e13dbb970cb325137104fb2e5f36fe865f27746c6b526f6352861b1980eb80b1\n" ] } ], "source": [ "# Let's import a fast tokenizer that can work on batched inputs\n", "# (the 'Fast' tokenizers in HuggingFace)\n", "from transformers import BertTokenizerFast\n", "\n", "tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:nlp.arrow_dataset:Caching processed dataset at /Users/thomwolf/.cache/huggingface/datasets/squad/plain_text/1.0.0/cache-09ed6375515654521b963025766295d1.arrow\n", "100%|██████████| 2/2 [00:00<00:00, 18.20it/s]\n", "INFO:nlp.arrow_writer:Done writing 1057 examples in 4811564 bytes /Users/thomwolf/.cache/huggingface/datasets/squad/plain_text/1.0.0/cache-09ed6375515654521b963025766295d1.arrow.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "dataset[0] {'context': 'Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi\\'s Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the \"golden anniversary\" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as \"Super Bowl L\"), so that the logo could prominently feature the Arabic numerals 50.', 'question': '0: Which NFL team represented the AFC at Super Bowl 50?', 'answers.text': ['Denver Broncos', 'Denver Broncos', 'Denver Broncos'], 'answers.answer_start': [177, 177, 177], 'new_title': 'Wouhahh: My cutest title: My cute title: Super_Bowl_50', 'input_ids': [101, 3198, 5308, 1851, 1108, 1126, 1237, 1709, 1342, 1106, 4959, 1103, 3628, 1104, 1103, 1305, 2289, 1453, 113, 4279, 114, 1111, 1103, 1410, 1265, 119, 1109, 1237, 2289, 3047, 113, 10402, 114, 3628, 7068, 14722, 2378, 1103, 1305, 2289, 3047, 113, 24743, 114, 3628, 2938, 13598, 1572, 782, 1275, 1106, 7379, 1147, 1503, 3198, 5308, 1641, 119, 1109, 1342, 1108, 1307, 1113, 1428, 128, 117, 1446, 117, 1120, 12388, 112, 188, 3339, 1107, 1103, 1727, 2948, 2410, 3894, 1120, 3364, 10200, 117, 1756, 119, 1249, 1142, 1108, 1103, 13163, 3198, 5308, 117, 1103, 2074, 13463, 1103, 107, 5404, 5453, 107, 1114, 1672, 2284, 118, 12005, 11751, 117, 1112, 1218, 1112, 7818, 28117, 20080, 16264, 1103, 3904, 1104, 10505, 1296, 3198, 5308, 1342, 1114, 2264, 183, 15447, 16179, 113, 1223, 1134, 1103, 1342, 1156, 1138, 1151, 1227, 1112, 107, 3198, 5308, 149, 107, 114, 117, 1177, 1115, 1103, 7998, 1180, 15199, 2672, 1103, 4944, 183, 15447, 16179, 1851, 119, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\n" ] } ], "source": [ "# Now let's batch tokenize our dataset 'context'\n", "dataset = dataset.map(lambda example: tokenizer.batch_encode_plus(example['context']),\n", " batched=True)\n", "\n", "print(\"dataset[0]\", dataset[0])" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['context', 'question', 'answers.text', 'answers.answer_start', 'new_title', 'input_ids', 'token_type_ids', 'attention_mask']\n" ] } ], "source": [ "# we have added additional columns\n", "# we could have replaced the dataset with `remove_columns=True`\n", "print(dataset.column_names)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:nlp.arrow_dataset:Caching processed dataset at /Users/thomwolf/.cache/huggingface/datasets/squad/plain_text/1.0.0/cache-b01e56f9216a5c4e04189ae568585041.arrow\n", "100%|██████████| 2/2 [00:00<00:00, 6.16it/s]\n", "INFO:nlp.arrow_writer:Done writing 1057 examples in 21999734 bytes /Users/thomwolf/.cache/huggingface/datasets/squad/plain_text/1.0.0/cache-b01e56f9216a5c4e04189ae568585041.arrow.\n" ] } ], "source": [ "# Let show a more complex processing with the full preparation of the SQuAD dataset\n", "# for training a model from Transformers\n", "def convert_to_features(batch):\n", " # Tokenize contexts and questions (as pairs of inputs)\n", " # keep offset mappings for evaluation\n", " input_pairs = list(zip(batch['context'], batch['question']))\n", " encodings = tokenizer.batch_encode_plus(input_pairs,\n", " pad_to_max_length=True,\n", " return_offsets_mapping=True)\n", "\n", " # Compute start and end tokens for labels\n", " start_positions, end_positions = [], []\n", " for i, (text, start) in enumerate(zip(batch['answers.text'], batch['answers.answer_start'])):\n", " first_char = start[0]\n", " last_char = first_char + len(text[0]) - 1\n", " start_positions.append(encodings.char_to_token(i, first_char))\n", " end_positions.append(encodings.char_to_token(i, last_char))\n", "\n", " encodings.update({'start_positions': start_positions, 'end_positions': end_positions})\n", " return encodings\n", "\n", "dataset = dataset.map(convert_to_features, batched=True)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "column_names ['context', 'question', 'answers.text', 'answers.answer_start', 'new_title', 'input_ids', 'token_type_ids', 'attention_mask', 'offset_mapping', 'start_positions', 'end_positions']\n", "start_positions [34, 45, 80, 34, 98]\n" ] } ], "source": [ "# Now our dataset comprise the labels for the start and end position\n", "# as well as the offsets for converting back tokens\n", "# in span of the original string for evaluation\n", "print(\"column_names\", dataset.column_names)\n", "print(\"start_positions\", dataset[:5]['start_positions'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Formating outputs for numpy/torch/tensorflow\n", "\n", "Now that we hae all our tokenized inputs, we would like to use this dataset in a `torch.Dataloader` or a `tf.data.Dataset`.\n", "\n", "To be able to do this we need to tweak two things:\n", "\n", "- format the indexing (`__getitem__`) to return numpy/torch/tensorflow tensors, instead of python objects, and\n", "- format the indexing (`__getitem__`) to return only the subset of the columns that we need for our model inputs.\n", "\n", " We don't want the columns `id` or `title` as input sto train our model, but we could still want to keep them in the dataset, for instance for the evaluation of the model.\n", " \n", "This is handled by the `.set_format(type: Union[None, str], columns: Union[None, str, List[str]])` where:\n", "\n", "- `type` define the return type for our dataset `__getitem__` method and is one of `[None, 'numpy', 'torch', 'tensorflow']` (`None` means return python objects), and\n", "- `columns` define the columns returned by `__getitem__` and takes the name of a column in the dataset or a list of columns to return (`None` means return all columns)." ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:nlp.arrow_dataset:Set __getitem__(key) output type to torch and filter ['input_ids', 'token_type_ids', 'attention_mask', 'start_positions', 'end_positions'] columns (when key is int or slice).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "input_ids torch.Size([10, 451])\n", "token_type_ids torch.Size([10, 451])\n", "attention_mask torch.Size([10, 451])\n", "start_positions torch.Size([10])\n", "end_positions torch.Size([10])\n" ] } ], "source": [ "columns_to_return = ['input_ids', 'token_type_ids', 'attention_mask',\n", " 'start_positions', 'end_positions']\n", "\n", "dataset.set_format(type='torch',\n", " columns=columns_to_return)\n", "\n", "# Our dataset indexing output is now ready for being used in a pytorch dataloader\n", "print('\\n'.join([' '.join((n, str(type(t)), str(t.shape))) for n, t in dataset[:10].items()]))" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['context', 'question', 'answers.text', 'answers.answer_start', 'new_title', 'input_ids', 'token_type_ids', 'attention_mask', 'offset_mapping', 'start_positions', 'end_positions']\n" ] } ], "source": [ "# Note that the columns are not removed from the dataset,\n", "# just not returned when calling __getitem__\n", "print(dataset.column_names)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:nlp.arrow_dataset:Set __getitem__(key) output type to python objects and filter no columns (when key is int or slice).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "context \n", "question \n", "answers.text \n", "answers.answer_start \n", "new_title \n", "input_ids \n", "token_type_ids \n", "attention_mask \n", "offset_mapping \n", "start_positions \n", "end_positions \n" ] } ], "source": [ "# We can remove the formating with `.reset_format()`\n", "# or, identically, a call to `.set_format()` with no arguments\n", "dataset.reset_format()\n", "\n", "print('\\n'.join([' '.join((n, str(type(t)))) for n, t in dataset[:10].items()]))" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'type': 'python',\n", " 'columns': ['context',\n", " 'question',\n", " 'answers.text',\n", " 'answers.answer_start',\n", " 'new_title',\n", " 'input_ids',\n", " 'token_type_ids',\n", " 'attention_mask',\n", " 'offset_mapping',\n", " 'start_positions',\n", " 'end_positions']}" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# The current format can be checked with `.format`,\n", "# which is a dict of the type and formating\n", "dataset.format" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Wrapping this all up\n", "\n", "Let's wrap this all up with the full code to load and prepare SQuAD for training a PyTorch model." ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:nlp.load:Dataset script /Users/thomwolf/.cache/huggingface/datasets/ee43d2be6898ebb9c2afefda4455306911d308bcf924d21c975796832cc7c114.e7d8881147e5da61c98918c61832c7f1c88b33b51a082c464e70e119bb24983d already found in datasets directory at /Users/thomwolf/Documents/GitHub/datasets/src/nlp/datasets/686d79c021d7dcd78da4d67fe01fbe30dfecabcd4bd02d06aa9d51edab713144/squad.py, returning it. Use `force_reload=True` to override.\n", "INFO:nlp.builder:No config specified, defaulting to first: squad/plain_text\n", "INFO:nlp.builder:Overwrite dataset info from restored data version.\n", "INFO:nlp.info:Loading Dataset info from /Users/thomwolf/.cache/huggingface/datasets/squad/plain_text/1.0.0\n", "INFO:nlp.builder:Reusing dataset squad (/Users/thomwolf/.cache/huggingface/datasets/squad/plain_text/1.0.0)\n", "INFO:nlp.builder:Constructing Dataset for split None, from /Users/thomwolf/.cache/huggingface/datasets/squad/plain_text/1.0.0\n", "INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt from cache at /Users/thomwolf/.cache/torch/transformers/5e8a2b4893d13790ed4150ca1906be5f7a03d6c4ddf62296c383f6db42814db2.e13dbb970cb325137104fb2e5f36fe865f27746c6b526f6352861b1980eb80b1\n", "INFO:nlp.arrow_dataset:Caching processed dataset at /Users/thomwolf/.cache/huggingface/datasets/squad/plain_text/1.0.0/cache-7008a010d9ded38b9f1f7e5bfe57c19a.arrow\n", "100%|██████████| 88/88 [00:15<00:00, 5.83it/s]\n", "INFO:nlp.arrow_writer:Done writing 87599 examples in 1114822607 bytes /Users/thomwolf/.cache/huggingface/datasets/squad/plain_text/1.0.0/cache-7008a010d9ded38b9f1f7e5bfe57c19a.arrow.\n", "INFO:nlp.arrow_dataset:Set __getitem__(key) output type to torch and filter ['input_ids', 'token_type_ids', 'attention_mask', 'start_positions', 'end_positions'] columns (when key is int or slice).\n" ] } ], "source": [ "import nlp\n", "import torch \n", "from transformers import BertTokenizerFast\n", "\n", "# Load our training dataset and tokenizer\n", "dataset = nlp.load('squad')\n", "tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')\n", "\n", "# Tokenize our training dataset\n", "def convert_to_features(example_batch):\n", " # Tokenize contexts and questions (as pairs of inputs)\n", " input_pairs = list(zip(example_batch['context'], example_batch['question']))\n", " encodings = tokenizer.batch_encode_plus(input_pairs, pad_to_max_length=True)\n", "\n", " # Compute start and end tokens for labels\n", " start_positions, end_positions = [], []\n", " for i, answer in enumerate(example_batch['answers']):\n", " first_char = answer['answer_start'][0]\n", " last_char = first_char + len(answer['text'][0]) - 1\n", " start_positions.append(encodings.char_to_token(i, first_char))\n", " end_positions.append(encodings.char_to_token(i, last_char))\n", "\n", " encodings.update({'start_positions': start_positions,\n", " 'end_positions': end_positions})\n", " return encodings\n", "\n", "dataset['train'] = dataset['train'].map(convert_to_features, batched=True)\n", "\n", "# Format our outputs to train a pytorch model\n", "columns = ['input_ids', 'token_type_ids', 'attention_mask', 'start_positions', 'end_positions']\n", "dataset['train'].set_format(type='torch', columns=columns)\n", "\n", "# Instantiate a PyTorch Dataloader around our dataset\n", "dataloader = torch.utils.data.DataLoader(dataset['train'], batch_size=8)" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:transformers.configuration_utils:loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json from cache at /Users/thomwolf/.cache/torch/transformers/b945b69218e98b3e2c95acf911789741307dec43c698d35fad11c1ae28bda352.3d5adf10d3445c36ce131f4c6416aa62e9b58e1af56b97664773f4858a46286e\n", "INFO:transformers.configuration_utils:Model config BertConfig {\n", " \"_num_labels\": 2,\n", " \"architectures\": [\n", " \"BertForMaskedLM\"\n", " ],\n", " \"attention_probs_dropout_prob\": 0.1,\n", " \"bad_words_ids\": null,\n", " \"bos_token_id\": null,\n", " \"decoder_start_token_id\": null,\n", " \"do_sample\": false,\n", " \"early_stopping\": false,\n", " \"eos_token_id\": null,\n", " \"finetuning_task\": null,\n", " \"hidden_act\": \"gelu\",\n", " \"hidden_dropout_prob\": 0.1,\n", " \"hidden_size\": 768,\n", " \"id2label\": {\n", " \"0\": \"LABEL_0\",\n", " \"1\": \"LABEL_1\"\n", " },\n", " \"initializer_range\": 0.02,\n", " \"intermediate_size\": 3072,\n", " \"is_decoder\": false,\n", " \"is_encoder_decoder\": false,\n", " \"label2id\": {\n", " \"LABEL_0\": 0,\n", " \"LABEL_1\": 1\n", " },\n", " \"layer_norm_eps\": 1e-12,\n", " \"length_penalty\": 1.0,\n", " \"max_length\": 20,\n", " \"max_position_embeddings\": 512,\n", " \"min_length\": 0,\n", " \"model_type\": \"bert\",\n", " \"no_repeat_ngram_size\": 0,\n", " \"num_attention_heads\": 12,\n", " \"num_beams\": 1,\n", " \"num_hidden_layers\": 12,\n", " \"num_return_sequences\": 1,\n", " \"output_attentions\": false,\n", " \"output_hidden_states\": false,\n", " \"output_past\": true,\n", " \"pad_token_id\": 0,\n", " \"prefix\": null,\n", " \"pruned_heads\": {},\n", " \"repetition_penalty\": 1.0,\n", " \"task_specific_params\": null,\n", " \"temperature\": 1.0,\n", " \"top_k\": 50,\n", " \"top_p\": 1.0,\n", " \"torchscript\": false,\n", " \"type_vocab_size\": 2,\n", " \"use_bfloat16\": false,\n", " \"vocab_size\": 28996\n", "}\n", "\n", "INFO:transformers.modeling_utils:loading weights file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin from cache at /Users/thomwolf/.cache/torch/transformers/35d8b9d36faaf46728a0192d82bf7d00137490cd6074e8500778afed552a67e5.3fadbea36527ae472139fe84cddaa65454d7429f12d543d80bfc3ad70de55ac2\n", "INFO:transformers.modeling_utils:Weights of BertForQuestionAnswering not initialized from pretrained model: ['qa_outputs.weight', 'qa_outputs.bias']\n", "INFO:transformers.modeling_utils:Weights from pretrained model not used in BertForQuestionAnswering: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']\n" ] } ], "source": [ "# Let's load a pretrained Bert model and a simple optimizer\n", "from transformers import BertForQuestionAnswering\n", "\n", "model = BertForQuestionAnswering.from_pretrained('bert-base-cased')\n", "optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Step 0 - loss: 6.26\n" ] }, { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataloader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/miniconda2/envs/datasets/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 530\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 531\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 532\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 533\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 534\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/Documents/GitHub/transformers/src/transformers/modeling_bert.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, start_positions, end_positions)\u001b[0m\n\u001b[1;32m 1478\u001b[0m \u001b[0mposition_ids\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mposition_ids\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1479\u001b[0m \u001b[0mhead_mask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mhead_mask\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1480\u001b[0;31m \u001b[0minputs_embeds\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs_embeds\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1481\u001b[0m )\n\u001b[1;32m 1482\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/miniconda2/envs/datasets/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 530\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 531\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 532\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 533\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 534\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/Documents/GitHub/transformers/src/transformers/modeling_bert.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask)\u001b[0m\n\u001b[1;32m 788\u001b[0m \u001b[0mhead_mask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mhead_mask\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 789\u001b[0m \u001b[0mencoder_hidden_states\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mencoder_hidden_states\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 790\u001b[0;31m \u001b[0mencoder_attention_mask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mencoder_extended_attention_mask\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 791\u001b[0m )\n\u001b[1;32m 792\u001b[0m \u001b[0msequence_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mencoder_outputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/miniconda2/envs/datasets/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 530\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 531\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 532\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 533\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 534\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/Documents/GitHub/transformers/src/transformers/modeling_bert.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask)\u001b[0m\n\u001b[1;32m 405\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 406\u001b[0m layer_outputs = layer_module(\n\u001b[0;32m--> 407\u001b[0;31m \u001b[0mhidden_states\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattention_mask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhead_mask\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mencoder_hidden_states\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mencoder_attention_mask\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 408\u001b[0m )\n\u001b[1;32m 409\u001b[0m \u001b[0mhidden_states\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlayer_outputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/miniconda2/envs/datasets/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 530\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 531\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 532\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 533\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 534\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/Documents/GitHub/transformers/src/transformers/modeling_bert.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask)\u001b[0m\n\u001b[1;32m 366\u001b[0m \u001b[0mencoder_attention_mask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 367\u001b[0m ):\n\u001b[0;32m--> 368\u001b[0;31m \u001b[0mself_attention_outputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mattention\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhidden_states\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattention_mask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhead_mask\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 369\u001b[0m \u001b[0mattention_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself_attention_outputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 370\u001b[0m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself_attention_outputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;31m# add self attentions if we output attention weights\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/miniconda2/envs/datasets/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 530\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 531\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 532\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 533\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 534\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/Documents/GitHub/transformers/src/transformers/modeling_bert.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask)\u001b[0m\n\u001b[1;32m 312\u001b[0m ):\n\u001b[1;32m 313\u001b[0m self_outputs = self.self(\n\u001b[0;32m--> 314\u001b[0;31m \u001b[0mhidden_states\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattention_mask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhead_mask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mencoder_hidden_states\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mencoder_attention_mask\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 315\u001b[0m )\n\u001b[1;32m 316\u001b[0m \u001b[0mattention_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself_outputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhidden_states\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/miniconda2/envs/datasets/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 530\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 531\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 532\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 533\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 534\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/Documents/GitHub/transformers/src/transformers/modeling_bert.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask)\u001b[0m\n\u001b[1;32m 239\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 240\u001b[0m \u001b[0;31m# Normalize the attention scores to probabilities.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 241\u001b[0;31m \u001b[0mattention_probs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSoftmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mattention_scores\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 242\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 243\u001b[0m \u001b[0;31m# This is actually dropping out entire tokens to attend to, which might\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/miniconda2/envs/datasets/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 530\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 531\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 532\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 533\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 534\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/miniconda2/envs/datasets/lib/python3.7/site-packages/torch/nn/modules/activation.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 1016\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1017\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1018\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msoftmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_stacklevel\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1019\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1020\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mextra_repr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/miniconda2/envs/datasets/lib/python3.7/site-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36msoftmax\u001b[0;34m(input, dim, _stacklevel, dtype)\u001b[0m\n\u001b[1;32m 1229\u001b[0m \u001b[0mdim\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_get_softmax_dim\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'softmax'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_stacklevel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1230\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mdtype\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1231\u001b[0;31m \u001b[0mret\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msoftmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1232\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1233\u001b[0m \u001b[0mret\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msoftmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "# Now let's train our model\n", "\n", "model.train()\n", "for i, batch in enumerate(dataloader):\n", " outputs = model(**batch)\n", " loss = outputs[0]\n", " loss.backward()\n", " optimizer.step()\n", " model.zero_grad()\n", " print(f'Step {i} - loss: {loss:.3}')\n", " if i > 3:\n", " break" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "file_extension": ".py", "kernelspec": { "display_name": "Python 3.7.7 64-bit ('datasets': conda)", "language": "python", "name": "python37764bitdatasetscondae5d8ff60608e4c5c953d6bb643d8ebc5" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.7" }, "mimetype": "text/x-python", "name": "python", "npconvert_exporter": "python", "pygments_lexer": "ipython3", "version": 3 }, "nbformat": 4, "nbformat_minor": 2 }