{ "cells": [ { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# 「深層学習」読書会 〜第7章〜" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "

2016/07/02 機械学習 名古屋 第5回勉強会

" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## 第7章 再帰型ニューラルネット" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "abstract:\n", "\n", "+ RNN(再帰型ニューラルネットワーク)\n", " + 以下のようなデータの特徴をうまく取り扱うNN:\n", " + データの長さがサンプルごとにまちまち\n", " + 系列内の要素の並び(=コンテキスト)に意味がある\n", " + 例:音声・言語・動画\n", "+ LSTM(長・短期記憶)\n", " + より長期のコンテキストをモデル化可能\n", "+ CTC(コネクショニスト時系列分類法)\n", " + 入力系列とは長さの異なる系列を推定(出力)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### 7.1 系列データの分類" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "**系列データ**:\n", "\n", "+ 個々の要素の順序付き集まりデータ" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$$\n", "{\\bf x}^1, {\\bf x}^2, {\\bf x}^3, \\dots , {\\bf x}^T\n", "$$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "+ 音声・動画・テキストなど\n", "+ 系列の長さ $T$ は、一般に可変\n", "+ インデックス $t = 1, 2, 3, \\dots$ を **時刻**と呼ぶ(*時間*とは言ってない)。" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "**RNN(再帰型ニューラルネット)**:\n", "コンテキストを学習し、分類出来る。\n", "\n", "**コンテキスト(文脈)**:系列内の要素の並び、依存関係\n", "\n", "要素の例:\n", "\n", "+ 文章中の「単語」\n", "+ 音声信号中の「音素」" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### 7.2 RNNの構造" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**RNN(再帰型ニューラルネット)**:\n", "\n", "+ 内部に(有向)閉路を持つNNの総称\n", "+ 特徴:\n", " + 情報を一時的に記憶\n", " + 振る舞いを動的に変化" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true, "slideshow": { "slide_type": "skip" } }, "outputs": [], "source": [ "from graphviz import Digraph" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true, "slideshow": { "slide_type": "skip" } }, "outputs": [], "source": [ "fig_7_3_a = Digraph(\"fig_7_3_a\", format=\"svg\")\n", "fig_7_3_a.graph_attr.update(compound=\"true\", splines=\"line\")\n", "\n", "fig_7_3_a.body.extend(['rankdir=BT'])\n", "\n", "fig_7_3_a.node_attr.update(shape='circle', color='black', penwidth='2')\n", "fig_7_3_a.node('na1', '')\n", "fig_7_3_a.node('na2', '')\n", "fig_7_3_a.node('na3', 'i')\n", "fig_7_3_a.node('nc1', '')\n", "fig_7_3_a.node('nc2', '')\n", "fig_7_3_a.node('nc3', 'k')\n", "\n", "c0 = Digraph('cluster0')\n", "c0.node('nb1', '')\n", "c0.node('nb2', '')\n", "c0.node('nb3', 'j')\n", "c0.node('nb4', '')\n", "\n", "fig_7_3_a.subgraph(c0)\n", "\n", "fig_7_3_a.edge('na1', 'nb1')\n", "fig_7_3_a.edge('na1', 'nb2')\n", "fig_7_3_a.edge('na1', 'nb3')\n", "fig_7_3_a.edge('na1', 'nb4')\n", "fig_7_3_a.edge('na2', 'nb1')\n", "fig_7_3_a.edge('na2', 'nb2')\n", "fig_7_3_a.edge('na2', 'nb3')\n", "fig_7_3_a.edge('na2', 'nb4')\n", "fig_7_3_a.edge('na3', 'nb1')\n", "fig_7_3_a.edge('na3', 'nb2')\n", "fig_7_3_a.edge('na3', 'nb3')\n", "fig_7_3_a.edge('na3', 'nb4')\n", "fig_7_3_a.edge('nb1', 'nc1')\n", "fig_7_3_a.edge('nb1', 'nc2')\n", "fig_7_3_a.edge('nb1', 'nc3')\n", "fig_7_3_a.edge('nb2', 'nc1')\n", "fig_7_3_a.edge('nb2', 'nc2')\n", "fig_7_3_a.edge('nb2', 'nc3')\n", "fig_7_3_a.edge('nb3', 'nc1')\n", "fig_7_3_a.edge('nb3', 'nc2')\n", "fig_7_3_a.edge('nb3', 'nc3')\n", "fig_7_3_a.edge('nb4', 'nc1')\n", "fig_7_3_a.edge('nb4', 'nc2')\n", "fig_7_3_a.edge('nb4', 'nc3')\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false, "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "fig_7_3_a\n", "\n", "cluster0\n", "\n", "\n", "\n", "na1\n", "\n", "\n", "\n", "nb1\n", "\n", "\n", "\n", "na1->nb1\n", "\n", "\n", "\n", "\n", "nb2\n", "\n", "\n", "\n", "na1->nb2\n", "\n", "\n", "\n", "\n", "nb3\n", "\n", "j\n", "\n", "\n", "na1->nb3\n", "\n", "\n", "\n", "\n", "nb4\n", "\n", "\n", "\n", "na1->nb4\n", "\n", "\n", "\n", "\n", "na2\n", "\n", "\n", "\n", "na2->nb1\n", "\n", "\n", "\n", "\n", "na2->nb2\n", "\n", "\n", "\n", "\n", "na2->nb3\n", "\n", "\n", "\n", "\n", "na2->nb4\n", "\n", "\n", "\n", "\n", "na3\n", "\n", "i\n", "\n", "\n", "na3->nb1\n", "\n", "\n", "\n", "\n", "na3->nb2\n", "\n", "\n", "\n", "\n", "na3->nb3\n", "\n", "\n", "\n", "\n", "na3->nb4\n", "\n", "\n", "\n", "\n", "nc1\n", "\n", "\n", "\n", "nc2\n", "\n", "\n", "\n", "nc3\n", "\n", "k\n", "\n", "\n", "nb1->nc1\n", "\n", "\n", "\n", "\n", "nb1->nc2\n", "\n", "\n", "\n", "\n", "nb1->nc3\n", "\n", "\n", "\n", "\n", "nb2->nc1\n", "\n", "\n", "\n", "\n", "nb2->nc2\n", "\n", "\n", "\n", "\n", "nb2->nc3\n", "\n", "\n", "\n", "\n", "nb3->nc1\n", "\n", "\n", "\n", "\n", "nb3->nc2\n", "\n", "\n", "\n", "\n", "nb3->nc3\n", "\n", "\n", "\n", "\n", "nb4->nc1\n", "\n", "\n", "\n", "\n", "nb4->nc2\n", "\n", "\n", "\n", "\n", "nb4->nc3\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "fig_7_3_a" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true, "slideshow": { "slide_type": "skip" } }, "outputs": [], "source": [ "fig_7_3_b = Digraph(\"fig_7_3_b\", format=\"svg\")\n", "fig_7_3_b.graph_attr.update(compound=\"true\", splines=\"line\")\n", "\n", "fig_7_3_b.body.extend(['rankdir=BT'])\n", "\n", "fig_7_3_b.node_attr.update(shape='circle', color='black', penwidth='2')\n", "\n", "c1 = Digraph('cluster_1')\n", "c1.body.append('label=\"t-1\"')\n", "c1.node('nc11', '')\n", "c1.node('nc12', 'j\\'')\n", "c1.node('nc13', '')\n", "c1.node('nc14', '')\n", "\n", "c2 = Digraph('cluster_2')\n", "c2.body.append('label=\"t\"')\n", "c2.body.append('labelloc=\"b\"')\n", "c2.node('nc21', '')\n", "c2.node('nc22', '')\n", "c2.node('nc23', 'j')\n", "c2.node('nc24', '')\n", "\n", "fig_7_3_b.subgraph(c1)\n", "fig_7_3_b.subgraph(c2)\n", "\n", "fig_7_3_b.edge('nc11', 'nc21')\n", "fig_7_3_b.edge('nc11', 'nc22')\n", "fig_7_3_b.edge('nc11', 'nc23')\n", "fig_7_3_b.edge('nc11', 'nc24')\n", "fig_7_3_b.edge('nc12', 'nc21')\n", "fig_7_3_b.edge('nc12', 'nc22')\n", "fig_7_3_b.edge('nc12', 'nc23')\n", "fig_7_3_b.edge('nc12', 'nc24')\n", "fig_7_3_b.edge('nc13', 'nc21')\n", "fig_7_3_b.edge('nc13', 'nc22')\n", "fig_7_3_b.edge('nc13', 'nc23')\n", "fig_7_3_b.edge('nc13', 'nc24')\n", "fig_7_3_b.edge('nc14', 'nc21')\n", "fig_7_3_b.edge('nc14', 'nc22')\n", "fig_7_3_b.edge('nc14', 'nc23')\n", "fig_7_3_b.edge('nc14', 'nc24')\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false, "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "fig_7_3_b\n", "\n", "cluster_1\n", "\n", "t-1\n", "\n", "cluster_2\n", "\n", "t\n", "\n", "\n", "nc11\n", "\n", "\n", "\n", "nc21\n", "\n", "\n", "\n", "nc11->nc21\n", "\n", "\n", "\n", "\n", "nc22\n", "\n", "\n", "\n", "nc11->nc22\n", "\n", "\n", "\n", "\n", "nc23\n", "\n", "j\n", "\n", "\n", "nc11->nc23\n", "\n", "\n", "\n", "\n", "nc24\n", "\n", "\n", "\n", "nc11->nc24\n", "\n", "\n", "\n", "\n", "nc12\n", "\n", "j'\n", "\n", "\n", "nc12->nc21\n", "\n", "\n", "\n", "\n", "nc12->nc22\n", "\n", "\n", "\n", "\n", "nc12->nc23\n", "\n", "\n", "\n", "\n", "nc12->nc24\n", "\n", "\n", "\n", "\n", "nc13\n", "\n", "\n", "\n", "nc13->nc21\n", "\n", "\n", "\n", "\n", "nc13->nc22\n", "\n", "\n", "\n", "\n", "nc13->nc23\n", "\n", "\n", "\n", "\n", "nc13->nc24\n", "\n", "\n", "\n", "\n", "nc14\n", "\n", "\n", "\n", "nc14->nc21\n", "\n", "\n", "\n", "\n", "nc14->nc22\n", "\n", "\n", "\n", "\n", "nc14->nc23\n", "\n", "\n", "\n", "\n", "nc14->nc24\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "fig_7_3_b" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": true, "slideshow": { "slide_type": "skip" } }, "outputs": [], "source": [ "fig_7_4 = Digraph(\"fig_7_4\", format=\"svg\")\n", "fig_7_4.graph_attr.update(compound=\"true\")\n", "\n", "fig_7_4.body.extend(['rankdir=BT'])\n", "\n", "fig_7_4.node('x', 'x_t', shape='circle', color='black', penwidth='2')\n", "fig_7_4.node('A', 'A', shape='rect', color='black', penwidth='2')\n", "fig_7_4.node('y', 'y_t', shape='circle', color='black', penwidth='2')\n", "\n", "fig_7_4.node('x0', 'x_0', shape='circle', color='black', penwidth='2')\n", "fig_7_4.node('A0', 'A', shape='rect', color='black', penwidth='2')\n", "fig_7_4.node('y0', 'y_0', shape='circle', color='black', penwidth='2')\n", "fig_7_4.node('x1', 'x_1', shape='circle', color='black', penwidth='2')\n", "fig_7_4.node('A1', 'A', shape='rect', color='black', penwidth='2')\n", "fig_7_4.node('y1', 'y_1', shape='circle', color='black', penwidth='2')\n", "fig_7_4.node('x2', 'x_2', shape='circle', color='black', penwidth='2')\n", "fig_7_4.node('A2', 'A', shape='rect', color='black', penwidth='2')\n", "fig_7_4.node('y2', 'y_2', shape='circle', color='black', penwidth='2')\n", "fig_7_4.node('xt', 'x_t', shape='circle', color='black', penwidth='2')\n", "fig_7_4.node('At', 'A', shape='rect', color='black', penwidth='2')\n", "fig_7_4.node('yt', 'y_t', shape='circle', color='black', penwidth='2')\n", "\n", "fig_7_4.edge('x', 'A')\n", "fig_7_4.edge('A', 'y')\n", "\n", "fig_7_4.edge('A', 'A', minlen=\"2\", dir=\"back\")\n", "\n", "fig_7_4.edge('x0', 'A0')\n", "fig_7_4.edge('A0', 'y0')\n", "fig_7_4.edge('x1', 'A1')\n", "fig_7_4.edge('A1', 'y1')\n", "fig_7_4.edge('x2', 'A2')\n", "fig_7_4.edge('A2', 'y2')\n", "fig_7_4.edge('xt', 'At')\n", "fig_7_4.edge('At', 'yt')\n", "\n", "fig_7_4.edge('x2', 'xt', minlen=\"2\", constraint=\"false\", color=\"transparent\", label=\"…\")\n", "\n", "fig_7_4.edge('A', 'A0', minlen=\"3\", constraint=\"false\", color=\"transparent\", label=\"=\", fontsize=\"28.0\")\n", "\n", "fig_7_4.edge('A0', 'A1', minlen=\"1\", constraint=\"false\")\n", "fig_7_4.edge('A1', 'A2', minlen=\"1\", constraint=\"false\")\n", "fig_7_4.edge('A2', 'At', minlen=\"2\", constraint=\"false\")\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false, "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "fig_7_4\n", "\n", "\n", "x\n", "\n", "x_t\n", "\n", "\n", "A\n", "\n", "A\n", "\n", "\n", "x->A\n", "\n", "\n", "\n", "\n", "A->A\n", "\n", "\n", "\n", "\n", "y\n", "\n", "y_t\n", "\n", "\n", "A->y\n", "\n", "\n", "\n", "\n", "A0\n", "\n", "A\n", "\n", "\n", "A->A0\n", "\n", "\n", "=\n", "\n", "\n", "x0\n", "\n", "x_0\n", "\n", "\n", "x0->A0\n", "\n", "\n", "\n", "\n", "y0\n", "\n", "y_0\n", "\n", "\n", "A0->y0\n", "\n", "\n", "\n", "\n", "A1\n", "\n", "A\n", "\n", "\n", "A0->A1\n", "\n", "\n", "\n", "\n", "x1\n", "\n", "x_1\n", "\n", "\n", "x1->A1\n", "\n", "\n", "\n", "\n", "y1\n", "\n", "y_1\n", "\n", "\n", "A1->y1\n", "\n", "\n", "\n", "\n", "A2\n", "\n", "A\n", "\n", "\n", "A1->A2\n", "\n", "\n", "\n", "\n", "x2\n", "\n", "x_2\n", "\n", "\n", "x2->A2\n", "\n", "\n", "\n", "\n", "xt\n", "\n", "x_t\n", "\n", "\n", "x2->xt\n", "\n", "\n", "\n", "\n", "\n", "y2\n", "\n", "y_2\n", "\n", "\n", "A2->y2\n", "\n", "\n", "\n", "\n", "At\n", "\n", "A\n", "\n", "\n", "A2->At\n", "\n", "\n", "\n", "\n", "xt->At\n", "\n", "\n", "\n", "\n", "yt\n", "\n", "y_t\n", "\n", "\n", "At->yt\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "fig_7_4" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "誤差関数:(順伝播ネットワークと同様)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$$\n", "E({\\bf w}) = - \\sum_n \\sum_t \\sum_k d^t_{nk} \\log y^t_k({\\bf x}_n; {\\bf w})\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "ただし $d^t_n$:$n$ 番目のサンプル ${\\bf x}_n$ に対する、時刻 $t$ での目標出力 \n", "($(d^t_{n1}, d^t_{n2}, \\dots , d^t_{nk})$ というベクトル)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "**双方向RNN**: \n", "データを 順方向 逆方向 両方の入力で与えるRNNを統合したもの。\n", "\n", "+ データの数が有限ならば有効\n", "+ オンライン学習には不向き" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### 7.3 順伝播計算" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "《略》" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### 7.4 逆伝播計算" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "《略》" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### 7.5 長・短期記憶(LSTM)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "#### 7.5.1 RNN の勾配消失問題" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": true, "slideshow": { "slide_type": "skip" } }, "outputs": [], "source": [ "%matplotlib inline\n", "%config InlineBackend.figure_formats = {'svg',}\n", "\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import tensorflow as tf\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": true, "slideshow": { "slide_type": "skip" } }, "outputs": [], "source": [ "def gradient_vanishing_image():\n", " x_linspace = tf.constant(np.linspace(-5.0, 5.0))\n", " grad1 = tf.sigmoid(x_linspace)\n", " grad2 = tf.sigmoid(grad1)\n", " grad3 = tf.sigmoid(grad2)\n", " \n", " with tf.Session() as sess:\n", " x, y1, y2, y3 = sess.run([x_linspace, grad1, grad2, grad3])\n", " plt.plot(x, y1, \"b\")\n", " plt.plot(x, y2, \"g\")\n", " plt.plot(x, y3, \"r\")" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "※イメージ" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": false, "scrolled": true }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "gradient_vanishing_image()" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "notes" } }, "source": [ "※ RNNの、ではなく一般のDNNの勾配消失イメージです。`sigmoid`関数を重ねて適用するとだんだん平らになっていくという実例。" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "#### 7.5.2 LSTM の概要" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "**LSTM(長・短期記憶)**: \n", "RNNの拡張モデル(の1つ)。 \n", "RNNの中間層の各ユニットをメモリユニットと呼ぶ要素で置き換えた構造を持つ。" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": true, "slideshow": { "slide_type": "skip" } }, "outputs": [], "source": [ "fig_7_7 = Digraph(\"fig_7_7\", format=\"svg\")\n", "fig_7_7.graph_attr.update(compound='true', ranksep=\"0.3 equally\")\n", "fig_7_7.body.extend(['rankdir=BT'])\n", "\n", "cb = Digraph('clusterB')\n", "cb.graph_attr.update(color='transparent')\n", "cb.node('x00', '', shape='none', pos='0')\n", "cb.node('x0', '', shape='none', pos='0')\n", "cb.node('x01', '', shape='none', pos='0')\n", "\n", "fig_7_7.subgraph(cb)\n", "\n", "c1 = Digraph('cluster1')\n", "c1.graph_attr.update(color='transparent')\n", "c1.node('xb0', '', shape='none')\n", "c1.node('xc0', '', shape='none')\n", "c1.node('xf', '', shape='none')\n", "c1.node('xd0', '', shape='none')\n", "c1.edge('xb0', 'xc0', color='transparent')\n", "c1.edge('xc0', 'xf', color='transparent')\n", "c1.edge('xf', 'xd0', color='transparent')\n", "\n", "fig_7_7.subgraph(c1)\n", "fig_7_7.edge('x00', 'xb0', color='transparent')\n", "fig_7_7.edge('xf', 'FG', constraint='false', minlen=\"2\", ltail='cluster1')\n", "\n", "c_main = Digraph('cluster_main')\n", "c_main.node('B', 'B', shape='circle', style='filled', fillcolor='#009999', penwidth='2', pos='0')\n", "\n", "s1 = Digraph('subgraph1')\n", "s1.graph_attr.update(rank=\"same\")\n", "s1.node('c_main', '', shape='none')\n", "s1.node('e0', '', shape='circle', width='0.3', style='filled', fillcolor='#0000cc', penwidth='1', pos='0')\n", "s1.node('IG', 'C', shape='circle', style='filled', fillcolor='#00cc00', penwidth='2')\n", "s1.edge('c_main', 'e0', color='transparent', minlen=\"2\")\n", "s1.edge('e0', 'IG', minlen=\"2\", dir='back')\n", "\n", "c_main.subgraph(s1)\n", "c_main.edge('B', 'e0')\n", "\n", "s2 = Digraph('subgraph2')\n", "s2.graph_attr.update(rank=\"same\")\n", "s2.node('FG', 'F', shape='circle', style='filled', fillcolor='#00cc00', penwidth='2', group=\"f\")\n", "s2.node('e1', '', shape='circle', width='0.3', style='filled', fillcolor='#0000cc', penwidth='1')\n", "s2.node('A', 'A', shape='circle', style='filled', fillcolor='#ff0000', penwidth='2', pos='0')\n", "s2.node('E', '', shape='none')\n", "s2.edge('FG', 'e1', minlen=\"2\")\n", "s2.edge('e1', 'A', dir='back', minlen=\"1\")\n", "s2.edge('A', 'E', color='transparent')\n", "s2.edge('FG', 'A', dir='back', minlen=\"3\")\n", "s2.edge('A', 'e1', dir='back', minlen=\"1\")\n", "\n", "c_main.subgraph(s2)\n", "c_main.edge('e0', 'A')\n", "c_main.edge('IG', 'A', dir='back')\n", "c_main.edge('c_main', 'A', dir='back', color='transparent')\n", "\n", "s3 = Digraph('subgraph3')\n", "s3.graph_attr.update(rank=\"same\")\n", "s3.node('D0', '', shape='none')\n", "s3.node('e2', '', shape='circle', width='0.3', style='filled', fillcolor='#0000cc', penwidth='1', pos='0')\n", "s3.node('OG', 'D', shape='circle', style='filled', fillcolor='#00cc00', penwidth='2')\n", "s3.edge('D0', 'e2', dir='back', color='transparent', minlen=\"2\")\n", "s3.edge('e2', 'OG', dir='back', minlen=\"2\")\n", "\n", "c_main.subgraph(s3)\n", "c_main.edge('A', 'e2')\n", "c_main.edge('A', 'OG')\n", "c_main.edge('A', 'D0', color='transparent')\n", "\n", "fig_7_7.subgraph(c_main)\n", "\n", "fig_7_7.node('y0', '', shape='none')\n", "\n", "fig_7_7.edge('x0', 'B')\n", "fig_7_7.edge('e2', 'y0')\n", "\n", "c2 = Digraph('cluster2')\n", "c2.graph_attr.update(color='transparent')\n", "c2.node('xb1', '', shape='none')\n", "c2.node('x1', '', shape='none')\n", "c2.node('xf1', '', shape='none')\n", "c2.node('x2', '', shape='none')\n", "c2.edge('xb1', 'x1', color='transparent')\n", "c2.edge('x1', 'xf1', color='transparent')\n", "c2.edge('xf1', 'x2', color='transparent')\n", "\n", "fig_7_7.subgraph(c2)\n", "fig_7_7.edge('x01', 'xb1', color='transparent', ltail='clusterB')\n", "fig_7_7.edge('x1', 'IG', constraint='false', ltail='cluster2')\n", "fig_7_7.edge('x2', 'OG', constraint='false', ltail='cluster2')\n" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": false, "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "fig_7_7\n", "\n", "clusterB\n", "\n", "\n", "cluster1\n", "\n", "\n", "cluster_main\n", "\n", "\n", "cluster2\n", "\n", "\n", "\n", "x00\n", "\n", "\n", "xb0\n", "\n", "\n", "x00->xb0\n", "\n", "\n", "\n", "\n", "x0\n", "\n", "\n", "B\n", "\n", "B\n", "\n", "\n", "x0->B\n", "\n", "\n", "\n", "\n", "x01\n", "\n", "\n", "xb1\n", "\n", "\n", "x01->xb1\n", "\n", "\n", "\n", "\n", "xc0\n", "\n", "\n", "xb0->xc0\n", "\n", "\n", "\n", "\n", "xf\n", "\n", "\n", "xc0->xf\n", "\n", "\n", "\n", "\n", "xd0\n", "\n", "\n", "xf->xd0\n", "\n", "\n", "\n", "\n", "FG\n", "\n", "F\n", "\n", "\n", "xf->FG\n", "\n", "\n", "\n", "\n", "e1\n", "\n", "\n", "\n", "FG->e1\n", "\n", "\n", "\n", "\n", "A\n", "\n", "A\n", "\n", "\n", "FG->A\n", "\n", "\n", "\n", "\n", "e0\n", "\n", "\n", "\n", "B->e0\n", "\n", "\n", "\n", "\n", "c_main\n", "\n", "\n", "c_main->e0\n", "\n", "\n", "\n", "\n", "c_main->A\n", "\n", "\n", "\n", "\n", "IG\n", "\n", "C\n", "\n", "\n", "e0->IG\n", "\n", "\n", "\n", "\n", "e0->A\n", "\n", "\n", "\n", "\n", "IG->A\n", "\n", "\n", "\n", "\n", "e1->A\n", "\n", "\n", "\n", "\n", "A->e1\n", "\n", "\n", "\n", "\n", "E\n", "\n", "\n", "A->E\n", "\n", "\n", "\n", "\n", "D0\n", "\n", "\n", "A->D0\n", "\n", "\n", "\n", "\n", "e2\n", "\n", "\n", "\n", "A->e2\n", "\n", "\n", "\n", "\n", "OG\n", "\n", "D\n", "\n", "\n", "A->OG\n", "\n", "\n", "\n", "\n", "D0->e2\n", "\n", "\n", "\n", "\n", "e2->OG\n", "\n", "\n", "\n", "\n", "y0\n", "\n", "\n", "e2->y0\n", "\n", "\n", "\n", "\n", "x1\n", "\n", "\n", "xb1->x1\n", "\n", "\n", "\n", "\n", "x1->IG\n", "\n", "\n", "\n", "\n", "xf1\n", "\n", "\n", "x1->xf1\n", "\n", "\n", "\n", "\n", "x2\n", "\n", "\n", "xf1->x2\n", "\n", "\n", "\n", "\n", "x2->OG\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "fig_7_7" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "TensorFlow には `tf.nn.rnn_cell.BasicLSTMCell` というクラスが用意されており利用可能。 \n", "参照: [class tf.nn.rnn_cell.BasicLSTMCell](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/g3doc/api_docs/python/rnn_cell.md#class-tfnnrnn_cellbasiclstmcell-basiclstmcell)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "cell = tf.nn.rnn_cell.BasicLSTMCell(size, forget_bias=0.5)\n", "(cell_output, new_state) = cell(inputs, old_state)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "#### 7.5.3 順伝播計算" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "《略》" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "#### 7.5.4 逆伝播計算" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "《略》" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### 7.6 入出力間で系列長が異なる場合" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "#### 7.6.1 隠れマルコフモデル" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "#### 7.6.2 コネクショニスト時系列分類法" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "notes" } }, "source": [ "《略》" ] } ], "metadata": { "celltoolbar": "Slideshow", "kernelspec": { "display_name": "TensorFlow v0.8 (Python 3)", "language": "python", "name": "tensorflow08" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.1" } }, "nbformat": 4, "nbformat_minor": 0 }