{"cells":[{"cell_type":"markdown","metadata":{"id":"iOOZOgNUWmO9"},"source":["# Préparer des données (PyTorch)"]},{"cell_type":"markdown","metadata":{"id":"NP0--nriWmPA"},"source":["Installez les bibliothèques 🤗 *Transformers* et 🤗 *Datasets* pour exécuter ce *notebook*."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"4lshhqWdWmPB"},"outputs":[],"source":["!pip install datasets transformers[sentencepiece]"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"kUeCRz_qWmPC"},"outputs":[],"source":["import torch\n","from transformers import AdamW, AutoTokenizer, AutoModelForSequenceClassification\n","\n","# Comme avant\n","checkpoint = \"camembert-base\"\n","tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n","model = AutoModelForSequenceClassification.from_pretrained(checkpoint)\n","sequences = [\n"," \"J'ai attendu un cours d'HuggingFace toute ma vie.\", \n"," \"Je déteste tellement ça !\"]\n","batch = tokenizer(sequences, padding=True, truncation=True, return_tensors=\"pt\")\n","\n","# C'est nouveau\n","batch[\"labels\"] = torch.tensor([1, 1])\n","\n","optimizer = AdamW(model.parameters())\n","loss = model(**batch).loss\n","loss.backward()\n","optimizer.step()"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"EQlfGavTWmPD"},"outputs":[],"source":["from datasets import load_dataset\n","\n","raw_datasets = load_dataset(\"paws-x\", \"fr\")\n","raw_datasets"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"dRbS6uuIWmPE"},"outputs":[],"source":["raw_train_dataset = raw_datasets[\"train\"]\n","raw_train_dataset[0]"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"4i0iOjQ0WmPF"},"outputs":[],"source":["raw_train_dataset.features"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"fF9p109xWmPF"},"outputs":[],"source":["from transformers import AutoTokenizer\n","\n","checkpoint = \"camembert-base\"\n","tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n","tokenized_sentences_1 = tokenizer(raw_datasets[\"train\"][\"sentence1\"])\n","tokenized_sentences_2 = tokenizer(raw_datasets[\"train\"][\"sentence2\"])"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"zycRmMuLWmPG"},"outputs":[],"source":["inputs = tokenizer(\"C'est la première phrase.\", \"C'est la deuxième.\")\n","inputs"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"HUh1gWLDWmPH"},"outputs":[],"source":["tokenizer.convert_ids_to_tokens(inputs[\"input_ids\"])"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"puyxiXf1WmPI"},"outputs":[],"source":["tokenized_dataset = tokenizer(\n"," raw_datasets[\"train\"][\"sentence1\"],\n"," raw_datasets[\"train\"][\"sentence2\"],\n"," padding=True,\n"," truncation=True,\n",")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"KSivd-YnWmPJ"},"outputs":[],"source":["def tokenize_function(example):\n"," return tokenizer(example[\"sentence1\"], example[\"sentence2\"], truncation=True)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"4qXjbD3HWmPJ"},"outputs":[],"source":["tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)\n","tokenized_datasets"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"U3w5hwUZWmPK"},"outputs":[],"source":["from transformers import DataCollatorWithPadding\n","\n","data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"_MijuhbYWmPL"},"outputs":[],"source":["samples = tokenized_datasets[\"train\"][:8]\n","samples = {k: v for k, v in samples.items() if k not in [\"idx\", \"sentence1\", \"sentence2\"]}\n","[len(x) for x in samples[\"input_ids\"]]"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"JLP_NVbzWmPL"},"outputs":[],"source":["batch = data_collator(samples)\n","{k: v.shape for k, v in batch.items()}"]}],"metadata":{"colab":{"provenance":[],"collapsed_sections":[]},"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.8.5"}},"nbformat":4,"nbformat_minor":0}