{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Load libraries\n", "import torch\n", "import torch.nn as nn\n", "from res.plot_lib import set_default, show_scatterplot, plot_bases\n", "from matplotlib.pyplot import plot, title, axis, figure, gca, gcf, rcParams, show\n", "from numpy import clip\n", "rcParams['figure.max_open_warning'] = 100" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Set style (needs to be in a new cell)\n", "%matplotlib inline\n", "set_default()\n", "torch.manual_seed(0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# generate some points in 2-D space\n", "n_points = 1_000\n", "X = torch.randn(n_points, 2).to(device)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# colors [0 – 511]^2\n", "x_min = -1.5 #X.min(0)[0] #+ 1\n", "x_max = +1.5 #X.max(0)[0] #- 1\n", "colors = (X - x_min) / (x_max - x_min)\n", "colors = (colors * 511).short().numpy()\n", "colors = clip(colors, 0, 511)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "figure().add_axes([0, 0, 1, 1])\n", "show_scatterplot(X, colors, title='X')\n", "OI = torch.cat((torch.zeros(2, 2), torch.eye(2))).to(device)\n", "plot_bases(OI)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Visualizing Linear Transformations\n", "\n", "* Generate a random matrix $W$\n", "\n", "$\n", "\\begin{equation}\n", " W = U\n", " \\left[ {\\begin{array}{cc}\n", " s_1 & 0 \\\\\n", " 0 & s_2 \\\\\n", " \\end{array} } \\right]\n", " V^\\top\n", "\\end{equation}\n", "$\n", "* Compute $y = Wx$\n", "* Larger singular values stretch the points\n", "* Smaller singular values push them together\n", "* $U, V$ rotate/reflect" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "show_scatterplot(X, colors, title='X')\n", "plot_bases(OI)\n", "\n", "for i in range(10):\n", " figure()\n", " # create a random matrix\n", " W = torch.randn(2, 2).to(device)\n", " # transform points\n", " Y = X @ W.t()\n", " # compute singular values\n", " U, S, V = torch.svd(W)\n", " # plot transformed points\n", " show_scatterplot(Y, colors, title='y = Wx, singular values : [{:.3f}, {:.3f}]'.format(S[0], S[1]))\n", " # transform the basis\n", " new_OI = OI @ W\n", " # plot old and new basis\n", " plot_bases(OI)\n", "# plot_bases(new_OI)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Linear transformation with PyTorch" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = nn.Sequential(\n", " nn.Linear(2, 2, bias=False)\n", ")\n", "model.to(device)\n", "with torch.no_grad():\n", " Y = model(X)\n", " figure()\n", " show_scatterplot(Y, colors)\n", " plot_bases(model(OI))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Non-linear Transform: Map Points to a Square\n", "\n", "* Linear transforms can rotate, reflect, stretch and compress, but cannot curve\n", "* We need non-linearities for this\n", "* Can (approximately) map points to a square by first stretching out by a factor $s$, then squashing with a tanh function\n", "\n", "$\n", " f(x)= \\tanh \\left(\n", " \\left[ {\\begin{array}{cc}\n", " s & 0 \\\\\n", " 0 & s \\\\\n", " \\end{array} } \\right] \n", " x\n", " \\right)\n", "$" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "z = torch.linspace(-10, 10, 101)\n", "s = torch.tanh(z)\n", "plot(z.numpy(), s.numpy())\n", "title('tanh() non linearity');" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "show_scatterplot(X, colors, title='X')\n", "plot_bases(OI)\n", "\n", "model = nn.Sequential(\n", " nn.Linear(2, 2, bias=False),\n", " nn.Tanh()\n", ")\n", "\n", "model.to(device)\n", "\n", "for s in range(1, 6):\n", " figure()\n", " W = s * torch.eye(2)\n", " model[0].weight.data.copy_(W)\n", " Y = model(X).data\n", " show_scatterplot(Y, colors, title=f'f(x), s={s}')\n", " plot_bases(OI, width=0.01)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Visualize Functions Represented by Random Neural Networks" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "show_scatterplot(X, colors, title='x')\n", "n_hidden = 5\n", "\n", "# NL = nn.ReLU() # ()^+\n", "NL = nn.Tanh()\n", "\n", "models = list()\n", "\n", "for i in range(5):\n", " # create 1-layer neural networks with random weights\n", " model = nn.Sequential(\n", " nn.Linear(2, n_hidden), \n", " NL, \n", " nn.Linear(n_hidden, 2)\n", " )\n", " model.to(device)\n", " models.append(model)\n", " with torch.no_grad():\n", " Y = model(X)\n", " figure()\n", " show_scatterplot(Y, colors, title='f(x)')\n", "# plot_bases(OI)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# deeper network with random weights\n", "show_scatterplot(X, colors, title='x')\n", "n_hidden = 5\n", "\n", "NL = nn.ReLU()\n", "# NL = nn.Tanh()\n", "\n", "for i in range(5):\n", " model = nn.Sequential(\n", " nn.Linear(2, n_hidden), \n", " NL, \n", " nn.Linear(n_hidden, n_hidden), \n", " NL, \n", " nn.Linear(n_hidden, n_hidden), \n", " NL, \n", " nn.Linear(n_hidden, n_hidden), \n", " NL, \n", " nn.Linear(n_hidden, 2)\n", " )\n", " model.to(device)\n", " with torch.no_grad():\n", " Y = model(X).detach()\n", " figure()\n", " show_scatterplot(Y, colors, title='f(x)', axis=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "show_scatterplot(X, colors, title='x')\n", "with torch.no_grad():\n", " Y = models[2](X)\n", "figure()\n", "show_scatterplot(Y, colors, title='f(x)')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def interpolate(X_in, X_out, steps, p=1/50, plotting_grid=False, ratio='1:1'):\n", " N = 1000\n", " for t in range(steps):\n", " # a = (t / (steps - 1)) ** p\n", " a = ((p + 1)**(t / (steps - 1)) - 1) / p\n", " gca().cla()\n", "# plt.text(0, 5, action, color='w', horizontalalignment='center', verticalalignment='center')\n", " show_scatterplot(a * X_out + (1 - a) * X_in, colors, title='f(x)')\n", "\n", " if plotting_grid: plot_grid(a * X_out[N:] + (1 - a) * X_in[N:])\n", " gcf().canvas.draw()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%matplotlib widget" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# fig = figure(figsize=(19.2, 10.8)) # resolution is 100 px per inch => 1920 x 1080\n", "fig = figure(figsize=(10, 5))\n", "ax = fig.add_axes([0, 0, 1, 1]) # stretched the plot area to the whole figure\n", "axis('off');" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Animate input/output\n", "steps = 150\n", "# steps = 1500\n", "interpolate(X, Y, steps, p=.001)" ] } ], "metadata": { "jupytext": { "formats": "ipynb,py:percent" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.2" } }, "nbformat": 4, "nbformat_minor": 4 }