{ "cells": [ { "cell_type": "markdown", "id": "04ae7e4c-7642-49a3-827c-452df0b17108", "metadata": {}, "source": [ "# 1. Using the trained model, find semantically similar words for other input words. Can you improve the results by tuning hyperparameters?" ] }, { "cell_type": "code", "execution_count": 226, "id": "b12770d2-c299-49a7-916c-a9e85c4673dd", "metadata": { "tags": [] }, "outputs": [], "source": [ "import time\n", "import collections\n", "import math\n", "import os\n", "import random\n", "import torch\n", "import warnings\n", "import sys\n", "import pandas as pd\n", "import torch\n", "import torch.nn as nn\n", "sys.path.append('/home/jovyan/work/d2l_solutions/notebooks/exercises/d2l_utils/')\n", "import d2l\n", "from torchsummary import summary\n", "warnings.filterwarnings(\"ignore\")\n", "\n", "#@save\n", "d2l.DATA_HUB['ptb'] = (d2l.DATA_URL + 'ptb.zip',\n", " '319d85e578af0cdc590547f26231e4e31cdf1e42')\n", "#@save\n", "class RandomGenerator:\n", " \"\"\"Randomly draw among {1, ..., n} according to n sampling weights.\"\"\"\n", " def __init__(self, sampling_weights,k=10000):\n", " # Exclude\n", " self.population = list(range(1, len(sampling_weights) + 1))\n", " self.sampling_weights = sampling_weights\n", " self.candidates = []\n", " self.i = 0\n", " self.k = k\n", "\n", " def draw(self):\n", " if self.i == len(self.candidates):\n", " # Cache `k` random sampling results\n", " self.candidates = random.choices(\n", " self.population, self.sampling_weights, k=self.k)\n", " self.i = 0\n", " self.i += 1\n", " return self.candidates[self.i - 1]\n", " \n", "#@save\n", "def subsample(sentences, vocab,flag=True):\n", " \"\"\"Subsample high-frequency words.\"\"\"\n", " # Exclude unknown tokens ('')\n", " sentences = [[token for token in line if vocab[token] != vocab.unk]\n", " for line in sentences]\n", " counter = collections.Counter([\n", " token for line in sentences for token in line])\n", " num_tokens = sum(counter.values())\n", "\n", " # Return True if `token` is kept during subsampling\n", " def keep(token):\n", " return(random.uniform(0, 1) <\n", " math.sqrt(1e-4 / counter[token] * num_tokens))\n", " if flag:\n", " return ([[token for token in line if keep(token)] for line in sentences],\n", " counter)\n", " return (sentences,counter)\n", "\n", "#@save\n", "def get_centers_and_contexts(corpus, max_window_size):\n", " \"\"\"Return center words and context words in skip-gram.\"\"\"\n", " centers, contexts = [], []\n", " for line in corpus:\n", " # To form a \"center word--context word\" pair, each sentence needs to\n", " # have at least 2 words\n", " if len(line) < 2:\n", " continue\n", " centers += line\n", " for i in range(len(line)): # Context window centered at `i`\n", " window_size = random.randint(1, max_window_size)\n", " indices = list(range(max(0, i - window_size),\n", " min(len(line), i + 1 + window_size)))\n", " # Exclude the center word from the context words\n", " indices.remove(i)\n", " contexts.append([line[idx] for idx in indices])\n", " return centers, contexts\n", "\n", "#@save\n", "def read_ptb():\n", " \"\"\"Load the PTB dataset into a list of text lines.\"\"\"\n", " data_dir = d2l.download_extract('ptb')\n", " # Read the training set\n", " with open(os.path.join(data_dir, 'ptb.train.txt')) as f:\n", " raw_text = f.read()\n", " return [line.split() for line in raw_text.split('\\n')]\n", "\n", "#@save\n", "def get_negatives(all_contexts, vocab, counter, K, k=10000):\n", " \"\"\"Return noise words in negative sampling.\"\"\"\n", " # Sampling weights for words with indices 1, 2, ... (index 0 is the\n", " # excluded unknown token) in the vocabulary\n", " sampling_weights = [counter[vocab.to_tokens(i)]**0.75\n", " for i in range(1, len(vocab))]\n", " all_negatives, generator = [], RandomGenerator(sampling_weights,k)\n", " for contexts in all_contexts:\n", " negatives = []\n", " while len(negatives) < len(contexts) * K:\n", " neg = generator.draw()\n", " # Noise words cannot be context words\n", " if neg not in contexts:\n", " negatives.append(neg)\n", " all_negatives.append(negatives)\n", " return all_negatives\n", "\n", "#@save\n", "def batchify(data):\n", " \"\"\"Return a minibatch of examples for skip-gram with negative sampling.\"\"\"\n", " max_len = max(len(c) + len(n) for _, c, n in data)\n", " centers, contexts_negatives, masks, labels = [], [], [], []\n", " for center, context, negative in data:\n", " cur_len = len(context) + len(negative)\n", " centers += [center]\n", " contexts_negatives += [context + negative + [0] * (max_len - cur_len)]\n", " masks += [[1] * cur_len + [0] * (max_len - cur_len)]\n", " labels += [[1] * len(context) + [0] * (max_len - len(context))]\n", " return (torch.tensor(centers).reshape((-1, 1)), torch.tensor(\n", " contexts_negatives), torch.tensor(masks), torch.tensor(labels))\n", "\n", "#@save\n", "def load_data_ptb(batch_size, max_window_size, num_noise_words, flag=True, k=10000):\n", " \"\"\"Download the PTB dataset and then load it into memory.\"\"\"\n", " # num_workers = d2l.get_dataloader_workers()\n", " sentences = read_ptb()\n", " vocab = d2l.Vocab(sentences, min_freq=10)\n", " subsampled, counter = subsample(sentences, vocab, flag)\n", " corpus = [vocab[line] for line in subsampled]\n", " all_centers, all_contexts = get_centers_and_contexts(\n", " corpus, max_window_size)\n", " all_negatives = get_negatives(\n", " all_contexts, vocab, counter, num_noise_words, k=k)\n", "\n", " class PTBDataset(torch.utils.data.Dataset):\n", " def __init__(self, centers, contexts, negatives):\n", " assert len(centers) == len(contexts) == len(negatives)\n", " self.centers = centers\n", " self.contexts = contexts\n", " self.negatives = negatives\n", "\n", " def __getitem__(self, index):\n", " return (self.centers[index], self.contexts[index],\n", " self.negatives[index])\n", "\n", " def __len__(self):\n", " return len(self.centers)\n", "\n", " dataset = PTBDataset(all_centers, all_contexts, all_negatives)\n", "\n", " data_iter = torch.utils.data.DataLoader(dataset, batch_size, shuffle=True,\n", " collate_fn=batchify)\n", " return data_iter, vocab\n", "\n", "def skip_gram(center, contexts_and_negatives, embed_v, embed_u):\n", " v = embed_v(center)\n", " u = embed_u(contexts_and_negatives)\n", " pred = torch.bmm(v, u.permute(0, 2, 1))\n", " return pred\n", "\n", "class SigmoidBCELoss(nn.Module):\n", " # Binary cross-entropy loss with masking\n", " def __init__(self):\n", " super().__init__()\n", "\n", " def forward(self, inputs, target, mask=None):\n", " out = nn.functional.binary_cross_entropy_with_logits(\n", " inputs, target, weight=mask, reduction=\"none\")\n", " return out.mean(dim=1)\n", "\n", "def train(net, data_iter, lr, num_epochs, device='cpu'):\n", " def init_weights(module):\n", " if type(module) == nn.Embedding:\n", " nn.init.xavier_uniform_(module.weight)\n", " net.apply(init_weights)\n", " net = net.to(device)\n", " optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n", " animator = d2l.Animator(xlabel='epoch', ylabel='loss',\n", " xlim=[1, num_epochs])\n", " # Sum of normalized losses, no. of normalized losses\n", " metric = d2l.Accumulator(2)\n", " loss = SigmoidBCELoss()\n", " for epoch in range(num_epochs):\n", " timer, num_batches = d2l.Timer(), len(data_iter)\n", " for i, batch in enumerate(data_iter):\n", " optimizer.zero_grad()\n", " center, context_negative, mask, label = [\n", " data.to(device) for data in batch]\n", "\n", " pred = skip_gram(center, context_negative, net[0], net[1])\n", " l = (loss(pred.reshape(label.shape).float(), label.float(), mask)\n", " / mask.sum(axis=1) * mask.shape[1])\n", " l.sum().backward()\n", " optimizer.step()\n", " metric.add(l.sum(), l.numel())\n", " if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:\n", " animator.add(epoch + (i + 1) / num_batches,\n", " (metric[0] / metric[1],))\n", " print(f'loss {metric[0] / metric[1]:.3f}, '\n", " f'{metric[1] / timer.stop():.1f} tokens/sec on {str(device)}')\n", " return metric[0] / metric[1]\n", "\n", "def get_similar_tokens(query_token, k, embed):\n", " W = embed.weight.data\n", " x = W[vocab[query_token]]\n", " # Compute the cosine similarity. Add 1e-9 for numerical stability\n", " cos = torch.mv(W, x) / torch.sqrt(torch.sum(W * W, dim=1) *\n", " torch.sum(x * x) + 1e-9)\n", " topk = torch.topk(cos, k=k+1)[1].cpu().numpy().astype('int32')\n", " for i in topk[1:]: # Remove the input words\n", " print(f'cosine sim={float(cos[i]):.3f}: {vocab.to_tokens(i)}')" ] }, { "cell_type": "code", "execution_count": 8, "id": "eaa1e620-8a88-4b7a-bc0e-2a020db974c2", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loss 0.410, 55995.3 tokens/sec on cpu\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-09-28T12:26:13.031876\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "data_iter, vocab = load_data_ptb(512, 5, 5)\n", "lr, num_epochs = 0.002, 5\n", "embed_size = 100\n", "net = nn.Sequential(nn.Embedding(num_embeddings=len(vocab),\n", " embedding_dim=embed_size),\n", " nn.Embedding(num_embeddings=len(vocab),\n", " embedding_dim=embed_size))\n", "train(net, data_iter, lr, num_epochs)" ] }, { "cell_type": "code", "execution_count": 11, "id": "ba4f941f-a8e4-493a-be24-2d78884ecd0e", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loss 0.384, 36166.7 tokens/sec on cpu\n" ] }, { "data": { "text/plain": [ "0.3843912484866804" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-09-28T12:35:10.325423\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "lr, num_epochs = 0.002, 5\n", "embed_size = 200\n", "net = nn.Sequential(nn.Embedding(num_embeddings=len(vocab),\n", " embedding_dim=embed_size),\n", " nn.Embedding(num_embeddings=len(vocab),\n", " embedding_dim=embed_size))\n", "train(net, data_iter, lr, num_epochs)" ] }, { "cell_type": "code", "execution_count": 14, "id": "d2391db9-f3f4-456c-ad03-b4fa9e9359ab", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cosine sim=0.551: workstations\n", "cosine sim=0.537: microprocessor\n", "cosine sim=0.525: compaq\n" ] } ], "source": [ "get_similar_tokens('intel', 3, net[0])" ] }, { "cell_type": "markdown", "id": "793a6be3-a982-40b3-8063-6b68585e44a4", "metadata": {}, "source": [ "# 2. When a training corpus is huge, we often sample context words and noise words for the center words in the current minibatch when updating model parameters. In other words, the same center word may have different context words or noise words in different training epochs. What are the benefits of this method? Try to implement this training method." ] }, { "cell_type": "markdown", "id": "a3ced8d4-728e-41db-8a46-2f7a6b20fd5b", "metadata": {}, "source": [ "The benefit of sampling context words and noise words is that it allows the model to capture more information from different contexts and handle larger training corpora. By varying the context words and noise words for the same center word in different epochs, the model can learn more diverse and robust representations of words that reflect their meanings in various situations4. This can also help the model avoid overfitting to a specific context or corpus." ] }, { "cell_type": "code", "execution_count": 223, "id": "6d4f22f5-3316-4382-9a11-e20899f21518", "metadata": { "tags": [] }, "outputs": [], "source": [ "import os\n", "import random\n", "\n", "class Vocab:\n", " \"\"\"Vocabulary for text.\"\"\"\n", " def __init__(self, filename, min_freq=0, reserved_tokens=[]):\n", " # Flatten a 2D list if needed\n", " # if tokens and isinstance(tokens[0], list):\n", " # tokens = [token for line in tokens for token in line]\n", " self.token_freqs = collections.Counter([])\n", " with open(filename, 'rb') as f:\n", " for line in f:\n", " counter = collections.Counter(line.decode().strip().split())\n", " self.token_freqs.update(counter)\n", " # Count token frequencies\n", " \n", " # counter = collections.Counter(tokens)\n", " # self.token_freqs = sorted(counter.items(), key=lambda x: x[1],\n", " # reverse=True)\n", " # The list of unique tokens\n", " filter_tokens = list(set([''] + reserved_tokens + [\n", " token for token in self.token_freqs if self.token_freqs[token] >= min_freq]))\n", " self.idx_to_token = {}\n", " self.token_to_idx = {}\n", " for idx, token in enumerate(filter_tokens):\n", " self.idx_to_token[idx]=token\n", " self.token_to_idx[token]=idx\n", " self.cur_id = len(self.idx_to_token)\n", "\n", " def __len__(self):\n", " return len(self.idx_to_token)\n", "\n", " def __getitem__(self, tokens):\n", " if not isinstance(tokens, (list, tuple)):\n", " return self.token_to_idx.get(tokens, self.unk)\n", " return [self.__getitem__(token) for token in tokens]\n", "\n", " def to_tokens(self, indices):\n", " if hasattr(indices, '__len__') and len(indices) > 1:\n", " return [self.idx_to_token[int(index)] for index in indices]\n", " return self.idx_to_token[indices]\n", " \n", " def update(self, tokens):\n", " # 统计tokens中每个词的出现次数,并更新到原有的词频字典中\n", " counter = collections.Counter(tokens)\n", " self.token_freqs.update(counter)\n", " # 遍历新出现的词,给每个词分配一个id,并更新两个映射字典\n", " for token in counter:\n", " if token not in self.token_to_idx:\n", " self.token_to_idx[token] = self.cur_id\n", " self.idx_to_token[self.cur_id] = token\n", " self.cur_id += 1\n", " \n", " @property\n", " def unk(self): # Index for the unknown token\n", " return self.token_to_idx['']\n", "\n", "\n", "def get_random_line(filepath,n):\n", " file_size = os.path.getsize(filepath)\n", " res = []\n", " row = 0\n", " with open(filepath, 'rb') as f:\n", " while True:\n", " pos = random.randint(0, file_size)\n", " if not pos: # the first line is chosen\n", " return res # return str\n", " f.seek(pos) # seek to random position\n", " f.readline() # skip possibly incomplete line\n", " line = f.readline() # read next (full) line\n", " if line:\n", " res.append(line.decode().strip().split())\n", " row += 1\n", " if row >= n:\n", " return res\n", " \n", "#@save\n", "def subsample(sentences, vocab,flag=True):\n", " \"\"\"Subsample high-frequency words.\"\"\"\n", " # Exclude unknown tokens ('')\n", " sentences = [[token for token in line if vocab[token] != vocab.unk]\n", " for line in sentences]\n", " # print(sentences)\n", " counter = collections.Counter([\n", " token for line in sentences for token in line])\n", " num_tokens = sum(counter.values())\n", "\n", " # Return True if `token` is kept during subsampling\n", " def keep(token):\n", " # print(token, counter[token]/num_tokens,math.sqrt(1e-4 / counter[token] * num_tokens))\n", " return(random.uniform(0, 1) <\n", " math.sqrt(1e-4 / counter[token] * num_tokens))\n", " if flag:\n", " return ([[token for token in line if keep(token)] for line in sentences],\n", " counter)\n", " return (sentences,counter)\n", "\n", "def batchify(all_centers, all_contexts,all_negatives):\n", " \"\"\"Return a minibatch of examples for skip-gram with negative sampling.\"\"\"\n", " max_len = max(len(c) + len(n) for _, c, n in data)\n", " centers, contexts_negatives, masks, labels = [], [], [], []\n", " for center, context, negative in zip(all_centers, all_contexts,all_negatives):\n", " cur_len = len(context) + len(negative)\n", " centers += [center]\n", " contexts_negatives += [context + negative + [0] * (max_len - cur_len)]\n", " masks += [[1] * cur_len + [0] * (max_len - cur_len)]\n", " labels += [[1] * len(context) + [0] * (max_len - len(context))]\n", " # print(contexts_negatives)\n", " return (torch.tensor(centers).reshape((-1, 1)), torch.tensor(\n", " contexts_negatives), torch.tensor(masks), torch.tensor(labels))\n", "\n", "def random_batch(filename,vocab,n,max_window_size,num_noise_words):\n", " sentences = get_random_line(filename,n)\n", " subsampled, counter = subsample(sentences, vocab, False)\n", " corpus = [vocab[line] for line in subsampled]\n", " all_centers, all_contexts = get_centers_and_contexts(\n", " corpus, max_window_size)\n", " all_negatives = get_negatives(all_contexts, vocab, counter, num_noise_words, k=100)\n", " return batchify(all_centers, all_contexts,all_negatives)\n", "\n", "def train(net, filename,vocab,n ,max_window_size,num_noise_words,lr, num_epochs, num_batches, device='cpu'):\n", " def init_weights(module):\n", " if type(module) == nn.Embedding:\n", " nn.init.xavier_uniform_(module.weight)\n", " net.apply(init_weights)\n", " net = net.to(device)\n", " optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n", " animator = d2l.Animator(xlabel='epoch', ylabel='loss',\n", " xlim=[1, num_epochs])\n", " # Sum of normalized losses, no. of normalized losses\n", " metric = d2l.Accumulator(2)\n", " loss = SigmoidBCELoss()\n", " for epoch in range(num_epochs):\n", " timer = d2l.Timer()\n", " for i in range(num_batches):\n", " optimizer.zero_grad()\n", " center, context_negative, mask, label = random_batch(filename,vocab,n,max_window_size,num_noise_words)\n", " pred = skip_gram(center, context_negative, net[0], net[1])\n", " l = (loss(pred.reshape(label.shape).float(), label.float(), mask)\n", " / mask.sum(axis=1) * mask.shape[1])\n", " l.sum().backward()\n", " optimizer.step()\n", " metric.add(l.sum(), l.numel())\n", " if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:\n", " animator.add(epoch + (i + 1) / num_batches,\n", " (metric[0] / metric[1],))\n", " print(f'loss {metric[0] / metric[1]:.3f}, '\n", " f'{metric[1] / timer.stop():.1f} tokens/sec on {str(device)}')\n", " return metric[0] / metric[1]" ] }, { "cell_type": "code", "execution_count": 155, "id": "3bd83c85-3d63-418d-a5e4-0e73220a2fe0", "metadata": { "tags": [] }, "outputs": [], "source": [ "data_dir = d2l.download_extract('ptb')\n", "filename = os.path.join(data_dir, 'ptb.train.txt')\n", "vocab = Vocab(filename=filename)\n", "# sentences = get_random_line(filename,100)\n", "# sentences" ] }, { "cell_type": "code", "execution_count": 224, "id": "78975b9d-45cf-4364-9f58-8373c793cbad", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loss 0.444, 24885.4 tokens/sec on cpu\n" ] }, { "data": { "text/plain": [ "0.4439971089160742" ] }, "execution_count": 224, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-10-19T02:52:40.264552\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# data_iter, vocab = load_data_ptb(512, 5, 5)\n", "lr, num_epochs = 0.002, 10\n", "embed_size = 100\n", "n = 100\n", "max_window_size = 5\n", "num_noise_words = 5\n", "num_batches = 100\n", "net = nn.Sequential(nn.Embedding(num_embeddings=len(vocab),\n", " embedding_dim=embed_size),\n", " nn.Embedding(num_embeddings=len(vocab),\n", " embedding_dim=embed_size))\n", "# train(net, data_iter, lr, num_epochs)\n", "train(net, filename,vocab,n,max_window_size,num_noise_words, lr, num_epochs, num_batches)" ] }, { "cell_type": "code", "execution_count": 225, "id": "56389e7f-b616-4b29-ac5d-e1bea3ed80f3", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cosine sim=0.981: lighting\n", "cosine sim=0.977: reliance\n", "cosine sim=0.976: vanguard\n" ] } ], "source": [ "get_similar_tokens('intel', 3, net[0])" ] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:d2l]", "language": "python", "name": "conda-env-d2l-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.4" } }, "nbformat": 4, "nbformat_minor": 5 }