3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
16 def __init__(self, model, input_record, output_dim,
17 name=
'pairwise_dot_product', **kwargs):
18 super(PairwiseDotProduct, self).__init__(model, name, input_record, **kwargs)
20 "Incorrect input type. Excpected Struct, but received: {0}".
23 (
'all_embeddings' in input_record) ^
24 (
'x_embeddings' in input_record
and 'y_embeddings' in input_record)
26 "either (all_embeddings) xor (x_embeddings and y_embeddings) " +
29 if 'all_embeddings' in input_record:
30 x_embeddings = input_record[
'all_embeddings']
31 y_embeddings = input_record[
'all_embeddings']
33 x_embeddings = input_record[
'x_embeddings']
34 y_embeddings = input_record[
'y_embeddings']
37 "Incorrect input type for x. Expected Scalar, " +
38 "but received: {0}".format(x_embeddings))
40 "Incorrect input type for y. Expected Scalar, " +
41 "but received: {0}".format(y_embeddings)
44 if 'indices_to_gather' in input_record:
45 indices_to_gather = input_record[
'indices_to_gather']
47 "Incorrect type of indices_to_gather. " 48 "Expected Scalar, but received: {0}".format(indices_to_gather)
57 dtype = x_embeddings.field_types()[0].base
60 (dtype, (output_dim,)),
61 self.get_next_blob_reference(
'output')
64 def add_ops(self, net):
72 flattened = net.Flatten(