1 from __future__
import absolute_import
2 from __future__
import division
3 from __future__
import print_function
4 from __future__
import unicode_literals
14 This class modifies the net passed in by adding ops to compute histogram for 18 blobs: list of blobs to compute histogram for 19 logging_frequency: frequency for printing 20 lower_bound: left boundary of histogram values 21 upper_bound: right boundary of histogram values 22 num_buckets: number of buckets to use in [lower_bound, upper_bound) 23 accumulate: boolean to output accumulate or per-batch histogram 26 def __init__(self, blobs, logging_frequency, num_buckets=30,
27 lower_bound=0.0, upper_bound=1.0, accumulate=
False):
38 "num_buckets need to be greater than 0, got {}".format(num_buckets))
42 def modify_net(self, net, init_net=None, grad_map=None, blob_to_device=None):
43 for blob_name
in self.
_blobs:
45 if not net.BlobIsDefined(blob):
46 raise Exception(
'blob {0} is not defined in net {1}'.format(
49 blob_float = net.Cast(blob, net.NextScopedBlob(prefix=blob +
50 '_float'), to=core.DataType.FLOAT)
51 curr_hist, acc_hist = net.AccumulateHistogram(
53 [net.NextScopedBlob(prefix=blob +
'_curr_hist'),
54 net.NextScopedBlob(prefix=blob +
'_acc_hist')],
62 net.NextScopedBlob(prefix=blob +
'_cast_hist'),
63 to=core.DataType.FLOAT)
67 net.NextScopedBlob(prefix=blob +
'_cast_hist'),
68 to=core.DataType.FLOAT)
70 normalized_hist = net.NormalizeL1(
82 if net.output_record()
is None:
83 net.set_output_record(
87 net.AppendOutputRecordField(
91 def field_name_suffix(self):