{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Theano tensor 模块:索引" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using gpu device 1: Tesla C2075 (CNMeM is disabled)\n" ] } ], "source": [ "import theano\n", "import theano.tensor as T\n", "import numpy as np" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 简单索引" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`tensor` 模块完全支持 `numpy` 中的简单索引:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[1 3 5 7]\n" ] } ], "source": [ "t = T.arange(9)\n", "\n", "print t[1::2].eval()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`numpy` 结果:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[1 3 5 7]\n" ] } ], "source": [ "n = np.arange(9)\n", "\n", "print n[1::2]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## mask 索引" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`tensor` 模块虽然支持简单索引,但并不支持 `mask` 索引,例如这样的做法是错误的:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[[0 1 2]\n", " [0 1 2]\n", " [0 1 2]]\n", "\n", " [[0 1 2]\n", " [0 1 2]\n", " [3 4 5]]\n", "\n", " [[3 4 5]\n", " [3 4 5]\n", " [3 4 5]]]\n" ] } ], "source": [ "t = T.arange(9).reshape((3,3))\n", "\n", "print t[t > 4].eval()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`numpy` 中的结果:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[5 6 7 8]\n" ] } ], "source": [ "n = np.arange(9).reshape((3,3))\n", "\n", "print n[n > 4]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "要想像 `numpy` 一样得到正确结果,我们需要使用这样的方法:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[5 6 7 8]\n" ] } ], "source": [ "print t[(t > 4).nonzero()].eval()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 使用索引进行赋值" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`tensor` 模块不支持直接使用索引赋值,例如 `a[5] = b, a[5]+=b` 等是不允许的。\n", "\n", "不过可以考虑用 `set_subtensor` 和 `inc_subtensor` 来实现类似的功能:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### T.set_subtensor(x, y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "实现类似 r[10:] = 5 的功能:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false }, "outputs": [], "source": [ "r = T.vector()\n", "\n", "new_r = T.set_subtensor(r[10:], 5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### T.inc_subtensor(x, y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "实现类似 r[10:] += 5 的功能:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": true }, "outputs": [], "source": [ "r = T.vector()\n", "\n", "new_r = T.inc_subtensor(r[10:], 5)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 0 }