Caffe2 - C++ API
A deep learning, cross platform ML framework
fully_connected_op_sparse.h
1 #ifndef CAFFE2_OPERATORS_FULLY_CONNECTED_OP_SPARSE_H_
2 #define CAFFE2_OPERATORS_FULLY_CONNECTED_OP_SPARSE_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 #include "caffe2/utils/math.h"
7 #ifdef CAFFE2_USE_MKL
8 #include <mkl.h>
9 #endif // CAFFE2_USE_MKL
10 
11 namespace caffe2 {
12 
13 namespace {
14 
15 template<int N>
16 using Shape = std::array<int, N>;
17 
18 template<int N>
19 const std::vector<TIndex>& shape(Shape<N> vs) {
20  static thread_local std::vector<TIndex> cache;
21  cache.resize(vs.size());
22  for (auto i = 0; i < vs.size(); ++i) {
23  cache[i] = vs[i];
24  }
25  return cache;
26 }
27 
28 inline const std::vector<TIndex>& shape(int i) {
29  return shape<1>(Shape<1>({i}));
30 }
31 
32 inline const std::vector<TIndex>& shape(int i, int j) {
33  return shape<2>(Shape<2>({i, j}));
34 }
35 
36 template <typename T, class Context>
37 void Sparse_mm(const T* acsr, const int* ia, const int* ja,
38  int m, int k, int n, const T* b, T* c, Context* context);
39 
40 template<typename T, class Context>
41 void trans_mat(const T* o, T* t, int m, int n, Context* context);
42 
43 template <>
44 void trans_mat<float, CPUContext>(
45  const float* o,
46  float* t,
47  int m,
48  int n,
49  CPUContext* /*context*/) {
50  for(int i = 0; i < m; ++i){
51  for(int j = 0; j < n; ++j){
52  t[j*m+i]=o[i*n+j];
53  }
54  }
55 }
56 
57 // C = A(sparse) * B
58 // No transpose;
59 template <>
60 void Sparse_mm<float, CPUContext>(
61  const float* acsr,
62  const int* ia,
63  const int* ja,
64  int m,
65  int k,
66  int n,
67  const float* b,
68  float* c,
69  CPUContext* /*context*/) {
70  float alpha = 1.0, beta = 0.;
71  mkl_scsrmm("N", &m, &n, &k, &alpha, "GLNC",
72  acsr, ja, ia, ia+1, b, &n, &beta, c, &n);
73 }
74 
75 }
76 
77 // This is Caffe's InnerProductOp, with a name that fits its purpose better.
78 template <typename T, class Context, class Engine=DefaultEngine>
79 class FullyConnectedOp_SPARSE final : public Operator<Context> {
80  public:
81  USE_OPERATOR_CONTEXT_FUNCTIONS;
82  FullyConnectedOp_SPARSE(const OperatorDef& operator_def, Workspace* ws)
83  : Operator<Context>(operator_def, ws) {}
85 
86  bool RunOnDevice() override {
87  const auto& Xt = Input(0); // transposed X
88  const auto& Wcsr = Input(1);
89  const auto& iw = Input(2);
90  const auto& jw = Input(3);
91  // Notice that we do not need to transpose b
92  const auto& b = Input(4);
93  auto* Yt = Output(0); //transposed Y
94  // here we assume X is k-by-m
95  CAFFE_ENFORCE_EQ(Xt.ndim(), 2);
96  CAFFE_ENFORCE_EQ(b.ndim(), 1);
97  // batch size
98  int K = Xt.ndim() > 1 ? Xt.dim32(0) : 1;
99  // Feature dimension
100  int M = Xt.size() / K;
101  // number of outputs.
102  int N = iw.dim32(0)-1;
103  CAFFE_ENFORCE_EQ(N, b.dim32(0));
104  Yt->Resize(shape(N, M));
105 
106  // Y' = W * X';
107  Sparse_mm<T, Context>(
108  Wcsr.template data<T>(), iw.template data<int>(),
109  jw.template data<int>(), N, K, M, Xt.template data<T>(),
110  Yt->template mutable_data<T>(), &context_);
111  // Add bias term
112  if (bias_multiplier_.size() != M) {
113  // If the helper bias multiplier is not M, reshape and fill it with one.
114  bias_multiplier_.Resize(shape(M));
115  math::Set<T, Context>(
116  M, static_cast<T>(1), bias_multiplier_.template mutable_data<T>(),
117  &context_);
118  }
119  math::Gemm<T, Context, Engine>(
120  CblasNoTrans, CblasNoTrans, N, M, 1, 1,
121  b.template data<T>(), bias_multiplier_.template data<T>(), 1,
122  Yt->template mutable_data<T>(), &context_);
123  return true;
124  }
125 
126  protected:
127  Tensor<Context> bias_multiplier_;
128 };
129 
130 
131 } // namespace caffe2
132 
133 #endif // CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:93
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...