1 #include "caffe2/operators/weighted_sample_op.h" 6 bool WeightedSampleOp<float, CPUContext>::RunOnDevice() {
10 "The number of tensors of the input and the output must be the same.");
11 auto& weights = Input(0);
12 int batch_size = weights.dim(0);
13 int weights_dim = weights.dim(1);
14 auto* out_idx = Output(0);
16 if (batch_size > 0 && weights_dim > 0) {
17 cum_mass_.resize(weights_dim);
18 const float* mat_weights = weights.template data<float>();
19 const float* mat_values =
nullptr;
20 out_idx->Resize(batch_size, 1);
21 int* output_indices = out_idx->template mutable_data<int>();
22 float* output_values =
nullptr;
24 if (InputSize() == 2) {
25 auto& values = Input(1);
29 "The sampling weights tensor and the sampling values tensor must have the same dimensions.");
30 mat_values = values.template data<float>();
31 auto* out_value = Output(1);
32 out_value->Resize(batch_size, 1);
33 output_values = out_value->template mutable_data<float>();
36 for (
int i = 0; i < batch_size; i++) {
38 int offset = i * weights_dim;
40 cum_mass_[0] = mat_weights[offset];
41 for (
int j = 1; j < weights_dim; j++) {
42 cum_mass_[j] = cum_mass_[j - 1] + mat_weights[offset + j];
45 math::RandUniform<float, CPUContext>(
46 1, 0.0f, cum_mass_[cum_mass_.size() - 1], &r, &context_);
49 cum_mass_[cum_mass_.size() - 1] += 0.01f;
50 auto lb = lower_bound(cum_mass_.begin(), cum_mass_.end(), r);
51 CAFFE_ENFORCE(lb != cum_mass_.end(),
"Cannot find ", r,
" in cum_mass_.");
52 output_indices[i] =
static_cast<int>(lb - cum_mass_.begin());
56 static_cast<float>(mat_values[offset + (lb - cum_mass_.begin())]);
61 out_idx->template mutable_data<int>();
62 if (OutputSize() == 2) {
63 auto* out_value = Output(1);
65 out_value->template mutable_data<float>();
72 REGISTER_CPU_OPERATOR(WeightedSample, WeightedSampleOp<float, CPUContext>);
74 OPERATOR_SCHEMA(WeightedSample)
77 .TensorInferenceFunction([](
const OperatorDef& def,
78 const vector<TensorShape>& in) {
79 vector<TensorShape> out(2);
80 int batch_size = in[0].dims(0);
81 out[0] = CreateTensorShape(vector<int>{batch_size}, TensorProto::INT32);
82 out[1] = CreateTensorShape(vector<int>{batch_size}, TensorProto::FLOAT);
86 The operator performs sampling based on the input sampling weights for 87 each batch. All weights must be non-negative numbers. 88 The input is a 2-D tensor (Tensor<float>) of size (batch_size x weights_dim). 89 For each batch, an index is randomly sampled from the distribution given by 90 the weights of the corresponding batch. 91 The output is a 1-D tensor (Tensor<int>) of size (batch_size x 1) and 92 contains the index(es) of the sampled output. 97 "A 2-D Tensor<float> of size (batch_size x weights_dim)." 98 "All weights must be non-negative numbers.")
102 "An optional 2-D Tensor<float> of size (batch_size x weights_dim)." 103 "Its values correspond to the sampling weights.")
107 "The output tensor contains index(es) sampled from distribution given" 108 "by the weight vector(s) in the input tensor" 109 "The output is a 1-D Tensor<int> of size (batch_size x 1)")
113 "The output tensor contains value(s) selected by the sampled index(es)" 114 "It is a 1-D Tensor<float> of size (batch_size x 1)");
116 SHOULD_NOT_DO_GRADIENT(WeightedSample);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...