Caffe2 - C++ API
A deep learning, cross platform ML framework
arg_ops_eigen.h
1 #ifndef CAFFE2_OPERATORS_ARG_OPS_EIGEN_H_
2 #define CAFFE2_OPERATORS_ARG_OPS_EIGEN_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/types.h"
6 
7 #include "Eigen/Core"
8 
9 #if EIGEN_VERSION_AT_LEAST(3, 3, 0)
10 
11 #include "unsupported/Eigen/CXX11/Tensor"
12 
13 namespace caffe2 {
14 namespace arg_ops_eigen {
15 
16 template <typename T>
17 using EigenTensorMap1D = Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor>>;
18 
19 template <typename T>
20 using EigenTensorMap2D = Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>>;
21 
22 template <typename T>
23 using EigenTensorMap3D = Eigen::TensorMap<Eigen::Tensor<T, 3, Eigen::RowMajor>>;
24 
25 template <class Device, typename T>
26 void ComputeArgMaxEigen(
27  const Device& device,
28  const T* X,
29  const TIndex prev_size,
30  const TIndex next_size,
31  const TIndex n,
32  TIndex* Y) {
33  if (next_size == 1) {
34  EigenTensorMap1D<TIndex>(Y, prev_size).device(device) =
35  EigenTensorMap2D<T>(const_cast<T*>(X), prev_size, n)
36  .argmax(1)
37  .template cast<TIndex>();
38  } else if (prev_size == 1) {
39  EigenTensorMap1D<TIndex>(Y, next_size).device(device) =
40  EigenTensorMap2D<T>(const_cast<T*>(X), n, next_size)
41  .argmax(0)
42  .template cast<TIndex>();
43  } else {
44  EigenTensorMap2D<TIndex>(Y, prev_size, next_size).device(device) =
45  EigenTensorMap3D<T>(const_cast<T*>(X), prev_size, n, next_size)
46  .argmax(1)
47  .template cast<TIndex>();
48  }
49 }
50 
51 template <class Device, typename T>
52 void ComputeArgMinEigen(
53  const Device& device,
54  const T* X,
55  const TIndex prev_size,
56  const TIndex next_size,
57  const TIndex n,
58  TIndex* Y) {
59  if (next_size == 1) {
60  EigenTensorMap1D<TIndex>(Y, prev_size).device(device) =
61  EigenTensorMap2D<T>(const_cast<T*>(X), prev_size, n)
62  .argmin(1)
63  .template cast<TIndex>();
64  } else if (prev_size == 1) {
65  EigenTensorMap1D<TIndex>(Y, next_size).device(device) =
66  EigenTensorMap2D<T>(const_cast<T*>(X), n, next_size)
67  .argmin(0)
68  .template cast<TIndex>();
69  } else {
70  EigenTensorMap2D<TIndex>(Y, prev_size, next_size).device(device) =
71  EigenTensorMap3D<T>(const_cast<T*>(X), prev_size, n, next_size)
72  .argmin(1)
73  .template cast<TIndex>();
74  }
75 }
76 
77 } // namespace arg_ops_eigen
78 } // namespace caffe2
79 
80 #endif // EIGEN_VERSION_AT_LEAST(3, 3, 0)
81 
82 #endif // CAFFE2_OPERATORS_ARG_OPS_EIGEN_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...