4 from __future__
import absolute_import
5 from __future__
import division
6 from __future__
import print_function
10 Utility for creating ResNets 11 See "Deep Residual Learning for Image Recognition" by He, Zhang et. al. 2015 17 Helper class for constructing residual blocks. 20 def __init__(self, model, prev_blob, no_bias, is_test, spatial_bn_mom=0.9):
27 self.
no_bias = 1
if no_bias
else 0
29 def add_conv(self, in_filters, out_filters, kernel, stride=1, pad=0):
37 weight_init=(
"MSRAFill", {}),
53 def add_spatial_bn(self, num_filters):
66 Add a "bottleneck" component as decribed in He et. al. Figure 3 (right) 75 spatial_batch_norm=
True,
88 if spatial_batch_norm:
98 stride=(1
if down_sampling
is False else 2),
102 if spatial_batch_norm:
107 last_conv = self.
add_conv(base_filters, output_filters, kernel=1)
108 if spatial_batch_norm:
114 if (output_filters > input_filters):
115 shortcut_blob = brew.conv(
121 weight_init=(
"MSRAFill", {}),
123 stride=(1
if down_sampling
is False else 2),
126 if spatial_batch_norm:
127 shortcut_blob = brew.spatial_bn(
130 'shortcut_projection_%d_spatbn' % self.
comp_count,
138 self.
model, [shortcut_blob, last_conv],
147 def add_simple_block(
152 spatial_batch_norm=
True 162 stride=(1
if down_sampling
is False else 2),
166 if spatial_batch_norm:
170 last_conv = self.
add_conv(num_filters, num_filters, kernel=3, pad=1)
171 if spatial_batch_norm:
175 if (num_filters != input_filters):
176 shortcut_blob = brew.conv(
182 weight_init=(
"MSRAFill", {}),
184 stride=(1
if down_sampling
is False else 2),
187 if spatial_batch_norm:
188 shortcut_blob = brew.spatial_bn(
191 'shortcut_projection_%d_spatbn' % self.
comp_count,
198 self.
model, [shortcut_blob, last_conv],
230 weight_init=(
"MSRAFill", {}),
246 brew.relu(model,
'conv1_spatbn_relu',
'conv1_spatbn_relu')
247 brew.max_pool(model,
'conv1_spatbn_relu',
'pool1', kernel=3, stride=2)
251 is_test=is_test, spatial_bn_mom=0.1)
254 builder.add_bottleneck(64, 64, 256)
255 builder.add_bottleneck(256, 64, 256)
256 builder.add_bottleneck(256, 64, 256)
259 builder.add_bottleneck(256, 128, 512, down_sampling=
True)
260 for _
in range(1, 4):
261 builder.add_bottleneck(512, 128, 512)
264 builder.add_bottleneck(512, 256, 1024, down_sampling=
True)
265 for _
in range(1, 6):
266 builder.add_bottleneck(1024, 256, 1024)
269 builder.add_bottleneck(1024, 512, 2048, down_sampling=
True)
270 builder.add_bottleneck(2048, 512, 2048)
271 builder.add_bottleneck(2048, 512, 2048)
274 final_avg = brew.average_pool(
278 kernel=final_avg_kernel,
285 model, final_avg,
'last_out_L{}'.format(num_labels), 2048, num_labels
292 if (label
is not None):
293 (softmax, loss) = model.SoftmaxWithLoss(
298 return (softmax, loss)
301 return brew.softmax(model, last_out,
"softmax")
304 def create_resnet_32x32(
305 model, data, num_input_channels, num_groups, num_labels, is_test=
False 308 Create residual net for smaller images (sec 4.2 of He et. al (2015)) 309 num_groups = 'n' in the paper 313 model, data,
'conv1', num_input_channels, 16, kernel=3, stride=1
316 model,
'conv1',
'conv1_spatbn', 16, epsilon=1e-3, is_test=is_test
318 brew.relu(model,
'conv1_spatbn',
'relu1')
321 filters = [16, 32, 64]
323 builder =
ResNetBuilder(model,
'relu1', no_bias=0, is_test=is_test)
325 for groupidx
in range(0, 3):
326 for blockidx
in range(0, 2 * num_groups):
327 builder.add_simple_block(
328 prev_filters
if blockidx == 0
else filters[groupidx],
330 down_sampling=(
True if blockidx == 0
and 331 groupidx > 0
else False))
332 prev_filters = filters[groupidx]
336 model, builder.prev_blob,
'final_avg', kernel=8, stride=1
338 brew.fc(model,
'final_avg',
'last_out', 64, num_labels)
339 softmax = brew.softmax(model,
'last_out',
'softmax')
def add_spatial_bn(self, num_filters)
def add_conv(self, in_filters, out_filters, kernel, stride=1, pad=0)