1 #include "caffe2/experiments/operators/funhash_op.h" 6 REGISTER_CPU_OPERATOR(FunHash, FunHashOp<float, CPUContext>);
7 REGISTER_CPU_OPERATOR(FunHashGradient, FunHashGradientOp<float, CPUContext>);
9 OPERATOR_SCHEMA(FunHash)
13 This layer compresses a fully-connected layer for sparse inputs 15 It takes four required inputs and an optional fifth input. 16 The first three inputs `scalars`, `indices`, and `segment_ids` are 17 the sparse segmented representation of sparse data, which are the 18 same as the last three inputs of the `SparseSortedSegmentWeightedSum` 19 operator. If the argument `num_segments` is specified, it would be used 20 as the first dimension for the output; otherwise it would be derived 21 from the maximum segment ID. 23 The fourth input is a 1D weight vector. Each entry of the fully-connected 24 layer would be randomly mapped from one of the entries in this vector. 26 When the optional fifth input vector is present, each weight of the 27 fully-connected layer would be the linear combination of K entries 28 randomly mapped from the weight vector, provided the input 29 (length-K vector) serves as the coefficients. 31 .Input(0, "scalars",
"Values of the non-zero entries of the sparse data.")
32 .Input(1,
"indices",
"Indices to the non-zero valued features.")
33 .Input(2,
"segment_ids",
34 "Segment IDs corresponding to the non-zero entries.")
35 .Input(3,
"weight",
"Weight vector")
37 "Optional coefficients for linear combination of hashed weights.")
39 "Output tensor with the first dimension equal to the number " 41 .Arg(
"num_outputs",
"Number of outputs")
42 .Arg(
"num_segments",
"Number of segments");
44 OPERATOR_SCHEMA(FunHashGradient)
48 class GetFunHashGradient :
public GradientMakerBase {
49 using GradientMakerBase::GradientMakerBase;
50 vector<OperatorDef> GetGradientDefs()
override {
51 if (def_.input_size() == 4) {
52 return SingleGradientDef(
53 "FunHashGradient",
"",
54 vector<string>{GO(0), I(0), I(1), I(2), I(3)},
55 vector<string>{GI(3)});
58 return SingleGradientDef(
59 "FunHashGradient",
"",
60 vector<string>{GO(0), I(0), I(1), I(2), I(3), I(4)},
61 vector<string>{GI(3), GI(4)});
65 REGISTER_GRADIENT(FunHash, GetFunHashGradient);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...