{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from transformers import BertTokenizerFast, BertModel, BertConfig, BertTokenizer\n", "import numpy as np\n", "import pandas as pd\n", "from tqdm.notebook import tqdm as tqdm\n", "import glob\n", "import os\n", "from sklearn.decomposition import PCA\n", "import time\n", "import plotly.express as px\n", "from sklearn.manifold import TSNE\n", "from scipy.spatial.distance import cdist, cosine\n", "from gpytorch.kernels.rq_kernel import RQKernel\n", "import torch.nn as nn\n", "from torch.nn import functional as F\n", "import torch.nn.functional as F\n", "from torch.utils.data import Dataset, DataLoader" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### For computing representation on fine-tuned model" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "class BertClassifier(torch.nn.Module):\n", " \n", " def __init__(self, config, model, dim=256, num_classes=2):\n", " super(BertClassifier, self).__init__()\n", " \n", " # create the model config and BERT initialize the pretrained BERT, also layers wise outputs\n", " self.config = config\n", " self.base = model\n", " \n", " # classifier head [not useful]\n", " self.head = torch.nn.Sequential(*[\n", " torch.nn.Dropout(p=self.config.hidden_dropout_prob),\n", " torch.nn.Linear(in_features=self.config.hidden_size, out_features=dim),\n", " torch.nn.ReLU(),\n", " torch.nn.Dropout(p=self.config.hidden_dropout_prob),\n", " torch.nn.Linear(in_features=dim, out_features=num_classes)\n", " ])\n", " \n", " def forward(self, input_ids, attention_mask=None):\n", " \n", " # first output is top layer output, second output is context of input seq and third output will be layerwise tokens \n", " top_layer, pooled, layers = self.base(input_ids, attention_mask)\n", " outputs = self.head(pooled)\n", " return top_layer, outputs, layers" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "class SentimentDataset(Dataset):\n", " def __init__(self, df, tokenizer, max_len=512):\n", " self.tokenizer = tokenizer\n", " self.text = df.review_text.values\n", " self.max_len = max_len\n", " \n", " def __len__(self):\n", " return len(self.text)\n", " \n", " def __getitem__(self, idx):\n", " text = self.text[idx]\n", " \n", " # encode the text and target into tensors return the attention masks as well\n", " encoding = self.tokenizer.encode_plus(\n", " text=text,\n", " add_special_tokens=True,\n", " max_length=self.max_len,\n", " return_token_type_ids=False,\n", " pad_to_max_length=True,\n", " return_attention_mask=True,\n", " return_tensors='pt',\n", " truncation=True\n", " )\n", " \n", " return {\n", " 'text': text,\n", " 'input_ids': encoding['input_ids'].flatten(),\n", " 'attention_mask': encoding['attention_mask'].flatten(),\n", " }\n", " " ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def model_predict(trained):\n", " dictionary_list = []\n", "\n", " df = pd.read_csv(\"./amazon-review/dvd-UL.csv\")\n", " df = df.sample(n=1000, random_state=42) #number of samples\n", " df = df.reset_index(drop=True)\n", "\n", " dataset = SentimentDataset(df=df, tokenizer=tokenizer)\n", "\n", " data_loader = torch.utils.data.DataLoader(\n", " dataset=dataset,\n", " batch_size= 4,\n", " shuffle=False,\n", " num_workers=8\n", " )\n", "\n", " for bi, d in enumerate(tqdm(data_loader)):\n", " input_ids = d[\"input_ids\"]\n", " attention_mask = d[\"attention_mask\"]\n", "\n", " _, _, output = classifier_trained(input_ids, attention_mask)\n", "\n", " output = output[1:]\n", "\n", " for zeta in range(len(output[0])):\n", " for i in range(0,12):\n", " new_row = {'embeddings':output[i][zeta][0].cpu().detach().numpy(), 'layers': i+1}\n", " dictionary_list.append(new_row)\n", "\n", " dictionary_list = np.save(f\"./data/batch_{bi}\", dictionary_list, allow_pickle=True)\n", " dictionary_list = []" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "if(os.path.exists(\"./data\")):\n", " files = glob.glob('./data/*')\n", " for f in files:\n", " os.remove(f)\n", "else:\n", " os.makedirs(\"./data\")" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4c803dfea528440cabcc5d47873daaf5", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "PATH = \"books\"+\".pt\" #change the model name here\n", "\n", "model_name = \"bert-base-uncased\"\n", "config = BertConfig.from_pretrained(model_name, output_hidden_states=True)\n", "tokenizer = BertTokenizerFast.from_pretrained(model_name, do_lower_case=True)\n", "bert = BertModel.from_pretrained(model_name, config=config)\n", "\n", "classifier_trained = BertClassifier(config=config, model=bert, num_classes=2)\n", "classifier_trained.load_state_dict(torch.load(PATH))\n", "\n", "model_predict(classifier_trained)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "dictionary_list = []\n", "\n", "files = glob.glob(\"./data/*.npy\")\n", "\n", "for j in range(len(files)):\n", " alpha = np.load(f\"./data/batch_{j}.npy\", allow_pickle = True)\n", " for i in range(len(alpha)):\n", " new_row = {'embeddings':alpha[i][\"embeddings\"], 'layers': alpha[i][\"layers\"]}\n", " dictionary_list.append(new_row)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
| \n", " | embeddings | \n", "layers | \n", "
|---|---|---|
| 0 | \n", "[0.029436039, 0.06670721, -0.22471415, -0.2367... | \n", "1 | \n", "
| 1 | \n", "[-0.1554519, -0.21112284, -0.3408423, -0.20209... | \n", "2 | \n", "
| 2 | \n", "[-0.12095504, -0.36359823, -0.17967358, -0.109... | \n", "3 | \n", "
| 3 | \n", "[-0.21423775, -0.7461651, -0.6160757, -0.30794... | \n", "4 | \n", "
| 4 | \n", "[-0.4974339, -0.85912985, -0.42627215, -0.5099... | \n", "5 | \n", "
| ... | \n", "... | \n", "... | \n", "
| 11995 | \n", "[0.51743513, -0.6500383, -0.68353117, -0.22525... | \n", "8 | \n", "
| 11996 | \n", "[0.42627597, -0.63389504, -0.19636014, -0.2719... | \n", "9 | \n", "
| 11997 | \n", "[0.025212316, -0.5110682, 0.48476753, -0.35641... | \n", "10 | \n", "
| 11998 | \n", "[0.04342799, -0.75802934, 0.5390331, -0.213192... | \n", "11 | \n", "
| 11999 | \n", "[-0.29624316, -0.9558969, 0.48933977, -0.35488... | \n", "12 | \n", "
12000 rows × 2 columns
\n", "