Caffe2 - Python API
A deep learning, cross platform ML framework
batch_lr_loss.py
1 ## @package batch_lr_loss
2 # Module caffe2.python.layers.batch_lr_loss
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python import core, schema
9 from caffe2.python.layers.layers import (
10  ModelLayer,
11 )
12 from caffe2.python.layers.tags import (
13  Tags
14 )
15 import numpy as np
16 
17 
18 class BatchLRLoss(ModelLayer):
19 
20  def __init__(
21  self,
22  model,
23  input_record,
24  name='batch_lr_loss',
25  average_loss=True,
26  jsd_weight=0.0,
27  pos_label_target=1.0,
28  neg_label_target=0.0,
29  homotopy_weighting=False,
30  **kwargs
31  ):
32  super(BatchLRLoss, self).__init__(model, name, input_record, **kwargs)
33 
34  self.average_loss = average_loss
35 
36  assert (schema.is_schema_subset(
38  ('label', schema.Scalar()),
39  ('logit', schema.Scalar())
40  ),
41  input_record
42  ))
43 
44  self.jsd_fuse = False
45  assert jsd_weight >= 0 and jsd_weight <= 1
46  if jsd_weight > 0 or homotopy_weighting:
47  assert 'prediction' in input_record
48  self.init_weight(jsd_weight, homotopy_weighting)
49  self.jsd_fuse = True
50  self.homotopy_weighting = homotopy_weighting
51 
52  assert pos_label_target <= 1 and pos_label_target >= 0
53  assert neg_label_target <= 1 and neg_label_target >= 0
54  assert pos_label_target >= neg_label_target
55  self.pos_label_target = pos_label_target
56  self.neg_label_target = neg_label_target
57 
58  self.tags.update([Tags.EXCLUDE_FROM_PREDICTION])
59 
61  np.float32,
62  self.get_next_blob_reference('output')
63  )
64 
65  def init_weight(self, jsd_weight, homotopy_weighting):
66  if homotopy_weighting:
67  self.mutex = self.create_param(
68  param_name=('%s_mutex' % self.name),
69  shape=None,
70  initializer=('CreateMutex', ),
71  optimizer=self.model.NoOptim,
72  )
73  self.counter = self.create_param(
74  param_name=('%s_counter' % self.name),
75  shape=[1],
76  initializer=(
77  'ConstantFill', {
78  'value': 0,
79  'dtype': core.DataType.INT64
80  }
81  ),
82  optimizer=self.model.NoOptim,
83  )
84  self.xent_weight = self.create_param(
85  param_name=('%s_xent_weight' % self.name),
86  shape=[1],
87  initializer=(
88  'ConstantFill', {
89  'value': 1.,
90  'dtype': core.DataType.FLOAT
91  }
92  ),
93  optimizer=self.model.NoOptim,
94  )
95  self.jsd_weight = self.create_param(
96  param_name=('%s_jsd_weight' % self.name),
97  shape=[1],
98  initializer=(
99  'ConstantFill', {
100  'value': 0.,
101  'dtype': core.DataType.FLOAT
102  }
103  ),
104  optimizer=self.model.NoOptim,
105  )
106  else:
107  self.jsd_weight = self.model.add_global_constant(
108  '%s_jsd_weight' % self.name, jsd_weight
109  )
110  self.xent_weight = self.model.add_global_constant(
111  '%s_xent_weight' % self.name, 1. - jsd_weight
112  )
113 
114  def update_weight(self, net):
115  net.AtomicIter([self.mutex, self.counter], [self.counter])
116  # iter = 0: lr = 1;
117  # iter = 1e6; lr = 0.5^0.1 = 0.93
118  # iter = 1e9; lr = 1e-3^0.1 = 0.50
119  net.LearningRate([self.counter], [self.xent_weight], base_lr=1.0,
120  policy='inv', gamma=1e-6, power=0.1,)
121  net.Sub(
122  [self.model.global_constants['ONE'], self.xent_weight],
123  [self.jsd_weight]
124  )
125  return self.xent_weight, self.jsd_weight
126 
127  def add_ops(self, net):
128  # numerically stable log-softmax with crossentropy
129  label = self.input_record.label()
130  # mandatory cast to float32
131  # self.input_record.label.field_type().base is np.float32 but
132  # label type is actually int
133  label = net.Cast(
134  label,
135  net.NextScopedBlob('label_float32'),
136  to=core.DataType.FLOAT)
137  label = net.ExpandDims(label, net.NextScopedBlob('expanded_label'),
138  dims=[1])
139  if self.pos_label_target != 1.0 or self.neg_label_target != 0.0:
140  label = net.StumpFunc(
141  label,
142  net.NextScopedBlob('smoothed_label'),
143  threshold=0.5,
144  low_value=self.neg_label_target,
145  high_value=self.pos_label_target,
146  )
147  xent = net.SigmoidCrossEntropyWithLogits(
148  [self.input_record.logit(), label],
149  net.NextScopedBlob('cross_entropy'),
150  )
151  # fuse with JSD
152  if self.jsd_fuse:
153  jsd = net.BernoulliJSD(
154  [self.input_record.prediction(), label],
155  net.NextScopedBlob('jsd'),
156  )
157  if self.homotopy_weighting:
158  self.update_weight(net)
159  loss = net.WeightedSum(
160  [xent, self.xent_weight, jsd, self.jsd_weight],
161  net.NextScopedBlob('loss'),
162  )
163  else:
164  loss = xent
165  if 'weight' in self.input_record.fields:
166  weight_blob = self.input_record.weight()
167  if self.input_record.weight.field_type().base != np.float32:
168  weight_blob = net.Cast(
169  weight_blob,
170  weight_blob + '_float32',
171  to=core.DataType.FLOAT
172  )
173  weight_blob = net.StopGradient(
174  [weight_blob],
175  [net.NextScopedBlob('weight_stop_gradient')],
176  )
177  loss = net.Mul(
178  [loss, weight_blob],
179  net.NextScopedBlob('weighted_cross_entropy'),
180  )
181 
182  if self.average_loss:
183  net.AveragedLoss(loss, self.output_schema.field_blobs())
184  else:
185  net.ReduceFrontSum(loss, self.output_schema.field_blobs())
def init_weight(self, jsd_weight, homotopy_weighting)