{"cells":[{"metadata":{},"cell_type":"markdown","source":"#### Dependencies"},{"metadata":{"trusted":true},"cell_type":"code","source":"# !pip install datasets wandb","execution_count":2,"outputs":[]},{"metadata":{"trusted":true},"cell_type":"code","source":"# utils\nimport os\nimport torch\nimport tqdm\n\n# data\nfrom datasets import load_dataset\nfrom torch.utils.data import Dataset, DataLoader\nfrom transformers import AutoConfig, AutoModel, AutoTokenizer\n\n# model \nimport torch.nn as nn\n\n# training and evaluation\nimport wandb\nimport torch.nn.functional as F\nimport pytorch_lightning as pl\nfrom pytorch_lightning.loggers import WandbLogger\nfrom pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint\nfrom sklearn.metrics import accuracy_score, f1_score, classification_report","execution_count":3,"outputs":[{"output_type":"stream","text":"\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.\n","name":"stderr"}]},{"metadata":{"trusted":true},"cell_type":"code","source":"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\nprint(device)","execution_count":4,"outputs":[{"output_type":"stream","text":"cuda\n","name":"stdout"}]},{"metadata":{},"cell_type":"markdown","source":"### Custom Dataset Class"},{"metadata":{"trusted":true},"cell_type":"code","source":"# custom dataset class \nclass ReviewDataset(Dataset):\n def __init__(self, tokenizer, data, text_field='text', label_field='label', max_len=512):\n self.tokenizer = tokenizer\n self.data = data\n self.text_field = text_field\n self.label_field = label_field\n self.max_len = max_len\n \n def __len__(self):\n return len(self.data[self.text_field])\n \n def __getitem__(self, idx):\n text = self.data[self.text_field][idx]\n target = self.data[self.label_field][idx]\n \n \n # encode the text and target into tensors return the attention masks as well\n encoding = self.tokenizer.encode_plus(\n text=text,\n add_special_tokens=True,\n max_length=self.max_len,\n return_token_type_ids=False,\n pad_to_max_length=True,\n return_attention_mask=True,\n return_tensors='pt',\n truncation=True,\n padding='max_length'\n )\n \n return {\n 'text': text,\n 'input_ids': encoding['input_ids'].flatten(),\n 'attention_mask': encoding['attention_mask'].flatten(),\n 'targets': torch.tensor(target, dtype=torch.long)\n }\n ","execution_count":5,"outputs":[]},{"metadata":{},"cell_type":"markdown","source":"### Classifier "},{"metadata":{"trusted":true},"cell_type":"code","source":"class Classifier(torch.nn.Module):\n \n def __init__(self, model_name, num_classes=2):\n super(Classifier, self).__init__()\n \n # create the model config and BERT initialize the pretrained BERT, also layers wise outputs\n self.config = AutoConfig.from_pretrained(pretrained_model_name_or_path=model_name)\n self.base = AutoModel.from_pretrained(pretrained_model_name_or_path=model_name)\n \n # classifier head [not useful]\n self.head = nn.Sequential(*[\n nn.Linear(in_features=self.config.hidden_size, out_features=256),\n nn.ReLU(),\n nn.Linear(in_features=256, out_features=num_classes)\n ])\n \n \n def forward(self, input_ids, attention_mask=None):\n \n # first output is top layer output, second output is context of input seq and third output will be layerwise token embeddings\n top_layer, pooled = None, self.base(input_ids, attention_mask)[0][:, 0]\n logits = self.head(pooled)\n return logits, pooled, top_layer\n ","execution_count":7,"outputs":[]},{"metadata":{"trusted":true},"cell_type":"code","source":"","execution_count":null,"outputs":[]},{"metadata":{},"cell_type":"markdown","source":"### Lightning Model"},{"metadata":{"trusted":true},"cell_type":"code","source":"class Finetuner(pl.LightningModule):\n \n def __init__(self, config):\n super(Finetuner, self).__init__()\n \n # initialize the BERT model\n self.config = config\n self.model = Classifier(model_name=self.config['model_name'], num_classes=self.config['num_classes'])\n self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=self.config['model_name'])\n \n def forward(self, input_ids, attention_mask=None):\n logits, _, _ = self.model(input_ids, attention_mask)\n return logits\n \n \n def configure_optimizers(self):\n return torch.optim.Adam(params=self.parameters(), lr=self.config['lr'])\n \n def train_dataloader(self):\n # first 10% data reserved for validation\n data = load_dataset(\"csv\", data_files=self.config['root_dir']+self.config['source'], split='train[10%:]')\n dataset = ReviewDataset(tokenizer=self.tokenizer, data=data, text_field=self.config['text_field'], label_field=self.config['label_field'], max_len=self.config['max_len']) \n loader = DataLoader(dataset=dataset, batch_size=self.config['batch_size'], shuffle=True)\n return loader\n \n def training_step(self, batch, batch_idx): \n input_ids, attention_mask, targets = batch['input_ids'], batch['attention_mask'], batch['targets'] \n logits = self(input_ids, attention_mask)\n loss = F.cross_entropy(logits, targets) \n acc = accuracy_score(targets.cpu(), logits.argmax(dim=1).cpu())\n f1 = f1_score(targets.cpu(), logits.argmax(dim=1).cpu()) \n wandb.log({\"Loss\": loss, \"Accuracy\": torch.tensor([acc]), \"F1\":torch.tensor([f1])})\n return {\"loss\": loss, \"accuracy\": torch.tensor([acc]), \"f1\":torch.tensor([f1])}\n \n def val_dataloader(self):\n # first 10% data reserved for validation\n data = load_dataset(\"csv\", data_files=self.config['root_dir']+self.config['source'], split='train[:10%]')\n dataset = ReviewDataset(tokenizer=self.tokenizer, data=data, text_field=self.config['text_field'], label_field=self.config['label_field'], max_len=self.config['max_len']) \n loader = DataLoader(dataset=dataset, batch_size=self.config['batch_size'], shuffle=False)\n return loader\n \n def validation_step(self, batch, batch_idx):\n input_ids, attention_mask, targets = batch['input_ids'], batch['attention_mask'], batch['targets']\n logits = self(input_ids, attention_mask)\n loss = F.cross_entropy(logits, targets)\n acc = accuracy_score(targets.cpu(), logits.argmax(dim=1).cpu())\n f1 = f1_score(targets.cpu(), logits.argmax(dim=1).cpu()) \n wandb.log({\"Val_loss\": loss, \"Val_accuracy\": torch.tensor([acc]), \"Val_f1\":torch.tensor([f1])})\n return {\"val_loss\": loss, \"val_accuracy\": torch.tensor([acc]), \"val_f1\":torch.tensor([f1])}\n\n def validation_epoch_end(self, outputs):\n avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()\n avg_acc = torch.stack([x['val_accuracy'] for x in outputs]).mean()\n avg_f_score = torch.stack([x['val_f1'] for x in outputs]).mean()\n \n wandb.log({\"Val_loss\":avg_loss, \"Val_accuracy\":avg_acc, \"Val_f1\":avg_f_score})\n return {'val_loss': avg_loss, 'val_accuracy': avg_acc, \"val_f1\":avg_f_score}\n \n def test_dataloader(self):\n # test data is same as validation data\n data = load_dataset(\"csv\", data_files=self.config['root_dir']+self.config['source'], split='train[:10%]')\n dataset = ReviewDataset(tokenizer=self.tokenizer, data=data, text_field=self.config['text_field'], label_field=self.config['label_field'], max_len=self.config['max_len']) \n loader = DataLoader(dataset=dataset, batch_size=self.config['batch_size'], shuffle=False)\n return loader\n \n def test_step(self, batch, batch_idx):\n input_ids, attention_mask, targets = batch['input_ids'], batch['attention_mask'], batch['targets']\n logits = self(input_ids, attention_mask)\n loss = F.cross_entropy(logits, targets)\n acc = accuracy_score(targets.cpu(), logits.argmax(dim=1).cpu())\n f1 = f1_score(targets.cpu(), logits.argmax(dim=1).cpu())\n return {\"test_loss\":loss, \"test_accuracy\":torch.tensor([acc]), \"test_f1\":torch.tensor([f1])}\n \n def test_epoch_end(self, outputs):\n avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()\n avg_acc = torch.stack([x['test_accuracy'] for x in outputs]).mean()\n avg_f1 = torch.stack([x['test_f1'] for x in outputs]).mean()\n return {\"test_loss\":avg_loss, \"test_accuracy\":avg_acc, \"test_f1\":avg_f1}\n\n ","execution_count":8,"outputs":[]},{"metadata":{},"cell_type":"markdown","source":"### Training "},{"metadata":{"trusted":true},"cell_type":"code","source":"config = {\n \n # data\n \"root_dir\":\"../input/amazonproductsreview/amazon-review/\",\n \"source\":'books.csv',\n \"targets\":[\"dvd.csv\", \"electronics.csv\", \"kitchen_housewares.csv\"],\n \"max_len\":512,\n \"batch_size\":8,\n \"num_classes\":2,\n \"text_field\":\"review_text\",\n \"label_field\":\"sentiment\",\n \n # model\n \"model_name\":'xlnet-base-cased',\n \n \n # training\n \"lr\":1e-5,\n \"epochs\":20,\n \n # logger and checkpoints\n \"project\":\"pretrained-model-robustness\",\n \"run_name\":\"xlnet\",\n \"monitor\":\"val_accuracy\",\n \"min_delta\":0.001,\n \"filepath\":\"../working/{epoch}-{val_accuracy:4f}\",\n \"save_dir\":\"../working/\",\n \n}","execution_count":10,"outputs":[]},{"metadata":{"trusted":true},"cell_type":"code","source":"logger = WandbLogger(\n name=config[\"run_name\"],\n save_dir=config[\"save_dir\"],\n project=config[\"project\"],\n log_model=True,\n)\nearly_stopping = EarlyStopping(\n monitor=config[\"monitor\"],\n min_delta=config[\"min_delta\"],\n patience=5,\n)\ncheckpoints = ModelCheckpoint(\n filepath=config[\"filepath\"],\n monitor=config[\"monitor\"],\n save_top_k=1\n)\n","execution_count":11,"outputs":[{"output_type":"stream","text":"/opt/conda/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:45: UserWarning: Checkpoint directory /kaggle/working exists and is not empty. With save_top_k=1, all files in this directory will be deleted when a checkpoint is saved!\n warnings.warn(*args, **kwargs)\n","name":"stderr"}]},{"metadata":{"trusted":true},"cell_type":"code","source":"trainer = pl.Trainer(\n logger=logger,\n gpus=[0],\n checkpoint_callback=checkpoints,\n default_root_dir=\"../working/\",\n max_epochs=config[\"epochs\"],\n callbacks=[early_stopping]\n)\n","execution_count":13,"outputs":[{"output_type":"stream","text":"GPU available: True, used: True\nTPU available: False, using: 0 TPU cores\nLOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n","name":"stderr"}]},{"metadata":{"trusted":true},"cell_type":"code","source":"model = Finetuner(config)","execution_count":14,"outputs":[{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=0.0, description='Downloading', max=760.0, style=ProgressStyle(description_…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"a84008c60b1f4fdea5e439834b1990dc"}},"metadata":{}},{"output_type":"stream","text":"\n","name":"stdout"},{"output_type":"stream","text":"/opt/conda/lib/python3.7/site-packages/transformers/configuration_xlnet.py:212: FutureWarning: This config doesn't use attention memories, a core feature of XLNet. Consider setting `mem_len` to a non-zero value, for example `xlnet = XLNetLMHeadModel.from_pretrained('xlnet-base-cased'', mem_len=1024)`, for accurate training performance as well as an order of magnitude faster inference. Starting from version 3.5.0, the default parameter will be 1024, following the implementation in https://arxiv.org/abs/1906.08237\n FutureWarning,\n","name":"stderr"},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=0.0, description='Downloading', max=467042463.0, style=ProgressStyle(descri…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"3ea66bd8817546f598a11dd93bb0b2f0"}},"metadata":{}},{"output_type":"stream","text":"\n","name":"stdout"},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=0.0, description='Downloading', max=798011.0, style=ProgressStyle(descripti…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"64784c59b4084cf79c4454d8aabf89b3"}},"metadata":{}},{"output_type":"stream","text":"\n","name":"stdout"}]},{"metadata":{"trusted":true},"cell_type":"code","source":"trainer.fit(model)","execution_count":15,"outputs":[{"output_type":"stream","text":"\u001b[34m\u001b[1mwandb\u001b[0m: You can find your API key in your browser here: https://wandb.ai/authorize\n","name":"stderr"},{"output_type":"stream","name":"stdout","text":"wandb: Paste an API key from your profile and hit enter: ········\n"},{"output_type":"stream","text":"\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc\n\u001b[34m\u001b[1mwandb\u001b[0m: wandb version 0.10.11 is available! To upgrade, please run:\n\u001b[34m\u001b[1mwandb\u001b[0m: $ pip install wandb --upgrade\n","name":"stderr"},{"output_type":"display_data","data":{"text/plain":"","text/html":"\n Tracking run with wandb version 0.10.10
\n Syncing run xlnet to Weights & Biases (Documentation).
\n Project page: https://wandb.ai/macab/pretrained-model-robustness
\n Run page: https://wandb.ai/macab/pretrained-model-robustness/runs/3bz466av
\n Run data is saved locally in ../working/wandb/run-20201128_021800-3bz466av

\n "},"metadata":{}},{"output_type":"stream","text":"\n | Name | Type | Params\n-------------------------------------\n0 | model | Classifier | 116 M \n","name":"stderr"},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1562.0, style=ProgressStyle(description…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"ef867d44003141f8a05671f26520d1ac"}},"metadata":{}},{"output_type":"stream","text":"Using custom data configuration default\n","name":"stderr"},{"output_type":"stream","text":"\nDownloading and preparing dataset csv/default-af39d8bf1cb33848 (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /root/.cache/huggingface/datasets/csv/default-af39d8bf1cb33848/0.0.0/2960f95a26e85d40ca41a230ac88787f715ee3003edaacb8b1f0891e9f04dda2...\n","name":"stdout"},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"stream","text":"Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/default-af39d8bf1cb33848/0.0.0/2960f95a26e85d40ca41a230ac88787f715ee3003edaacb8b1f0891e9f04dda2. Subsequent calls will reuse this data.\n","name":"stdout"},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"stream","text":"/opt/conda/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:45: UserWarning: The validation_epoch_end should not return anything as of 9.1.to log, use self.log(...) or self.write(...) directly in the LightningModule\n warnings.warn(*args, **kwargs)\nUsing custom data configuration default\nReusing dataset csv (/root/.cache/huggingface/datasets/csv/default-af39d8bf1cb33848/0.0.0/2960f95a26e85d40ca41a230ac88787f715ee3003edaacb8b1f0891e9f04dda2)\n","name":"stderr"},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"0f601a652c94430883ee50253b41f3db"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"stream","text":"/opt/conda/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1465: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 due to no true nor predicted samples. Use `zero_division` parameter to control this behavior.\n average, \"true nor predicted\", 'F-score is', len(true_sum)\n","name":"stderr"},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"stream","text":"\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Symlinked 0 file into the W&B run directory, call wandb.save again to sync new files.\n","name":"stderr"},{"output_type":"stream","text":"\n","name":"stdout"},{"output_type":"execute_result","execution_count":15,"data":{"text/plain":"1"},"metadata":{}}]},{"metadata":{"trusted":true},"cell_type":"code","source":"trainer.test(model)","execution_count":16,"outputs":[{"output_type":"stream","text":"Using custom data configuration default\nReusing dataset csv (/root/.cache/huggingface/datasets/csv/default-af39d8bf1cb33848/0.0.0/2960f95a26e85d40ca41a230ac88787f715ee3003edaacb8b1f0891e9f04dda2)\n","name":"stderr"},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"e03ceb2763d14db099c0f371ffa3cb1a"}},"metadata":{}},{"output_type":"stream","text":"--------------------------------------------------------------------------------\nDATALOADER:0 TEST RESULTS\n{'test_accuracy': tensor(0.8800, dtype=torch.float64),\n 'test_f1': tensor(0.8675, dtype=torch.float64),\n 'test_loss': tensor(0.4570, device='cuda:0')}\n--------------------------------------------------------------------------------\n\n","name":"stdout"},{"output_type":"stream","text":"/opt/conda/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:45: UserWarning: The testing_epoch_end should not return anything as of 9.1.to log, use self.log(...) or self.write(...) directly in the LightningModule\n warnings.warn(*args, **kwargs)\n","name":"stderr"},{"output_type":"execute_result","execution_count":16,"data":{"text/plain":"[{'test_loss': 0.45704737305641174,\n 'test_accuracy': 0.88,\n 'test_f1': 0.8674862914862916}]"},"metadata":{}}]},{"metadata":{},"cell_type":"markdown","source":"#### Load from checkpoint and test"},{"metadata":{"trusted":true},"cell_type":"code","source":"l = torch.load(f=\"../working/epoch=6-val_accuracy=0.915000.ckpt\")\nmodel.load_state_dict(l['state_dict'])","execution_count":17,"outputs":[{"output_type":"error","ename":"FileNotFoundError","evalue":"[Errno 2] No such file or directory: '../working/epoch=10-val_accuracy=0.895000.ckpt'","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)","\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0ml\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"../working/epoch=10-val_accuracy=0.895000.ckpt\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_state_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ml\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'state_dict'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/opt/conda/lib/python3.7/site-packages/torch/serialization.py\u001b[0m in \u001b[0;36mload\u001b[0;34m(f, map_location, pickle_module, **pickle_load_args)\u001b[0m\n\u001b[1;32m 569\u001b[0m \u001b[0mpickle_load_args\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'encoding'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'utf-8'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 570\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 571\u001b[0;31m \u001b[0;32mwith\u001b[0m \u001b[0m_open_file_like\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'rb'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mopened_file\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 572\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0m_is_zipfile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mopened_file\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 573\u001b[0m \u001b[0;31m# The zipfile reader is going to advance the current file position.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/opt/conda/lib/python3.7/site-packages/torch/serialization.py\u001b[0m in \u001b[0;36m_open_file_like\u001b[0;34m(name_or_buffer, mode)\u001b[0m\n\u001b[1;32m 227\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_open_file_like\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname_or_buffer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 228\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0m_is_path\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname_or_buffer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 229\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_open_file\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname_or_buffer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 230\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 231\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;34m'w'\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/opt/conda/lib/python3.7/site-packages/torch/serialization.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, name, mode)\u001b[0m\n\u001b[1;32m 208\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0m_open_file\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_opener\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 209\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 210\u001b[0;31m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_open_file\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 211\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 212\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__exit__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '../working/epoch=10-val_accuracy=0.895000.ckpt'"]}]},{"metadata":{"trusted":true},"cell_type":"code","source":"trainer.test(model)","execution_count":16,"outputs":[{"output_type":"stream","text":"Using custom data configuration default\nReusing dataset csv (/root/.cache/huggingface/datasets/csv/default-c2d784cc01c6211e/0.0.0/2960f95a26e85d40ca41a230ac88787f715ee3003edaacb8b1f0891e9f04dda2)\n","name":"stderr"},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"25c833221693428ebb6bd9a8f32c794d"}},"metadata":{}},{"output_type":"stream","text":"--------------------------------------------------------------------------------\nDATALOADER:0 TEST RESULTS\n{'test_accuracy': tensor(0.8950, dtype=torch.float64),\n 'test_f1': tensor(0.8734, dtype=torch.float64),\n 'test_loss': tensor(0.4835, device='cuda:0')}\n--------------------------------------------------------------------------------\n\n","name":"stdout"},{"output_type":"execute_result","execution_count":16,"data":{"text/plain":"[{'test_loss': 0.4835154712200165,\n 'test_accuracy': 0.895,\n 'test_f1': 0.8734305694305696}]"},"metadata":{}}]},{"metadata":{},"cell_type":"markdown","source":"#### Evaluating the model on different target distribution"},{"metadata":{"trusted":true},"cell_type":"code","source":"def load_data(file, toknizer):\n data = load_dataset(\"csv\", data_files=config['root_dir']+file)\n dataset = ReviewDataset(tokenizer=tokenizer, data=data['train'], text_field=config['text_field'], label_field=config['label_field'])\n data_loader = DataLoader(dataset=dataset, batch_size=config['batch_size'], shuffle=False)\n return data_loader\n\ndef test_fn(model, loader):\n \n y_true = []\n y_pred = []\n model.eval()\n for batch in tqdm.tqdm(loader):\n input_ids, attention_mask, targets = batch['input_ids'], batch['attention_mask'], batch['targets']\n logits = model(input_ids.to(device), attention_mask.to(device))\n y_true += targets.tolist()\n y_pred += logits.argmax(dim=1).cpu().tolist()\n \n return classification_report(y_true, y_pred)\n \n ","execution_count":17,"outputs":[]},{"metadata":{"trusted":true},"cell_type":"code","source":"tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=config['model_name'])","execution_count":18,"outputs":[]},{"metadata":{"trusted":true},"cell_type":"code","source":"for target in config['targets']:\n loader = load_data(file=target, toknizer=tokenizer)\n report = test_fn(model.to(device), loader)\n print(f'Target Domain Name: {target}')\n print(report)\n del loader\n ","execution_count":19,"outputs":[{"output_type":"stream","text":"Using custom data configuration default\n","name":"stderr"},{"output_type":"stream","text":"Downloading and preparing dataset csv/default-884b52687aa4816e (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /root/.cache/huggingface/datasets/csv/default-884b52687aa4816e/0.0.0/2960f95a26e85d40ca41a230ac88787f715ee3003edaacb8b1f0891e9f04dda2...\n","name":"stdout"},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"stream","text":"\r 0%| | 0/248 [00:00