3 #ifndef CAFFE2_OPERATORS_UTILS_EIGEN_H_ 4 #define CAFFE2_OPERATORS_UTILS_EIGEN_H_ 8 #include "caffe2/core/logging.h" 14 using EArrXt = Eigen::Array<T, Eigen::Dynamic, 1>;
15 using EArrXf = Eigen::ArrayXf;
16 using EArrXd = Eigen::ArrayXd;
17 using EArrXi = Eigen::ArrayXi;
18 using EArrXb = EArrXt<bool>;
22 using EArrXXt = Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>;
23 using EArrXXf = Eigen::ArrayXXf;
28 Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
29 using ERArrXXf = ERArrXXt<float>;
33 using EVecXt = Eigen::Matrix<T, Eigen::Dynamic, 1>;
34 using EVecXd = Eigen::VectorXd;
35 using EVecXf = Eigen::VectorXf;
38 using ERVecXd = Eigen::RowVectorXd;
39 using ERVecXf = Eigen::RowVectorXf;
43 using EMatXt = Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>;
44 using EMatXd = Eigen::MatrixXd;
45 using EMatXf = Eigen::MatrixXf;
50 Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
51 using ERMatXd = ERMatXt<double>;
52 using ERMatXf = ERMatXt<float>;
57 Eigen::Map<const EArrXt<T>> AsEArrXt(
const std::vector<T>& arr) {
58 return {arr.data(),
static_cast<int>(arr.size())};
61 Eigen::Map<EArrXt<T>> AsEArrXt(std::vector<T>& arr) {
62 return {arr.data(),
static_cast<int>(arr.size())};
66 template <
class Derived,
class Derived1,
class Derived2>
68 const Eigen::ArrayBase<Derived>& array,
69 const Eigen::ArrayBase<Derived1>& indices,
70 Eigen::ArrayBase<Derived2>* out_array) {
71 CAFFE_ENFORCE_EQ(array.cols(), 1);
74 out_array->derived().resize(indices.size());
75 for (
int i = 0; i < indices.size(); i++) {
76 DCHECK_LT(indices[i], array.size());
77 (*out_array)[i] = array[indices[i]];
82 template <
class Derived,
class Derived1>
83 EArrXt<typename Derived::Scalar> GetSubArray(
84 const Eigen::ArrayBase<Derived>& array,
85 const Eigen::ArrayBase<Derived1>& indices) {
86 using T =
typename Derived::Scalar;
87 EArrXt<T> ret(indices.size());
88 GetSubArray(array, indices, &ret);
93 template <
class Derived>
94 EArrXt<typename Derived::Scalar> GetSubArray(
95 const Eigen::ArrayBase<Derived>& array,
96 const std::vector<int>& indices) {
97 return GetSubArray(array, AsEArrXt(indices));
101 template <
class Derived,
class Derived1,
class Derived2>
102 void GetSubArrayRows(
103 const Eigen::ArrayBase<Derived>& array2d,
104 const Eigen::ArrayBase<Derived1>& row_indices,
105 Eigen::ArrayBase<Derived2>* out_array) {
106 out_array->derived().resize(row_indices.size(), array2d.cols());
108 for (
int i = 0; i < row_indices.size(); i++) {
109 DCHECK_LT(row_indices[i], array2d.size());
111 array2d.row(row_indices[i]).template cast<typename Derived2::Scalar>();
116 template <
class Derived>
117 std::vector<int> GetArrayIndices(
const Eigen::ArrayBase<Derived>& array) {
118 std::vector<int> ret;
119 for (
int i = 0; i < array.size(); i++) {
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...