1 #ifndef CAFFE2_OPERATORS_RECURRENT_NETWORK_OP_H_ 2 #define CAFFE2_OPERATORS_RECURRENT_NETWORK_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/logging.h" 6 #include "caffe2/core/operator.h" 7 #include "caffe2/core/tensor.h" 8 #include "caffe2/operators/rnn/recurrent_network_executor.h" 9 #include "caffe2/utils/conversions.h" 10 #include "caffe2/utils/math.h" 12 CAFFE2_DECLARE_bool(caffe2_rnn_executor);
20 std::string cellGradient;
31 std::string externalGrad;
32 std::string lastExternalGrad;
50 std::vector<std::shared_ptr<Workspace>> stepWorkspaces;
51 std::shared_ptr<Workspace> sharedBlobsWs =
nullptr;
54 inline void UpdateTimestepBlob(
Workspace* ws, std::string blob_name,
int t) {
56 auto timestepBlob = ws->
GetBlob(blob_name);
57 CAFFE_ENFORCE(timestepBlob);
58 timestepBlob->GetMutable<
TensorCPU>()->mutable_data<int32_t>()[0] = t;
61 std::map<string, string> GetRecurrentMapping(
62 const std::vector<detail::Link>& links,
bool backward);
64 template <
typename T,
typename Context>
65 void applyOffsetAlias(
69 VLOG(1) <<
"Aliasing: " << oc.src <<
" to: " << oc.dst
70 <<
" at offset: " << oc.offset;
71 auto srcBlob = ws->
GetBlob(oc.src);
72 CAFFE_ENFORCE(srcBlob);
73 auto* src = srcBlob->template GetMutable<Tensor<Context>>();
74 auto* dst = ws->
GetBlob(oc.dst)->template GetMutable<Tensor<Context>>();
75 auto timestep = src->size() / src->dim(0);
76 auto dims = src->dims();
77 const int32_t startDstTimestep =
78 oc.offset >= 0 ? oc.offset : src->dim(0) + oc.offset;
79 const int32_t numDstTimesteps = src->dim(0) - startDstTimestep;
81 numDstTimesteps >= 1,
"Invalid number of timesteps: ", numDstTimesteps);
82 dims[0] = numDstTimesteps;
84 CAFFE_ENFORCE(timestep == dst->size() / numDstTimesteps,
"Invalid offset");
85 dst->ShareExternalPointer(
86 src->template mutable_data<T>() + startDstTimestep * timestep,
90 template <
typename T,
class Context>
97 for (
int i = 0; i < repeat_n; ++i) {
98 context->template Copy<T, Context, Context>(n, src, dst + i * n);
106 template <
typename T,
typename Context>
107 void initializeRecurrentInput(
113 auto stateBlob = ws->
GetBlob(rc.state);
114 CAFFE_ENFORCE(stateBlob);
115 auto* state = stateBlob->template GetMutable<Tensor<Context>>();
117 auto inputBlob = ws->
GetBlob(rc.input);
118 CAFFE_ENFORCE(inputBlob);
119 const auto& input = inputBlob->template Get<Tensor<Context>>();
120 CAFFE_ENFORCE_GE(input.ndim(), 1, rc.input);
121 CAFFE_ENFORCE_LE(input.ndim(), 3, rc.input);
123 const auto stateSize = input.dim(input.ndim() - 1);
128 auto initialStateLength = 1;
129 if (input.ndim() == 3) {
130 initialStateLength = input.dim(0);
133 state->Resize(seqLen + initialStateLength, batchSize, stateSize);
135 if (input.ndim() >= 2) {
136 CAFFE_ENFORCE_EQ(input.dim(input.ndim() - 2), batchSize, rc.input);
137 context->template Copy<T, Context, Context>(
138 batchSize * stateSize * initialStateLength,
139 input.template data<T>(),
140 state->template mutable_data<T>());
144 repeatCopy<T, Context>(
147 input.template data<T>(),
148 state->template mutable_data<T>(),
153 void PrependOps(std::vector<OperatorDef> ops, NetDef* netdef);
155 void AddApplyLinkOps(
156 const vector<Link>& links,
157 std::string timestep,
158 const DeviceOption& device_option,
163 const std::string& internalArg,
164 const std::string& externalArg,
165 const std::string& offsetArg,
166 const std::string& windowArg,
167 std::vector<detail::Link>* links);
169 NetDef extractNetDef(
const OperatorDef& op,
const std::string& argName);
172 template <
class Context>
175 USE_OPERATOR_CONTEXT_FUNCTIONS;
179 enable_rnn_executor_(OperatorBase::template GetSingleArgument<bool>(
180 "enable_rnn_executor",
182 timestep_(OperatorBase::template GetSingleArgument<std::string>(
187 stepNetDef_ = detail::extractNetDef(operator_def,
"step_net");
189 recurrentInputs_ = constructRecurrentInputs(operator_def, sharedWs_);
190 links_ = constructLinks();
191 aliases_ = constructAliases();
193 stepNetDef_.add_external_input(timestep_);
194 detail::AddApplyLinkOps(
195 links_, timestep_, operator_def.device_option(), &stepNetDef_);
197 if (FLAGS_caffe2_rnn_executor && enable_rnn_executor_) {
198 VLOG(1) <<
"Use RecurrentNetworkExecutor";
199 auto recurrent_map = detail::GetRecurrentMapping(links_,
false );
201 createRNNExecutor<Context>(
208 if (stepNetDef_.type() ==
"rnn") {
209 stepNetDef_.set_type(
"async_simple");
211 CAFFE_ENFORCE(stepNetDef_.type() !=
"async_dag");
215 size_t NumObservers()
override {
216 size_t num = this->observers_list_.size();
218 num += rnnExecutor_->NumObserversStepNet();
223 std::vector<detail::RecurrentInput> constructRecurrentInputs(
224 const OperatorDef& operator_def,
227 OperatorBase::GetRepeatedArgument<std::string>(
"recurrent_states");
229 OperatorBase::GetRepeatedArgument<int>(
"initial_recurrent_state_ids");
230 CAFFE_ENFORCE_EQ(states.size(), inputs.size(),
"states/inputs mismatch");
231 std::vector<detail::RecurrentInput> ris;
232 for (
auto i = 0; i < states.size(); ++i) {
238 ri.state = states[i];
239 ri.input = operator_def.input(inputs[i]);
245 std::vector<detail::OffsetAlias> constructAliases() {
247 OperatorBase::GetRepeatedArgument<std::string>(
"alias_src");
249 OperatorBase::GetRepeatedArgument<std::string>(
"alias_dst");
251 OperatorBase::GetRepeatedArgument<int32_t>(
"alias_offset");
253 src.size() == offset.size(),
"alias_src/alias_offset mismatch");
255 dst.size() == offset.size(),
"alias_dst/alias_offset mismatch");
256 std::vector<detail::OffsetAlias> aliases;
257 for (
auto i = 0; i < src.size(); ++i) {
261 oc.offset = offset[i];
262 aliases.push_back(oc);
274 std::vector<std::string> v;
275 const auto& blobs = OperatorBase::GetRepeatedArgument<std::string>(
276 "recompute_blobs_on_backward", v);
277 for (
const auto& b : blobs) {
283 std::vector<detail::Link> constructLinks() {
284 std::vector<detail::Link> links;
285 detail::extractLinks(
296 bool DoRunWithType() {
297 const auto seqLen = Input(0).dim32(0);
298 const auto batchSize = Input(0).dim32(1);
299 for (
const auto& ri : recurrentInputs_) {
300 detail::initializeRecurrentInput<T, Context>(
301 ri, seqLen, batchSize, sharedWs_, &context_);
306 bool has_backward_pass =
307 OperatorBase::HasSingleArgumentOfType<NetDef>(
"backward_step_net") ||
308 (OperatorBase::HasSingleArgumentOfType<string>(
"backward_step_net") &&
309 OperatorBase::GetSingleArgument<string>(
"backward_step_net",
"") !=
314 OperatorBase::Output<detail::ScratchWorkspaces>(OutputSize() - 1);
315 std::vector<std::shared_ptr<Workspace>>& stepWorkspaces =
316 scratch->stepWorkspaces;
317 std::shared_ptr<Workspace>& sharedBlobsWs = scratch->sharedBlobsWs;
318 if (!sharedBlobsWs) {
319 sharedBlobsWs = std::make_shared<Workspace>(sharedWs_);
325 initializeBlobsToRecomputeOnBackward(sharedBlobsWs.get());
327 if (has_backward_pass && seqLen > stepWorkspaces.size()) {
328 stepWorkspaces.resize(seqLen);
334 int num_workspaces_on_fwd_only = rnnExecutor_ ? 4 : 2;
336 if (!has_backward_pass && stepWorkspaces.size() < num_workspaces_on_fwd_only) {
340 stepWorkspaces.resize(num_workspaces_on_fwd_only);
343 for (
auto t = 0; t < seqLen; ++t) {
344 auto& currentStepWorkspace =
345 (has_backward_pass ? stepWorkspaces[t] :
346 stepWorkspaces[t % num_workspaces_on_fwd_only]);
347 if (!currentStepWorkspace) {
348 currentStepWorkspace = std::make_shared<Workspace>(sharedBlobsWs.get());
352 if (!has_backward_pass) {
354 rnnExecutor_->SetMaxParallelTimesteps(num_workspaces_on_fwd_only);
356 rnnExecutor_->EnsureTimestepInitialized(
357 t, currentStepWorkspace.get(), this->observers_list_);
360 detail::UpdateTimestepBlob(currentStepWorkspace.get(), timestep_, t);
361 auto* stepNet = currentStepWorkspace->GetNet(stepNetDef_.name());
362 if (stepNet ==
nullptr) {
363 stepNet = currentStepWorkspace->CreateNet(stepNetDef_);
365 CAFFE_ENFORCE(stepNet,
"Step Net construction failure");
372 rnnExecutor_->Run(seqLen);
375 for (
const auto& alias : aliases_) {
376 detail::applyOffsetAlias<T, Context>(alias, sharedWs_, &context_);
382 bool RunOnDevice()
override {
383 return DoRunWithType<float>();
389 bool enable_rnn_executor_;
390 std::unique_ptr<RecurrentNetworkExecutorBase> rnnExecutor_;
392 std::vector<detail::Link> links_;
393 std::vector<detail::OffsetAlias> aliases_;
394 std::vector<detail::RecurrentInput> recurrentInputs_;
395 std::string timestep_;
398 template <
class Context>
401 USE_OPERATOR_CONTEXT_FUNCTIONS;
405 enable_rnn_executor_(OperatorBase::template GetSingleArgument<bool>(
406 "enable_rnn_executor",
408 timestep_(OperatorBase::template GetSingleArgument<std::string>(
411 gradInputs_(OperatorBase::template GetRepeatedArgument<int32_t>(
412 "outputs_with_grads")) {
415 stepNetDef_ = detail::extractNetDef(operator_def,
"backward_step_net");
417 links_ = constructLinks();
418 params_ = constructParams(operator_def);
419 recurrentGradients_ = constructRecurrentGradients(operator_def);
420 recurrentInputIds_ = OperatorBase::template GetRepeatedArgument<int32_t>(
421 "initial_recurrent_state_ids");
426 stepNetDef_.add_external_input(timestep_);
428 AddGradientInputAccumulationOps(operator_def);
429 detail::AddApplyLinkOps(
430 links_, timestep_, operator_def.device_option(), &stepNetDef_);
431 AddParamGradientAccumulationOps(operator_def);
433 if (FLAGS_caffe2_rnn_executor && enable_rnn_executor_) {
434 InitializeExecutor(operator_def);
439 std::string remappedName(std::string blob_name) {
440 return OperatorBase::template GetSingleArgument<std::string>(
441 blob_name +
".rename", blob_name);
446 renamed_link.internal = remappedName(link.internal);
447 renamed_link.external = remappedName(link.external);
451 void renameOpInputOutput(std::string from_name, std::string to_name) {
452 for (
int j = 0; j < stepNetDef_.op_size(); j++) {
453 auto* op = stepNetDef_.mutable_op(j);
454 for (
int i = 0; i < op->input_size(); i++) {
455 if (op->input(i) == from_name) {
456 op->set_input(i, to_name);
459 for (
int i = 0; i < op->output_size(); i++) {
460 if (op->output(i) == from_name) {
461 op->set_output(i, to_name);
467 std::vector<detail::Param> constructParams(
const OperatorDef& operator_def) {
468 std::vector<detail::Param> params;
469 const auto& param = OperatorBase::GetRepeatedArgument<int32_t>(
"param");
470 const auto& param_grads =
471 OperatorBase::GetRepeatedArgument<string>(
"param_grads");
473 param_grads.empty() || param_grads.size() == param.size(),
477 for (
int i = 0; i < param.size(); ++i) {
480 p.param = operator_def.input(param[i] + gradInputs_.size());
482 p.grad = operator_def.output(i + numSequences_);
484 std::string grad_blob =
485 param_grads.empty() ? p.grad : remappedName(param_grads[i]);
486 p.cellGradient = grad_blob +
"_tmpstep";
489 renameOpInputOutput(grad_blob, p.cellGradient);
494 std::vector<detail::RecurrentGradient> constructRecurrentGradients(
495 const OperatorDef& operator_def) {
496 std::vector<detail::RecurrentGradient> rgs;
497 const auto& recurrent =
498 OperatorBase::GetRepeatedArgument<std::string>(
"recurrent_states");
499 const auto& alias_src =
500 OperatorBase::GetRepeatedArgument<std::string>(
"alias_src");
502 OperatorBase::GetRepeatedArgument<int32_t>(
"alias_offset");
504 for (
auto i = 0; i < recurrent.size(); ++i) {
506 rg.param = recurrent[i];
507 rg.grad = remappedName(recurrent[i] +
"_grad");
509 for (
int j = 0; j < alias_src.size(); ++j) {
510 if (alias_src[j] != recurrent[i]) {
514 for (
int k = 0; k < gradInputs_.size(); ++k) {
515 if (gradInputs_[k] == j) {
523 CAFFE_ENFORCE(offset[j] == 1 || offset[j] == -1);
524 if (offset[j] == 1) {
525 rg.externalGrad = operator_def.input(idx);
526 }
else if (offset[j] == -1) {
527 rg.lastExternalGrad = operator_def.input(idx);
536 std::vector<detail::Link> constructLinks() {
537 std::vector<detail::Link> links;
538 detail::extractLinks(
545 detail::extractLinks(
547 "backward_link_internal",
548 "backward_link_external",
549 "backward_link_offset",
552 for (
int i = 0; i < links.size(); i++) {
553 links[i] = remappedLink(links[i]);
558 void InitializeExecutor(
const OperatorDef& operator_def) {
559 VLOG(1) <<
"Use RecurrentNetworkExecutor for backward";
560 auto recurrent_map = detail::GetRecurrentMapping(links_,
true );
561 rnnExecutor_ = createRNNExecutor<Context>(
562 stepNetDef_, recurrent_map, timestep_,
ArgumentHelper(operator_def));
569 std::vector<OperatorDef> ops;
570 for (
const auto& rg : recurrentGradients_) {
571 if (rg.externalGrad.empty()) {
574 VLOG(1) <<
"Accumulating into: " << rg.grad <<
" from " << rg.externalGrad
575 <<
", offset: " << rg.offset;
578 opdef.set_type(
"rnn_internal_accumulate_gradient_input");
579 opdef.add_input(timestep_);
580 opdef.add_input(rg.externalGrad);
581 opdef.add_input(rg.grad);
582 opdef.add_output(rg.grad);
586 for (
auto& l : links_) {
587 if (rg.grad == l.external) {
588 Argument* dep_arg = opdef.add_arg();
589 dep_arg->set_name(
"rnn_dependency." + l.internal);
590 dep_arg->set_s(l.internal);
594 opdef.mutable_device_option()->CopyFrom(operator_def.device_option());
596 Argument* offset_arg = opdef.add_arg();
597 offset_arg->set_name(
"offset");
598 offset_arg->set_i(rg.offset);
599 ops.push_back(opdef);
601 stepNetDef_.add_external_input(rg.externalGrad);
602 stepNetDef_.add_external_input(rg.grad);
604 detail::PrependOps(ops, &stepNetDef_);
607 void AddParamGradientAccumulationOps(
const OperatorDef& operator_def) {
613 for (
const auto& param : params_) {
615 opdef.set_type(
"Sum");
616 opdef.add_input(param.grad);
617 opdef.add_input(param.cellGradient);
618 opdef.add_output(param.grad);
619 opdef.mutable_device_option()->CopyFrom(operator_def.device_option());
620 stepNetDef_.add_op()->CopyFrom(opdef);
621 stepNetDef_.add_external_input(param.grad);
626 const std::shared_ptr<Workspace>& step0Ws,
632 for (
auto& op : stepNetDef_.op()) {
633 for (
const string& outp : op.output()) {
634 if (!step0Ws->HasBlob(outp)) {
642 bool DoRunWithType() {
643 const auto seqLen = Input(gradInputs_.size()).dim32(0);
644 VLOG(1) <<
"seqLen: " << seqLen;
647 OperatorBase::Input<detail::ScratchWorkspaces>(InputSize() - 1);
648 const std::vector<std::shared_ptr<Workspace>>& stepWorkspaces =
649 scratch.stepWorkspaces;
650 CAFFE_ENFORCE_GE(stepWorkspaces.size(), seqLen);
651 Workspace& sharedBlobsWs = *scratch.sharedBlobsWs.get();
653 const auto batchSize = Input(0).dim32(1);
654 for (
auto& param : params_) {
655 auto pBlob = sharedWs_->GetBlob(param.param);
656 CAFFE_ENFORCE(pBlob);
657 const auto& p = pBlob->template Get<Tensor<Context>>();
659 auto gBlob = sharedWs_->GetBlob(param.grad);
660 CAFFE_ENFORCE(gBlob);
661 auto* g = gBlob->template GetMutable<Tensor<Context>>();
663 math::Set<T, Context>(
665 convert::To<float,T>(0.0),
666 g->template mutable_data<T>(),
670 for (
auto& rg : recurrentGradients_) {
671 auto pBlob = sharedWs_->GetBlob(rg.param);
672 CAFFE_ENFORCE(pBlob);
673 const auto& p = pBlob->template Get<Tensor<Context>>();
675 auto gBlob = sharedWs_->CreateBlob(rg.grad);
676 CAFFE_ENFORCE(gBlob);
677 auto* g = gBlob->template GetMutable<Tensor<Context>>();
679 CAFFE_ENFORCE_EQ(g->ndim(), 3);
680 const auto timestep = g->size() / g->dim(0);
682 math::Set<T, Context>(
684 convert::To<float,T>(0.0),
685 g->template mutable_data<T>() + (g->dim(0) - 1) * timestep,
692 for (
int i = 0; i < numSequences_; ++i) {
695 const int gradientInputIndex = i + gradInputs_.size();
696 const auto& inputName = this->debug_def().input(gradientInputIndex);
697 auto gradientName = remappedName(inputName +
"_grad");
698 VLOG(1) <<
"Initializing gradient for input " << gradientInputIndex
699 <<
" (" << inputName <<
") " 700 <<
" as blob " << gradientName
701 <<
". Size: " << Input(gradientInputIndex).size();
702 auto pGradientBlob = sharedWs_->GetBlob(gradientName);
703 CAFFE_ENFORCE(pGradientBlob);
704 auto* g = pGradientBlob->template GetMutable<Tensor<Context>>();
705 g->ResizeLike(Input(gradientInputIndex));
706 g->template mutable_data<T>();
709 auto accumulateFinalInputGradients = [&]() {
710 for (
const auto& rg : recurrentGradients_) {
711 if (rg.lastExternalGrad.empty()) {
714 VLOG(1) <<
"Accumulating into: " << rg.grad <<
" from " 715 << rg.lastExternalGrad <<
" for final time step (sep. blob)";
716 auto gBlob = sharedWs_->GetBlob(rg.grad);
717 CAFFE_ENFORCE(gBlob);
718 auto* g = gBlob->template GetMutable<Tensor<Context>>();
720 auto oglastBlob = sharedWs_->GetBlob(rg.lastExternalGrad);
721 CAFFE_ENFORCE(oglastBlob);
722 const auto& oglast = oglastBlob->template Get<Tensor<Context>>();
723 CAFFE_ENFORCE_EQ(g->dim(1), oglast.dim(1));
724 CAFFE_ENFORCE_EQ(g->dim(2), oglast.dim(2));
726 const auto t = g->dim(0) - 1;
727 const auto timestep_size = g->size() / g->dim(0);
728 CAFFE_ENFORCE_EQ(timestep_size, oglast.size());
729 T* g_data_with_offset =
730 g->template mutable_data<T>() + t * timestep_size;
731 math::Add<T, Context>(
733 oglast.template data<T>(),
740 accumulateFinalInputGradients();
744 if (stepWorkspaces.size() > 0) {
745 CreateSharedBlobs(stepWorkspaces[0], &sharedBlobsWs);
747 for (int32_t t = seqLen - 1; t >= 0; --t) {
749 rnnExecutor_->EnsureTimestepInitialized(
750 t, stepWorkspaces[t].
get(), this->observers_list_);
752 auto* stepNet = stepWorkspaces[t].get()->GetNet(stepNetDef_.name());
753 if (stepNet ==
nullptr) {
754 stepNet = stepWorkspaces[t].get()->CreateNet(stepNetDef_);
756 CAFFE_ENFORCE(stepNet);
762 rnnExecutor_->RunBackwards(seqLen);
765 CAFFE_ENFORCE_EQ(recurrentInputIds_.size(), recurrentGradients_.size());
766 for (
int i = 0; i < recurrentInputIds_.size(); ++i) {
771 auto outputIdx = i + params_.size() + numSequences_;
773 int inputId = recurrentInputIds_[i] + gradInputs_.size();
774 VLOG(1) <<
"Resetting output " << this->debug_def().output(outputIdx)
775 <<
" like input " << this->debug_def().input(inputId);
776 Output(outputIdx)->ResizeLike(Input(inputId));
777 T* output_data = Output(outputIdx)->template mutable_data<T>();
778 auto pBlob = sharedWs_->GetBlob(recurrentGradients_[i].grad);
779 CAFFE_ENFORCE(pBlob);
780 auto* p = pBlob->template GetMutable<Tensor<Context>>();
782 if (Input(inputId).ndim() >= 2) {
786 Output(outputIdx)->template ShareExternalPointer<T>(
787 p->template mutable_data<T>());
792 const auto recurrentStateSize = Input(inputId).dim32(0);
794 math::Set<T, Context>(
796 convert::To<float,T>(0.0),
800 math::AddStripedBatch<T, Context>(
802 p->template data<T>(),
813 bool RunOnDevice()
override {
814 return DoRunWithType<float>();
820 bool enable_rnn_executor_;
821 std::unique_ptr<RecurrentNetworkExecutorBase> rnnExecutor_;
822 std::vector<detail::Link> links_;
823 std::vector<detail::Param> params_;
824 std::vector<detail::RecurrentGradient> recurrentGradients_;
825 std::string timestep_;
827 const int numSequences_{1};
828 std::vector<int32_t> recurrentInputIds_;
829 std::vector<int32_t> gradInputs_;
832 template <
class Context>
837 offset_(OperatorBase::GetSingleArgument<int>(
"offset", -1)) {
838 CAFFE_ENFORCE(offset_ >= 0,
"Offset not set");
840 USE_OPERATOR_CONTEXT_FUNCTIONS;
843 bool DoRunWithType() {
844 const auto& t0 = OperatorBase::Input<Tensor<CPUContext>>(0);
845 const auto t = t0.template data<int32_t>()[0];
849 T* g_data = g->template mutable_data<T>();
850 const auto timestep_size = g->size() / g->dim(0);
853 (t + offset_) * timestep_size + timestep_size <= g->size(),
854 "Accumulation destination address over bounds");
856 t * timestep_size + timestep_size <= og.size(),
857 "Accumulation source address out of bounds");
859 math::Add<T, Context>(
861 og.template data<T>() + t * timestep_size,
862 g_data + (t + offset_) * timestep_size,
863 g_data + (t + offset_) * timestep_size,
868 bool RunOnDevice()
override {
876 template <
class Context>
881 offset_(OperatorBase::GetSingleArgument<int>(
"offset", -1)),
882 window_(OperatorBase::GetSingleArgument<int>(
"window", -1)) {
883 CAFFE_ENFORCE(offset_ >= 0,
"offset not set");
884 CAFFE_ENFORCE(window_ >= 0,
"window not set");
887 USE_OPERATOR_CONTEXT_FUNCTIONS;
889 template <
typename T>
890 bool DoRunWithType() {
893 const auto& t0 = OperatorBase::Input<Tensor<CPUContext>>(0);
894 const auto t = t0.template data<int32_t>()[0];
895 auto& external = Input(1);
897 auto* internal_out = Output(0);
898 auto* external_out = Output(1);
900 CAFFE_ENFORCE_GT(external.size(), 0);
901 const TIndex externalTimestepSize = external.size() / external.dim(0);
902 auto* externalData = external_out->template mutable_data<T>() +
903 (t + offset_) * externalTimestepSize;
904 auto internalDims = external_out->dims();
905 internalDims[0] = window_;
907 internal_out->Resize(internalDims);
908 internal_out->ShareExternalPointer(
909 externalData, externalTimestepSize * window_);
913 bool RunOnDevice()
override {
914 return DoRunWithType<float>();
924 #endif // CAFFE2_OPERATORS_RECURRENT_NETWORK_OP_H_ void AddGradientInputAccumulationOps(const OperatorDef &operator_def)
Blob * CreateBlob(const string &name)
Creates a blob of the given name.
void initializeBlobsToRecomputeOnBackward(Workspace *sharedBlobsWs)
Some blobs can be marked as to be recomputed on backward pass.
A helper class to index into arguments.
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
const Blob * GetBlob(const string &name) const
Gets the blob with the given name as a const pointer.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
T * GetMutable(bool *is_new_object=nullptr)
Gets a mutable pointer to the stored object.
void CreateSharedBlobs(const std::shared_ptr< Workspace > &step0Ws, Workspace *sharedBlobsWs)