{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import tensorflow as tf\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Define model and Loss\n", "\n", "class Model(object):\n", " def __init__(self):\n", " self.W = tf.Variable(10.0)\n", " self.b = tf.Variable(-5.0)\n", "\n", " def __call__(self, inputs):\n", " return self.W * inputs + self.b\n", "\n", "def compute_loss(y_true, y_pred):\n", " return tf.reduce_mean(tf.square(y_true-y_pred))" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "model = Model()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# Define True weight and bias\n", "\n", "TRUE_W = 3.0\n", "TRUE_b = 2.0" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# Obtain training data, Let's synthesize the training data with some noise.\n", "\n", "NUM_EXAMPLES = 1000\n", "inputs = tf.random.normal(shape=[NUM_EXAMPLES])\n", "noise = tf.random.normal(shape=[NUM_EXAMPLES])\n", "outputs = inputs * TRUE_W + TRUE_b + noise" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# Before we train the model let's visualize where the model stands right now.\n", "# We'll plot the model's predictions in red and the training data in blue.\n", "\n", "def plot(epoch):\n", " plt.scatter(inputs, outputs, c='b')\n", " plt.scatter(inputs, model(inputs), c='r')\n", " plt.title(\"epoch %2d, loss = %s\" %(epoch, str(compute_loss(outputs, model(inputs)).numpy())))\n", " plt.legend()\n", " plt.draw()\n", " plt.ion()\n", " plt.pause(1)\n", " plt.close()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "scrolled": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING: Logging before flag parsing goes to stderr.\n", "W0308 18:57:55.500271 4321252224 legend.py:1289] No handles with labels found to put in legend.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=> epoch 1: w_true= 3.00, w_pred= 8.67; b_true= 2.00, b_pred= -3.62, loss= 96.15\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "W0308 18:57:56.937305 4321252224 legend.py:1289] No handles with labels found to put in legend.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=> epoch 2: w_true= 3.00, w_pred= 7.59; b_true= 2.00, b_pred= -2.51, loss= 62.87\n", "=> epoch 3: w_true= 3.00, w_pred= 6.71; b_true= 2.00, b_pred= -1.62, loss= 41.23\n", "=> epoch 4: w_true= 3.00, w_pred= 6.00; b_true= 2.00, b_pred= -0.91, loss= 27.16\n", "=> epoch 5: w_true= 3.00, w_pred= 5.43; b_true= 2.00, b_pred= -0.33, loss= 18.02\n", "=> epoch 6: w_true= 3.00, w_pred= 4.96; b_true= 2.00, b_pred= 0.13, loss= 12.07\n", "=> epoch 7: w_true= 3.00, w_pred= 4.58; b_true= 2.00, b_pred= 0.49, loss= 8.20\n", "=> epoch 8: w_true= 3.00, w_pred= 4.28; b_true= 2.00, b_pred= 0.79, loss= 5.68\n", "=> epoch 9: w_true= 3.00, w_pred= 4.03; b_true= 2.00, b_pred= 1.03, loss= 4.05\n", "=> epoch 10: w_true= 3.00, w_pred= 3.83; b_true= 2.00, b_pred= 1.22, loss= 2.98\n", "=> epoch 11: w_true= 3.00, w_pred= 3.66; b_true= 2.00, b_pred= 1.37, loss= 2.29\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "W0308 18:57:58.325464 4321252224 legend.py:1289] No handles with labels found to put in legend.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=> epoch 12: w_true= 3.00, w_pred= 3.53; b_true= 2.00, b_pred= 1.49, loss= 1.84\n", "=> epoch 13: w_true= 3.00, w_pred= 3.42; b_true= 2.00, b_pred= 1.59, loss= 1.55\n", "=> epoch 14: w_true= 3.00, w_pred= 3.34; b_true= 2.00, b_pred= 1.67, loss= 1.36\n", "=> epoch 15: w_true= 3.00, w_pred= 3.27; b_true= 2.00, b_pred= 1.73, loss= 1.23\n", "=> epoch 16: w_true= 3.00, w_pred= 3.21; b_true= 2.00, b_pred= 1.78, loss= 1.15\n", "=> epoch 17: w_true= 3.00, w_pred= 3.16; b_true= 2.00, b_pred= 1.82, loss= 1.10\n", "=> epoch 18: w_true= 3.00, w_pred= 3.12; b_true= 2.00, b_pred= 1.85, loss= 1.06\n", "=> epoch 19: w_true= 3.00, w_pred= 3.09; b_true= 2.00, b_pred= 1.88, loss= 1.04\n", "=> epoch 20: w_true= 3.00, w_pred= 3.07; b_true= 2.00, b_pred= 1.90, loss= 1.03\n", "=> epoch 21: w_true= 3.00, w_pred= 3.05; b_true= 2.00, b_pred= 1.92, loss= 1.02\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "=> epoch 22: w_true= 3.00, w_pred= 3.03; b_true= 2.00, b_pred= 1.93, loss= 1.01\n", "=> epoch 23: w_true= 3.00, w_pred= 3.02; b_true= 2.00, b_pred= 1.94, loss= 1.01\n", "=> epoch 24: w_true= 3.00, w_pred= 3.01; b_true= 2.00, b_pred= 1.95, loss= 1.00\n", "=> epoch 25: w_true= 3.00, w_pred= 3.00; b_true= 2.00, b_pred= 1.96, loss= 1.00\n", "=> epoch 26: w_true= 3.00, w_pred= 2.99; b_true= 2.00, b_pred= 1.96, loss= 1.00\n", "=> epoch 27: w_true= 3.00, w_pred= 2.99; b_true= 2.00, b_pred= 1.97, loss= 1.00\n", "=> epoch 28: w_true= 3.00, w_pred= 2.98; b_true= 2.00, b_pred= 1.97, loss= 1.00\n", "=> epoch 29: w_true= 3.00, w_pred= 2.98; b_true= 2.00, b_pred= 1.97, loss= 1.00\n", "=> epoch 30: w_true= 3.00, w_pred= 2.98; b_true= 2.00, b_pred= 1.98, loss= 1.00\n" ] } ], "source": [ "# Define a training loop\n", "learning_rate = 0.1\n", "for epoch in range(30):\n", " with tf.GradientTape() as tape:\n", " loss = compute_loss(outputs, model(inputs))\n", "\n", " dW, db = tape.gradient(loss, [model.W, model.b])\n", "\n", " model.W.assign_sub(learning_rate * dW)\n", " model.b.assign_sub(learning_rate * db)\n", "\n", " print(\"=> epoch %2d: w_true= %.2f, w_pred= %.2f; b_true= %.2f, b_pred= %.2f, loss= %.2f\" %(\n", " epoch+1, TRUE_W, model.W.numpy(), TRUE_b, model.b.numpy(), loss.numpy()))\n", " if (epoch) % 10 == 0: plot(epoch + 1)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.2" } }, "nbformat": 4, "nbformat_minor": 2 }