{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Theano 条件语句" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`theano` 中提供了两种条件语句,`ifelse` 和 `switch`,两者都是用于在符号变量上使用条件语句:\n", "\n", "- `ifelse(condition, var1, var2)`\n", " - 如果 `condition` 为 `true`,返回 `var1`,否则返回 `var2`\n", "- `switch(tensor, var1, var2)`\n", " - Elementwise `ifelse` 操作,更一般化\n", "- `switch` 会计算两个输出,而 `ifelse` 只会根据给定的条件,计算相应的输出。\n", "\n", "`ifelse` 需要从 `theano.ifelse` 中导入,而 `switch` 在 `theano.tensor` 模块中。" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using gpu device 1: Tesla K10.G2.8GB (CNMeM is disabled)\n" ] } ], "source": [ "import theano, time\n", "import theano.tensor as T\n", "import numpy as np\n", "from theano.ifelse import ifelse" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "假设我们有两个标量参数:$a, b$,和两个矩阵 $\\mathbf{x, y}$,定义函数为:\n", "\n", "$$ \n", "\\mathbf z = f(a, b,\\mathbf{x, y}) = \\left\\{ \n", "\\begin{aligned}\n", " \\mathbf x & ,\\ a <= b\\\\\n", " \\mathbf y & ,\\ a > b\n", "\\end{aligned}\n", "\\right.\n", "$$\n", "\n", "定义变量:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "a, b = T.scalars('a', 'b')\n", "x, y = T.matrices('x', 'y')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "用 `ifelse` 构造,小于等于用 `T.lt()`,大于等于用 `T.gt()`:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [], "source": [ "z_ifelse = ifelse(T.lt(a, b), x, y)\n", "\n", "f_ifelse = theano.function([a, b, x, y], z_ifelse)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "用 `switch` 构造:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "z_switch = T.switch(T.lt(a, b), x, y)\n", "\n", "f_switch = theano.function([a, b, x, y], z_switch)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "测试数据:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false }, "outputs": [], "source": [ "val1 = 0.\n", "val2 = 1.\n", "big_mat1 = np.ones((10000, 1000), dtype=theano.config.floatX)\n", "big_mat2 = np.ones((10000, 1000), dtype=theano.config.floatX)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "比较两者的运行速度:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "time spent evaluating both values 0.638598 sec\n", "time spent evaluating one value 0.461249 sec\n" ] } ], "source": [ "n_times = 10\n", "\n", "tic = time.clock()\n", "for i in xrange(n_times):\n", " f_switch(val1, val2, big_mat1, big_mat2)\n", "print 'time spent evaluating both values %f sec' % (time.clock() - tic)\n", "\n", "tic = time.clock()\n", "for i in xrange(n_times):\n", " f_ifelse(val1, val2, big_mat1, big_mat2)\n", "print 'time spent evaluating one value %f sec' % (time.clock() - tic)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 0 }