{ "cells": [ { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "# TensorFlow基础" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "import tensorflow.keras as keras\n", "import tensorflow.keras.layers as layers\n", "\n", "physical_devices = tf.config.experimental.list_physical_devices('GPU')\n", "assert len(physical_devices) > 0, \"Not enough GPU hardware devices available\"\n", "tf.config.experimental.set_memory_growth(physical_devices[0], True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 数据类型\n", "\n", "### 数值类型\n", "\n", "标量在 TensorFlow 是如何创建的" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(float, tensorflow.python.framework.ops.EagerTensor, True)" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# python 语言方式创建标量\n", "a = 1.2 \n", "# TF 方式创建标量\n", "aa = tf.constant(1.2)\n", "\n", "type(a), type(aa), tf.is_tensor(aa)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "如果要使用 TensorFlow 提供的功能函数, 须通过 TensorFlow 规定的方式去创建张量,而不能使用 Python 语言的标准变量创建方式。" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = tf.constant([1,2.,3.3])\n", "# 打印 TF 张量的相关信息 \n", "x" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([1. , 2. , 3.3], dtype=float32)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 将 TF 张量的数据导出为 numpy 数组格式\n", "x.numpy() " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "与标量不同,向量的定义须通过 List 容器传给 tf.constant()函数。\n", "\n", "创建一个元素的向量:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(,\n", " TensorShape([1]))" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 创建一个元素的向量\n", "a = tf.constant([1.2]) \n", "a, a.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "创建 3 个元素的向量:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(,\n", " TensorShape([3]))" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ " # 创建 3 个元素的向量\n", "a = tf.constant([1,2, 3.])\n", "a, a.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "定义矩阵" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(, TensorShape([2, 2]))" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 创建 2 行 2 列的矩阵\n", "a = tf.constant([[1,2],[3,4]]) \n", "a, a.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "三维张量可以定义为:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 创建 3 维张量\n", "tf.constant([[[1,2],[3,4]],[[5,6],[7,8]]]) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "通过传入字符串对象即可创建字符串类型的张量" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 创建字符串\n", "a = tf.constant('Hello, Deep Learning.') \n", "a" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 字符串类型\n", "\n", "通过传入字符串对象即可创建字符串类型的张量" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 创建字符串\n", "a = tf.constant('Hello, Deep Learning.') \n", "a" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "在 tf.strings 模块中,提供了常见的字符串类型的工具函数,如小写化 lower()、 拼接\n", "join()、 长度 length()、 切分 split()等。" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 小写化字符串\n", "tf.strings.lower(a) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 布尔类型\n", "布尔类型的张量只需要传入 Python 语言的布尔类型数据,转换成 TensorFlow 内部布尔型即可。" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 创建布尔类型标量\n", "tf.constant(True) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "创建布尔类型的向量" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ " # 创建布尔类型向量\n", "tf.constant([True, False])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "需要注意的是, TensorFlow 的布尔类型和 Python 语言的布尔类型并不等价,不能通用" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "False\n", "tf.Tensor(True, shape=(), dtype=bool)\n" ] } ], "source": [ "# 创建 TF 布尔张量\n", "a = tf.constant(True) \n", "# TF 布尔类型张量与 python 布尔类型比较\n", "print(a is True) \n", "# 仅数值比较\n", "print(a == True) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 数值精度\n", "\n", "在创建张量时,可以指定张量的保存精度" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 创建指定精度的张量\n", "tf.constant(123456789, dtype=tf.int16)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tf.constant(123456789, dtype=tf.int32)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "对于浮点数, 高精度的张量可以表示更精准的数据,例如采用 tf.float32 精度保存π时,实际保存的数据为 3.1415927" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "# 从 numpy 中导入 pi 常量\n", "np.pi \n", "# 32 位\n", "tf.constant(np.pi, dtype=tf.float32) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "如果采用 tf.float64 精度保存π,则能获得更高的精度" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tf.constant(np.pi, dtype=tf.float64) # 64 位" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 读取精度\n", "\n", "通过访问张量的 dtype 成员属性可以判断张量的保存精度" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "before: \n", "after : \n" ] } ], "source": [ "a = tf.constant(np.pi, dtype=tf.float16)\n", "\n", "# 读取原有张量的数值精度\n", "print('before:',a.dtype) \n", "# 如果精度不符合要求,则进行转换\n", "if a.dtype != tf.float32: \n", " # tf.cast 函数可以完成精度转换\n", " a = tf.cast(a,tf.float32) \n", "# 打印转换后的精度\n", "print('after :',a.dtype) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 类型转换\n", "系统的每个模块使用的数据类型、 数值精度可能各不相同, 对于不符合要求的张量的类型及精度, 需要通过 tf.cast 函数进行转换" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 创建 tf.float16 低精度张量\n", "a = tf.constant(np.pi, dtype=tf.float16) \n", "# 转换为高精度张量\n", "tf.cast(a, tf.double) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "进行类型转换时,需要保证转换操作的合法性, 例如将高精度的张量转换为低精度的张量时,可能发生数据溢出隐患:" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = tf.constant(123456789, dtype=tf.int32)\n", "# 转换为低精度整型\n", "tf.cast(a, tf.int16) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "布尔类型与整型之间相互转换也是合法的, 是比较常见的操作" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = tf.constant([True, False])\n", "# 布尔类型转整型\n", "tf.cast(a, tf.int32) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "一般默认 0 表示 False, 1 表示 True,在 TensorFlow 中,将非 0 数字都视为 True," ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = tf.constant([-1, 0, 1, 2])\n", "# 整型转布尔类型\n", "tf.cast(a, tf.bool) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 待优化张量\n", "\n", "TensorFlow 增加了一种专门的数据类型来支持梯度信息的记录: tf.Variable。 tf.Variable 类型在普通的张量类型基础上添加了 name, trainable 等属性来支持计算图的构建。" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "('Variable:0', True)" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 创建 TF 张量\n", "a = tf.constant([-1, 0, 1, 2]) \n", "# 转换为 Variable 类型\n", "aa = tf.Variable(a) \n", "# Variable 类型张量的属性\n", "aa.name, aa.trainable " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "name 属性用于命名计算图中的变量,这套命名体系是 TensorFlow 内部维护的, 一般不需要用户关注 name 属性; \n", "trainable属性表征当前张量是否需要被优化,创建 Variable 对象时是默认启用优化标志,可以设置trainable=False 来设置张量不需要优化。" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 直接创建 Variable 张量\n", "tf.Variable([[1,2],[3,4]]) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 创建张量\n", "\n", "### 从数组、列表对象创建" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "通过 tf.convert_to_tensor 函数可以创建新 Tensor,并将保存在 Python List 对象或者Numpy Array 对象中的数据导入到新 Tensor 中。" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 从列表创建张量\n", "tf.convert_to_tensor([1,2.]) " ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 从数组中创建张量\n", "tf.convert_to_tensor(np.array([[1,2.],[3,4]])) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 创建全0或全1张量" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(,\n", " )" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 创建全 0,全 1 的标量\n", "tf.zeros([]),tf.ones([]) " ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(,\n", " )" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 创建全 0,全 1 的向量\n", "tf.zeros([1]),tf.ones([1]) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "创建全 0 的矩阵" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 创建全 0 矩阵,指定 shape 为 2 行 2 列\n", "tf.zeros([2,2]) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "创建全 1 的矩阵" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 创建全 1 矩阵,指定 shape 为 3 行 2 列\n", "tf.ones([3,2]) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "通过 tf.zeros_like, tf.ones_like 可以方便地新建与某个张量 shape 一致, 且内容为全 0 或全 1 的张量。" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 创建一个矩阵\n", "a = tf.ones([2,3]) \n", "# 创建一个与 a 形状相同,但是全 0 的新矩阵\n", "tf.zeros_like(a) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "创建与张量A形状一样的全 1 张量" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 创建一个矩阵\n", "a = tf.zeros([3,2]) \n", "# 创建一个与 a 形状相同,但是全 1 的新矩阵\n", "tf.ones_like(a) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 创建自定义数值张量\n", "\n", "通过 tf.fill(shape, value)可以创建全为自定义数值 value 的张量,形状由 shape 参数指定。" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 创建-1 的标量\n", "tf.fill([], -1) " ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 创建-1 的向量\n", "tf.fill([1], -1) " ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 创建 2 行 2 列,元素全为 99 的矩阵\n", "tf.fill([2,2], 99) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 创建已知分布的张量" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "通过 tf.random.normal(shape, mean=0.0, stddev=1.0)可以创建形状为 shape,均值为mean,标准差为 stddev 的正态分布$\\mathcal{N}(mean, stddev^2)$。" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 创建标准正态分布的张量\n", "tf.random.normal([2,2]) " ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 创建均值为 1,标准差为 2 的正态分布的张量\n", "tf.random.normal([2,2], mean=1,stddev=2) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "通过 tf.random.uniform(shape, minval=0, maxval=None, dtype=tf.float32)可以创建采样自[minval, maxval)区间的均匀分布的张量" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 创建采样自[0,1)均匀分布的矩阵\n", "tf.random.uniform([3,2]) " ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 创建采样自[0,10)均匀分布的矩阵\n", "tf.random.uniform([2,2],maxval=10) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "如果需要均匀采样整形类型的数据,必须指定采样区间的最大值 maxval 参数,同时指定数据类型为 tf.int*型" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 创建采样自[0,100)均匀分布的整型矩阵\n", "tf.random.uniform([2,2],maxval=100,dtype=tf.int32)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 创建序列\n", "\n", "tf.range(limit, delta=1)可以创建[0, limit)之间,步长为 delta 的整型序列,不包含 limit 本身。" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 0~10,不包含 10\n", "tf.range(10) " ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 43, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 创建 0~10,步长为 2 的整形序列\n", "tf.range(10,delta=2)" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tf.range(1,10,delta=2) # 1~10" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 张量的典型应用\n", "\n", "### 标量" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(0.28147948, shape=(), dtype=float32)\n" ] } ], "source": [ "# 随机模拟网络输出\n", "out = tf.random.uniform([4,10]) \n", "# 随机构造样本真实标签\n", "y = tf.constant([2,3,2,0]) \n", "# one-hot 编码\n", "y = tf.one_hot(y, depth=10) \n", "# 计算每个样本的 MSE\n", "loss = tf.keras.losses.mse(y, out) \n", "# 平均 MSE,loss 应是标量\n", "loss = tf.reduce_mean(loss) \n", "print(loss)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 向量\n", "\n", "考虑 2 个输出节点的网络层, 我们创建长度为 2 的偏置向量b,并累加在每个输出节点上:" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# z=wx,模拟获得激活函数的输入 z\n", "z = tf.random.normal([4,2])\n", "# 创建偏置向量\n", "b = tf.zeros([2])\n", "# 累加上偏置向量\n", "z = z + b \n", "z" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "创建输入节点数为 4,输出节点数为 3 的线性层网络,那么它的偏置向量 b 的长度应为 3" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 创建一层 Wx+b,输出节点为 3\n", "fc = tf.keras.layers.Dense(3) \n", "# 通过 build 函数创建 W,b 张量,输入节点为 4\n", "fc.build(input_shape=(2,4))\n", "# 查看偏置向量\n", "fc.bias " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 矩阵" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 2 个样本,特征长度为 4 的张量\n", "x = tf.random.normal([2,4]) \n", "# 定义 W 张量\n", "w = tf.ones([4,3])\n", "# 定义 b 张量\n", "b = tf.zeros([3]) \n", "# X@W+b 运算\n", "o = x@w+b \n", "o" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 定义全连接层的输出节点为 3\n", "fc = tf.keras.layers.Dense(3) \n", "# 定义全连接层的输入节点为 4\n", "fc.build(input_shape=(2,4)) \n", "# 查看权值矩阵 W\n", "fc.kernel " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 三维张量" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(25000, 80)" ] }, "execution_count": 50, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 自动加载 IMDB 电影评价数据集\n", "(x_train,y_train),(x_test,y_test)=keras.datasets.imdb.load_data(num_words=10000)\n", "# 将句子填充、截断为等长 80 个单词的句子\n", "x_train = keras.preprocessing.sequence.pad_sequences(x_train,maxlen=80)\n", "x_train.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以看到 x_train 张量的 shape 为[25000,80],其中 25000 表示句子个数, 80 表示每个句子共 80 个单词,每个单词使用数字编码方式表示。\n", "\n", "我们通过 layers.Embedding 层将数字编码的单词转换为长度为 100 个词向量:" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "TensorShape([25000, 80, 100])" ] }, "execution_count": 51, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 创建词向量 Embedding 层类\n", "embedding = tf.keras.layers.Embedding(10000, 100)\n", "# 将数字编码的单词转换为词向量\n", "out = embedding(x_train)\n", "out.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以看到,经过 Embedding 层编码后,句子张量的 shape 变为[25000,80,100],其中 100 表示每个单词编码为长度是 100 的向量。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 四维张量" ] }, { "cell_type": "code", "execution_count": 52, "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/plain": [ "TensorShape([4, 30, 30, 16])" ] }, "execution_count": 52, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 创建 32x32 的彩色图片输入,个数为 4\n", "x = tf.random.normal([4,32,32,3])\n", "# 创建卷积神经网络\n", "layer = layers.Conv2D(16, kernel_size=3)\n", "# 前向计算\n", "out = layer(x) \n", "# 输出大小\n", "out.shape " ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "TensorShape([3, 3, 3, 16])" ] }, "execution_count": 53, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 访问卷积核张量\n", "layer.kernel.shape " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 索引与切片\n", "### 索引" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [], "source": [ "# 创建4维张量\n", "x = tf.random.normal([4,32,32,3]) " ] }, { "cell_type": "code", "execution_count": 55, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 55, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 取第 1 张图片的数据\n", "x[0]" ] }, { "cell_type": "code", "execution_count": 56, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 56, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 取第 1 张图片的第 2 行\n", "x[0][1]" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 57, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 取第 1 张图片,第 2 行,第 3 列的数据\n", "x[0][1][2]" ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 58, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 取第 3 张图片,第 2 行,第 1 列的像素, B 通道(第 2 个通道)颜色强度值\n", "x[2][1][0][1]" ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 59, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 取第 2 张图片,第 10 行,第 3 列的数据\n", "x[1,9,2]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 切片" ] }, { "cell_type": "code", "execution_count": 60, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 60, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 读取第 2,3 张图片\n", "x[1:3]" ] }, { "cell_type": "code", "execution_count": 61, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 61, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 读取第一张图片\n", "x[0,::] " ] }, { "cell_type": "code", "execution_count": 62, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 62, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x[:,0:28:2,0:28:2,:]" ] }, { "cell_type": "code", "execution_count": 63, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 63, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 考虑一个 0~9 的简单序列向量, 逆序取到第 1 号元素,不包含第 1 号\n", "# 创建 0~9 向量\n", "x = tf.range(9) \n", "# 从 8 取到 0,逆序,不包含 0\n", "x[8:0:-1] " ] }, { "cell_type": "code", "execution_count": 64, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 64, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 逆序全部元素\n", "x[::-1] " ] }, { "cell_type": "code", "execution_count": 65, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 65, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 逆序间隔采样\n", "x[::-2] " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "读取每张图片的所有通道,其中行按着逆序隔行采样,列按着逆序隔行采样" ] }, { "cell_type": "code", "execution_count": 66, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 66, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = tf.random.normal([4,32,32,3])\n", "# 行、列逆序间隔采样\n", "x[0,::-2,::-2] " ] }, { "cell_type": "code", "execution_count": 67, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 67, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 取 G 通道数据\n", "x[:,:,:,1] " ] }, { "cell_type": "code", "execution_count": 68, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 68, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 读取第 1~2 张图片的 G/B 通道数据\n", "# 高宽维度全部采集\n", "x[0:2,...,1:] " ] }, { "cell_type": "code", "execution_count": 69, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 69, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 读取最后 2 张图片\n", "# 高、宽、通道维度全部采集,等价于 x[2:]\n", "x[2:,...] " ] }, { "cell_type": "code", "execution_count": 70, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 70, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 读取 R/G 通道数据\n", "# 所有样本,所有高、宽的前 2 个通道\n", "x[...,:2] " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 维度变换\n", "\n", "### 改变视图" ] }, { "cell_type": "code", "execution_count": 71, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 71, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 生成向量\n", "x=tf.range(96)\n", "# 改变 x 的视图,获得 4D 张量,存储并未改变\n", "x=tf.reshape(x,[2,4,4,3]) \n", "x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 改变视图" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "我们通过 tf.range()模拟生成一个向量数据,并通过 tf.reshape 视图改变函数产生不同的视图" ] }, { "cell_type": "code", "execution_count": 72, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 72, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 生成向量\n", "x = tf.range(96) \n", "# 改变 x 的视图,获得 4D 张量,存储并未改变\n", "x = tf.reshape(x,[2,4,4,3]) \n", "x" ] }, { "cell_type": "code", "execution_count": 73, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(4, TensorShape([2, 4, 4, 3]))" ] }, "execution_count": 73, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 获取张量的维度数和形状列表\n", "x.ndim,x.shape " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "通过 tf.reshape(x, new_shape),可以将张量的视图任意地合法改变" ] }, { "cell_type": "code", "execution_count": 74, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 74, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tf.reshape(x,[2,-1])" ] }, { "cell_type": "code", "execution_count": 75, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 75, "metadata": {}, "output_type": "execute_result" } ], "source": [ " tf.reshape(x,[2,4,12])" ] }, { "cell_type": "code", "execution_count": 76, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 76, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tf.reshape(x,[2,-1,3])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 增、删维度" ] }, { "cell_type": "code", "execution_count": 77, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 77, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 产生矩阵\n", "x = tf.random.uniform([28,28],maxval=10,dtype=tf.int32)\n", "x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "通过 tf.expand_dims(x, axis)可在指定的 axis 轴前可以插入一个新的维度" ] }, { "cell_type": "code", "execution_count": 78, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 78, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# axis=2 表示宽维度后面的一个维度\n", "x = tf.expand_dims(x,axis=2) \n", "x" ] }, { "cell_type": "code", "execution_count": 79, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 79, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = tf.expand_dims(x,axis=0) # 高维度之前插入新维度\n", "x" ] }, { "cell_type": "code", "execution_count": 80, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 80, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = tf.squeeze(x, axis=0) # 删除图片数量维度\n", "x" ] }, { "cell_type": "code", "execution_count": 81, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 81, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = tf.random.uniform([1,28,28,1],maxval=10,dtype=tf.int32)\n", "tf.squeeze(x) # 删除所有长度为 1 的维度" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 交换维度" ] }, { "cell_type": "code", "execution_count": 82, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 82, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = tf.random.normal([2,32,32,3])\n", "# 交换维度\n", "tf.transpose(x,perm=[0,3,1,2]) " ] }, { "cell_type": "code", "execution_count": 83, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 83, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = tf.random.normal([2,32,32,3])\n", "# 交换维度\n", "tf.transpose(x,perm=[0,2,1,3]) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 复制数据" ] }, { "cell_type": "code", "execution_count": 84, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 84, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 创建向量 b\n", "b = tf.constant([1,2]) \n", "# 插入新维度,变成矩阵\n", "b = tf.expand_dims(b, axis=0) \n", "b" ] }, { "cell_type": "code", "execution_count": 85, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 85, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 样本维度上复制一份\n", "b = tf.tile(b, multiples=[2,1]) \n", "b" ] }, { "cell_type": "code", "execution_count": 86, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 86, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = tf.range(4)\n", "# 创建 2 行 2 列矩阵\n", "x=tf.reshape(x,[2,2]) \n", "x" ] }, { "cell_type": "code", "execution_count": 87, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 87, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 列维度复制一份\n", "x = tf.tile(x,multiples=[1,2]) \n", "x" ] }, { "cell_type": "code", "execution_count": 88, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 88, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 行维度复制一份\n", "x = tf.tile(x,multiples=[2,1]) \n", "x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Broadcasting" ] }, { "cell_type": "code", "execution_count": 89, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 89, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 创建矩阵\n", "A = tf.random.normal([32,1]) \n", "# 扩展为 4D 张量\n", "tf.broadcast_to(A, [2,32,32,3]) " ] }, { "cell_type": "code", "execution_count": 90, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Incompatible shapes: [32,2] vs. [2,32,32,4] [Op:BroadcastTo]\n" ] } ], "source": [ "A = tf.random.normal([32,2])\n", "# 不符合 Broadcasting 条件\n", "try: \n", " tf.broadcast_to(A, [2,32,32,4])\n", "except Exception as e:\n", " print(e)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 数学运算\n", "\n", "### 加、减、乘、除运算" ] }, { "cell_type": "code", "execution_count": 91, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 91, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = tf.range(5)\n", "b = tf.constant(2)\n", "# 整除运算\n", "a//b " ] }, { "cell_type": "code", "execution_count": 92, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 92, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 余除运算\n", "a%b " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 乘方运算" ] }, { "cell_type": "code", "execution_count": 93, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 93, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = tf.range(4)\n", "# 乘方运算\n", "tf.pow(x,3) " ] }, { "cell_type": "code", "execution_count": 94, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 94, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 乘方运算符\n", "x**2 " ] }, { "cell_type": "code", "execution_count": 95, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 95, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x=tf.constant([1.,4.,9.])\n", "# 平方根\n", "x**(0.5) " ] }, { "cell_type": "code", "execution_count": 96, "metadata": {}, "outputs": [], "source": [ "x = tf.range(5)\n", "# 转换为浮点数\n", "x = tf.cast(x, dtype=tf.float32) \n", "# 平方\n", "x = tf.square(x) " ] }, { "cell_type": "code", "execution_count": 97, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 97, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 平方根\n", "tf.sqrt(x) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 指数和对数运算" ] }, { "cell_type": "code", "execution_count": 98, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 98, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = tf.constant([1.,2.,3.])\n", "# 指数运算\n", "2**x " ] }, { "cell_type": "code", "execution_count": 99, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 99, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 自然指数运算\n", "tf.exp(1.)" ] }, { "cell_type": "code", "execution_count": 100, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 100, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = tf.exp(3.)\n", "# 对数运算\n", "tf.math.log(x) " ] }, { "cell_type": "code", "execution_count": 101, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 101, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = tf.constant([1.,2.])\n", "x = 10**x\n", "# 换底公式\n", "tf.math.log(x)/tf.math.log(10.) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 矩阵相乘运算" ] }, { "cell_type": "code", "execution_count": 102, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 102, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = tf.random.normal([4,3,28,32])\n", "b = tf.random.normal([4,3,32,2])\n", "# 批量形式的矩阵相乘\n", "a@b" ] }, { "cell_type": "code", "execution_count": 103, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 103, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = tf.random.normal([4,28,32])\n", "b = tf.random.normal([32,16])\n", "# 先自动扩展,再矩阵相乘\n", "tf.matmul(a,b)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 前向传播实战" ] }, { "cell_type": "code", "execution_count": 104, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import tensorflow as tf\n", "import tensorflow.keras.datasets as datasets\n", "\n", "plt.rcParams['font.size'] = 16\n", "plt.rcParams['font.family'] = ['STKaiti']\n", "plt.rcParams['axes.unicode_minus'] = False" ] }, { "cell_type": "code", "execution_count": 105, "metadata": {}, "outputs": [], "source": [ "def load_data():\n", " # 加载 MNIST 数据集\n", " (x, y), (x_val, y_val) = datasets.mnist.load_data()\n", " # 转换为浮点张量, 并缩放到-1~1\n", " x = tf.convert_to_tensor(x, dtype=tf.float32) / 255.\n", " # 转换为整形张量\n", " y = tf.convert_to_tensor(y, dtype=tf.int32)\n", " # one-hot 编码\n", " y = tf.one_hot(y, depth=10)\n", "\n", " # 改变视图, [b, 28, 28] => [b, 28*28]\n", " x = tf.reshape(x, (-1, 28 * 28))\n", "\n", " # 构建数据集对象\n", " train_dataset = tf.data.Dataset.from_tensor_slices((x, y))\n", " # 批量训练\n", " train_dataset = train_dataset.batch(200)\n", " return train_dataset" ] }, { "cell_type": "code", "execution_count": 106, "metadata": {}, "outputs": [], "source": [ "def init_paramaters():\n", " # 每层的张量都需要被优化,故使用 Variable 类型,并使用截断的正太分布初始化权值张量\n", " # 偏置向量初始化为 0 即可\n", " # 第一层的参数\n", " w1 = tf.Variable(tf.random.truncated_normal([784, 256], stddev=0.1))\n", " b1 = tf.Variable(tf.zeros([256]))\n", " # 第二层的参数\n", " w2 = tf.Variable(tf.random.truncated_normal([256, 128], stddev=0.1))\n", " b2 = tf.Variable(tf.zeros([128]))\n", " # 第三层的参数\n", " w3 = tf.Variable(tf.random.truncated_normal([128, 10], stddev=0.1))\n", " b3 = tf.Variable(tf.zeros([10]))\n", " return w1, b1, w2, b2, w3, b3" ] }, { "cell_type": "code", "execution_count": 107, "metadata": {}, "outputs": [], "source": [ "def train_epoch(epoch, train_dataset, w1, b1, w2, b2, w3, b3, lr=0.001):\n", " for step, (x, y) in enumerate(train_dataset):\n", " with tf.GradientTape() as tape:\n", " # 第一层计算, [b, 784]@[784, 256] + [256] => [b, 256] + [256] => [b,256] + [b, 256]\n", " h1 = x @ w1 + tf.broadcast_to(b1, (x.shape[0], 256))\n", " h1 = tf.nn.relu(h1) # 通过激活函数\n", "\n", " # 第二层计算, [b, 256] => [b, 128]\n", " h2 = h1 @ w2 + b2\n", " h2 = tf.nn.relu(h2)\n", " # 输出层计算, [b, 128] => [b, 10]\n", " out = h2 @ w3 + b3\n", "\n", " # 计算网络输出与标签之间的均方差, mse = mean(sum(y-out)^2)\n", " # [b, 10]\n", " loss = tf.square(y - out)\n", " # 误差标量, mean: scalar\n", " loss = tf.reduce_mean(loss)\n", "\n", " # 自动梯度,需要求梯度的张量有[w1, b1, w2, b2, w3, b3]\n", " grads = tape.gradient(loss, [w1, b1, w2, b2, w3, b3])\n", "\n", " # 梯度更新, assign_sub 将当前值减去参数值,原地更新\n", " w1.assign_sub(lr * grads[0])\n", " b1.assign_sub(lr * grads[1])\n", " w2.assign_sub(lr * grads[2])\n", " b2.assign_sub(lr * grads[3])\n", " w3.assign_sub(lr * grads[4])\n", " b3.assign_sub(lr * grads[5]) \n", " \n", " return loss.numpy()" ] }, { "cell_type": "code", "execution_count": 108, "metadata": {}, "outputs": [], "source": [ "def train(epochs):\n", " losses = []\n", " train_dataset = load_data()\n", " w1, b1, w2, b2, w3, b3 = init_paramaters()\n", " for epoch in range(epochs):\n", " loss = train_epoch(epoch, train_dataset, w1, b1, w2, b2, w3, b3, lr=0.001)\n", " print('epoch:', epoch, 'loss:', loss)\n", " losses.append(loss)\n", "\n", " x = [i for i in range(0, epochs)]\n", " # 绘制曲线\n", " plt.plot(x, losses, color='blue', marker='s', label='训练')\n", " plt.xlabel('Epoch')\n", " plt.ylabel('MSE')\n", " plt.legend()\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 109, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch: 0 loss: 0.16430502\n", "epoch: 1 loss: 0.1473901\n", "epoch: 2 loss: 0.13449846\n", "epoch: 3 loss: 0.12467159\n", "epoch: 4 loss: 0.11701874\n", "epoch: 5 loss: 0.110820025\n", "epoch: 6 loss: 0.105716765\n", "epoch: 7 loss: 0.10139298\n", "epoch: 8 loss: 0.09769247\n", "epoch: 9 loss: 0.09452056\n", "epoch: 10 loss: 0.09176841\n", "epoch: 11 loss: 0.08933158\n", "epoch: 12 loss: 0.08714955\n", "epoch: 13 loss: 0.08515776\n", "epoch: 14 loss: 0.08333652\n", "epoch: 15 loss: 0.08166271\n", "epoch: 16 loss: 0.08011957\n", "epoch: 17 loss: 0.07869925\n", "epoch: 18 loss: 0.07738015\n", "epoch: 19 loss: 0.07615029\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZEAAAEQCAYAAABxzUkqAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3deXhU5fn/8fedEHZBtkBFUXFDRQVBxQ2xVotCVXADtRVoQTZFsIhV8Vu16tcdURaxWLAKSmur1qqUuvzUWsFAXeuXqohVZAkCKnuA+/fHmcAwTibJZGbOzOTzuq5zJfOcJ+fcM8bcPOsxd0dERCQZBWEHICIiuUtJREREkqYkIiIiSVMSERGRpCmJiIhI0uqEHUAmtWzZ0vfbb7+wwxARySkLFy5c7e6t4p2rVUlkv/32o6SkJOwwRERyipl9XtE5dWeJiEjSlERERCRpSiIiIpI0JREREUmakoiIiCRNSSSBNm3A7PtHmzZhRyYikh1q1RTf6lq5snrlIpJ6mzdvprS0lM2bN7Nt27aww8krRUVFFBcX06RJk6SvoSQiIlnrm2++YeXKlbRq1Yo2bdpQp04dzCzssPKCu7Np0yaWLVsGkHQiUXeWiGSt1atXs/fee9OsWTOKioqUQFLIzGjYsCFt27Zl1apVSV9HSUREstbWrVtp0KBB2GHktQYNGlBWVpb0zyuJiEhWU+sjvWr6+SqJJNC6dfXKRURqm7xIImZWbGbHpfq6K1aA+65j0CBo2BA++ijVdxIRyU0ZTyIWGGZmQ81slJkdnaBufTOba2Y9KjhfaGY3ApcD76cp5J1GjoSNG2HGjHTfSUQkN4TREhkFvO3uU939fmCwmTWNrWRmewL3AwfGu4iZ1QH+DHzu7re4+8Z0Bg3QuTOceCJMmgQ7dqT7biKS79avX5/w/KBBg5g2bVrccy+99BInn3wy33777W7lQ4YM4fbbb09ZjJXJaBIxs0Kgn7tHP9RjPjAgTvUzgTHAFxVc7nZgpbvPTGmQlRg5Ej79FObOzeRdRSQVsmkXiu3btzNo0CA6dOjAt99+S2lp6ffqHHzwwXzwwQdxf37hwoV07tyZzZs371a+dOlS+vTpk5aY48l0S6Q7sDmmbDHQP7aiu8929w3xLmJmBwFDgRtSHmEl+vYNfuEefDDTdxaRmsqmXSgKCwuZM2cO48ePZ8uWLcyZM4fRo0ezYcMG5syZA0D9+vU54ogjvvezZWVlfPbZZ0yYMIERI0Ywa9YsduzYwYIFC1i+fDk33ngj7dq1o0OHDmzdujWt7yPTK9bbA2tiytZEyqvj58CLwFYzGwr0Aka5+5Kah5hY3bpw+eVw883wySdwYNzONhFJp6uugnfeSe01e/SoXv1OnWDChJrdc8uWLVxyySUAdOnShZKSEho2bMjo0aM577zzKCgoiLtOZvr06VxzzTUsWLCADRs20LNnT2bNmsXEiRPp3r07Xbp0YcyYMRx11FHUrVu3ZkFWItMtkWIgdlVLGdDSzKoTy0nA/3P3te4+FZgF/DZeRTMbYmYlZlYSr7mYjCFDoLAQpkxJyeVEpJb573//y6RJkyguLmbAgAEA7LHHHju3dTn22GMpLCzEzL63jmPq1KmMGTOGvn37Mm7cOJ566imaN29O//79ady4Md26daO0tJRu3bplZKFmplsiq4DYtFgErHb36gxVt2L3sZJ5wCwzaxTbBebu04BpAF27dvXqh/x9e+0F550HjzwStEgaNUrFVUWkqpJtASRaV/fqq8lds7qee+45HnjgAS677DJOO+00hg8fHonNqFMn+JNc/sc/XhK54IILuPjii3nhhRc4/fTTadCgAevWreOKK65g6dKlDBw4kIEDB2bmzZD5lsgSggQQrRnwWTWvswaoF/W6fHpCiyTjqraRI2HdOnj88UzdUUTyQe/evZk7dy4XX3wx27Zto2PHjgBs3LiRPfbYA4CCguBPs7vv/L5c8+bNGTduHFOnTuXKK69k6tSpANStW5exY8dy4YUXsmjRIk444QRatWrFlDR3mWS6JfI60MTMCqJaHocAs6t5nQ8iP1euBbAV+KrmIVbNiScGfaIPPgiDByf+F46IZIfWreMPooe1C0VZWRn16tXjySefpH79+uy9994AO7e83xFnLYGZ0bhxY+68806OOeYYAG6++WYOOOAAjjzySPbff38+/PBDrr766oy8h4y2RNx9GzAd6BlV3A2YYWbnm1lVp+s+APSPTBkGOA14OHL9jDALWiPvvw+vv56pu4pITcTuQlF+rFiR+Vh27NjBl19+Sb9+/TjyyCN57bXXOPbYY9m4cSPLly8HqPD5KXfddRdt2rRh4sSJ3HXXXYwfP56ePXty4okn0rx5c7Zs2cK//vUvFi9enPb3EcZiw4nAwWY23MxGAdPdfR3QFuhYXsnMGpjZEKATMNTMdp5z9/eA24AJZnYl0BUYl8k3AdC/PzRrpum+IlJ9Dz/8MIsXL2bEiBG0aNGCd999d+dgeHkLY9OmTbtN0V21ahXXX389xx9/PBdddBGHHXYYY8eO5dlnn+XRRx8Fgm6xzZs30759e3r16sWkSZPS+j4y/lAqd3fge8NikdXr90e93kQwIB53uaa7PwY8lqYwq6RhQ/j5z+G++2DZMmjbNsxoRCSXNG7cmLvvvpvu3bszaNAgJk+eDATdVXfffTcAy5Yto3nz5jt/pkWLFqxbt47Ro0dz4YUX7iy/9dZbeeihh9ixYwcrV65k/fr1NG3alCeeeIJBgwYxYsSItL2PvNiAMUzDhgVboDz0UNiRiEguueSSS7jyyitZuHAht956KwcffPD36ixdunS3cZHCwkImTZq0WwIBeOihh+jcuTO33HILQ4YM2fmUwq5duzJv3ry0vg8lkRpq3x569QqSyJYtYUcjIrnmmGOOoW0F3Rg/+tGPOOSQQ+Kei9a5c2cArrvuOg499NCdCxgBWqd51oCesZ4CI0dCz57wxz9C1H87EZEaqe4Mq6KiIt566600RROfWiIpcPrpcNBBGmAXkdpHSSQFCgqC1shbb0FJSeX1RUTyhZJIilx2WbD9SZpn04mIZBUlkRRp2hR+9jOYPRtWrw47GpH8EawKkHSp6eerJJJCI0YEM7SmTw87EpH8ULduXTZt2hR2GHlt06ZNFBUVJf3zSiIpdPjhcOqpMHkybN8edjQiua9ly5Z8+eWXrFmzhrKyMrVKUsjd2bhxI8uWLaO4uDjp62iKb4qNHBlsE//cc3DOOWFHI5LbmjZtSr169SgtLeXrr7+ucC8pSU5RURGtW7feuTgxGUoiKXb22bD33vDAA0oiIqlQv3599tlnn7DDkAqoOyvF6tQJtkJ56SX46KOwoxERSS8lkTQYPDh4Frum+4pIvlMSSYNWraBfP5g5E779tvL6IiK5SkkkTUaOhPXrIbLFv4hIXlISSZNjjoFjjw3209KsRBHJV0oiaTRyJCxeHAyyi4jkIyWRNLrggmB85IEHwo5ERCQ9lETSaL/9oLQUnn0WzHYdbdqEHZmISGooiaTRypXVKxcRyTVKIiIikjQlERERSZqSiIiIJE1JREREkqYkkkatW8cvb9Eis3GIiKSLkkgarVgRrFYvP0pLoUkTOP74sCMTEUkNJZEMatkSrrsueGDVyy+HHY2ISM0piWTYlVdCu3bwy1/Cjh1hRyMiUjNKIhnWoAHcdhv8618wa1bY0YiI1IySSAj694cuXYKurU2bwo5GRCR5SiIhKCiAu++GL76A++8POxoRkeQpiYSkRw84++yga6u0NOxoRESSoyQSojvugI0b4aabwo5ERCQ5GU8iFhhmZkPNbJSZHZ2gbn0zm2tmPSq55tMpDzQDOnSAIUNg6tTg4VUiIrkmjJbIKOBtd5/q7vcDg82saWwlM9sTuB84MNHFzOwi4Jy0RJoBv/41NGwI48aFHYmISPVlNImYWSHQz91LoornAwPiVD8TGAN8keB6ragkyWS74mK49lp45hl47bWwoxERqZ5Mt0S6A5tjyhYD/WMruvtsd99QyfVGA/emKLbQXHUV7L23FiCKSO7JdBJpD6yJKVsTKa8WMzsX+Ju75/xKi4YN4dZb4e234cknw45GRKTqMp1EioGymLIyoKWZVTmWyBjK0e7+ahXqDjGzEjMrKc3iubSXXgqdOsGvfgWbY9tqIiJZKtNJZBVQN6asCFjt7tXpyLkWmFiViu4+zd27unvXVq1aVeMWmVVQAPfcA59/Dg88EHY0IiJVk+kksgSI/UveDPisqheIDKb3A/5uZu+Y2TuR8nfMrEqJJVv98Idw1llB19bq1WFHIyJSuUwnkdeBJjFdV4cAs6t6AXcvdff93b1T+REp7+TuV6Y43oy780747ju45ZawIxERqVxGk4i7bwOmAz2jirsBM8zsfDObmcl4stHhh8MvfgGTJ8PHH4cdjYhIYmEsNpwIHGxmw81sFDDd3dcBbYGO5ZXMrIGZDQE6AUPNrGP8y+Wfm26CevWCQXYRkWxm7h52DBnTtWtXLykpqbxiFrjlFrjxRnjjDTjxxLCjEZHazMwWunvXeOe0AWOWGjMG9toLrr46eD67iEg2UhLJUo0awW9+A/Pnwx/+EHY0IiLxKYlksfIxkYsuArNdR5s24cYlIlJOSSSLrVxZvXIRkUxTEhERkaQpiYiISNKUREREJGlKIiIikjQlkSzWunX88sLCYH8tEZGwKYlksRUrgoWG0cerrwZfr7gi7OhERJREcs4pp8ANN8DMmfD442FHIyK1nZJIDho/Hk46CYYNg08/DTsaEanNlERyUJ06QSuksBD694etW8OOSERqKyWRHNWuHfz2t/D228FuvyIiYVASyWHnnQeXXw533AHz5oUdjYjURkoiOe7ee+Gww+BnP4NVq8KORkRqGyWRHNewITzxBKxdCwMGwI4dYUckIrWJkkgeOOKIoEXywgswcWLY0YhIbaIkkieGDYNzzoFrroFFi8KORkRqCyWRPGEG06dDcTH06wfr14cdkYjUBkoieaRFi2D9yCefaFsUEckMJZE8U74tyowZMGtW2NGISL5TEslDN94IJ5wAQ4fCkiVhRyMi+UxJJA/VqRO0QgoKgm1RysrCjkhE8lXSScTMGqYyEEmtffcNvi5YAHXrBgPv5UebNuHGJiL5o8IkYmaVJZjuZvZnM/tnimOSFPnmm/jlK1dmNg4RyV+JEsVaM7vbzPaLd9LdXwTOB9amIS4REckBiZLIWHf/pbsvNbNDzWyMmX1qZteb2aEA7r4dmJuZUEVEJNvUSXBu53I1d/8I+MjMtrr7gzH1NGwrIlJLJWqJWJyyeI8/ildPstymTWFHICL5IFES8Speo6r1JMNat6743IUXauqviNRcou6sPmZWyO4tjePNbHNMvTOBySmPTGpsxYr45VOmwPDhwdbxv/99sJ5ERCQZiZLIeZEj1mUxr9USyTHDhgXPH7n+ethzT3jwwWD9iIhIdSX6N+g4dy+o7AAGVeeGFhhmZkPNbJSZHZ2gbn0zm2tmPeKc62tm75jZSjN71MyaVieO2u5Xv4Jf/hImT4bx48OORkRyVaKWyAtVvMYH1bznKOANdy8BMLMpZnatu++2NM7M9gTuAA6MvYCZdQGOB44B6gF/BCYBl1YzllrLDO68E9atg1tvhWbN4Oqrw45KRHJNhS0Rd4+bHMyshZmdbWbdzazI3RdW9WaRMZZ+5QkkYj4wIE71M4ExwBdxznUHrnH3MndfD/wPcK6ZOmWqwwymToULLghaJdOnhx2RiOSaRNue9DCzyWZ2UlTZUcBi4GHgSuA5M9urGvfrDsQOzC8G+sdWdPfZ7r6hguv81t2jx2LKgE0xZVIFhYXw2GPw4x/DkCHwxz+GHZGI5JJEYyJjgHvc/Q0AM6sDzAK2AEe7+/kEf/yHV+N+7YE1MWVrIuVV5u7fxRT9EHgsXl0zG2JmJWZWUlpaWp3b1Bp168JTT8Hxx8PFF8Nc7UEgIlWUKIksdPdPo16PBA4FrnT3ZQDuvgZYVY37FfP9Fe5lQMsqbPgYl5k1I+j6+p945919mrt3dfeurVq1SuYWtUKjRvDcc3DYYdC3L7z5ZtgRiUguSPSH+9vyb8ysBTAemOvuT8XU27ca91sF1I0pKwJWu/uOalynPC4DfgMMdvdvK6svie25Z9AKadsWzjoL3n037IhEJNslSiIFZnZSJIE8BuwAhkZXiIyRdKvG/ZYAsc2BZsBn1bhGtLHAJHfX8/tSpHVrmDcP9tgDzjgDPv447IhEJJslSiJTgdOBvxMsKDzN3T8HMLNiM7sFeBN4sRr3ex1oEtN1dQgwu1pRBzEMJpgq/O+osubVvY583777BomktBQOPnj3B1rpoVYiEi1uEjGzIuBCd/8fd+/s7me5+3vl5919lbuPB5oAq6t6M3ffBkwHekYVdwNmmNn5ZjazKtcxs7Mi13szquwMggF2SYEOHaCiuW56qJWIlIu72NDdy8zsSjNzgm6sihjBivUp1bjnRGBU5GFXRcB0d19nZm2BjjsvbNYA+CnQCRhqZqvd/QMz24NginFTM7svKo6GwAnViENERGrIKlpaYWabCKbzrqHi7d4N+IG710tPeKnVtWtXLykpqbyiAIn309KKHJHaw8wWunvXeOcSbXvSmmADxr2AT4Fn3P17T6Ews3EpiVJERHJOhUkkMmX2dwBm1h4YFuli+oe7vxpV9fG0RihZ6amn4Lx4ezyLSK1SpQV+7r7E3e9191uBbWZ2g5ldY2Yd3P3LNMcoIanooVZFRcF+W/feq24tkdoumVXin0R+7grgHTMbm9qQJFusWBEkidjjm2+CVsjVV8MVV8C2bWFHKiJhqXISMbMfmdkfgc8JZk09CLRz97vSFZxkpwYN4MknYexYmDQJ+vSB9evDjkpEwpAwiZhZczO72swWA38lWHTYy90Pcvc73H2VmbXMSKSSVQoKgueRTJoEzz8Pp5wCy5eHHZWIZFpFiw0LzOxR4EtgGPAIsI+7X+Duf4+pfkaaY5QsNnw4PPssLF4M3brBhx+GHZGIZFJl60T+BMwjaIHEq1gIXOLuP0pbhCmkdSLps2gR9O4NGzbAn/4Ep50WdkQikirJrhO5jiCJJFIA5EQCkfQ6+mh46y3o1Qt69oSHH4YBA8KOSkTSLW4SiTzG9vnyDRcTMbNZKY9KclK7dvDGG3D++TBwIHz2Gfz614lXvotIbos7JuLu2919cVUu4O5/TW1IksuaNg0G2gcOhJtvDgbgtQuwSP5K6mmCIokUFcH06RWf1y7AIvlDSUTSQl1YIrWDkoiIiCRNSURCUVoadgQikgpKIhKKI4+Ev/0t7ChEpKaURCRtKtoFuEULaN4cfvxjGDMGtmzJbFwikjpKIpI2Fe0CvHo1lJTAiBFw331w3HHw73+HHa2IJENJRELRoAE8+CD85S+wbBl06QJTpuj5JCK5RklEQtW7N7z/frAL8PDhcO65GnQXySVKIhK6Nm2CVe733QcvvhgMus+bF3ZUIlIVSiKSFQoK4KqrYMGCYND9jDOgUaPvb5mibVNEsouSiGSVo44KBt2HD4eNG+PX0bYpItlDSUSyToMGwRMTRST7KYmIiEjSlEQkJ333XdgRiAgoiUiO6tABnnhC60pEwqYkIlmrom1TmjcPZmj17w+nnw6Lq/T4NBFJByURyVoVbZvy9dfBVOBJk4KZXEccAdddBxs2hB2xSO2jJCI5qbAwmAb8n//AxRfD7bfDYYfB00+ri0skk5REJKcVF8OMGfDaa8Hz3fv0CbZS+fTTsCMTqR2URCQvnHwyLFwI994bJJTDD4fGjbXiXSTdlEQkbxQVwejRwUB7nz4Vj5FoxbtI6tTJ9A3NzIChgAP1gNfdfVEFdesDzwC3u/urMed6AB2B7cAmd5+Rvqgll+y1F8yeHUwBFpH0CqMlMgp4292nuvv9wGAzaxpbycz2BO4HDoxzbh/gx+7+oLtPAdaa2UXpDlxERHaX0SRiZoVAP3cviSqeDwyIU/1MYAzwRZxzI4CXo16/AIxOUZhSC4wYAV99FXYUIrkv0y2R7sDmmLLFQP/Yiu4+290rmvnfH/g4qu5WoJmZdUhVoJLfpk2DAw6AsWODx/WKSHIynUTaA2tiytZEyqvEzIqAvat6HTMbYmYlZlZSqkfm1SoVrXhv3Rr+7//gggvgnntg//3hxhth3brMxieSDzKdRIqBspiyMqClmVU1lpYEcce7TnFsZXef5u5d3b1rq1atqhuv5LCKVryvWBG0Qh59FD74AM48E265Bdq3DxYtauW7SNVlOomsAurGlBUBq919RxWvsRrYUcF1VtUsPKltDjsM5syBRYvghBOC7VPat4cJE2Dz5mBNidaaiFQs00lkCRDbHGgGfFbVC7h7GfBlTa8jEq1zZ3juOXjzTejYMVhvcuCBFa8p0VoTkUCmk8jrQJOYrqtDgNnVvM5sgjUiAJhZXeAbd/+o5iFKbXb88fDSS8Gxzz5hRyOS/TKaRNx9GzAd6BlV3A2YYWbnm9nMKl5qEtA76nVPYEJqohSBH/4waJWISGIZX7EOTARGmdl+BOMY0919nZm1ZffWRQPgp0AnYKiZrXb3DwDc/Qsze8zMxgLrgS3u/kiG34fkObPE55csCcZPRGoz81q0b3bXrl29pKSk8ooiEYkSSUEBnHsujBkTDMpXlnREcpWZLXT3rvHOaQNGkQQqWmvSqhWMGwevvAInnQTHHRfs1VUWO/FcJM8piYgkUNFak1Wr4Lbb4IsvYPLkYKFi//7B+pO77tLCRak9lEREaqBRIxg2LFgB/+yzwbTga66BvfeGUaOCFovWmUg+UxIRSYGCAvjJT+Dll4OFi337wpQpFe/LpXUmki+URERSrHPnYEuVpUvDjkQk/ZRERNJkr70Sn//rX2HbtszEIpIuSiIiIendG/bdF66/Hj75JOxoRJKjJCISkqeeCrq+/vd/4aCDoEePoBts48awIxOpOiURkTRK9EyTvn2DTR//+99guvCyZXDZZfCDH8DQobBggXYRluynFesiWcIdXnsNHnkE/vAH2LSp8voimaAV6yI5wAxOOQVmzoTly2Hq1LAjEqmckohIFmraFC6/PHGdOXM0fiLhUxIRyVEXXQTFxcF2K08/HTyJUSTTlEREctTLL8Oll8K8edCnTzBYf9ll8PzzsHVrUEcD85JuSiIiWSzR7K5TTw3GTZYvhxdfhPPOC/bv6tUrSBKDB+vxvpJ+mp0lkke2bIG//Q2efBKeeQbWr6+4bi36X19qKNHsrDCebCgiaVKvXrAR5E9+EkwRbtiw4rpr10KzZpmLTfKTurNE8lSDBonPFxfDaafB/fcHj/oVSYaSiEgtNXZs8NCtq64KHqZ1xBHBPl7z58OOHUEdDcxLZZRERPJYooH5226DDz+Ejz+Ge++Fli3hjjugWzdo2xaGDNHAvFROA+sistOaNcEU4WefDWZ8ffddxXVr0Z+OWk/bnohIlTRvHqw9mTMHSksT133hBa2YFyUREalAvXqJz591VpB0zjgD7rkn6BpT66T2URIRkaS8+CIMHx5sYf/LX0LHjrDPPvDznwe7EK9dq4H52kBjIiJSoTZt4g+it24dzOwq98UXMHducMybB998AwUFu2Z5xVOL/vTkPI2JiEhSVqwI/tjHHtEJBIIWyC9+EbRAVq+Gf/wDbrgh8bUTJRjJHUoiIpJSderACSfATTclrteyZfB0xwcegA8+UMskV2nbExEJRZ8+wU7Ef/5z8Lq4OHjO/KmnBsfBBwePCq5Kd5qER0lEREIxfXrwdelSeOWVIKG88kowvRhgr7202DEXqDtLRNIm0Yr5cvvtBwMHwu9/HwzQ/+c/8NBDcPLJia+9fXvKwpQa0OwsEclaZhWfa9IEjj8eTjopOI499vu7Fld1dpkkpq3gRSTvXHwxvPEGjB8fvK5TB7p02ZVUTjxR3WGZoJaIiGStRC2R8j9da9bAP/8ZJJQ33oAFC3Y9HjiRWvSnr8ayqiViZgYMBRyoB7zu7ovi1KsHDAc2AnsAT7n7Z1HnTwJOAdYChwJ3u/vn6X8HIpIprVtX3B1Vrnnz4JHAvXoFrzdvhoULg4Ry7bUVX/v554MusJYtUxtzbZPxloiZXQW84e4lkddTgGvd/ZuYencCE9z9q0jimQkMdPftZtYKmOzuF0Tqtgemu/upie6tlohI7ZKoJVPugAOC7e+POy44OnWCunWDcxpTCWRNS8TMCoF+7j4hqng+MAC4P6peM6Czu38F4O5uZkuA3sAzwMnAzlaJuy8xs33T/w5EJF+88krwAK7584PpxY8/HpTXqwedOwcJRWMqlct0d1Z3YHNM2WLgPqKSCHAusCxOvf4ESeQroLuZ1XH3bZGWyeL0hCwiuSpRd1iPHsEBwfjIl1/uSirz58O0aYmv7V61lk6+y3QSaQ+siSlbEymvcj13f8vMPgVeNrPrgcOAn8a7oZkNAYYAtGvXrkbBi0huqWqXk1mw/9c++8D55wdlZWW7urXiad06mA3WpQscfXTwtV273RNLbegOy3QSKQbKYsrKgJZmVuDuO6LqrYtTrzjq9QjgTmA0cDCwCFgde0N3nwZMg2BMpKZvQERqh6KixOd79w4G8OfN27XwsWXLXQmlS5fa0R2W6SSyCojN7UXA6qgEUl6vcZx6qwDMrAXwJHC2u280s17AXDM73N2Xpyd0EZFdHnkk+LppE7z3XpBQyo+77oJt28KNL1MynUSWAK1iypoRNUgeVe+0BPX6Av92940A7v5XM3sROB94IKURi0itVZUpxg0a7JrZVW7zZnj//WAKcUW6dAlmgh111K5jzz13r5ML3WGZTiKvA01iuq4OAWbH1HsaGBxTFl2vMd8foF8Wp0xEJGnJ/qGuXx+OOSZxnebN4S9/2dWiAdh3310JpVOn3OgOy2gSicykmg70BJ6PFHcDfmVm5wM/cffL3H2tmc2PdE99GKl3IPCbyPd/B6abmUWm/xYCXYC7M/h2RESSNm/ergd8vfsuvPNO8PXdd+G553LnoV1hLDY0YBSwlWCc4x/uXmJmo4CfuXuXSL16wBiCFZfwWZ0AAAbNSURBVOmNgT+5+5Ko6/QFTgX+A7QF5sRb+R5Niw1FJJOS7Y7auBE+/DBxd1ibNsFz7aOPww+Hxo13r5OK7rBEiw21d5aISJZKtA5lwIDgiZAffhgM7pfbf/9dSeX22yv++er86c+aFesiIpIav/td8HXHDvjssyChRB8vvJCZOJRERESyVFVmhxUUBPt/HXAAnHPOrvKtW4MtXNJNSUREJEvVZBpvotX2qaTH44qISNKURERE8lRVnnFfU+rOEhHJU5lY1a6WiIiIJE1JREREkqYkIiIiSVMSERGRpCmJiIhI0mrV3llmVgp8nuSPtyTOkxOlyvT51Zw+w5rR55e8fd099llQQC1LIjVhZiUVbUAmldPnV3P6DGtGn196qDtLRESSpiQiIiJJUxKpumlhB5Dj9PnVnD7DmtHnlwYaExERkaSpJSIiIklTEhERkaQpiVTCAsPMbKiZjTKzo8OOKReZ2SYz86ijbdgxZSszKzCz35nZgJjyI83sqsjv4ggz0/+/cST4/MbF/A4+HFKIeUW/hJUbBbzt7lPd/X5gsJk1DTuoXGJmdYCxQIPyw92XhRtVdjKzesA9QOeY8sbAcHef4O5TgX8QfKYSpaLPL+I7oBG7fg+HZTC0vKUkkoCZFQL93L0kqng+MCCciHJWE2C5u28uP8IOKIv1BG4D3okpvxRYWP7C3d8B+ppZhh6CmjMq+vwAytx9Y9Tv4bYMx5aXlEQS6w7E/sFbDPQPIZZctgewNuwgcoG7P+PupXFO9Qc+jilbB/ww/VHljgSfH8D2jAZTSyiJJNYeWBNTtiZSLlXXBJhoZhvM7D0zOyvsgHKQfhdrrquZfWlmq81supk1CTugfKAkklgxUBZTVga01KBmtWwC/hc4CpgMPGFmXcINKedU9LtYHEIsuWoRcGbk2BeYGW44+UHPWE9sFRDb51wErHb3HSHEk5Pc/RPgk8jLTyKDxEOAy8OLKudU9Lu4KoRYcpK7/7b8ezM7l+B38QfuvjzEsHKe/jWd2BIgdvvjZsBnIcSST/5G8C9BqTr9LqaQu68H3kS/hzWmJJLY60CTmK6rQ4DZIcWTL5oDn4YdRI6ZDXSMKWsGvBRCLPliT5SEa0xJJIHIFMDpBNMGy3UDZoQSUI4ys0vMrFHk+zrACGBCuFHlnMeAk83MIFh4CDzt7lvDDSs3mNkRZnZC1OsewL/dfWV4UeUHbcBYicj/tKOArQR90P+IWTciCUQ+v+lAV+DPBJ/jHHePna4q7Py8fgrcQjCl9wZ3fyty7gjgbOBrgvHMKe6uaatRKvr8zOwU4EHg38ACgunRM/T51ZySiIiIJE3dWSIikjQlERERSZqSiIiIJE1JREREkqYkIiIiSVMSERGRpCmJiIhI0pRERJIQWQF9TeQxq0+Z2aVRx+Vm9m1kdX667l9oZleYWZmZ7Zeu+4hURosNRWrAzBwY6O4zYsrvdfcxGbj/58Ap7r403fcSiUctEZH0eDBD99G/AiVUSiIiKWZmh7n7ksj3DcxsYtgxiaSLkohI6p0JYGYtgbeAo8xsiJm9YGYrI19/UF7ZzA40s9vM7GIzG2Nm95pZ/ZjzvzOzu8zsczN7onw334g9zWySmX0SuXa9jL1TqfU0JiJSA5ExkUeBf0aKioAx7r5/5PyZBI8EPsfd34s81XEusMbdf2JmzYH5wLHuvjbyM4MjrwdHEsL7wOXu/oqZ7UPwRMjx7u5mthT4E3A9sA14GZjq7o9n5AOQWk+PxxWpuVeiB9bNbM+oc5uAz939PQieqGdmNwB/N7O6wEBgUXkCiZgJTDKz64EOQD13fyXy818AN8Tcf7K7b4rc+1WgfSrfnEgi6s4SSb0/VXL+fYL/9/Yg+IP/VfTJyIOm1gAHAW2Abyq53rao77cDhdUJVqQmlEREUszdP4SdXVmN41RpDKxw968Jnp3eMvpkpIXSAlgaOd8+0g0mknWURETSwMwKCcYu1hOTJIBfADdFvp8JHGtmTaLODwBmufuyyFM0PwAmRK6JmV0UGbQXCZ3GRESSYGZdgY6RlxfEtBTqAD8GmkdeNzGzqwm6qNoBX7v7VAB3X21mfYDrzexfQCOCLqwhUdfrB8wAvjSzJcA9wDdm9nOCFsu1ZnZv1H23mdlf3H1hqt+3SCzNzhJJIzPrAfza3Xuk4dqF7r49Mt3X3H1Hqu8hUhm1RERylLtvj3x1tHJdQqIxERERSZqSiEiaRBYS9gY6mNlPzaw47JhEUk1jIiJpEtkKvgjYTLB2w9y9LNyoRFJLSURERJKm7iwREUmakoiIiCRNSURERJKmJCIiIkn7/+V5MRkfciuyAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "train(epochs=20)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.2" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": { "height": "calc(100% - 180px)", "left": "10px", "top": "150px", "width": "273.188px" }, "toc_section_display": true, "toc_window_display": true } }, "nbformat": 4, "nbformat_minor": 1 }