3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
11 get_categorical_limit,
25 def get_sparse_lookup_predictor_version(version):
26 assert version
in {
'fp32',
'fp16',
'uint8rowwise',
'fused_uint8rowwise'},\
27 "Unexpected version of sparse_lookup layer {0}".format(version)
31 def _is_id_list(input_record):
32 return schema.equal_schemas(input_record, IdList)
35 def _is_id_score_list(input_record):
36 return schema.equal_schemas(input_record,
38 check_field_types=
False)
42 _id_list_supported_reducers = [
43 'LogMeanExp',
'LogSumExp',
'Max',
'Mean',
'Sum',
44 'WeightedSum',
'WeightedMean',
'Sqrt',
'None']
46 _id_score_list_supported_reducers = [
47 'PositionWeighted',
'Mean',
'Sum',
'WeightedSum',
'WeightedMean',
'None']
49 def __init__(self, model, input_record, inner_shape, reducer,
50 weight_init=
None, weight_optim=
None,
51 name=
'sparse_lookup', regularizer=
None, **kwargs):
53 super(SparseLookup, self).__init__(model, name, input_record, **kwargs)
56 if isinstance(inner_shape, int):
57 inner_shape = [inner_shape]
58 assert isinstance(inner_shape, list)
or isinstance(inner_shape, tuple),\
59 "Unexpected type for inner_shape, expected list or tuple, got {0}".\
60 format(type(inner_shape))
62 if reducer ==
"PositionWeighted":
63 assert _is_id_score_list(self.input_record), (
64 "PositionWeighted only support IdScoreList, but got {} " +
65 "please use PositionWeighted layer to convert IdList " +
66 "to IdScoreList").format(repr(self.input_record))
70 input_dim = get_categorical_limit(input_record)
71 assert input_dim > 0, (
72 "{} should have categorical limit > 0, but got {}".format(
73 get_key(input_record)(), input_dim))
75 scale = math.sqrt(1.0 / input_dim)
76 self.
shape = [input_dim] + inner_shape
77 self.
weight_init = weight_init
if weight_init
else (
78 'UniformFill', {
'min': -scale,
'max': scale})
80 if _is_id_list(self.input_record):
81 sparse_key = self.input_record.items()
82 elif _is_id_score_list(self.input_record):
83 sparse_key = self.input_record.keys()
85 raise NotImplementedError()
87 if self.input_record.lengths.metadata:
88 avg_length = self.input_record.lengths.metadata.expected_value
92 self.
w = self.create_param(
96 optimizer=weight_optim,
97 ps_param=LayerPsParam(
98 sparse_key=sparse_key,
99 average_length=avg_length),
100 regularizer=regularizer
106 param_name=
'scale_bias',
109 optimizer=model.NoOptim,
113 (np.float32, inner_shape),
114 self.get_next_blob_reference(
'output'),
117 def get_memory_usage(self):
118 return functools.reduce(operator.mul, self.
shape) * 4
120 def get_fp16_compatible_parameters(self):
123 def support_8bit(self):
130 def get_8bits_compatible_parameters(self, fused=True):
134 RowwiseQuantized8BitsWeight = collections.namedtuple(
135 'RowwiseQuantized8BitsWeight',
'w' 137 return [RowwiseQuantized8BitsWeight(self.
w)]
139 RowwiseQuantized8BitsWeight = collections.namedtuple(
140 'RowwiseQuantized8BitsWeight',
'w, scale_bias' 142 return [RowwiseQuantized8BitsWeight(self.
w, self.
scale_bias)]
144 def _gather_wrapper(self, net, version, in_indices, out):
148 if version ==
'fp32':
149 return net.Gather([self.
w, in_indices], out)
150 elif version ==
'fp16':
151 gathered_w = net.Gather([self.
w, in_indices],
'gathered_w')
153 return net.HalfToFloat(gathered_w, out)
154 elif version ==
'uint8rowwise':
155 gathered_w = net.Gather([self.
w, in_indices],
'gathered_w')
156 gathered_scale_bias = net.Gather(
158 'gathered_scale_bias' 161 return net.Rowwise8BitQuantizedToFloat(
162 [gathered_w, gathered_scale_bias], out)
163 elif version ==
'fused_uint8rowwise':
164 gathered_w = net.Gather([self.
w, in_indices],
'gathered_w')
165 return net.Fused8BitRowwiseQuantizedToFloat(gathered_w, out)
167 raise "Unsupported version of operators in SparseLookup " +\
168 "layer: {0}".format(version)
170 def _sparse_lengths_weighted_reducer(
171 self, in_indices, weights, reducer,
172 net, version, grad_on_weights=0):
177 self.input_record.lengths()
179 layer_name =
'SparseLengths' + reducer
181 if version
in [
'fp32',
'fp16']:
184 net.__getattr__(layer_name)(
186 self.output_schema.field_blobs(),
187 grad_on_weights=grad_on_weights,
189 elif version ==
'uint8rowwise':
190 op_input.insert(len(op_input), self.
scale_bias)
191 net.__getattr__(layer_name +
'8BitsRowwise')(
192 op_input, self.output_schema.field_blobs())
193 elif version ==
'fused_uint8rowwise':
194 net.__getattr__(layer_name +
'Fused8BitRowwise')(
195 op_input, self.output_schema.field_blobs())
197 raise "Unsupported version of operator in SparseLookUp " +\
198 "layer: {0}".format(version)
201 def _add_ops_id_list(self, net, version):
203 "Unsupported reducer: {} for ID_LIST".format(self.
reducer)
205 if self.
reducer in [
'Sum',
'Mean',
'WeightedSum',
'WeightedMean']:
207 self.input_record.items(),
208 self.input_record.lengths()]
213 if self.
reducer ==
'WeightedSum':
215 elif self.
reducer ==
'WeightedMean':
218 layer_name =
'SparseLengths' + self.
reducer 219 if version
in [
'fp32',
'fp16']:
222 net.__getattr__(layer_name)(
224 self.output_schema.field_blobs(),
226 elif version ==
'uint8rowwise':
227 op_input.insert(len(op_input), self.
scale_bias)
228 net.__getattr__(layer_name +
'8BitsRowwise')(
229 op_input, self.output_schema.field_blobs())
230 elif version ==
'fused_uint8rowwise':
231 net.__getattr__(layer_name +
'Fused8BitRowwise')(
232 op_input, self.output_schema.field_blobs())
234 raise "Unsupported version of operator in SparseLookUp " +\
235 "layer: {0}".format(version)
238 sqrt_weight = net.LengthsToWeights(
239 [self.input_record.lengths()],
240 [net.NextScopedBlob(
'lengths_sqrt')],
244 self.input_record.items(),
246 'WeightedSum', net, version)
252 self.output_schema.field_blobs())
256 net, version, self.input_record.items(),
'table_rows')
258 segment_ids = net.LengthsToSegmentIds(
259 self.input_record.lengths(),
260 self.input_record.lengths() +
'_sid')
261 net.__getattr__(
'SortedSegmentRange' + self.
reducer)(
262 [table_rows, segment_ids],
263 self.output_schema.field_blobs(),
267 def _add_ops_id_score_list(self, net, version):
269 "Unsupported reducer: {} for ID_SCORE_LIST".format(self.
reducer)
271 if self.
reducer in [
'WeightedSum',
'WeightedMean']:
273 self.input_record.keys(),
274 self.input_record.values(),
277 elif self.
reducer in [
'Sum',
'Mean']:
279 self.input_record.keys(),
280 self.input_record.lengths()]
282 layer_name =
'SparseLengths' + self.
reducer 284 if version
in [
'fp32',
'fp16']:
285 net.__getattr__(layer_name)(
287 self.output_schema.field_blobs(),
289 elif version ==
'uint8rowwise':
290 net.__getattr__(layer_name +
'8BitsRowwise')(
291 op_input, self.output_schema.field_blobs())
292 elif version ==
'fused_uint8rowwise':
293 net.__getattr__(layer_name +
'Fused8BitRowwise')(
294 op_input, self.output_schema.field_blobs())
296 raise "Unsupported version of operator in SparseLookUp " +\
297 "layer: {0}".format(version)
299 elif self.
reducer ==
'PositionWeighted':
301 self.input_record.keys(),
303 'WeightedSum', net, version, grad_on_weights=1)
309 self.output_schema.field_blobs())
311 raise "Only Sum, Mean, None are supported for IdScoreList input." +\
312 "Trying to create with {}".format(self.
reducer)
314 def add_ops(self, net):
315 cur_scope = get_current_scope()
316 version = get_sparse_lookup_predictor_version(
317 **cur_scope.get(get_sparse_lookup_predictor_version.__name__,
318 {
'version':
'fp32'}))
322 if not self.
support_8bit()
and version
in {
'uint8rowwise',
323 'fused_uint8rowwise'}:
326 if _is_id_list(self.input_record):
328 elif _is_id_score_list(self.input_record):
331 raise "Unsupported input type {0}".format(self.input_record)
def _sparse_lengths_weighted_reducer(self, in_indices, weights, reducer, net, version, grad_on_weights=0)
list _id_list_supported_reducers
list _id_score_list_supported_reducers
def _add_ops_id_list(self, net, version)
def _add_ops_id_score_list(self, net, version)
def _gather_wrapper(self, net, version, in_indices, out)