Caffe2 - C++ API
A deep learning, cross platform ML framework
sparse_to_dense_op.h
1 #ifndef CAFFE2_OPERATORS_SPARSE_TO_DENSE_OP_H_
2 #define CAFFE2_OPERATORS_SPARSE_TO_DENSE_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 #include "caffe2/utils/math.h"
7 
8 namespace caffe2 {
9 
10 template <class Context>
11 class SparseToDenseOp final : public Operator<Context> {
12  public:
13  USE_OPERATOR_CONTEXT_FUNCTIONS;
14  USE_DISPATCH_HELPER;
15 
16  SparseToDenseOp(const OperatorDef& operator_def, Workspace* ws)
17  : Operator<Context>(operator_def, ws),
18  output_first_dim_(
19  OperatorBase::GetSingleArgument<int>("output_first_dim", 0)) {}
20 
21  bool RunOnDevice() override {
23  this, Input(INDICES));
24  }
25 
26  private:
27  template <typename TInd>
28  int GetOutputFirstDim(
29  const TInd* sparse_indices_vec,
30  const int32_t sparse_indices_len) {
31  if (output_first_dim_ > 0) {
32  CAFFE_ENFORCE_EQ(InputSize(), 2);
33  return output_first_dim_;
34  }
35  if (InputSize() == 3) {
36  auto& data_to_infer_dim = Input(DATA_TO_INFER_DIM);
37  CAFFE_ENFORCE_GE(data_to_infer_dim.ndim(), 1);
38  return data_to_infer_dim.dim32(0);
39  }
40  if (sparse_indices_len <= 0) {
41  return 0;
42  }
43 
44  // Awkward way to get the max element to make it work with both CUDA
45  // and CPU.
46  max_element_.Resize(1);
47  TInd* max_element_ptr = max_element_.template mutable_data<TInd>();
48  math::ReduceMax<TInd>(sparse_indices_len, sparse_indices_vec, max_element_ptr,
49  &scratch_, &context_);
50  max_element_host_.CopyFrom(max_element_);
51  return 1 + max_element_host_.template data<TInd>()[0];
52  }
53 
54  template <typename TInd>
55  bool DoRunWithType() {
56  return DispatchHelper<
58  float,
59  int32_t,
60  int64_t,
62  TInd>::call(this, Input(VALUES));
63  }
64 
65  template <typename TInd, typename TData>
66  bool DoRunWithType2() {
67  auto& sparse_indices = Input(INDICES);
68  CAFFE_ENFORCE_EQ(sparse_indices.ndim(), 1);
69  auto& sparse_values = Input(VALUES);
70  CAFFE_ENFORCE_GE(sparse_values.ndim(), 1);
71  CAFFE_ENFORCE_EQ(sparse_indices.size(), sparse_values.dim(0));
72 
73  const TInd* sparse_indices_vec = sparse_indices.template data<TInd>();
74  const int32_t sparse_indices_len = sparse_indices.dim32(0);
75  const int output_first_dim =
76  GetOutputFirstDim(sparse_indices_vec, sparse_indices_len);
77 
78  auto shape = sparse_values.dims();
79  shape[0] = output_first_dim;
80  auto* output = Output(0);
81  output->Resize(shape);
82 
83  TData* output_data = output->template mutable_data<TData>();
84  memset(output_data, 0, output->nbytes());
85  const auto block_nitems = sparse_values.size_from_dim(1);
86  const TData* sparse_values_vec = sparse_values.template data<TData>();
87 
88  for (int32_t i = 0; i < sparse_indices_len; i++) {
89  const TInd idx = sparse_indices_vec[i];
90  CAFFE_ENFORCE_GE(idx, 0);
91  CAFFE_ENFORCE_LT(idx, output_first_dim);
92  math::Add(
93  block_nitems,
94  output_data + idx * block_nitems,
95  sparse_values_vec + i * block_nitems,
96  output_data + idx * block_nitems,
97  &context_);
98  }
99  return true;
100  }
101 
102  template <typename TInd>
103  bool DoRunWithOtherType2() {
104  CAFFE_THROW(
105  "SparseToDense is not implemented on tensor of type ",
106  Input(VALUES).meta().name(),
107  "Consider adding it a type in the list DispatchHelper or implementing "
108  "a generic version (which won't work for duplicated indices though)");
109  }
110 
111  private:
112  int output_first_dim_;
113  Tensor<Context> scratch_;
114  Tensor<CPUContext> max_element_host_;
115  Tensor<Context> max_element_;
116 
117  INPUT_TAGS(INDICES, VALUES, DATA_TO_INFER_DIM);
118 };
119 
120 } // namespace caffe2
121 
122 #endif // CAFFE2_OPERATORS_SPARSE_TO_DENSE_OP_H_
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:93
void CopyFrom(const Tensor< SrcContext > &src, ContextForCopy *context)
Copies the data from a source tensor, with a contex provided to carry out the underlying memcpy opera...
Definition: tensor.h:166
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...