Caffe2 - C++ API
A deep learning, cross platform ML framework
sparse_normalize_op.h
1 #pragma once
2 
3 #include "caffe2/core/operator.h"
4 #include "caffe2/utils/math.h"
5 
6 namespace caffe2 {
7 
8 template <typename T, class Context>
9 class SparseNormalizeOp final : public Operator<Context> {
10  public:
11  USE_OPERATOR_CONTEXT_FUNCTIONS;
12  SparseNormalizeOp(const OperatorDef& operator_def, Workspace* ws)
13  : Operator<Context>(operator_def, ws),
14  use_max_norm_(
15  OperatorBase::GetSingleArgument<bool>("use_max_norm", true)),
16  norm_(OperatorBase::GetSingleArgument<float>("norm", 1.0)) {
17  CAFFE_ENFORCE_GE(norm_, 0, "norm should be bigger than 0");
18  }
19 
20  bool RunOnDevice() override {
21  CAFFE_ENFORCE_EQ(
22  Input(PARAM).size_from_dim(1),
23  Input(GRAD).size_from_dim(Input(INDICES).ndim()));
24 
26  this, Input(INDICES));
27  }
28 
29  template <typename SIndex>
30  bool DoRunWithType();
31 
32  protected:
33  bool use_max_norm_;
34  float norm_;
35  INPUT_TAGS(PARAM, INDICES, GRAD);
36  OUTPUT_TAGS(OUTPUT_PARAM);
37 };
38 
39 } // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...