{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/duoan/TorchCode/blob/master/templates/03_linear.ipynb)\n", "\n", "# 🟑 Medium: Simple Linear Layer\n", "\n", "Implement a fully-connected linear layer: **y = xW^T + b**\n", "\n", "### Signature\n", "```python\n", "class SimpleLinear:\n", " def __init__(self, in_features: int, out_features: int): ...\n", " def forward(self, x: torch.Tensor) -> torch.Tensor: ...\n", "```\n", "\n", "### Requirements\n", "- `self.weight`: shape `(out_features, in_features)`, init with `randn * (1/√in_features)`\n", "- `self.bias`: shape `(out_features,)`, init as zeros\n", "- Both must have `requires_grad=True`\n", "- `forward(x)` computes `x @ W^T + b`\n", "- Do **NOT** use `torch.nn.Linear`" ], "outputs": [] }, { "cell_type": "code", "metadata": {}, "source": [ "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n", "try:\n", " import google.colab\n", " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n", "except ImportError:\n", " pass\n" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import math" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# ✏️ YOUR IMPLEMENTATION HERE\n", "\n", "class SimpleLinear:\n", " def __init__(self, in_features: int, out_features: int):\n", " pass # Initialize weight and bias\n", "\n", " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", " pass # Compute y = x @ W^T + b" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# πŸ§ͺ Debug\n", "layer = SimpleLinear(8, 4)\n", "print(\"W shape:\", layer.weight.shape) # should be (4, 8)\n", "print(\"b shape:\", layer.bias.shape) # should be (4,)\n", "\n", "x = torch.randn(2, 8)\n", "y = layer.forward(x)\n", "print(\"Output shape:\", y.shape) # should be (2, 4)" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# βœ… SUBMIT\n", "from torch_judge import check\n", "check(\"linear\")" ], "execution_count": null } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11.0" } }, "nbformat": 4, "nbformat_minor": 4 }