1 #ifndef CAFFE2_OPERATORS_ONNX_WHILE_OP_H_ 2 #define CAFFE2_OPERATORS_ONNX_WHILE_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/logging.h" 6 #include "caffe2/core/operator.h" 7 #include "caffe2/operators/create_scope_op.h" 11 template <
class Context>
18 OperatorBase::GetSingleArgument<int64_t>(
"has_trip_count", 0)),
19 has_cond_(OperatorBase::GetSingleArgument<int64_t>(
"has_cond", 0)),
20 save_scopes_(OperatorBase::GetSingleArgument<int64_t>(
"save_scopes", 0)) {
22 this->
template HasSingleArgumentOfType<NetDef>(
"body"),
23 "body net must be specified in ONNXWhile operator");
24 body_net_def_ = this->
template GetSingleArgument<NetDef>(
"body", NetDef());
25 if (!body_net_def_.has_name()) {
26 body_net_def_.set_name(
"loop_net");
30 USE_OPERATOR_CONTEXT_FUNCTIONS;
38 bool RunOnDevice()
override {
42 auto loop_ws = ws_stack_.pushForwardWorkspace(parent_ws_);
43 scope_ = std::make_shared<LocalScope>(loop_ws, body_net_def_);
45 constexpr int64_t num_inputs_before_lcds = 2;
49 int num_loop_carried_deps = InputSize() - num_inputs_before_lcds;
50 int64_t max_trip_count = *Input(0).template data<int64_t>();
51 const bool first_iter_condition = *Input(1).template data<bool>();
56 num_loop_carried_deps + 2,
57 scope_->net()->external_input().size(),
58 "Body graph must have 2+N inputs, where N is the number of " 59 "loop carried dependencies.");
63 int num_scan_outputs =
64 scope_->net()->external_output().size() - num_loop_carried_deps - 1;
69 "Body graph must have N+K outputs, where N is the number " 70 "of loop-carried dependencies and K is the number of scan " 74 for (
int i = 0; i < num_loop_carried_deps; ++i) {
75 scope_->lcd_tensor(i)->CopyFrom(Input(i + num_inputs_before_lcds));
79 scope_->set_iteration(0ll);
82 scope_->set_input_condition(first_iter_condition);
84 auto valid_iter_num = [
this, max_trip_count](int64_t i) {
85 if (has_trip_count_) {
86 return i < max_trip_count;
93 [
this, first_iter_condition](int64_t i,
bool cond_value) {
96 return (
bool)first_iter_condition;
106 for (
int i = 0; i < num_scan_outputs; ++i) {
107 Output(i + num_loop_carried_deps)->Resize(0);
108 Output(i + num_loop_carried_deps)->template mutable_data<int32_t>();
113 std::vector<std::vector<TIndex>> scan_outputs_sizes;
115 std::shared_ptr<Workspace> cur_ws =
nullptr;
116 bool cur_output_condition =
false;
119 int64_t itr = scope_->iteration();
120 if (valid_iter_num(itr) && condition_true(itr, cur_output_condition)) {
121 if (!scope_->net()->Run()) {
125 cur_ws = scope_->workspace();
126 cur_output_condition = scope_->output_condition();
128 loop_ws = ws_stack_.pushForwardWorkspace(parent_ws_);
129 scope_ = std::make_shared<LocalScope>(loop_ws, body_net_def_);
133 for (
int i = 0; i < num_loop_carried_deps; ++i) {
134 Blob* b = cur_ws->GetBlob(
135 scope_->net()->external_output()[i + 1]);
137 scope_->lcd_tensor(i)->CopyFrom(t);
140 for (
int i = 0; i < num_scan_outputs; ++i) {
141 int net_output_idx = i + 1 + num_loop_carried_deps;
144 scope_->net()->external_output()[net_output_idx])
146 auto* scan_output_target = Output(i + num_loop_carried_deps);
148 auto dims = scan_output.
dims();
149 scan_outputs_sizes.push_back(dims);
150 dims.insert(dims.begin(), 1);
151 scan_output_target->Resize(dims);
152 scan_output_target->CopyFrom(scan_output);
154 auto dims = scan_output.
dims();
157 scan_outputs_sizes[i],
158 "Size of scan output changed across iterations");
159 dims.insert(dims.begin(), itr);
160 scan_output_target->Extend(1, 2.0f, &context_);
162 TIndex timestep_size = 1;
163 for (
const TIndex t : scan_outputs_sizes[i]) {
167 const void* src_data = scan_output.
raw_data();
168 auto& sot_meta = scan_output_target->meta();
170 (
char*)scan_output_target->raw_mutable_data(sot_meta) +
171 timestep_size * scan_output.
itemsize() * itr;
172 memcpy(dst_data, src_data, timestep_size * scan_output.
itemsize());
175 scope_->set_iteration(itr + 1ll);
176 scope_->set_input_condition(cur_output_condition);
182 if (scope_->iteration() > 0) {
184 for (
int i = 0; i < num_loop_carried_deps; ++i) {
185 Output(i)->CopyFrom(*scope_->lcd_tensor(i));
189 for (
int i = 0; i < num_loop_carried_deps; ++i) {
190 Output(i)->CopyFrom(Input(i + num_inputs_before_lcds));
201 const std::shared_ptr<Workspace>& loop_ws,
202 const NetDef& body_net_def) : loop_ws_(loop_ws) {
203 CAFFE_ENFORCE(loop_ws_,
204 "Failed to initialize local loop workspace");
207 lcd_tensors_.clear();
208 for (
int i = 2; i < body_net_def.external_input_size(); ++i) {
209 Blob* b = loop_ws_->CreateBlob(body_net_def.external_input(i));
211 lcd_tensors_.push_back(t);
214 auto* iteration_var_blob = loop_ws_->CreateBlob(
215 body_net_def.external_input(0));
217 iteration_var_blob->template GetMutable<Tensor<Context>>();
219 input_condition_var_ = loop_ws_->CreateBlob(
220 body_net_def.external_input(1))
223 auto* condition_var_blob =
224 loop_ws_->CreateBlob(body_net_def.external_output(0));
225 condition_var_ = condition_var_blob->template GetMutable<Tensor<Context>>();
226 condition_var_->
Resize(1);
227 condition_var_->template mutable_data<bool>();
229 body_net_ = loop_ws_->GetNet(body_net_def.name());
231 body_net_ = loop_ws_->CreateNet(body_net_def,
true);
233 CAFFE_ENFORCE(body_net_,
"Failed to initialize loop subnet");
240 std::shared_ptr<Workspace> workspace()
const {
244 int64_t iteration()
const {
245 auto* iteration_var_ptr =
246 iteration_var_->template mutable_data<int64_t>();
247 return *iteration_var_ptr;
251 return lcd_tensors_[idx];
254 void set_iteration(int64_t itr) {
255 iteration_var_->
Resize(1);
256 auto* iteration_var_ptr =
257 iteration_var_->template mutable_data<int64_t>();
258 *iteration_var_ptr = itr;
261 void set_input_condition(
bool cond_value) {
262 input_condition_var_->Resize(1);
263 auto* input_condition_var_ptr =
264 input_condition_var_->template mutable_data<bool>();
265 *input_condition_var_ptr = cond_value;
268 bool output_condition()
const {
269 auto* condition_var_ptr =
270 condition_var_->template mutable_data<bool>();
271 return *condition_var_ptr;
275 std::shared_ptr<Workspace> loop_ws_;
282 std::vector<Tensor<Context>*> lcd_tensors_;
285 NetDef body_net_def_;
289 bool has_trip_count_;
293 std::shared_ptr<LocalScope> scope_;
298 #endif // CAFFE2_OPERATORS_ONNX_WHILE_OP_H Blob is a general container that hosts a typed pointer.
size_t itemsize() const
Return the number of bytes each item takes in the tensor.
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
const vector< TIndex > & dims() const
Returns the dimensions of the tensor as a vector.
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
void Resize(Ts...dim_source)
Resizes a tensor.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
const void * raw_data() const
Returns a const raw void* pointer of the underlying storage.