{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "zeropad (generic function with 2 methods)" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "function zeropad(a::AbstractArray{T,N}, pad_width::NTuple{N,Tuple{Int,Int}}) where {T,N}\n", " sizes = [b+p1+p2 for (b,(p1,p2))=zip(size(a),pad_width)]\n", " r = zeros(T, sizes...)\n", " ranges = [p1+1:p1+b for (b,(p1,_))=zip(size(a),pad_width)]\n", " r[ranges...] = a\n", " r\n", "end\n", "\n", "@inline zeropad(a::AbstractArray{T,N}, pad_width::Tuple{Int,Int}...) where {T,N} = zeropad(a, pad_width)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "im2col (generic function with 3 methods)" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "function im2col(input_data::AbstractArray{T,4}, filter_w::Int, filter_h::Int, stride::Int=1, pad::Int=0) where {T}\n", " W, H, C, N = size(input_data)\n", " out_h = (H + 2pad - filter_h) ÷ stride + 1\n", " out_w = (W + 2pad - filter_w) ÷ stride + 1\n", " img = pad==0 ? input_data : zeropad(input_data, (pad, pad), (pad, pad), (0, 0), (0, 0))\n", " col = zeros(T, (out_w, out_h, filter_w, filter_h, C, N))\n", " for y = 1:filter_h\n", " y_max = y + stride*out_h - 1\n", " for x = 1:filter_w\n", " x_max = x + stride*out_w - 1\n", " col[:, :, x, y, :, :] = img[x:stride:x_max, y:stride:y_max, :, :]\n", " end\n", " end\n", " reshape(permutedims(col, (3, 4, 5, 1, 2, 6)), filter_w*filter_h*C, out_w*out_h*N)\n", "end" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "col2im (generic function with 3 methods)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "function col2im(col::AbstractArray{T,2}, input_shape::NTuple{4,Int}, filter_h::Int, filter_w::Int, stride::Int=1, pad::Int=0) where {T}\n", " W, H, C, N = input_shape\n", " out_h = (H + 2pad - filter_h) ÷ stride + 1\n", " out_w = (W + 2pad - filter_w) ÷ stride + 1\n", " _col = permutedims(reshape(col, filter_w, filter_h, C, out_w, out_h, N), (4, 5, 1, 2, 3, 6))\n", "\n", " img = zeros(T, (W + 2*pad + stride - 1, H + 2*pad + stride - 1, C, N))\n", " for y = 1:filter_h\n", " y_max = y + stride*out_h - 1\n", " for x = 1:filter_w\n", " x_max = x + stride*out_w - 1\n", " img[x:stride:x_max, y:stride:y_max, :, :] += _col[:, :, x, y, :, :]\n", " end\n", " end\n", "\n", " return img[pad+1:pad+W, pad+1:pad+H, :, :]\n", "end" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(7, 7, 3, 1)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x1 = rand(Float32, (7, 7, 3, 1));\n", "size(x1)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(75, 9)" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "col1 = im2col(x1, 5, 5, 1, 0);\n", "size(col1)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(75, 90)" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x2 = rand(Float32, (7, 7, 3, 10));\n", "col2 = im2col(x2, 5, 5, 1, 0);\n", "size(col2)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": true }, "outputs": [], "source": [ "type Convolution{T}\n", " W::AbstractArray{T,4}\n", " b::AbstractArray{T,1}\n", " stride::Int\n", " pad::Int\n", " (::Type{Convolution})(\n", " W::AbstractArray{T,4}, \n", " b::AbstractArray{T,1},\n", " stride::Int=1,\n", " pad::Int=0) where {T} = new{T}(W, b, stride, pad)\n", "end" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "forward (generic function with 1 method)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "function forward(self::Convolution{T}, x::AbstractArray{T,4}) where {T}\n", " FW, FH, C, FN = size(self.W)\n", " W, H, C, N = size(x)\n", " out_h = 1 + (H + 2*self.pad - FH) ÷ self.stride\n", " out_w = 1 + (W + 2*self.pad - FW) ÷ self.stride\n", " \n", " col = im2col(x, FH, FW, self.stride, self.pad)\n", " col_w = reshape(self.W, (:, FN)).'\n", " out = col_w * col .+ self.b\n", " \n", " return permutedims(reshape(out, (:, out_w, out_h, N)), (2, 3, 1, 4))\n", "end" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": true }, "outputs": [], "source": [ "type Pooling{T}\n", " pool_h::Int\n", " pool_w::Int\n", " stride::Int\n", " pad::Int\n", " (::Type{Pooling{T}})(pool_h::Int, pool_w::Int, stride::Int=1, pad::Int=0) where {T} =\n", " new{T}(pool_h, pool_w, stride, pad)\n", "end" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "forward (generic function with 2 methods)" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "function forward(self::Pooling{T}, x::AbstractArray{T,4}) where {T}\n", " W, H, C, N = size(x)\n", " out_h = 1 + (H - self.pool_h) ÷ self.stride\n", " out_w = 1 + (W - self.pool_w) ÷ self.stride\n", " \n", " col = im2col(x, self.pool_h, self.pool_w, self.stride, self.pad)\n", " col = reshape(col, (self.pool_h*self.pool_w, :))\n", " \n", " out = maximum(col, 1)\n", " return permutedims(reshape(out, (C, out_w, out_h, N)), (2, 3, 1, 4))\n", "end" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "4×4×3 Array{Int64,3}:\n", "[:, :, 1] =\n", " 1 0 1 3\n", " 2 1 0 2\n", " 3 2 4 0\n", " 0 4 2 1\n", "\n", "[:, :, 2] =\n", " 3 4 3 2\n", " 0 2 0 3\n", " 6 4 1 3\n", " 5 3 0 1\n", "\n", "[:, :, 3] =\n", " 4 0 3 4\n", " 2 1 0 2\n", " 1 0 6 4\n", " 2 4 2 3" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "P0 = reshape([1,2,3,0,0,1,2,4,1,0,4,2,3,2,0,1,3,0,6,5,4,2,4,3,3,0,1,0,2,3,3,1,4,2,1,2,0,1,0,4,3,0,6,2,4,2,4,3],\n", " (4,4,3))" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Pooling{Int64}(2, 2, 2, 0)" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pool0 = Pooling{Int}(2, 2, 2, 0)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "2×2×3×1 Array{Int64,4}:\n", "[:, :, 1, 1] =\n", " 2 3\n", " 4 4\n", "\n", "[:, :, 2, 1] =\n", " 4 3\n", " 6 3\n", "\n", "[:, :, 3, 1] =\n", " 4 4\n", " 4 6" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "forward(pool0, reshape(P0, (4, 4, 3, 1)))" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "12×4 Array{Int64,2}:\n", " 1 3 1 4\n", " 2 0 0 2\n", " 0 2 3 0\n", " 1 4 2 1\n", " 3 6 3 1\n", " 0 5 0 0\n", " 4 4 2 3\n", " 2 3 3 1\n", " 4 1 3 6\n", " 2 2 0 2\n", " 0 0 4 4\n", " 1 4 2 3" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "im2col(reshape(P0, (4, 4, 3, 1)), 2, 2, 2, 0)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "4×12 Array{Int64,2}:\n", " 1 3 4 3 6 1 1 3 3 4 1 6\n", " 2 0 2 0 5 2 0 0 0 2 0 2\n", " 0 4 0 2 4 0 3 2 4 0 3 4\n", " 1 2 1 4 3 4 2 3 2 1 1 3" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "reshape(im2col(reshape(P0, (4, 4, 3, 1)), 2, 2, 2, 0), (2*2, :))" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1×12 Array{Int64,2}:\n", " 2 4 4 4 6 4 3 3 4 4 3 6" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "o0 = maximum(reshape(im2col(reshape(P0, (4, 4, 3, 1)), 2, 2, 2, 0), (2*2, :)), 1)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "2×2×3×1 Array{Int64,4}:\n", "[:, :, 1, 1] =\n", " 2 3\n", " 4 4\n", "\n", "[:, :, 2, 1] =\n", " 4 3\n", " 6 3\n", "\n", "[:, :, 3, 1] =\n", " 4 4\n", " 4 6" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "permutedims(reshape(o0, (3, 2, 2, :)), (2, 3, 1, 4))" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "4×4×3×1 Array{Int64,4}:\n", "[:, :, 1, 1] =\n", " 1 0 1 3\n", " 2 1 0 2\n", " 3 2 4 0\n", " 0 4 2 1\n", "\n", "[:, :, 2, 1] =\n", " 3 4 3 2\n", " 0 2 0 3\n", " 6 4 1 3\n", " 5 3 0 1\n", "\n", "[:, :, 3, 1] =\n", " 4 0 3 4\n", " 2 1 0 2\n", " 1 0 6 4\n", " 2 4 2 3" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X0 = reshape(P0, (4, 4, 3, 1))" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Convolution{Int64}([0 2 -2; -1 1 -2; 0 0 -1]\n", "\n", "[-1 -2 2; -2 -1 -1; -2 0 -1]\n", "\n", "[1 1 0; 2 1 2; 2 -1 2]\n", "\n", "[-1 -2 -2; 0 -2 0; 1 -2 -1]\n", "\n", "[0 -2 2; -2 0 -1; 1 0 1]\n", "\n", "[2 -2 -1; 2 1 2; -1 2 -2]\n", "\n", "[-2 -1 -1; 1 0 1; -2 -2 2]\n", "\n", "[-1 2 -1; -1 2 0; 1 -1 0]\n", "\n", "[-2 2 2; -2 2 2; -2 0 0], [-2, -2, 1], 1, 1)" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "conv0 = Convolution(rand(-2:2, 3, 3, 3, 3), rand(-2:2, 3), 1, 1)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "4×4×3×1 Array{Int64,4}:\n", "[:, :, 1, 1] =\n", " -7 -1 -12 0\n", " 1 -6 -24 11\n", " -5 -21 7 6\n", " -9 -18 17 6\n", "\n", "[:, :, 2, 1] =\n", " -5 3 -4 -4\n", " -14 -10 -18 -10\n", " -23 -14 -9 2\n", " -11 -33 -19 -3\n", "\n", "[:, :, 3, 1] =\n", " 13 -6 22 -2\n", " 9 0 5 -13\n", " 21 11 -1 3\n", " 32 11 12 0" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "forward(conv0, X0)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "27×16 Array{Int64,2}:\n", " 0 0 0 0 0 1 2 3 0 0 1 2 0 1 0 4\n", " 0 0 0 0 1 2 3 0 0 1 2 4 1 0 4 2\n", " 0 0 0 0 2 3 0 0 1 2 4 0 0 4 2 0\n", " 0 1 2 3 0 0 1 2 0 1 0 4 0 3 2 0\n", " 1 2 3 0 0 1 2 4 1 0 4 2 3 2 0 1\n", " 2 3 0 0 1 2 4 0 0 4 2 0 2 0 1 0\n", " 0 0 1 2 0 1 0 4 0 3 2 0 0 0 0 0\n", " 0 1 2 4 1 0 4 2 3 2 0 1 0 0 0 0\n", " 1 2 4 0 0 4 2 0 2 0 1 0 0 0 0 0\n", " 0 0 0 0 0 3 0 6 0 4 2 4 0 3 0 1\n", " 0 0 0 0 3 0 6 5 4 2 4 3 3 0 1 0\n", " 0 0 0 0 0 6 5 0 2 4 3 0 0 1 0 0\n", " 0 3 0 6 0 4 2 4 0 3 0 1 0 2 3 3\n", " ⋮ ⋮ ⋮ ⋮\n", " 0 4 2 4 0 3 0 1 0 2 3 3 0 0 0 0\n", " 4 2 4 3 3 0 1 0 2 3 3 1 0 0 0 0\n", " 2 4 3 0 0 1 0 0 3 3 1 0 0 0 0 0\n", " 0 0 0 0 0 4 2 1 0 0 1 0 0 3 0 6\n", " 0 0 0 0 4 2 1 2 0 1 0 4 3 0 6 2\n", " 0 0 0 0 2 1 2 0 1 0 4 0 0 6 2 0\n", " 0 4 2 1 0 0 1 0 0 3 0 6 0 4 2 4\n", " 4 2 1 2 0 1 0 4 3 0 6 2 4 2 4 3\n", " 2 1 2 0 1 0 4 0 0 6 2 0 2 4 3 0\n", " 0 0 1 0 0 3 0 6 0 4 2 4 0 0 0 0\n", " 0 1 0 4 3 0 6 2 4 2 4 3 0 0 0 0\n", " 1 0 4 0 0 6 2 0 2 4 3 0 0 0 0 0" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "im2col(X0, 3, 3, 1, 1)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "27×3 Array{Int64,2}:\n", " 0 -1 -2\n", " -1 0 1\n", " 0 1 -2\n", " 2 -2 -1\n", " 1 -2 0\n", " 0 -2 -2\n", " -2 -2 -1\n", " -2 0 1\n", " -1 -1 2\n", " -1 0 -1\n", " -2 -2 -1\n", " -2 1 1\n", " -2 -2 2\n", " ⋮ \n", " 2 2 -1\n", " -1 -1 0\n", " -1 1 0\n", " 1 2 -2\n", " 2 2 -2\n", " 2 -1 -2\n", " 1 -2 2\n", " 1 1 2\n", " -1 2 0\n", " 0 -1 2\n", " 2 2 2\n", " 2 -2 0" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "reshape(conv0.W, (:, 3))" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "3×16 Array{Int64,2}:\n", " -5 3 -3 -7 1 -4 -19 -16 -10 -22 9 19 2 13 8 8\n", " -3 -12 -21 -9 5 -8 -12 -31 -2 -16 -7 -17 -2 -8 4 -1\n", " 12 8 20 31 -7 -1 10 10 21 4 -2 11 -3 -14 2 -1" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "reshape(conv0.W, (:, 3)).' * im2col(X0, 3, 3, 1, 1)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "3×16 Array{Int64,2}:\n", " -7 1 -5 -9 -1 -6 -21 -18 -12 -24 7 17 0 11 6 6\n", " -5 -14 -23 -11 3 -10 -14 -33 -4 -18 -9 -19 -4 -10 2 -3\n", " 13 9 21 32 -6 0 11 11 22 5 -1 12 -2 -13 3 0" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "o1 = reshape(conv0.W, (:, 3)).' * im2col(X0, 3, 3, 1, 1) .+ conv0.b" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "4×4×3×1 Array{Int64,4}:\n", "[:, :, 1, 1] =\n", " -7 -1 -12 0\n", " 1 -6 -24 11\n", " -5 -21 7 6\n", " -9 -18 17 6\n", "\n", "[:, :, 2, 1] =\n", " -5 3 -4 -4\n", " -14 -10 -18 -10\n", " -23 -14 -9 2\n", " -11 -33 -19 -3\n", "\n", "[:, :, 3, 1] =\n", " 13 -6 22 -2\n", " 9 0 5 -13\n", " 21 11 -1 3\n", " 32 11 12 0" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "permutedims(reshape(o1, (3, 4, 4, :)), (2, 3, 1, 4))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Julia 0.6.0", "language": "julia", "name": "julia-0.6" }, "language_info": { "file_extension": ".jl", "mimetype": "application/julia", "name": "julia", "version": "0.6.0" } }, "nbformat": 4, "nbformat_minor": 2 }