Caffe2 - C++ API
A deep learning, cross platform ML framework
tt_contraction_op.cc
1 #include "caffe2/experiments/operators/tt_contraction_op.h"
2 
3 namespace caffe2 {
4 
5 REGISTER_CPU_OPERATOR(TTContraction, TTContractionOp<float, CPUContext>);
6 
7 OPERATOR_SCHEMA(TTContraction)
8  .NumInputs(2)
9  .NumOutputs(1)
10  .SetDoc(R"DOC(
11 Tensor contraction C = A * B
12 )DOC")
13  .Arg("K", "i_{k-1} * r_k")
14  .Arg("M", "r_{k-1} * o_{k-1}")
15  .Arg("N", "o_k")
16  .Input(0, "A", "2D matrix of size (K x M)")
17  .Input(1, "B", "tensor")
18  .Output(0, "C", "contracted tensor");
19 
20 REGISTER_CPU_OPERATOR(
21  TTContractionGradient,
22  TTContractionGradientOp<float, CPUContext>);
23 
24 OPERATOR_SCHEMA(TTContractionGradient).NumInputs(3).NumOutputs(2);
25 
26 class GetTTContractionGradient : public GradientMakerBase {
27  using GradientMakerBase::GradientMakerBase;
28  vector<OperatorDef> GetGradientDefs() override {
29  return SingleGradientDef(
30  "TTContractionGradient",
31  "",
32  vector<string>{GO(0), I(0), I(1)},
33  vector<string>{GI(0), GI(1)},
34  Def().arg());
35  }
36 };
37 
38 REGISTER_GRADIENT(TTContraction, GetTTContractionGradient);
39 
40 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...