1 from __future__
import absolute_import
2 from __future__
import division
3 from __future__
import print_function
4 from __future__
import unicode_literals
14 This class modifies the net passed in by adding ops to compute norms for 18 blobs: list of blobs to compute norm for 19 logging_frequency: frequency for printing norms to logs 20 p: type of norm. Currently it supports p=1 or p=2 21 compute_averaged_norm: norm or averaged_norm (averaged_norm = norm/size) 24 def __init__(self, blobs, logging_frequency, p=2, compute_averaged_norm=False):
30 if compute_averaged_norm:
33 def modify_net(self, net, init_net=None, grad_map=None, blob_to_device=None):
38 for blob_name
in self.
_blobs:
40 if not net.BlobIsDefined(blob):
41 raise Exception(
'blob {0} is not defined in net {1}'.format(
45 norm = net.LpNorm(blob, norm_name, p=p, average=compute_averaged_norm)
53 if net.output_record()
is None:
54 net.set_output_record(
58 net.AppendOutputRecordField(
62 def field_name_suffix(self):