{"cells":[{"metadata":{},"cell_type":"markdown","source":["#### Dependencies"]},{"metadata":{"trusted":true},"cell_type":"code","source":["# !pip install torch pytorch_lightning datasets wandb"],"execution_count":3,"outputs":[]},{"metadata":{"trusted":true},"cell_type":"code","source":["import os\n","import torch\n","import tqdm\n","import torch.nn.functional as F\n","from torch.utils.data import Dataset, DataLoader\n","from transformers import BertModel, BertConfig, BertTokenizer\n","from datasets import load_dataset\n","import pytorch_lightning as pl\n","import wandb"],"execution_count":4,"outputs":[{"output_type":"stream","text":"\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m W&B installed but not logged in.  Run `wandb login` or set the WANDB_API_KEY env variable.\n","name":"stderr"}]},{"metadata":{},"cell_type":"markdown","source":["### Custom Dataset Class"]},{"metadata":{"trusted":true},"cell_type":"code","source":["# custom dataset class \n","class SentimentDataset(Dataset):\n","    def __init__(self, tokenizer, text, target, max_len=180):\n","        self.tokenizer = tokenizer\n","        self.text = text\n","        self.target = target\n","        self.max_len =  max_len\n","    \n","    def __len__(self):\n","        return len(self.text)\n","    \n","    def __getitem__(self, idx):\n","        text  = self.text[idx]\n","        target = self.target[idx]\n","        \n","        # encode the text and target into tensors return the attention masks as well\n","        encoding = self.tokenizer.encode_plus(\n","            text=text,\n","            add_special_tokens=True,\n","            max_length=self.max_len,\n","            return_token_type_ids=False,\n","            pad_to_max_length=True,\n","            return_attention_mask=True,\n","            return_tensors='pt',\n","            truncation=True\n","        )\n","        \n","        return {\n","          'text': text,\n","          'input_ids': encoding['input_ids'].flatten(),\n","          'attention_mask': encoding['attention_mask'].flatten(),\n","          'targets': torch.tensor(target, dtype=torch.long)\n","        }\n","        "],"execution_count":5,"outputs":[]},{"metadata":{},"cell_type":"markdown","source":["### BERTModel PyTorch"]},{"metadata":{"trusted":true},"cell_type":"code","source":["class BertClassifier(torch.nn.Module):\n","    \n","    def __init__(self, config, model, dim=256, num_classes=2):\n","        super(BertClassifier, self).__init__()\n","        \n","        # create the model config and BERT initialize the pretrained BERT, also layers wise outputs\n","        self.config = config\n","        self.base = model\n","        \n","        # classifier head [not useful]\n","        self.head = torch.nn.Sequential(*[\n","            torch.nn.Dropout(p=self.config.hidden_dropout_prob),\n","            torch.nn.Linear(in_features=self.config.hidden_size, out_features=dim),\n","            torch.nn.ReLU(),\n","            torch.nn.Dropout(p=self.config.hidden_dropout_prob),\n","            torch.nn.Linear(in_features=dim, out_features=num_classes)\n","        ])\n","    \n","    def forward(self, input_ids, attention_mask=None):\n","        \n","        # first output is top layer output, second output is context of input seq and third output will be layerwise token embeddings\n","        top_layer, pooled, layers = self.base(input_ids, attention_mask)\n","        outputs = self.head(pooled)\n","        return top_layer, outputs, layers\n","        "],"execution_count":6,"outputs":[]},{"metadata":{},"cell_type":"markdown","source":["### Lightning Model"]},{"metadata":{"trusted":true},"cell_type":"code","source":["class BertFinetuner(pl.LightningModule):\n","    \n","    def __init__(self, model=None, tokenizer=None, data_file=\"./data/twitter/train.csv\", use_cols=['review_text', 'sentiment'], batch_size=32):\n","        super(BertFinetuner, self).__init__()\n","        \n","        # initialize the BERT model c\n","        self.model = model\n","        self.data_file = data_file\n","        self.use_cols = use_cols\n","        self.batch_size = batch_size\n","        self.tokenizer = tokenizer\n","        \n","        self.f_score= Fbeta()\n","    \n","    def accuracy(self, outputs, targets):\n","        correct = 0\n","        for i in range(outputs.shape[0]):\n","            if outputs[i]==targets[i]:\n","                correct+=1\n","        return correct/outputs.shape[0]\n","    \n","    \n","    def forward(self, input_ids, attention_mask=None):\n","        top_layer, outputs, layers =  self.model(input_ids, attention_mask)\n","        return top_layer, outputs, layers\n","    \n","    \n","    def configure_optimizers(self):\n","        return torch.optim.Adam(params=self.parameters(), lr=1e-5)\n","    \n","    def train_dataloader(self):\n","        # first 30% data reserved for validation\n","        train = load_dataset(\"csv\", data_files=self.data_file, split='train[20%:]')\n","        text, target = train['review_text'], train['sentiment']\n","        dataset = SentimentDataset(tokenizer=self.tokenizer, text=text, target=target)\n","        loader = DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)\n","        return loader\n","        \n","    def training_step(self, batch, batch_idx):\n","        input_ids, attention_mask, targets =  batch['input_ids'], batch['attention_mask'], batch['targets']\n","        _, logits, _ = self(input_ids, attention_mask)\n","        loss = F.cross_entropy(logits, targets)\n","        acc = self.accuracy(logits.argmax(dim=1), targets)\n","        wandb.log({\"Loss\": loss, \"Accuracy\": torch.tensor(acc)})\n","        return {\"loss\": loss, \"accuracy\": torch.tensor(acc)}\n","    \n","    def val_dataloader(self):\n","        # first 30% data reserved for validation\n","        val = load_dataset(\"csv\", data_files=self.data_file, split='train[:20%]')\n","        text, target = val['review_text'], val['sentiment']\n","        dataset = SentimentDataset(tokenizer=self.tokenizer, text=text, target=target)\n","        loader = DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)\n","        return loader\n","        \n","    def validation_step(self, batch, batch_idx):\n","        input_ids, attention_mask, targets =  batch['input_ids'], batch['attention_mask'], batch['targets']\n","        _, logits, _ = self(input_ids, attention_mask)\n","        loss = F.cross_entropy(logits, targets)\n","        acc = self.accuracy(logits.argmax(dim=1), targets)\n","#         wandb.log({\"val_loss\":loss, \"val_accuracy\":acc})\n","        self.f_score(logits.argmax(dim=1), targets)\n","        return {\"val_loss\": loss, \"val_accuracy\": torch.tensor(acc)}\n","\n","    def validation_epoch_end(self, outputs):\n","        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()\n","        avg_acc = torch.stack([x['val_accuracy'] for x in outputs]).mean()\n","        avg_f_score = self.f_score.compute()\n","        \n","        wandb.log({\"val_loss\":avg_loss, \"val_accuracy\":avg_acc, \"val_fb\":avg_f_score})\n","        return {'val_accuracy': avg_loss, 'val_accuracy': avg_acc, \"val_fb\":avg_f_score}\n","    "],"execution_count":7,"outputs":[]},{"metadata":{},"cell_type":"markdown","source":["### Training "]},{"metadata":{"trusted":true},"cell_type":"code","source":["from pytorch_lightning.loggers import WandbLogger\n","from pytorch_lightning.metrics import Fbeta \n","from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ProgressBar"],"execution_count":8,"outputs":[]},{"metadata":{"trusted":true},"cell_type":"code","source":["ROOT_DIR = \"../input/amazonproductsreview/amazon-review/\"\n","DATASET = \"dvd\"\n","NUM_CLASSES = 2\n","BATCH_SIZE = 32\n","EPOCH = 20"],"execution_count":20,"outputs":[]},{"metadata":{"trusted":true},"cell_type":"code","source":["# logger \n","logger = WandbLogger(\n","    name=DATASET,\n","    save_dir=\"../working/\",\n","    project=\"domain-adaptation\",\n","    log_model = True\n",")"],"execution_count":21,"outputs":[]},{"metadata":{"trusted":true},"cell_type":"code","source":["# callbacks\n","early_stopping = EarlyStopping(\n","    monitor=\"val_accuracy\",\n",")\n","model_checkpoint = ModelCheckpoint(\n","    filepath=\"{epoch}-{val_accuracy:.2f}-{val_loss:.2f}\",\n","    monitor=\"val_accuracy\",\n","    save_top_k=1,\n",")\n","progress_bar = ProgressBar()"],"execution_count":22,"outputs":[]},{"metadata":{"trusted":true},"cell_type":"code","source":["# create the BERTConfig, BERTTokenizer, and BERTModel \n","model_name = \"bert-base-uncased\"\n","config = BertConfig.from_pretrained(model_name, output_hidden_states=True)\n","tokenizer = BertTokenizer.from_pretrained(model_name, do_lower_case=True)\n","bert = BertModel.from_pretrained(model_name, config=config)\n","classifier = BertClassifier(config=config, model=bert, num_classes=NUM_CLASSES)"],"execution_count":23,"outputs":[]},{"metadata":{"trusted":true},"cell_type":"code","source":["model = BertFinetuner(\n","    model=classifier,\n","    data_file=os.path.join(ROOT_DIR, DATASET+\".csv\"),\n","    tokenizer=tokenizer,\n","    batch_size=BATCH_SIZE\n",")"],"execution_count":24,"outputs":[]},{"metadata":{"trusted":true},"cell_type":"code","source":["tuner = pl.Trainer(\n","    logger=logger,\n","    gpus=[0],\n","    checkpoint_callback=model_checkpoint,\n","    max_epochs=EPOCH,\n",")"],"execution_count":25,"outputs":[{"output_type":"stream","text":"GPU available: True, used: True\nTPU available: False, using: 0 TPU cores\nLOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n","name":"stderr"}]},{"metadata":{"trusted":true},"cell_type":"code","source":["tuner.fit(model)"],"execution_count":26,"outputs":[{"output_type":"stream","text":"\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Calling wandb.login() after wandb.init() has no effect.\n","name":"stderr"},{"output_type":"display_data","data":{"text/plain":"<IPython.core.display.HTML object>","text/html":"<br/>Waiting for W&B process to finish, PID 249<br/>Program ended successfully."},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\\r'), FloatProgress(value=1.0, max=1.0)…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"<IPython.core.display.HTML object>","text/html":"Find user logs for this run at: <code>../working/wandb/run-20201026_093116-2jdo8tzd/logs/debug.log</code>"},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"<IPython.core.display.HTML object>","text/html":"Find internal logs for this run at: <code>../working/wandb/run-20201026_093116-2jdo8tzd/logs/debug-internal.log</code>"},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"<IPython.core.display.HTML object>","text/html":"<h3>Run summary:</h3><br/><style>\n    table.wandb td:nth-child(1) { padding: 0 10px; text-align: right }\n    </style><table class=\"wandb\">\n<tr><td>val_loss</td><td>0.65587</td></tr><tr><td>val_accuracy</td><td>0.83062</td></tr><tr><td>val_fb</td><td>0.83688</td></tr><tr><td>_step</td><td>1020</td></tr><tr><td>_runtime</td><td>746</td></tr><tr><td>_timestamp</td><td>1603705422</td></tr><tr><td>Loss</td><td>0.00673</td></tr><tr><td>Accuracy</td><td>1.0</td></tr></table>"},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"<IPython.core.display.HTML object>","text/html":"<h3>Run history:</h3><br/><style>\n    table.wandb td:nth-child(1) { padding: 0 10px; text-align: right }\n    </style><table class=\"wandb\">\n<tr><td>val_loss</td><td>█▄▁▁▂▁▃▃▃▄▄▅▆▄▆▅▅▆▅▆▇</td></tr><tr><td>val_accuracy</td><td>▁▇█████████▇▇███████▇</td></tr><tr><td>val_fb</td><td>▁▆██▇██▇███▇▇███████▇</td></tr><tr><td>_step</td><td>▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███</td></tr><tr><td>_runtime</td><td>▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███</td></tr><tr><td>_timestamp</td><td>▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███</td></tr><tr><td>Loss</td><td>██▅▅▄▃▄▄▂▅▁▄▂▂▁▁▁▁▁▂▁▄▁▁▁▁▃▁▁▃▁▁▁▁▁▁▁▃▃▁</td></tr><tr><td>Accuracy</td><td>▁▃▇▆▇█▆▇█▆█▇█████████▇██████████████████</td></tr></table><br/>"},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"<IPython.core.display.HTML object>","text/html":"Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"<IPython.core.display.HTML object>","text/html":"\n                    <br/>Synced <strong style=\"color:#cdcd00\">books</strong>: <a href=\"https://wandb.ai/macab/domain-adaptation/runs/2jdo8tzd\" target=\"_blank\">https://wandb.ai/macab/domain-adaptation/runs/2jdo8tzd</a><br/>\n                "},"metadata":{}},{"output_type":"stream","text":"\u001b[34m\u001b[1mwandb\u001b[0m: wandb version 0.10.8 is available!  To upgrade, please run:\n\u001b[34m\u001b[1mwandb\u001b[0m:  $ pip install wandb --upgrade\n","name":"stderr"},{"output_type":"display_data","data":{"text/plain":"<IPython.core.display.HTML object>","text/html":"\n                Tracking run with wandb version 0.10.7<br/>\n                Syncing run <strong style=\"color:#cdcd00\">dvd</strong> to <a href=\"https://wandb.ai\" target=\"_blank\">Weights & Biases</a> <a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">(Documentation)</a>.<br/>\n                Project page: <a href=\"https://wandb.ai/macab/domain-adaptation\" target=\"_blank\">https://wandb.ai/macab/domain-adaptation</a><br/>\n                Run page: <a href=\"https://wandb.ai/macab/domain-adaptation/runs/1rj1ncqv\" target=\"_blank\">https://wandb.ai/macab/domain-adaptation/runs/1rj1ncqv</a><br/>\n                Run data is saved locally in <code>../working/wandb/run-20201026_094759-1rj1ncqv</code><br/><br/>\n            "},"metadata":{}},{"output_type":"stream","text":"\n  | Name    | Type           | Params\n-------------------------------------------\n0 | model   | BertClassifier | 109 M \n1 | f_score | Fbeta          | 0     \nUsing custom data configuration default\n","name":"stderr"},{"output_type":"stream","text":"Downloading and preparing dataset csv/default-343faf1a87cc9b22 (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /root/.cache/huggingface/datasets/csv/default-343faf1a87cc9b22/0.0.0/49187751790fa4d820300fd4d0707896e5b941f1a9c644652645b866716a4ac4...\n","name":"stdout"},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"stream","text":"Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/default-343faf1a87cc9b22/0.0.0/49187751790fa4d820300fd4d0707896e5b941f1a9c644652645b866716a4ac4. Subsequent calls will reuse this data.\n","name":"stdout"},{"output_type":"stream","text":"/opt/conda/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:45: UserWarning: Your val_dataloader has `shuffle=True`, it is best practice to turn this off for validation and test dataloaders.\n  warnings.warn(*args, **kwargs)\n","name":"stderr"},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"stream","text":"/opt/conda/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:45: UserWarning: The validation_epoch_end should not return anything as of 9.1.to log, use self.log(...) or self.write(...) directly in the LightningModule\n  warnings.warn(*args, **kwargs)\nUsing custom data configuration default\nReusing dataset csv (/root/.cache/huggingface/datasets/csv/default-343faf1a87cc9b22/0.0.0/49187751790fa4d820300fd4d0707896e5b941f1a9c644652645b866716a4ac4)\n","name":"stderr"},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"fb340a96fe2c404db157e6b09a2acf4f"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":""}},"metadata":{}},{"output_type":"stream","text":"\n","name":"stdout"},{"output_type":"execute_result","execution_count":26,"data":{"text/plain":"1"},"metadata":{}}]},{"metadata":{},"cell_type":"markdown","source":["#### Save trained state dictionary\n","- See section 4 : https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_models_for_inference.html"]},{"metadata":{},"cell_type":"markdown","source":["#### 1. Books"]},{"metadata":{"trusted":true},"cell_type":"code","source":["PATH =  DATASET+\".pt\"\n","# save the model \n","torch.save(classifier.state_dict(), PATH)\n"],"execution_count":16,"outputs":[]},{"metadata":{"trusted":true},"cell_type":"code","source":["### Load from state dictionary\n","classifier_trained = BertClassifier(config=config, model=bert, num_classes=NUM_CLASSES)\n","classifier_trained.load_state_dict(torch.load(PATH))\n","\n","# you can evaluate the model on top 20% data"],"execution_count":19,"outputs":[{"output_type":"execute_result","execution_count":19,"data":{"text/plain":"<All keys matched successfully>"},"metadata":{}}]},{"metadata":{},"cell_type":"markdown","source":["#### DVD"]},{"metadata":{"trusted":true},"cell_type":"code","source":["PATH =  DATASET+\".pt\"\n","# save the model \n","torch.save(classifier.state_dict(), PATH)"],"execution_count":27,"outputs":[]},{"metadata":{"trusted":true},"cell_type":"code","source":["### Load from state dictionary\n","classifier_dvd = BertClassifier(config=config, model=bert, num_classes=NUM_CLASSES)\n","classifier_dvd.load_state_dict(torch.load(PATH))\n","\n","# you can evaluate the model on top 20% data"],"execution_count":null,"outputs":[]},{"metadata":{"trusted":false},"cell_type":"code","source":["## There you go "],"execution_count":null,"outputs":[]},{"metadata":{"trusted":false},"cell_type":"code","source":[],"execution_count":null,"outputs":[]},{"metadata":{"trusted":false},"cell_type":"code","source":[],"execution_count":null,"outputs":[]}],"metadata":{"kernelspec":{"name":"python3","display_name":"Python 3","language":"python"},"language_info":{"name":"python","version":"3.7.6","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"}},"nbformat":4,"nbformat_minor":4}