How does the running time of code in this section changes if not using subsampling?" ] }, { "cell_type": "code", "execution_count": 18, "id": "0012d647-c837-437b-ad22-6c76632e6691", "metadata": { "tags": [] }, "outputs": [], "source": [ "import time\n", "import collections\n", "import math\n", "import os\n", "import random\n", "import torch\n", "import warnings\n", "import sys\n", "import pandas as pd\n", "sys.path.append('/home/jovyan/work/d2l_solutions/notebooks/exercises/d2l_utils/')\n", "import d2l\n", "from torchsummary import summary\n", "warnings.filterwarnings(\"ignore\")\n", "\n", "#@save\n", "d2l.DATA_HUB['ptb'] = (d2l.DATA_URL + 'ptb.zip',\n", " '319d85e578af0cdc590547f26231e4e31cdf1e42')\n", "#@save\n", "class RandomGenerator:\n", " \"\"\"Randomly draw among {1, ..., n} according to n sampling weights.\"\"\"\n", " def __init__(self, sampling_weights,k=10000):\n", " # Exclude\n", " self.population = list(range(1, len(sampling_weights) + 1))\n", " self.sampling_weights = sampling_weights\n", " self.candidates = []\n", " self.i = 0\n", " self.k = k\n", "\n", " def draw(self):\n", " if self.i == len(self.candidates):\n", " # Cache `k` random sampling results\n", " self.candidates = random.choices(\n", " self.population, self.sampling_weights, k=self.k)\n", " self.i = 0\n", " self.i += 1\n", " return self.candidates[self.i - 1]\n", " \n", "#@save\n", "def subsample(sentences, vocab,flag=True):\n", " \"\"\"Subsample high-frequency words.\"\"\"\n", " # Exclude unknown tokens ('')\n", " sentences = [[token for token in line if vocab[token] != vocab.unk]\n", " for line in sentences]\n", " counter = collections.Counter([\n", " token for line in sentences for token in line])\n", " num_tokens = sum(counter.values())\n", "\n", " # Return True if `token` is kept during subsampling\n", " def keep(token):\n", " return(random.uniform(0, 1) <\n", " math.sqrt(1e-4 / counter[token] * num_tokens))\n", " if flag:\n", " return ([[token for token in line if keep(token)] for line in sentences],\n", " counter)\n", " return (sentences,counter)\n", "\n", "#@save\n", "def get_centers_and_contexts(corpus, max_window_size):\n", " \"\"\"Return center words and context words in skip-gram.\"\"\"\n", " centers, contexts = [], []\n", " for line in corpus:\n", " # To form a \"center word--context word\" pair, each sentence needs to\n", " # have at least 2 words\n", " if len(line) < 2:\n", " continue\n", " centers += line\n", " for i in range(len(line)): # Context window centered at `i`\n", " window_size = random.randint(1, max_window_size)\n", " indices = list(range(max(0, i - window_size),\n", " min(len(line), i + 1 + window_size)))\n", " # Exclude the center word from the context words\n", " indices.remove(i)\n", " contexts.append([line[idx] for idx in indices])\n", " return centers, contexts\n", "\n", "#@save\n", "def read_ptb():\n", " \"\"\"Load the PTB dataset into a list of text lines.\"\"\"\n", " data_dir = d2l.download_extract('ptb')\n", " # Read the training set\n", " with open(os.path.join(data_dir, 'ptb.train.txt')) as f:\n", " raw_text = f.read()\n", " return [line.split() for line in raw_text.split('\\n')]\n", "\n", "#@save\n", "def get_negatives(all_contexts, vocab, counter, K, k=10000):\n", " \"\"\"Return noise words in negative sampling.\"\"\"\n", " # Sampling weights for words with indices 1, 2, ... (index 0 is the\n", " # excluded unknown token) in the vocabulary\n", " sampling_weights = [counter[vocab.to_tokens(i)]**0.75\n", " for i in range(1, len(vocab))]\n", " all_negatives, generator = [], RandomGenerator(sampling_weights,k)\n", " for contexts in all_contexts:\n", " negatives = []\n", " while len(negatives) < len(contexts) * K:\n", " neg = generator.draw()\n", " # Noise words cannot be context words\n", " if neg not in contexts:\n", " negatives.append(neg)\n", " all_negatives.append(negatives)\n", " return all_negatives\n", "\n", "#@save\n", "def batchify(data):\n", " \"\"\"Return a minibatch of examples for skip-gram with negative sampling.\"\"\"\n", " max_len = max(len(c) + len(n) for _, c, n in data)\n", " centers, contexts_negatives, masks, labels = [], [], [], []\n", " for center, context, negative in data:\n", " cur_len = len(context) + len(negative)\n", " centers += [center]\n", " contexts_negatives += [context + negative + [0] * (max_len - cur_len)]\n", " masks += [[1] * cur_len + [0] * (max_len - cur_len)]\n", " labels += [[1] * len(context) + [0] * (max_len - len(context))]\n", " print(centers[:2])\n", " return (torch.tensor(centers).reshape((-1, 1)), torch.tensor(\n", " contexts_negatives), torch.tensor(masks), torch.tensor(labels))\n", "\n", "#@save\n", "def load_data_ptb(batch_size, max_window_size, num_noise_words, flag=True, k=10000):\n", " \"\"\"Download the PTB dataset and then load it into memory.\"\"\"\n", " # num_workers = d2l.get_dataloader_workers()\n", " sentences = read_ptb()\n", " vocab = d2l.Vocab(sentences, min_freq=10)\n", " subsampled, counter = subsample(sentences, vocab,flag)\n", " corpus = [vocab[line] for line in subsampled]\n", " all_centers, all_contexts = get_centers_and_contexts(\n", " corpus, max_window_size)\n", " all_negatives = get_negatives(\n", " all_contexts, vocab, counter, num_noise_words, k=k)\n", "\n", " class PTBDataset(torch.utils.data.Dataset):\n", " def __init__(self, centers, contexts, negatives):\n", " assert len(centers) == len(contexts) == len(negatives)\n", " self.centers = centers\n", " self.contexts = contexts\n", " self.negatives = negatives\n", "\n", " def __getitem__(self, index):\n", " return (self.centers[index], self.contexts[index],\n", " self.negatives[index])\n", "\n", " def __len__(self):\n", " return len(self.centers)\n", "\n", " dataset = PTBDataset(all_centers, all_contexts, all_negatives)\n", "\n", " data_iter = torch.utils.data.DataLoader(dataset, batch_size, shuffle=True,\n", " collate_fn=batchify)\n", " return data_iter, vocab" ] }, { "cell_type": "code", "execution_count": 19, "id": "a8bac86a-ce43-447e-a63b-0e2c179245e5", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[5425, 4442]\n", "centers shape: torch.Size([512, 1])\n", "contexts_negatives shape: torch.Size([512, 60])\n", "masks shape: torch.Size([512, 60])\n", "labels shape: torch.Size([512, 60])\n" ] } ], "source": [ "data_iter, vocab = load_data_ptb(512, 5, 5)\n", "names = ['centers', 'contexts_negatives', 'masks', 'labels']\n", "for batch in data_iter:\n", " for name, data in zip(names, batch):\n", " print(name, 'shape:', data.shape)\n", " break" ] }, { "cell_type": "code", "execution_count": 4, "id": "4a83edea-015e-4713-be0f-a3723f9b351e", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "9.802619218826294" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t0 = time.time()\n", "data_iter, vocab = load_data_ptb(512, 5, 5)\n", "t1 = time.time()\n", "t1-t0\n", "# names = ['centers', 'contexts_negatives', 'masks', 'labels']\n", "# for batch in data_iter:\n", "# for name, data in zip(names, batch):\n", "# print(name, 'shape:', data.shape)\n", "# break" ] }, { "cell_type": "code", "execution_count": 6, "id": "3fa20ade-55cd-44d3-91d3-097b5a17b952", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "23.945112943649292" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t0 = time.time()\n", "data_iter, vocab = load_data_ptb(512, 5, 5,flag=False)\n", "t1 = time.time()\n", "t1-t0" ] }, { "cell_type": "markdown", "id": "35a62073-32e3-4936-bc6e-b57d1495930f", "metadata": {}, "source": [ "# 2. The RandomGenerator class caches k random sampling results. Set k to other values and see how it affects the data loading speed." ] }, { "cell_type": "code", "execution_count": 12, "id": "9197aedd-1bdb-4d97-bf64-6327b77a0194", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", "
" ], "text/plain": [ " k time\n", "0 10 10.338631\n", "1 100 9.933641\n", "2 1000 9.871675\n", "3 10000 10.212862\n", "4 100000 10.313871" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ts = []\n", "k_list = [10,100,1000,10000,100000]\n", "for k in k_list:\n", " t0 = time.time()\n", " data_iter, vocab = load_data_ptb(512, 5, 5, k)\n", " t1 = time.time()\n", " ts.append(t1-t0)\n", "df = pd.DataFrame({'k':k_list,'time':ts})\n", "df" ] }, { "cell_type": "markdown", "id": "9b388088-50b4-4007-8a4c-96ff9e7e5a2c", "metadata": {}, "source": [ "# 3. What other hyperparameters in the code of this section may affect the data loading speed?" ] }, { "cell_type": "code", "execution_count": 13, "id": "d203c92d-53eb-41ce-a7d2-5995618c13b5", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", "
" ], "text/plain": [ " num_noise_words time\n", "0 2 6.078225\n", "1 5 9.767658\n", "2 10 16.298754\n", "3 15 22.715422\n", "4 20 28.570359\n", "5 25 35.331429\n", "6 30 41.231029" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ts = []\n", "noise_list = [2,5,10,15,20,25,30]\n", "for num_noise_words in noise_list:\n", " t0 = time.time()\n", " data_iter, vocab = load_data_ptb(512, 5, num_noise_words)\n", " t1 = time.time()\n", " ts.append(t1-t0)\n", "df = pd.DataFrame({'num_noise_words':noise_list,'time':ts})\n", "df" ] }, { "cell_type": "code", "execution_count": 17, "id": "bd9570cb-5dde-4674-9ea0-5f3a19bfe897", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", "
