1 #include "batch_sparse_to_dense_op.h" 3 #include "caffe2/core/context.h" 7 template <
typename T,
class Context>
8 bool BatchSparseToDenseOp<T, Context>::RunOnDevice() {
9 auto& lengths = Input(LENGTHS);
10 auto& indices = Input(INDICES);
11 auto& values = Input(VALUES);
12 auto* output = Output(0);
13 CAFFE_ENFORCE_EQ(indices.size(), values.size());
14 CAFFE_ENFORCE_EQ(lengths.ndim(), 1);
15 CAFFE_ENFORCE_EQ(indices.ndim(), 1);
17 const TIndex* lengths_data = lengths.template data<TIndex>();
18 const TIndex* indices_data = indices.template data<TIndex>();
19 const T* values_data = values.template data<T>();
20 TIndex batch_size = lengths.size();
21 TIndex lengths_sum = 0;
22 math::Sum<TIndex, Context>(batch_size, lengths_data, &lengths_sum, &context_);
23 CAFFE_ENFORCE_EQ(lengths_sum, indices.size());
25 vector<TIndex> output_shape = {batch_size};
26 if (InputSize() == 4) {
27 auto& shaper = Input(3);
28 CAFFE_ENFORCE_EQ(shaper.ndim(), 2);
29 if (dense_last_dim_ == -1) {
30 dense_last_dim_ = shaper.dim(1);
33 dense_last_dim_ == shaper.dim(1),
34 "The last dim argument is not aligned with the shape input last dim");
37 CAFFE_ENFORCE(dense_last_dim_ >= 1,
"The last dim of dense must be >= 1");
39 output_shape.push_back(dense_last_dim_);
40 output->Resize(output_shape);
41 T* output_data = output->template mutable_data<T>();
43 output->size(),
static_cast<T
>(default_value_), output_data, &context_);
46 for (TIndex i = 0; i < batch_size; ++i) {
47 for (TIndex j = 0; j < lengths_data[i]; ++j) {
49 indices_data[k] < dense_last_dim_,
52 ") is larger then last dim of dense (",
55 output_data[i * dense_last_dim_ + indices_data[k]] = values_data[k];
63 template <
typename T,
class Context>
64 bool BatchDenseToSparseOp<T, Context>::RunOnDevice() {
65 auto& lengths = Input(LENGTHS);
66 auto& indices = Input(INDICES);
67 auto& dense = Input(DENSE);
68 auto* output = Output(0);
69 CAFFE_ENFORCE_EQ(lengths.ndim(), 1);
70 CAFFE_ENFORCE_EQ(indices.ndim(), 1);
71 CAFFE_ENFORCE_EQ(dense.ndim(), 2);
72 const TIndex* lengths_data = lengths.template data<TIndex>();
73 const TIndex* indices_data = indices.template data<TIndex>();
74 const T* dense_data = dense.template data<T>();
76 TIndex batch_size = lengths.size();
77 TIndex lengths_sum = 0;
78 math::Sum<TIndex, Context>(batch_size, lengths_data, &lengths_sum, &context_);
79 CAFFE_ENFORCE_EQ(lengths_sum, indices.size());
81 CAFFE_ENFORCE_EQ(batch_size, dense.dim(0));
82 dense_last_dim_ = dense.dim(1);
83 vector<TIndex> output_shape = indices.dims();
84 output->Resize(output_shape);
85 T* output_data = output->template mutable_data<T>();
88 for (TIndex i = 0; i < batch_size; ++i) {
89 for (TIndex j = 0; j < lengths_data[i]; ++j) {
91 indices_data[k] < dense.dim(1),
94 ") is larger then last dim of dense (",
97 output_data[k] = dense_data[i * dense.dim(1) + indices_data[k]];
104 REGISTER_CPU_OPERATOR(
106 BatchSparseToDenseOp<float, CPUContext>);
108 OPERATOR_SCHEMA(BatchSparseToDense)
112 Convert sparse matrix representation into dense matrix. 114 A sparse matrix is represented by `lengths` vector, `indices` vector, 115 and `values` vector. Each element in `lengths` vector (lengths[`i`]) represents 116 the number of indices in this batch (batch `i`). 117 With in each batch, `indices` should not have duplicate number. 119 For example, with input: 122 indices = [0, 1, 2, 3, 4, 5] 123 values = [6, 7, 8, 9, 10, 11] 129 output = [[6, 7, 0, 0, 0, 0], 133 after running this operator. 138 "Flatten tensor, used to break down indices and values into per batch indices and values.")
142 "Flatten tensor of total size = \\sum lengths, containing the indices ")
143 .Input(2,
"values",
"Data tensor, dimension has to match `indices`")
146 "output_shape_inference",
147 "Optional, a dense tensor whose shape define the output shape")
151 "2-D dense tensor, with 1st dim = len(lengths), 2nd dim = dense_last_dim" 152 "in the arg list, the tensor is of the same data type as `values`." 153 "Missing values are filled with default_value")
156 "Optional, output dense last dimension. " 157 "If both this argument and output_shape_inference are set, " 158 "it should be consistent with output_shape_inference's last dim")
161 "Optional, missing values are filled with this value." 162 "default_value = 0 when not set");
164 REGISTER_CPU_OPERATOR(
166 BatchDenseToSparseOp<float, CPUContext>);
168 OPERATOR_SCHEMA(BatchDenseToSparse)
172 This Op is a inverse of BatchSparseToDenseOp. 173 Basically, given a `lengths` vector, a `indices` vector, 174 and a dense matrix `dense`, output `value` vector so that, along with 175 `lengths` vector and `indices` vector, forms a sparse representation 178 A sparse matrix is represented by `lengths` vector, `indices` vector, 179 and `values` vector. Each element in `lengths` vector (lengths[`i`]) represents 180 the number of indices in this batch (batch `i`). 181 With in each batch, `indices` should not have duplicate number. 183 For example, with input: 186 indices = [0, 1, 2, 3, 4, 5] 187 output = [[6, 7, 0, 0, 0, 0], 193 values = [6, 7, 8, 9, 10, 11] 195 after running this operator. 200 "Flatten lengths, Used to break down indices into per batch indices")
204 "Flatten indices, tensor of total size = \\sum lengths, containing the indices ")
208 "dense 2-D tensor, first dim = len(lengths), last dim > Any(indices)")
212 "Values, tensor of the same size as `indices` and same data type as dense tensor.");
216 class GetBatchSparseToDenseGradient :
public GradientMakerBase {
217 using GradientMakerBase::GradientMakerBase;
218 vector<OperatorDef> GetGradientDefs()
override {
219 return SingleGradientDef(
220 "BatchDenseToSparse",
222 vector<string>{I(0), I(1), GO(0)},
223 vector<string>{GI(2)});
227 class GetBatchDenseToSparseGradient :
public GradientMakerBase {
228 using GradientMakerBase::GradientMakerBase;
229 vector<OperatorDef> GetGradientDefs()
override {
230 return SingleGradientDef(
231 "BatchSparseToDense",
233 vector<string>{I(0), I(1), GO(0), I(2)},
234 vector<string>{GI(2)});
238 REGISTER_GRADIENT(BatchSparseToDense, GetBatchSparseToDenseGradient);
239 REGISTER_GRADIENT(BatchDenseToSparse, GetBatchDenseToSparseGradient);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...