Caffe2 - Python API
A deep learning, cross platform ML framework
compute_norm_for_blobs.py
1 from __future__ import absolute_import
2 from __future__ import division
3 from __future__ import print_function
4 from __future__ import unicode_literals
5 
6 from caffe2.python import core, schema
7 from caffe2.python.modeling.net_modifier import NetModifier
8 
9 import numpy as np
10 
11 
13  """
14  This class modifies the net passed in by adding ops to compute norms for
15  certain blobs.
16 
17  Args:
18  blobs: list of blobs to compute norm for
19  logging_frequency: frequency for printing norms to logs
20  p: type of norm. Currently it supports p=1 or p=2
21  compute_averaged_norm: norm or averaged_norm (averaged_norm = norm/size)
22  """
23 
24  def __init__(self, blobs, logging_frequency, p=2, compute_averaged_norm=False):
25  self._blobs = blobs
26  self._logging_frequency = logging_frequency
27  self._p = p
28  self._compute_averaged_norm = compute_averaged_norm
29  self._field_name_suffix = '_l{}_norm'.format(p)
30  if compute_averaged_norm:
31  self._field_name_suffix = '_averaged' + self._field_name_suffix
32 
33  def modify_net(self, net, init_net=None, grad_map=None, blob_to_device=None):
34 
35  p = self._p
36  compute_averaged_norm = self._compute_averaged_norm
37 
38  for blob_name in self._blobs:
39  blob = core.BlobReference(blob_name)
40  if not net.BlobIsDefined(blob):
41  raise Exception('blob {0} is not defined in net {1}'.format(
42  blob, net.Name()))
43 
44  norm_name = net.NextScopedBlob(prefix=blob + self._field_name_suffix)
45  norm = net.LpNorm(blob, norm_name, p=p, average=compute_averaged_norm)
46 
47  if self._logging_frequency >= 1:
48  net.Print(norm, [], every_n=self._logging_frequency)
49 
50  output_field_name = str(blob) + self._field_name_suffix
51  output_scalar = schema.Scalar((np.float, (1,)), norm)
52 
53  if net.output_record() is None:
54  net.set_output_record(
55  schema.Struct((output_field_name, output_scalar))
56  )
57  else:
58  net.AppendOutputRecordField(
59  output_field_name,
60  output_scalar)
61 
62  def field_name_suffix(self):
63  return self._field_name_suffix