Caffe2 - C++ API
A deep learning, cross platform ML framework
weighted_multi_sampling_op.cc
1 #include "caffe2/operators/weighted_multi_sampling_op.h"
2 #include "caffe2/utils/math.h"
3 
4 namespace caffe2 {
5 
6 template <class Context>
7 bool WeightedMultiSamplingOp<Context>::RunOnDevice() {
8  const auto& weight = Input(0);
9  CAFFE_ENFORCE_EQ(weight.ndim(), 1, "Input should be 1-D vector");
10  auto dims = weight.dims();
11  size_t data_size = weight.dim32(0);
12  auto* indices = Output(0);
13 
14  auto num_samples = num_samples_;
15  if (InputSize() == 2) {
16  CAFFE_ENFORCE(
17  !OperatorBase::HasArgument("num_samples"),
18  "New shape is specified by the input blob, do not pass in "
19  "the argument `num_samples`.");
20  num_samples = Input(1).size();
21  indices->ResizeLike(Input(1));
22  } else {
23  indices->Resize(num_samples);
24  }
25 
26  int* indices_data = indices->template mutable_data<int>();
27  if (data_size == 0) {
28  indices->Resize(0);
29  return true;
30  }
31 
32  const float* weight_data = weight.template data<float>();
33 
34  for (int i = 0; i < num_samples; ++i) {
35  float r;
36  math::RandUniform<float, Context>(
37  1, 0.0f, weight_data[data_size - 1], &r, &context_);
38  auto lb = std::lower_bound(weight_data, weight_data + data_size, r);
39  CAFFE_ENFORCE(
40  lb != weight_data + data_size, "Cannot find ", r, " in input CDF.");
41  indices_data[i] = static_cast<int>(lb - weight_data);
42  }
43  return true;
44 }
45 
46 REGISTER_CPU_OPERATOR(
47  WeightedMultiSampling,
48  WeightedMultiSamplingOp<CPUContext>);
49 
50 OPERATOR_SCHEMA(WeightedMultiSampling)
51  .NumInputs(1, 2)
52  .NumOutputs(1)
53  .TensorInferenceFunction([](const OperatorDef& def,
54  const vector<TensorShape>& in) {
55  vector<TensorShape> out(1);
56  if (in[0].dims(0) == 0) {
57  out[0].set_data_type(TensorProto::INT32);
58  out[0].add_dims(0);
59  return out;
60  }
61 
62  const ArgumentHelper args(def);
63  if (args.HasArgument("num_samples")) {
64  CAFFE_ENFORCE_EQ(
65  in.size(),
66  1,
67  "New shape must not be specified by the input blob and the "
68  "argument `num_samples` at the same time.");
69  int num_samples = args.GetSingleArgument<int64_t>("num_samples", 0);
70  out[0] =
71  CreateTensorShape(vector<int64_t>{num_samples}, TensorProto::INT32);
72  return out;
73  } else {
74  CAFFE_ENFORCE_EQ(
75  in.size(),
76  2,
77  "New shape must be specified by either the input blob or the "
78  "argument `num_samples`.");
79  std::vector<int64_t> output_dims = GetDimsVector(in[1]);
80  out[0] = CreateTensorShape(output_dims, TensorProto::INT32);
81  return out;
82  }
83  })
84  .SetDoc(R"DOC(
85 The operator performs sampling based on the input sampling weights.
86 All weights are cummulative probability thus sorted. The output is
87 a 1-D tensor (Tensor<int>). If two inputs are given, the second input
88 is used to provide shape of the output sample tensor. Otherwise, we use
89 argument `num_samples` to determine the number of samples to generate.
90 )DOC")
91  .Input(
92  0,
93  "sampling_cdf",
94  "An optional 1-D Tensor<float>."
95  "Input cumulative sampling probability (such as [0.2, 0.5, 0.8, 1.5])."
96  " All weights must be non-negative numbers. Note that the last value of"
97  " CDF is not necessary 1. If the last value is not 1, all values in"
98  " sampling_cdf will be scaled by this number.")
99  .Input(
100  1,
101  "shape_tensor (optional)",
102  "Tensor whose shape will be applied to output.")
103  .Output(
104  0,
105  "sampled_indexes",
106  "The output tensor contains indices sampled from distribution given"
107  "by the weight vector in the input tensor"
108  "The output is a 1-D Tensor<int> of size determined by argument"
109  "`num_samples` or the second input tensor.")
110  .Arg("num_samples", "number of samples to sample from the input data");
111 
112 SHOULD_NOT_DO_GRADIENT(WeightedMultiSample);
113 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:37