{"cells":[{"metadata":{},"cell_type":"markdown","source":["#### Dependencies"]},{"metadata":{"trusted":true},"cell_type":"code","source":["# !pip install torch pytorch_lightning datasets wandb"],"execution_count":3,"outputs":[]},{"metadata":{"trusted":true},"cell_type":"code","source":["import os\n","import torch\n","import tqdm\n","import torch.nn.functional as F\n","from torch.utils.data import Dataset, DataLoader\n","from transformers import BertModel, BertConfig, BertTokenizer\n","from datasets import load_dataset\n","import pytorch_lightning as pl\n","import wandb"],"execution_count":4,"outputs":[{"output_type":"stream","text":"\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.\n","name":"stderr"}]},{"metadata":{},"cell_type":"markdown","source":["### Custom Dataset Class"]},{"metadata":{"trusted":true},"cell_type":"code","source":["# custom dataset class \n","class SentimentDataset(Dataset):\n"," def __init__(self, tokenizer, text, target, max_len=180):\n"," self.tokenizer = tokenizer\n"," self.text = text\n"," self.target = target\n"," self.max_len = max_len\n"," \n"," def __len__(self):\n"," return len(self.text)\n"," \n"," def __getitem__(self, idx):\n"," text = self.text[idx]\n"," target = self.target[idx]\n"," \n"," # encode the text and target into tensors return the attention masks as well\n"," encoding = self.tokenizer.encode_plus(\n"," text=text,\n"," add_special_tokens=True,\n"," max_length=self.max_len,\n"," return_token_type_ids=False,\n"," pad_to_max_length=True,\n"," return_attention_mask=True,\n"," return_tensors='pt',\n"," truncation=True\n"," )\n"," \n"," return {\n"," 'text': text,\n"," 'input_ids': encoding['input_ids'].flatten(),\n"," 'attention_mask': encoding['attention_mask'].flatten(),\n"," 'targets': torch.tensor(target, dtype=torch.long)\n"," }\n"," "],"execution_count":5,"outputs":[]},{"metadata":{},"cell_type":"markdown","source":["### BERTModel PyTorch"]},{"metadata":{"trusted":true},"cell_type":"code","source":["class BertClassifier(torch.nn.Module):\n"," \n"," def __init__(self, config, model, dim=256, num_classes=2):\n"," super(BertClassifier, self).__init__()\n"," \n"," # create the model config and BERT initialize the pretrained BERT, also layers wise outputs\n"," self.config = config\n"," self.base = model\n"," \n"," # classifier head [not useful]\n"," self.head = torch.nn.Sequential(*[\n"," torch.nn.Dropout(p=self.config.hidden_dropout_prob),\n"," torch.nn.Linear(in_features=self.config.hidden_size, out_features=dim),\n"," torch.nn.ReLU(),\n"," torch.nn.Dropout(p=self.config.hidden_dropout_prob),\n"," torch.nn.Linear(in_features=dim, out_features=num_classes)\n"," ])\n"," \n"," def forward(self, input_ids, attention_mask=None):\n"," \n"," # first output is top layer output, second output is context of input seq and third output will be layerwise token embeddings\n"," top_layer, pooled, layers = self.base(input_ids, attention_mask)\n"," outputs = self.head(pooled)\n"," return top_layer, outputs, layers\n"," "],"execution_count":6,"outputs":[]},{"metadata":{},"cell_type":"markdown","source":["### Lightning Model"]},{"metadata":{"trusted":true},"cell_type":"code","source":["class BertFinetuner(pl.LightningModule):\n"," \n"," def __init__(self, model=None, tokenizer=None, data_file=\"./data/twitter/train.csv\", use_cols=['review_text', 'sentiment'], batch_size=32):\n"," super(BertFinetuner, self).__init__()\n"," \n"," # initialize the BERT model c\n"," self.model = model\n"," self.data_file = data_file\n"," self.use_cols = use_cols\n"," self.batch_size = batch_size\n"," self.tokenizer = tokenizer\n"," \n"," self.f_score= Fbeta()\n"," \n"," def accuracy(self, outputs, targets):\n"," correct = 0\n"," for i in range(outputs.shape[0]):\n"," if outputs[i]==targets[i]:\n"," correct+=1\n"," return correct/outputs.shape[0]\n"," \n"," \n"," def forward(self, input_ids, attention_mask=None):\n"," top_layer, outputs, layers = self.model(input_ids, attention_mask)\n"," return top_layer, outputs, layers\n"," \n"," \n"," def configure_optimizers(self):\n"," return torch.optim.Adam(params=self.parameters(), lr=1e-5)\n"," \n"," def train_dataloader(self):\n"," # first 30% data reserved for validation\n"," train = load_dataset(\"csv\", data_files=self.data_file, split='train[20%:]')\n"," text, target = train['review_text'], train['sentiment']\n"," dataset = SentimentDataset(tokenizer=self.tokenizer, text=text, target=target)\n"," loader = DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)\n"," return loader\n"," \n"," def training_step(self, batch, batch_idx):\n"," input_ids, attention_mask, targets = batch['input_ids'], batch['attention_mask'], batch['targets']\n"," _, logits, _ = self(input_ids, attention_mask)\n"," loss = F.cross_entropy(logits, targets)\n"," acc = self.accuracy(logits.argmax(dim=1), targets)\n"," wandb.log({\"Loss\": loss, \"Accuracy\": torch.tensor(acc)})\n"," return {\"loss\": loss, \"accuracy\": torch.tensor(acc)}\n"," \n"," def val_dataloader(self):\n"," # first 30% data reserved for validation\n"," val = load_dataset(\"csv\", data_files=self.data_file, split='train[:20%]')\n"," text, target = val['review_text'], val['sentiment']\n"," dataset = SentimentDataset(tokenizer=self.tokenizer, text=text, target=target)\n"," loader = DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)\n"," return loader\n"," \n"," def validation_step(self, batch, batch_idx):\n"," input_ids, attention_mask, targets = batch['input_ids'], batch['attention_mask'], batch['targets']\n"," _, logits, _ = self(input_ids, attention_mask)\n"," loss = F.cross_entropy(logits, targets)\n"," acc = self.accuracy(logits.argmax(dim=1), targets)\n","# wandb.log({\"val_loss\":loss, \"val_accuracy\":acc})\n"," self.f_score(logits.argmax(dim=1), targets)\n"," return {\"val_loss\": loss, \"val_accuracy\": torch.tensor(acc)}\n","\n"," def validation_epoch_end(self, outputs):\n"," avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()\n"," avg_acc = torch.stack([x['val_accuracy'] for x in outputs]).mean()\n"," avg_f_score = self.f_score.compute()\n"," \n"," wandb.log({\"val_loss\":avg_loss, \"val_accuracy\":avg_acc, \"val_fb\":avg_f_score})\n"," return {'val_accuracy': avg_loss, 'val_accuracy': avg_acc, \"val_fb\":avg_f_score}\n"," "],"execution_count":7,"outputs":[]},{"metadata":{},"cell_type":"markdown","source":["### Training "]},{"metadata":{"trusted":true},"cell_type":"code","source":["from pytorch_lightning.loggers import WandbLogger\n","from pytorch_lightning.metrics import Fbeta \n","from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ProgressBar"],"execution_count":8,"outputs":[]},{"metadata":{"trusted":true},"cell_type":"code","source":["ROOT_DIR = \"../input/amazonproductsreview/amazon-review/\"\n","DATASET = \"dvd\"\n","NUM_CLASSES = 2\n","BATCH_SIZE = 32\n","EPOCH = 20"],"execution_count":20,"outputs":[]},{"metadata":{"trusted":true},"cell_type":"code","source":["# logger \n","logger = WandbLogger(\n"," name=DATASET,\n"," save_dir=\"../working/\",\n"," project=\"domain-adaptation\",\n"," log_model = True\n",")"],"execution_count":21,"outputs":[]},{"metadata":{"trusted":true},"cell_type":"code","source":["# callbacks\n","early_stopping = EarlyStopping(\n"," monitor=\"val_accuracy\",\n",")\n","model_checkpoint = ModelCheckpoint(\n"," filepath=\"{epoch}-{val_accuracy:.2f}-{val_loss:.2f}\",\n"," monitor=\"val_accuracy\",\n"," save_top_k=1,\n",")\n","progress_bar = ProgressBar()"],"execution_count":22,"outputs":[]},{"metadata":{"trusted":true},"cell_type":"code","source":["# create the BERTConfig, BERTTokenizer, and BERTModel \n","model_name = \"bert-base-uncased\"\n","config = BertConfig.from_pretrained(model_name, output_hidden_states=True)\n","tokenizer = BertTokenizer.from_pretrained(model_name, do_lower_case=True)\n","bert = BertModel.from_pretrained(model_name, config=config)\n","classifier = BertClassifier(config=config, model=bert, num_classes=NUM_CLASSES)"],"execution_count":23,"outputs":[]},{"metadata":{"trusted":true},"cell_type":"code","source":["model = BertFinetuner(\n"," model=classifier,\n"," data_file=os.path.join(ROOT_DIR, DATASET+\".csv\"),\n"," tokenizer=tokenizer,\n"," batch_size=BATCH_SIZE\n",")"],"execution_count":24,"outputs":[]},{"metadata":{"trusted":true},"cell_type":"code","source":["tuner = pl.Trainer(\n"," logger=logger,\n"," gpus=[0],\n"," checkpoint_callback=model_checkpoint,\n"," max_epochs=EPOCH,\n",")"],"execution_count":25,"outputs":[{"output_type":"stream","text":"GPU available: True, used: True\nTPU available: False, using: 0 TPU cores\nLOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n","name":"stderr"}]},{"metadata":{"trusted":true},"cell_type":"code","source":["tuner.fit(model)"],"execution_count":26,"outputs":[{"output_type":"stream","text":"\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Calling wandb.login() after wandb.init() has no effect.\n","name":"stderr"},{"output_type":"display_data","data":{"text/plain":"<IPython.core.display.HTML object>","text/html":"<br/>Waiting for W&B process to finish, PID 249<br/>Program ended successfully."},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\\r'), FloatProgress(value=1.0, max=1.0)…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"<IPython.core.display.HTML object>","text/html":"Find user logs for this run at: <code>../working/wandb/run-20201026_093116-2jdo8tzd/logs/debug.log</code>"},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"<IPython.core.display.HTML object>","text/html":"Find internal logs for this run at: <code>../working/wandb/run-20201026_093116-2jdo8tzd/logs/debug-internal.log</code>"},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"<IPython.core.display.HTML object>","text/html":"<h3>Run summary:</h3><br/><style>\n table.wandb td:nth-child(1) { padding: 0 10px; text-align: right }\n </style><table class=\"wandb\">\n<tr><td>val_loss</td><td>0.65587</td></tr><tr><td>val_accuracy</td><td>0.83062</td></tr><tr><td>val_fb</td><td>0.83688</td></tr><tr><td>_step</td><td>1020</td></tr><tr><td>_runtime</td><td>746</td></tr><tr><td>_timestamp</td><td>1603705422</td></tr><tr><td>Loss</td><td>0.00673</td></tr><tr><td>Accuracy</td><td>1.0</td></tr></table>"},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"<IPython.core.display.HTML object>","text/html":"<h3>Run history:</h3><br/><style>\n table.wandb td:nth-child(1) { padding: 0 10px; text-align: right }\n </style><table class=\"wandb\">\n<tr><td>val_loss</td><td>█▄▁▁▂▁▃▃▃▄▄▅▆▄▆▅▅▆▅▆▇</td></tr><tr><td>val_accuracy</td><td>▁▇█████████▇▇███████▇</td></tr><tr><td>val_fb</td><td>▁▆██▇██▇███▇▇███████▇</td></tr><tr><td>_step</td><td>▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███</td></tr><tr><td>_runtime</td><td>▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███</td></tr><tr><td>_timestamp</td><td>▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███</td></tr><tr><td>Loss</td><td>██▅▅▄▃▄▄▂▅▁▄▂▂▁▁▁▁▁▂▁▄▁▁▁▁▃▁▁▃▁▁▁▁▁▁▁▃▃▁</td></tr><tr><td>Accuracy</td><td>▁▃▇▆▇█▆▇█▆█▇█████████▇██████████████████</td></tr></table><br/>"},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"<IPython.core.display.HTML object>","text/html":"Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"<IPython.core.display.HTML object>","text/html":"\n <br/>Synced <strong style=\"color:#cdcd00\">books</strong>: <a href=\"https://wandb.ai/macab/domain-adaptation/runs/2jdo8tzd\" target=\"_blank\">https://wandb.ai/macab/domain-adaptation/runs/2jdo8tzd</a><br/>\n "},"metadata":{}},{"output_type":"stream","text":"\u001b[34m\u001b[1mwandb\u001b[0m: wandb version 0.10.8 is available! To upgrade, please run:\n\u001b[34m\u001b[1mwandb\u001b[0m: $ pip install wandb --upgrade\n","name":"stderr"},{"output_type":"display_data","data":{"text/plain":"<IPython.core.display.HTML object>","text/html":"\n Tracking run with wandb version 0.10.7<br/>\n Syncing run <strong style=\"color:#cdcd00\">dvd</strong> to <a href=\"https://wandb.ai\" target=\"_blank\">Weights & Biases</a> <a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">(Documentation)</a>.<br/>\n Project page: <a href=\"https://wandb.ai/macab/domain-adaptation\" target=\"_blank\">https://wandb.ai/macab/domain-adaptation</a><br/>\n Run page: <a href=\"https://wandb.ai/macab/domain-adaptation/runs/1rj1ncqv\" target=\"_blank\">https://wandb.ai/macab/domain-adaptation/runs/1rj1ncqv</a><br/>\n Run data is saved locally in <code>../working/wandb/run-20201026_094759-1rj1ncqv</code><br/><br/>\n "},"metadata":{}},{"output_type":"stream","text":"\n | Name | Type | Params\n-------------------------------------------\n0 | model | BertClassifier | 109 M \n1 | f_score | Fbeta | 0 \nUsing custom data configuration default\n","name":"stderr"},{"output_type":"stream","text":"Downloading and preparing dataset csv/default-343faf1a87cc9b22 (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /root/.cache/huggingface/datasets/csv/default-343faf1a87cc9b22/0.0.0/49187751790fa4d820300fd4d0707896e5b941f1a9c644652645b866716a4ac4...\n","name":"stdout"},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"stream","text":"Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/default-343faf1a87cc9b22/0.0.0/49187751790fa4d820300fd4d0707896e5b941f1a9c644652645b866716a4ac4. Subsequent calls will reuse this data.\n","name":"stdout"},{"output_type":"stream","text":"/opt/conda/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:45: UserWarning: Your val_dataloader has `shuffle=True`, it is best practice to turn this off for validation and test dataloaders.\n warnings.warn(*args, **kwargs)\n","name":"stderr"},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"stream","text":"/opt/conda/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:45: UserWarning: The validation_epoch_end should not return anything as of 9.1.to log, use self.log(...) or self.write(...) directly in the LightningModule\n warnings.warn(*args, **kwargs)\nUsing custom data configuration default\nReusing dataset csv (/root/.cache/huggingface/datasets/csv/default-343faf1a87cc9b22/0.0.0/49187751790fa4d820300fd4d0707896e5b941f1a9c644652645b866716a4ac4)\n","name":"stderr"},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"fb340a96fe2c404db157e6b09a2acf4f"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"stream","text":"\n","name":"stdout"},{"output_type":"execute_result","execution_count":26,"data":{"text/plain":"1"},"metadata":{}}]},{"metadata":{},"cell_type":"markdown","source":["#### Save trained state dictionary\n","- See section 4 : https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_models_for_inference.html"]},{"metadata":{},"cell_type":"markdown","source":["#### 1. Books"]},{"metadata":{"trusted":true},"cell_type":"code","source":["PATH = DATASET+\".pt\"\n","# save the model \n","torch.save(classifier.state_dict(), PATH)\n"],"execution_count":16,"outputs":[]},{"metadata":{"trusted":true},"cell_type":"code","source":["### Load from state dictionary\n","classifier_trained = BertClassifier(config=config, model=bert, num_classes=NUM_CLASSES)\n","classifier_trained.load_state_dict(torch.load(PATH))\n","\n","# you can evaluate the model on top 20% data"],"execution_count":19,"outputs":[{"output_type":"execute_result","execution_count":19,"data":{"text/plain":"<All keys matched successfully>"},"metadata":{}}]},{"metadata":{},"cell_type":"markdown","source":["#### DVD"]},{"metadata":{"trusted":true},"cell_type":"code","source":["PATH = DATASET+\".pt\"\n","# save the model \n","torch.save(classifier.state_dict(), PATH)"],"execution_count":27,"outputs":[]},{"metadata":{"trusted":true},"cell_type":"code","source":["### Load from state dictionary\n","classifier_dvd = BertClassifier(config=config, model=bert, num_classes=NUM_CLASSES)\n","classifier_dvd.load_state_dict(torch.load(PATH))\n","\n","# you can evaluate the model on top 20% data"],"execution_count":null,"outputs":[]},{"metadata":{"trusted":false},"cell_type":"code","source":["## There you go "],"execution_count":null,"outputs":[]},{"metadata":{"trusted":false},"cell_type":"code","source":[],"execution_count":null,"outputs":[]},{"metadata":{"trusted":false},"cell_type":"code","source":[],"execution_count":null,"outputs":[]}],"metadata":{"kernelspec":{"name":"python3","display_name":"Python 3","language":"python"},"language_info":{"name":"python","version":"3.7.6","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"}},"nbformat":4,"nbformat_minor":4}