1 #ifndef CAFFE2_OPERATORS_ARG_OPS_EIGEN_H_ 2 #define CAFFE2_OPERATORS_ARG_OPS_EIGEN_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/types.h" 9 #if EIGEN_VERSION_AT_LEAST(3, 3, 0) 11 #include "unsupported/Eigen/CXX11/Tensor" 14 namespace arg_ops_eigen {
17 using EigenTensorMap1D = Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor>>;
20 using EigenTensorMap2D = Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>>;
23 using EigenTensorMap3D = Eigen::TensorMap<Eigen::Tensor<T, 3, Eigen::RowMajor>>;
25 template <
class Device,
typename T>
26 void ComputeArgMaxEigen(
29 const TIndex prev_size,
30 const TIndex next_size,
34 EigenTensorMap1D<TIndex>(Y, prev_size).device(device) =
35 EigenTensorMap2D<T>(
const_cast<T*
>(X), prev_size, n)
37 .template cast<TIndex>();
38 }
else if (prev_size == 1) {
39 EigenTensorMap1D<TIndex>(Y, next_size).device(device) =
40 EigenTensorMap2D<T>(
const_cast<T*
>(X), n, next_size)
42 .template cast<TIndex>();
44 EigenTensorMap2D<TIndex>(Y, prev_size, next_size).device(device) =
45 EigenTensorMap3D<T>(
const_cast<T*
>(X), prev_size, n, next_size)
47 .template cast<TIndex>();
51 template <
class Device,
typename T>
52 void ComputeArgMinEigen(
55 const TIndex prev_size,
56 const TIndex next_size,
60 EigenTensorMap1D<TIndex>(Y, prev_size).device(device) =
61 EigenTensorMap2D<T>(
const_cast<T*
>(X), prev_size, n)
63 .template cast<TIndex>();
64 }
else if (prev_size == 1) {
65 EigenTensorMap1D<TIndex>(Y, next_size).device(device) =
66 EigenTensorMap2D<T>(
const_cast<T*
>(X), n, next_size)
68 .template cast<TIndex>();
70 EigenTensorMap2D<TIndex>(Y, prev_size, next_size).device(device) =
71 EigenTensorMap3D<T>(
const_cast<T*
>(X), prev_size, n, next_size)
73 .template cast<TIndex>();
80 #endif // EIGEN_VERSION_AT_LEAST(3, 3, 0) 82 #endif // CAFFE2_OPERATORS_ARG_OPS_EIGEN_H_ A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...