{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "day34_01_simpleRNN.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "authorship_tag": "ABX9TyOcbQPro0QML4vSSylreSP3",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/rkti498/e_shikaku/blob/main/day34_01_simpleRNN.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "92THlk0B11ZF"
      },
      "source": [
        "# RNN バイナリ加算のサンプル"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "j0boG2fNAaBl"
      },
      "source": [
        "import numpy as np\n",
        "\n",
        "# 中間層の活性化関数\n",
        "# シグモイド関数(ロジスティック関数)\n",
        "def sigmoid(x):\n",
        "    return 1/(1 + np.exp(-x))\n",
        "\n",
        "# ReLU関数\n",
        "def relu(x):\n",
        "    return np.maximum(0, x)\n",
        "\n",
        "# ステップ関数(閾値0)\n",
        "def step_function(x):\n",
        "    return np.where( x > 0, 1, 0) \n",
        "\n",
        "# 出力層の活性化関数\n",
        "# ソフトマックス関数\n",
        "def softmax(x):\n",
        "    if x.ndim == 2:\n",
        "        x = x.T\n",
        "        x = x - np.max(x, axis=0)\n",
        "        y = np.exp(x) / np.sum(np.exp(x), axis=0)\n",
        "        return y.T\n",
        "\n",
        "    x = x - np.max(x) # オーバーフロー対策\n",
        "    return np.exp(x) / np.sum(np.exp(x))\n",
        "\n",
        "# ソフトマックスとクロスエントロピーの複合関数\n",
        "def softmax_with_loss(d, x):\n",
        "    y = softmax(x)\n",
        "    return cross_entropy_error(d, y)\n",
        "\n",
        "# 誤差関数\n",
        "# 平均二乗誤差\n",
        "def mean_squared_error(d, y):\n",
        "    return np.mean(np.square(d - y)) / 2\n",
        "\n",
        "# クロスエントロピー\n",
        "def cross_entropy_error(d, y):\n",
        "    if y.ndim == 1:\n",
        "        d = d.reshape(1, d.size)\n",
        "        y = y.reshape(1, y.size)\n",
        "        \n",
        "    # 教師データがone-hot-vectorの場合、正解ラベルのインデックスに変換\n",
        "    if d.size == y.size:\n",
        "        d = d.argmax(axis=1)\n",
        "             \n",
        "    batch_size = y.shape[0]\n",
        "    return -np.sum(np.log(y[np.arange(batch_size), d] + 1e-7)) / batch_size\n",
        "\n",
        "\n",
        "\n",
        "# 活性化関数の導関数\n",
        "# シグモイド関数(ロジスティック関数)の導関数\n",
        "def d_sigmoid(x):\n",
        "    dx = (1.0 - sigmoid(x)) * sigmoid(x)\n",
        "    return dx\n",
        "\n",
        "# ReLU関数の導関数\n",
        "def d_relu(x):\n",
        "    return np.where( x > 0, 1, 0)\n",
        "    \n",
        "# ステップ関数の導関数\n",
        "def d_step_function(x):\n",
        "    return 0\n",
        "\n",
        "# 平均二乗誤差の導関数\n",
        "def d_mean_squared_error(d, y):\n",
        "    if type(d) == np.ndarray:\n",
        "        batch_size = d.shape[0]\n",
        "        dx = (y - d)/batch_size\n",
        "    else:\n",
        "        dx = y - d\n",
        "    return dx\n",
        "\n",
        "\n",
        "# ソフトマックスとクロスエントロピーの複合導関数\n",
        "def d_softmax_with_loss(d, y):\n",
        "    batch_size = d.shape[0]\n",
        "    if d.size == y.size: # 教師データがone-hot-vectorの場合\n",
        "        dx = (y - d) / batch_size\n",
        "    else:\n",
        "        dx = y.copy()\n",
        "        dx[np.arange(batch_size), d] -= 1\n",
        "        dx = dx / batch_size\n",
        "    return dx\n",
        "\n",
        "# シグモイドとクロスエントロピーの複合導関数\n",
        "def d_sigmoid_with_loss(d, y):\n",
        "    return y - d\n",
        "\n",
        "# 数値微分\n",
        "def numerical_gradient(f, x):\n",
        "    h = 1e-4\n",
        "    grad = np.zeros_like(x)\n",
        "\n",
        "    for idx in range(x.size):\n",
        "        tmp_val = x[idx]\n",
        "        # f(x + h)の計算\n",
        "        x[idx] = tmp_val + h\n",
        "        fxh1 = f(x)\n",
        "\n",
        "        # f(x - h)の計算\n",
        "        x[idx] = tmp_val - h\n",
        "        fxh2 = f(x)\n",
        "\n",
        "        grad[idx] = (fxh1 - fxh2) / (2 * h)\n",
        "        # 値を元に戻す\n",
        "        x[idx] = tmp_val\n",
        "\n",
        "    return grad\n",
        "\n",
        "\n",
        "def im2col(input_data, filter_h, filter_w, stride=1, pad=0):\n",
        "    N, C, H, W = input_data.shape\n",
        "    out_h = (H + 2*pad - filter_h)//stride + 1\n",
        "    out_w = (W + 2*pad - filter_w)//stride + 1\n",
        "\n",
        "    img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')\n",
        "    col = np.zeros((N, C, filter_h, filter_w, out_h, out_w))\n",
        "\n",
        "    for y in range(filter_h):\n",
        "        y_max = y + stride*out_h\n",
        "        for x in range(filter_w):\n",
        "            x_max = x + stride*out_w\n",
        "            col[:, :, y, x, :, :] = img[:, :, y:y_max:stride, x:x_max:stride]\n",
        "\n",
        "    col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N*out_h*out_w, -1)\n",
        "    return col\n",
        "\n",
        "\n",
        "def col2im(col, input_shape, filter_h, filter_w, stride=1, pad=0):\n",
        "    N, C, H, W = input_shape\n",
        "    out_h = (H + 2*pad - filter_h)//stride + 1\n",
        "    out_w = (W + 2*pad - filter_w)//stride + 1\n",
        "    col = col.reshape(N, out_h, out_w, C, filter_h, filter_w).transpose(0, 3, 4, 5, 1, 2)\n",
        "\n",
        "    img = np.zeros((N, C, H + 2*pad + stride - 1, W + 2*pad + stride - 1))\n",
        "    for y in range(filter_h):\n",
        "        y_max = y + stride*out_h\n",
        "        for x in range(filter_w):\n",
        "            x_max = x + stride*out_w\n",
        "            img[:, :, y:y_max:stride, x:x_max:stride] += col[:, :, y, x, :, :]\n",
        "\n",
        "    return img[:, :, pad:H + pad, pad:W + pad]"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "Ma8DLL-i1l_0",
        "outputId": "7e85f0b8-cf2f-4d50-b263-8ebac7e7bbd4"
      },
      "source": [
        "import numpy as np\n",
        "# from common import functions\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "\n",
        "def d_tanh(x):\n",
        "    return 1/(np.cosh(x) ** 2)\n",
        "\n",
        "# データを用意\n",
        "# 2進数の桁数\n",
        "binary_dim = 8\n",
        "# 最大値 + 1\n",
        "largest_number = pow(2, binary_dim)\n",
        "# largest_numberまで2進数を用意\n",
        "binary = np.unpackbits(np.array([range(largest_number)],dtype=np.uint8).T,axis=1)\n",
        "\n",
        "input_layer_size = 2\n",
        "hidden_layer_size = 16\n",
        "output_layer_size = 1\n",
        "\n",
        "weight_init_std = 1\n",
        "learning_rate = 0.1\n",
        "\n",
        "iters_num = 10000\n",
        "plot_interval = 100\n",
        "\n",
        "# ウェイト初期化 (バイアスは簡単のため省略)\n",
        "# W_in = weight_init_std * np.random.randn(input_layer_size, hidden_layer_size)\n",
        "# W_out = weight_init_std * np.random.randn(hidden_layer_size, output_layer_size)\n",
        "# W = weight_init_std * np.random.randn(hidden_layer_size, hidden_layer_size)\n",
        "# Xavier\n",
        "W_in = np.random.randn(input_layer_size, hidden_layer_size) / (np.sqrt(input_layer_size))\n",
        "W_out = np.random.randn(hidden_layer_size, output_layer_size) / (np.sqrt(hidden_layer_size))\n",
        "W = np.random.randn(hidden_layer_size, hidden_layer_size) / (np.sqrt(hidden_layer_size))\n",
        "# He\n",
        "# W_in = np.random.randn(input_layer_size, hidden_layer_size) / (np.sqrt(input_layer_size)) * np.sqrt(2)\n",
        "# W_out = np.random.randn(hidden_layer_size, output_layer_size) / (np.sqrt(hidden_layer_size)) * np.sqrt(2)\n",
        "# W = np.random.randn(hidden_layer_size, hidden_layer_size) / (np.sqrt(hidden_layer_size)) * np.sqrt(2)\n",
        "\n",
        "\n",
        "# 勾配\n",
        "W_in_grad = np.zeros_like(W_in)\n",
        "W_out_grad = np.zeros_like(W_out)\n",
        "W_grad = np.zeros_like(W)\n",
        "\n",
        "u = np.zeros((hidden_layer_size, binary_dim + 1))\n",
        "z = np.zeros((hidden_layer_size, binary_dim + 1))\n",
        "y = np.zeros((output_layer_size, binary_dim))\n",
        "\n",
        "delta_out = np.zeros((output_layer_size, binary_dim))\n",
        "delta = np.zeros((hidden_layer_size, binary_dim + 1))\n",
        "\n",
        "all_losses = []\n",
        "\n",
        "for i in range(iters_num):\n",
        "    \n",
        "    # A, B初期化 (a + b = d)\n",
        "    a_int = np.random.randint(largest_number/2)\n",
        "    a_bin = binary[a_int] # binary encoding\n",
        "    b_int = np.random.randint(largest_number/2)\n",
        "    b_bin = binary[b_int] # binary encoding\n",
        "    \n",
        "    # 正解データ\n",
        "    d_int = a_int + b_int\n",
        "    d_bin = binary[d_int]\n",
        "    \n",
        "    # 出力バイナリ\n",
        "    out_bin = np.zeros_like(d_bin)\n",
        "    \n",
        "    # 時系列全体の誤差\n",
        "    all_loss = 0    \n",
        "    \n",
        "    # 時系列ループ\n",
        "    for t in range(binary_dim):\n",
        "        # 入力値\n",
        "        X = np.array([a_bin[ - t - 1], b_bin[ - t - 1]]).reshape(1, -1)\n",
        "        # 時刻tにおける正解データ\n",
        "        dd = np.array([d_bin[binary_dim - t - 1]])\n",
        "        \n",
        "        u[:,t+1] = np.dot(X, W_in) + np.dot(z[:,t].reshape(1, -1), W)\n",
        "\n",
        "        z[:,t+1] = sigmoid(u[:,t+1])\n",
        "        y[:,t] = sigmoid(np.dot(z[:,t+1].reshape(1, -1), W_out))\n",
        "\n",
        "        # z[:,t+1] = relu(u[:,t+1])\n",
        "        # y[:,t] = relu(np.dot(z[:,t+1].reshape(1, -1), W_out))\n",
        "\n",
        "        z[:,t+1] = np.tanh(u[:,t+1])    \n",
        "        y[:,t] = np.tanh(np.dot(z[:,t+1].reshape(1, -1), W_out))\n",
        "\n",
        "\n",
        "        #誤差\n",
        "        loss = mean_squared_error(dd, y[:,t])\n",
        "        \n",
        "        delta_out[:,t] = d_mean_squared_error(dd, y[:,t]) * d_sigmoid(y[:,t])        \n",
        "        \n",
        "        all_loss += loss\n",
        "\n",
        "        out_bin[binary_dim - t - 1] = np.round(y[:,t])\n",
        "    \n",
        "    \n",
        "    for t in range(binary_dim)[::-1]:\n",
        "        X = np.array([a_bin[-t-1],b_bin[-t-1]]).reshape(1, -1)        \n",
        "\n",
        "        delta[:,t] = (np.dot(delta[:,t+1].T, W.T) + np.dot(delta_out[:,t].T, W_out.T)) * d_sigmoid(u[:,t+1])\n",
        "\n",
        "        # delta[:,t] = (np.dot(delta[:,t+1].T, W.T) + np.dot(delta_out[:,t].T, W_out.T)) * d_relu(u[:,t+1])\n",
        "\n",
        "        # delta[:,t] = (np.dot(delta[:,t+1].T, W.T) + np.dot(delta_out[:,t].T, W_out.T)) * d_tanh(u[:,t+1])    \n",
        "\n",
        "        # 勾配更新\n",
        "        W_out_grad += np.dot(z[:,t+1].reshape(-1,1), delta_out[:,t].reshape(-1,1))\n",
        "        W_grad += np.dot(z[:,t].reshape(-1,1), delta[:,t].reshape(1,-1))\n",
        "        W_in_grad += np.dot(X.T, delta[:,t].reshape(1,-1))\n",
        "    \n",
        "    # 勾配適用\n",
        "    W_in -= learning_rate * W_in_grad\n",
        "    W_out -= learning_rate * W_out_grad\n",
        "    W -= learning_rate * W_grad\n",
        "    \n",
        "    W_in_grad *= 0\n",
        "    W_out_grad *= 0\n",
        "    W_grad *= 0\n",
        "    \n",
        "\n",
        "    if(i % plot_interval == 0):\n",
        "        all_losses.append(all_loss)        \n",
        "        print(\"iters:\" + str(i))\n",
        "        print(\"Loss:\" + str(all_loss))\n",
        "        print(\"Pred:\" + str(out_bin))\n",
        "        print(\"True:\" + str(d_bin))\n",
        "        out_int = 0\n",
        "        for index,x in enumerate(reversed(out_bin)):\n",
        "            out_int += x * pow(2, index)\n",
        "        print(str(a_int) + \" + \" + str(b_int) + \" = \" + str(out_int))\n",
        "        print(\"------------\")\n",
        "\n",
        "lists = range(0, iters_num, plot_interval)\n",
        "plt.plot(lists, all_losses, label=\"loss\")\n",
        "plt.show()"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "iters:0\n",
            "Loss:2.153564894745493\n",
            "Pred:[0 0 1 0 0 0 0 0]\n",
            "True:[1 1 0 1 0 1 0 0]\n",
            "121 + 91 = 32\n",
            "------------\n",
            "iters:100\n",
            "Loss:0.5659978109753718\n",
            "Pred:[0 0 1 1 1 1 1 0]\n",
            "True:[0 0 1 1 0 0 1 1]\n",
            "24 + 27 = 62\n",
            "------------\n",
            "iters:200\n",
            "Loss:1.5234273549747595\n",
            "Pred:[0 0 0 0 1 0 0 0]\n",
            "True:[0 1 0 1 1 0 1 1]\n",
            "5 + 86 = 8\n",
            "------------\n",
            "iters:300\n",
            "Loss:1.0732684500064333\n",
            "Pred:[0 0 0 1 1 1 0 0]\n",
            "True:[0 1 1 1 1 1 0 1]\n",
            "111 + 14 = 28\n",
            "------------\n",
            "iters:400\n",
            "Loss:1.3702915075700555\n",
            "Pred:[0 0 0 0 1 1 1 0]\n",
            "True:[0 1 1 1 1 0 0 1]\n",
            "10 + 111 = 14\n",
            "------------\n",
            "iters:500\n",
            "Loss:0.9647743604989347\n",
            "Pred:[0 1 0 0 1 1 0 0]\n",
            "True:[1 0 0 1 0 1 1 0]\n",
            "70 + 80 = 76\n",
            "------------\n",
            "iters:600\n",
            "Loss:1.1480619329217434\n",
            "Pred:[1 1 1 1 1 1 1 0]\n",
            "True:[1 1 1 0 0 0 0 1]\n",
            "114 + 111 = 254\n",
            "------------\n",
            "iters:700\n",
            "Loss:1.0431930834757628\n",
            "Pred:[1 1 1 1 1 0 1 0]\n",
            "True:[1 0 1 1 1 1 0 0]\n",
            "73 + 115 = 250\n",
            "------------\n",
            "iters:800\n",
            "Loss:0.8700848545561863\n",
            "Pred:[0 0 1 1 1 0 0 0]\n",
            "True:[0 1 0 0 0 0 0 0]\n",
            "24 + 40 = 56\n",
            "------------\n",
            "iters:900\n",
            "Loss:0.9435390961625657\n",
            "Pred:[1 1 0 1 0 0 0 0]\n",
            "True:[1 0 0 1 0 1 1 0]\n",
            "69 + 81 = 208\n",
            "------------\n",
            "iters:1000\n",
            "Loss:1.2748087371021117\n",
            "Pred:[0 0 1 1 1 1 0 1]\n",
            "True:[0 1 0 0 1 1 1 0]\n",
            "25 + 53 = 61\n",
            "------------\n",
            "iters:1100\n",
            "Loss:1.0681730984674394\n",
            "Pred:[0 0 1 1 1 0 0 1]\n",
            "True:[0 1 0 0 0 1 0 1]\n",
            "25 + 44 = 57\n",
            "------------\n",
            "iters:1200\n",
            "Loss:0.5352576913311647\n",
            "Pred:[0 1 1 1 0 0 0 0]\n",
            "True:[0 1 1 1 1 0 0 0]\n",
            "24 + 96 = 112\n",
            "------------\n",
            "iters:1300\n",
            "Loss:0.6447085338301011\n",
            "Pred:[0 0 0 0 0 0 0 0]\n",
            "True:[0 0 0 0 1 1 0 1]\n",
            "3 + 10 = 0\n",
            "------------\n",
            "iters:1400\n",
            "Loss:0.8743242683856232\n",
            "Pred:[0 0 1 0 0 0 0 0]\n",
            "True:[0 0 1 1 1 1 0 1]\n",
            "17 + 44 = 32\n",
            "------------\n",
            "iters:1500\n",
            "Loss:0.9926579862906834\n",
            "Pred:[0 0 1 0 0 1 0 0]\n",
            "True:[0 0 1 1 1 1 1 0]\n",
            "26 + 36 = 36\n",
            "------------\n",
            "iters:1600\n",
            "Loss:1.1973668032523332\n",
            "Pred:[1 1 1 1 1 1 0 1]\n",
            "True:[1 1 0 0 0 0 1 1]\n",
            "86 + 109 = 253\n",
            "------------\n",
            "iters:1700\n",
            "Loss:0.5712612361030289\n",
            "Pred:[0 0 1 1 0 1 1 0]\n",
            "True:[0 0 1 0 0 1 1 1]\n",
            "19 + 20 = 54\n",
            "------------\n",
            "iters:1800\n",
            "Loss:1.1251359893632809\n",
            "Pred:[1 0 1 1 1 1 0 0]\n",
            "True:[1 0 1 0 0 0 0 0]\n",
            "93 + 67 = 188\n",
            "------------\n",
            "iters:1900\n",
            "Loss:0.6588827075097232\n",
            "Pred:[1 1 1 1 1 0 1 0]\n",
            "True:[1 0 1 1 1 1 1 0]\n",
            "68 + 122 = 250\n",
            "------------\n",
            "iters:2000\n",
            "Loss:0.8459203533313725\n",
            "Pred:[0 1 0 1 1 1 0 0]\n",
            "True:[1 0 0 1 1 1 0 0]\n",
            "125 + 31 = 92\n",
            "------------\n",
            "iters:2100\n",
            "Loss:0.6806428475661361\n",
            "Pred:[1 1 0 1 1 1 0 0]\n",
            "True:[1 1 0 1 1 0 0 0]\n",
            "123 + 93 = 220\n",
            "------------\n",
            "iters:2200\n",
            "Loss:0.28051406170320214\n",
            "Pred:[0 1 0 0 1 0 0 0]\n",
            "True:[0 1 0 0 1 1 0 0]\n",
            "41 + 35 = 72\n",
            "------------\n",
            "iters:2300\n",
            "Loss:0.2678601472715533\n",
            "Pred:[0 0 1 1 1 0 1 1]\n",
            "True:[0 0 1 1 1 0 0 1]\n",
            "26 + 31 = 59\n",
            "------------\n",
            "iters:2400\n",
            "Loss:0.5053430276519022\n",
            "Pred:[0 1 0 0 1 1 1 1]\n",
            "True:[0 1 0 1 1 1 1 1]\n",
            "25 + 70 = 79\n",
            "------------\n",
            "iters:2500\n",
            "Loss:0.4214456423294395\n",
            "Pred:[1 1 1 0 1 0 0 0]\n",
            "True:[0 1 1 0 1 0 0 0]\n",
            "69 + 35 = 232\n",
            "------------\n",
            "iters:2600\n",
            "Loss:0.43138557041268244\n",
            "Pred:[0 0 1 1 0 0 1 1]\n",
            "True:[0 0 1 1 0 1 1 1]\n",
            "22 + 33 = 51\n",
            "------------\n",
            "iters:2700\n",
            "Loss:0.08342940450333307\n",
            "Pred:[1 1 1 0 1 0 1 0]\n",
            "True:[1 1 1 0 1 0 1 0]\n",
            "127 + 107 = 234\n",
            "------------\n",
            "iters:2800\n",
            "Loss:0.9843271892876817\n",
            "Pred:[1 1 0 1 1 1 0 0]\n",
            "True:[1 0 0 0 1 0 0 0]\n",
            "17 + 119 = 220\n",
            "------------\n",
            "iters:2900\n",
            "Loss:0.27041290556212505\n",
            "Pred:[0 1 1 1 0 1 1 1]\n",
            "True:[0 1 1 1 0 0 1 1]\n",
            "85 + 30 = 119\n",
            "------------\n",
            "iters:3000\n",
            "Loss:0.17747346976447112\n",
            "Pred:[0 0 1 1 1 1 0 1]\n",
            "True:[0 0 1 1 1 1 0 1]\n",
            "36 + 25 = 61\n",
            "------------\n",
            "iters:3100\n",
            "Loss:0.15474438566810336\n",
            "Pred:[0 1 0 0 1 1 0 1]\n",
            "True:[0 1 0 0 1 1 0 1]\n",
            "31 + 46 = 77\n",
            "------------\n",
            "iters:3200\n",
            "Loss:0.2925576273470318\n",
            "Pred:[1 0 1 1 1 1 1 0]\n",
            "True:[1 0 0 1 1 1 1 0]\n",
            "38 + 120 = 190\n",
            "------------\n",
            "iters:3300\n",
            "Loss:0.058212540150715524\n",
            "Pred:[0 1 1 0 0 0 0 1]\n",
            "True:[0 1 1 0 0 0 0 1]\n",
            "65 + 32 = 97\n",
            "------------\n",
            "iters:3400\n",
            "Loss:0.23097330361993285\n",
            "Pred:[0 1 1 1 0 0 0 1]\n",
            "True:[0 1 1 1 0 0 0 1]\n",
            "70 + 43 = 113\n",
            "------------\n",
            "iters:3500\n",
            "Loss:0.12528287152008144\n",
            "Pred:[0 1 0 1 1 0 1 0]\n",
            "True:[0 1 0 1 1 0 1 0]\n",
            "33 + 57 = 90\n",
            "------------\n",
            "iters:3600\n",
            "Loss:0.11931292858767471\n",
            "Pred:[1 0 0 1 0 1 1 1]\n",
            "True:[1 0 0 1 0 1 1 1]\n",
            "35 + 116 = 151\n",
            "------------\n",
            "iters:3700\n",
            "Loss:0.03487172172445985\n",
            "Pred:[0 1 1 0 0 0 0 0]\n",
            "True:[0 1 1 0 0 0 0 0]\n",
            "24 + 72 = 96\n",
            "------------\n",
            "iters:3800\n",
            "Loss:0.42386621971691407\n",
            "Pred:[1 1 0 0 0 1 1 0]\n",
            "True:[1 0 0 0 0 1 1 0]\n",
            "56 + 78 = 198\n",
            "------------\n",
            "iters:3900\n",
            "Loss:0.04679461229142709\n",
            "Pred:[0 1 0 0 1 1 1 0]\n",
            "True:[0 1 0 0 1 1 1 0]\n",
            "51 + 27 = 78\n",
            "------------\n",
            "iters:4000\n",
            "Loss:0.08500932952686045\n",
            "Pred:[0 1 0 1 1 1 0 0]\n",
            "True:[0 1 0 1 1 1 0 0]\n",
            "16 + 76 = 92\n",
            "------------\n",
            "iters:4100\n",
            "Loss:0.028449876185565142\n",
            "Pred:[1 1 1 1 0 0 1 1]\n",
            "True:[1 1 1 1 0 0 1 1]\n",
            "125 + 118 = 243\n",
            "------------\n",
            "iters:4200\n",
            "Loss:0.010387937689509383\n",
            "Pred:[1 0 0 0 0 1 1 1]\n",
            "True:[1 0 0 0 0 1 1 1]\n",
            "87 + 48 = 135\n",
            "------------\n",
            "iters:4300\n",
            "Loss:0.16121578777568452\n",
            "Pred:[1 0 0 0 0 1 0 1]\n",
            "True:[1 0 0 0 0 1 0 1]\n",
            "55 + 78 = 133\n",
            "------------\n",
            "iters:4400\n",
            "Loss:0.008219474578897448\n",
            "Pred:[0 1 1 1 0 1 1 1]\n",
            "True:[0 1 1 1 0 1 1 1]\n",
            "7 + 112 = 119\n",
            "------------\n",
            "iters:4500\n",
            "Loss:0.3146766921785453\n",
            "Pred:[1 1 1 1 1 1 0 0]\n",
            "True:[1 0 1 1 1 1 0 0]\n",
            "81 + 107 = 252\n",
            "------------\n",
            "iters:4600\n",
            "Loss:0.023366812397411115\n",
            "Pred:[1 1 1 0 0 0 1 0]\n",
            "True:[1 1 1 0 0 0 1 0]\n",
            "107 + 119 = 226\n",
            "------------\n",
            "iters:4700\n",
            "Loss:0.021387873678467333\n",
            "Pred:[1 1 0 1 0 1 0 0]\n",
            "True:[1 1 0 1 0 1 0 0]\n",
            "98 + 114 = 212\n",
            "------------\n",
            "iters:4800\n",
            "Loss:0.030882954549819937\n",
            "Pred:[0 1 1 0 1 1 0 1]\n",
            "True:[0 1 1 0 1 1 0 1]\n",
            "11 + 98 = 109\n",
            "------------\n",
            "iters:4900\n",
            "Loss:0.12399539928882054\n",
            "Pred:[0 1 0 0 1 0 1 1]\n",
            "True:[0 1 0 0 1 0 1 1]\n",
            "61 + 14 = 75\n",
            "------------\n",
            "iters:5000\n",
            "Loss:0.040624100294295765\n",
            "Pred:[0 1 1 1 0 1 1 0]\n",
            "True:[0 1 1 1 0 1 1 0]\n",
            "99 + 19 = 118\n",
            "------------\n",
            "iters:5100\n",
            "Loss:0.029890623469243406\n",
            "Pred:[1 0 0 1 1 0 1 1]\n",
            "True:[1 0 0 1 1 0 1 1]\n",
            "94 + 61 = 155\n",
            "------------\n",
            "iters:5200\n",
            "Loss:0.006747409869158992\n",
            "Pred:[1 1 0 1 0 1 0 0]\n",
            "True:[1 1 0 1 0 1 0 0]\n",
            "104 + 108 = 212\n",
            "------------\n",
            "iters:5300\n",
            "Loss:0.11362048400667227\n",
            "Pred:[0 0 1 1 1 1 0 0]\n",
            "True:[0 0 1 1 1 1 0 0]\n",
            "43 + 17 = 60\n",
            "------------\n",
            "iters:5400\n",
            "Loss:0.036499078282814464\n",
            "Pred:[1 0 0 1 0 0 0 1]\n",
            "True:[1 0 0 1 0 0 0 1]\n",
            "104 + 41 = 145\n",
            "------------\n",
            "iters:5500\n",
            "Loss:0.0020271442734815285\n",
            "Pred:[1 0 1 1 1 0 0 0]\n",
            "True:[1 0 1 1 1 0 0 0]\n",
            "92 + 92 = 184\n",
            "------------\n",
            "iters:5600\n",
            "Loss:0.019150168217955153\n",
            "Pred:[0 0 0 1 0 0 1 1]\n",
            "True:[0 0 0 1 0 0 1 1]\n",
            "14 + 5 = 19\n",
            "------------\n",
            "iters:5700\n",
            "Loss:0.015615537960966534\n",
            "Pred:[0 1 1 0 1 1 1 0]\n",
            "True:[0 1 1 0 1 1 1 0]\n",
            "11 + 99 = 110\n",
            "------------\n",
            "iters:5800\n",
            "Loss:0.055848845352638045\n",
            "Pred:[0 1 0 0 0 0 0 0]\n",
            "True:[0 1 0 0 0 0 0 0]\n",
            "3 + 61 = 64\n",
            "------------\n",
            "iters:5900\n",
            "Loss:0.027811941122713048\n",
            "Pred:[1 0 1 1 0 1 0 1]\n",
            "True:[1 0 1 1 0 1 0 1]\n",
            "85 + 96 = 181\n",
            "------------\n",
            "iters:6000\n",
            "Loss:0.010294577678977852\n",
            "Pred:[1 0 1 0 0 0 1 0]\n",
            "True:[1 0 1 0 0 0 1 0]\n",
            "106 + 56 = 162\n",
            "------------\n",
            "iters:6100\n",
            "Loss:0.0294434376838914\n",
            "Pred:[0 1 0 1 0 0 0 0]\n",
            "True:[0 1 0 1 0 0 0 0]\n",
            "63 + 17 = 80\n",
            "------------\n",
            "iters:6200\n",
            "Loss:0.09631313639715536\n",
            "Pred:[0 1 0 0 0 1 0 1]\n",
            "True:[0 1 0 0 0 1 0 1]\n",
            "24 + 45 = 69\n",
            "------------\n",
            "iters:6300\n",
            "Loss:0.01924271266456333\n",
            "Pred:[1 1 0 0 1 0 0 1]\n",
            "True:[1 1 0 0 1 0 0 1]\n",
            "110 + 91 = 201\n",
            "------------\n",
            "iters:6400\n",
            "Loss:0.08407179514535484\n",
            "Pred:[0 1 1 1 1 1 1 0]\n",
            "True:[0 1 1 1 1 1 1 0]\n",
            "20 + 106 = 126\n",
            "------------\n",
            "iters:6500\n",
            "Loss:0.004977384528896935\n",
            "Pred:[1 0 0 0 1 0 0 0]\n",
            "True:[1 0 0 0 1 0 0 0]\n",
            "36 + 100 = 136\n",
            "------------\n",
            "iters:6600\n",
            "Loss:0.023494576741571575\n",
            "Pred:[1 0 0 0 1 1 1 0]\n",
            "True:[1 0 0 0 1 1 1 0]\n",
            "95 + 47 = 142\n",
            "------------\n",
            "iters:6700\n",
            "Loss:0.003000975419069599\n",
            "Pred:[1 0 1 1 0 0 1 1]\n",
            "True:[1 0 1 1 0 0 1 1]\n",
            "96 + 83 = 179\n",
            "------------\n",
            "iters:6800\n",
            "Loss:0.0021108360627262197\n",
            "Pred:[1 0 0 1 1 0 0 1]\n",
            "True:[1 0 0 1 1 0 0 1]\n",
            "87 + 66 = 153\n",
            "------------\n",
            "iters:6900\n",
            "Loss:0.043114944696639484\n",
            "Pred:[0 1 0 1 1 0 1 0]\n",
            "True:[0 1 0 1 1 0 1 0]\n",
            "63 + 27 = 90\n",
            "------------\n",
            "iters:7000\n",
            "Loss:0.006124100758764581\n",
            "Pred:[0 1 1 1 1 0 0 0]\n",
            "True:[0 1 1 1 1 0 0 0]\n",
            "114 + 6 = 120\n",
            "------------\n",
            "iters:7100\n",
            "Loss:0.06282776676401762\n",
            "Pred:[0 1 1 0 0 0 0 1]\n",
            "True:[0 1 1 0 0 0 0 1]\n",
            "69 + 28 = 97\n",
            "------------\n",
            "iters:7200\n",
            "Loss:0.012013012155202389\n",
            "Pred:[0 1 0 0 0 0 0 1]\n",
            "True:[0 1 0 0 0 0 0 1]\n",
            "32 + 33 = 65\n",
            "------------\n",
            "iters:7300\n",
            "Loss:0.028836632944875068\n",
            "Pred:[1 1 0 1 1 0 0 1]\n",
            "True:[1 1 0 1 1 0 0 1]\n",
            "122 + 95 = 217\n",
            "------------\n",
            "iters:7400\n",
            "Loss:0.03777987985745057\n",
            "Pred:[0 1 1 0 0 0 1 1]\n",
            "True:[0 1 1 0 0 0 1 1]\n",
            "60 + 39 = 99\n",
            "------------\n",
            "iters:7500\n",
            "Loss:0.0016045064751352685\n",
            "Pred:[0 1 0 1 1 0 1 0]\n",
            "True:[0 1 0 1 1 0 1 0]\n",
            "10 + 80 = 90\n",
            "------------\n",
            "iters:7600\n",
            "Loss:0.011171217721587709\n",
            "Pred:[1 0 1 0 0 0 0 1]\n",
            "True:[1 0 1 0 0 0 0 1]\n",
            "125 + 36 = 161\n",
            "------------\n",
            "iters:7700\n",
            "Loss:0.08617873995102997\n",
            "Pred:[0 1 0 0 1 1 1 0]\n",
            "True:[0 1 0 0 1 1 1 0]\n",
            "40 + 38 = 78\n",
            "------------\n",
            "iters:7800\n",
            "Loss:0.010397438122656017\n",
            "Pred:[0 1 1 0 1 0 1 0]\n",
            "True:[0 1 1 0 1 0 1 0]\n",
            "88 + 18 = 106\n",
            "------------\n",
            "iters:7900\n",
            "Loss:0.02433398525828952\n",
            "Pred:[1 0 0 0 0 1 0 1]\n",
            "True:[1 0 0 0 0 1 0 1]\n",
            "77 + 56 = 133\n",
            "------------\n",
            "iters:8000\n",
            "Loss:0.001718991142316749\n",
            "Pred:[1 1 0 1 0 1 0 1]\n",
            "True:[1 1 0 1 0 1 0 1]\n",
            "117 + 96 = 213\n",
            "------------\n",
            "iters:8100\n",
            "Loss:0.0007114582440403065\n",
            "Pred:[0 1 1 1 0 1 1 1]\n",
            "True:[0 1 1 1 0 1 1 1]\n",
            "88 + 31 = 119\n",
            "------------\n",
            "iters:8200\n",
            "Loss:0.025349840562032013\n",
            "Pred:[1 1 0 0 1 1 1 1]\n",
            "True:[1 1 0 0 1 1 1 1]\n",
            "90 + 117 = 207\n",
            "------------\n",
            "iters:8300\n",
            "Loss:0.021087777651367698\n",
            "Pred:[0 0 0 0 1 1 1 0]\n",
            "True:[0 0 0 0 1 1 1 0]\n",
            "5 + 9 = 14\n",
            "------------\n",
            "iters:8400\n",
            "Loss:0.012730265116910503\n",
            "Pred:[1 0 1 0 0 0 0 1]\n",
            "True:[1 0 1 0 0 0 0 1]\n",
            "111 + 50 = 161\n",
            "------------\n",
            "iters:8500\n",
            "Loss:0.010435261233883586\n",
            "Pred:[1 0 1 0 1 0 0 1]\n",
            "True:[1 0 1 0 1 0 0 1]\n",
            "125 + 44 = 169\n",
            "------------\n",
            "iters:8600\n",
            "Loss:0.06613496188358742\n",
            "Pred:[0 1 1 1 0 1 0 1]\n",
            "True:[0 1 1 1 0 1 0 1]\n",
            "109 + 8 = 117\n",
            "------------\n",
            "iters:8700\n",
            "Loss:0.011993586944276585\n",
            "Pred:[0 0 1 1 1 0 0 1]\n",
            "True:[0 0 1 1 1 0 0 1]\n",
            "10 + 47 = 57\n",
            "------------\n",
            "iters:8800\n",
            "Loss:0.019079905925579396\n",
            "Pred:[0 0 1 0 0 1 1 1]\n",
            "True:[0 0 1 0 0 1 1 1]\n",
            "18 + 21 = 39\n",
            "------------\n",
            "iters:8900\n",
            "Loss:0.013600918569904754\n",
            "Pred:[1 0 1 0 0 1 1 1]\n",
            "True:[1 0 1 0 0 1 1 1]\n",
            "75 + 92 = 167\n",
            "------------\n",
            "iters:9000\n",
            "Loss:0.009621368877386419\n",
            "Pred:[1 0 0 0 1 0 0 1]\n",
            "True:[1 0 0 0 1 0 0 1]\n",
            "73 + 64 = 137\n",
            "------------\n",
            "iters:9100\n",
            "Loss:0.01811070525066136\n",
            "Pred:[0 1 1 0 1 0 1 0]\n",
            "True:[0 1 1 0 1 0 1 0]\n",
            "63 + 43 = 106\n",
            "------------\n",
            "iters:9200\n",
            "Loss:0.004258061377861931\n",
            "Pred:[1 0 1 1 1 1 1 0]\n",
            "True:[1 0 1 1 1 1 1 0]\n",
            "64 + 126 = 190\n",
            "------------\n",
            "iters:9300\n",
            "Loss:0.04397747646965129\n",
            "Pred:[0 1 1 0 1 0 0 0]\n",
            "True:[0 1 1 0 1 0 0 0]\n",
            "69 + 35 = 104\n",
            "------------\n",
            "iters:9400\n",
            "Loss:0.03249117145077845\n",
            "Pred:[0 1 1 0 0 0 0 1]\n",
            "True:[0 1 1 0 0 0 0 1]\n",
            "80 + 17 = 97\n",
            "------------\n",
            "iters:9500\n",
            "Loss:0.02220329506937916\n",
            "Pred:[1 0 1 1 0 0 0 0]\n",
            "True:[1 0 1 1 0 0 0 0]\n",
            "70 + 106 = 176\n",
            "------------\n",
            "iters:9600\n",
            "Loss:0.020627497677411313\n",
            "Pred:[1 0 1 0 0 0 0 1]\n",
            "True:[1 0 1 0 0 0 0 1]\n",
            "37 + 124 = 161\n",
            "------------\n",
            "iters:9700\n",
            "Loss:0.010723291665635452\n",
            "Pred:[0 1 0 1 0 0 0 0]\n",
            "True:[0 1 0 1 0 0 0 0]\n",
            "25 + 55 = 80\n",
            "------------\n",
            "iters:9800\n",
            "Loss:0.026674913935543214\n",
            "Pred:[1 0 0 0 1 1 1 0]\n",
            "True:[1 0 0 0 1 1 1 0]\n",
            "102 + 40 = 142\n",
            "------------\n",
            "iters:9900\n",
            "Loss:0.03234901207213073\n",
            "Pred:[0 1 1 0 1 0 0 0]\n",
            "True:[0 1 1 0 1 0 0 0]\n",
            "97 + 7 = 104\n",
            "------------\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "image/png": "\n",
            "text/plain": [
              "<Figure size 432x288 with 1 Axes>"
            ]
          },
          "metadata": {
            "tags": [],
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PlRjNu4cAz9G"
      },
      "source": [
        "sigmoid(ランダム初期化)の結果\n",
        "\n",
        "![image.png]()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-TkkCuBOJ929"
      },
      "source": [
        "sigmoid(Xavier初期化)  \n",
        "![image.png]()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "692rSn7sBHp_"
      },
      "source": [
        "ReLU(ランダム初期化)の結果  \n",
        "![image.png]()\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hBAK-mtmI0Fg"
      },
      "source": [
        "ReLU(He初期化)の結果  \n",
        "![image.png]()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "I9DJd7AzBTMA"
      },
      "source": [
        "tanh(ランダム初期化)の結果  \n",
        "![image.png]()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zsFA4_zRJm2C"
      },
      "source": [
        "tanh(Xavier初期化)の結果  \n",
        "![image.png]()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FqjobgfAKPXK"
      },
      "source": [
        "### 考察\n",
        "実施する度に結果が少々ばらつくが、どれもうまい具合に収束する。\n",
        "tanhのランダム初期化だけはあまりうまくいかない。\n",
        "他のReLUとtanhは初期化にあまり影響を受けていないように見える。\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-u0A5ThGjKVH"
      },
      "source": [
        "# サイン波予測\n",
        "maxlen:2 iters_num:100  \n",
        "maxlen:2 iters_num:500  \n",
        "maxlen:2 iters_num:3000  \n",
        "maxlen:5 iters_num:100  \n",
        "maxlen:5 iters_num:500  \n",
        "maxlen:5 iters_num:3000  \n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "IjjFN0eciCUL",
        "outputId": "c1b2a888-a131-4985-cec8-d37aaa09d3cd"
      },
      "source": [
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "from sklearn.model_selection import train_test_split\n",
        "\n",
        "np.random.seed(0)\n",
        "\n",
        "# sin曲線\n",
        "round_num = 10\n",
        "div_num = 500\n",
        "ts = np.linspace(0, round_num * np.pi, div_num)\n",
        "f = np.sin(ts)\n",
        "\n",
        "def d_tanh(x):\n",
        "    return 1/(np.cosh(x)**2 + 1e-4)\n",
        "\n",
        "# ひとつの時系列データの長さ\n",
        "maxlen = 5\n",
        "\n",
        "# sin波予測の入力データ\n",
        "test_head = [[f[k]] for k in range(0, maxlen)]\n",
        "\n",
        "data = []\n",
        "target = []\n",
        "\n",
        "for i in range(div_num - maxlen):\n",
        "    data.append(f[i: i + maxlen])\n",
        "    target.append(f[i + maxlen])\n",
        "    \n",
        "X = np.array(data).reshape(len(data), maxlen, 1)\n",
        "D = np.array(target).reshape(len(data), 1)\n",
        "\n",
        "# データ設定\n",
        "N_train = int(len(data) * 0.8)\n",
        "N_validation = len(data) - N_train\n",
        "\n",
        "x_train, x_test, d_train, d_test = train_test_split(X, D, test_size=N_validation)\n",
        "\n",
        "input_layer_size = 1\n",
        "hidden_layer_size = 5\n",
        "output_layer_size = 1\n",
        "\n",
        "weight_init_std = 0.01\n",
        "learning_rate = 0.1\n",
        "\n",
        "iters_num = 3000\n",
        "\n",
        "# ウェイト初期化 (バイアスは簡単のため省略)\n",
        "W_in = weight_init_std * np.random.randn(input_layer_size, hidden_layer_size)\n",
        "W_out = weight_init_std * np.random.randn(hidden_layer_size, output_layer_size)\n",
        "W = weight_init_std * np.random.randn(hidden_layer_size, hidden_layer_size)\n",
        "\n",
        "# 勾配\n",
        "W_in_grad = np.zeros_like(W_in)\n",
        "W_out_grad = np.zeros_like(W_out)\n",
        "W_grad = np.zeros_like(W)\n",
        "\n",
        "us = []\n",
        "zs = []\n",
        "\n",
        "u = np.zeros(hidden_layer_size)\n",
        "z = np.zeros(hidden_layer_size)\n",
        "y = np.zeros(output_layer_size)\n",
        "\n",
        "delta_out = np.zeros(output_layer_size)\n",
        "delta = np.zeros(hidden_layer_size)\n",
        "\n",
        "losses = []\n",
        "\n",
        "# トレーニング\n",
        "for i in range(iters_num):\n",
        "    for s in range(x_train.shape[0]):\n",
        "        us.clear()\n",
        "        zs.clear()\n",
        "        z *= 0\n",
        "        \n",
        "        # sにおける正解データ\n",
        "        d = d_train[s]\n",
        "\n",
        "        xs = x_train[s]        \n",
        "        \n",
        "        # 時系列ループ\n",
        "        for t in range(maxlen):\n",
        "            \n",
        "            # 入力値\n",
        "            x = xs[t]\n",
        "            u = np.dot(x, W_in) + np.dot(z, W)\n",
        "            us.append(u)\n",
        "            z = np.tanh(u)\n",
        "            zs.append(z)\n",
        "\n",
        "        y = np.dot(z, W_out)\n",
        "        \n",
        "        #誤差\n",
        "        loss = mean_squared_error(d, y)\n",
        "        \n",
        "        delta_out = d_mean_squared_error(d, y)\n",
        "        \n",
        "        delta *= 0\n",
        "        for t in range(maxlen)[::-1]:\n",
        "            \n",
        "            delta = (np.dot(delta, W.T) + np.dot(delta_out, W_out.T)) * d_tanh(us[t])\n",
        "            \n",
        "            # 勾配更新\n",
        "            W_grad += np.dot(zs[t].reshape(-1,1), delta.reshape(1,-1))\n",
        "            W_in_grad += np.dot(xs[t], delta.reshape(1,-1))\n",
        "        W_out_grad = np.dot(z.reshape(-1,1), delta_out)\n",
        "        \n",
        "        # 勾配適用\n",
        "        W -= learning_rate * W_grad\n",
        "        W_in -= learning_rate * W_in_grad\n",
        "        W_out -= learning_rate * W_out_grad.reshape(-1,1)\n",
        "            \n",
        "        W_in_grad *= 0\n",
        "        W_out_grad *= 0\n",
        "        W_grad *= 0\n",
        "\n",
        "# テスト        \n",
        "for s in range(x_test.shape[0]):\n",
        "    z *= 0\n",
        "\n",
        "    # sにおける正解データ\n",
        "    d = d_test[s]\n",
        "\n",
        "    xs = x_test[s]\n",
        "\n",
        "    # 時系列ループ\n",
        "    for t in range(maxlen):\n",
        "\n",
        "        # 入力値\n",
        "        x = xs[t]\n",
        "        u = np.dot(x, W_in) + np.dot(z, W)\n",
        "        z = np.tanh(u)\n",
        "\n",
        "    y = np.dot(z, W_out)\n",
        "\n",
        "    #誤差\n",
        "    loss = mean_squared_error(d, y)\n",
        "    print('loss:', loss, '   d:', d, '   y:', y)\n",
        "        \n",
        "        \n",
        "        \n",
        "original = np.full(maxlen, None)\n",
        "pred_num = 200\n",
        "\n",
        "xs = test_head\n",
        "\n",
        "# sin波予測\n",
        "for s in range(0, pred_num):\n",
        "    z *= 0\n",
        "    for t in range(maxlen):\n",
        "        \n",
        "        # 入力値\n",
        "        x = xs[t]\n",
        "        u = np.dot(x, W_in) + np.dot(z, W)\n",
        "        z = np.tanh(u)\n",
        "\n",
        "    y = np.dot(z, W_out)\n",
        "    original = np.append(original, y)\n",
        "    xs = np.delete(xs, 0)\n",
        "    xs = np.append(xs, y)\n",
        "\n",
        "plt.figure()\n",
        "plt.ylim([-1.5, 1.5])\n",
        "plt.plot(np.sin(np.linspace(0, round_num* pred_num / div_num * np.pi, pred_num)), linestyle='dotted', color='#aaaaaa')\n",
        "plt.plot(original, linestyle='dashed', color='black')\n",
        "plt.show()"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "loss: 1.0231756688222737e-07    d: [-0.29761864]    y: [-0.29716628]\n",
            "loss: 1.2201090156041725e-08    d: [-0.56307233]    y: [-0.56322854]\n",
            "loss: 4.038245320901024e-11    d: [-0.65766776]    y: [-0.65765877]\n",
            "loss: 1.012613171470193e-08    d: [0.13182648]    y: [0.13168417]\n",
            "loss: 6.953992391246068e-08    d: [0.49909101]    y: [0.49871807]\n",
            "loss: 5.5838333193039375e-08    d: [0.9518317]    y: [0.95149752]\n",
            "loss: 1.006541644062743e-07    d: [0.97784112]    y: [0.97739245]\n",
            "loss: 1.7033534573325946e-08    d: [-0.58880346]    y: [-0.58861889]\n",
            "loss: 5.2622812685924356e-08    d: [-0.78351093]    y: [-0.78383534]\n",
            "loss: 2.60918282588755e-10    d: [-0.49909101]    y: [-0.49906816]\n",
            "loss: 5.712849721315155e-08    d: [0.21857331]    y: [0.21823529]\n",
            "loss: 9.126545094849141e-08    d: [-0.33938943]    y: [-0.3389622]\n",
            "loss: 1.0304339135538811e-07    d: [-0.43793098]    y: [-0.43747701]\n",
            "loss: 1.1445752972814562e-07    d: [-0.33346065]    y: [-0.3329822]\n",
            "loss: 4.487028896407404e-08    d: [-0.99639027]    y: [-0.9960907]\n",
            "loss: 4.518845122030081e-08    d: [0.88624247]    y: [0.88654309]\n",
            "loss: 2.3166269315110745e-08    d: [-0.92833248]    y: [-0.92811723]\n",
            "loss: 7.877136645016467e-10    d: [-0.52075286]    y: [-0.52079255]\n",
            "loss: 8.254988700985761e-09    d: [-0.55262221]    y: [-0.5527507]\n",
            "loss: 8.291678208274492e-08    d: [0.47711265]    y: [0.47670543]\n",
            "loss: 3.515702303123948e-08    d: [0.60896952]    y: [0.60923469]\n",
            "loss: 4.6349673171343807e-08    d: [-0.94587102]    y: [-0.94556656]\n",
            "loss: 9.366188781076226e-08    d: [0.27953518]    y: [0.27910237]\n",
            "loss: 7.187670231870784e-08    d: [0.73863456]    y: [0.73901371]\n",
            "loss: 2.5175272023918393e-08    d: [-0.00629574]    y: [-0.00607135]\n",
            "loss: 4.238282860583762e-08    d: [0.54208448]    y: [0.54179333]\n",
            "loss: 5.364268767133932e-08    d: [0.99781582]    y: [0.99748827]\n",
            "loss: 7.761284716153206e-08    d: [-0.96441607]    y: [-0.96402208]\n",
            "loss: 7.29946476567447e-08    d: [0.07547747]    y: [0.07509539]\n",
            "loss: 5.374729213271438e-12    d: [0.66239735]    y: [0.66240063]\n",
            "loss: 3.917648631674704e-08    d: [-0.54736419]    y: [-0.54708428]\n",
            "loss: 1.1396823999698672e-07    d: [0.99393675]    y: [0.99345932]\n",
            "loss: 1.8916287872014157e-10    d: [0.96441607]    y: [0.96443552]\n",
            "loss: 6.104853811859931e-08    d: [-0.22471249]    y: [-0.22436306]\n",
            "loss: 3.374712152402054e-08    d: [-0.99393675]    y: [-0.99367695]\n",
            "loss: 7.459208770391463e-08    d: [-0.71705202]    y: [-0.71743827]\n",
            "loss: 3.3346599617550096e-09    d: [0.8649742]    y: [0.86505587]\n",
            "loss: 1.0502769561892032e-07    d: [-0.11933469]    y: [-0.11887638]\n",
            "loss: 3.7402892316320774e-08    d: [-0.40941891]    y: [-0.4091454]\n",
            "loss: 1.1571745806193722e-07    d: [0.39789889]    y: [0.39741782]\n",
            "loss: 1.3219146512941343e-08    d: [0.98611478]    y: [0.98595218]\n",
            "loss: 6.82473798362108e-08    d: [-0.0691982]    y: [-0.06882875]\n",
            "loss: 7.74733530193459e-08    d: [0.35709413]    y: [0.3567005]\n",
            "loss: 4.2013170888265586e-08    d: [0.99583607]    y: [0.9955462]\n",
            "loss: 2.1526965553360173e-08    d: [-0.92597363]    y: [-0.92618113]\n",
            "loss: 6.817128658024936e-08    d: [0.36882689]    y: [0.36845764]\n",
            "loss: 1.1810524118698761e-08    d: [0.91617219]    y: [0.9160185]\n",
            "loss: 1.511514695515008e-09    d: [0.52611726]    y: [0.52617224]\n",
            "loss: 8.054071771358459e-08    d: [0.96606148]    y: [0.96566013]\n",
            "loss: 1.956018209202484e-08    d: [0.43793098]    y: [0.43773319]\n",
            "loss: 2.365114287876332e-09    d: [-0.86811636]    y: [-0.86818514]\n",
            "loss: 7.426592869136916e-08    d: [-0.99975723]    y: [-0.99937184]\n",
            "loss: 3.369004516388851e-08    d: [0.77562491]    y: [0.77588449]\n",
            "loss: 4.2428732558073345e-09    d: [0.04405617]    y: [0.04414829]\n",
            "loss: 9.775107382870593e-09    d: [0.71705202]    y: [0.71719184]\n",
            "loss: 4.388321808758248e-10    d: [0.8773359]    y: [0.87736552]\n",
            "loss: 3.008800038067477e-08    d: [0.91363079]    y: [0.9138761]\n",
            "loss: 3.4353135857868203e-09    d: [-0.97784112]    y: [-0.97775823]\n",
            "loss: 8.641444917528756e-10    d: [0.96101064]    y: [0.96105221]\n",
            "loss: 8.711809080390049e-08    d: [0.26742375]    y: [0.26700633]\n",
            "loss: 5.0207765974484416e-08    d: [-0.87122411]    y: [-0.871541]\n",
            "loss: 2.838017375633304e-08    d: [-0.91617219]    y: [-0.91641043]\n",
            "loss: 1.5567226856080684e-09    d: [0.87122411]    y: [0.87127991]\n",
            "loss: 9.803777837499867e-09    d: [0.98394564]    y: [0.98380562]\n",
            "loss: 4.14806161320395e-08    d: [0.4036669]    y: [0.40337887]\n",
            "loss: 8.246547424877302e-08    d: [0.0880268]    y: [0.08762068]\n",
            "loss: 3.519939115353052e-08    d: [0.81012572]    y: [0.81039105]\n",
            "loss: 3.3475921950765465e-08    d: [0.41515469]    y: [0.41489593]\n",
            "loss: 3.920213266201233e-08    d: [-0.99524241]    y: [-0.9949624]\n",
            "loss: 6.676591329593643e-09    d: [-0.94789551]    y: [-0.94801107]\n",
            "loss: 2.7259396135763863e-08    d: [0.16916853]    y: [0.16893503]\n",
            "loss: 5.901954968651478e-08    d: [-0.95374324]    y: [-0.95339967]\n",
            "loss: 7.404264207154143e-08    d: [-0.72577151]    y: [-0.72615633]\n",
            "loss: 1.950089841485295e-08    d: [0.74286391]    y: [0.7430614]\n",
            "loss: 9.1706563452403e-09    d: [-0.94380904]    y: [-0.94394447]\n",
            "loss: 1.8342443304828862e-08    d: [-0.83516734]    y: [-0.83535887]\n",
            "loss: 4.0168629526980415e-08    d: [-0.94170965]    y: [-0.94142621]\n",
            "loss: 1.0438172705832991e-07    d: [0.32156366]    y: [0.32110676]\n",
            "loss: 7.96519652519902e-08    d: [-0.48263615]    y: [-0.48223703]\n",
            "loss: 4.522391104134645e-08    d: [0.03776568]    y: [0.03746493]\n",
            "loss: 8.671646127145176e-08    d: [0.34530476]    y: [0.34488831]\n",
            "loss: 3.0057469610811475e-08    d: [0.56307233]    y: [0.56282714]\n",
            "loss: 3.344824485318678e-08    d: [0.90843947]    y: [0.90869812]\n",
            "loss: 9.740854716739049e-08    d: [0.99940055]    y: [0.99895917]\n",
            "loss: 5.2790241458557934e-08    d: [0.85534252]    y: [0.85566745]\n",
            "loss: 1.3438734027154209e-08    d: [0.93739898]    y: [0.93756292]\n",
            "loss: 1.0825441262363102e-07    d: [0.99738016]    y: [0.99691486]\n",
            "loss: 7.346796784291748e-08    d: [-0.6992734]    y: [-0.69965672]\n",
            "loss: 9.62971935624344e-08    d: [-0.10682399]    y: [-0.10638513]\n",
            "loss: 4.974432822949506e-09    d: [-0.54208448]    y: [-0.54218422]\n",
            "loss: 7.137974552014359e-08    d: [0.9995987]    y: [0.99922087]\n",
            "loss: 1.1991319740559976e-07    d: [0.29761864]    y: [0.29712892]\n",
            "loss: 3.111699813899164e-08    d: [0.99322482]    y: [0.99297535]\n",
            "loss: 9.573921959195159e-08    d: [0.33346065]    y: [0.33302307]\n",
            "loss: 5.1818369725854564e-08    d: [-0.83168816]    y: [-0.83201008]\n",
            "loss: 1.1452214926486475e-08    d: [-0.98504973]    y: [-0.98489839]\n",
            "loss: 5.381547067764966e-08    d: [-0.64332332]    y: [-0.6436514]\n",
            "loss: 1.0539082479572855e-07    d: [0.43226238]    y: [0.43180327]\n",
            "loss: 3.26643400999611e-08    d: [-0.81380058]    y: [-0.81405617]\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "image/png": "\n",
            "text/plain": [
              "<Figure size 432x288 with 1 Axes>"
            ]
          },
          "metadata": {
            "tags": [],
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iI6SQJcHj62q"
      },
      "source": [
        "maxlen:2 iters_num:100\n",
        "\n",
        "![image.png]()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ylqvxrZ-9IvP"
      },
      "source": [
        "maxlen:2 iters_num:500\n",
        "\n",
        "![image.png]()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qTUnnjwx9X_8"
      },
      "source": [
        "maxlen:2 iters_num:3000 \n",
        "\n",
        "![image.png]()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xyfTBafWRb2M"
      },
      "source": [
        "maxlen:5 iters_num:100 \n",
        "\n",
        "![image.png]()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "52L8VqJ3RxqZ"
      },
      "source": [
        "maxlen:5 iters_num:500  \n",
        "\n",
        "![image.png]()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5SD7pir8UNRJ"
      },
      "source": [
        "maxlen:5 iters_num:3000  \n",
        "\n",
        "![image.png]()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pK4Jprp9UVx5"
      },
      "source": [
        "### 考察\n",
        "maxlen:2 のときは学習がうまく進まない。iters_numを増やしてもうまくいかない。\n",
        "maxlen:5 にすると、iters_num:3000でほぼ正確なサイン波を予測できている。\n"
      ]
    }
  ]
}