1 #include "caffe2/experiments/operators/tt_contraction_op.h" 5 REGISTER_CPU_OPERATOR(TTContraction, TTContractionOp<float, CPUContext>);
7 OPERATOR_SCHEMA(TTContraction)
11 Tensor contraction C = A * B 13 .Arg("K",
"i_{k-1} * r_k")
14 .Arg(
"M",
"r_{k-1} * o_{k-1}")
16 .Input(0,
"A",
"2D matrix of size (K x M)")
17 .Input(1,
"B",
"tensor")
18 .Output(0,
"C",
"contracted tensor");
20 REGISTER_CPU_OPERATOR(
21 TTContractionGradient,
22 TTContractionGradientOp<float, CPUContext>);
24 OPERATOR_SCHEMA(TTContractionGradient).NumInputs(3).NumOutputs(2);
26 class GetTTContractionGradient :
public GradientMakerBase {
27 using GradientMakerBase::GradientMakerBase;
28 vector<OperatorDef> GetGradientDefs()
override {
29 return SingleGradientDef(
30 "TTContractionGradient",
32 vector<string>{GO(0), I(0), I(1)},
33 vector<string>{GI(0), GI(1)},
38 REGISTER_GRADIENT(TTContraction, GetTTContractionGradient);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...