3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
29 homotopy_weighting=
False,
32 super(BatchLRLoss, self).__init__(model, name, input_record, **kwargs)
36 assert (schema.is_schema_subset(
45 assert jsd_weight >= 0
and jsd_weight <= 1
46 if jsd_weight > 0
or homotopy_weighting:
47 assert 'prediction' in input_record
52 assert pos_label_target <= 1
and pos_label_target >= 0
53 assert neg_label_target <= 1
and neg_label_target >= 0
54 assert pos_label_target >= neg_label_target
58 self.tags.update([Tags.EXCLUDE_FROM_PREDICTION])
62 self.get_next_blob_reference(
'output')
65 def init_weight(self, jsd_weight, homotopy_weighting):
66 if homotopy_weighting:
67 self.
mutex = self.create_param(
68 param_name=(
'%s_mutex' % self.name),
70 initializer=(
'CreateMutex', ),
71 optimizer=self.model.NoOptim,
73 self.
counter = self.create_param(
74 param_name=(
'%s_counter' % self.name),
79 'dtype': core.DataType.INT64
82 optimizer=self.model.NoOptim,
85 param_name=(
'%s_xent_weight' % self.name),
90 'dtype': core.DataType.FLOAT
93 optimizer=self.model.NoOptim,
96 param_name=(
'%s_jsd_weight' % self.name),
101 'dtype': core.DataType.FLOAT
104 optimizer=self.model.NoOptim,
107 self.
jsd_weight = self.model.add_global_constant(
108 '%s_jsd_weight' % self.name, jsd_weight
111 '%s_xent_weight' % self.name, 1. - jsd_weight
114 def update_weight(self, net):
120 policy=
'inv', gamma=1e-6, power=0.1,)
122 [self.model.global_constants[
'ONE'], self.
xent_weight],
127 def add_ops(self, net):
129 label = self.input_record.label()
135 net.NextScopedBlob(
'label_float32'),
136 to=core.DataType.FLOAT)
137 label = net.ExpandDims(label, net.NextScopedBlob(
'expanded_label'),
140 label = net.StumpFunc(
142 net.NextScopedBlob(
'smoothed_label'),
147 xent = net.SigmoidCrossEntropyWithLogits(
148 [self.input_record.logit(), label],
149 net.NextScopedBlob(
'cross_entropy'),
153 jsd = net.BernoulliJSD(
154 [self.input_record.prediction(), label],
155 net.NextScopedBlob(
'jsd'),
159 loss = net.WeightedSum(
161 net.NextScopedBlob(
'loss'),
165 if 'weight' in self.input_record.fields:
166 weight_blob = self.input_record.weight()
167 if self.input_record.weight.field_type().base != np.float32:
168 weight_blob = net.Cast(
170 weight_blob +
'_float32',
171 to=core.DataType.FLOAT
173 weight_blob = net.StopGradient(
175 [net.NextScopedBlob(
'weight_stop_gradient')],
179 net.NextScopedBlob(
'weighted_cross_entropy'),
183 net.AveragedLoss(loss, self.output_schema.field_blobs())
185 net.ReduceFrontSum(loss, self.output_schema.field_blobs())
def init_weight(self, jsd_weight, homotopy_weighting)
def update_weight(self, net)