Caffe2 - Python API
A deep learning, cross platform ML framework
batch_distill_lr_loss.py
1 ## @package batch_distill_lr_loss
2 # Module caffe2.python.layers.batch_distill_lr_loss
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python import core, schema
9 from caffe2.python.layers.layers import (
10  ModelLayer,
11 )
12 from caffe2.python.layers.tags import (
13  Tags
14 )
15 import numpy as np
16 
17 
18 class BatchDistillLRLoss(ModelLayer):
19 
20  def __init__(
21  self, model, input_record,
22  name='batch_distill_lr_loss', teacherWeight=0.0, **kwargs):
23 
24  super(BatchDistillLRLoss, self).__init__(model, name, input_record, **kwargs)
25 
26  assert teacherWeight >= 0 and teacherWeight <= 1, (
27  'teacherWeight=%0.2f should be in [0, 1]' % teacherWeight
28  )
29  self._teacherWeight = teacherWeight
30 
31  assert schema.is_schema_subset(
33  ('teacher_label', schema.Scalar()),
34  ('label', schema.Scalar()),
35  ('logit', schema.Scalar()),
36  ),
37  input_record
38  )
39  self.tags.update([Tags.EXCLUDE_FROM_PREDICTION])
40 
42  np.float32,
43  self.get_next_blob_reference('output')
44  )
45 
46  def add_ops(self, net):
47  label = self.input_record.label()
48  if self.input_record.label.field_type() != np.float32:
49  label = net.Cast(
50  label,
51  net.NextScopedBlob('float_label'),
52  to=core.DataType.FLOAT,
53  )
54 
55  # Assuming 1-D input
56  label = net.ExpandDims(label, net.NextScopedBlob('expanded_label'),
57  dims=[1])
58 
59  teacher_label = self.input_record.teacher_label()
60  if self.input_record.teacher_label.field_type() != np.float32:
61  teacher_label = net.Cast(
62  teacher_label,
63  net.NextScopedBlob('float_teacher_label'),
64  to=core.DataType.FLOAT,
65  )
66  teacher_label = net.ExpandDims(
67  teacher_label, net.NextScopedBlob('expanded_teacher_label'),
68  dims=[1])
69 
70  true_xent = net.SigmoidCrossEntropyWithLogits(
71  [self.input_record.logit(), label],
72  net.NextScopedBlob('cross_entropy')
73  )
74 
75  teacher_xent = net.SigmoidCrossEntropyWithLogits(
76  [self.input_record.logit(), teacher_label],
77  net.NextScopedBlob('teacher_cross_entropy')
78  )
79 
80  scaled_true_xent = net.Scale(
81  true_xent,
82  net.NextScopedBlob('scaled_cross_entropy'),
83  scale=1.0 - self._teacherWeight,
84  )
85  scaled_teacher_xent = net.Scale(
86  teacher_xent,
87  net.NextScopedBlob('scaled_teacher_cross_entropy'),
88  scale=self._teacherWeight,
89  )
90 
91  true_loss = net.AveragedLoss(
92  scaled_true_xent,
93  net.NextScopedBlob('true_loss')
94  )
95  teacher_loss = net.AveragedLoss(
96  scaled_teacher_xent,
97  net.NextScopedBlob('teacher_loss')
98  )
99 
100  net.Add(
101  [true_loss, teacher_loss],
102  self.output_schema.field_blobs()
103  )