3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
21 self, model, input_record,
22 name=
'batch_distill_lr_loss', teacherWeight=0.0, **kwargs):
24 super(BatchDistillLRLoss, self).__init__(model, name, input_record, **kwargs)
26 assert teacherWeight >= 0
and teacherWeight <= 1, (
27 'teacherWeight=%0.2f should be in [0, 1]' % teacherWeight
31 assert schema.is_schema_subset(
39 self.tags.update([Tags.EXCLUDE_FROM_PREDICTION])
43 self.get_next_blob_reference(
'output')
46 def add_ops(self, net):
47 label = self.input_record.label()
48 if self.input_record.label.field_type() != np.float32:
51 net.NextScopedBlob(
'float_label'),
52 to=core.DataType.FLOAT,
56 label = net.ExpandDims(label, net.NextScopedBlob(
'expanded_label'),
59 teacher_label = self.input_record.teacher_label()
60 if self.input_record.teacher_label.field_type() != np.float32:
61 teacher_label = net.Cast(
63 net.NextScopedBlob(
'float_teacher_label'),
64 to=core.DataType.FLOAT,
66 teacher_label = net.ExpandDims(
67 teacher_label, net.NextScopedBlob(
'expanded_teacher_label'),
70 true_xent = net.SigmoidCrossEntropyWithLogits(
71 [self.input_record.logit(), label],
72 net.NextScopedBlob(
'cross_entropy')
75 teacher_xent = net.SigmoidCrossEntropyWithLogits(
76 [self.input_record.logit(), teacher_label],
77 net.NextScopedBlob(
'teacher_cross_entropy')
80 scaled_true_xent = net.Scale(
82 net.NextScopedBlob(
'scaled_cross_entropy'),
85 scaled_teacher_xent = net.Scale(
87 net.NextScopedBlob(
'scaled_teacher_cross_entropy'),
91 true_loss = net.AveragedLoss(
93 net.NextScopedBlob(
'true_loss')
95 teacher_loss = net.AveragedLoss(
97 net.NextScopedBlob(
'teacher_loss')
101 [true_loss, teacher_loss],
102 self.output_schema.field_blobs()