Caffe2 - C++ API
A deep learning, cross platform ML framework
onnx_while_op.h
1 #ifndef CAFFE2_OPERATORS_ONNX_WHILE_OP_H_
2 #define CAFFE2_OPERATORS_ONNX_WHILE_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/logging.h"
6 #include "caffe2/core/operator.h"
7 #include "caffe2/operators/create_scope_op.h"
8 
9 namespace caffe2 {
10 
11 template <class Context>
12 class ONNXWhileOp final : public Operator<Context> {
13  public:
14  ONNXWhileOp(const OperatorDef& operator_def, Workspace* ws)
15  : Operator<Context>(operator_def, ws),
16  parent_ws_(ws),
17  has_trip_count_(
18  OperatorBase::GetSingleArgument<int64_t>("has_trip_count", 0)),
19  has_cond_(OperatorBase::GetSingleArgument<int64_t>("has_cond", 0)),
20  save_scopes_(OperatorBase::GetSingleArgument<int64_t>("save_scopes", 0)) {
21  CAFFE_ENFORCE(
22  this->template HasSingleArgumentOfType<NetDef>("body"),
23  "body net must be specified in ONNXWhile operator");
24  body_net_def_ = this->template GetSingleArgument<NetDef>("body", NetDef());
25  if (!body_net_def_.has_name()) {
26  body_net_def_.set_name("loop_net");
27  }
28  }
29 
30  USE_OPERATOR_CONTEXT_FUNCTIONS;
31 
32  // Operator
33  // Inputs: max trip count, condition, initial loop-carried dependencies
34  // Outputs: Final loop-carried dependencies, scan_outputs
35  // Body
36  // Inputs: iteration number, condition, loop-carried dependencies
37  // Outputs: condition, loop-carried dependencies, scan_outputs
38  bool RunOnDevice() override {
39  // Clear workspaces from the previous invocations of the loop
40  // and setup a local scope for the first iteration
41  ws_stack_.clear();
42  auto loop_ws = ws_stack_.pushForwardWorkspace(parent_ws_);
43  scope_ = std::make_shared<LocalScope>(loop_ws, body_net_def_);
44 
45  constexpr int64_t num_inputs_before_lcds = 2;
46  // First input is the maximumt trip count. Second input is the condition
47  // variable (for the first iteration). The rest of the inputs are
48  // loop-carried dependencies.
49  int num_loop_carried_deps = InputSize() - num_inputs_before_lcds;
50  int64_t max_trip_count = *Input(0).template data<int64_t>();
51  const bool first_iter_condition = *Input(1).template data<bool>();
52 
53  // Body graph has 2+N inputs: iteration number, condition value, and N
54  // loop-carried dependencies
55  CAFFE_ENFORCE_EQ(
56  num_loop_carried_deps + 2,
57  scope_->net()->external_input().size(),
58  "Body graph must have 2+N inputs, where N is the number of "
59  "loop carried dependencies.");
60 
61  // Body graph has 1+N+K outputs: recalculated condition variable, N
62  // loop-carried dependencies, and K scan_outputs
63  int num_scan_outputs =
64  scope_->net()->external_output().size() - num_loop_carried_deps - 1;
65 
66  CAFFE_ENFORCE_GE(
67  num_scan_outputs,
68  0,
69  "Body graph must have N+K outputs, where N is the number "
70  "of loop-carried dependencies and K is the number of scan "
71  "outputs");
72 
73  // Copy initial loop-carried dependencies
74  for (int i = 0; i < num_loop_carried_deps; ++i) {
75  scope_->lcd_tensor(i)->CopyFrom(Input(i + num_inputs_before_lcds));
76  }
77 
78  // Initialize iteration variable
79  scope_->set_iteration(0ll);
80 
81  // Initialize input condition variable
82  scope_->set_input_condition(first_iter_condition);
83 
84  auto valid_iter_num = [this, max_trip_count](int64_t i) {
85  if (has_trip_count_) {
86  return i < max_trip_count;
87  } else {
88  return true;
89  }
90  };
91 
92  auto condition_true =
93  [this, first_iter_condition](int64_t i, bool cond_value) {
94  if (has_cond_) {
95  if (i == 0) {
96  return (bool)first_iter_condition;
97  } else {
98  return cond_value;
99  }
100  } else {
101  return true;
102  }
103  };
104 
105  // Allocate scan_outputs for zero-iteration case
106  for (int i = 0; i < num_scan_outputs; ++i) {
107  Output(i + num_loop_carried_deps)->Resize(0);
108  Output(i + num_loop_carried_deps)->template mutable_data<int32_t>();
109  }
110 
111  // Use this to keep track of the sizes of the scan outputs and validate
112  // they're the same across iterations.
113  std::vector<std::vector<TIndex>> scan_outputs_sizes;
114 
115  std::shared_ptr<Workspace> cur_ws = nullptr;
116  bool cur_output_condition = false;
117 
118  while (true) {
119  int64_t itr = scope_->iteration();
120  if (valid_iter_num(itr) && condition_true(itr, cur_output_condition)) {
121  if (!scope_->net()->Run()) {
122  return false;
123  }
124 
125  cur_ws = scope_->workspace();
126  cur_output_condition = scope_->output_condition();
127  if (save_scopes_) {
128  loop_ws = ws_stack_.pushForwardWorkspace(parent_ws_);
129  scope_ = std::make_shared<LocalScope>(loop_ws, body_net_def_);
130  }
131 
132  // Copy forward loop-carried dependencies
133  for (int i = 0; i < num_loop_carried_deps; ++i) {
134  Blob* b = cur_ws->GetBlob(
135  scope_->net()->external_output()[i + 1]);
136  const Tensor<Context>& t = b->template Get<Tensor<Context>>();
137  scope_->lcd_tensor(i)->CopyFrom(t);
138  }
139  // Copy out scan_outputs
140  for (int i = 0; i < num_scan_outputs; ++i) {
141  int net_output_idx = i + 1 + num_loop_carried_deps;
142  const Tensor<Context>& scan_output =
143  cur_ws->GetBlob(
144  scope_->net()->external_output()[net_output_idx])
145  ->template Get<Tensor<Context>>();
146  auto* scan_output_target = Output(i + num_loop_carried_deps);
147  if (itr == 0) {
148  auto dims = scan_output.dims();
149  scan_outputs_sizes.push_back(dims);
150  dims.insert(dims.begin(), 1);
151  scan_output_target->Resize(dims);
152  scan_output_target->CopyFrom(scan_output);
153  } else {
154  auto dims = scan_output.dims();
155  CAFFE_ENFORCE_EQ(
156  dims,
157  scan_outputs_sizes[i],
158  "Size of scan output changed across iterations");
159  dims.insert(dims.begin(), itr);
160  scan_output_target->Extend(1, 2.0f, &context_);
161 
162  TIndex timestep_size = 1;
163  for (const TIndex t : scan_outputs_sizes[i]) {
164  timestep_size *= t;
165  }
166 
167  const void* src_data = scan_output.raw_data();
168  auto& sot_meta = scan_output_target->meta();
169  void* dst_data =
170  (char*)scan_output_target->raw_mutable_data(sot_meta) +
171  timestep_size * scan_output.itemsize() * itr;
172  memcpy(dst_data, src_data, timestep_size * scan_output.itemsize());
173  }
174  }
175  scope_->set_iteration(itr + 1ll);
176  scope_->set_input_condition(cur_output_condition);
177  } else {
178  break;
179  }
180  }
181 
182  if (scope_->iteration() > 0) {
183  // Copy out final loop-carried dependencies
184  for (int i = 0; i < num_loop_carried_deps; ++i) {
185  Output(i)->CopyFrom(*scope_->lcd_tensor(i));
186  }
187  } else {
188  // Copy out final loop-carried dependencies
189  for (int i = 0; i < num_loop_carried_deps; ++i) {
190  Output(i)->CopyFrom(Input(i + num_inputs_before_lcds));
191  }
192  }
193 
194  return true;
195  }
196 
197  private:
198  class LocalScope {
199  public:
200  LocalScope(
201  const std::shared_ptr<Workspace>& loop_ws,
202  const NetDef& body_net_def) : loop_ws_(loop_ws) {
203  CAFFE_ENFORCE(loop_ws_,
204  "Failed to initialize local loop workspace");
205 
206  // Create loop-carried deps in Workspace
207  lcd_tensors_.clear();
208  for (int i = 2; i < body_net_def.external_input_size(); ++i) {
209  Blob* b = loop_ws_->CreateBlob(body_net_def.external_input(i));
210  Tensor<Context>* t = b->template GetMutable<Tensor<Context>>();
211  lcd_tensors_.push_back(t);
212  }
213  // First output is the iteration variable
214  auto* iteration_var_blob = loop_ws_->CreateBlob(
215  body_net_def.external_input(0));
216  iteration_var_ =
217  iteration_var_blob->template GetMutable<Tensor<Context>>();
218 
219  input_condition_var_ = loop_ws_->CreateBlob(
220  body_net_def.external_input(1))
221  ->template GetMutable<Tensor<Context>>();
222 
223  auto* condition_var_blob =
224  loop_ws_->CreateBlob(body_net_def.external_output(0));
225  condition_var_ = condition_var_blob->template GetMutable<Tensor<Context>>();
226  condition_var_->Resize(1);
227  condition_var_->template mutable_data<bool>();
228 
229  body_net_ = loop_ws_->GetNet(body_net_def.name());
230  if (!body_net_) {
231  body_net_ = loop_ws_->CreateNet(body_net_def, true);
232  }
233  CAFFE_ENFORCE(body_net_, "Failed to initialize loop subnet");
234  }
235 
236  NetBase* net() const {
237  return body_net_;
238  }
239 
240  std::shared_ptr<Workspace> workspace() const {
241  return loop_ws_;
242  }
243 
244  int64_t iteration() const {
245  auto* iteration_var_ptr =
246  iteration_var_->template mutable_data<int64_t>();
247  return *iteration_var_ptr;
248  }
249 
250  Tensor<Context>* lcd_tensor(int idx) {
251  return lcd_tensors_[idx];
252  }
253 
254  void set_iteration(int64_t itr) {
255  iteration_var_->Resize(1);
256  auto* iteration_var_ptr =
257  iteration_var_->template mutable_data<int64_t>();
258  *iteration_var_ptr = itr;
259  }
260 
261  void set_input_condition(bool cond_value) {
262  input_condition_var_->Resize(1);
263  auto* input_condition_var_ptr =
264  input_condition_var_->template mutable_data<bool>();
265  *input_condition_var_ptr = cond_value;
266  }
267 
268  bool output_condition() const {
269  auto* condition_var_ptr =
270  condition_var_->template mutable_data<bool>();
271  return *condition_var_ptr;
272  }
273 
274  private:
275  std::shared_ptr<Workspace> loop_ws_;
276 
277  NetBase* body_net_; // owned by a workspace
278  Tensor<Context>* iteration_var_;
279  Tensor<Context>* input_condition_var_;
280  Tensor<Context>* condition_var_;
281 
282  std::vector<Tensor<Context>*> lcd_tensors_;
283  };
284 
285  NetDef body_net_def_;
286  Workspace* parent_ws_;
287  detail::WorkspaceStack ws_stack_;
288 
289  bool has_trip_count_;
290  bool has_cond_;
291  bool save_scopes_;
292 
293  std::shared_ptr<LocalScope> scope_;
294 };
295 
296 } // namespace caffe2
297 
298 #endif // CAFFE2_OPERATORS_ONNX_WHILE_OP_H
Blob is a general container that hosts a typed pointer.
Definition: blob.h:25
size_t itemsize() const
Return the number of bytes each item takes in the tensor.
Definition: tensor.h:597
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:93
const vector< TIndex > & dims() const
Returns the dimensions of the tensor as a vector.
Definition: tensor.h:611
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
void Resize(Ts...dim_source)
Resizes a tensor.
Definition: tensor.h:288
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
const void * raw_data() const
Returns a const raw void* pointer of the underlying storage.
Definition: tensor.h:472