3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
12 Implementation of adaptive weighting: https://arxiv.org/pdf/1705.07115.pdf 21 name=
'adaptive_weight',
27 self).__init__(model, name, input_record, **kwargs)
31 self.
data = self.input_record.field_blobs()
36 initializer = (
'ConstantFill', {
'value': np.log(self.
num / 2.)})
38 assert len(weights) == self.
num 39 weights = np.array(weights).astype(np.float32)
40 values = np.log(1. / 2. / weights)
44 'dtype': core.DataType.FLOAT
51 initializer=initializer,
55 def concat_data(self, net):
57 net.NextScopedBlob(
'reshaped_data_%d' % i)
for i
in range(self.
num)
60 for i
in range(self.
num):
63 [reshaped[i], net.NextScopedBlob(
'new_shape_%d' % i)],
66 concated = net.NextScopedBlob(
'concated_data')
68 reshaped, [concated, net.NextScopedBlob(
'concated_new_shape')],
73 def compute_adaptive_sum(self, x, net):
74 mu_exp = net.NextScopedBlob(
'mu_exp')
75 net.Exp(self.
mu, mu_exp)
76 mu_exp_double = net.NextScopedBlob(
'mu_exp_double')
77 net.Scale(mu_exp, mu_exp_double, scale=2.0)
78 weighted_x = net.NextScopedBlob(
'weighted_x')
79 net.Div([x, mu_exp_double], weighted_x)
80 weighted_elements = net.NextScopedBlob(
'weighted_elements')
81 net.Add([weighted_x, self.
mu], weighted_elements)
84 def add_ops(self, net):
def get_next_blob_reference(self, name)
def concat_data(self, net)
def create_param(self, param_name, shape, initializer, optimizer, ps_param=None, regularizer=None)
def compute_adaptive_sum(self, x, net)