{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "mutable struct MulLayer\n", " x\n", " y\n", " (::Type{MulLayer})() = new()\n", "end" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "forward (generic function with 1 method)" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "function forward(self::MulLayer, x, y)\n", " self.x = x\n", " self.y = y\n", " out = x * y\n", " return out\n", "end" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "backward (generic function with 1 method)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "function backward(self::MulLayer, dout)\n", " dx = dout * self.y\n", " dy = dout * self.x\n", " return dx, dy\n", "end" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1.1" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "apple = 100\n", "apple_num = 2\n", "tax = 1.1" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "MulLayer(#undef, #undef)" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mul_apple_layer = MulLayer()\n", "mul_tax_layer = MulLayer()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "220.00000000000003" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "apple_price = forward(mul_apple_layer, apple, apple_num)\n", "price = forward(mul_tax_layer, apple_price, tax)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(2.2, 110.00000000000001, 200)" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dprice = 1\n", "dapple_price, dtax = backward(mul_tax_layer, dprice)\n", "dapple, dapple_num = backward(mul_apple_layer, dapple_price)\n", "(dapple, dapple_num, dtax)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "backward (generic function with 2 methods)" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mutable struct AddLayer end\n", "\n", "function forward(self::AddLayer, x, y)\n", " out = x + y\n", " return out\n", "end \n", " \n", "function backward(self::AddLayer, dout)\n", " dx = dout * 1\n", " dy = dout * 1\n", " return dx, dy\n", "end" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1.1" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "apple = 100\n", "apple_num = 2\n", "orange = 150\n", "orange_num = 3\n", "tax = 1.1" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "MulLayer(#undef, #undef)" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mul_apple_layer = MulLayer()\n", "mul_orange_layer = MulLayer()\n", "add_apple_orange_layer = AddLayer()\n", "mul_tax_layer = MulLayer()" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "715.0000000000001" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "apple_price = forward(mul_apple_layer, apple, apple_num)\n", "orange_price = forward(mul_orange_layer, orange, orange_num)\n", "all_price = forward(add_apple_orange_layer, apple_price, orange_price)\n", "price = forward(mul_tax_layer, all_price, tax)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(2.2, 110.00000000000001, 3.3000000000000003, 165.0, 650)" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dprice = 1\n", "dall_price, dtax = backward(mul_tax_layer, dprice)\n", "dapple_price, dorange_price = backward(add_apple_orange_layer, dall_price)\n", "dorange, dorange_num = backward(mul_orange_layer, dorange_price)\n", "dapple, dapple_num = backward(mul_apple_layer, dapple_price)\n", "(dapple, dapple_num, dorange, dorange_num, dtax)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Julia 0.6.0", "language": "julia", "name": "julia-0.6" }, "language_info": { "file_extension": ".jl", "mimetype": "application/julia", "name": "julia", "version": "0.6.0" } }, "nbformat": 4, "nbformat_minor": 2 }