{ "cells": [ { "cell_type": "code", "execution_count": 86, "id": "77bf42ce-506c-4f24-bcef-3f55f8d7190a", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[NbConvertApp] WARNING | Config option `kernel_spec_manager_class` not recognized by `NbConvertApp`.\n", "[NbConvertApp] Converting notebook 4_7_7_Exercises.ipynb to markdown\n", "[NbConvertApp] Support files will be in 4_7_7_Exercises_files/\n", "[NbConvertApp] Making directory 4_7_7_Exercises_files\n", "[NbConvertApp] Writing 18738 bytes to 4_7_7_Exercises.md\n" ] } ], "source": [ "!jupyter nbconvert --to markdown 4_7_7_Exercises.ipynb" ] }, { "cell_type": "markdown", "id": "7771fa01-13f8-4387-869a-0ed8b9972b6a", "metadata": {}, "source": [ "# 4.7.7. Exercises" ] }, { "cell_type": "markdown", "id": "2f37df87-5ba9-416e-a076-fdab3ae76c8a", "metadata": {}, "source": [ "## 1. What could happen when we change the behavior of a search engine? What might the users do? What about the advertisers?" ] }, { "cell_type": "markdown", "id": "33cabb3f-4caa-4f65-a90e-a0243f1fc11d", "metadata": {}, "source": [ "When the behavior of a search engine changes, several outcomes can occur, impacting both users and advertisers:\n", "\n", "1. **Users' Reactions:**\n", " - Improved Search Experience: If the changes result in more accurate and relevant search results, users might have a better experience and find the information they're looking for more easily.\n", " - Frustration and Discontent: On the other hand, if the changes lead to less relevant results, users might become frustrated and dissatisfied with the search engine's performance.\n", " - Change in User Habits: Users might change their search behaviors, such as using different search engines or altering their search queries to adapt to the new behavior.\n", "\n", "2. **Advertisers' Reactions:**\n", " - Changes in Ad Performance: Altering the behavior of the search engine could affect the performance of advertisements. Advertisers might see variations in click-through rates, conversion rates, and overall campaign success.\n", " - Adjustment of Advertising Strategies: Advertisers might need to modify their advertising strategies, keywords, and targeting parameters to align with the new search engine behavior.\n", " - Financial Impact: If the changes lead to decreased ad performance, advertisers might experience reduced return on investment (ROI) and might reconsider their advertising budgets on that platform.\n", "\n", "3. **Algorithmic Impact:**\n", " - Changes in Ranking: Search engine behavior often revolves around algorithms that determine how content is ranked and displayed. Algorithmic changes could lead to shifting rankings, affecting the visibility of websites and content.\n", " - SEO Practices: Search engine optimization (SEO) strategies might need to be adjusted to match the new algorithms, potentially affecting how websites are optimized for better search engine visibility.\n", "\n", "4. **Search Engine Market Share:**\n", " - Changes in User Base: Search engine behavior changes could influence user preferences, leading to shifts in market share among search engines. Users might migrate to other search engines if they prefer their new behaviors.\n", "\n", "5. **Ethical and Legal Considerations:**\n", " - Privacy Concerns: Changes in behavior could impact user privacy and data usage, leading to ethical and legal concerns regarding user data collection and tracking.\n", " - Regulatory Compliance: Changes in behavior might need to comply with data protection and privacy regulations in different regions.\n", "\n", "Overall, any changes to a search engine's behavior can have far-reaching effects on user satisfaction, user habits, advertiser performance, and even the competitive landscape. Careful consideration and testing are essential before implementing significant changes to ensure a positive impact on both users and advertisers." ] }, { "cell_type": "markdown", "id": "ed780c08-f505-499a-ab64-c74bd9997daa", "metadata": {}, "source": [ "## 2. Implement a covariate shift detector. Hint: build a classifier." ] }, { "cell_type": "markdown", "id": "f9b3353f-e7ed-4b3a-b861-7e05c8843ed1", "metadata": {}, "source": [ "A covariate shift detector aims to identify if there is a distribution shift between the training data and the test data. One approach to detecting covariate shift is by building a classifier that tries to distinguish between the training and test data.\n", "- In the `SameFashionMNIST`, we use both train and test dataset from `FashionMNIST`, and the `auc` of shift detector is **0.5**, which implies that there is **no covariate shift**.\n", "- In the `CovarFashionMNIST`, we use trainset from `FashionMNIST` and testset from `MNIST`, and the `auc` of shift detector is **0.99**, which implies that testset is **totally different** from trainset." ] }, { "cell_type": "code", "execution_count": 1, "id": "1f4f4cc3-288b-4b92-b56a-f509de36efa0", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/jovyan/work/d2l_solutions/notebooks/exercises/d2l_utils/d2l.py:122: SyntaxWarning: assertion is always true, perhaps remove parentheses?\n", " assert(self, 'net'), 'Neural network is defined'\n", "/home/jovyan/work/d2l_solutions/notebooks/exercises/d2l_utils/d2l.py:126: SyntaxWarning: assertion is always true, perhaps remove parentheses?\n", " assert(self, 'trainer'), 'trainer is not inited'\n" ] } ], "source": [ "import numpy as np\n", "import torch\n", "import torchvision\n", "from torchvision import transforms\n", "from torch.nn import functional as F\n", "from sklearn.metrics import roc_auc_score\n", "import sys\n", "sys.path.append('/home/jovyan/work/d2l_solutions/notebooks/exercises/d2l_utils/')\n", "import d2l\n", "\n", "\n", "class SameFashionMNIST(d2l.DataModule):\n", " def __init__(self, batch_size=64, resize=(28, 28)):\n", " super().__init__()\n", " self.save_hyperparameters()\n", " trans = transforms.Compose([transforms.Resize(resize),\n", " transforms.ToTensor()])\n", " train = torchvision.datasets.FashionMNIST(\n", " root=self.root, train=True, transform=trans, download=True)\n", " val = torchvision.datasets.FashionMNIST(\n", " root=self.root, train=False, transform=trans, download=True)\n", " num_rows_to_select = val.data.shape[0]//2\n", " # 从数组中随机选择行\n", " random_rows_indices = torch.randperm(train.data.size(0))\n", " self.train_X = torch.cat(\n", " (train.data[random_rows_indices[:num_rows_to_select]],\n", " val.data[:num_rows_to_select]), dim=0).type(torch.float32)#.unsqueeze(dim=1)\n", " self.train_y = torch.cat((torch.ones(num_rows_to_select),\n", " torch.zeros(num_rows_to_select)), dim=0).type(torch.int64)\n", " self.val_X = torch.cat((train.data[\n", " random_rows_indices[num_rows_to_select:2*num_rows_to_select]],\n", " val.data[num_rows_to_select:2*num_rows_to_select]),\n", " dim=0).type(torch.float32)#.unsqueeze(dim=1)\n", " self.val_y = torch.cat((torch.ones(num_rows_to_select),\n", " torch.zeros(num_rows_to_select)), dim=0).type(torch.int64)\n", " self.train = torch.utils.data.TensorDataset(self.train_X, self.train_y)\n", " self.val = torch.utils.data.TensorDataset(self.val_X, self.val_y)\n", " \n", " def get_dataloader(self, train):\n", " data = self.train if train else self.val\n", " return torch.utils.data.DataLoader(data, self.batch_size, shuffle=train\n", " , num_workers=self.num_workers)\n", " \n", " def visualize(self, batch, nrows=1, ncols=8, labels=[]):\n", " \"\"\"Defined in :numref:`sec_fashion_mnist`\"\"\"\n", " X, y = batch\n", " # if not labels:\n", " # labels = self.text_labels(y)\n", " d2l.show_images(X.squeeze(1), nrows, ncols, titles=labels)" ] }, { "cell_type": "code", "execution_count": 8, "id": "3721e70f-8c63-4e35-a12b-be21546ef060", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.]]) torch.Size([2, 784]) tensor([ 22816.5000, -22816.5000])\n" ] }, { "ename": "IndexError", "evalue": "Dimension out of range (expected to be in range of [-1, 0], but got 1)", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[8], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m a \u001b[38;5;241m=\u001b[39m detector\u001b[38;5;241m.\u001b[39mnet\u001b[38;5;241m.\u001b[39mparameters()\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m g \u001b[38;5;129;01min\u001b[39;00m a:\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28mprint\u001b[39m(g\u001b[38;5;241m.\u001b[39mgrad,g\u001b[38;5;241m.\u001b[39mgrad\u001b[38;5;241m.\u001b[39mshape,\u001b[43mg\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgrad\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msum\u001b[49m\u001b[43m(\u001b[49m\u001b[43maxis\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m)\n", "\u001b[0;31mIndexError\u001b[0m: Dimension out of range (expected to be in range of [-1, 0], but got 1)" ] } ], "source": [ "a = detector.net.parameters()\n", "for g in a:\n", " print(g.grad,g.grad.shape,g.grad.sum(axis=-1) if len(g.grad.shape)" ] }, { "cell_type": "code", "execution_count": 2, "id": "ddbf5de7-73e8-4de1-9a6b-b8c34008f5d4", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "auc: 0.5019\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/envs/d2l/lib/python3.11/site-packages/IPython/core/pylabtools.py:152: MatplotlibDeprecationWarning: savefig() got unexpected keyword argument \"facecolor\" which is no longer supported as of 3.3 and will become an error two minor releases later\n", " fig.canvas.print_figure(bytes_io, **kw)\n", "/opt/conda/envs/d2l/lib/python3.11/site-packages/IPython/core/pylabtools.py:152: MatplotlibDeprecationWarning: savefig() got unexpected keyword argument \"edgecolor\" which is no longer supported as of 3.3 and will become an error two minor releases later\n", " fig.canvas.print_figure(bytes_io, **kw)\n", "/opt/conda/envs/d2l/lib/python3.11/site-packages/IPython/core/pylabtools.py:152: MatplotlibDeprecationWarning: savefig() got unexpected keyword argument \"orientation\" which is no longer supported as of 3.3 and will become an error two minor releases later\n", " fig.canvas.print_figure(bytes_io, **kw)\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-21T03:22:40.106212\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-21T03:22:40.213726\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "data = SameFashionMNIST(batch_size=256)\n", "detector = d2l.SoftmaxRegression(num_outputs=2, lr=0.1)\n", "trainer = d2l.Trainer(max_epochs=2)\n", "trainer.fit(detector, data)\n", "y_hat=detector.net(data.val_X)\n", "preds = y_hat.argmax(axis=1)\n", "print(f'auc: {roc_auc_score(data.val_y, preds)}')\n", "index = data.val_y == 0\n", "data.visualize([data.val_X[index],data.val_y])" ] }, { "cell_type": "code", "execution_count": 82, "id": "ff15e4a4-8e3c-4614-9bfb-d8d648488917", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "auc: 0.995\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/envs/d2l/lib/python3.11/site-packages/IPython/core/pylabtools.py:152: MatplotlibDeprecationWarning: savefig() got unexpected keyword argument \"facecolor\" which is no longer supported as of 3.3 and will become an error two minor releases later\n", " fig.canvas.print_figure(bytes_io, **kw)\n", "/opt/conda/envs/d2l/lib/python3.11/site-packages/IPython/core/pylabtools.py:152: MatplotlibDeprecationWarning: savefig() got unexpected keyword argument \"edgecolor\" which is no longer supported as of 3.3 and will become an error two minor releases later\n", " fig.canvas.print_figure(bytes_io, **kw)\n", "/opt/conda/envs/d2l/lib/python3.11/site-packages/IPython/core/pylabtools.py:152: MatplotlibDeprecationWarning: savefig() got unexpected keyword argument \"orientation\" which is no longer supported as of 3.3 and will become an error two minor releases later\n", " fig.canvas.print_figure(bytes_io, **kw)\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-20T12:44:48.908953\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-20T12:44:49.016626\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "class CovarFashionMNIST(d2l.DataModule):\n", " def __init__(self, batch_size=64, resize=(28, 28)):\n", " super().__init__()\n", " self.save_hyperparameters()\n", " trans = transforms.Compose([transforms.Resize(resize),\n", " transforms.ToTensor()])\n", " train = torchvision.datasets.FashionMNIST(\n", " root=self.root, train=True, transform=trans, download=True)\n", " val = torchvision.datasets.MNIST(\n", " root=self.root, train=False, transform=trans, download=True)\n", " num_rows_to_select = val.data.shape[0]//2\n", " # 从数组中随机选择行\n", " random_rows_indices = torch.randperm(train.data.size(0))\n", " self.train_X = torch.cat(\n", " (train.data[random_rows_indices[:num_rows_to_select]],\n", " val.data[:num_rows_to_select]), dim=0).type(torch.float32)#.unsqueeze(dim=1)\n", " self.train_y = torch.cat((torch.ones(num_rows_to_select),\n", " torch.zeros(num_rows_to_select)), dim=0).type(torch.int64)\n", " self.val_X = torch.cat((train.data[\n", " random_rows_indices[num_rows_to_select:2*num_rows_to_select]],\n", " val.data[num_rows_to_select:2*num_rows_to_select]),\n", " dim=0).type(torch.float32)#.unsqueeze(dim=1)\n", " self.val_y = torch.cat((torch.ones(num_rows_to_select),\n", " torch.zeros(num_rows_to_select)), dim=0).type(torch.int64)\n", " self.train = torch.utils.data.TensorDataset(self.train_X, self.train_y)\n", " self.val = torch.utils.data.TensorDataset(self.val_X, self.val_y)\n", " \n", " def get_dataloader(self, train):\n", " data = self.train if train else self.val\n", " return torch.utils.data.DataLoader(data, self.batch_size, shuffle=train\n", " , num_workers=self.num_workers)\n", " \n", " def visualize(self, batch, nrows=1, ncols=8, labels=[]):\n", " \"\"\"Defined in :numref:`sec_fashion_mnist`\"\"\"\n", " X, y = batch\n", " d2l.show_images(X.squeeze(1), nrows, ncols, titles=labels)\n", " \n", "data = CovarFashionMNIST(batch_size=256)\n", "model = d2l.SoftmaxRegression(num_outputs=2, lr=0.1)\n", "trainer = d2l.Trainer(max_epochs=10)\n", "trainer.fit(model, data)\n", "y_hat=model.net(data.val_X)\n", "preds = y_hat.argmax(axis=1)\n", "print(f'auc: {roc_auc_score(data.val_y, preds)}')\n", "index = data.val_y == 0\n", "data.visualize([data.val_X[index],data.val_y])" ] }, { "cell_type": "markdown", "id": "57330a3f-a0df-430c-9a94-f73fc487b073", "metadata": {}, "source": [ "## 3. Implement a covariate shift corrector." ] }, { "cell_type": "markdown", "id": "9c07122f-4fc7-4af6-a879-092d8566e05a", "metadata": {}, "source": [ "Covariate shift correction aims to mitigate the effects of distribution differences between training and test data. One common technique for covariate shift correction is importance weighting, where you assign different weights to training samples to make their distribution match the test distribution. " ] }, { "cell_type": "code", "execution_count": 116, "id": "bf968f35-bd4b-47cc-8115-4ce16ac9f6a1", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "auc: 0.98\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/envs/d2l/lib/python3.11/site-packages/IPython/core/pylabtools.py:152: MatplotlibDeprecationWarning: savefig() got unexpected keyword argument \"facecolor\" which is no longer supported as of 3.3 and will become an error two minor releases later\n", " fig.canvas.print_figure(bytes_io, **kw)\n", "/opt/conda/envs/d2l/lib/python3.11/site-packages/IPython/core/pylabtools.py:152: MatplotlibDeprecationWarning: savefig() got unexpected keyword argument \"edgecolor\" which is no longer supported as of 3.3 and will become an error two minor releases later\n", " fig.canvas.print_figure(bytes_io, **kw)\n", "/opt/conda/envs/d2l/lib/python3.11/site-packages/IPython/core/pylabtools.py:152: MatplotlibDeprecationWarning: savefig() got unexpected keyword argument \"orientation\" which is no longer supported as of 3.3 and will become an error two minor releases later\n", " fig.canvas.print_figure(bytes_io, **kw)\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-20T13:59:08.394522\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-20T13:59:08.501563\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "class CovarFashionMNIST(d2l.DataModule):\n", " def __init__(self, batch_size=64, resize=(28, 28)):\n", " super().__init__()\n", " self.save_hyperparameters()\n", " trans = transforms.Compose([transforms.Resize(resize),\n", " transforms.ToTensor()])\n", " train = torchvision.datasets.MNIST(\n", " root=self.root, train=True, transform=trans, download=True)\n", " val = torchvision.datasets.MNIST(\n", " root=self.root, train=False, transform=trans, download=True)\n", " # val = torchvision.datasets.EMNIST(\n", " # root=self.root, train=False, transform=trans, download=True,\n", " # split = 'digits')\n", " num_rows_to_select = val.data.shape[0]//2\n", " # 从数组中随机选择行\n", " random_rows_indices = torch.randperm(train.data.size(0))\n", " self.train_X = torch.cat(\n", " (train.data[random_rows_indices[:num_rows_to_select]],\n", " val.data[:num_rows_to_select]+torch.ones(1)*10), dim=0).type(torch.float32)#.unsqueeze(dim=1)\n", " self.train_y = torch.cat((torch.ones(num_rows_to_select),\n", " torch.zeros(num_rows_to_select)), dim=0).type(torch.int64)\n", " self.val_X = torch.cat((train.data[\n", " random_rows_indices[num_rows_to_select:2*num_rows_to_select]],\n", " val.data[num_rows_to_select:2*num_rows_to_select]+torch.ones(1)*10),\n", " dim=0).type(torch.float32)#.unsqueeze(dim=1)\n", " self.val_y = torch.cat((torch.ones(num_rows_to_select),\n", " torch.zeros(num_rows_to_select)), dim=0).type(torch.int64)\n", " self.train = torch.utils.data.TensorDataset(self.train_X, self.train_y)\n", " self.val = torch.utils.data.TensorDataset(self.val_X, self.val_y)\n", " \n", " def get_dataloader(self, train):\n", " data = self.train if train else self.val\n", " return torch.utils.data.DataLoader(data, self.batch_size, shuffle=train\n", " , num_workers=self.num_workers)\n", " \n", " def visualize(self, batch, nrows=1, ncols=8, labels=[]):\n", " \"\"\"Defined in :numref:`sec_fashion_mnist`\"\"\"\n", " X, y = batch\n", " d2l.show_images(X.squeeze(1), nrows, ncols, titles=labels)\n", " \n", "data = CovarFashionMNIST(batch_size=256)\n", "detector = d2l.SoftmaxRegression(num_outputs=2, lr=0.1)\n", "trainer = d2l.Trainer(max_epochs=10)\n", "trainer.fit(detector, data)\n", "y_hat=detector.net(data.val_X)\n", "preds = y_hat.argmax(axis=1)\n", "print(f'auc: {roc_auc_score(data.val_y, preds)}')\n", "index = data.val_y == 0\n", "data.visualize([data.val_X[index],data.val_y])" ] }, { "cell_type": "code", "execution_count": 112, "id": "10a38316-34c5-4ca5-b523-822d71ef18cf", "metadata": { "tags": [] }, "outputs": [], "source": [ "from torch import nn\n", "\n", "\n", "class CorrFashionMNIST(d2l.DataModule):\n", " def __init__(self, detector, batch_size=64, resize=(28, 28)):\n", " super().__init__()\n", " self.save_hyperparameters()\n", " trans = transforms.Compose([transforms.Resize(resize),\n", " transforms.ToTensor()])\n", " # train = torchvision.datasets.FashionMNIST(\n", " # root=self.root, train=True, transform=trans, download=True)\n", " # val = torchvision.datasets.FashionMNIST(\n", " # root=self.root, train=False, transform=trans, download=True)\n", " train = torchvision.datasets.MNIST(\n", " root=self.root, train=True, transform=trans, download=True)\n", " val = torchvision.datasets.MNIST(\n", " root=self.root, train=False, transform=trans, download=True)\n", " # val = torchvision.datasets.EMNIST(\n", " # root=self.root, train=False, transform=trans, download=True,\n", " # split = 'digits')\n", " self.train_X = train.data.type(torch.float32)#.unsqueeze(dim=1)\n", " self.train_y = train.targets.type(torch.int64)\n", " self.train_weight = self.stat_weight(self.train_X)\n", " self.val_X = val.data.type(torch.float32)*100#.unsqueeze(dim=1)\n", " self.val_y = val.targets.type(torch.int64)\n", " self.val_weight = self.stat_weight(self.val_X)\n", " # print(self.train_weight.shape, self.train_X.shape, self.train_y.shape)\n", " self.train = torch.utils.data.TensorDataset(self.train_weight, self.train_X, self.train_y)\n", " self.val = torch.utils.data.TensorDataset(self.val_weight, self.val_X, self.val_y)\n", " \n", " def get_dataloader(self, train):\n", " data = self.train if train else self.val\n", " return torch.utils.data.DataLoader(data, self.batch_size, shuffle=train\n", " , num_workers=self.num_workers)\n", " \n", " def visualize(self, batch, nrows=1, ncols=8, labels=[]):\n", " \"\"\"Defined in :numref:`sec_fashion_mnist`\"\"\"\n", " X, y = batch\n", " d2l.show_images(X.squeeze(1), nrows, ncols, titles=labels)\n", "\n", " def stat_weight(self, X):\n", " h = self.detector(X)\n", " weight = torch.exp(h[:,1]-h[:,0])\n", " weight[weight==torch.inf] = 10\n", " # weight.requires_grad = True\n", " weight = weight.detach()\n", " return weight\n", "\n", "\n", "class CorrSoftmaxRegression(d2l.Classifier):\n", " def __init__(self, num_outputs, lr):\n", " super().__init__()\n", " self.save_hyperparameters()\n", " self.net = nn.Sequential(nn.Flatten(),\n", " nn.LazyLinear(num_outputs))\n", " \n", " def forward(self, X):\n", " return self.net(X)\n", "\n", " def loss(self, y_hat, y, weight=None):\n", " y_hat = y_hat.reshape((-1, y_hat.shape[-1]))\n", " y = y.reshape((-1,))\n", " l = F.cross_entropy(y_hat, y, reduction='none')\n", " # weight = torch.ones(l.shape[0])\n", " if weight is not None:\n", " # print(weight.shape, weight.sum())\n", " l = l*weight\n", " # weight = weight.reshape(-1,1) + torch.zeros(1, self.num_outputs)\n", " # print(weight.shape)\n", " return l.mean()\n", " \n", "\n", " def training_step(self, batch, plot_flag=True):\n", " # print(\"training\")\n", " y_hat = self(*batch[1:-1])\n", " l = self.loss(y_hat, batch[-1], batch[0])\n", " # l = self.loss(y_hat, batch[-1])\n", " # auc = torch.tensor(roc_auc_score(batch[-1].detach().numpy() , y_hat[:,1].detach().numpy()))\n", " if plot_flag:\n", " # self.plot('loss', l, train=True)\n", " # self.plot('auc', auc, train=True)\n", " self.plot('acc', self.accuracy(y_hat, batch[-1]), train=True)\n", " return l\n", "\n", " def validation_step(self, batch, plot_flag=True):\n", " y_hat = self(*batch[1:-1])\n", " l = self.loss(y_hat, batch[-1])\n", " # auc = torch.tensor(roc_auc_score(batch[-1].detach().numpy() , y_hat[:,1].detach().numpy()))\n", " if plot_flag:\n", " # self.plot('loss', l, train=False)\n", " # self.plot('auc', auc, train=True)\n", " self.plot('acc', self.accuracy(y_hat, batch[-1]), train=False)\n", " return l" ] }, { "cell_type": "code", "execution_count": 113, "id": "72bc59ce-544b-4ac7-843c-c8b98efe9e13", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "acc:0.10\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-20T13:52:26.198683\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "data = CorrFashionMNIST(detector=detector, batch_size=256)\n", "model = CorrSoftmaxRegression(num_outputs=10, lr=0.1)\n", "trainer = d2l.Trainer(max_epochs=10)\n", "trainer.fit(model, data)\n", "y_hat=model.net(data.val_X)\n", "print(f'acc:{model.accuracy(y_hat, data.val_y):.2f}')" ] }, { "cell_type": "code", "execution_count": 20, "id": "a622ddf3-cc37-47e2-8b64-e36c1a0c0ef1", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "acc:0.79\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-20T11:56:18.890190\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "data = d2l.FashionMNIST(batch_size=256)\n", "model = d2l.SoftmaxRegression(num_outputs=10, lr=0.1)\n", "trainer = d2l.Trainer(max_epochs=10)\n", "trainer.fit(model, data)\n", "y_hat=model.net(data.val.data.type(torch.float32))\n", "print(f'acc:{model.accuracy(y_hat, data.val.targets):.2f}')" ] }, { "cell_type": "markdown", "id": "718c94d2-d8d1-48a8-8929-e2e10ca04b31", "metadata": {}, "source": [ "## 4. Besides distribution shift, what else could affect how the empirical risk approximates the risk?" ] }, { "cell_type": "markdown", "id": "7e1f27b7-fcb5-41b5-bee5-ff0fcb99cbc8", "metadata": {}, "source": [ "Several factors can affect how well the empirical risk approximates the true risk in machine learning models beyond distribution shift:\n", "\n", "1. **Sample Size:** The size of the training dataset plays a crucial role in how well the empirical risk approximates the risk. Smaller sample sizes may lead to higher variance and less accurate risk estimation.\n", "\n", "2. **Sample Quality:** The quality of the training data matters. If the training dataset contains noise, outliers, or mislabeled examples, the empirical risk may be skewed.\n", "\n", "3. **Bias and Fairness:** If the training dataset is biased in terms of representation or contains biased labels, it can lead to biased models and inaccurate risk estimates, especially when deployed in different contexts.\n", "\n", "4. **Feature Quality:** The quality and relevance of features used in the model affect the model's ability to generalize to new data. Irrelevant or redundant features can contribute to overfitting.\n", "\n", "5. **Model Complexity:** Highly complex models can fit the training data closely but may not generalize well to new data, leading to overfitting.\n", "\n", "6. **Regularization:** Regularization techniques, such as L1 and L2 regularization, can influence the model's ability to generalize and impact how the empirical risk approximates the true risk.\n", "\n", "7. **Data Augmentation:** Augmenting the training data through techniques like rotation, translation, or flipping can help improve generalization and reduce overfitting.\n", "\n", "8. **Hyperparameters:** Choices of hyperparameters, such as learning rate, batch size, and regularization strength, can significantly impact model performance and risk estimation.\n", "\n", "9. **Model Selection:** The selection of the model architecture (e.g., linear regression, neural networks, etc.) affects generalization. Different models have different biases and assumptions about the data.\n", "\n", "10. **Hyperparameter Tuning:** The process of tuning hyperparameters can impact how well the model generalizes. Over-tuning on the validation set may lead to poor generalization on new data.\n", "\n", "11. **Data Leakage:** Leakage of information from the test set into the training process can lead to optimistic risk estimates.\n", "\n", "12. **Ensemble Methods:** Combining predictions from multiple models can often lead to better generalization and risk estimation.\n", "\n", "13. **Evaluation Metrics:** The choice of evaluation metrics can impact how well the model's risk is estimated. Different metrics may emphasize different aspects of performance.\n", "\n", "It's essential to consider all these factors when building and evaluating machine learning models to ensure that the empirical risk accurately reflects the model's true risk on unseen data." ] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:d2l]", "language": "python", "name": "conda-env-d2l-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.4" } }, "nbformat": 4, "nbformat_minor": 5 }