{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "day34_01_simpleRNN.ipynb", "provenance": [], "collapsed_sections": [], "authorship_tag": "ABX9TyOcbQPro0QML4vSSylreSP3", "include_colab_link": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "<a href=\"https://colab.research.google.com/github/rkti498/e_shikaku/blob/main/day34_01_simpleRNN.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" ] }, { "cell_type": "markdown", "metadata": { "id": "92THlk0B11ZF" }, "source": [ "# RNN バイナリ加算のサンプル" ] }, { "cell_type": "code", "metadata": { "id": "j0boG2fNAaBl" }, "source": [ "import numpy as np\n", "\n", "# 中間層の活性化関数\n", "# シグモイド関数(ロジスティック関数)\n", "def sigmoid(x):\n", " return 1/(1 + np.exp(-x))\n", "\n", "# ReLU関数\n", "def relu(x):\n", " return np.maximum(0, x)\n", "\n", "# ステップ関数(閾値0)\n", "def step_function(x):\n", " return np.where( x > 0, 1, 0) \n", "\n", "# 出力層の活性化関数\n", "# ソフトマックス関数\n", "def softmax(x):\n", " if x.ndim == 2:\n", " x = x.T\n", " x = x - np.max(x, axis=0)\n", " y = np.exp(x) / np.sum(np.exp(x), axis=0)\n", " return y.T\n", "\n", " x = x - np.max(x) # オーバーフロー対策\n", " return np.exp(x) / np.sum(np.exp(x))\n", "\n", "# ソフトマックスとクロスエントロピーの複合関数\n", "def softmax_with_loss(d, x):\n", " y = softmax(x)\n", " return cross_entropy_error(d, y)\n", "\n", "# 誤差関数\n", "# 平均二乗誤差\n", "def mean_squared_error(d, y):\n", " return np.mean(np.square(d - y)) / 2\n", "\n", "# クロスエントロピー\n", "def cross_entropy_error(d, y):\n", " if y.ndim == 1:\n", " d = d.reshape(1, d.size)\n", " y = y.reshape(1, y.size)\n", " \n", " # 教師データがone-hot-vectorの場合、正解ラベルのインデックスに変換\n", " if d.size == y.size:\n", " d = d.argmax(axis=1)\n", " \n", " batch_size = y.shape[0]\n", " return -np.sum(np.log(y[np.arange(batch_size), d] + 1e-7)) / batch_size\n", "\n", "\n", "\n", "# 活性化関数の導関数\n", "# シグモイド関数(ロジスティック関数)の導関数\n", "def d_sigmoid(x):\n", " dx = (1.0 - sigmoid(x)) * sigmoid(x)\n", " return dx\n", "\n", "# ReLU関数の導関数\n", "def d_relu(x):\n", " return np.where( x > 0, 1, 0)\n", " \n", "# ステップ関数の導関数\n", "def d_step_function(x):\n", " return 0\n", "\n", "# 平均二乗誤差の導関数\n", "def d_mean_squared_error(d, y):\n", " if type(d) == np.ndarray:\n", " batch_size = d.shape[0]\n", " dx = (y - d)/batch_size\n", " else:\n", " dx = y - d\n", " return dx\n", "\n", "\n", "# ソフトマックスとクロスエントロピーの複合導関数\n", "def d_softmax_with_loss(d, y):\n", " batch_size = d.shape[0]\n", " if d.size == y.size: # 教師データがone-hot-vectorの場合\n", " dx = (y - d) / batch_size\n", " else:\n", " dx = y.copy()\n", " dx[np.arange(batch_size), d] -= 1\n", " dx = dx / batch_size\n", " return dx\n", "\n", "# シグモイドとクロスエントロピーの複合導関数\n", "def d_sigmoid_with_loss(d, y):\n", " return y - d\n", "\n", "# 数値微分\n", "def numerical_gradient(f, x):\n", " h = 1e-4\n", " grad = np.zeros_like(x)\n", "\n", " for idx in range(x.size):\n", " tmp_val = x[idx]\n", " # f(x + h)の計算\n", " x[idx] = tmp_val + h\n", " fxh1 = f(x)\n", "\n", " # f(x - h)の計算\n", " x[idx] = tmp_val - h\n", " fxh2 = f(x)\n", "\n", " grad[idx] = (fxh1 - fxh2) / (2 * h)\n", " # 値を元に戻す\n", " x[idx] = tmp_val\n", "\n", " return grad\n", "\n", "\n", "def im2col(input_data, filter_h, filter_w, stride=1, pad=0):\n", " N, C, H, W = input_data.shape\n", " out_h = (H + 2*pad - filter_h)//stride + 1\n", " out_w = (W + 2*pad - filter_w)//stride + 1\n", "\n", " img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')\n", " col = np.zeros((N, C, filter_h, filter_w, out_h, out_w))\n", "\n", " for y in range(filter_h):\n", " y_max = y + stride*out_h\n", " for x in range(filter_w):\n", " x_max = x + stride*out_w\n", " col[:, :, y, x, :, :] = img[:, :, y:y_max:stride, x:x_max:stride]\n", "\n", " col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N*out_h*out_w, -1)\n", " return col\n", "\n", "\n", "def col2im(col, input_shape, filter_h, filter_w, stride=1, pad=0):\n", " N, C, H, W = input_shape\n", " out_h = (H + 2*pad - filter_h)//stride + 1\n", " out_w = (W + 2*pad - filter_w)//stride + 1\n", " col = col.reshape(N, out_h, out_w, C, filter_h, filter_w).transpose(0, 3, 4, 5, 1, 2)\n", "\n", " img = np.zeros((N, C, H + 2*pad + stride - 1, W + 2*pad + stride - 1))\n", " for y in range(filter_h):\n", " y_max = y + stride*out_h\n", " for x in range(filter_w):\n", " x_max = x + stride*out_w\n", " img[:, :, y:y_max:stride, x:x_max:stride] += col[:, :, y, x, :, :]\n", "\n", " return img[:, :, pad:H + pad, pad:W + pad]" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "Ma8DLL-i1l_0", "outputId": "7e85f0b8-cf2f-4d50-b263-8ebac7e7bbd4" }, "source": [ "import numpy as np\n", "# from common import functions\n", "import matplotlib.pyplot as plt\n", "\n", "\n", "def d_tanh(x):\n", " return 1/(np.cosh(x) ** 2)\n", "\n", "# データを用意\n", "# 2進数の桁数\n", "binary_dim = 8\n", "# 最大値 + 1\n", "largest_number = pow(2, binary_dim)\n", "# largest_numberまで2進数を用意\n", "binary = np.unpackbits(np.array([range(largest_number)],dtype=np.uint8).T,axis=1)\n", "\n", "input_layer_size = 2\n", "hidden_layer_size = 16\n", "output_layer_size = 1\n", "\n", "weight_init_std = 1\n", "learning_rate = 0.1\n", "\n", "iters_num = 10000\n", "plot_interval = 100\n", "\n", "# ウェイト初期化 (バイアスは簡単のため省略)\n", "# W_in = weight_init_std * np.random.randn(input_layer_size, hidden_layer_size)\n", "# W_out = weight_init_std * np.random.randn(hidden_layer_size, output_layer_size)\n", "# W = weight_init_std * np.random.randn(hidden_layer_size, hidden_layer_size)\n", "# Xavier\n", "W_in = np.random.randn(input_layer_size, hidden_layer_size) / (np.sqrt(input_layer_size))\n", "W_out = np.random.randn(hidden_layer_size, output_layer_size) / (np.sqrt(hidden_layer_size))\n", "W = np.random.randn(hidden_layer_size, hidden_layer_size) / (np.sqrt(hidden_layer_size))\n", "# He\n", "# W_in = np.random.randn(input_layer_size, hidden_layer_size) / (np.sqrt(input_layer_size)) * np.sqrt(2)\n", "# W_out = np.random.randn(hidden_layer_size, output_layer_size) / (np.sqrt(hidden_layer_size)) * np.sqrt(2)\n", "# W = np.random.randn(hidden_layer_size, hidden_layer_size) / (np.sqrt(hidden_layer_size)) * np.sqrt(2)\n", "\n", "\n", "# 勾配\n", "W_in_grad = np.zeros_like(W_in)\n", "W_out_grad = np.zeros_like(W_out)\n", "W_grad = np.zeros_like(W)\n", "\n", "u = np.zeros((hidden_layer_size, binary_dim + 1))\n", "z = np.zeros((hidden_layer_size, binary_dim + 1))\n", "y = np.zeros((output_layer_size, binary_dim))\n", "\n", "delta_out = np.zeros((output_layer_size, binary_dim))\n", "delta = np.zeros((hidden_layer_size, binary_dim + 1))\n", "\n", "all_losses = []\n", "\n", "for i in range(iters_num):\n", " \n", " # A, B初期化 (a + b = d)\n", " a_int = np.random.randint(largest_number/2)\n", " a_bin = binary[a_int] # binary encoding\n", " b_int = np.random.randint(largest_number/2)\n", " b_bin = binary[b_int] # binary encoding\n", " \n", " # 正解データ\n", " d_int = a_int + b_int\n", " d_bin = binary[d_int]\n", " \n", " # 出力バイナリ\n", " out_bin = np.zeros_like(d_bin)\n", " \n", " # 時系列全体の誤差\n", " all_loss = 0 \n", " \n", " # 時系列ループ\n", " for t in range(binary_dim):\n", " # 入力値\n", " X = np.array([a_bin[ - t - 1], b_bin[ - t - 1]]).reshape(1, -1)\n", " # 時刻tにおける正解データ\n", " dd = np.array([d_bin[binary_dim - t - 1]])\n", " \n", " u[:,t+1] = np.dot(X, W_in) + np.dot(z[:,t].reshape(1, -1), W)\n", "\n", " z[:,t+1] = sigmoid(u[:,t+1])\n", " y[:,t] = sigmoid(np.dot(z[:,t+1].reshape(1, -1), W_out))\n", "\n", " # z[:,t+1] = relu(u[:,t+1])\n", " # y[:,t] = relu(np.dot(z[:,t+1].reshape(1, -1), W_out))\n", "\n", " z[:,t+1] = np.tanh(u[:,t+1]) \n", " y[:,t] = np.tanh(np.dot(z[:,t+1].reshape(1, -1), W_out))\n", "\n", "\n", " #誤差\n", " loss = mean_squared_error(dd, y[:,t])\n", " \n", " delta_out[:,t] = d_mean_squared_error(dd, y[:,t]) * d_sigmoid(y[:,t]) \n", " \n", " all_loss += loss\n", "\n", " out_bin[binary_dim - t - 1] = np.round(y[:,t])\n", " \n", " \n", " for t in range(binary_dim)[::-1]:\n", " X = np.array([a_bin[-t-1],b_bin[-t-1]]).reshape(1, -1) \n", "\n", " delta[:,t] = (np.dot(delta[:,t+1].T, W.T) + np.dot(delta_out[:,t].T, W_out.T)) * d_sigmoid(u[:,t+1])\n", "\n", " # delta[:,t] = (np.dot(delta[:,t+1].T, W.T) + np.dot(delta_out[:,t].T, W_out.T)) * d_relu(u[:,t+1])\n", "\n", " # delta[:,t] = (np.dot(delta[:,t+1].T, W.T) + np.dot(delta_out[:,t].T, W_out.T)) * d_tanh(u[:,t+1]) \n", "\n", " # 勾配更新\n", " W_out_grad += np.dot(z[:,t+1].reshape(-1,1), delta_out[:,t].reshape(-1,1))\n", " W_grad += np.dot(z[:,t].reshape(-1,1), delta[:,t].reshape(1,-1))\n", " W_in_grad += np.dot(X.T, delta[:,t].reshape(1,-1))\n", " \n", " # 勾配適用\n", " W_in -= learning_rate * W_in_grad\n", " W_out -= learning_rate * W_out_grad\n", " W -= learning_rate * W_grad\n", " \n", " W_in_grad *= 0\n", " W_out_grad *= 0\n", " W_grad *= 0\n", " \n", "\n", " if(i % plot_interval == 0):\n", " all_losses.append(all_loss) \n", " print(\"iters:\" + str(i))\n", " print(\"Loss:\" + str(all_loss))\n", " print(\"Pred:\" + str(out_bin))\n", " print(\"True:\" + str(d_bin))\n", " out_int = 0\n", " for index,x in enumerate(reversed(out_bin)):\n", " out_int += x * pow(2, index)\n", " print(str(a_int) + \" + \" + str(b_int) + \" = \" + str(out_int))\n", " print(\"------------\")\n", "\n", "lists = range(0, iters_num, plot_interval)\n", "plt.plot(lists, all_losses, label=\"loss\")\n", "plt.show()" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "iters:0\n", "Loss:2.153564894745493\n", "Pred:[0 0 1 0 0 0 0 0]\n", "True:[1 1 0 1 0 1 0 0]\n", "121 + 91 = 32\n", "------------\n", "iters:100\n", "Loss:0.5659978109753718\n", "Pred:[0 0 1 1 1 1 1 0]\n", "True:[0 0 1 1 0 0 1 1]\n", "24 + 27 = 62\n", "------------\n", "iters:200\n", "Loss:1.5234273549747595\n", "Pred:[0 0 0 0 1 0 0 0]\n", "True:[0 1 0 1 1 0 1 1]\n", "5 + 86 = 8\n", "------------\n", "iters:300\n", "Loss:1.0732684500064333\n", "Pred:[0 0 0 1 1 1 0 0]\n", "True:[0 1 1 1 1 1 0 1]\n", "111 + 14 = 28\n", "------------\n", "iters:400\n", "Loss:1.3702915075700555\n", "Pred:[0 0 0 0 1 1 1 0]\n", "True:[0 1 1 1 1 0 0 1]\n", "10 + 111 = 14\n", "------------\n", "iters:500\n", "Loss:0.9647743604989347\n", "Pred:[0 1 0 0 1 1 0 0]\n", "True:[1 0 0 1 0 1 1 0]\n", "70 + 80 = 76\n", "------------\n", "iters:600\n", "Loss:1.1480619329217434\n", "Pred:[1 1 1 1 1 1 1 0]\n", "True:[1 1 1 0 0 0 0 1]\n", "114 + 111 = 254\n", "------------\n", "iters:700\n", "Loss:1.0431930834757628\n", "Pred:[1 1 1 1 1 0 1 0]\n", "True:[1 0 1 1 1 1 0 0]\n", "73 + 115 = 250\n", "------------\n", "iters:800\n", "Loss:0.8700848545561863\n", "Pred:[0 0 1 1 1 0 0 0]\n", "True:[0 1 0 0 0 0 0 0]\n", "24 + 40 = 56\n", "------------\n", "iters:900\n", "Loss:0.9435390961625657\n", "Pred:[1 1 0 1 0 0 0 0]\n", "True:[1 0 0 1 0 1 1 0]\n", "69 + 81 = 208\n", "------------\n", "iters:1000\n", "Loss:1.2748087371021117\n", "Pred:[0 0 1 1 1 1 0 1]\n", "True:[0 1 0 0 1 1 1 0]\n", "25 + 53 = 61\n", "------------\n", "iters:1100\n", "Loss:1.0681730984674394\n", "Pred:[0 0 1 1 1 0 0 1]\n", "True:[0 1 0 0 0 1 0 1]\n", "25 + 44 = 57\n", "------------\n", "iters:1200\n", "Loss:0.5352576913311647\n", "Pred:[0 1 1 1 0 0 0 0]\n", "True:[0 1 1 1 1 0 0 0]\n", "24 + 96 = 112\n", "------------\n", "iters:1300\n", "Loss:0.6447085338301011\n", "Pred:[0 0 0 0 0 0 0 0]\n", "True:[0 0 0 0 1 1 0 1]\n", "3 + 10 = 0\n", "------------\n", "iters:1400\n", "Loss:0.8743242683856232\n", "Pred:[0 0 1 0 0 0 0 0]\n", "True:[0 0 1 1 1 1 0 1]\n", "17 + 44 = 32\n", "------------\n", "iters:1500\n", "Loss:0.9926579862906834\n", "Pred:[0 0 1 0 0 1 0 0]\n", "True:[0 0 1 1 1 1 1 0]\n", "26 + 36 = 36\n", "------------\n", "iters:1600\n", "Loss:1.1973668032523332\n", "Pred:[1 1 1 1 1 1 0 1]\n", "True:[1 1 0 0 0 0 1 1]\n", "86 + 109 = 253\n", "------------\n", "iters:1700\n", "Loss:0.5712612361030289\n", "Pred:[0 0 1 1 0 1 1 0]\n", "True:[0 0 1 0 0 1 1 1]\n", "19 + 20 = 54\n", "------------\n", "iters:1800\n", "Loss:1.1251359893632809\n", "Pred:[1 0 1 1 1 1 0 0]\n", "True:[1 0 1 0 0 0 0 0]\n", "93 + 67 = 188\n", "------------\n", "iters:1900\n", "Loss:0.6588827075097232\n", "Pred:[1 1 1 1 1 0 1 0]\n", "True:[1 0 1 1 1 1 1 0]\n", "68 + 122 = 250\n", "------------\n", "iters:2000\n", "Loss:0.8459203533313725\n", "Pred:[0 1 0 1 1 1 0 0]\n", "True:[1 0 0 1 1 1 0 0]\n", "125 + 31 = 92\n", "------------\n", "iters:2100\n", "Loss:0.6806428475661361\n", "Pred:[1 1 0 1 1 1 0 0]\n", "True:[1 1 0 1 1 0 0 0]\n", "123 + 93 = 220\n", "------------\n", "iters:2200\n", "Loss:0.28051406170320214\n", "Pred:[0 1 0 0 1 0 0 0]\n", "True:[0 1 0 0 1 1 0 0]\n", "41 + 35 = 72\n", "------------\n", "iters:2300\n", "Loss:0.2678601472715533\n", "Pred:[0 0 1 1 1 0 1 1]\n", "True:[0 0 1 1 1 0 0 1]\n", "26 + 31 = 59\n", "------------\n", "iters:2400\n", "Loss:0.5053430276519022\n", "Pred:[0 1 0 0 1 1 1 1]\n", "True:[0 1 0 1 1 1 1 1]\n", "25 + 70 = 79\n", "------------\n", "iters:2500\n", "Loss:0.4214456423294395\n", "Pred:[1 1 1 0 1 0 0 0]\n", "True:[0 1 1 0 1 0 0 0]\n", "69 + 35 = 232\n", "------------\n", "iters:2600\n", "Loss:0.43138557041268244\n", "Pred:[0 0 1 1 0 0 1 1]\n", "True:[0 0 1 1 0 1 1 1]\n", "22 + 33 = 51\n", "------------\n", "iters:2700\n", "Loss:0.08342940450333307\n", "Pred:[1 1 1 0 1 0 1 0]\n", "True:[1 1 1 0 1 0 1 0]\n", "127 + 107 = 234\n", "------------\n", "iters:2800\n", "Loss:0.9843271892876817\n", "Pred:[1 1 0 1 1 1 0 0]\n", "True:[1 0 0 0 1 0 0 0]\n", "17 + 119 = 220\n", "------------\n", "iters:2900\n", "Loss:0.27041290556212505\n", "Pred:[0 1 1 1 0 1 1 1]\n", "True:[0 1 1 1 0 0 1 1]\n", "85 + 30 = 119\n", "------------\n", "iters:3000\n", "Loss:0.17747346976447112\n", "Pred:[0 0 1 1 1 1 0 1]\n", "True:[0 0 1 1 1 1 0 1]\n", "36 + 25 = 61\n", "------------\n", "iters:3100\n", "Loss:0.15474438566810336\n", "Pred:[0 1 0 0 1 1 0 1]\n", "True:[0 1 0 0 1 1 0 1]\n", "31 + 46 = 77\n", "------------\n", "iters:3200\n", "Loss:0.2925576273470318\n", "Pred:[1 0 1 1 1 1 1 0]\n", "True:[1 0 0 1 1 1 1 0]\n", "38 + 120 = 190\n", "------------\n", "iters:3300\n", "Loss:0.058212540150715524\n", "Pred:[0 1 1 0 0 0 0 1]\n", "True:[0 1 1 0 0 0 0 1]\n", "65 + 32 = 97\n", "------------\n", "iters:3400\n", "Loss:0.23097330361993285\n", "Pred:[0 1 1 1 0 0 0 1]\n", "True:[0 1 1 1 0 0 0 1]\n", "70 + 43 = 113\n", "------------\n", "iters:3500\n", "Loss:0.12528287152008144\n", "Pred:[0 1 0 1 1 0 1 0]\n", "True:[0 1 0 1 1 0 1 0]\n", "33 + 57 = 90\n", "------------\n", "iters:3600\n", "Loss:0.11931292858767471\n", "Pred:[1 0 0 1 0 1 1 1]\n", "True:[1 0 0 1 0 1 1 1]\n", "35 + 116 = 151\n", "------------\n", "iters:3700\n", "Loss:0.03487172172445985\n", "Pred:[0 1 1 0 0 0 0 0]\n", "True:[0 1 1 0 0 0 0 0]\n", "24 + 72 = 96\n", "------------\n", "iters:3800\n", "Loss:0.42386621971691407\n", "Pred:[1 1 0 0 0 1 1 0]\n", "True:[1 0 0 0 0 1 1 0]\n", "56 + 78 = 198\n", "------------\n", "iters:3900\n", "Loss:0.04679461229142709\n", "Pred:[0 1 0 0 1 1 1 0]\n", "True:[0 1 0 0 1 1 1 0]\n", "51 + 27 = 78\n", "------------\n", "iters:4000\n", "Loss:0.08500932952686045\n", "Pred:[0 1 0 1 1 1 0 0]\n", "True:[0 1 0 1 1 1 0 0]\n", "16 + 76 = 92\n", "------------\n", "iters:4100\n", "Loss:0.028449876185565142\n", "Pred:[1 1 1 1 0 0 1 1]\n", "True:[1 1 1 1 0 0 1 1]\n", "125 + 118 = 243\n", "------------\n", "iters:4200\n", "Loss:0.010387937689509383\n", "Pred:[1 0 0 0 0 1 1 1]\n", "True:[1 0 0 0 0 1 1 1]\n", "87 + 48 = 135\n", "------------\n", "iters:4300\n", "Loss:0.16121578777568452\n", "Pred:[1 0 0 0 0 1 0 1]\n", "True:[1 0 0 0 0 1 0 1]\n", "55 + 78 = 133\n", "------------\n", "iters:4400\n", "Loss:0.008219474578897448\n", "Pred:[0 1 1 1 0 1 1 1]\n", "True:[0 1 1 1 0 1 1 1]\n", "7 + 112 = 119\n", "------------\n", "iters:4500\n", "Loss:0.3146766921785453\n", "Pred:[1 1 1 1 1 1 0 0]\n", "True:[1 0 1 1 1 1 0 0]\n", "81 + 107 = 252\n", "------------\n", "iters:4600\n", "Loss:0.023366812397411115\n", "Pred:[1 1 1 0 0 0 1 0]\n", "True:[1 1 1 0 0 0 1 0]\n", "107 + 119 = 226\n", "------------\n", "iters:4700\n", "Loss:0.021387873678467333\n", "Pred:[1 1 0 1 0 1 0 0]\n", "True:[1 1 0 1 0 1 0 0]\n", "98 + 114 = 212\n", "------------\n", "iters:4800\n", "Loss:0.030882954549819937\n", "Pred:[0 1 1 0 1 1 0 1]\n", "True:[0 1 1 0 1 1 0 1]\n", "11 + 98 = 109\n", "------------\n", "iters:4900\n", "Loss:0.12399539928882054\n", "Pred:[0 1 0 0 1 0 1 1]\n", "True:[0 1 0 0 1 0 1 1]\n", "61 + 14 = 75\n", "------------\n", "iters:5000\n", "Loss:0.040624100294295765\n", "Pred:[0 1 1 1 0 1 1 0]\n", "True:[0 1 1 1 0 1 1 0]\n", "99 + 19 = 118\n", "------------\n", "iters:5100\n", "Loss:0.029890623469243406\n", "Pred:[1 0 0 1 1 0 1 1]\n", "True:[1 0 0 1 1 0 1 1]\n", "94 + 61 = 155\n", "------------\n", "iters:5200\n", "Loss:0.006747409869158992\n", "Pred:[1 1 0 1 0 1 0 0]\n", "True:[1 1 0 1 0 1 0 0]\n", "104 + 108 = 212\n", "------------\n", "iters:5300\n", "Loss:0.11362048400667227\n", "Pred:[0 0 1 1 1 1 0 0]\n", "True:[0 0 1 1 1 1 0 0]\n", "43 + 17 = 60\n", "------------\n", "iters:5400\n", "Loss:0.036499078282814464\n", "Pred:[1 0 0 1 0 0 0 1]\n", "True:[1 0 0 1 0 0 0 1]\n", "104 + 41 = 145\n", "------------\n", "iters:5500\n", "Loss:0.0020271442734815285\n", "Pred:[1 0 1 1 1 0 0 0]\n", "True:[1 0 1 1 1 0 0 0]\n", "92 + 92 = 184\n", "------------\n", "iters:5600\n", "Loss:0.019150168217955153\n", "Pred:[0 0 0 1 0 0 1 1]\n", "True:[0 0 0 1 0 0 1 1]\n", "14 + 5 = 19\n", "------------\n", "iters:5700\n", "Loss:0.015615537960966534\n", "Pred:[0 1 1 0 1 1 1 0]\n", "True:[0 1 1 0 1 1 1 0]\n", "11 + 99 = 110\n", "------------\n", "iters:5800\n", "Loss:0.055848845352638045\n", "Pred:[0 1 0 0 0 0 0 0]\n", "True:[0 1 0 0 0 0 0 0]\n", "3 + 61 = 64\n", "------------\n", "iters:5900\n", "Loss:0.027811941122713048\n", "Pred:[1 0 1 1 0 1 0 1]\n", "True:[1 0 1 1 0 1 0 1]\n", "85 + 96 = 181\n", "------------\n", "iters:6000\n", "Loss:0.010294577678977852\n", "Pred:[1 0 1 0 0 0 1 0]\n", "True:[1 0 1 0 0 0 1 0]\n", "106 + 56 = 162\n", "------------\n", "iters:6100\n", "Loss:0.0294434376838914\n", "Pred:[0 1 0 1 0 0 0 0]\n", "True:[0 1 0 1 0 0 0 0]\n", "63 + 17 = 80\n", "------------\n", "iters:6200\n", "Loss:0.09631313639715536\n", "Pred:[0 1 0 0 0 1 0 1]\n", "True:[0 1 0 0 0 1 0 1]\n", "24 + 45 = 69\n", "------------\n", "iters:6300\n", "Loss:0.01924271266456333\n", "Pred:[1 1 0 0 1 0 0 1]\n", "True:[1 1 0 0 1 0 0 1]\n", "110 + 91 = 201\n", "------------\n", "iters:6400\n", "Loss:0.08407179514535484\n", "Pred:[0 1 1 1 1 1 1 0]\n", "True:[0 1 1 1 1 1 1 0]\n", "20 + 106 = 126\n", "------------\n", "iters:6500\n", "Loss:0.004977384528896935\n", "Pred:[1 0 0 0 1 0 0 0]\n", "True:[1 0 0 0 1 0 0 0]\n", "36 + 100 = 136\n", "------------\n", "iters:6600\n", "Loss:0.023494576741571575\n", "Pred:[1 0 0 0 1 1 1 0]\n", "True:[1 0 0 0 1 1 1 0]\n", "95 + 47 = 142\n", "------------\n", "iters:6700\n", "Loss:0.003000975419069599\n", "Pred:[1 0 1 1 0 0 1 1]\n", "True:[1 0 1 1 0 0 1 1]\n", "96 + 83 = 179\n", "------------\n", "iters:6800\n", "Loss:0.0021108360627262197\n", "Pred:[1 0 0 1 1 0 0 1]\n", "True:[1 0 0 1 1 0 0 1]\n", "87 + 66 = 153\n", "------------\n", "iters:6900\n", "Loss:0.043114944696639484\n", "Pred:[0 1 0 1 1 0 1 0]\n", "True:[0 1 0 1 1 0 1 0]\n", "63 + 27 = 90\n", "------------\n", "iters:7000\n", "Loss:0.006124100758764581\n", "Pred:[0 1 1 1 1 0 0 0]\n", "True:[0 1 1 1 1 0 0 0]\n", "114 + 6 = 120\n", "------------\n", "iters:7100\n", "Loss:0.06282776676401762\n", "Pred:[0 1 1 0 0 0 0 1]\n", "True:[0 1 1 0 0 0 0 1]\n", "69 + 28 = 97\n", "------------\n", "iters:7200\n", "Loss:0.012013012155202389\n", "Pred:[0 1 0 0 0 0 0 1]\n", "True:[0 1 0 0 0 0 0 1]\n", "32 + 33 = 65\n", "------------\n", "iters:7300\n", "Loss:0.028836632944875068\n", "Pred:[1 1 0 1 1 0 0 1]\n", "True:[1 1 0 1 1 0 0 1]\n", "122 + 95 = 217\n", "------------\n", "iters:7400\n", "Loss:0.03777987985745057\n", "Pred:[0 1 1 0 0 0 1 1]\n", "True:[0 1 1 0 0 0 1 1]\n", "60 + 39 = 99\n", "------------\n", "iters:7500\n", "Loss:0.0016045064751352685\n", "Pred:[0 1 0 1 1 0 1 0]\n", "True:[0 1 0 1 1 0 1 0]\n", "10 + 80 = 90\n", "------------\n", "iters:7600\n", "Loss:0.011171217721587709\n", "Pred:[1 0 1 0 0 0 0 1]\n", "True:[1 0 1 0 0 0 0 1]\n", "125 + 36 = 161\n", "------------\n", "iters:7700\n", "Loss:0.08617873995102997\n", "Pred:[0 1 0 0 1 1 1 0]\n", "True:[0 1 0 0 1 1 1 0]\n", "40 + 38 = 78\n", "------------\n", "iters:7800\n", "Loss:0.010397438122656017\n", "Pred:[0 1 1 0 1 0 1 0]\n", "True:[0 1 1 0 1 0 1 0]\n", "88 + 18 = 106\n", "------------\n", "iters:7900\n", "Loss:0.02433398525828952\n", "Pred:[1 0 0 0 0 1 0 1]\n", "True:[1 0 0 0 0 1 0 1]\n", "77 + 56 = 133\n", "------------\n", "iters:8000\n", "Loss:0.001718991142316749\n", "Pred:[1 1 0 1 0 1 0 1]\n", "True:[1 1 0 1 0 1 0 1]\n", "117 + 96 = 213\n", "------------\n", "iters:8100\n", "Loss:0.0007114582440403065\n", "Pred:[0 1 1 1 0 1 1 1]\n", "True:[0 1 1 1 0 1 1 1]\n", "88 + 31 = 119\n", "------------\n", "iters:8200\n", "Loss:0.025349840562032013\n", "Pred:[1 1 0 0 1 1 1 1]\n", "True:[1 1 0 0 1 1 1 1]\n", "90 + 117 = 207\n", "------------\n", "iters:8300\n", "Loss:0.021087777651367698\n", "Pred:[0 0 0 0 1 1 1 0]\n", "True:[0 0 0 0 1 1 1 0]\n", "5 + 9 = 14\n", "------------\n", "iters:8400\n", "Loss:0.012730265116910503\n", "Pred:[1 0 1 0 0 0 0 1]\n", "True:[1 0 1 0 0 0 0 1]\n", "111 + 50 = 161\n", "------------\n", "iters:8500\n", "Loss:0.010435261233883586\n", "Pred:[1 0 1 0 1 0 0 1]\n", "True:[1 0 1 0 1 0 0 1]\n", "125 + 44 = 169\n", "------------\n", "iters:8600\n", "Loss:0.06613496188358742\n", "Pred:[0 1 1 1 0 1 0 1]\n", "True:[0 1 1 1 0 1 0 1]\n", "109 + 8 = 117\n", "------------\n", "iters:8700\n", "Loss:0.011993586944276585\n", "Pred:[0 0 1 1 1 0 0 1]\n", "True:[0 0 1 1 1 0 0 1]\n", "10 + 47 = 57\n", "------------\n", "iters:8800\n", "Loss:0.019079905925579396\n", "Pred:[0 0 1 0 0 1 1 1]\n", "True:[0 0 1 0 0 1 1 1]\n", "18 + 21 = 39\n", "------------\n", "iters:8900\n", "Loss:0.013600918569904754\n", "Pred:[1 0 1 0 0 1 1 1]\n", "True:[1 0 1 0 0 1 1 1]\n", "75 + 92 = 167\n", "------------\n", "iters:9000\n", "Loss:0.009621368877386419\n", "Pred:[1 0 0 0 1 0 0 1]\n", "True:[1 0 0 0 1 0 0 1]\n", "73 + 64 = 137\n", "------------\n", "iters:9100\n", "Loss:0.01811070525066136\n", "Pred:[0 1 1 0 1 0 1 0]\n", "True:[0 1 1 0 1 0 1 0]\n", "63 + 43 = 106\n", "------------\n", "iters:9200\n", "Loss:0.004258061377861931\n", "Pred:[1 0 1 1 1 1 1 0]\n", "True:[1 0 1 1 1 1 1 0]\n", "64 + 126 = 190\n", "------------\n", "iters:9300\n", "Loss:0.04397747646965129\n", "Pred:[0 1 1 0 1 0 0 0]\n", "True:[0 1 1 0 1 0 0 0]\n", "69 + 35 = 104\n", "------------\n", "iters:9400\n", "Loss:0.03249117145077845\n", "Pred:[0 1 1 0 0 0 0 1]\n", "True:[0 1 1 0 0 0 0 1]\n", "80 + 17 = 97\n", "------------\n", "iters:9500\n", "Loss:0.02220329506937916\n", "Pred:[1 0 1 1 0 0 0 0]\n", "True:[1 0 1 1 0 0 0 0]\n", "70 + 106 = 176\n", "------------\n", "iters:9600\n", "Loss:0.020627497677411313\n", "Pred:[1 0 1 0 0 0 0 1]\n", "True:[1 0 1 0 0 0 0 1]\n", "37 + 124 = 161\n", "------------\n", "iters:9700\n", "Loss:0.010723291665635452\n", "Pred:[0 1 0 1 0 0 0 0]\n", "True:[0 1 0 1 0 0 0 0]\n", "25 + 55 = 80\n", "------------\n", "iters:9800\n", "Loss:0.026674913935543214\n", "Pred:[1 0 0 0 1 1 1 0]\n", "True:[1 0 0 0 1 1 1 0]\n", "102 + 40 = 142\n", "------------\n", "iters:9900\n", "Loss:0.03234901207213073\n", "Pred:[0 1 1 0 1 0 0 0]\n", "True:[0 1 1 0 1 0 0 0]\n", "97 + 7 = 104\n", "------------\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "image/png": "\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "tags": [], "needs_background": "light" } } ] }, { "cell_type": "markdown", "metadata": { "id": "PlRjNu4cAz9G" }, "source": [ "sigmoid(ランダム初期化)の結果\n", "\n", "" ] }, { "cell_type": "markdown", "metadata": { "id": "-TkkCuBOJ929" }, "source": [ "sigmoid(Xavier初期化) \n", "" ] }, { "cell_type": "markdown", "metadata": { "id": "692rSn7sBHp_" }, "source": [ "ReLU(ランダム初期化)の結果 \n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "hBAK-mtmI0Fg" }, "source": [ "ReLU(He初期化)の結果 \n", "" ] }, { "cell_type": "markdown", "metadata": { "id": "I9DJd7AzBTMA" }, "source": [ "tanh(ランダム初期化)の結果 \n", "" ] }, { "cell_type": "markdown", "metadata": { "id": "zsFA4_zRJm2C" }, "source": [ "tanh(Xavier初期化)の結果 \n", "" ] }, { "cell_type": "markdown", "metadata": { "id": "FqjobgfAKPXK" }, "source": [ "### 考察\n", "実施する度に結果が少々ばらつくが、どれもうまい具合に収束する。\n", "tanhのランダム初期化だけはあまりうまくいかない。\n", "他のReLUとtanhは初期化にあまり影響を受けていないように見える。\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "-u0A5ThGjKVH" }, "source": [ "# サイン波予測\n", "maxlen:2 iters_num:100 \n", "maxlen:2 iters_num:500 \n", "maxlen:2 iters_num:3000 \n", "maxlen:5 iters_num:100 \n", "maxlen:5 iters_num:500 \n", "maxlen:5 iters_num:3000 \n", "\n", "\n" ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "IjjFN0eciCUL", "outputId": "c1b2a888-a131-4985-cec8-d37aaa09d3cd" }, "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from sklearn.model_selection import train_test_split\n", "\n", "np.random.seed(0)\n", "\n", "# sin曲線\n", "round_num = 10\n", "div_num = 500\n", "ts = np.linspace(0, round_num * np.pi, div_num)\n", "f = np.sin(ts)\n", "\n", "def d_tanh(x):\n", " return 1/(np.cosh(x)**2 + 1e-4)\n", "\n", "# ひとつの時系列データの長さ\n", "maxlen = 5\n", "\n", "# sin波予測の入力データ\n", "test_head = [[f[k]] for k in range(0, maxlen)]\n", "\n", "data = []\n", "target = []\n", "\n", "for i in range(div_num - maxlen):\n", " data.append(f[i: i + maxlen])\n", " target.append(f[i + maxlen])\n", " \n", "X = np.array(data).reshape(len(data), maxlen, 1)\n", "D = np.array(target).reshape(len(data), 1)\n", "\n", "# データ設定\n", "N_train = int(len(data) * 0.8)\n", "N_validation = len(data) - N_train\n", "\n", "x_train, x_test, d_train, d_test = train_test_split(X, D, test_size=N_validation)\n", "\n", "input_layer_size = 1\n", "hidden_layer_size = 5\n", "output_layer_size = 1\n", "\n", "weight_init_std = 0.01\n", "learning_rate = 0.1\n", "\n", "iters_num = 3000\n", "\n", "# ウェイト初期化 (バイアスは簡単のため省略)\n", "W_in = weight_init_std * np.random.randn(input_layer_size, hidden_layer_size)\n", "W_out = weight_init_std * np.random.randn(hidden_layer_size, output_layer_size)\n", "W = weight_init_std * np.random.randn(hidden_layer_size, hidden_layer_size)\n", "\n", "# 勾配\n", "W_in_grad = np.zeros_like(W_in)\n", "W_out_grad = np.zeros_like(W_out)\n", "W_grad = np.zeros_like(W)\n", "\n", "us = []\n", "zs = []\n", "\n", "u = np.zeros(hidden_layer_size)\n", "z = np.zeros(hidden_layer_size)\n", "y = np.zeros(output_layer_size)\n", "\n", "delta_out = np.zeros(output_layer_size)\n", "delta = np.zeros(hidden_layer_size)\n", "\n", "losses = []\n", "\n", "# トレーニング\n", "for i in range(iters_num):\n", " for s in range(x_train.shape[0]):\n", " us.clear()\n", " zs.clear()\n", " z *= 0\n", " \n", " # sにおける正解データ\n", " d = d_train[s]\n", "\n", " xs = x_train[s] \n", " \n", " # 時系列ループ\n", " for t in range(maxlen):\n", " \n", " # 入力値\n", " x = xs[t]\n", " u = np.dot(x, W_in) + np.dot(z, W)\n", " us.append(u)\n", " z = np.tanh(u)\n", " zs.append(z)\n", "\n", " y = np.dot(z, W_out)\n", " \n", " #誤差\n", " loss = mean_squared_error(d, y)\n", " \n", " delta_out = d_mean_squared_error(d, y)\n", " \n", " delta *= 0\n", " for t in range(maxlen)[::-1]:\n", " \n", " delta = (np.dot(delta, W.T) + np.dot(delta_out, W_out.T)) * d_tanh(us[t])\n", " \n", " # 勾配更新\n", " W_grad += np.dot(zs[t].reshape(-1,1), delta.reshape(1,-1))\n", " W_in_grad += np.dot(xs[t], delta.reshape(1,-1))\n", " W_out_grad = np.dot(z.reshape(-1,1), delta_out)\n", " \n", " # 勾配適用\n", " W -= learning_rate * W_grad\n", " W_in -= learning_rate * W_in_grad\n", " W_out -= learning_rate * W_out_grad.reshape(-1,1)\n", " \n", " W_in_grad *= 0\n", " W_out_grad *= 0\n", " W_grad *= 0\n", "\n", "# テスト \n", "for s in range(x_test.shape[0]):\n", " z *= 0\n", "\n", " # sにおける正解データ\n", " d = d_test[s]\n", "\n", " xs = x_test[s]\n", "\n", " # 時系列ループ\n", " for t in range(maxlen):\n", "\n", " # 入力値\n", " x = xs[t]\n", " u = np.dot(x, W_in) + np.dot(z, W)\n", " z = np.tanh(u)\n", "\n", " y = np.dot(z, W_out)\n", "\n", " #誤差\n", " loss = mean_squared_error(d, y)\n", " print('loss:', loss, ' d:', d, ' y:', y)\n", " \n", " \n", " \n", "original = np.full(maxlen, None)\n", "pred_num = 200\n", "\n", "xs = test_head\n", "\n", "# sin波予測\n", "for s in range(0, pred_num):\n", " z *= 0\n", " for t in range(maxlen):\n", " \n", " # 入力値\n", " x = xs[t]\n", " u = np.dot(x, W_in) + np.dot(z, W)\n", " z = np.tanh(u)\n", "\n", " y = np.dot(z, W_out)\n", " original = np.append(original, y)\n", " xs = np.delete(xs, 0)\n", " xs = np.append(xs, y)\n", "\n", "plt.figure()\n", "plt.ylim([-1.5, 1.5])\n", "plt.plot(np.sin(np.linspace(0, round_num* pred_num / div_num * np.pi, pred_num)), linestyle='dotted', color='#aaaaaa')\n", "plt.plot(original, linestyle='dashed', color='black')\n", "plt.show()" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "loss: 1.0231756688222737e-07 d: [-0.29761864] y: [-0.29716628]\n", "loss: 1.2201090156041725e-08 d: [-0.56307233] y: [-0.56322854]\n", "loss: 4.038245320901024e-11 d: [-0.65766776] y: [-0.65765877]\n", "loss: 1.012613171470193e-08 d: [0.13182648] y: [0.13168417]\n", "loss: 6.953992391246068e-08 d: [0.49909101] y: [0.49871807]\n", "loss: 5.5838333193039375e-08 d: [0.9518317] y: [0.95149752]\n", "loss: 1.006541644062743e-07 d: [0.97784112] y: [0.97739245]\n", "loss: 1.7033534573325946e-08 d: [-0.58880346] y: [-0.58861889]\n", "loss: 5.2622812685924356e-08 d: [-0.78351093] y: [-0.78383534]\n", "loss: 2.60918282588755e-10 d: [-0.49909101] y: [-0.49906816]\n", "loss: 5.712849721315155e-08 d: [0.21857331] y: [0.21823529]\n", "loss: 9.126545094849141e-08 d: [-0.33938943] y: [-0.3389622]\n", "loss: 1.0304339135538811e-07 d: [-0.43793098] y: [-0.43747701]\n", "loss: 1.1445752972814562e-07 d: [-0.33346065] y: [-0.3329822]\n", "loss: 4.487028896407404e-08 d: [-0.99639027] y: [-0.9960907]\n", "loss: 4.518845122030081e-08 d: [0.88624247] y: [0.88654309]\n", "loss: 2.3166269315110745e-08 d: [-0.92833248] y: [-0.92811723]\n", "loss: 7.877136645016467e-10 d: [-0.52075286] y: [-0.52079255]\n", "loss: 8.254988700985761e-09 d: [-0.55262221] y: [-0.5527507]\n", "loss: 8.291678208274492e-08 d: [0.47711265] y: [0.47670543]\n", "loss: 3.515702303123948e-08 d: [0.60896952] y: [0.60923469]\n", "loss: 4.6349673171343807e-08 d: [-0.94587102] y: [-0.94556656]\n", "loss: 9.366188781076226e-08 d: [0.27953518] y: [0.27910237]\n", "loss: 7.187670231870784e-08 d: [0.73863456] y: [0.73901371]\n", "loss: 2.5175272023918393e-08 d: [-0.00629574] y: [-0.00607135]\n", "loss: 4.238282860583762e-08 d: [0.54208448] y: [0.54179333]\n", "loss: 5.364268767133932e-08 d: [0.99781582] y: [0.99748827]\n", "loss: 7.761284716153206e-08 d: [-0.96441607] y: [-0.96402208]\n", "loss: 7.29946476567447e-08 d: [0.07547747] y: [0.07509539]\n", "loss: 5.374729213271438e-12 d: [0.66239735] y: [0.66240063]\n", "loss: 3.917648631674704e-08 d: [-0.54736419] y: [-0.54708428]\n", "loss: 1.1396823999698672e-07 d: [0.99393675] y: [0.99345932]\n", "loss: 1.8916287872014157e-10 d: [0.96441607] y: [0.96443552]\n", "loss: 6.104853811859931e-08 d: [-0.22471249] y: [-0.22436306]\n", "loss: 3.374712152402054e-08 d: [-0.99393675] y: [-0.99367695]\n", "loss: 7.459208770391463e-08 d: [-0.71705202] y: [-0.71743827]\n", "loss: 3.3346599617550096e-09 d: [0.8649742] y: [0.86505587]\n", "loss: 1.0502769561892032e-07 d: [-0.11933469] y: [-0.11887638]\n", "loss: 3.7402892316320774e-08 d: [-0.40941891] y: [-0.4091454]\n", "loss: 1.1571745806193722e-07 d: [0.39789889] y: [0.39741782]\n", "loss: 1.3219146512941343e-08 d: [0.98611478] y: [0.98595218]\n", "loss: 6.82473798362108e-08 d: [-0.0691982] y: [-0.06882875]\n", "loss: 7.74733530193459e-08 d: [0.35709413] y: [0.3567005]\n", "loss: 4.2013170888265586e-08 d: [0.99583607] y: [0.9955462]\n", "loss: 2.1526965553360173e-08 d: [-0.92597363] y: [-0.92618113]\n", "loss: 6.817128658024936e-08 d: [0.36882689] y: [0.36845764]\n", "loss: 1.1810524118698761e-08 d: [0.91617219] y: [0.9160185]\n", "loss: 1.511514695515008e-09 d: [0.52611726] y: [0.52617224]\n", "loss: 8.054071771358459e-08 d: [0.96606148] y: [0.96566013]\n", "loss: 1.956018209202484e-08 d: [0.43793098] y: [0.43773319]\n", "loss: 2.365114287876332e-09 d: [-0.86811636] y: [-0.86818514]\n", "loss: 7.426592869136916e-08 d: [-0.99975723] y: [-0.99937184]\n", "loss: 3.369004516388851e-08 d: [0.77562491] y: [0.77588449]\n", "loss: 4.2428732558073345e-09 d: [0.04405617] y: [0.04414829]\n", "loss: 9.775107382870593e-09 d: [0.71705202] y: [0.71719184]\n", "loss: 4.388321808758248e-10 d: [0.8773359] y: [0.87736552]\n", "loss: 3.008800038067477e-08 d: [0.91363079] y: [0.9138761]\n", "loss: 3.4353135857868203e-09 d: [-0.97784112] y: [-0.97775823]\n", "loss: 8.641444917528756e-10 d: [0.96101064] y: [0.96105221]\n", "loss: 8.711809080390049e-08 d: [0.26742375] y: [0.26700633]\n", "loss: 5.0207765974484416e-08 d: [-0.87122411] y: [-0.871541]\n", "loss: 2.838017375633304e-08 d: [-0.91617219] y: [-0.91641043]\n", "loss: 1.5567226856080684e-09 d: [0.87122411] y: [0.87127991]\n", "loss: 9.803777837499867e-09 d: [0.98394564] y: [0.98380562]\n", "loss: 4.14806161320395e-08 d: [0.4036669] y: [0.40337887]\n", "loss: 8.246547424877302e-08 d: [0.0880268] y: [0.08762068]\n", "loss: 3.519939115353052e-08 d: [0.81012572] y: [0.81039105]\n", "loss: 3.3475921950765465e-08 d: [0.41515469] y: [0.41489593]\n", "loss: 3.920213266201233e-08 d: [-0.99524241] y: [-0.9949624]\n", "loss: 6.676591329593643e-09 d: [-0.94789551] y: [-0.94801107]\n", "loss: 2.7259396135763863e-08 d: [0.16916853] y: [0.16893503]\n", "loss: 5.901954968651478e-08 d: [-0.95374324] y: [-0.95339967]\n", "loss: 7.404264207154143e-08 d: [-0.72577151] y: [-0.72615633]\n", "loss: 1.950089841485295e-08 d: [0.74286391] y: [0.7430614]\n", "loss: 9.1706563452403e-09 d: [-0.94380904] y: [-0.94394447]\n", "loss: 1.8342443304828862e-08 d: [-0.83516734] y: [-0.83535887]\n", "loss: 4.0168629526980415e-08 d: [-0.94170965] y: [-0.94142621]\n", "loss: 1.0438172705832991e-07 d: [0.32156366] y: [0.32110676]\n", "loss: 7.96519652519902e-08 d: [-0.48263615] y: [-0.48223703]\n", "loss: 4.522391104134645e-08 d: [0.03776568] y: [0.03746493]\n", "loss: 8.671646127145176e-08 d: [0.34530476] y: [0.34488831]\n", "loss: 3.0057469610811475e-08 d: [0.56307233] y: [0.56282714]\n", "loss: 3.344824485318678e-08 d: [0.90843947] y: [0.90869812]\n", "loss: 9.740854716739049e-08 d: [0.99940055] y: [0.99895917]\n", "loss: 5.2790241458557934e-08 d: [0.85534252] y: [0.85566745]\n", "loss: 1.3438734027154209e-08 d: [0.93739898] y: [0.93756292]\n", "loss: 1.0825441262363102e-07 d: [0.99738016] y: [0.99691486]\n", "loss: 7.346796784291748e-08 d: [-0.6992734] y: [-0.69965672]\n", "loss: 9.62971935624344e-08 d: [-0.10682399] y: [-0.10638513]\n", "loss: 4.974432822949506e-09 d: [-0.54208448] y: [-0.54218422]\n", "loss: 7.137974552014359e-08 d: [0.9995987] y: [0.99922087]\n", "loss: 1.1991319740559976e-07 d: [0.29761864] y: [0.29712892]\n", "loss: 3.111699813899164e-08 d: [0.99322482] y: [0.99297535]\n", "loss: 9.573921959195159e-08 d: [0.33346065] y: [0.33302307]\n", "loss: 5.1818369725854564e-08 d: [-0.83168816] y: [-0.83201008]\n", "loss: 1.1452214926486475e-08 d: [-0.98504973] y: [-0.98489839]\n", "loss: 5.381547067764966e-08 d: [-0.64332332] y: [-0.6436514]\n", "loss: 1.0539082479572855e-07 d: [0.43226238] y: [0.43180327]\n", "loss: 3.26643400999611e-08 d: [-0.81380058] y: [-0.81405617]\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "image/png": "\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "tags": [], "needs_background": "light" } } ] }, { "cell_type": "markdown", "metadata": { "id": "iI6SQJcHj62q" }, "source": [ "maxlen:2 iters_num:100\n", "\n", "" ] }, { "cell_type": "markdown", "metadata": { "id": "ylqvxrZ-9IvP" }, "source": [ "maxlen:2 iters_num:500\n", "\n", "" ] }, { "cell_type": "markdown", "metadata": { "id": "qTUnnjwx9X_8" }, "source": [ "maxlen:2 iters_num:3000 \n", "\n", "" ] }, { "cell_type": "markdown", "metadata": { "id": "xyfTBafWRb2M" }, "source": [ "maxlen:5 iters_num:100 \n", "\n", "" ] }, { "cell_type": "markdown", "metadata": { "id": "52L8VqJ3RxqZ" }, "source": [ "maxlen:5 iters_num:500 \n", "\n", "" ] }, { "cell_type": "markdown", "metadata": { "id": "5SD7pir8UNRJ" }, "source": [ "maxlen:5 iters_num:3000 \n", "\n", "" ] }, { "cell_type": "markdown", "metadata": { "id": "pK4Jprp9UVx5" }, "source": [ "### 考察\n", "maxlen:2 のときは学習がうまく進まない。iters_numを増やしてもうまくいかない。\n", "maxlen:5 にすると、iters_num:3000でほぼ正確なサイン波を予測できている。\n" ] } ] }