Caffe2 - C++ API
A deep learning, cross platform ML framework
assert_op.h
1 #ifndef CAFFE2_OPERATORS_ASSERT_OP_H_
2 #define CAFFE2_OPERATORS_ASSERT_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 
7 namespace caffe2 {
8 
9 template <class Context>
10 class AssertOp final : public Operator<Context> {
11  public:
12  AssertOp(const OperatorDef& operator_def, Workspace* ws)
13  : Operator<Context>(operator_def, ws),
14  error_msg_(
15  OperatorBase::GetSingleArgument<std::string>("error_msg", "")) {}
16 
17  USE_OPERATOR_CONTEXT_FUNCTIONS;
18 
19  template <typename T>
20  bool DoRunWithType() {
21  // Copy into CPU context for comparison
22  cmp_tensor_.CopyFrom(Input(0));
23  auto* cmp_data = cmp_tensor_.template data<T>();
24 
25  for (TIndex i = 0; i < cmp_tensor_.size(); ++i) {
26  CAFFE_ENFORCE((bool)cmp_data[i], [&]() {
27  std::stringstream ss;
28  ss << "Assert failed for element " << i
29  << " in tensor, value: " << cmp_data[i] << "\n";
30  if (!error_msg_.empty()) {
31  ss << "Error message: " << error_msg_;
32  }
33  return ss.str();
34  }());
35  }
36  return true;
37  }
38 
39  bool RunOnDevice() override {
40  return DispatchHelper<TensorTypes<long, int, bool>>::call(this, Input(0));
41  }
42 
43  private:
44  TensorCPU cmp_tensor_;
45  std::string error_msg_;
46 };
47 
48 } // namespace caffe2
49 
50 #endif /* CAFFE2_OPERATORS_ASSERT_OP_H_ */
void CopyFrom(const Tensor< SrcContext > &src, ContextForCopy *context)
Copies the data from a source tensor, with a contex provided to carry out the underlying memcpy opera...
Definition: tensor.h:166
TIndex size() const
Returns the size (i.e.
Definition: tensor.h:593
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...