'})"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from transformers import AutoTokenizer\n",
"\n",
"model_checkpoint = \"Helsinki-NLP/opus-mt-de-en\"\n",
"pretrained_tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)\n",
"pretrained_tokenizer"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can pass a single record, or a list of records to huggingface's tokenizer. Then depending on the model, we might see different keys in the dictionary returned. For example, here, we have:\n",
"\n",
"- `input_ids`: The tokenizer converted our raw input text into numerical ids.\n",
"- `attention_mask` Mask to avoid performing attention on padded token ids. As we haven't yet performed the padding step, the numbers are all showing up as 1, indicating they are not masked."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'input_ids': [3303, 5338, 17270, 2843, 70, 49, 14991, 5, 9, 1413, 10949, 14243, 3351, 3, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pretrained_tokenizer(dataset_dict['train']['de'][0])"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"''"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# notice the last token id is 0, the end of sentence special token\n",
"pretrained_tokenizer.convert_ids_to_tokens(0)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'input_ids': [[3303, 5338, 17270, 2843, 70, 49, 14991, 5, 9, 1413, 10949, 14243, 3351, 3, 0], [20520, 2843, 30, 1235, 19116, 15, 14570, 53, 17992, 3013, 1947, 3, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pretrained_tokenizer(dataset_dict['train']['de'][0:2])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can apply the tokenizers to our entire raw dataset, so this preprocessing will be a one time process. By passing the function to our dataset dict's `map` method, it will apply the same tokenizing step to all the splits in our data."
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"max_source_length = 128\n",
"max_target_length = 128\n",
"source_lang = \"de\"\n",
"target_lang = \"en\"\n",
"\n",
"\n",
"def batch_tokenize_fn(examples):\n",
" \"\"\"\n",
" Generate the input_ids and labels field for huggingface dataset/dataset dict.\n",
" \n",
" Truncation is enabled, so we cap the sentence to the max length, padding will be done later\n",
" in a data collator, so pad examples to the longest length in the batch and not the whole dataset.\n",
" \"\"\"\n",
" sources = examples[source_lang]\n",
" targets = examples[target_lang]\n",
" model_inputs = pretrained_tokenizer(sources, max_length=max_source_length, truncation=True)\n",
"\n",
" # setup the tokenizer for targets,\n",
" # huggingface expects the target tokenized ids to be stored in the labels field\n",
" with pretrained_tokenizer.as_target_tokenizer():\n",
" labels = pretrained_tokenizer(targets, max_length=max_target_length, truncation=True)\n",
"\n",
" model_inputs[\"labels\"] = labels[\"input_ids\"]\n",
" return model_inputs"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" "
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d6051597a5354fbead9671e9a5bccb6f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#6', max=4.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" "
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "acc94d858ce0407882a13057d3b5bce1",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#4', max=4.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" "
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "fe777f652f014c35811968583403ec69",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#2', max=4.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" "
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "16c079f92a2046bc8365cc2bfdc981a6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#5', max=4.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5bed431304fc4659920397ca15f4c0b7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#7', max=4.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e7b912cb771347b9b7a203f8120657a2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#0', max=4.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "669c9c28025449e198f12073b693d7be",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#3', max=4.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3a6b7b057b1f4db5a53b36ec58f9d6b9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#1', max=4.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
" "
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3f80b0a6ac6f476285fce127effcf32a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#7', max=1.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "88c87815207b455f91259103e40b985f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#0', max=1.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" "
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c0eba2dd1fac493da1e331ac32290328",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#1', max=1.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3b637a68d5144a769e45deeed7620d67",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#2', max=1.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "070ee1706e8e444e9a322daff89462e5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#6', max=1.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "958be04a8bba411785218ca4ef7d34cd",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#3', max=1.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "08f7a57bda1845ba8724bac074eb12fb",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#4', max=1.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\n",
"\n",
" "
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "669c85b3732f4569a69e04bca6d9f8ea",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#5', max=1.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
" "
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "fdd25ec407504d439aaa5739c0dff855",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#1', max=1.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
" "
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0808b32962f244798dbdbd5880bb1d35",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#0', max=1.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" "
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ca71776f60de4316a9245815c0546b59",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#4', max=1.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "961b5e48dac04ef5a97717237288a8c3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#7', max=1.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "85493ad173a641d694272d41bd24585f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#6', max=1.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\n",
" "
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "893f8957876c45ff94f9ecc82c510dcf",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#3', max=1.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" \n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "134c8a6c68f84a369b4d72562200c7ef",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#5', max=1.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" \n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d64c31da50984057b82bd9ca7d8bd804",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#2', max=1.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"text/plain": [
"DatasetDict({\n",
" train: Dataset({\n",
" features: ['attention_mask', 'de', 'en', 'input_ids', 'labels'],\n",
" num_rows: 29000\n",
" })\n",
" val: Dataset({\n",
" features: ['attention_mask', 'de', 'en', 'input_ids', 'labels'],\n",
" num_rows: 1014\n",
" })\n",
" test: Dataset({\n",
" features: ['attention_mask', 'de', 'en', 'input_ids', 'labels'],\n",
" num_rows: 1000\n",
" })\n",
"})"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokenized_dataset_dict = dataset_dict.map(batch_tokenize_fn, batched=True, num_proc=8)\n",
"tokenized_dataset_dict"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
" 'de': 'Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.',\n",
" 'en': 'Two young, White males are outside near many bushes.',\n",
" 'input_ids': [3303,\n",
" 5338,\n",
" 17270,\n",
" 2843,\n",
" 70,\n",
" 49,\n",
" 14991,\n",
" 5,\n",
" 9,\n",
" 1413,\n",
" 10949,\n",
" 14243,\n",
" 3351,\n",
" 3,\n",
" 0],\n",
" 'labels': [4386, 1296, 2, 3380, 25020, 48, 2060, 1656, 374, 45315, 3, 0]}"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# printing out the tokenized data, to check for the newly added fields\n",
"tokenized_dataset_dict['train'][0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Having prepared our dataset, we'll load the pre-trained model. Similar to the tokenizer, we can use the `.from_pretrained` method, and specify a valid huggingface model."
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"# of parameters: 74410496\n"
]
},
{
"data": {
"text/plain": [
"MarianMTModel(\n",
" (model): MarianModel(\n",
" (shared): Embedding(58101, 512, padding_idx=58100)\n",
" (encoder): MarianEncoder(\n",
" (embed_tokens): Embedding(58101, 512, padding_idx=58100)\n",
" (embed_positions): MarianSinusoidalPositionalEmbedding(512, 512)\n",
" (layers): ModuleList(\n",
" (0): MarianEncoderLayer(\n",
" (self_attn): MarianAttention(\n",
" (k_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (v_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (q_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (out_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" )\n",
" (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
" (fc1): Linear(in_features=512, out_features=2048, bias=True)\n",
" (fc2): Linear(in_features=2048, out_features=512, bias=True)\n",
" (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (1): MarianEncoderLayer(\n",
" (self_attn): MarianAttention(\n",
" (k_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (v_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (q_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (out_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" )\n",
" (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
" (fc1): Linear(in_features=512, out_features=2048, bias=True)\n",
" (fc2): Linear(in_features=2048, out_features=512, bias=True)\n",
" (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (2): MarianEncoderLayer(\n",
" (self_attn): MarianAttention(\n",
" (k_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (v_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (q_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (out_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" )\n",
" (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
" (fc1): Linear(in_features=512, out_features=2048, bias=True)\n",
" (fc2): Linear(in_features=2048, out_features=512, bias=True)\n",
" (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (3): MarianEncoderLayer(\n",
" (self_attn): MarianAttention(\n",
" (k_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (v_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (q_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (out_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" )\n",
" (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
" (fc1): Linear(in_features=512, out_features=2048, bias=True)\n",
" (fc2): Linear(in_features=2048, out_features=512, bias=True)\n",
" (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (4): MarianEncoderLayer(\n",
" (self_attn): MarianAttention(\n",
" (k_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (v_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (q_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (out_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" )\n",
" (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
" (fc1): Linear(in_features=512, out_features=2048, bias=True)\n",
" (fc2): Linear(in_features=2048, out_features=512, bias=True)\n",
" (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (5): MarianEncoderLayer(\n",
" (self_attn): MarianAttention(\n",
" (k_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (v_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (q_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (out_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" )\n",
" (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
" (fc1): Linear(in_features=512, out_features=2048, bias=True)\n",
" (fc2): Linear(in_features=2048, out_features=512, bias=True)\n",
" (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" )\n",
" )\n",
" (decoder): MarianDecoder(\n",
" (embed_tokens): Embedding(58101, 512, padding_idx=58100)\n",
" (embed_positions): MarianSinusoidalPositionalEmbedding(512, 512)\n",
" (layers): ModuleList(\n",
" (0): MarianDecoderLayer(\n",
" (self_attn): MarianAttention(\n",
" (k_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (v_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (q_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (out_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" )\n",
" (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
" (encoder_attn): MarianAttention(\n",
" (k_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (v_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (q_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (out_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" )\n",
" (encoder_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
" (fc1): Linear(in_features=512, out_features=2048, bias=True)\n",
" (fc2): Linear(in_features=2048, out_features=512, bias=True)\n",
" (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (1): MarianDecoderLayer(\n",
" (self_attn): MarianAttention(\n",
" (k_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (v_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (q_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (out_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" )\n",
" (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
" (encoder_attn): MarianAttention(\n",
" (k_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (v_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (q_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (out_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" )\n",
" (encoder_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
" (fc1): Linear(in_features=512, out_features=2048, bias=True)\n",
" (fc2): Linear(in_features=2048, out_features=512, bias=True)\n",
" (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (2): MarianDecoderLayer(\n",
" (self_attn): MarianAttention(\n",
" (k_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (v_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (q_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (out_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" )\n",
" (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
" (encoder_attn): MarianAttention(\n",
" (k_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (v_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (q_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (out_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" )\n",
" (encoder_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
" (fc1): Linear(in_features=512, out_features=2048, bias=True)\n",
" (fc2): Linear(in_features=2048, out_features=512, bias=True)\n",
" (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (3): MarianDecoderLayer(\n",
" (self_attn): MarianAttention(\n",
" (k_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (v_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (q_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (out_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" )\n",
" (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
" (encoder_attn): MarianAttention(\n",
" (k_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (v_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (q_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (out_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" )\n",
" (encoder_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
" (fc1): Linear(in_features=512, out_features=2048, bias=True)\n",
" (fc2): Linear(in_features=2048, out_features=512, bias=True)\n",
" (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (4): MarianDecoderLayer(\n",
" (self_attn): MarianAttention(\n",
" (k_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (v_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (q_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (out_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" )\n",
" (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
" (encoder_attn): MarianAttention(\n",
" (k_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (v_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (q_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (out_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" )\n",
" (encoder_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
" (fc1): Linear(in_features=512, out_features=2048, bias=True)\n",
" (fc2): Linear(in_features=2048, out_features=512, bias=True)\n",
" (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (5): MarianDecoderLayer(\n",
" (self_attn): MarianAttention(\n",
" (k_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (v_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (q_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (out_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" )\n",
" (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
" (encoder_attn): MarianAttention(\n",
" (k_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (v_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (q_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" (out_proj): Linear(in_features=512, out_features=512, bias=True)\n",
" )\n",
" (encoder_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
" (fc1): Linear(in_features=512, out_features=2048, bias=True)\n",
" (fc2): Linear(in_features=2048, out_features=512, bias=True)\n",
" (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (lm_head): Linear(in_features=512, out_features=58101, bias=False)\n",
")"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from transformers import AutoModelForSeq2SeqLM\n",
"\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"\n",
"pretrained_model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).to(device)\n",
"print('# of parameters: ', pretrained_model.num_parameters())\n",
"pretrained_model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can directly use this model to `generate` the translations, and eyeball the results."
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"def generate_translation(model, tokenizer, example):\n",
" \"\"\"print out the source, target and predicted raw text.\"\"\"\n",
" source = example[source_lang]\n",
" target = example[target_lang]\n",
" input_ids = example['input_ids']\n",
" input_ids = torch.LongTensor(input_ids).view(1, -1).to(model.device)\n",
" generated_ids = model.generate(input_ids)\n",
" prediction = tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n",
" \n",
" print('source: ', source)\n",
" print('target: ', target)\n",
" print('prediction: ', prediction)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"source: Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.\n",
"target: Two young, White males are outside near many bushes.\n",
"prediction: Two young white men are outdoors near many bushes.\n"
]
}
],
"source": [
"example = tokenized_dataset_dict['train'][0]\n",
"generate_translation(pretrained_model, pretrained_tokenizer, example)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"source: Ein Mann mit einem orangefarbenen Hut, der etwas anstarrt.\n",
"target: A man in an orange hat starring at something.\n",
"prediction: A man with an orange hat staring at something.\n"
]
}
],
"source": [
"example = tokenized_dataset_dict['test'][0]\n",
"generate_translation(pretrained_model, pretrained_tokenizer, example)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training Model From Scratch"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The next section shows the steps for training the model parameters from scratch. Instead of directly instantiating the model using `.from_pretrained` method. We use the `.from_config` method, where we specify the configurations for a particular model architecture. The configuration will be created using `.from_pretrained`, as well as updating some of the configuration hyper parameters, where we opted for a smaller model for faster iteration."
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"MarianConfig {\n",
" \"_num_labels\": 3,\n",
" \"activation_dropout\": 0.0,\n",
" \"activation_function\": \"swish\",\n",
" \"add_bias_logits\": false,\n",
" \"add_final_layer_norm\": false,\n",
" \"architectures\": [\n",
" \"MarianMTModel\"\n",
" ],\n",
" \"attention_dropout\": 0.0,\n",
" \"bad_words_ids\": [\n",
" [\n",
" 58100\n",
" ]\n",
" ],\n",
" \"bos_token_id\": 0,\n",
" \"classif_dropout\": 0.0,\n",
" \"classifier_dropout\": 0.0,\n",
" \"d_model\": 256,\n",
" \"decoder_attention_heads\": 8,\n",
" \"decoder_ffn_dim\": 512,\n",
" \"decoder_layerdrop\": 0.0,\n",
" \"decoder_layers\": 3,\n",
" \"decoder_start_token_id\": 58100,\n",
" \"dropout\": 0.1,\n",
" \"encoder_attention_heads\": 8,\n",
" \"encoder_ffn_dim\": 512,\n",
" \"encoder_layerdrop\": 0.0,\n",
" \"encoder_layers\": 6,\n",
" \"eos_token_id\": 0,\n",
" \"gradient_checkpointing\": false,\n",
" \"id2label\": {\n",
" \"0\": \"LABEL_0\",\n",
" \"1\": \"LABEL_1\",\n",
" \"2\": \"LABEL_2\"\n",
" },\n",
" \"init_std\": 0.02,\n",
" \"is_encoder_decoder\": true,\n",
" \"label2id\": {\n",
" \"LABEL_0\": 0,\n",
" \"LABEL_1\": 1,\n",
" \"LABEL_2\": 2\n",
" },\n",
" \"max_length\": 128,\n",
" \"max_position_embeddings\": 128,\n",
" \"model_type\": \"marian\",\n",
" \"normalize_before\": false,\n",
" \"normalize_embedding\": false,\n",
" \"num_beams\": 4,\n",
" \"num_hidden_layers\": 6,\n",
" \"pad_token_id\": 58100,\n",
" \"scale_embedding\": true,\n",
" \"static_position_embeddings\": true,\n",
" \"transformers_version\": \"4.3.0\",\n",
" \"use_cache\": true,\n",
" \"vocab_size\": 58101\n",
"}"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from transformers import (\n",
" AutoConfig,\n",
" AutoModelForSeq2SeqLM,\n",
" DataCollatorForSeq2Seq,\n",
" EarlyStoppingCallback,\n",
" Seq2SeqTrainingArguments,\n",
" Seq2SeqTrainer\n",
")\n",
"\n",
"config_params = {\n",
" 'd_model': 256,\n",
" 'decoder_layers': 3,\n",
" 'decoder_attention_heads': 8,\n",
" 'decoder_ffn_dim': 512,\n",
" 'encoder_layers': 6,\n",
" 'encoder_attention_heads': 8,\n",
" 'encoder_ffn_dim': 512,\n",
" 'max_length': 128,\n",
" 'max_position_embeddings': 128\n",
"}\n",
"\n",
"model_checkpoint = \"Helsinki-NLP/opus-mt-de-en\"\n",
"config = AutoConfig.from_pretrained(model_checkpoint, **config_params)\n",
"config"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"# of parameters: 20474368\n"
]
},
{
"data": {
"text/plain": [
"MarianMTModel(\n",
" (model): MarianModel(\n",
" (shared): Embedding(58101, 256, padding_idx=58100)\n",
" (encoder): MarianEncoder(\n",
" (embed_tokens): Embedding(58101, 256, padding_idx=58100)\n",
" (embed_positions): MarianSinusoidalPositionalEmbedding(128, 256)\n",
" (layers): ModuleList(\n",
" (0): MarianEncoderLayer(\n",
" (self_attn): MarianAttention(\n",
" (k_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (v_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (q_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (out_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" )\n",
" (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" (fc1): Linear(in_features=256, out_features=512, bias=True)\n",
" (fc2): Linear(in_features=512, out_features=256, bias=True)\n",
" (final_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (1): MarianEncoderLayer(\n",
" (self_attn): MarianAttention(\n",
" (k_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (v_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (q_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (out_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" )\n",
" (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" (fc1): Linear(in_features=256, out_features=512, bias=True)\n",
" (fc2): Linear(in_features=512, out_features=256, bias=True)\n",
" (final_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (2): MarianEncoderLayer(\n",
" (self_attn): MarianAttention(\n",
" (k_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (v_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (q_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (out_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" )\n",
" (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" (fc1): Linear(in_features=256, out_features=512, bias=True)\n",
" (fc2): Linear(in_features=512, out_features=256, bias=True)\n",
" (final_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (3): MarianEncoderLayer(\n",
" (self_attn): MarianAttention(\n",
" (k_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (v_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (q_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (out_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" )\n",
" (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" (fc1): Linear(in_features=256, out_features=512, bias=True)\n",
" (fc2): Linear(in_features=512, out_features=256, bias=True)\n",
" (final_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (4): MarianEncoderLayer(\n",
" (self_attn): MarianAttention(\n",
" (k_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (v_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (q_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (out_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" )\n",
" (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" (fc1): Linear(in_features=256, out_features=512, bias=True)\n",
" (fc2): Linear(in_features=512, out_features=256, bias=True)\n",
" (final_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (5): MarianEncoderLayer(\n",
" (self_attn): MarianAttention(\n",
" (k_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (v_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (q_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (out_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" )\n",
" (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" (fc1): Linear(in_features=256, out_features=512, bias=True)\n",
" (fc2): Linear(in_features=512, out_features=256, bias=True)\n",
" (final_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" )\n",
" )\n",
" (decoder): MarianDecoder(\n",
" (embed_tokens): Embedding(58101, 256, padding_idx=58100)\n",
" (embed_positions): MarianSinusoidalPositionalEmbedding(128, 256)\n",
" (layers): ModuleList(\n",
" (0): MarianDecoderLayer(\n",
" (self_attn): MarianAttention(\n",
" (k_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (v_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (q_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (out_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" )\n",
" (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" (encoder_attn): MarianAttention(\n",
" (k_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (v_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (q_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (out_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" )\n",
" (encoder_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" (fc1): Linear(in_features=256, out_features=512, bias=True)\n",
" (fc2): Linear(in_features=512, out_features=256, bias=True)\n",
" (final_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (1): MarianDecoderLayer(\n",
" (self_attn): MarianAttention(\n",
" (k_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (v_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (q_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (out_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" )\n",
" (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" (encoder_attn): MarianAttention(\n",
" (k_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (v_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (q_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (out_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" )\n",
" (encoder_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" (fc1): Linear(in_features=256, out_features=512, bias=True)\n",
" (fc2): Linear(in_features=512, out_features=256, bias=True)\n",
" (final_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (2): MarianDecoderLayer(\n",
" (self_attn): MarianAttention(\n",
" (k_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (v_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (q_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (out_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" )\n",
" (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" (encoder_attn): MarianAttention(\n",
" (k_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (v_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (q_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (out_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" )\n",
" (encoder_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" (fc1): Linear(in_features=256, out_features=512, bias=True)\n",
" (fc2): Linear(in_features=512, out_features=256, bias=True)\n",
" (final_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (lm_head): Linear(in_features=256, out_features=58101, bias=False)\n",
")"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = AutoModelForSeq2SeqLM.from_config(config)\n",
"print('# of parameters: ', model.num_parameters())\n",
"model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The huggingface library offers pre-built functionality to avoid writing the training logic from scratch. This step can be swapped out with other higher level trainer packages or even implementing our own logic. We setup the:\n",
"\n",
"- `Seq2SeqTrainingArguments` a class that contains all the attributes to customize the training. At the bare minimum, it requires one folder name, which will be used to save model checkpoint.\n",
"- `DataCollatorForSeq2Seq` a helper class provided to batch our examples. Where the padding logic resides."
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"batch_size = 128\n",
"args = Seq2SeqTrainingArguments(\n",
" output_dir=\"test-translation\",\n",
" evaluation_strategy=\"epoch\",\n",
" learning_rate=0.0005,\n",
" per_device_train_batch_size=batch_size,\n",
" per_device_eval_batch_size=batch_size,\n",
" weight_decay=0.01,\n",
" save_total_limit=3,\n",
" num_train_epochs=20,\n",
" load_best_model_at_end=True,\n",
" predict_with_generate=True,\n",
" remove_unused_columns=True,\n",
" fp16=True\n",
")\n",
"\n",
"data_collator = DataCollatorForSeq2Seq(pretrained_tokenizer)\n",
"\n",
"callbacks = [EarlyStoppingCallback(early_stopping_patience=3)]"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"trainer = Seq2SeqTrainer(\n",
" model,\n",
" args,\n",
" data_collator=data_collator,\n",
" train_dataset=tokenized_dataset_dict[\"train\"],\n",
" eval_dataset=tokenized_dataset_dict[\"val\"],\n",
" callbacks=callbacks\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can take a look at the batched examples. Understanding the output can be beneficial if we wish to customize the data collate function later.\n",
"\n",
"- `attention_mask` Padded tokens will be masked out with 0..\n",
"- `input_ids`. Input ids are padded with the padding special tokens.\n",
"- `labels`. By default -100 will be automatically ignored by PyTorch loss functions, hence we will use that particular id when padding our labels."
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'attention_mask': tensor([[1, 1, 1, ..., 0, 0, 0],\n",
" [1, 1, 1, ..., 0, 0, 0],\n",
" [1, 1, 1, ..., 0, 0, 0],\n",
" ...,\n",
" [1, 1, 1, ..., 0, 0, 0],\n",
" [1, 1, 1, ..., 0, 0, 0],\n",
" [1, 1, 1, ..., 0, 0, 0]]),\n",
" 'input_ids': tensor([[ 246, 1155, 5, ..., 58100, 58100, 58100],\n",
" [ 525, 788, 2, ..., 58100, 58100, 58100],\n",
" [ 246, 2902, 18756, ..., 58100, 58100, 58100],\n",
" ...,\n",
" [ 3303, 2843, 2, ..., 58100, 58100, 58100],\n",
" [ 246, 2902, 1251, ..., 58100, 58100, 58100],\n",
" [ 246, 5324, 8055, ..., 58100, 58100, 58100]]),\n",
" 'labels': tensor([[ 93, 175, 5, ..., -100, -100, -100],\n",
" [ 93, 2950, 19, ..., -100, -100, -100],\n",
" [ 93, 4040, 5074, ..., -100, -100, -100],\n",
" ...,\n",
" [ 4386, 1135, 25345, ..., -100, -100, -100],\n",
" [ 93, 4040, 2047, ..., -100, -100, -100],\n",
" [ 93, 839, 6799, ..., -100, -100, -100]])}"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataloader_train = trainer.get_train_dataloader()\n",
"batch = next(iter(dataloader_train))\n",
"batch"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" \n",
"
\n",
" [2497/4540 07:34 < 06:12, 5.49 it/s, Epoch 11/20]\n",
"
\n",
" \n",
" \n",
" \n",
" Epoch | \n",
" Training Loss | \n",
" Validation Loss | \n",
" Runtime | \n",
" Samples Per Second | \n",
"
\n",
" \n",
" \n",
" \n",
" 1 | \n",
" No log | \n",
" 3.701715 | \n",
" 0.654800 | \n",
" 1548.522000 | \n",
"
\n",
" \n",
" 2 | \n",
" No log | \n",
" 2.559382 | \n",
" 0.669000 | \n",
" 1515.684000 | \n",
"
\n",
" \n",
" 3 | \n",
" 3.932700 | \n",
" 2.114859 | \n",
" 0.746600 | \n",
" 1358.109000 | \n",
"
\n",
" \n",
" 4 | \n",
" 3.932700 | \n",
" 1.917877 | \n",
" 0.940700 | \n",
" 1077.914000 | \n",
"
\n",
" \n",
" 5 | \n",
" 1.959900 | \n",
" 1.787804 | \n",
" 0.582700 | \n",
" 1740.033000 | \n",
"
\n",
" \n",
" 6 | \n",
" 1.959900 | \n",
" 1.738364 | \n",
" 0.497900 | \n",
" 2036.748000 | \n",
"
\n",
" \n",
" 7 | \n",
" 1.408800 | \n",
" 1.711116 | \n",
" 0.647000 | \n",
" 1567.237000 | \n",
"
\n",
" \n",
" 8 | \n",
" 1.408800 | \n",
" 1.710581 | \n",
" 0.572500 | \n",
" 1771.124000 | \n",
"
\n",
" \n",
" 9 | \n",
" 1.106000 | \n",
" 1.723755 | \n",
" 0.717300 | \n",
" 1413.717000 | \n",
"
\n",
" \n",
" 10 | \n",
" 1.106000 | \n",
" 1.741999 | \n",
" 0.558800 | \n",
" 1814.496000 | \n",
"
\n",
" \n",
" 11 | \n",
" 1.106000 | \n",
" 1.731579 | \n",
" 0.590900 | \n",
" 1716.157000 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"TrainOutput(global_step=2497, training_loss=1.8638566963950491, metrics={'train_runtime': 455.0706, 'train_samples_per_second': 9.976, 'total_flos': 1390184690368512, 'epoch': 11.0})"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer_output = trainer.train()\n",
"trainer_output"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Similar to what we did before, we can use this model to `generate` the translations, and eyeball the results."
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"def generate_translation(model, tokenizer, example):\n",
" \"\"\"print out the source, target and predicted raw text.\"\"\"\n",
" source = example[source_lang]\n",
" target = example[target_lang]\n",
" input_ids = tokenizer(source)['input_ids']\n",
" input_ids = torch.LongTensor(input_ids).view(1, -1).to(model.device)\n",
" generated_ids = model.generate(input_ids)\n",
" prediction = tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n",
"\n",
" print('source: ', source)\n",
" print('target: ', target)\n",
" print('prediction: ', prediction)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"source: Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.\n",
"target: Two young, White males are outside near many bushes.\n",
"prediction: Two young white men are outside near many bushes.\n"
]
}
],
"source": [
"example = dataset_dict['train'][0]\n",
"generate_translation(model, pretrained_tokenizer, example)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"source: Ein Mann mit einem orangefarbenen Hut, der etwas anstarrt.\n",
"target: A man in an orange hat starring at something.\n",
"prediction: A man with an orange hat looking at something.\n"
]
}
],
"source": [
"example = dataset_dict['test'][0]\n",
"generate_translation(model, pretrained_tokenizer, example)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training Tokenizer and Model From Scratch"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"From our raw pair, we need to use or train a tokenizer to convert them into numerical indices. Here we'll be training our tokenizer from scratch using Huggingface's [tokenizer](https://github.com/huggingface/tokenizers). Feel free to swap this step out with other tokenization procedures, what's important is to leave rooms for special tokens such as the init token that represents the beginning of a sentence, the end of sentence token that represents the end of a sentence, unknown token, and padding token that pads sentence batches into equivalent length."
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"multi30k/train.de multi30k/train.en\n"
]
}
],
"source": [
"# use only the training set to train our tokenizer\n",
"split = 'train'\n",
"source_input_path = os.path.join(directory, f'{split}.{source_lang}')\n",
"target_input_path = os.path.join(directory, f'{split}.{target_lang}')\n",
"print(source_input_path, target_input_path)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"elapsed: 6.176600694656372\n",
"source vocab size: 5000\n",
"target vocab size: 5000\n"
]
}
],
"source": [
"bos_token = ''\n",
"unk_token = ''\n",
"eos_token = ''\n",
"pad_token = ''\n",
"special_tokens = [unk_token, bos_token, eos_token, pad_token]\n",
"\n",
"tokenizer_params = {\n",
" 'min_frequency': 2,\n",
" 'vocab_size': 5000,\n",
" 'show_progress': False,\n",
" 'special_tokens': special_tokens\n",
"}\n",
"\n",
"start_time = time.time()\n",
"source_tokenizer = ByteLevelBPETokenizer(lowercase=True)\n",
"source_tokenizer.train(source_input_path, **tokenizer_params)\n",
"\n",
"target_tokenizer = ByteLevelBPETokenizer(lowercase=True)\n",
"target_tokenizer.train(target_input_path, **tokenizer_params)\n",
"end_time = time.time()\n",
"\n",
"print('elapsed: ', end_time - start_time)\n",
"print('source vocab size: ', source_tokenizer.get_vocab_size())\n",
"print('target vocab size: ', target_tokenizer.get_vocab_size())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We'll perform this tokenization step for all our dataset up front, so we can do as little preprocessing as possible while feeding our dataset to model. Note that we do not perform the padding step at this stage."
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"pad_token_id = source_tokenizer.token_to_id(pad_token)\n",
"eos_token_id = source_tokenizer.token_to_id(eos_token)"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"def batch_encode_fn(examples):\n",
" sources = examples[source_lang]\n",
" targets = examples[target_lang]\n",
"\n",
" input_ids = [encoding.ids + [eos_token_id] for encoding in source_tokenizer.encode_batch(sources)]\n",
" labels = [encoding.ids + [eos_token_id] for encoding in target_tokenizer.encode_batch(targets)]\n",
"\n",
" examples['input_ids'] = input_ids\n",
" examples['labels'] = labels\n",
" return examples"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" "
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f2d828507cf14fa69dc5c9b5fcc9c9d0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#0', max=4.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "99ff4522218c4a6ca655b6af23a61664",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#5', max=4.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2b12ecb5ba78486f8ed99fb3656c1e4a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#3', max=4.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "66820df017aa45fcaefd047a1f048c32",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#2', max=4.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f4b0dc4256654110ab9766fc7b9a8136",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#7', max=4.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "fba0cf327a7c415faa99565caf65c38a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#6', max=4.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "88529b83cc904f1298ae3d67fcf1760e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#4', max=4.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5746e4c8bde444d097402a50f2bc0f9d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#1', max=4.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
" "
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4a855f19d5574137a9be298799c99242",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#1', max=1.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d2a8ad44c0f6449fb959627f594e9c31",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#0', max=1.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d8b120665d544c15aa2eb829eb6bcfdf",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#2', max=1.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4b1b4777e4434d3b98bf2868bd98484d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#4', max=1.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3a16aff563a34f80a915bb06ec9e62e7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#3', max=1.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c48d89b75aee42d09b91c0b7d17c0c92",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#5', max=1.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d5e9df442169417a9fc1725aad2e3548",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#6', max=1.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "25caa37e83d54b11bda4f45a0e394f46",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#7', max=1.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
" "
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c8830a7ffd8b4d0c8e66029c2587c63c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#2', max=1.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "92294f6c31c1434eb5775dfbd1ce7c59",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#0', max=1.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0e3021ddb16447269d830d6e0adb8793",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#1', max=1.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2c2a9343eaa54b42a26d31adf5a0e5b1",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#4', max=1.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "58d96b4173924c08bccb071198a36a7f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#3', max=1.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "68f2feee414f4bd4b95ec144ba86dd05",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#5', max=1.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "681bd1ba0e2149b3b48b8c7c3dd0e22a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#6', max=1.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6ff22d92e1e04905b23d4a750a5acdbc",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='#7', max=1.0, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n"
]
},
{
"data": {
"text/plain": [
"DatasetDict({\n",
" train: Dataset({\n",
" features: ['de', 'en', 'input_ids', 'labels'],\n",
" num_rows: 29000\n",
" })\n",
" val: Dataset({\n",
" features: ['de', 'en', 'input_ids', 'labels'],\n",
" num_rows: 1014\n",
" })\n",
" test: Dataset({\n",
" features: ['de', 'en', 'input_ids', 'labels'],\n",
" num_rows: 1000\n",
" })\n",
"})"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataset_dict_encoded = dataset_dict.map(batch_encode_fn, batched=True, num_proc=8)\n",
"dataset_dict_encoded"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'de': 'Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.',\n",
" 'en': 'Two young, White males are outside near many bushes.',\n",
" 'input_ids': [344,\n",
" 378,\n",
" 1191,\n",
" 413,\n",
" 649,\n",
" 349,\n",
" 660,\n",
" 281,\n",
" 327,\n",
" 726,\n",
" 1284,\n",
" 263,\n",
" 728,\n",
" 707,\n",
" 17,\n",
" 2],\n",
" 'labels': [336, 373, 15, 370, 2182, 321, 494, 557, 1203, 3158, 17, 2]}"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataset_train = dataset_dict_encoded['train']\n",
"dataset_train[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Given the custom tokenizer, we can also custom our data collate class that does the padding for input and labels."
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
"class Seq2SeqDataCollator:\n",
" \n",
" def __init__(\n",
" self,\n",
" max_length: int,\n",
" pad_token_id: int,\n",
" pad_label_token_id: int = -100\n",
" ):\n",
" self.max_length = max_length\n",
" self.pad_token_id = pad_token_id\n",
" self.pad_label_token_id = pad_label_token_id\n",
" \n",
" def __call__(self, batch):\n",
" source_batch = []\n",
" source_len = []\n",
" target_batch = []\n",
" target_len = []\n",
" for example in batch:\n",
" source = example['input_ids']\n",
" source_len.append(len(source))\n",
" source_batch.append(source)\n",
"\n",
" target = example['labels']\n",
" target_len.append(len(target))\n",
" target_batch.append(target)\n",
"\n",
" source_padded = self.process_encoded_text(source_batch, source_len, self.pad_token_id)\n",
" target_padded = self.process_encoded_text(target_batch, target_len, self.pad_label_token_id)\n",
" attention_mask = generate_attention_mask(source_padded, self.pad_token_id)\n",
" return {\n",
" 'input_ids': source_padded,\n",
" 'labels': target_padded,\n",
" 'attention_mask': attention_mask\n",
" }\n",
"\n",
" def process_encoded_text(self, sequences, sequences_len, pad_token_id):\n",
" sequences_max_len = np.max(sequences_len)\n",
" max_length = min(sequences_max_len, self.max_length)\n",
" padded_sequences = pad_sequences(sequences, max_length, pad_token_id)\n",
" return torch.LongTensor(padded_sequences)\n",
"\n",
"\n",
"def generate_attention_mask(input_ids, pad_token_id):\n",
" return (input_ids != pad_token_id).long()\n",
"\n",
" \n",
"def pad_sequences(sequences, max_length, pad_token_id):\n",
" \"\"\"\n",
" Pad the list of sequences (numerical token ids) to the same length.\n",
" Sequence that are shorter than the specified ``max_len`` will be appended\n",
" with the specified ``pad_token_id``. Those that are longer will be truncated.\n",
"\n",
" Parameters\n",
" ----------\n",
" sequences : list[int]\n",
" List of numerical token ids.\n",
"\n",
" max_length : int\n",
" Maximum length that all sequences will be truncated/padded to.\n",
"\n",
" pad_token_id : int\n",
" Padding token index.\n",
"\n",
" Returns\n",
" -------\n",
" padded_sequences : 1d ndarray\n",
" \"\"\"\n",
" num_samples = len(sequences)\n",
" padded_sequences = np.full((num_samples, max_length), pad_token_id)\n",
" for i, sequence in enumerate(sequences):\n",
" sequence = np.array(sequence)[:max_length]\n",
" padded_sequences[i, :len(sequence)] = sequence\n",
"\n",
" return padded_sequences"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Given that we are using our own tokenizer instead of the pre-trained ones, we need to update a couple of other parameters in our config. The one that's worth pointing out is that this model starts generating with `pad_token_id`, that's why the `decoder_start_token_id` is the same as the `pad_token_id`.\n",
"\n",
"Then rest of model training code should be the same as the ones in the previous section."
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"MarianConfig {\n",
" \"_num_labels\": 3,\n",
" \"activation_dropout\": 0.0,\n",
" \"activation_function\": \"swish\",\n",
" \"add_bias_logits\": false,\n",
" \"add_final_layer_norm\": false,\n",
" \"architectures\": [\n",
" \"MarianMTModel\"\n",
" ],\n",
" \"attention_dropout\": 0.0,\n",
" \"bad_words_ids\": [\n",
" [\n",
" 3\n",
" ]\n",
" ],\n",
" \"bos_token_id\": 0,\n",
" \"classif_dropout\": 0.0,\n",
" \"classifier_dropout\": 0.0,\n",
" \"d_model\": 256,\n",
" \"decoder_attention_heads\": 8,\n",
" \"decoder_ffn_dim\": 512,\n",
" \"decoder_layerdrop\": 0.0,\n",
" \"decoder_layers\": 3,\n",
" \"decoder_start_token_id\": 3,\n",
" \"dropout\": 0.1,\n",
" \"encoder_attention_heads\": 8,\n",
" \"encoder_ffn_dim\": 512,\n",
" \"encoder_layerdrop\": 0.0,\n",
" \"encoder_layers\": 6,\n",
" \"eos_token_id\": 2,\n",
" \"gradient_checkpointing\": false,\n",
" \"id2label\": {\n",
" \"0\": \"LABEL_0\",\n",
" \"1\": \"LABEL_1\",\n",
" \"2\": \"LABEL_2\"\n",
" },\n",
" \"init_std\": 0.02,\n",
" \"is_encoder_decoder\": true,\n",
" \"label2id\": {\n",
" \"LABEL_0\": 0,\n",
" \"LABEL_1\": 1,\n",
" \"LABEL_2\": 2\n",
" },\n",
" \"max_length\": 128,\n",
" \"max_position_embeddings\": 128,\n",
" \"model_type\": \"marian\",\n",
" \"normalize_before\": false,\n",
" \"normalize_embedding\": false,\n",
" \"num_beams\": 4,\n",
" \"num_hidden_layers\": 6,\n",
" \"pad_token_id\": 3,\n",
" \"scale_embedding\": true,\n",
" \"static_position_embeddings\": true,\n",
" \"transformers_version\": \"4.3.0\",\n",
" \"use_cache\": true,\n",
" \"vocab_size\": 5000\n",
"}"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"config_params = {\n",
" 'd_model': 256,\n",
" 'decoder_layers': 3,\n",
" 'decoder_attention_heads': 8,\n",
" 'decoder_ffn_dim': 512,\n",
" 'encoder_layers': 6,\n",
" 'encoder_attention_heads': 8,\n",
" 'encoder_ffn_dim': 512,\n",
" 'max_length': 128,\n",
" 'max_position_embeddings': 128,\n",
" 'eos_token_id': eos_token_id,\n",
" 'pad_token_id': pad_token_id,\n",
" 'decoder_start_token_id': pad_token_id,\n",
" \"bad_words_ids\": [\n",
" [\n",
" pad_token_id\n",
" ]\n",
" ],\n",
" 'vocab_size': source_tokenizer.get_vocab_size()\n",
"}\n",
"\n",
"model_config = AutoConfig.from_pretrained(model_checkpoint, **config_params)\n",
"model_config"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"# of parameters: 6880512\n"
]
},
{
"data": {
"text/plain": [
"MarianMTModel(\n",
" (model): MarianModel(\n",
" (shared): Embedding(5000, 256, padding_idx=3)\n",
" (encoder): MarianEncoder(\n",
" (embed_tokens): Embedding(5000, 256, padding_idx=3)\n",
" (embed_positions): MarianSinusoidalPositionalEmbedding(128, 256)\n",
" (layers): ModuleList(\n",
" (0): MarianEncoderLayer(\n",
" (self_attn): MarianAttention(\n",
" (k_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (v_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (q_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (out_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" )\n",
" (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" (fc1): Linear(in_features=256, out_features=512, bias=True)\n",
" (fc2): Linear(in_features=512, out_features=256, bias=True)\n",
" (final_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (1): MarianEncoderLayer(\n",
" (self_attn): MarianAttention(\n",
" (k_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (v_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (q_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (out_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" )\n",
" (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" (fc1): Linear(in_features=256, out_features=512, bias=True)\n",
" (fc2): Linear(in_features=512, out_features=256, bias=True)\n",
" (final_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (2): MarianEncoderLayer(\n",
" (self_attn): MarianAttention(\n",
" (k_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (v_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (q_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (out_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" )\n",
" (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" (fc1): Linear(in_features=256, out_features=512, bias=True)\n",
" (fc2): Linear(in_features=512, out_features=256, bias=True)\n",
" (final_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (3): MarianEncoderLayer(\n",
" (self_attn): MarianAttention(\n",
" (k_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (v_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (q_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (out_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" )\n",
" (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" (fc1): Linear(in_features=256, out_features=512, bias=True)\n",
" (fc2): Linear(in_features=512, out_features=256, bias=True)\n",
" (final_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (4): MarianEncoderLayer(\n",
" (self_attn): MarianAttention(\n",
" (k_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (v_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (q_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (out_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" )\n",
" (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" (fc1): Linear(in_features=256, out_features=512, bias=True)\n",
" (fc2): Linear(in_features=512, out_features=256, bias=True)\n",
" (final_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (5): MarianEncoderLayer(\n",
" (self_attn): MarianAttention(\n",
" (k_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (v_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (q_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (out_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" )\n",
" (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" (fc1): Linear(in_features=256, out_features=512, bias=True)\n",
" (fc2): Linear(in_features=512, out_features=256, bias=True)\n",
" (final_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" )\n",
" )\n",
" (decoder): MarianDecoder(\n",
" (embed_tokens): Embedding(5000, 256, padding_idx=3)\n",
" (embed_positions): MarianSinusoidalPositionalEmbedding(128, 256)\n",
" (layers): ModuleList(\n",
" (0): MarianDecoderLayer(\n",
" (self_attn): MarianAttention(\n",
" (k_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (v_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (q_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (out_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" )\n",
" (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" (encoder_attn): MarianAttention(\n",
" (k_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (v_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (q_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (out_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" )\n",
" (encoder_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" (fc1): Linear(in_features=256, out_features=512, bias=True)\n",
" (fc2): Linear(in_features=512, out_features=256, bias=True)\n",
" (final_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (1): MarianDecoderLayer(\n",
" (self_attn): MarianAttention(\n",
" (k_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (v_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (q_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (out_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" )\n",
" (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" (encoder_attn): MarianAttention(\n",
" (k_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (v_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (q_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (out_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" )\n",
" (encoder_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" (fc1): Linear(in_features=256, out_features=512, bias=True)\n",
" (fc2): Linear(in_features=512, out_features=256, bias=True)\n",
" (final_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (2): MarianDecoderLayer(\n",
" (self_attn): MarianAttention(\n",
" (k_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (v_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (q_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (out_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" )\n",
" (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" (encoder_attn): MarianAttention(\n",
" (k_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (v_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (q_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" (out_proj): Linear(in_features=256, out_features=256, bias=True)\n",
" )\n",
" (encoder_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" (fc1): Linear(in_features=256, out_features=512, bias=True)\n",
" (fc2): Linear(in_features=512, out_features=256, bias=True)\n",
" (final_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (lm_head): Linear(in_features=256, out_features=5000, bias=False)\n",
")"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"transformers_model = AutoModelForSeq2SeqLM.from_config(model_config)\n",
"print('# of parameters: ', transformers_model.num_parameters())\n",
"transformers_model"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
"batch_size = 128\n",
"args = Seq2SeqTrainingArguments(\n",
" output_dir=\"test-translation\",\n",
" evaluation_strategy=\"epoch\",\n",
" learning_rate=0.0005,\n",
" per_device_train_batch_size=batch_size,\n",
" per_device_eval_batch_size=batch_size,\n",
" weight_decay=0.01,\n",
" save_total_limit=3,\n",
" num_train_epochs=20,\n",
" load_best_model_at_end=True,\n",
" predict_with_generate=True,\n",
" remove_unused_columns=True,\n",
" fp16=True\n",
")\n",
"\n",
"data_collator = Seq2SeqDataCollator(model_config.max_length, pad_token_id)\n",
"\n",
"callbacks = [EarlyStoppingCallback(early_stopping_patience=3)]"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [],
"source": [
"trainer = Seq2SeqTrainer(\n",
" transformers_model,\n",
" args,\n",
" train_dataset=dataset_dict_encoded[\"train\"],\n",
" eval_dataset=dataset_dict_encoded[\"val\"],\n",
" data_collator=data_collator,\n",
" callbacks=callbacks\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'input_ids': tensor([[262, 294, 281, ..., 3, 3, 3],\n",
" [297, 318, 15, ..., 3, 3, 3],\n",
" [262, 386, 672, ..., 3, 3, 3],\n",
" ...,\n",
" [344, 413, 15, ..., 3, 3, 3],\n",
" [262, 386, 546, ..., 3, 3, 3],\n",
" [262, 563, 378, ..., 3, 3, 3]]),\n",
" 'labels': tensor([[ 68, 292, 271, ..., -100, -100, -100],\n",
" [ 68, 326, 293, ..., -100, -100, -100],\n",
" [ 68, 376, 662, ..., -100, -100, -100],\n",
" ...,\n",
" [ 336, 401, 560, ..., -100, -100, -100],\n",
" [ 68, 376, 1130, ..., -100, -100, -100],\n",
" [ 68, 505, 385, ..., -100, -100, -100]]),\n",
" 'attention_mask': tensor([[1, 1, 1, ..., 0, 0, 0],\n",
" [1, 1, 1, ..., 0, 0, 0],\n",
" [1, 1, 1, ..., 0, 0, 0],\n",
" ...,\n",
" [1, 1, 1, ..., 0, 0, 0],\n",
" [1, 1, 1, ..., 0, 0, 0],\n",
" [1, 1, 1, ..., 0, 0, 0]])}"
]
},
"execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataloader_train = trainer.get_train_dataloader()\n",
"next(iter(dataloader_train))"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" \n",
"
\n",
" [3405/4540 07:43 < 02:34, 7.34 it/s, Epoch 15/20]\n",
"
\n",
" \n",
" \n",
" \n",
" Epoch | \n",
" Training Loss | \n",
" Validation Loss | \n",
" Runtime | \n",
" Samples Per Second | \n",
"
\n",
" \n",
" \n",
" \n",
" 1 | \n",
" No log | \n",
" 3.598508 | \n",
" 0.477500 | \n",
" 2123.690000 | \n",
"
\n",
" \n",
" 2 | \n",
" No log | \n",
" 2.741445 | \n",
" 0.630800 | \n",
" 1607.404000 | \n",
"
\n",
" \n",
" 3 | \n",
" 3.840200 | \n",
" 2.316170 | \n",
" 0.509800 | \n",
" 1989.048000 | \n",
"
\n",
" \n",
" 4 | \n",
" 3.840200 | \n",
" 2.078891 | \n",
" 0.713500 | \n",
" 1421.158000 | \n",
"
\n",
" \n",
" 5 | \n",
" 2.274600 | \n",
" 1.941849 | \n",
" 0.540400 | \n",
" 1876.244000 | \n",
"
\n",
" \n",
" 6 | \n",
" 2.274600 | \n",
" 1.841438 | \n",
" 0.608600 | \n",
" 1666.216000 | \n",
"
\n",
" \n",
" 7 | \n",
" 1.767100 | \n",
" 1.781287 | \n",
" 0.657200 | \n",
" 1543.026000 | \n",
"
\n",
" \n",
" 8 | \n",
" 1.767100 | \n",
" 1.747373 | \n",
" 0.599300 | \n",
" 1691.906000 | \n",
"
\n",
" \n",
" 9 | \n",
" 1.486600 | \n",
" 1.719654 | \n",
" 0.743000 | \n",
" 1364.803000 | \n",
"
\n",
" \n",
" 10 | \n",
" 1.486600 | \n",
" 1.704974 | \n",
" 0.617300 | \n",
" 1642.552000 | \n",
"
\n",
" \n",
" 11 | \n",
" 1.486600 | \n",
" 1.701151 | \n",
" 0.575000 | \n",
" 1763.431000 | \n",
"
\n",
" \n",
" 12 | \n",
" 1.294600 | \n",
" 1.692111 | \n",
" 0.519600 | \n",
" 1951.319000 | \n",
"
\n",
" \n",
" 13 | \n",
" 1.294600 | \n",
" 1.693845 | \n",
" 0.487700 | \n",
" 2079.081000 | \n",
"
\n",
" \n",
" 14 | \n",
" 1.136700 | \n",
" 1.702049 | \n",
" 0.508200 | \n",
" 1995.281000 | \n",
"
\n",
" \n",
" 15 | \n",
" 1.136700 | \n",
" 1.706282 | \n",
" 0.527000 | \n",
" 1923.935000 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"TrainOutput(global_step=3405, training_loss=1.8559266660357012, metrics={'train_runtime': 463.7701, 'train_samples_per_second': 9.789, 'total_flos': 643119585140736, 'epoch': 15.0})"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer_output = trainer.train()\n",
"trainer_output"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"def generate_translation(model, source_tokenizer, target_tokenizer, example):\n",
" source = example[source_lang]\n",
" target = example[target_lang]\n",
" input_ids = source_tokenizer.encode(source).ids\n",
" input_ids = torch.LongTensor(input_ids).view(1, -1).to(model.device)\n",
" generated_ids = model.generate(input_ids)\n",
" generated_ids = generated_ids[0].detach().cpu().numpy()\n",
"\n",
" prediction = target_tokenizer.decode(generated_ids)\n",
" print('source: ', source)\n",
" print('target: ', target)\n",
" print('prediction: ', prediction)"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"source: Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.\n",
"target: Two young, White males are outside near many bushes.\n",
"prediction: two young white men are outside near many bushes.\n"
]
}
],
"source": [
"example = dataset_dict['train'][0]\n",
"generate_translation(transformers_model, source_tokenizer, target_tokenizer, example)"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"source: Ein Mann mit einem orangefarbenen Hut, der etwas anstarrt.\n",
"target: A man in an orange hat starring at something.\n",
"prediction: a man in an orange hat who is looking at something.\n"
]
}
],
"source": [
"example = dataset_dict['test'][0]\n",
"generate_translation(transformers_model, source_tokenizer, target_tokenizer, example)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Confirming saving and loading the model gives us identical predictions."
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [],
"source": [
"model_checkpoint = 'transformers_model'\n",
"transformers_model.save_pretrained(model_checkpoint)"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [],
"source": [
"transformers_model_loaded = transformers_model.from_pretrained(model_checkpoint).to(device)"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"source: Ein Mann mit einem orangefarbenen Hut, der etwas anstarrt.\n",
"target: A man in an orange hat starring at something.\n",
"prediction: a man in an orange hat who is looking at something.\n"
]
}
],
"source": [
"example = dataset_dict['test'][0]\n",
"generate_translation(transformers_model_loaded, source_tokenizer, target_tokenizer, example)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As the last step, we'll write a inferencing function that performs batch scoring on a given dataset. Here we generate the predictions and save it in a pandas dataframe along with the source and the target."
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1000"
]
},
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(dataset_dict_encoded['test'])"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 58,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# we use a different data collator then the one we used for training and evaluating model\n",
"# replace -100 in the labels with other special tokens during inferencing\n",
"# as we can't decode them.\n",
"data_collator = Seq2SeqDataCollator(model_config.max_length, pad_token_id, pad_token_id)\n",
"data_loader = DataLoader(dataset_dict_encoded['test'], collate_fn=data_collator, batch_size=64)\n",
"data_loader"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"elapsed: 12.964367628097534\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" source | \n",
" target | \n",
" prediction | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" ein mann mit einem orangefarbenen hut, der etw... | \n",
" a man in an orange hat starring at something. | \n",
" a man in an orange hat looking at something. | \n",
"
\n",
" \n",
" 1 | \n",
" ein boston terrier läuft über saftig-grünes gr... | \n",
" a boston terrier is running on lush green gras... | \n",
" a boststst skier runs in front of a white fenc... | \n",
"
\n",
" \n",
" 2 | \n",
" ein mädchen in einem karateanzug bricht einen ... | \n",
" a girl in karate uniform breaking a stick with... | \n",
" a girl in a karate uniform gets a stick with a... | \n",
"
\n",
" \n",
" 3 | \n",
" fünf leute in winterjacken und mit helmen steh... | \n",
" five people wearing winter jackets and helmets... | \n",
" five people wearing winter jackets and helmets... | \n",
"
\n",
" \n",
" 4 | \n",
" leute reparieren das dach eines hauses. | \n",
" people are fixing the roof of a house. | \n",
" people are fixing the roof of a house. | \n",
"
\n",
" \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" 995 | \n",
" marathonläufer laufen auf einer städtischen st... | \n",
" marathon runners are racing on a city street, ... | \n",
" a marathon runner running around a city street... | \n",
"
\n",
" \n",
" 996 | \n",
" asiatische frau trägt einen sonnenhut beim fah... | \n",
" asian woman wearing a sunhat while riding a bike. | \n",
" asian woman wearing a sunhat riding a bicycle. | \n",
"
\n",
" \n",
" 997 | \n",
" ein paar kinder sind im freien und spielen auf... | \n",
" some children are outside playing in the dirt ... | \n",
" a couple of children are outside playing on th... | \n",
"
\n",
" \n",
" 998 | \n",
" ein älterer mann spielt ein videospiel. | \n",
" an older man is playing a video arcade game. | \n",
" an older man is playing a video game. | \n",
"
\n",
" \n",
" 999 | \n",
" ein mädchen an einer küste mit einem berg im h... | \n",
" a girl at the shore of a beach with a mountain... | \n",
" a girl at a shore with a mountain in the backg... | \n",
"
\n",
" \n",
"
\n",
"
1000 rows × 3 columns
\n",
"
"
],
"text/plain": [
" source \\\n",
"0 ein mann mit einem orangefarbenen hut, der etw... \n",
"1 ein boston terrier läuft über saftig-grünes gr... \n",
"2 ein mädchen in einem karateanzug bricht einen ... \n",
"3 fünf leute in winterjacken und mit helmen steh... \n",
"4 leute reparieren das dach eines hauses. \n",
".. ... \n",
"995 marathonläufer laufen auf einer städtischen st... \n",
"996 asiatische frau trägt einen sonnenhut beim fah... \n",
"997 ein paar kinder sind im freien und spielen auf... \n",
"998 ein älterer mann spielt ein videospiel. \n",
"999 ein mädchen an einer küste mit einem berg im h... \n",
"\n",
" target \\\n",
"0 a man in an orange hat starring at something. \n",
"1 a boston terrier is running on lush green gras... \n",
"2 a girl in karate uniform breaking a stick with... \n",
"3 five people wearing winter jackets and helmets... \n",
"4 people are fixing the roof of a house. \n",
".. ... \n",
"995 marathon runners are racing on a city street, ... \n",
"996 asian woman wearing a sunhat while riding a bike. \n",
"997 some children are outside playing in the dirt ... \n",
"998 an older man is playing a video arcade game. \n",
"999 a girl at the shore of a beach with a mountain... \n",
"\n",
" prediction \n",
"0 a man in an orange hat looking at something. \n",
"1 a boststst skier runs in front of a white fenc... \n",
"2 a girl in a karate uniform gets a stick with a... \n",
"3 five people wearing winter jackets and helmets... \n",
"4 people are fixing the roof of a house. \n",
".. ... \n",
"995 a marathon runner running around a city street... \n",
"996 asian woman wearing a sunhat riding a bicycle. \n",
"997 a couple of children are outside playing on th... \n",
"998 an older man is playing a video game. \n",
"999 a girl at a shore with a mountain in the backg... \n",
"\n",
"[1000 rows x 3 columns]"
]
},
"execution_count": 59,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"start = time.time()\n",
"rows = []\n",
"for example in data_loader:\n",
" input_ids = example['input_ids']\n",
" generated_ids = transformers_model.generate(input_ids.to(transformers_model.device))\n",
" generated_ids = generated_ids.detach().cpu().numpy()\n",
" predictions = target_tokenizer.decode_batch(generated_ids)\n",
"\n",
" labels = example['labels'].detach().cpu().numpy()\n",
" targets = target_tokenizer.decode_batch(labels)\n",
"\n",
" sources = source_tokenizer.decode_batch(input_ids.detach().cpu().numpy())\n",
" for source, target, prediction in zip(sources, targets, predictions):\n",
" row = [source, target, prediction]\n",
" rows.append(row)\n",
"\n",
"end = time.time()\n",
"print('elapsed: ', end - start)\n",
"df_rows = pd.DataFrame(rows, columns=['source', 'target', 'prediction'])\n",
"df_rows"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Reference"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- [Colab: Fine-tuning a model on a translation task](https://colab.research.google.com/github/huggingface/notebooks/blob/master/examples/translation.ipynb)\n",
"- [Github: Huggingface Transformers Translation Example](https://github.com/huggingface/transformers/tree/master/examples/pytorch/translation)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": true,
"toc_position": {
"height": "calc(100% - 180px)",
"left": "10px",
"top": "150px",
"width": "274.797px"
},
"toc_section_display": true,
"toc_window_display": true
},
"varInspector": {
"cols": {
"lenName": 16,
"lenType": 16,
"lenVar": 40
},
"kernels_config": {
"python": {
"delete_cmd_postfix": "",
"delete_cmd_prefix": "del ",
"library": "var_list.py",
"varRefreshCmd": "print(var_dic_list())"
},
"r": {
"delete_cmd_postfix": ") ",
"delete_cmd_prefix": "rm(",
"library": "var_list.r",
"varRefreshCmd": "cat(var_dic_list()) "
}
},
"types_to_exclude": [
"module",
"function",
"builtin_function_or_method",
"instance",
"_Feature"
],
"window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 4
}