{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/duoan/TorchCode/blob/master/templates/20_weight_init.ipynb)\n", "\n", "# ๐ŸŸข Easy: Kaiming Initialization\n", "\n", "Implement **Kaiming (He) normal initialization** for weight tensors.\n", "\n", "$$W \\sim \\mathcal{N}(0, \\text{std}^2) \\quad \\text{where} \\quad \\text{std} = \\sqrt{\\frac{2}{\\text{fan\\_in}}}$$\n", "\n", "### Signature\n", "```python\n", "def kaiming_init(weight: Tensor) -> Tensor:\n", " # Initialize weight in-place with Kaiming normal\n", " # fan_in = weight.shape[1]\n", " # Returns the weight tensor\n", "```" ], "outputs": [] }, { "cell_type": "code", "metadata": {}, "source": [ "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n", "try:\n", " import google.colab\n", " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n", "except ImportError:\n", " pass\n" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import math" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# โœ๏ธ YOUR IMPLEMENTATION HERE\n", "\n", "def kaiming_init(weight):\n", " pass # fill with normal(0, sqrt(2/fan_in))" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# ๐Ÿงช Debug\n", "import math\n", "w = torch.empty(256, 512)\n", "kaiming_init(w)\n", "print(f'Mean: {w.mean():.4f} (expect ~0)')\n", "print(f'Std: {w.std():.4f} (expect {math.sqrt(2/512):.4f})')" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# โœ… SUBMIT\n", "from torch_judge import check\n", "check('weight_init')" ], "execution_count": null } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11.0" } }, "nbformat": 4, "nbformat_minor": 4 }