{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import numpy as np" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(10, 1, 28, 28)" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = np.random.rand(10, 1, 28, 28)\n", "x.shape" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(1, 28, 28)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x[0].shape" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(1, 28, 28)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x[1].shape" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(28, 28)" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x[0, 0].shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "https://github.com/oreilly-japan/deep-learning-from-scratch/blob/master/common/util.py" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "```py\n", "from common.util import im2col\n", "```" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def im2col(input_data, filter_h, filter_w, stride=1, pad=0):\n", " \"\"\"\n", "\n", " Parameters\n", " ----------\n", " input_data : (データ数, チャンネル, 高さ, 幅)の4次元配列からなる入力データ\n", " filter_h : フィルターの高さ\n", " filter_w : フィルターの幅\n", " stride : ストライド\n", " pad : パディング\n", "\n", " Returns\n", " -------\n", " col : 2次元配列\n", " \"\"\"\n", " N, C, H, W = input_data.shape\n", " out_h = (H + 2*pad - filter_h)//stride + 1\n", " out_w = (W + 2*pad - filter_w)//stride + 1\n", "\n", " img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')\n", " col = np.zeros((N, C, filter_h, filter_w, out_h, out_w))\n", "\n", " for y in range(filter_h):\n", " y_max = y + stride*out_h\n", " for x in range(filter_w):\n", " x_max = x + stride*out_w\n", " col[:, :, y, x, :, :] = img[:, :, y:y_max:stride, x:x_max:stride]\n", "\n", " col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N*out_h*out_w, -1)\n", " return col\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class Convolution:\n", " def __init__(self, W, b, stride=1, pad=0):\n", " self.W = W\n", " self.b = b\n", " self.stride = stride\n", " self.pad = pad\n", "\n", " def forward(self, x):\n", " FN, C, FH, FW = self.W.shape\n", " N, C, H, W = x.shape\n", " out_h = 1 + int((H + 2*self.pad - FH) / self.stride)\n", " out_w = 1 + int((W + 2*self.pad - FW) / self.stride)\n", "\n", " col = im2col(x, FH, FW, self.stride, self.pad)\n", " col_W = self.W.reshape(FN, -1).T\n", "\n", " out = np.dot(col, col_W) + self.b\n", " out = out.reshape(N, out_h, out_w, -1).transpose(0, 3, 1, 2)\n", "\n", " return out\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class Pooling:\n", " def __init__(self, pool_h, pool_w, stride=1, pad=0):\n", " self.pool_h = pool_h\n", " self.pool_w = pool_w\n", " self.stride = stride\n", " self.pad = pad\n", "\n", " def forward(self, x):\n", " N, C, H, W = x.shape\n", " out_h = int(1 + (H - self.pool_h) / self.stride)\n", " out_w = int(1 + (W - self.pool_w) / self.stride)\n", "\n", " col = im2col(x, self.pool_h, self.pool_w, self.stride, self.pad)\n", " col = col.reshape(-1, self.pool_h*self.pool_w)\n", "\n", " arg_max = np.argmax(col, axis=1)\n", " out = np.max(col, axis=1)\n", " out = out.reshape(N, out_h, out_w, C).transpose(0, 3, 1, 2)\n", "\n", " return out\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.0" } }, "nbformat": 4, "nbformat_minor": 2 }