Caffe2 - Python API
A deep learning, cross platform ML framework
pairwise_dot_product.py
1 ## @package dot_product
2 # Module caffe2.python.layers.dot_product
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python import schema
9 from caffe2.python.layers.layers import (
10  ModelLayer,
11 )
12 
13 
14 class PairwiseDotProduct(ModelLayer):
15 
16  def __init__(self, model, input_record, output_dim,
17  name='pairwise_dot_product', **kwargs):
18  super(PairwiseDotProduct, self).__init__(model, name, input_record, **kwargs)
19  assert isinstance(input_record, schema.Struct), (
20  "Incorrect input type. Excpected Struct, but received: {0}".
21  format(input_record))
22  assert (
23  ('all_embeddings' in input_record) ^
24  ('x_embeddings' in input_record and 'y_embeddings' in input_record)
25  ), (
26  "either (all_embeddings) xor (x_embeddings and y_embeddings) " +
27  "should be given."
28  )
29  if 'all_embeddings' in input_record:
30  x_embeddings = input_record['all_embeddings']
31  y_embeddings = input_record['all_embeddings']
32  else:
33  x_embeddings = input_record['x_embeddings']
34  y_embeddings = input_record['y_embeddings']
35 
36  assert isinstance(x_embeddings, schema.Scalar), (
37  "Incorrect input type for x. Expected Scalar, " +
38  "but received: {0}".format(x_embeddings))
39  assert isinstance(y_embeddings, schema.Scalar), (
40  "Incorrect input type for y. Expected Scalar, " +
41  "but received: {0}".format(y_embeddings)
42  )
43 
44  if 'indices_to_gather' in input_record:
45  indices_to_gather = input_record['indices_to_gather']
46  assert isinstance(indices_to_gather, schema.Scalar), (
47  "Incorrect type of indices_to_gather. "
48  "Expected Scalar, but received: {0}".format(indices_to_gather)
49  )
50  self.indices_to_gather = indices_to_gather
51  else:
52  self.indices_to_gather = None
53 
54  self.x_embeddings = x_embeddings
55  self.y_embeddings = y_embeddings
56 
57  dtype = x_embeddings.field_types()[0].base
58 
60  (dtype, (output_dim,)),
61  self.get_next_blob_reference('output')
62  )
63 
64  def add_ops(self, net):
65  Y = net.BatchMatMul(
66  [self.x_embeddings(), self.y_embeddings()],
67  [self.x_embeddings() + '_matmul'],
68  trans_b=1,
69  )
70 
71  if self.indices_to_gather:
72  flattened = net.Flatten(
73  Y, Y + '_flatten',
74  )
75  net.BatchGather(
76  [flattened, self.indices_to_gather()],
77  self.output_schema(),
78  )
79  else:
80  net.Flatten(Y, self.output_schema())