1 #ifndef CAFFE2_OPERATORS_FULLY_CONNECTED_OP_SPARSE_H_ 2 #define CAFFE2_OPERATORS_FULLY_CONNECTED_OP_SPARSE_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/operator.h" 6 #include "caffe2/utils/math.h" 9 #endif // CAFFE2_USE_MKL 16 using Shape = std::array<int, N>;
19 const std::vector<TIndex>& shape(Shape<N> vs) {
20 static thread_local std::vector<TIndex> cache;
21 cache.resize(vs.size());
22 for (
auto i = 0; i < vs.size(); ++i) {
28 inline const std::vector<TIndex>& shape(
int i) {
29 return shape<1>(Shape<1>({i}));
32 inline const std::vector<TIndex>& shape(
int i,
int j) {
33 return shape<2>(Shape<2>({i, j}));
36 template <
typename T,
class Context>
37 void Sparse_mm(
const T* acsr,
const int* ia,
const int* ja,
38 int m,
int k,
int n,
const T* b, T* c, Context* context);
40 template<
typename T,
class Context>
41 void trans_mat(
const T* o, T* t,
int m,
int n, Context* context);
44 void trans_mat<float, CPUContext>(
50 for(
int i = 0; i < m; ++i){
51 for(
int j = 0; j < n; ++j){
60 void Sparse_mm<float, CPUContext>(
70 float alpha = 1.0, beta = 0.;
71 mkl_scsrmm(
"N", &m, &n, &k, &alpha,
"GLNC",
72 acsr, ja, ia, ia+1, b, &n, &beta, c, &n);
78 template <
typename T,
class Context,
class Engine=DefaultEngine>
81 USE_OPERATOR_CONTEXT_FUNCTIONS;
86 bool RunOnDevice()
override {
87 const auto& Xt = Input(0);
88 const auto& Wcsr = Input(1);
89 const auto& iw = Input(2);
90 const auto& jw = Input(3);
92 const auto& b = Input(4);
95 CAFFE_ENFORCE_EQ(Xt.ndim(), 2);
96 CAFFE_ENFORCE_EQ(b.ndim(), 1);
98 int K = Xt.ndim() > 1 ? Xt.dim32(0) : 1;
100 int M = Xt.size() / K;
102 int N = iw.dim32(0)-1;
103 CAFFE_ENFORCE_EQ(N, b.dim32(0));
104 Yt->Resize(shape(N, M));
107 Sparse_mm<T, Context>(
108 Wcsr.template data<T>(), iw.template data<int>(),
109 jw.template data<int>(), N, K, M, Xt.template data<T>(),
110 Yt->template mutable_data<T>(), &context_);
112 if (bias_multiplier_.size() != M) {
114 bias_multiplier_.Resize(shape(M));
115 math::Set<T, Context>(
116 M,
static_cast<T
>(1), bias_multiplier_.template mutable_data<T>(),
119 math::Gemm<T, Context, Engine>(
120 CblasNoTrans, CblasNoTrans, N, M, 1, 1,
121 b.template data<T>(), bias_multiplier_.template data<T>(), 1,
122 Yt->template mutable_data<T>(), &context_);
133 #endif // CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_ Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...