Caffe2 - Python API
A deep learning, cross platform ML framework
cnn.py
1 ## @package cnn
2 # Module caffe2.python.cnn
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python import brew
9 from caffe2.python.model_helper import ModelHelper
10 from caffe2.proto import caffe2_pb2
11 import logging
12 
13 
15  """A helper model so we can write CNN models more easily, without having to
16  manually define parameter initializations and operators separately.
17  """
18 
19  def __init__(self, order="NCHW", name=None,
20  use_cudnn=True, cudnn_exhaustive_search=False,
21  ws_nbytes_limit=None, init_params=True,
22  skip_sparse_optim=False,
23  param_model=None):
24  logging.warning(
25  "[====DEPRECATE WARNING====]: you are creating an "
26  "object from CNNModelHelper class which will be deprecated soon. "
27  "Please use ModelHelper object with brew module. For more "
28  "information, please refer to caffe2.ai and python/brew.py, "
29  "python/brew_test.py for more information."
30  )
31 
32  cnn_arg_scope = {
33  'order': order,
34  'use_cudnn': use_cudnn,
35  'cudnn_exhaustive_search': cudnn_exhaustive_search,
36  }
37  if ws_nbytes_limit:
38  cnn_arg_scope['ws_nbytes_limit'] = ws_nbytes_limit
39  super(CNNModelHelper, self).__init__(
40  skip_sparse_optim=skip_sparse_optim,
41  name="CNN" if name is None else name,
42  init_params=init_params,
43  param_model=param_model,
44  arg_scope=cnn_arg_scope,
45  )
46 
47  self.order = order
48  self.use_cudnn = use_cudnn
49  self.cudnn_exhaustive_search = cudnn_exhaustive_search
50  self.ws_nbytes_limit = ws_nbytes_limit
51  if self.order != "NHWC" and self.order != "NCHW":
52  raise ValueError(
53  "Cannot understand the CNN storage order %s." % self.order
54  )
55 
56  def ImageInput(self, blob_in, blob_out, use_gpu_transform=False, **kwargs):
57  return brew.image_input(
58  self,
59  blob_in,
60  blob_out,
61  order=self.order,
62  use_gpu_transform=use_gpu_transform,
63  **kwargs
64  )
65 
66  def VideoInput(self, blob_in, blob_out, **kwargs):
67  return brew.video_input(
68  self,
69  blob_in,
70  blob_out,
71  **kwargs
72  )
73 
74  def PadImage(self, blob_in, blob_out, **kwargs):
75  # TODO(wyiming): remove this dummy helper later
76  self.net.PadImage(blob_in, blob_out, **kwargs)
77 
78  def ConvNd(self, *args, **kwargs):
79  return brew.conv_nd(
80  self,
81  *args,
82  use_cudnn=self.use_cudnn,
83  order=self.order,
84  cudnn_exhaustive_search=self.cudnn_exhaustive_search,
85  ws_nbytes_limit=self.ws_nbytes_limit,
86  **kwargs
87  )
88 
89  def Conv(self, *args, **kwargs):
90  return brew.conv(
91  self,
92  *args,
93  use_cudnn=self.use_cudnn,
94  order=self.order,
95  cudnn_exhaustive_search=self.cudnn_exhaustive_search,
96  ws_nbytes_limit=self.ws_nbytes_limit,
97  **kwargs
98  )
99 
100  def ConvTranspose(self, *args, **kwargs):
101  return brew.conv_transpose(
102  self,
103  *args,
104  use_cudnn=self.use_cudnn,
105  order=self.order,
106  cudnn_exhaustive_search=self.cudnn_exhaustive_search,
107  ws_nbytes_limit=self.ws_nbytes_limit,
108  **kwargs
109  )
110 
111  def GroupConv(self, *args, **kwargs):
112  return brew.group_conv(
113  self,
114  *args,
115  use_cudnn=self.use_cudnn,
116  order=self.order,
117  cudnn_exhaustive_search=self.cudnn_exhaustive_search,
118  ws_nbytes_limit=self.ws_nbytes_limit,
119  **kwargs
120  )
121 
122  def GroupConv_Deprecated(self, *args, **kwargs):
123  return brew.group_conv_deprecated(
124  self,
125  *args,
126  use_cudnn=self.use_cudnn,
127  order=self.order,
128  cudnn_exhaustive_search=self.cudnn_exhaustive_search,
129  ws_nbytes_limit=self.ws_nbytes_limit,
130  **kwargs
131  )
132 
133  def FC(self, *args, **kwargs):
134  return brew.fc(self, *args, **kwargs)
135 
136  def PackedFC(self, *args, **kwargs):
137  return brew.packed_fc(self, *args, **kwargs)
138 
139  def FC_Prune(self, *args, **kwargs):
140  return brew.fc_prune(self, *args, **kwargs)
141 
142  def FC_Decomp(self, *args, **kwargs):
143  return brew.fc_decomp(self, *args, **kwargs)
144 
145  def FC_Sparse(self, *args, **kwargs):
146  return brew.fc_sparse(self, *args, **kwargs)
147 
148  def Dropout(self, *args, **kwargs):
149  return brew.dropout(
150  self, *args, order=self.order, use_cudnn=self.use_cudnn, **kwargs
151  )
152 
153  def LRN(self, *args, **kwargs):
154  return brew.lrn(
155  self, *args, order=self.order, use_cudnn=self.use_cudnn, **kwargs
156  )
157 
158  def Softmax(self, *args, **kwargs):
159  return brew.softmax(self, *args, use_cudnn=self.use_cudnn, **kwargs)
160 
161  def SpatialBN(self, *args, **kwargs):
162  return brew.spatial_bn(self, *args, order=self.order, **kwargs)
163 
164  def InstanceNorm(self, *args, **kwargs):
165  return brew.instance_norm(self, *args, order=self.order, **kwargs)
166 
167  def Relu(self, *args, **kwargs):
168  return brew.relu(
169  self, *args, order=self.order, use_cudnn=self.use_cudnn, **kwargs
170  )
171 
172  def PRelu(self, *args, **kwargs):
173  return brew.prelu(self, *args, **kwargs)
174 
175  def Concat(self, *args, **kwargs):
176  return brew.concat(self, *args, order=self.order, **kwargs)
177 
178  def DepthConcat(self, *args, **kwargs):
179  """The old depth concat function - we should move to use concat."""
180  print("DepthConcat is deprecated. use Concat instead.")
181  return self.Concat(*args, **kwargs)
182 
183  def Sum(self, *args, **kwargs):
184  return brew.sum(self, *args, **kwargs)
185 
186  def Transpose(self, *args, **kwargs):
187  return brew.transpose(self, *args, use_cudnn=self.use_cudnn, **kwargs)
188 
189  def Iter(self, *args, **kwargs):
190  return brew.iter(self, *args, **kwargs)
191 
192  def Accuracy(self, *args, **kwargs):
193  return brew.accuracy(self, *args, **kwargs)
194 
195  def MaxPool(self, *args, **kwargs):
196  return brew.max_pool(
197  self, *args, use_cudnn=self.use_cudnn, order=self.order, **kwargs
198  )
199 
200  def MaxPoolWithIndex(self, *args, **kwargs):
201  return brew.max_pool_with_index(self, *args, order=self.order, **kwargs)
202 
203  def AveragePool(self, *args, **kwargs):
204  return brew.average_pool(
205  self, *args, use_cudnn=self.use_cudnn, order=self.order, **kwargs
206  )
207 
208  @property
209  def XavierInit(self):
210  return ('XavierFill', {})
211 
212  def ConstantInit(self, value):
213  return ('ConstantFill', dict(value=value))
214 
215  @property
216  def MSRAInit(self):
217  return ('MSRAFill', {})
218 
219  @property
220  def ZeroInit(self):
221  return ('ConstantFill', {})
222 
223  def AddWeightDecay(self, weight_decay):
224  return brew.add_weight_decay(self, weight_decay)
225 
226  @property
227  def CPU(self):
228  device_option = caffe2_pb2.DeviceOption()
229  device_option.device_type = caffe2_pb2.CPU
230  return device_option
231 
232  @property
233  def GPU(self, gpu_id=0):
234  device_option = caffe2_pb2.DeviceOption()
235  device_option.device_type = caffe2_pb2.CUDA
236  device_option.cuda_gpu_id = gpu_id
237  return device_option
def DepthConcat(self, args, kwargs)
Definition: cnn.py:178
def Concat(self, args, kwargs)
Definition: cnn.py:175