{ "cells": [ { "cell_type": "markdown", "id": "01e602b7", "metadata": { "slideshow": { "slide_type": "-" } }, "source": [ "# Parameter Initialization\n", "\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "d2a5e8cb", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:43:11.084761Z", "iopub.status.busy": "2023-08-18T19:43:11.083735Z", "iopub.status.idle": "2023-08-18T19:43:11.122261Z", "shell.execute_reply": "2023-08-18T19:43:11.121411Z" }, "origin_pos": 11, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 1])" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "from torch import nn\n", "\n", "net = nn.Sequential(nn.LazyLinear(8), nn.ReLU(), nn.LazyLinear(1))\n", "X = torch.rand(size=(2, 4))\n", "net(X).shape" ] }, { "cell_type": "markdown", "id": "5ed55d32", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Built-in Initialization" ] }, { "cell_type": "code", "execution_count": 3, "id": "6059e0fb", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:43:11.127859Z", "iopub.status.busy": "2023-08-18T19:43:11.127125Z", "iopub.status.idle": "2023-08-18T19:43:11.135507Z", "shell.execute_reply": "2023-08-18T19:43:11.134596Z" }, "origin_pos": 16, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "(tensor([-0.0129, -0.0007, -0.0033, 0.0276]), tensor(0.))" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def init_normal(module):\n", " if type(module) == nn.Linear:\n", " nn.init.normal_(module.weight, mean=0, std=0.01)\n", " nn.init.zeros_(module.bias)\n", "\n", "net.apply(init_normal)\n", "net[0].weight.data[0], net[0].bias.data[0]" ] }, { "cell_type": "code", "execution_count": 4, "id": "d2007d64", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:43:11.138851Z", "iopub.status.busy": "2023-08-18T19:43:11.138302Z", "iopub.status.idle": "2023-08-18T19:43:11.145695Z", "shell.execute_reply": "2023-08-18T19:43:11.144862Z" }, "origin_pos": 21, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "(tensor([1., 1., 1., 1.]), tensor(0.))" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def init_constant(module):\n", " if type(module) == nn.Linear:\n", " nn.init.constant_(module.weight, 1)\n", " nn.init.zeros_(module.bias)\n", "\n", "net.apply(init_constant)\n", "net[0].weight.data[0], net[0].bias.data[0]" ] }, { "cell_type": "markdown", "id": "7b7da2fb", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "We can also apply different initializers for certain blocks" ] }, { "cell_type": "code", "execution_count": 5, "id": "4734e6eb", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:43:11.149040Z", "iopub.status.busy": "2023-08-18T19:43:11.148497Z", "iopub.status.idle": "2023-08-18T19:43:11.155752Z", "shell.execute_reply": "2023-08-18T19:43:11.154840Z" }, "origin_pos": 26, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([-0.0974, 0.1707, 0.5840, -0.5032])\n", "tensor([[42., 42., 42., 42., 42., 42., 42., 42.]])\n" ] } ], "source": [ "def init_xavier(module):\n", " if type(module) == nn.Linear:\n", " nn.init.xavier_uniform_(module.weight)\n", "\n", "def init_42(module):\n", " if type(module) == nn.Linear:\n", " nn.init.constant_(module.weight, 42)\n", "\n", "net[0].apply(init_xavier)\n", "net[2].apply(init_42)\n", "print(net[0].weight.data[0])\n", "print(net[2].weight.data)" ] }, { "cell_type": "markdown", "id": "e2e9ff3b", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Custom Initialization" ] }, { "cell_type": "code", "execution_count": 6, "id": "334b9bed", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:43:11.159032Z", "iopub.status.busy": "2023-08-18T19:43:11.158501Z", "iopub.status.idle": "2023-08-18T19:43:11.166911Z", "shell.execute_reply": "2023-08-18T19:43:11.166067Z" }, "origin_pos": 35, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Init weight torch.Size([8, 4])\n", "Init weight torch.Size([1, 8])\n" ] }, { "data": { "text/plain": [ "tensor([[ 0.0000, -7.6364, -0.0000, -6.1206],\n", " [ 9.3516, -0.0000, 5.1208, -8.4003]], grad_fn=)" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def my_init(module):\n", " if type(module) == nn.Linear:\n", " print(\"Init\", *[(name, param.shape)\n", " for name, param in module.named_parameters()][0])\n", " nn.init.uniform_(module.weight, -10, 10)\n", " module.weight.data *= module.weight.data.abs() >= 5\n", "\n", "net.apply(my_init)\n", "net[0].weight[:2]" ] }, { "cell_type": "code", "execution_count": 7, "id": "e38feecc", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:43:11.170212Z", "iopub.status.busy": "2023-08-18T19:43:11.169683Z", "iopub.status.idle": "2023-08-18T19:43:11.176291Z", "shell.execute_reply": "2023-08-18T19:43:11.175385Z" }, "origin_pos": 41, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "tensor([42.0000, -6.6364, 1.0000, -5.1206])" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "net[0].weight.data[:] += 1\n", "net[0].weight.data[0, 0] = 42\n", "net[0].weight.data[0]" ] } ], "metadata": { "celltoolbar": "Slideshow", "language_info": { "name": "python" }, "required_libs": [], "rise": { "autolaunch": true, "enable_chalkboard": true, "overlay": "
", "scroll": true } }, "nbformat": 4, "nbformat_minor": 5 }