{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import warnings\n", "warnings.simplefilter(action='ignore')" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import pandas as pd" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1. 准备 CIFAR-10 数据" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "x_train: (50000, 32, 32, 3)\n", "y_train: (50000, 1)\n", "x_test: (10000, 32, 32, 3)\n", "y_test: (10000, 1)\n" ] } ], "source": [ "print('x_train:', x_train.shape)\n", "print('y_train:', y_train.shape)\n", "print('x_test:', x_test.shape)\n", "print('y_test:', y_test.shape)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "x_train_normalize = x_train.astype('float32') / 255.0\n", "x_test_normalize = x_test.astype('float32') / 255.0" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "y_train_one_hot = tf.keras.utils.to_categorical(y_train)\n", "y_test_one_hot = tf.keras.utils.to_categorical(y_test)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(10000, 10)" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_test_one_hot.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2. 建立模型" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 2.1 建立 Sequential 模型" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "model = tf.keras.models.Sequential([\n", "\n", " # 卷积层1与池化层1\n", " tf.keras.layers.Conv2D(input_shape=(32, 32, 3), filters=32, kernel_size=(3, 3), padding='same', activation='relu'),\n", " tf.keras.layers.Dropout(0.3),\n", " tf.keras.layers.Conv2D(filters=32, kernel_size=(3, 3), padding='same', activation='relu'),\n", " tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),\n", " \n", " # 卷积层2与池化层2\n", " tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3), padding='same', activation='relu'),\n", " tf.keras.layers.Dropout(0.3),\n", " tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3), padding='same', activation='relu'),\n", " tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),\n", " \n", " # 卷积层3与池化层3\n", " tf.keras.layers.Conv2D(filters=128, kernel_size=(3, 3), padding='same', activation='relu'),\n", " tf.keras.layers.Dropout(0.3),\n", " tf.keras.layers.Conv2D(filters=128, kernel_size=(3, 3), padding='same', activation='relu'),\n", " tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),\n", " \n", " # 平坦层\n", " tf.keras.layers.Flatten(),\n", " tf.keras.layers.Dropout(0.3),\n", " \n", " # 隐藏层1(2500个神经元)\n", " tf.keras.layers.Dense(2500, activation='relu'),\n", " tf.keras.layers.Dropout(0.3),\n", " \n", " # 隐藏层2(1500个神经元)\n", " tf.keras.layers.Dense(1500, activation='relu'),\n", " tf.keras.layers.Dropout(0.3),\n", " \n", " # 输出层\n", " tf.keras.layers.Dense(10, activation='softmax')\n", "])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 2.2 查看模型的摘要" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "=================================================================\n", "conv2d (Conv2D) (None, 32, 32, 32) 896 \n", "_________________________________________________________________\n", "dropout (Dropout) (None, 32, 32, 32) 0 \n", "_________________________________________________________________\n", "conv2d_1 (Conv2D) (None, 32, 32, 32) 9248 \n", "_________________________________________________________________\n", "max_pooling2d (MaxPooling2D) (None, 16, 16, 32) 0 \n", "_________________________________________________________________\n", "conv2d_2 (Conv2D) (None, 16, 16, 64) 18496 \n", "_________________________________________________________________\n", "dropout_1 (Dropout) (None, 16, 16, 64) 0 \n", "_________________________________________________________________\n", "conv2d_3 (Conv2D) (None, 16, 16, 64) 36928 \n", "_________________________________________________________________\n", "max_pooling2d_1 (MaxPooling2 (None, 8, 8, 64) 0 \n", "_________________________________________________________________\n", "conv2d_4 (Conv2D) (None, 8, 8, 128) 73856 \n", "_________________________________________________________________\n", "dropout_2 (Dropout) (None, 8, 8, 128) 0 \n", "_________________________________________________________________\n", "conv2d_5 (Conv2D) (None, 8, 8, 128) 147584 \n", "_________________________________________________________________\n", "max_pooling2d_2 (MaxPooling2 (None, 4, 4, 128) 0 \n", "_________________________________________________________________\n", "flatten (Flatten) (None, 2048) 0 \n", "_________________________________________________________________\n", "dropout_3 (Dropout) (None, 2048) 0 \n", "_________________________________________________________________\n", "dense (Dense) (None, 2500) 5122500 \n", "_________________________________________________________________\n", "dropout_4 (Dropout) (None, 2500) 0 \n", "_________________________________________________________________\n", "dense_1 (Dense) (None, 1500) 3751500 \n", "_________________________________________________________________\n", "dropout_5 (Dropout) (None, 1500) 0 \n", "_________________________________________________________________\n", "dense_2 (Dense) (None, 10) 15010 \n", "=================================================================\n", "Total params: 9,176,018\n", "Trainable params: 9,176,018\n", "Non-trainable params: 0\n", "_________________________________________________________________\n", "None\n" ] } ], "source": [ "print(model.summary())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3. 加载之前训练的模型" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "加载模型失败, 开始训练一个新模型\n" ] } ], "source": [ "try:\n", " model.load_weights('save_model/model_cifar10_cnn_deeper.h5')\n", " print('加载模型成功! 继续训练模型.')\n", "except:\n", " print('加载模型失败, 开始训练一个新模型')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 4. 训练模型" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train on 40000 samples, validate on 10000 samples\n", "Epoch 1/10\n", " - 327s - loss: 1.4048 - acc: 0.4818 - val_loss: 1.3247 - val_acc: 0.5224\n", "Epoch 3/10\n", " - 330s - loss: 1.2062 - acc: 0.5643 - val_loss: 1.1510 - val_acc: 0.5969\n", "Epoch 4/10\n", " - 331s - loss: 1.0788 - acc: 0.6155 - val_loss: 0.9656 - val_acc: 0.6586\n", "Epoch 5/10\n", " - 330s - loss: 0.9780 - acc: 0.6520 - val_loss: 0.9960 - val_acc: 0.6565\n", "Epoch 6/10\n", " - 326s - loss: 0.9038 - acc: 0.6809 - val_loss: 0.8209 - val_acc: 0.7134\n", "Epoch 7/10\n", " - 329s - loss: 0.8324 - acc: 0.7054 - val_loss: 0.7892 - val_acc: 0.7240\n", "Epoch 8/10\n", " - 328s - loss: 0.7756 - acc: 0.7239 - val_loss: 0.7598 - val_acc: 0.7374\n", "Epoch 9/10\n", " - 327s - loss: 0.7320 - acc: 0.7431 - val_loss: 0.7315 - val_acc: 0.7497\n", "Epoch 10/10\n", " - 327s - loss: 0.6898 - acc: 0.7553 - val_loss: 0.7132 - val_acc: 0.7483\n" ] } ], "source": [ "train_history = model.fit(x=x_train_normalize, y=y_train_one_hot, validation_split=0.2,\n", " epochs=10, batch_size=128, verbose=2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 5. 以图形显示训练过程" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "def show_train_history(train_history, train, validation):\n", " plt.plot(train_history.history[train])\n", " plt.plot(train_history.history[validation])\n", " plt.title('Train History')\n", " plt.xlabel('Epoch')\n", " plt.ylabel(train)\n", " plt.legend(['train', 'validation'], loc='upper left')\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_train_history(train_history, 'acc', 'val_acc')" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_train_history(train_history, 'loss', 'val_loss')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 6. 评估模型的准确率" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "10000/10000 [==============================] - 21s 2ms/step\n", "\n", "accuracy: 0.7447\n" ] } ], "source": [ "scores = model.evaluate(x_test_normalize, y_test_one_hot)\n", "print()\n", "print('accuracy:', scores[1])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 7. 进行预测" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 7.1 执行预测" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "predictions = model.predict_classes(x_test_normalize)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 7.2 预测结果" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([8, 8, 8, 8, 6, 6, 1, 6, 3, 9])" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictions[:10]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 7.3 定义函数以显示10项预测结果" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "label_dict = {\n", " 0: 'airplane',\n", " 1: 'automobile',\n", " 2: 'bird',\n", " 3: 'cat',\n", " 4: 'deer',\n", " 5: 'dog',\n", " 6: 'frog',\n", " 7: 'horse',\n", " 8: 'ship',\n", " 9: 'truck'\n", "}" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "def plot_images_labels_prediction(images, labels, predictions, idx, num=10):\n", " \"\"\"\n", " images: 图像数组\n", " labels: 真实值数组, 其实每个元素的数字代表一种图像类别的名称\n", " predictions: 预测结果数据\n", " idx: 开始显示的数据index\n", " num: 要显示的数据项数, 默认为10, 不超过25\n", " \"\"\"\n", " fig = plt.gcf()\n", " fig.set_size_inches(12, 14)\n", " if num > 25:\n", " num = 25\n", " for i in range(0, num):\n", " ax = plt.subplot(5, 5, i+1)\n", " ax.imshow(images[idx], cmap='binary')\n", " title = str(labels[idx][0]) + ',' + label_dict[labels[idx][0]]\n", " if len(predictions) > 0:\n", " title += '=>' + label_dict[predictions[idx]]\n", " ax.set_title(title, fontsize=10)\n", " ax.set_xticks([])\n", " ax.set_yticks([])\n", " idx += 1\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_images_labels_prediction(x_test, y_test, predictions, idx=0, num=10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 8. 查看预测概率" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 8.1 使用测试数据进行预测" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "predicted_probability = model.predict(x_test_normalize)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 8.2 建立 show_predicted_probability 函数以相看预测概率" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "def show_predicted_probability(x, y, predictions, predicted_probability, idx):\n", " print('label:', label_dict[y[idx][0]])\n", " print('predict:', label_dict[predictions[idx]])\n", " plt.figure(figsize=(2, 2))\n", " plt.imshow(x[idx].reshape((32, 32, 3)))\n", " plt.show()\n", " for j in range(10):\n", " print(label_dict[j], ':', predicted_probability[idx][j])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 8.3 查看第0项数据预测的概率" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "label: cat\n", "predict: ship\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAJIAAACPCAYAAAARM4LLAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAFvFJREFUeJztXWtsHNd1/s7M7IPL5ZJLkSKpF03bspRHGydxlDhNEyeNATdFkfwI2iRAkQAB+qcBWqA/GuRPW6AFUqBo+6+AgbgxiqKu2yRIkBht0yRNGtiJLcuvSLJkWhJlSnyIryW5y33MzO2PXc0559qU1tJoJVH3AwTd2Xvn7t3hmXue9xwyxsDB4Xrh3ewFOOwMOEJySAWOkBxSgSMkh1TgCMkhFThCckgFjpAcUsF1ERIRPUJEp4homoi+mtaiHG4/0LUaJInIB3AawMMAZgE8B+DzxpgT6S3P4XZBcB33HgEwbYw5AwBE9ASATwPYlpAG+gKzq5QFAJDVR2R/wpDEbsDtN91j8Jbj3vR98sLYm7Kc355D3Cj67Hfxyi8nz2GPMuatn8GVfkv8pi/f7jmaba9iawo558Jqc8kYM7rNpAmuh5D2AnhDXM8C+OCVbthVyuLPv3AYAEAmVn3ZDC+FPP3HbTYbSTuMWnxPNqvGRTHPaaynQ16UtD2fPzetfj0OPC6Tras+H3KNPH8Uh2pcK+R1xLH1hyWeI4x0X0OM1cSin5V8gZrNluqLIrFGcZ8nfhcANMWzqurlo9bksX/772dn0AVuuLBNRH9IREeJ6OjmVnj1GxxuS1zPjnQBwH5xva/zmYIx5lEAjwLA5Fi/aXZo15gtPVC8ITnoXcIDbyFBIHaW7bkSKKM7G81m0g5jMZ/F2nyxWwXW/BSLtz/kXdJ+22Mxf5Pyqi/yc9wnxgFAM+IvpDgSbf0C5sVvC0gv0gvETtkS6yU9hxFrNpag4ftvf3+5nh3pOQAHiWiKiLIAPgfge9cxn8NtjGvekYwxIRF9BcB/AfABPGaMOZ7ayhxuK1wPa4Mx5ikAT6W0FofbGNdFSG8fBuYyvzcN3RMxD6dIyw5xi+Ubv0/IEdDajJRv4ljLLdlMJmmHhttxy/oucV8Y6jlIqMWekK3I19qj8Vku2opyqm9+meWWalNrlpub3Ocb/u6BvF5jlvh3lwp9qq8vx88x9vi5eW+Sg3jODDRatj2gCzgXiUMqcITkkAp6ytrIGARRh6X5evv0hGqd8y17UyC2ZaHze7aaKqYM7e3Z4zkyWWYH43fdp4atry0l7aXlmurLBMzCPAg1PtSPccvw/CdnllSfyQ0n7ZavzRzNIrPEzcpK0r6wuKbGFXP8fdG87jswxmvcNcBrzAd6jWT4GWctm2lkNEvvBm5HckgFjpAcUoEjJIdU0GP1H7jsjqRgSH8qHJGh5aT0PObnzZBV2qyvVesoEmZ/S/2HmD8rXAwf/OTDatjzTz+TtC+uLau+aigdrizfzMxeUuPOXmBPUW5oQvXtG5viNeYGVF8z4N+TKbLDPaxvqnHLixeTdmFoWPXNbi4k7bpwO40NaCW/kGH1P2ppWdC7hsgityM5pAJHSA6poKesLSYPDa+9nVdqBdUXCW96uajV/5LPbCoQ1uVYsDlAxZqxBb0DaSqo1VaT9o+//101bmGN17Gwqd+zmQt838wch2L5+aIaF/mlpN1fGlF9mQKPDfLaKp0Tnvy8x6xzqakjJSb2HUja9a2q6jt7llnbSoXjqXzSa7xrlK8zkRXvFL39cB+3IzmkAkdIDqmgp6wtjAmXttrawkpLa20/e/qnSfsdB7XF9+PvYvZQFhbxONKamScckZ6ntZTIsOVcxoKdnTmrxq1sseZkCmXV5xeZHXjljaTdNzSoxjXrzFKapNlGqcy/rVTUv3Nxfj5pr6+yZXsgq/9M+T5miedXteU8M7A7aV+aP5+0iwsbatx4iefoIz1/GOvw3W7gdiSHVOAIySEVOEJySAW99f77OQSDbctubVnTcCvLltyVmg7kqjXZK17KssofG0tNFR5/39fmhXqTZYJLIqZuaUPLWdJSXB49oPqq8XrSHgHP51tqfDPDa6xXtWxS3+Q5Jsd2qb6akIUWhcpPGW3Br6wIS7Rlwd+qsjnAz/IzWFxfVePmhGlgckQ/b0+LdV3B7UgOqcARkkMq6Clry/f149CvHwEAzP7ilOorDjJrO/LgEdVX8PmwZ1OwCi/QKj5lmMVERpsXBnbzEbwXX57m7x3S7GXv5LuStvE0S8kIlhU32KHbbFpOZrEu31Ktj7/0ctIu5SxHaj+bA/qFBfzi/IIaJ4P2fIvtlQf4GVTEqeTVFa3Sn52vJO09Y+OqL8hqj0E3cDuSQypwhOSQChwhOaSCnspInh+gMNiWSSbv1kH3W4KFH5i6V/WNtFgmWBPJMVqW+h+FrO4e+ehnVN+Bux9I2lO/di5pP//CS2pcucjywsVF7X4IDAfW58Q5OTs/zaZQwSvC1QEA5X6+z44fi4TsMzLKMmOjpX/n0irLN2QdgBgQbpfA5z9vs66D1868MZu0R4e0+eLgPh1w1w2uuiMR0WNEtEhEvxKfDRPRD4notc7/5SvN4bDz0Q1r+yaAR6zPvgrgR8aYgwB+1Ll2uINxVdZmjPkZEd1lffxpAA912o8D+F8Af3a1ucjz4Ofaau3FhZOq7/73fyBp9w9qq7S/wTHQUcjbf2B5xc+8waaBj5SnVB8K+5LmQD9v8/lAB3z1CWtwPqtVa2lF3ruHY7FPvP66GpbNsiV+fUNbtu/adzBp33f4napvZYWtz8USmy8uzi+qcSQyhQ2Vdcx2RViwZXqavoI2h2xt8DOYfkOvsS/bu7Q2Y8aYuU57HsDYNc7jsENw3VqbaSdM3PbcgczYVqmsbzfM4TbHtWptC0Q0YYyZI6IJAIvbDZQZ2w4dOmQy+XY8c72uraeNBqttmaxmbYV+EQMtHKT20e5iwN7Ybz76DdX3u7//FZ6/ygFk2ZyV8UwcfZq6e6/qW1zhY0D1TdbMxnfruOyVdWYbMlMcANx9L2uk99yrNdfKC8eSdnWDjyCtV7XGFYoY660tnedySATZRYZZVmlIW9HDJv9O39OZYWbntv1zbotr3ZG+B+CLnfYXAXz3CmMd7gB0o/7/K4BnABwiolki+jKArwN4mIheA/DJzrXDHYxutLbPb9P1WymvxeE2Rm+PbBOB/Davrm3q81j1GgdyZSyP9sayCN7yWUbKoKLGTQyxWvzayWnVd3FWXNdY1pmZPafGvXecIw/2Tmqv+J5FVk6r02xhH85ZkQZDLDOdOaPnn9jDctfaulY+WkL2WbjE0QWxlYSdhMW6ZslIMp+4vKvfOmiAmM0GWdLn5prL83i7cL42h1TgCMkhFfSWtRkkcdW+lXFkYoQDzAp5zdp+/DJbjsuiPMPBYa3S5nO8rWcDveVfWjyXtOMGW38P3KMt4L747kJJuxBHxtg6vrzC6nllXavn8rjd6Kgu4xEItl1vavNFUzhnt+qybIaOy5bX9YY2L4Qh7w27RviMG5F+Vlni55OzkrlHRptfuoHbkRxSgSMkh1TgCMkhFfT2XBsBmaCtog8WdTDVkAhat4u4rBuR4mWVldqRAb38/izLAZGng93PXTyXtMfK7EaYvFd74Ovitmef1xEKF+ZYthoosvyUyejCNcenz4sr/a7G4rphyUibVVbDh4ZZPQ8t9X9ugV0Y/QM670AgciMUCizrZO1IhhabF6Kqzow7tvsGBLY5OHQDR0gOqaDnyUj9TlLQ8d3WWSpB03Fde6Mn9rGKflSwqDXS1lrjs7V8cESrzIMlZnuZPG/dd1msrTjIZoh/euyfVV9NrGt9i2Oxa1bWNFEME+NlrXbXV9giXs3Za+Tf8+qp15L2woJOdrouIgOGhvSfsNTPgXq+SOWTaeo1+sK6P9qvxYDB/PZlYbeD25EcUoEjJIdU0NvjSJ6XaA+lsmZtoSjqmwu0hnHfFGcFOfo8s6X1jD62FBMHco3t1SzlxMlfJO0Pf+xLSfuZp3+hxlWr7EhtNfVxpMV5WQua38HNln4fAzCrKHs6C8jePp6/cuk11Rf6rAmO7eZ2ZCUHlcFs9S1tVa8Ky3kYMwts1XWV2N0Z1hD3FLUluxFaZWK7gNuRHFKBIySHVOAIySEV9FxGuhxgVR7RAfOhSP9S93Rpz3yRg/9lcPv5N3QA1kc+wClp6ps6uqAwwCr03AU+rjx9+rReRyTKd+pEZqiucyDdwC4+11apaDllUNRdO3Tfu1Xfcy+9mrSPvXpOr/+h307a8gDEmWkdpFcRZ9Jiay+ob7FcNDnG8mRfv/YkDA9znwm0DBY2XSlSh5sER0gOqaCnrM2YGHHY3pYHh/VR6eoWW3lrkd5a5dHjA/s5uOz0ca0+V2rMzor9OpHo/nu4PXOarcsXLs6pcQ8+yEfHazVd3mpAxFsP72Fr+/mVV9W4rQavI9uvj1SXRjlz3HsH9qm+SyJO+9wMZ0mpbungtbUKr8sOnBs0/Hsmi3zf7pLm0xliM0SzpdX9fnKWbYebBEdIDqnAEZJDKuhtvbawhY3lNg/vs86uNUQuAIqt0uKiENvIMHvnT3tn1LjFFfZwL/ta/R8UmdgOv5tNCGdm3lDjWsIhv2YF9R88yClpDk6x0DUzp8/XHT/+Cq9jSbsfsjmWDctFHUA2e5xlrflllmHIMof4InpBRkYAwKQQbw4MsBki72kVv1Hn5xPH2p3UCm9AvTYi2k9EPyGiE0R0nIj+uPO5y9rmkKAb1hYC+FNjzDsBfAjAHxHRO+GytjkIdHP2fw7AXKe9QUQnAezFNWRtazQaODPdZkcHDr5D9eU9kQzdKr0Z5MUWLdoDA9qEUCyxBfzw4UOq73/++6mkXauwRbwwvFuNm57leOj9+7QJYerQ+5J2TmSLu/uAHrcmMq+dOKlNFLFh3nlhTav168IEUo+Y9a+vaRa7e5zNBueXdd/wfmbbyzkhPsSWCSEUFckDHXPeiG9wwvZOCsD3AvglXNY2B4GuCYmIigC+BeBPjDEq+8GVsrbJjG0bG5tvNcRhB6ArQqL2ed9vAfgXY8y3Ox8vdLK14UpZ24wxjxpjHjDGPGCzIoedg6vKSEREAL4B4KQx5u9E1+WsbV9Hl1nbao0QL0636e3Au3XhmhisupOtfopE5jJL7NqajmDcNXx/0v7UIx9Xffe/53DSfvLb3+HvIu06GBxk5XPvHu3CkJlm/ZDXOzyuH+PEFEdIVvq0/PHCS+z6mNvUrgiTYRlvcJzNHCP36LNrvpBpIuvM2ylxBnB6XuRC8PW4LVF3t2Y97jCWz+Sn6Abd2JF+A8AfAHiFiF7sfPY1tAnoyU4GtxkAv9fVNzrsSHSjtf0cOmeThMva5gCgx5btekQ4XWkHWC1F2qprMrzVek1tKTaxLMPO7T0TWnX/zQ+zep7P6DNjU5Psuf+dz34uaf/Hd36gxi2JOmZzFW0dr9c5wCwL5gcrW5o3TM+IgLumPjNmRtgsUd6trd6x0FdkGpo4b40jtnS3rEiJSsT35TM8Lh/ovaBKbDZoZbRl27hy7Q43C46QHFJBT1lbIyKcXmvT7nd//orqu3+SY7jHs/oodkGcgZ4YZ+frxEhJjbvnbqFlGW2dnRNBY489wezs2Isn9BqF8/hNvkvD750Rsd1RTq8j8phVBNCx0qHQEkNP9+XlX0NoY/Wmft+Nx32BZZX2Y2bHps4/IIRm05mY5/RJz99sucA2h5sER0gOqcARkkMq6KmMFIGw2QnS+tExfZ7stdc5SO2R9+tUM/fsYcvu2TPsTf/oB/SZsbxQYzea2mL95H8+l7RfOMEpXWqhlclMyBxexsq2JizsnsgEK2UWAIhEXbdGrOdoiYy0RFrNbkCUKTWiLl2g55CHIQoFHfSWBc8v8r8jssrGR6IztEqdZgd0Avpu4HYkh1TgCMkhFfSUtQVBgF0j7XNYK6vaIju3ygkxn35JnxOLWpPiirfy0XHtVCWf2dSzR3+l+n7w42eSdiMWlmIrhY7nbf9uRSI5uhFsLo61FV2yJdupmgn4kZNvnQn3+bcFos/39Z9JRlH41no9kaUtEuaKGNp6Lfne+Lh2Cg+U+Pp5dAe3IzmkAkdIDqnAEZJDKuhxwnZKeL9dky2ss3xwbkHXMWtUOXH6R9/HdWD7hibUuIo4q/XTXx5VfXXDKm4rZDkil9Muhli4GGo1HVgv4Qt1+k1H5YX4l7PkG/LEtWf15Vh26+tj90kQ6HEtoa5vVHW22kjIbg1RAGiwrNMIjU3wdTGv59+ySsx3A7cjOaQCR0gOqaDH9doM4svnqYxlNfaZxTSh1eLFTU6UfuwUW6U/VdMmhA1RnvzCqt6ec0VWmcMaz19v6OTwhYJgKRn9eORYEgF2nhX3LVV8Y7EvI97djMVWN8V58aaICZdsDtDmBcm+AKAqoheKoiTq0KjOItwMedypV7W5JWOZM7qB25EcUoEjJIdUcNNKkcIqRer7IkbZaFYhA8XOLTLLeuzJp9S4Tzz0QNI+e1HX76hF0sor2EveyvSR5euCr9+zrDhatLXBrKdlOT2NYDcZSyPyA/5t9n2+sGZLB/GWlTlO9vmWdXyozBnido2xVru0vKLGrS1xXPnaeX2s/N4pneGkG7gdySEVOEJySAWOkBxSQU9lJD/wMTzUDpqq17V6LjO3Zn0rYF7IHJ6wiP/s2ZfVuLMX2TRQqeqgsZVNTpUjNF/09+t8BKGwbOdy2voeCPkp38cqsm9ldg/EebLIeldDId9QrM0XRqS8iVq8/mZLH2ToE6l9RnbtUn3lEZaLmsLE0sha1uscrzEOdGRAtX4DitoQUZ6IniWilzoZ2/6y8/kUEf2SiKaJ6N+IKHu1uRx2LrphbQ0AnzDGvAfA/QAeIaIPAfgbAH9vjLkXwCqAL9+4ZTrc6ujm7L8BcFn/zHT+GQCfAPCFzuePA/gLAP94xblig0Zn28xZJNyIRNlMX29uoeAcRgRyeX2aLc0Ild8LNLsJW8xGJKusi6wcAFAVTlA7yE2yOlnRu8/KOOJ5ImF7XrPHvgKvuWlV2V5aYRU9FkfCAyt2vCxKlo4N6/jq8XFW/9eqbInfWNN14zYrHEgoK3oDwNIlneWlG3SbH8nvZCJZBPBDAK8DWDMmcanPop0O0OEORVeEZIyJjDH3A9gH4AiAw1e5JYHM2NaqrV/9BofbEm9L/TfGrAH4CYAHAQwRJUE5+wBc2OaeJGNbplB6qyEOOwDdZGwbBdAyxqwRUR+Ah9EWtH8C4LMAnkCXGdviOEajU481Z2UQK4iVxFaRFelcj8UZ9thys8QiasCuOWYi/j7pPZfty2u8DFtGWl1lOWNFrLFU1LkKBoWbomS5WfIQ2dZiHXkQkDAp5Pi3NKzy9TmRokbeAwBhrSLafN/m2rIaFwuTQj5nlZS3DyV0gW7sSBMAHqd2jjwPwJPGmO8T0QkATxDRXwF4Ae30gA53KLrR2l5GOyWy/fkZtOUlBweQvbXf0C8juoR2vskRAG9fx9yZuNWfxaQxZvRqg3pKSMmXEh01xjxw9ZE7HzvlWTinrUMqcITkkApuFiE9epO+91bEjngWN0VGcth5cKzNIRX0lJCI6BEiOtWJYbrjCgXu5GqcPWNtHcv4abRdLLMAngPweWPMiSveuIPQqSI1YYw5RkQDaKcf+gyALwFYMcZ8vfOClY0xVyyieKuhlzvSEQDTxpgzxpgm2j66T/fw+286jDFzxphjnfYGAFmN8/HOsMfRJq7bCr0kpL0AZEnrOzqGaadV43TC9k3AtVbjvJXRS0K6AGC/uN42hmkn43qqcd7K6CUhPQfgYOf0SRbA59CuQnnHoItqnECXsV23Gnrt/f8UgH8A4AN4zBjz1z378lsARPQRAP8H4BUgidD7Gtpy0pMADqBTjdMYs/KWk9yicJZth1TghG2HVOAIySEVOEJySAWOkBxSgSMkh1TgCMkhFThCckgFjpAcUsH/A/0jrUaEeTXJAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "airplane : 0.005173354\n", "automobile : 0.015833795\n", "bird : 0.020151181\n", "cat : 0.3291398\n", "deer : 0.0040378636\n", "dog : 0.1168231\n", "frog : 0.13499218\n", "horse : 0.0046175853\n", "ship : 0.3520366\n", "truck : 0.017194541\n" ] } ], "source": [ "show_predicted_probability(x_test, y_test, predictions, predicted_probability, 0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 8.4 查看第3项数据预测的概率" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "label: airplane\n", "predict: ship\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAJIAAACPCAYAAAARM4LLAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAFRFJREFUeJztXV1sXMd1/s79We6SXFIUJVGUJZmy9WM5luM2jpuiLmAkNeD2xXlIi7hAkQIB+tICLdCHBEELtEALuC9t3woYqFE/FHUNpECDIkARpE7tALEjWa7tWLIlVbIryZREWeKfyN29P6cPu7pzzpGXXIrXS4maDxA0d2d27uzluXN+5vwQM8PDY70INnoBHpsDnpA8SoEnJI9S4AnJoxR4QvIoBZ6QPEqBJySPUrAuQiKiZ4joQyI6Q0TfLWtRHncf6HYNkkQUAjgF4GkAFwAcBfAcM58ob3kedwuidXz3CQBnmPksABDRywCeBdCVkOr1YR4fHwcABNGA6gvIbY5hoDfKTBB7nmVFm0i/BPKKzL3l/CR7aQ0vUpeh3K0DAJFdiepdoYe7Dlvp5Zf3U+1bn0jX+fOsUbRPnT57lZm3d71hB+shpPsAnBfXFwD8ykpfGB8fx5/9+XcAAMPbDqq+Wlgp2iP1YdW30HTEc2P+06IdBLkal4uHHxlirAnCrYbiZwfmjyIfqunK8uwz+3L5uV1HpB9xEITuVisQmXxJyP5Ocz/9PTfnwID7zZVAv7hgd02VUHUtfXqyaD/1m7/9cdebCXzuwjYR/QERHSOiYwsLi5/37Tw2COvZkS4C2COud3c+U2DmFwC8AAD3T01xzlUAQBqOqXFJPFS0s1DvSEEsdqRlR4yc3VDj4ti1m6zf2kS81Y3IvT+Gw6KVuG09CPWbury0XLRD0RfLGwNotRI3R5CoPs5bYn79HlcqbldO00x8R6+xLZ521m92vLEx91wHanWxDr295uKaBvT6s0X9/HvBenakowAOENE+IqoA+CaAH6xjPo+7GLe9IzFzSkR/BOA/AYQAXmTm90tbmcddhfWwNjDzDwH8sKS1eNzFWBchrRUERsApACAzMkxGuWg3VF+17pY5fv9E0Q7mrqtxw0tOfmo1mnr+4WrRzke3FO16RWtON9cHAIHR/FpNJ99kuVtvtaoFLWlRsKp6N/Xc3i9N3DpyIyNJzbISafmmVquJYUIOgpbVcmSibSScFU0Wnw1/ROJRCjwheZSCvrI2RogUbZU0QEX15aHbv5us1e5QXA8JfX1kUG/r+fGjRbt1VdusJh85VLRpxrG5Jg2pccOh29YXlrV5oSpYxQC7ewfjxlwh1H+j4aM56O4dJZrthYm495BjowNzc2pctOfhor20ZVT15akTC7LAzVfN9fMmwXKDTPeF2dr3F78jeZQCT0gepcATkkcp6KuM1Eabb5NQswEgYCdXZKk5txCCBgnZpEF1NSzOnbxD23aovqUFJ3Mk504V7ZRqalzuRBjciM3hqNDDK4lbY+u8lumQuHEErbs3hBkibOi+yC0RzZ3udy5fuqbG1ckdxtPoNtUnzRKJOAaJA63S5+LcJQz074zsQXYP8DuSRynwhORRCvqr/jOQdRzT8kxv6yxpOtdba0uwwSxy3xtdMCfr253Vu7bjftWXslChK+5n87adatxyLHyJLn2q+iBO/G9UHUvkiXE1LM7db2kY36GhumO/rYUl1dcUpoeo5lTy8Ia29Efjjm1TrJ9jJvyM6oKbhYbFpuRYJwXajNI+Ol0b/I7kUQo8IXmUgv5rbZ3tNrPuqZKdGfLOhJ92TK49cOa0Gtd46/WinX5ZH9pCuJoyDxbtimGPDTh2Mzw9q/pC4bqaDwnfcdaW4Sxxc9bHt6i++KJgl4va+h5PCC30vBsXjWjLeWPmXbemQd2XH3RW74ZwlAtIs7ZKKthoanzf7SFxD/A7kkcp8ITkUQo8IXmUgv46thEhDtuqZmBUTGkOyI1lNRL0PnzdncinFz5R40ZiJ8MsfHJJ9bWq7pSc4azLdOmKGje0S6jnIzZuzqnhtUUng1VmF9S4hnAiS69Oq75Kw5mv03l9qj9wbaRoJ8tOhuHaA2rc7DkXBVapaRmpPunMHqGw0rOxXjeFJ0NKmgxat3jSrQ6/I3mUAk9IHqWgr6wtIMJApb3fcmisqblQ13PjyyyuF2NH+4uPf1GNG4m+VLSXFjS7SUIZxyV+dktv+XHN8YMbWUv1BcKXORHOX3Fg4t8qIvwcGsvClLG0qNc4JO7dEHMMDGv2tbXuYtcy47O9WBPXwkpfS/SekYrfYh43ktvIB+F3JI9S4AnJoxR4QvIoBf2VkYIAQ0PtU/O0qo8VkszF1YO03JIKdZQq7tS9NqEd3+dvuOONmTl9/EDi5L615NTzilV9Z90cqTkrGKg4+WNeHOlUY/MYA3dtM4c0l6QsqCWouWXn5dASwwYjvY76bpdyIbSauozpl/uEDV2DCr5TffltnJGsuiMR0YtEdIWIfiE+20pEPyKi053/x1aaw2PzoxfW9k8AnjGffRfAj5n5AIAfd6497mGsytqY+TUimjIfPwvgqU77JQA/AfCd1eYiIkQd9b1WH1R9iyLcOoo0fWdiu47EKXbAWj3P4a4p1D7hkVDRJUNJWtq5rCZS1ESBfjxxJFLZiPmyVN9LhounxqEsrgm1O9N6d0WYNqRzXJxqFthimYlNz1/NBJvKxLpsVjbxgd1Nbs3utjpuV9ieYOabtv9LACZWGuyx+bFurY3bWRK6WrBkxrb5+dluwzzuctyu1naZiCaZeZqIJgFc6TZQZmw7cPAQVzr5CitVfetc+BrX4qrqS8lt0QvzIiOIyahWHd1atCeGdKiS9NbSWTr0Nh6Kdysk/Z5Vot4eF4sDaMvaMmFhZ6MdBeK6IhmwWUczSLt1IRIabgaZuFX/TsrdbwkNJwttnHkPuN0d6QcAvtVpfwvAv9/mPB6bBL2o//8C4GcADhHRBSL6NoDnATxNRKcB/Ebn2uMeRi9a23Ndur5W8lo87mL0OWMbEHUcrELSqntVeAPMXtEhytcWnXPYzPSFoj1W1/Fkjzx8pGjHVR2KLR25EqEWByaGTspIQdA9o5qUOWxWtkyZKKze3T2tfCDjy9T8JrRbzBGQlhPlHHHo5M7YavQy1sLImlnQP/Xfw0PBE5JHKeh7XNtNlhCZ7TQXbGPBOKXNzDj/69nrLif8qXd/rsZ98M7Pivb+/Q+rvqn9h4v22DZhPzXbeJYLazDr90yODJUzm54jirqXiciFei7rqth5QjGHNdJJVrpSXRJlhmDrfy7uStoy32jp617gdySPUuAJyaMUeELyKAUbkLGtDSkDAEC16o5FHjr0kOrbf/i+or204OSl948fV+PePvZG0X79NV0d6uSJwp0KBw8/VrQPHDqsxm0Zc7H6lYp+PGHYTS5aIaO6kXAS4eiWpzrvgIT0DMiMCSFXRzy9gayMJMwGgfFySHPv/O+xQfCE5FEK+szauFB/AxMLxqKemrUoh8LqvWXc+Ss/+ZROOLp//76i/dP//onqO3fOmQ1uvO0cz6xry5FHXazcnj17VF8kKk9mop5aZkKcc2FCuKVMqWAxtpSqtBSQtKKb911lADL1UqQ5QK3rFvXffS83rNOy0l7gdySPUuAJyaMU9Jm1EajD0gITBhREjt3EoT0EFZZioSEFsQ5pOnDw0aKdp/odmZ7+ftG+ftVlMTnd1BlBLl/8sGg/eEBrj4e/4ObfMTFZtCNTzzRN3LoS488ty4uxOXClboel5tB2JZ9qln1CfLBTs+SPt5T7sslJV4ffkTxKgSckj1LgCcmjFPTdsn0zNUxo+HIoVOGK9QWTpgLB29m8B7JM+u49U6pvaspdH73sHOVSk9F15oozB8xc1RnhTp502WT37dtftB988IAaNzHhLPH1ug4rh0iU3jApdbKWqB0iMtLaE35p2baH/0zdwq2NqUHFxmmEfYxr8/BQ8ITkUQr6nIwUCDtbb2i3YKkmkznMlI5c6oDU6rQiQ0hVx8bV6y7Rp1KzDYuVbMSWAlu47sL33r4qDo/fOarGbR13OTV27tTW8Z2TU2KNmu2NjzuTwvYJVyOFTOBZLkwIqcl2IjOoKMu2eVQkQsLZ1oXxyUg9NgqekDxKgSckj1LQX/WfGcQ3ZSTTJWQmYs33lVMWdZdvpMq8bDLGXrrkVP7paSffzM/pY5ZYOK/Vh3TqnSEhdw1G7nuZceK/KGLvTn90VvU1Gv9VtFNTFn18266ifeSIC144sF/LWdu3O6+HEVOKdKDmZEGZmB5G7knlkk0Cgdbnof4T0R4iepWIThDR+0T0x53PfdY2jwK9sLYUwJ8y88MAvgLgD4noYfisbR4CvcT+TwOY7rQXiOgkgPtwO1nbCEAnhio3fsEsKmvbJKBCUwWFwuJrtmtZbvOd42+pvsXrM0V7q8gWd2F6Ro0bGXWsIY60CSFPXcLUkWERgxbrU/xK5OaPB4ZUXxi4WirXZrXnwccfnSjac7OOPR4/pv9MlYpb1549uk7Jrsm9RXtyl2OJuyb2qnFDw46BUM3E7wWmynkPWJOw3UkB+EsA3oTP2uYh0DMhEdEwgO8D+BNmnpd9K2Vtkxnb5swb6LF50BMhEVGMNhH9MzP/W+fjy51sbVgpaxszv8DMjzPz46NbRj9riMcmwKoyErWD1/8RwElm/lvRdTNr2/PoMWsbc44kbXtCypN6AKDULSUwsehSEmK4PmtCWBQqf2NZ17Q9dNDFr/3yY48X7bfe/YUa9+Yxd9wxt6gz3mapS8WzY9Kp6k8++aQaF1WdjPHRxzq+7o03XH6CLxzW+QlGRt2LdvmSM1FcvnxZjUsSt46dwlMTAPbtm3LrFUcfNxY0N5BBCXGk5bhGq3u8XTf0Ykf6NQC/B+A9IvqfzmffQ5uAXulkcPsYwO+s+e4emwa9aG0/RfeATp+1zQPABji23TxdvzXeyzWtE7xMsppT9yTktUGndv/6U5rGZWyYjE87+NgTatwjX/py0TYVURGIG24bd9niHnjgQTUuEnVWpg48qvp27T3k1lvTWeVGBWuTXgjXrn2qxkmWtWP7TtUnHelCkYU3yLU4nIn6eIl53jn5kG2PDYInJI9S0FfWluc5lpfb1uFwXh+qRizKYJkaI6lIPJ6mwnHLHJbKbGjWlznNZJJz9/60jGPYrr0u7NvW6CRxHYhsbuf+TydPXW6JA2gTUl0fdfPnxjJ/fc6tMRJsaWhkSo2D8Le+Nresuj657NYiTw8GAn04Lc63QcOaDBrXG1gr/I7kUQo8IXmUAk9IHqWgrzLS4sICXnvtVQDAXPqu6hsSJ+ZZU1uUEyFLJKKEepZpC6xUmROTDS0TspBUixtNE1uWSed/faofixj/rVucQ9nw8BY1TpZyt8nPZJZbm/FWJ4SXieO1fBMJp7rAOKXJ7yl/QOPPT6LcKw3qOYKG9ojoBX5H8igFnpA8SkGf49oCVOM2C0tC7Q8divphAwMjqi8XKXBkrJbN7Cat5TJr2s17F+OET3jO5vBYWK/ZZC4jkYZGWg0CaHNFJMqgNpv68FiZA4xlXoaPJ4k4nDb102S2u5XYo0TL+LCzmL+hOTgGQm1J7wV+R/IoBZ6QPEqBJySPUtD3uLa849i2eOO66hoUtcVsUtVM0LtMpddKtJkgTYVpP9AyEgtZSDqG5alJVi7U/yw18XUk1XpRM828jsxOLmo29BGGPNbJbabZvFuOA70OKQveUqtWtOX8YaJlwVTISEtbdP3fnXuGsVb4HcmjFHhC8igFfWVtrWQZ58+/DwA4c0mrzIMiQ21k0slkasN2Gc8yo+Lnudu+40rQtU96AmTWPVmwCqt2ywTrwQqpcULhOGdP+FstwVZNOhnp0Cez/hLpLLPyVN+WKVXWbPF5AvNMx5z5ZdcRXY9lVLtw9wS/I3mUAk9IHqWgz1obIeC2dhbbw0xh2bbJN6W2hEBmLdHbtSxvGpqE8JIDBCzuZZKmq+rWxulNvnaSZdmSYZlYb2LWmIu6KmycwuXPZsk6bXYWpbWZWiSRu05Fu75LB0LvPnKwaEekQ7RnT72HtcLvSB6lwBOSRynwhORRCvpu2U47Yc9ZS1t8E5FKJU21aQBCfpJVM3MjOwTSsc3IN7mQVaTanZt4r0rs1mHEDzWHVNXtuExaka2KL9ZoS9ZL+YyELAjjoRCLG9qyocmgM6OMHXIpb+6b0lnfGiIM/OwHOgVQNVnEWtFLxrYqEf2ciN7pZGz7y87n+4joTSI6Q0T/SkSV1eby2LzohbU1AXyVmb8I4DEAzxDRVwD8DYC/Y+b9AK4D+Pbnt0yPOx29xP4zgJt7Xdz5xwC+CuB3O5+/BOAvAPzDipMRgM5uHsbGIUvYA+LI8hRxLeLJQph6Z3LdZBOaut4BkWFtbESnvpRh2VlmrMG5dDYT8w3ozVg6qNnaatJskBn2u7DgWIo0V0iTAQDMCwfsaJte/96DTq0fG3N+5Rc/OKPGfXrmnJvDnBBUzd+mF/SaHynsZCK5AuBHAP4XwCxzITRcQDsdoMc9ip4IiZkzZn4MwG4ATwB4aJWvFJAZ2xrNtefd8bg7sCb1n5lnAbwK4FcBbCEqzMe7AVzs8p0iY1t1YO0lLj3uDvSSsW07gISZZ4moBuBptAXtVwF8A8DL6DFjGxgIb9aabZnMtWiKYXrnCsWJv2xbpy7p8H9rKXQRGCBi3paWtFO8jtW3J+vCbJA4+aaRWFnts+PTOp1iQt2Vyd8t12vMBPUdTi7afnCf6gvEmj88+mbRbl7RDv2hcLALjWeedbjrBb3YkSYBvETtEIoAwCvM/B9EdALAy0T0VwDeRjs9oMc9il60tnfRTolsPz+Ltrzk4QGyJ+2f682IZtDON7kNwNW+3fjOxp3+LO5n5u2rDeorIRU3JTrGzI+vPnLzY7M8C39o61EKPCF5lIKNIqQXNui+dyI2xbPYEBnJY/PBszaPUtBXQiKiZ4jow44P0z1XKHAzV+PsG2vrWMZPoX3EcgHAUQDPMfOJFb+4idCpIjXJzMeJqA7gLQBfB/D7AK4x8/OdF2yMmVcuoniHoZ870hMAzjDzWWZuoX1G92wf77/hYOZpZj7eaS8AkNU4X+oMewlt4rqr0E9Cug/AeXF9T/swbbZqnF7Y3gDcbjXOOxn9JKSLAGQoQ1cfps2M9VTjvJPRT0I6CuBAJ/qkAuCbaFehvGfQQzVOoFffrjsM/T79/y0Af492CMCLzPzXfbv5HQAiehLA6wDeg/Oa+x7actIrAPaiU42Tma995iR3KLxl26MUeGHboxR4QvIoBZ6QPEqBJySPUuAJyaMUeELyKAWekDxKgSckj1Lw/0+wFLljZLdHAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "airplane : 0.35164014\n", "automobile : 0.030136198\n", "bird : 0.00475265\n", "cat : 0.0015198193\n", "deer : 0.0004779164\n", "dog : 0.000117743686\n", "frog : 0.00021989802\n", "horse : 0.00014916781\n", "ship : 0.57337856\n", "truck : 0.037607934\n" ] } ], "source": [ "show_predicted_probability(x_test, y_test, predictions, predicted_probability, 3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 9. 显示混淆矩阵" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 9.1 查看预测结果的形状" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "predictons shape: (10000,)\n" ] } ], "source": [ "print('predictons shape:', predictions.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 9.2 查看测试 label 真实值的形状" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "y_test shape: (10000, 1)\n", "[[3]\n", " [8]\n", " [8]\n", " ...\n", " [5]\n", " [1]\n", " [7]]\n" ] } ], "source": [ "print('y_test shape:', y_test.shape)\n", "print(y_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 9.3 将测试 label 真实值转换为一维数据" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([3, 8, 8, ..., 5, 1, 7])" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_test.reshape(-1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 9.4 建立混淆矩阵" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
predict0123456789
label
066221461912314316852
149071422402650
24776216375529511218
31413535754713893162328
41385575705246437136
559461943561540271613
6210254119158685132
7139305760599723634
8202258114392016
9139037237420851
\n", "
" ], "text/plain": [ "predict 0 1 2 3 4 5 6 7 8 9\n", "label \n", "0 662 21 46 19 12 3 14 3 168 52\n", "1 4 907 1 4 2 2 4 0 26 50\n", "2 47 7 621 63 75 52 95 11 21 8\n", "3 14 13 53 575 47 138 93 16 23 28\n", "4 13 8 55 75 705 24 64 37 13 6\n", "5 5 9 46 194 35 615 40 27 16 13\n", "6 2 10 25 41 19 15 868 5 13 2\n", "7 13 9 30 57 60 59 9 723 6 34\n", "8 20 22 5 8 1 1 4 3 920 16\n", "9 13 90 3 7 2 3 7 4 20 851" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pd.crosstab(y_test.reshape(-1), predictions, rownames=['label'], colnames=['predict'])" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{0: 'airplane', 1: 'automobile', 2: 'bird', 3: 'cat', 4: 'deer', 5: 'dog', 6: 'frog', 7: 'horse', 8: 'ship', 9: 'truck'}\n" ] } ], "source": [ "print(label_dict)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 10. 保存模型" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 10.1 保存为 json" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "model_json = model.to_json()\n", "with open('save_model/model_cifar10_cnn_deeper.json', 'w') as json_file:\n", " json_file.write(model_json)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 10.2 保存为 yaml" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "model_yaml = model.to_yaml()\n", "with open('save_model/model_cifar10_cnn_deeper.yaml', 'w') as yaml_file:\n", " yaml_file.write(model_yaml)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 10.3 保存为 h5" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "模型已保存到: save_model/model_cifar10_cnn_deeper.h5\n" ] } ], "source": [ "model.save_weights('save_model/model_cifar10_cnn_deeper.h5', save_format='h5')\n", "print('模型已保存到:', 'save_model/model_cifar10_cnn_deeper.h5')" ] } ], "metadata": { "kernelspec": { "display_name": "tensorflow-keras-practice", "language": "python", "name": "tensorflow-keras-practice" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }