Caffe2 - Python API
A deep learning, cross platform ML framework
resnet.py
1 ## @package resnet
2 # Module caffe2.python.models.resnet
3 
4 from __future__ import absolute_import
5 from __future__ import division
6 from __future__ import print_function
7 
8 from caffe2.python import brew
9 '''
10 Utility for creating ResNets
11 See "Deep Residual Learning for Image Recognition" by He, Zhang et. al. 2015
12 '''
13 
14 
15 class ResNetBuilder():
16  '''
17  Helper class for constructing residual blocks.
18  '''
19 
20  def __init__(self, model, prev_blob, no_bias, is_test, spatial_bn_mom=0.9):
21  self.model = model
22  self.comp_count = 0
23  self.comp_idx = 0
24  self.prev_blob = prev_blob
25  self.is_test = is_test
26  self.spatial_bn_mom = spatial_bn_mom
27  self.no_bias = 1 if no_bias else 0
28 
29  def add_conv(self, in_filters, out_filters, kernel, stride=1, pad=0):
30  self.comp_idx += 1
31  self.prev_blob = brew.conv(
32  self.model,
33  self.prev_blob,
34  'comp_%d_conv_%d' % (self.comp_count, self.comp_idx),
35  in_filters,
36  out_filters,
37  weight_init=("MSRAFill", {}),
38  kernel=kernel,
39  stride=stride,
40  pad=pad,
41  no_bias=self.no_bias,
42  )
43  return self.prev_blob
44 
45  def add_relu(self):
46  self.prev_blob = brew.relu(
47  self.model,
48  self.prev_blob,
49  self.prev_blob, # in-place
50  )
51  return self.prev_blob
52 
53  def add_spatial_bn(self, num_filters):
54  self.prev_blob = brew.spatial_bn(
55  self.model,
56  self.prev_blob,
57  'comp_%d_spatbn_%d' % (self.comp_count, self.comp_idx),
58  num_filters,
59  epsilon=1e-3,
60  momentum=self.spatial_bn_mom,
61  is_test=self.is_test,
62  )
63  return self.prev_blob
64 
65  '''
66  Add a "bottleneck" component as decribed in He et. al. Figure 3 (right)
67  '''
68 
69  def add_bottleneck(
70  self,
71  input_filters, # num of feature maps from preceding layer
72  base_filters, # num of filters internally in the component
73  output_filters, # num of feature maps to output
74  down_sampling=False,
75  spatial_batch_norm=True,
76  ):
77  self.comp_idx = 0
78  shortcut_blob = self.prev_blob
79 
80  # 1x1
81  self.add_conv(
82  input_filters,
83  base_filters,
84  kernel=1,
85  stride=1
86  )
87 
88  if spatial_batch_norm:
89  self.add_spatial_bn(base_filters)
90 
91  self.add_relu()
92 
93  # 3x3 (note the pad, required for keeping dimensions)
94  self.add_conv(
95  base_filters,
96  base_filters,
97  kernel=3,
98  stride=(1 if down_sampling is False else 2),
99  pad=1
100  )
101 
102  if spatial_batch_norm:
103  self.add_spatial_bn(base_filters)
104  self.add_relu()
105 
106  # 1x1
107  last_conv = self.add_conv(base_filters, output_filters, kernel=1)
108  if spatial_batch_norm:
109  last_conv = self.add_spatial_bn(output_filters)
110 
111  # Summation with input signal (shortcut)
112  # If we need to increase dimensions (feature maps), need to
113  # do a projection for the short cut
114  if (output_filters > input_filters):
115  shortcut_blob = brew.conv(
116  self.model,
117  shortcut_blob,
118  'shortcut_projection_%d' % self.comp_count,
119  input_filters,
120  output_filters,
121  weight_init=("MSRAFill", {}),
122  kernel=1,
123  stride=(1 if down_sampling is False else 2),
124  no_bias=self.no_bias,
125  )
126  if spatial_batch_norm:
127  shortcut_blob = brew.spatial_bn(
128  self.model,
129  shortcut_blob,
130  'shortcut_projection_%d_spatbn' % self.comp_count,
131  output_filters,
132  epsilon=1e-3,
133  momentum=self.spatial_bn_mom,
134  is_test=self.is_test,
135  )
136 
137  self.prev_blob = brew.sum(
138  self.model, [shortcut_blob, last_conv],
139  'comp_%d_sum_%d' % (self.comp_count, self.comp_idx)
140  )
141  self.comp_idx += 1
142  self.add_relu()
143 
144  # Keep track of number of high level components if this ResNetBuilder
145  self.comp_count += 1
146 
147  def add_simple_block(
148  self,
149  input_filters,
150  num_filters,
151  down_sampling=False,
152  spatial_batch_norm=True
153  ):
154  self.comp_idx = 0
155  shortcut_blob = self.prev_blob
156 
157  # 3x3
158  self.add_conv(
159  input_filters,
160  num_filters,
161  kernel=3,
162  stride=(1 if down_sampling is False else 2),
163  pad=1
164  )
165 
166  if spatial_batch_norm:
167  self.add_spatial_bn(num_filters)
168  self.add_relu()
169 
170  last_conv = self.add_conv(num_filters, num_filters, kernel=3, pad=1)
171  if spatial_batch_norm:
172  last_conv = self.add_spatial_bn(num_filters)
173 
174  # Increase of dimensions, need a projection for the shortcut
175  if (num_filters != input_filters):
176  shortcut_blob = brew.conv(
177  self.model,
178  shortcut_blob,
179  'shortcut_projection_%d' % self.comp_count,
180  input_filters,
181  num_filters,
182  weight_init=("MSRAFill", {}),
183  kernel=1,
184  stride=(1 if down_sampling is False else 2),
185  no_bias=self.no_bias,
186  )
187  if spatial_batch_norm:
188  shortcut_blob = brew.spatial_bn(
189  self.model,
190  shortcut_blob,
191  'shortcut_projection_%d_spatbn' % self.comp_count,
192  num_filters,
193  epsilon=1e-3,
194  is_test=self.is_test,
195  )
196 
197  self.prev_blob = brew.sum(
198  self.model, [shortcut_blob, last_conv],
199  'comp_%d_sum_%d' % (self.comp_count, self.comp_idx)
200  )
201  self.comp_idx += 1
202  self.add_relu()
203 
204  # Keep track of number of high level components if this ResNetBuilder
205  self.comp_count += 1
206 
207 
208 # The conv1 and final_avg kernel/stride args provide a basic mechanism for
209 # adapting resnet50 for different sizes of input images.
210 def create_resnet50(
211  model,
212  data,
213  num_input_channels,
214  num_labels,
215  label=None,
216  is_test=False,
217  no_loss=False,
218  no_bias=0,
219  conv1_kernel=7,
220  conv1_stride=2,
221  final_avg_kernel=7,
222 ):
223  # conv1 + maxpool
224  brew.conv(
225  model,
226  data,
227  'conv1',
228  num_input_channels,
229  64,
230  weight_init=("MSRAFill", {}),
231  kernel=conv1_kernel,
232  stride=conv1_stride,
233  pad=3,
234  no_bias=no_bias
235  )
236 
237  brew.spatial_bn(
238  model,
239  'conv1',
240  'conv1_spatbn_relu',
241  64,
242  epsilon=1e-3,
243  momentum=0.1,
244  is_test=is_test
245  )
246  brew.relu(model, 'conv1_spatbn_relu', 'conv1_spatbn_relu')
247  brew.max_pool(model, 'conv1_spatbn_relu', 'pool1', kernel=3, stride=2)
248 
249  # Residual blocks...
250  builder = ResNetBuilder(model, 'pool1', no_bias=no_bias,
251  is_test=is_test, spatial_bn_mom=0.1)
252 
253  # conv2_x (ref Table 1 in He et al. (2015))
254  builder.add_bottleneck(64, 64, 256)
255  builder.add_bottleneck(256, 64, 256)
256  builder.add_bottleneck(256, 64, 256)
257 
258  # conv3_x
259  builder.add_bottleneck(256, 128, 512, down_sampling=True)
260  for _ in range(1, 4):
261  builder.add_bottleneck(512, 128, 512)
262 
263  # conv4_x
264  builder.add_bottleneck(512, 256, 1024, down_sampling=True)
265  for _ in range(1, 6):
266  builder.add_bottleneck(1024, 256, 1024)
267 
268  # conv5_x
269  builder.add_bottleneck(1024, 512, 2048, down_sampling=True)
270  builder.add_bottleneck(2048, 512, 2048)
271  builder.add_bottleneck(2048, 512, 2048)
272 
273  # Final layers
274  final_avg = brew.average_pool(
275  model,
276  builder.prev_blob,
277  'final_avg',
278  kernel=final_avg_kernel,
279  stride=1,
280  global_pooling=True,
281  )
282 
283  # Final dimension of the "image" is reduced to 7x7
284  last_out = brew.fc(
285  model, final_avg, 'last_out_L{}'.format(num_labels), 2048, num_labels
286  )
287 
288  if no_loss:
289  return last_out
290 
291  # If we create model for training, use softmax-with-loss
292  if (label is not None):
293  (softmax, loss) = model.SoftmaxWithLoss(
294  [last_out, label],
295  ["softmax", "loss"],
296  )
297 
298  return (softmax, loss)
299  else:
300  # For inference, we just return softmax
301  return brew.softmax(model, last_out, "softmax")
302 
303 
304 def create_resnet_32x32(
305  model, data, num_input_channels, num_groups, num_labels, is_test=False
306 ):
307  '''
308  Create residual net for smaller images (sec 4.2 of He et. al (2015))
309  num_groups = 'n' in the paper
310  '''
311  # conv1 + maxpool
312  brew.conv(
313  model, data, 'conv1', num_input_channels, 16, kernel=3, stride=1
314  )
315  brew.spatial_bn(
316  model, 'conv1', 'conv1_spatbn', 16, epsilon=1e-3, is_test=is_test
317  )
318  brew.relu(model, 'conv1_spatbn', 'relu1')
319 
320  # Number of blocks as described in sec 4.2
321  filters = [16, 32, 64]
322 
323  builder = ResNetBuilder(model, 'relu1', no_bias=0, is_test=is_test)
324  prev_filters = 16
325  for groupidx in range(0, 3):
326  for blockidx in range(0, 2 * num_groups):
327  builder.add_simple_block(
328  prev_filters if blockidx == 0 else filters[groupidx],
329  filters[groupidx],
330  down_sampling=(True if blockidx == 0 and
331  groupidx > 0 else False))
332  prev_filters = filters[groupidx]
333 
334  # Final layers
335  brew.average_pool(
336  model, builder.prev_blob, 'final_avg', kernel=8, stride=1
337  )
338  brew.fc(model, 'final_avg', 'last_out', 64, num_labels)
339  softmax = brew.softmax(model, 'last_out', 'softmax')
340  return softmax
def add_spatial_bn(self, num_filters)
Definition: resnet.py:53
def add_conv(self, in_filters, out_filters, kernel, stride=1, pad=0)
Definition: resnet.py:29