#if defined(__MACH__) #include #include #endif #if !defined(__WIN32__) #include #include #if !defined(__ANDROID__) #include #endif #endif #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include //===== EXPANDIND: mxnet0.cc ===== // mexnet.cc #define MSHADOW_FORCE_STREAM #define MSHADOW_USE_CUDA 0 #define MSHADOW_USE_CBLAS 1 #define MSHADOW_USE_MKL 0 #define MSHADOW_RABIT_PS 0 #define MSHADOW_DIST_PS 0 #define MXNET_USE_OPENCV 0 #define DISABLE_OPENMP 1 //===== EXPANDIND: mxnet/src/ndarray/unary_function.cc ===== /*! * Copyright (c) 2015 by Contributors * \file unary_function.cc * \brief CPU Implementation of unary function. */ // this will be invoked by gcc and compile CPU version //===== EXPANDIND: mxnet/src/ndarray/unary_function-inl.h ===== /*! * Copyright (c) 2015 by Contributors * \file unary-function-inl.h * \brief the real execution functions of ndarray operations */ #ifndef MXNET_NDARRAY_UNARY_FUNCTION_INL_H_ #define MXNET_NDARRAY_UNARY_FUNCTION_INL_H_ //===== EXPANDIND: mxnet/src/common/tblob_op_registry.h ===== /*! * Copyright (c) 2015 by Contributors * \file tblob_op_registry.h * \brief Helper registry to make registration of simple unary binary math function easy. * Register to this registry will enable both symbolic operator and NDArray operator in client. * * More complicated operators can be registered in normal way in ndarray and operator modules. */ #ifndef MXNET_COMMON_TBLOB_OP_REGISTRY_H_ #define MXNET_COMMON_TBLOB_OP_REGISTRY_H_ //===== EXPANDIND: mxnet/dmlc-core/include/dmlc/registry.h ===== /*! * Copyright (c) 2015 by Contributors * \file registry.h * \brief Registry utility that helps to build registry singletons. */ #ifndef DMLC_REGISTRY_H_ #define DMLC_REGISTRY_H_ //===== EXPANDIND: mxnet/dmlc-core/include/dmlc/base.h ===== /*! * Copyright (c) 2015 by Contributors * \file base.h * \brief defines configuration macros */ #ifndef DMLC_BASE_H_ #define DMLC_BASE_H_ /*! \brief whether use glog for logging */ #ifndef DMLC_USE_GLOG #define DMLC_USE_GLOG 0 #endif /*! * \brief whether throw dmlc::Error instead of * directly calling abort when FATAL error occured * NOTE: this may still not be perfect. * do not use FATAL and CHECK in destructors */ #ifndef DMLC_LOG_FATAL_THROW #define DMLC_LOG_FATAL_THROW 1 #endif /*! \brief whether compile with hdfs support */ #ifndef DMLC_USE_HDFS #define DMLC_USE_HDFS 0 #endif /*! \brief whether compile with s3 support */ #ifndef DMLC_USE_S3 #define DMLC_USE_S3 0 #endif /*! \brief whether or not use parameter server */ #ifndef DMLC_USE_PS #define DMLC_USE_PS 0 #endif /*! \brief whether or not use c++11 support */ #ifndef DMLC_USE_CXX11 #define DMLC_USE_CXX11 (defined(__GXX_EXPERIMENTAL_CXX0X__) ||\ __cplusplus >= 201103L || defined(_MSC_VER)) #endif /// check if g++ is before 4.6 #if DMLC_USE_CXX11 && defined(__GNUC__) && !defined(__clang_version__) #if __GNUC__ == 4 && __GNUC_MINOR__ < 6 #pragma message("Will need g++-4.6 or higher to compile all" \ "the features in dmlc-core, " \ "compile without c++0x, some features may be disabled") #undef DMLC_USE_CXX11 #define DMLC_USE_CXX11 0 #endif #endif /*! * \brief Disable copy constructor and assignment operator. * * If C++11 is supported, both copy and move constructors and * assignment operators are deleted explicitly. Otherwise, they are * only declared but not implemented. Place this macro in private * section if C++11 is not available. */ #ifndef DISALLOW_COPY_AND_ASSIGN # if DMLC_USE_CXX11 # define DISALLOW_COPY_AND_ASSIGN(T) \ T(T const&) = delete; \ T(T&&) = delete; \ T& operator=(T const&) = delete; \ T& operator=(T&&) = delete # else # define DISALLOW_COPY_AND_ASSIGN(T) \ T(T const&); \ T& operator=(T const&) # endif #endif /// /// code block to handle optionally loading /// #if !defined(__GNUC__) #define fopen64 std::fopen #endif #ifdef _MSC_VER #if _MSC_VER < 1900 // NOTE: sprintf_s is not equivalent to snprintf, // they are equivalent when success, which is sufficient for our case #define snprintf sprintf_s #define vsnprintf vsprintf_s #endif #else #ifdef _FILE_OFFSET_BITS #if _FILE_OFFSET_BITS == 32 #pragma message("Warning: FILE OFFSET BITS defined to be 32 bit") #endif #endif #ifdef __APPLE__ #define off64_t off_t #define fopen64 std::fopen #endif extern "C" { } #endif #ifdef _MSC_VER //! \cond Doxygen_Suppress typedef signed char int8_t; typedef __int16 int16_t; typedef __int32 int32_t; typedef __int64 int64_t; typedef unsigned char uint8_t; typedef unsigned __int16 uint16_t; typedef unsigned __int32 uint32_t; typedef unsigned __int64 uint64_t; //! \endcond #else #endif /*! \brief namespace for dmlc */ namespace dmlc { /*! * \brief safely get the beginning address of a vector * \param vec input vector * \return beginning address of a vector */ template inline T *BeginPtr(std::vector &vec) { // NOLINT(*) if (vec.size() == 0) { return NULL; } else { return &vec[0]; } } /*! * \brief get the beginning address of a vector * \param vec input vector * \return beginning address of a vector */ template inline const T *BeginPtr(const std::vector &vec) { if (vec.size() == 0) { return NULL; } else { return &vec[0]; } } /*! * \brief get the beginning address of a vector * \param str input string * \return beginning address of a string */ inline char* BeginPtr(std::string &str) { // NOLINT(*) if (str.length() == 0) return NULL; return &str[0]; } /*! * \brief get the beginning address of a vector * \param str input string * \return beginning address of a string */ inline const char* BeginPtr(const std::string &str) { if (str.length() == 0) return NULL; return &str[0]; } } // namespace dmlc #if defined(_MSC_VER) && _MSC_VER < 1900 #define constexpr const #define alignof __alignof #endif #endif // DMLC_BASE_H_ //===== EXPANDED: mxnet/dmlc-core/include/dmlc/base.h ===== //===== EXPANDIND: mxnet/dmlc-core/include/dmlc/logging.h ===== /*! * Copyright (c) 2015 by Contributors * \file logging.h * \brief defines logging macros of dmlc * allows use of GLOG, fall back to internal * implementation when disabled */ #ifndef DMLC_LOGGING_H_ #define DMLC_LOGGING_H_ namespace dmlc { /*! * \brief exception class that will be thrown by * default logger if DMLC_LOG_FATAL_THROW == 1 */ struct Error : public std::runtime_error { /*! * \brief constructor * \param s the error message */ explicit Error(const std::string &s) : std::runtime_error(s) {} }; } // namespace dmlc #if defined(_MSC_VER) && _MSC_VER < 1900 #define noexcept(a) #endif #if DMLC_USE_CXX11 #define DMLC_THROW_EXCEPTION noexcept(false) #else #define DMLC_THROW_EXCEPTION #endif #if DMLC_USE_GLOG namespace dmlc { inline void InitLogging(const char* argv0) { google::InitGoogleLogging(argv0); } } // namespace dmlc #else // use a light version of glog #if defined(_MSC_VER) #pragma warning(disable : 4722) #endif namespace dmlc { inline void InitLogging(const char* argv0) { // DO NOTHING } // Always-on checking #define CHECK(x) \ if (!(x)) \ dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "Check " \ "failed: " #x << ' ' #define CHECK_LT(x, y) CHECK((x) < (y)) #define CHECK_GT(x, y) CHECK((x) > (y)) #define CHECK_LE(x, y) CHECK((x) <= (y)) #define CHECK_GE(x, y) CHECK((x) >= (y)) #define CHECK_EQ(x, y) CHECK((x) == (y)) #define CHECK_NE(x, y) CHECK((x) != (y)) #define CHECK_NOTNULL(x) \ ((x) == NULL ? dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "Check notnull: " #x << ' ', (x) : (x)) // NOLINT(*) // Debug-only checking. #ifdef NDEBUG #define DCHECK(x) \ while (false) CHECK(x) #define DCHECK_LT(x, y) \ while (false) CHECK((x) < (y)) #define DCHECK_GT(x, y) \ while (false) CHECK((x) > (y)) #define DCHECK_LE(x, y) \ while (false) CHECK((x) <= (y)) #define DCHECK_GE(x, y) \ while (false) CHECK((x) >= (y)) #define DCHECK_EQ(x, y) \ while (false) CHECK((x) == (y)) #define DCHECK_NE(x, y) \ while (false) CHECK((x) != (y)) #else #define DCHECK(x) CHECK(x) #define DCHECK_LT(x, y) CHECK((x) < (y)) #define DCHECK_GT(x, y) CHECK((x) > (y)) #define DCHECK_LE(x, y) CHECK((x) <= (y)) #define DCHECK_GE(x, y) CHECK((x) >= (y)) #define DCHECK_EQ(x, y) CHECK((x) == (y)) #define DCHECK_NE(x, y) CHECK((x) != (y)) #endif // NDEBUG #define LOG_INFO dmlc::LogMessage(__FILE__, __LINE__) #define LOG_ERROR LOG_INFO #define LOG_WARNING LOG_INFO #define LOG_FATAL dmlc::LogMessageFatal(__FILE__, __LINE__) #define LOG_QFATAL LOG_FATAL // Poor man version of VLOG #define VLOG(x) LOG_INFO.stream() #define LOG(severity) LOG_##severity.stream() #define LG LOG_INFO.stream() #define LOG_IF(severity, condition) \ !(condition) ? (void)0 : dmlc::LogMessageVoidify() & LOG(severity) #ifdef NDEBUG #define LOG_DFATAL LOG_ERROR #define DFATAL ERROR #define DLOG(severity) true ? (void)0 : dmlc::LogMessageVoidify() & LOG(severity) #define DLOG_IF(severity, condition) \ (true || !(condition)) ? (void)0 : dmlc::LogMessageVoidify() & LOG(severity) #else #define LOG_DFATAL LOG_FATAL #define DFATAL FATAL #define DLOG(severity) LOG(severity) #define DLOG_IF(severity, condition) LOG_IF(severity, condition) #endif // Poor man version of LOG_EVERY_N #define LOG_EVERY_N(severity, n) LOG(severity) class DateLogger { public: DateLogger() { #if defined(_MSC_VER) _tzset(); #endif } const char* HumanDate() { #if defined(_MSC_VER) _strtime_s(buffer_, sizeof(buffer_)); #else time_t time_value = time(NULL); struct tm now; localtime_r(&time_value, &now); snprintf(buffer_, sizeof(buffer_), "%02d:%02d:%02d", now.tm_hour, now.tm_min, now.tm_sec); #endif return buffer_; } private: char buffer_[9]; }; class LogMessage { public: LogMessage(const char* file, int line) : #ifdef __ANDROID__ log_stream_(std::cout) #else log_stream_(std::cerr) #endif { log_stream_ << "[" << pretty_date_.HumanDate() << "] " << file << ":" << line << ": "; } ~LogMessage() { log_stream_ << "\n"; } std::ostream& stream() { return log_stream_; } protected: std::ostream& log_stream_; private: DateLogger pretty_date_; LogMessage(const LogMessage&); void operator=(const LogMessage&); }; #if DMLC_LOG_FATAL_THROW == 0 class LogMessageFatal : public LogMessage { public: LogMessageFatal(const char* file, int line) : LogMessage(file, line) {} ~LogMessageFatal() { log_stream_ << "\n"; abort(); } private: LogMessageFatal(const LogMessageFatal&); void operator=(const LogMessageFatal&); }; #else class LogMessageFatal { public: LogMessageFatal(const char* file, int line) { log_stream_ << "[" << pretty_date_.HumanDate() << "] " << file << ":" << line << ": "; } std::ostringstream &stream() { return log_stream_; } ~LogMessageFatal() DMLC_THROW_EXCEPTION { // throwing out of destructor is evil // hopefully we can do it here // also log the message before throw LOG(ERROR) << log_stream_.str(); throw Error(log_stream_.str()); } private: std::ostringstream log_stream_; DateLogger pretty_date_; LogMessageFatal(const LogMessageFatal&); void operator=(const LogMessageFatal&); }; #endif // This class is used to explicitly ignore values in the conditional // logging macros. This avoids compiler warnings like "value computed // is not used" and "statement has no effect". class LogMessageVoidify { public: LogMessageVoidify() {} // This has to be an operator with a precedence lower than << but // higher than "?:". See its usage. void operator&(std::ostream&) {} }; } // namespace dmlc #endif #endif // DMLC_LOGGING_H_ //===== EXPANDED: mxnet/dmlc-core/include/dmlc/logging.h ===== //===== EXPANDIND: mxnet/dmlc-core/include/dmlc/parameter.h ===== /*! * Copyright (c) 2015 by Contributors * \file parameter.h * \brief Provide lightweight util to do parameter setup and checking. */ #ifndef DMLC_PARAMETER_H_ #define DMLC_PARAMETER_H_ //===== EXPANDIND: mxnet/dmlc-core/include/dmlc/json.h ===== /*! * Copyright (c) 2015 by Contributors * \file json.h * \brief Lightweight JSON Reader/Writer that read save into C++ data structs. * This includes STL composites and structures. */ #ifndef DMLC_JSON_H_ #define DMLC_JSON_H_ //===== EXPANDIND: mxnet/dmlc-core/include/dmlc/type_traits.h ===== /*! * Copyright (c) 2015 by Contributors * \file type_traits.h * \brief type traits information header */ #ifndef DMLC_TYPE_TRAITS_H_ #define DMLC_TYPE_TRAITS_H_ #if DMLC_USE_CXX11 #endif namespace dmlc { /*! * \brief whether a type is pod type * \tparam T the type to query */ template struct is_pod { #if DMLC_USE_CXX11 /*! \brief the value of the traits */ static const bool value = std::is_pod::value; #else /*! \brief the value of the traits */ static const bool value = false; #endif }; /*! * \brief whether a type is integer type * \tparam T the type to query */ template struct is_integral { #if DMLC_USE_CXX11 /*! \brief the value of the traits */ static const bool value = std::is_integral::value; #else /*! \brief the value of the traits */ static const bool value = false; #endif }; /*! * \brief whether a type is floating point type * \tparam T the type to query */ template struct is_floating_point { #if DMLC_USE_CXX11 /*! \brief the value of the traits */ static const bool value = std::is_floating_point::value; #else /*! \brief the value of the traits */ static const bool value = false; #endif }; /*! * \brief whether a type is arithemetic type * \tparam T the type to query */ template struct is_arithmetic { #if DMLC_USE_CXX11 /*! \brief the value of the traits */ static const bool value = std::is_arithmetic::value; #else /*! \brief the value of the traits */ static const bool value = (dmlc::is_integral::value || dmlc::is_floating_point::value); #endif }; /*! * \brief the string representation of type name * \tparam T the type to query * \return a const string of typename. */ template inline const char* type_name() { return ""; } /*! * \brief whether a type have save/load function * \tparam T the type to query */ template struct has_saveload { /*! \brief the value of the traits */ static const bool value = false; }; /*! * \brief template to select type based on condition * For example, IfThenElseType::Type will give int * \tparam cond the condition * \tparam Then the typename to be returned if cond is true * \tparam The typename to be returned if cond is false */ template struct IfThenElseType; /*! \brief macro to quickly declare traits information */ #define DMLC_DECLARE_TRAITS(Trait, Type, Value) \ template<> \ struct Trait { \ static const bool value = Value; \ } /*! \brief macro to quickly declare traits information */ #define DMLC_DECLARE_TYPE_NAME(Type, Name) \ template<> \ inline const char* type_name() { \ return Name; \ } //! \cond Doxygen_Suppress // declare special traits when C++11 is not available #if DMLC_USE_CXX11 == 0 DMLC_DECLARE_TRAITS(is_pod, char, true); DMLC_DECLARE_TRAITS(is_pod, int8_t, true); DMLC_DECLARE_TRAITS(is_pod, int16_t, true); DMLC_DECLARE_TRAITS(is_pod, int32_t, true); DMLC_DECLARE_TRAITS(is_pod, int64_t, true); DMLC_DECLARE_TRAITS(is_pod, uint8_t, true); DMLC_DECLARE_TRAITS(is_pod, uint16_t, true); DMLC_DECLARE_TRAITS(is_pod, uint32_t, true); DMLC_DECLARE_TRAITS(is_pod, uint64_t, true); DMLC_DECLARE_TRAITS(is_pod, float, true); DMLC_DECLARE_TRAITS(is_pod, double, true); DMLC_DECLARE_TRAITS(is_integral, char, true); DMLC_DECLARE_TRAITS(is_integral, int8_t, true); DMLC_DECLARE_TRAITS(is_integral, int16_t, true); DMLC_DECLARE_TRAITS(is_integral, int32_t, true); DMLC_DECLARE_TRAITS(is_integral, int64_t, true); DMLC_DECLARE_TRAITS(is_integral, uint8_t, true); DMLC_DECLARE_TRAITS(is_integral, uint16_t, true); DMLC_DECLARE_TRAITS(is_integral, uint32_t, true); DMLC_DECLARE_TRAITS(is_integral, uint64_t, true); DMLC_DECLARE_TRAITS(is_floating_point, float, true); DMLC_DECLARE_TRAITS(is_floating_point, double, true); #endif DMLC_DECLARE_TYPE_NAME(float, "float"); DMLC_DECLARE_TYPE_NAME(double, "double"); DMLC_DECLARE_TYPE_NAME(int, "int"); DMLC_DECLARE_TYPE_NAME(uint32_t, "int (non-negative)"); DMLC_DECLARE_TYPE_NAME(uint64_t, "long (non-negative)"); DMLC_DECLARE_TYPE_NAME(std::string, "string"); DMLC_DECLARE_TYPE_NAME(bool, "boolean"); template struct IfThenElseType { typedef Then Type; }; template struct IfThenElseType { typedef Else Type; }; //! \endcond } // namespace dmlc #endif // DMLC_TYPE_TRAITS_H_ //===== EXPANDED: mxnet/dmlc-core/include/dmlc/type_traits.h ===== #if DMLC_USE_CXX11 #endif namespace dmlc { /*! * \brief Lightweight JSON Reader to read any STL compositions and structs. * The user need to know the schema of the * */ class JSONReader { public: /*! * \brief Constructor. * \param is the input stream. */ explicit JSONReader(std::istream *is) : is_(is), line_count_r_(0), line_count_n_(0) {} /*! * \brief Parse next JSON string. * \param out_str the output string. * \throw dmlc::Error when next token is not string */ inline void ReadString(std::string *out_str); /*! * \brief Read Number. * \param out_value output value; * \throw dmlc::Error when next token is not number of ValueType. * \tparam ValueType type of the number */ template inline void ReadNumber(ValueType *out_value); /*! * \brief Begin parsing an object. * \code * std::string key; * // value can be any type that is json serializable. * std::string value; * reader->BeginObject(); * while (reader->NextObjectItem(&key)) { * // do somthing to key value * reader->Read(&value); * } * \endcode */ inline void BeginObject(); /*! * \brief Begin parsing an array. * \code * // value can be any type that is json serializable. * std::string value; * reader->BeginArray(); * while (reader->NextObjectArrayItem(&value)) { * // do somthing to value * } * \endcode */ inline void BeginArray(); /*! * \brief Try to move to next object item. * If this call is successful, user can proceed to call * reader->Read to read in the value. * \param out_key the key to the next object. * \return true if the read is successful, false if we are at end of the object. */ inline bool NextObjectItem(std::string *out_key); /*! * \brief Try to read the next element in the array. * If this call is successful, user can proceed to call * reader->Read to read in the value. * \return true if the read is successful, false if we are at end of the array. */ inline bool NextArrayItem(); /*! * \brief Read next ValueType. * \param out_value any STL or json readable type to be read * \throw dmlc::Error when the read of ValueType is not successful. * \tparam ValueType the data type to be read. */ template inline void Read(ValueType *out_value); /*! \return current line count */ inline std::string line_info() const { char temp[64]; std::ostringstream os; os << " Line " << std::max(line_count_r_, line_count_n_); is_->getline(temp, 64); os << ", around ^`" << temp << "`"; return os.str(); } private: /*! \brief internal reader stream */ std::istream *is_; /*! \brief "\\r" counter */ size_t line_count_r_; /*! \brief "\\n" counter */ size_t line_count_n_; /*! * \brief record how many element processed in * current array/object scope. */ std::vector scope_counter_; /*! * \brief Read next nonspace character. * \return the next nonspace character. */ inline int NextNonSpace(); /*! * \brief Read just before next nonspace but not read that. * \return the next nonspace character. */ inline int PeekNextNonSpace(); }; /*! * \brief Lightweight json to write any STL compositions. */ class JSONWriter { public: /*! * \brief Constructor. * \param os the output stream. */ explicit JSONWriter(std::ostream *os) : os_(os) {} /*! * \brief Write a string that do not contain escape characters. * \param s the string to be written. */ inline void WriteNoEscape(const std::string &s); /*! * \brief Write a string that can contain escape characters. * \param s the string to be written. */ inline void WriteString(const std::string &s); /*! * \brief Write a string that can contain escape characters. * \param v the value to be written. * \tparam ValueType The value type to be written. */ template inline void WriteNumber(const ValueType &v); /*! * \brief Start beginning of array. * \param multi_line whether to start an multi_line array. * \code * writer->BeginArray(); * for (auto& v : vdata) { * writer->WriteArrayItem(v); * } * writer->EndArray(); * \endcode */ inline void BeginArray(bool multi_line = true); /*! \brief Finish writing an array. */ inline void EndArray(); /*! * \brief Start beginning of array. * \param multi_line whether to start an multi_line array. * \code * writer->BeginObject(); * for (auto& kv : vmap) { * writer->WriteObjectKeyValue(kv.first, kv.second); * } * writer->EndObject(); * \endcode */ inline void BeginObject(bool multi_line = true); /*! \brief Finish writing object. */ inline void EndObject(); /*! * \brief Write key value pair in the object. * \param key the key of the object. * \param value the value of to be written. * \tparam ValueType The value type to be written. */ template inline void WriteObjectKeyValue(const std::string &key, const ValueType &value); /*! * \brief Write value into array. * \param value The value of to be written. * \tparam ValueType The value type to be written. */ template inline void WriteArrayItem(const ValueType &value); /*! * \brief Write value to json. * \param value any STL or json readable that can be written. * \tparam ValueType the data type to be write. */ template inline void Write(const ValueType &value); private: /*! \brief Output stream */ std::ostream *os_; /*! * \brief record how many element processed in * current array/object scope. */ std::vector scope_counter_; /*! \brief Record whether current is a multiline scope */ std::vector scope_multi_line_; /*! * \brief Write seperating space and newlines */ inline void WriteSeperator(); }; /*! * \brief Helper class to read JSON into a class or struct object. * \code * struct Param { * std::string name; * int value; * // define load function from JSON * inline void Load(dmlc::JSONReader *reader) { * dmlc::JSONStructReadHelper helper; * helper.DeclareField("name", &name); * helper.DeclareField("value", &value); * helper.ReadAllFields(reader); * } * }; * \endcode */ class JSONObjectReadHelper { public: /*! * \brief Declare field of type T * \param key the key of the of field. * \param addr address of the data type. * \tparam T the data type to be read, must be STL composition of JSON serializable. */ template inline void DeclareField(const std::string &key, T *addr); /*! * \brief Read in all the declared fields. * \param reader the JSONReader to read the json. */ inline void ReadAllFields(JSONReader *reader); private: /*! * \brief The internal reader function. * \param reader The reader to read. * \param addr The memory address to read. */ template inline static void ReaderFunction(JSONReader *reader, void *addr); /*! \brief callback type to reader function */ typedef void (*ReadFunction)(JSONReader *reader, void *addr); /*! \brief the internal map of reader callbacks */ std::map > map_; }; //! \cond Doxygen_Suppress namespace json { /*! * \brief generic serialization handler * \tparam T the type to be serialized */ template struct Handler; template struct NumericHandler { inline static void Write(JSONWriter *writer, const ValueType &value) { writer->WriteNumber(value); } inline static void Read(JSONReader *reader, ValueType *value) { reader->ReadNumber(value); } }; template struct ArrayHandler { inline static void Write(JSONWriter *writer, const ContainerType &array) { typedef typename ContainerType::value_type ElemType; writer->BeginArray(array.size() > 10 || !dmlc::is_pod::value); for (typename ContainerType::const_iterator it = array.begin(); it != array.end(); ++it) { writer->WriteArrayItem(*it); } writer->EndArray(); } inline static void Read(JSONReader *reader, ContainerType *array) { typedef typename ContainerType::value_type ElemType; array->clear(); reader->BeginArray(); while (reader->NextArrayItem()) { ElemType value; Handler::Read(reader, &value); array->insert(array->end(), value); } } }; template struct MapHandler{ inline static void Write(JSONWriter *writer, const ContainerType &map) { writer->BeginObject(map.size() > 1); for (typename ContainerType::const_iterator it = map.begin(); it != map.end(); ++it) { writer->WriteObjectKeyValue(it->first, it->second); } writer->EndObject(); } inline static void Read(JSONReader *reader, ContainerType *map) { typedef typename ContainerType::mapped_type ElemType; map->clear(); reader->BeginObject(); std::string key; while (reader->NextObjectItem(&key)) { ElemType value; reader->Read(&value); (*map)[key] = value; } } }; template struct CommonJSONSerializer { inline static void Write(JSONWriter *writer, const T &value) { value.Save(writer); } inline static void Read(JSONReader *reader, T *value) { value->Load(reader); } }; template<> struct Handler { inline static void Write(JSONWriter *writer, const std::string &value) { writer->WriteString(value); } inline static void Read(JSONReader *reader, std::string *str) { reader->ReadString(str); } }; template struct Handler > : public ArrayHandler > { }; template struct Handler > { inline static void Write(JSONWriter *writer, const std::pair &kv) { writer->BeginArray(); writer->WriteArrayItem(kv.first); writer->WriteArrayItem(kv.second); writer->EndArray(); } inline static void Read(JSONReader *reader, std::pair *kv) { reader->BeginArray(); CHECK(reader->NextArrayItem()) << "Expect array of length 2"; Handler::Read(reader, &(kv->first)); CHECK(reader->NextArrayItem()) << "Expect array of length 2"; Handler::Read(reader, &(kv->second)); CHECK(!reader->NextArrayItem()) << "Expect array of length 2"; } }; template struct Handler > : public ArrayHandler > { }; template struct Handler > : public MapHandler > { }; #if DMLC_USE_CXX11 template struct Handler > : public MapHandler > { }; #endif template struct Handler { inline static void Write(JSONWriter *writer, const T &data) { typedef typename dmlc::IfThenElseType::value, NumericHandler, CommonJSONSerializer >::Type THandler; THandler::Write(writer, data); } inline static void Read(JSONReader *reader, T *data) { typedef typename dmlc::IfThenElseType::value, NumericHandler, CommonJSONSerializer >::Type THandler; THandler::Read(reader, data); } }; } // namespace json // implementations of JSONReader/Writer inline int JSONReader::NextNonSpace() { int ch; do { ch = is_->get(); if (ch == '\n') ++line_count_n_; if (ch == '\r') ++line_count_r_; } while (isspace(ch)); return ch; } inline int JSONReader::PeekNextNonSpace() { int ch; while (true) { ch = is_->peek(); if (ch == '\n') ++line_count_n_; if (ch == '\r') ++line_count_r_; if (!isspace(ch)) break; is_->get(); } return ch; } inline void JSONReader::ReadString(std::string *out_str) { int ch = NextNonSpace(); CHECK_EQ(ch, '\"') << "Error at" << line_info() << ", Expect \'\"\' but get \'" << static_cast(ch) << '\''; std::ostringstream os; while (true) { ch = is_->get(); if (ch == '\"') break; if (ch == '\\') { os << is_->get(); } else { os << static_cast(ch); } if (ch == EOF || ch == '\r' || ch == '\n') { LOG(FATAL) << "Error at" << line_info() << ", Expect \'\"\' but reach end of line "; } } *out_str = os.str(); } template inline void JSONReader::ReadNumber(ValueType *out_value) { *is_ >> *out_value; CHECK(!is_->fail()) << "Error at" << line_info() << ", Expect number"; } inline void JSONReader::BeginObject() { int ch = NextNonSpace(); CHECK_EQ(ch, '{') << "Error at" << line_info() << ", Expect \'{\' but get \'" << static_cast(ch) << '\''; scope_counter_.push_back(0); } inline void JSONReader::BeginArray() { int ch = NextNonSpace(); CHECK_EQ(ch, '[') << "Error at" << line_info() << ", Expect \'{\' but get \'" << static_cast(ch) << '\''; scope_counter_.push_back(0); } inline bool JSONReader::NextObjectItem(std::string *out_key) { bool next = true; if (scope_counter_.back() != 0) { int ch = NextNonSpace(); if (ch == EOF) { next = false; } else if (ch == '}') { next = false; } else { CHECK_EQ(ch, ',') << "Error at" << line_info() << ", JSON object expect \'}\' or \',\' \'" << static_cast(ch) << '\''; } } else { int ch = PeekNextNonSpace(); if (ch == '}') { is_->get(); next = false; } } if (!next) { scope_counter_.pop_back(); return false; } else { scope_counter_.back() += 1; ReadString(out_key); int ch = NextNonSpace(); CHECK_EQ(ch, ':') << "Error at" << line_info() << ", Expect \':\' but get \'" << static_cast(ch) << '\''; return true; } } inline bool JSONReader::NextArrayItem() { bool next = true; if (scope_counter_.back() != 0) { int ch = NextNonSpace(); if (ch == EOF) { next = false; } else if (ch == ']') { next = false; } else { CHECK_EQ(ch, ',') << "Error at" << line_info() << ", JSON array expect \']\' or \',\'. Get \'" << static_cast(ch) << "\' instead"; } } else { int ch = PeekNextNonSpace(); if (ch == ']') { is_->get(); next = false; } } if (!next) { scope_counter_.pop_back(); return false; } else { scope_counter_.back() += 1; return true; } } template inline void JSONReader::Read(ValueType *out_value) { json::Handler::Read(this, out_value); } inline void JSONWriter::WriteNoEscape(const std::string &s) { *os_ << '\"' << s << '\"'; } inline void JSONWriter::WriteString(const std::string &s) { std::ostream &os = *os_; os << '\"'; for (size_t i = 0; i < s.length(); ++i) { char ch = s[i]; switch (ch) { case '\r': os << "\\r"; break; case '\n': os << "\\n"; break; case '\\': os << "\\\\"; break; case '\t': os << "\\t"; break; case '\"': os << "\\\""; break; default: os << ch; } } os << '\"'; } template inline void JSONWriter::WriteNumber(const ValueType &v) { *os_ << v; } inline void JSONWriter::BeginArray(bool multi_line) { *os_ << '['; scope_multi_line_.push_back(multi_line); scope_counter_.push_back(0); } inline void JSONWriter::EndArray() { CHECK_NE(scope_multi_line_.size(), 0); CHECK_NE(scope_counter_.size(), 0); bool newline = scope_multi_line_.back(); size_t nelem = scope_counter_.back(); scope_multi_line_.pop_back(); scope_counter_.pop_back(); if (newline && nelem != 0) WriteSeperator(); *os_ << ']'; } inline void JSONWriter::BeginObject(bool multi_line) { *os_ << "{"; scope_multi_line_.push_back(multi_line); scope_counter_.push_back(0); } inline void JSONWriter::EndObject() { CHECK_NE(scope_multi_line_.size(), 0); CHECK_NE(scope_counter_.size(), 0); bool newline = scope_multi_line_.back(); size_t nelem = scope_counter_.back(); scope_multi_line_.pop_back(); scope_counter_.pop_back(); if (newline && nelem != 0) WriteSeperator(); *os_ << '}'; } template inline void JSONWriter::WriteObjectKeyValue(const std::string &key, const ValueType &value) { std::ostream &os = *os_; if (scope_counter_.back() == 0) { WriteSeperator(); os << '\"' << key << "\": "; } else { os << ", "; WriteSeperator(); os << '\"' << key << "\": "; } scope_counter_.back() += 1; json::Handler::Write(this, value); } template inline void JSONWriter::WriteArrayItem(const ValueType &value) { std::ostream &os = *os_; if (scope_counter_.back() != 0) { os << ", "; } scope_counter_.back() += 1; WriteSeperator(); json::Handler::Write(this, value); } template inline void JSONWriter::Write(const ValueType &value) { size_t nscope = scope_multi_line_.size(); json::Handler::Write(this, value); CHECK_EQ(nscope, scope_multi_line_.size()) << "Uneven scope, did you call EndArray/EndObject after each BeginObject/Array?"; } inline void JSONWriter::WriteSeperator() { if (scope_multi_line_.size() == 0 || scope_multi_line_.back()) { *os_ << '\n' << std::string(scope_multi_line_.size() * 2, ' '); } } inline void JSONObjectReadHelper::ReadAllFields(JSONReader *reader) { reader->BeginObject(); std::map visited; std::string key; while (reader->NextObjectItem(&key)) { if (map_.count(key) != 0) { std::pair kv = map_[key]; (*kv.first)(reader, kv.second); visited[key] = 0; } else { std::ostringstream os; os << "JSONReader: Unknown field " << key << ", candidates are: \n"; for (std::map >::iterator it = map_.begin(); it != map_.end(); ++it) { os << '\"' <first << "\"\n"; } LOG(FATAL) << os.str(); } } if (visited.size() != map_.size()) { for (std::map >::iterator it = map_.begin(); it != map_.end(); ++it) { CHECK_NE(visited.count(it->first), 0) << "JSONReader: Missing field \"" << it->first << "\"\n At " << reader->line_info(); } } } template inline void JSONObjectReadHelper::ReaderFunction(JSONReader *reader, void *addr) { json::Handler::Read(reader, static_cast(addr)); } template inline void JSONObjectReadHelper::DeclareField(const std::string &key, T *addr) { CHECK_EQ(map_.count(key), 0) << "Adding duplicate field " << key; map_[key] = std::make_pair(ReaderFunction, static_cast(addr)); } //! \endcond } // namespace dmlc #endif // DMLC_JSON_H_ //===== EXPANDED: mxnet/dmlc-core/include/dmlc/json.h ===== namespace dmlc { // this file is backward compatible with non-c++11 /*! \brief Error throwed by parameter checking */ struct ParamError : public dmlc::Error { /*! * \brief constructor * \param msg error message */ explicit ParamError(const std::string &msg) : dmlc::Error(msg) {} }; /*! * \brief Get environment variable with default. * \param key the name of environment variable. * \param default_value the default value of environment vriable. * \return The value received */ template inline ValueType GetEnv(const char *key, ValueType default_value); /*! \brief internal namespace for parameter manangement */ namespace parameter { // forward declare ParamManager class ParamManager; // forward declare FieldAccessEntry class FieldAccessEntry; // forward declare FieldEntry template class FieldEntry; // forward declare ParamManagerSingleton template struct ParamManagerSingleton; } // namespace parameter /*! * \brief Information about a parameter field in string representations. */ struct ParamFieldInfo { /*! \brief name of the field */ std::string name; /*! \brief type of the field in string format */ std::string type; /*! * \brief detailed type information string * This include the default value, enum constran and typename. */ std::string type_info_str; /*! \brief detailed description of the type */ std::string description; }; /*! * \brief Parameter is the base type every parameter struct should inheritate from * The following code is a complete example to setup parameters. * \code * struct Param : public dmlc::Parameter { * float learning_rate; * int num_hidden; * std::string name; * // declare parameters in header file * DMLC_DECLARE_PARAMETER(Param) { * DMLC_DECLARE_FIELD(num_hidden).set_range(0, 1000); * DMLC_DECLARE_FIELD(learning_rate).set_default(0.01f); * DMLC_DECLARE_FIELD(name).set_default("hello"); * } * }; * // register it in cc file * DMLC_REGISTER_PARAMETER(Param); * \endcode * * After that, the Param struct will get all the functions defined in Parameter. * \tparam PType the type of parameter struct * * \sa DMLC_DECLARE_FIELD, DMLC_REGISTER_PARAMETER, DMLC_DECLARE_PARAMETER */ template struct Parameter { public: /*! * \brief initialize the parameter by keyword arguments. * This function will initialize the parameter struct, check consistency * and throw error if something wrong happens. * * \param kwargs map of keyword arguments, or vector of pairs * \tparam Container container type * \throw ParamError when something go wrong. */ template inline void Init(const Container &kwargs) { PType::__MANAGER__()->RunInit(static_cast(this), kwargs.begin(), kwargs.end(), NULL); } /*! * \brief initialize the parameter by keyword arguments. * This is same as Init, but allow unknown arguments. * * \param kwargs map of keyword arguments, or vector of pairs * \tparam Container container type * \throw ParamError when something go wrong. * \return vector of pairs of unknown arguments. */ template inline std::vector > InitAllowUnknown(const Container &kwargs) { std::vector > unknown; PType::__MANAGER__()->RunInit(static_cast(this), kwargs.begin(), kwargs.end(), &unknown); return unknown; } /*! * \brief Return a dictionary representation of the parameters * \return A dictionary that maps key -> value */ inline std::map __DICT__() const { std::vector > vec = PType::__MANAGER__()->GetDict(this->head()); return std::map(vec.begin(), vec.end()); } /*! * \brief Write the parameters in JSON format. * \param writer JSONWriter used for writing. */ inline void Save(dmlc::JSONWriter *writer) const { writer->Write(this->__DICT__()); } /*! * \brief Load the parameters from JSON. * \param reader JSONReader used for loading. * \throw ParamError when something go wrong. */ inline void Load(dmlc::JSONReader *reader) { std::map kwargs; reader->Read(&kwargs); this->Init(kwargs); } /*! * \brief Get the fields of the parameters. * \return List of ParamFieldInfo of each field. */ inline static std::vector __FIELDS__() { return PType::__MANAGER__()->GetFieldInfo(); } /*! * \brief Print docstring of the parameter * \return the printed docstring */ inline static std::string __DOC__() { std::ostringstream os; PType::__MANAGER__()->PrintDocString(os); return os.str(); } protected: /*! * \brief internal function to allow declare of a parameter memember * \param manager the parameter manager * \param key the key name of the parameter * \param ref the reference to the parameter in the struct. */ template inline parameter::FieldEntry& DECLARE( parameter::ParamManagerSingleton *manager, const std::string &key, DType &ref) { // NOLINT(*) parameter::FieldEntry *e = new parameter::FieldEntry(); e->Init(key, this->head(), ref); manager->manager.AddEntry(key, e); return *e; } private: /*! \return Get head pointer of child structure */ inline PType *head() const { return static_cast(const_cast*>(this)); } }; //! \cond Doxygen_Suppress /*! * \brief macro used to declare parameter * * Example: * \code * struct Param : public dmlc::Parameter { * // declare parameters in header file * DMLC_DECLARE_PARAMETER(Param) { * // details of declarations * } * }; * \endcode * * This macro need to be put in a source file so that registeration only happens once. * Refer to example code in Parameter for details * * \param PType the name of parameter struct. * \sa Parameter */ #define DMLC_DECLARE_PARAMETER(PType) \ static ::dmlc::parameter::ParamManager *__MANAGER__(); \ inline void __DECLARE__(::dmlc::parameter::ParamManagerSingleton *manager) \ /*! * \brief macro to declare fields * \param FieldName the name of the field. */ #define DMLC_DECLARE_FIELD(FieldName) this->DECLARE(manager, #FieldName, FieldName) /*! * \brief Macro used to register parameter. * * This macro need to be put in a source file so that registeration only happens once. * Refer to example code in Parameter for details * \param PType the type of parameter struct. * \sa Parameter */ #define DMLC_REGISTER_PARAMETER(PType) \ ::dmlc::parameter::ParamManager *PType::__MANAGER__() { \ static ::dmlc::parameter::ParamManagerSingleton inst(#PType); \ return &inst.manager; \ } \ static ::dmlc::parameter::ParamManager &__make__ ## PType ## ParamManager__ = \ (*PType::__MANAGER__()) \ //! \endcond /*! * \brief internal namespace for parameter manangement * There is no need to use it directly in normal case */ namespace parameter { /*! * \brief FieldAccessEntry interface to help manage the parameters * Each entry can be used to access one parameter in the Parameter struct. * * This is an internal interface used that is used to manage parameters */ class FieldAccessEntry { public: FieldAccessEntry() : has_default_(false) {} /*! \brief destructor */ virtual ~FieldAccessEntry() {} /*! * \brief set the default value. * \param head the pointer to the head of the struct * \throw error if no default is presented */ virtual void SetDefault(void *head) const = 0; /*! * \brief set the parameter by string value * \param head the pointer to the head of the struct * \param value the value to be set */ virtual void Set(void *head, const std::string &value) const = 0; // check if value is OK virtual void Check(void *head) const {} /*! * \brief get the string representation of value. * \param head the pointer to the head of the struct */ virtual std::string GetStringValue(void *head) const = 0; /*! * \brief Get field information * \return the corresponding field information */ virtual ParamFieldInfo GetFieldInfo() const = 0; protected: /*! \brief whether this parameter have default value */ bool has_default_; /*! \brief positional index of parameter in struct */ size_t index_; /*! \brief parameter key name */ std::string key_; /*! \brief parameter type */ std::string type_; /*! \brief description of the parameter */ std::string description_; /*! * \brief print string representation of default value * \parma os the stream to print the docstring to. */ virtual void PrintDefaultValueString(std::ostream &os) const = 0; // NOLINT(*) // allow ParamManager to modify self friend class ParamManager; }; /*! * \brief manager class to handle parameter setting for each type * An manager will be created for each parameter types. */ class ParamManager { public: /*! \brief destructor */ ~ParamManager() { for (size_t i = 0; i < entry_.size(); ++i) { delete entry_[i]; } } /*! * \brief find the access entry by parameter key * \param key the key of the parameter. * \return pointer to FieldAccessEntry, NULL if nothing is found. */ inline FieldAccessEntry *Find(const std::string &key) const { std::map::const_iterator it = entry_map_.find(key); if (it == entry_map_.end()) return NULL; return it->second; } /*! * \brief set parameter by keyword arguments. * \param head head to the parameter field. * \param begin begin iterator of original kwargs * \param end end iterator of original kwargs * \param unknown_args optional, used to hold unknown arguments * When it is specified, unknown arguments will be stored into here, instead of raise an error * \tparam RandomAccessIterator iterator type * \throw ParamError when there is unknown argument and unknown_args == NULL, or required argument is missing. */ template inline void RunInit(void *head, RandomAccessIterator begin, RandomAccessIterator end, std::vector > *unknown_args) const { std::set selected_args; for (RandomAccessIterator it = begin; it != end; ++it) { FieldAccessEntry *e = Find(it->first); if (e != NULL) { e->Set(head, it->second); e->Check(head); selected_args.insert(e); } else { if (unknown_args != NULL) { unknown_args->push_back(*it); } else { std::ostringstream os; os << "Cannot find argument \'" << it->first << "\', Possible Arguments:\n"; os << "----------------\n"; PrintDocString(os); throw dmlc::ParamError(os.str()); } } } for (std::map::const_iterator it = entry_map_.begin(); it != entry_map_.end(); ++it) { if (selected_args.count(it->second) == 0) { it->second->SetDefault(head); } } } /*! * \brief internal function to add entry to manager, * The manager will take ownership of the entry. * \param key the key to the parameters * \param e the pointer to the new entry. */ inline void AddEntry(const std::string &key, FieldAccessEntry *e) { e->index_ = entry_.size(); // TODO(bing) better error message if (entry_map_.count(key) != 0) { LOG(FATAL) << "key " << key << " has already been registered in " << name_; } entry_.push_back(e); entry_map_[key] = e; } /*! * \brief set the name of parameter manager * \param name the name to set */ inline void set_name(const std::string &name) { name_ = name; } /*! * \brief get field information of each field. * \return field information */ inline std::vector GetFieldInfo() const { std::vector ret(entry_.size()); for (size_t i = 0; i < entry_.size(); ++i) { ret[i] = entry_[i]->GetFieldInfo(); } return ret; } /*! * \brief Print readible docstring to ostream, add newline. * \parma os the stream to print the docstring to. */ inline void PrintDocString(std::ostream &os) const { // NOLINT(*) for (size_t i = 0; i < entry_.size(); ++i) { ParamFieldInfo info = entry_[i]->GetFieldInfo(); os << info.name << " : " << info.type_info_str << '\n'; if (info.description.length() != 0) { os << " " << info.description << '\n'; } } } /*! * \brief Get internal parameters in vector of pairs. * \param head the head of the struct. * \param skip_default skip the values that equals default value. * \return the parameter dictionary. */ inline std::vector > GetDict(void * head) const { std::vector > ret; for (std::map::const_iterator it = entry_map_.begin(); it != entry_map_.end(); ++it) { ret.push_back(std::make_pair(it->first, it->second->GetStringValue(head))); } return ret; } private: /*! \brief parameter struct name */ std::string name_; /*! \brief positional list of entries */ std::vector entry_; /*! \brief map of key to entry */ std::map entry_map_; }; //! \cond Doxygen_Suppress // The following piece of code will be template heavy and less documented // singleton parameter manager for certain type, used for initialization template struct ParamManagerSingleton { ParamManager manager; explicit ParamManagerSingleton(const std::string ¶m_name) { PType param; param.__DECLARE__(this); manager.set_name(param_name); } }; // Base class of FieldEntry // implement set_default template class FieldEntryBase : public FieldAccessEntry { public: // entry type typedef TEntry EntryType; // implement set value virtual void Set(void *head, const std::string &value) const { std::istringstream is(value); is >> this->Get(head); if (!is.fail()) { while (!is.eof()) { int ch = is.get(); if (ch == EOF) { is.clear(); break; } if (!isspace(ch)) { is.setstate(std::ios::failbit); break; } } } if (is.fail()) { std::ostringstream os; os << "Invalid Parameter format for " << key_ << " expect " << type_ << " but value=\'" << value<< '\''; throw dmlc::ParamError(os.str()); } } virtual std::string GetStringValue(void *head) const { std::ostringstream os; PrintValue(os, this->Get(head)); return os.str(); } virtual ParamFieldInfo GetFieldInfo() const { ParamFieldInfo info; std::ostringstream os; info.name = key_; info.type = type_; os << type_; if (has_default_) { os << ',' << " optional, default="; PrintDefaultValueString(os); } else { os << ", required"; } info.type_info_str = os.str(); info.description = description_; return info; } // implement set head to default value virtual void SetDefault(void *head) const { if (!has_default_) { std::ostringstream os; os << "Parameter " << key_ << " is not presented"; throw dmlc::ParamError(os.str()); } else { this->Get(head) = default_value_; } } // return reference of self as derived type inline TEntry &self() { return *(static_cast(this)); } // implement set_default inline TEntry &set_default(const DType &default_value) { default_value_ = default_value; has_default_ = true; // return self to allow chaining return this->self(); } // implement describe inline TEntry &describe(const std::string &description) { description_ = description; // return self to allow chaining return this->self(); } // initialization function inline void Init(const std::string &key, void *head, DType &ref) { // NOLINT(*) this->key_ = key; if (this->type_.length() == 0) { this->type_ = dmlc::type_name(); } this->offset_ = ((char*)&ref) - ((char*)head); // NOLINT(*) } protected: // print the value virtual void PrintValue(std::ostream &os, DType value) const { // NOLINT(*) os << value; } virtual void PrintDefaultValueString(std::ostream &os) const { // NOLINT(*) PrintValue(os, default_value_); } // get the internal representation of parameter // for example if this entry corresponds field param.learning_rate // then Get(¶m) will return reference to param.learning_rate inline DType &Get(void *head) const { return *(DType*)((char*)(head) + offset_); // NOLINT(*) } // internal offset of the field ptrdiff_t offset_; // default value of field DType default_value_; }; // parameter base for numeric types that have range template class FieldEntryNumeric : public FieldEntryBase { public: FieldEntryNumeric() : has_begin_(false), has_end_(false) {} // implement set_range virtual TEntry &set_range(DType begin, DType end) { begin_ = begin; end_ = end; has_begin_ = true; has_end_ = true; return this->self(); } // implement set_range virtual TEntry &set_lower_bound(DType begin) { begin_ = begin; has_begin_ = true; return this->self(); } // consistency check for numeric ranges virtual void Check(void *head) const { FieldEntryBase::Check(head); DType v = this->Get(head); if (has_begin_ && has_end_) { if (v < begin_ || v >= end_) { std::ostringstream os; os << "value " << v << "for Parameter " << this->key_ << " exceed bound [" << begin_ << ',' << end_ <<')'; throw dmlc::ParamError(os.str()); } } else if (has_begin_ && v < begin_) { std::ostringstream os; os << "value " << v << "for Parameter " << this->key_ << " should be greater equal to " << begin_; throw dmlc::ParamError(os.str()); } else if (has_end_ && v >= end_) { std::ostringstream os; os << "value " << v << "for Parameter " << this->key_ << " should be smaller than " << end_; throw dmlc::ParamError(os.str()); } } protected: // whether it have begin and end range bool has_begin_, has_end_; // data bound DType begin_, end_; }; /*! * \brief FieldEntry defines parsing and checking behavior of DType. * This class can be specialized to implement specific behavior of more settings. * \tparam DType the data type of the entry. */ template class FieldEntry : public IfThenElseType::value, FieldEntryNumeric, DType>, FieldEntryBase, DType> >::Type { }; // specialize define for int(enum) template<> class FieldEntry : public FieldEntryNumeric, int> { public: // construct FieldEntry() : is_enum_(false) {} // parent typedef FieldEntryNumeric, int> Parent; // override set virtual void Set(void *head, const std::string &value) const { if (is_enum_) { std::map::const_iterator it = enum_map_.find(value); std::ostringstream os; if (it == enum_map_.end()) { os << "Invalid Input: \'" << value; os << "\', valid values are: "; PrintEnums(os); throw dmlc::ParamError(os.str()); } else { os << it->second; Parent::Set(head, os.str()); } } else { Parent::Set(head, value); } } virtual ParamFieldInfo GetFieldInfo() const { if (is_enum_) { ParamFieldInfo info; std::ostringstream os; info.name = key_; info.type = type_; PrintEnums(os); if (has_default_) { os << ',' << "optional, default="; PrintDefaultValueString(os); } else { os << ", required"; } info.type_info_str = os.str(); info.description = description_; return info; } else { return Parent::GetFieldInfo(); } } // add enum inline FieldEntry &add_enum(const std::string &key, int value) { if ((enum_map_.size() != 0 && enum_map_.count(key) != 0) || \ enum_back_map_.count(value) != 0) { std::ostringstream os; os << "Enum " << "(" << key << ": " << value << " exisit!" << ")\n"; os << "Enums: "; for (std::map::const_iterator it = enum_map_.begin(); it != enum_map_.end(); ++it) { os << "(" << it->first << ": " << it->second << "), "; } throw dmlc::ParamError(os.str()); } enum_map_[key] = value; enum_back_map_[value] = key; is_enum_ = true; return this->self(); } protected: // enum flag bool is_enum_; // enum map std::map enum_map_; // enum map std::map enum_back_map_; // override print behavior virtual void PrintDefaultValueString(std::ostream &os) const { // NOLINT(*) os << '\''; PrintValue(os, default_value_); os << '\''; } // override print default virtual void PrintValue(std::ostream &os, int value) const { // NOLINT(*) if (is_enum_) { CHECK_NE(enum_back_map_.count(value), 0) << "Value not found in enum declared"; os << enum_back_map_.at(value); } else { os << value; } } private: inline void PrintEnums(std::ostream &os) const { // NOLINT(*) os << '{'; for (std::map::const_iterator it = enum_map_.begin(); it != enum_map_.end(); ++it) { if (it != enum_map_.begin()) { os << ", "; } os << "\'" << it->first << '\''; } os << '}'; } }; // specialize define for string template<> class FieldEntry : public FieldEntryBase, std::string> { public: // parent class typedef FieldEntryBase, std::string> Parent; // override set virtual void Set(void *head, const std::string &value) const { this->Get(head) = value; } // override print default virtual void PrintDefaultValueString(std::ostream &os) const { // NOLINT(*) os << '\'' << default_value_ << '\''; } }; // specialize define for bool template<> class FieldEntry : public FieldEntryBase, bool> { public: // parent class typedef FieldEntryBase, bool> Parent; // override set virtual void Set(void *head, const std::string &value) const { std::string lower_case; lower_case.resize(value.length()); std::transform(value.begin(), value.end(), lower_case.begin(), ::tolower); bool &ref = this->Get(head); if (lower_case == "true") { ref = true; } else if (lower_case == "false") { ref = false; } else if (lower_case == "1") { ref = true; } else if (lower_case == "0") { ref = false; } else { std::ostringstream os; os << "Invalid Parameter format for " << key_ << " expect " << type_ << " but value=\'" << value<< '\''; throw dmlc::ParamError(os.str()); } } protected: // print default string virtual void PrintValue(std::ostream &os, bool value) const { // NOLINT(*) if (value) { os << "True"; } else { os << "False"; } } }; } // namespace parameter //! \endcond // implement GetEnv template inline ValueType GetEnv(const char *key, ValueType default_value) { const char *val = getenv(key); if (val == NULL) return default_value; ValueType ret; parameter::FieldEntry e; e.Init(key, &ret, ret); e.Set(&ret, val); return ret; } } // namespace dmlc #endif // DMLC_PARAMETER_H_ //===== EXPANDED: mxnet/dmlc-core/include/dmlc/parameter.h ===== namespace dmlc { /*! * \brief Registry class. * Registry can be used to register global singletons. * The most commonly use case are factory functions. * * \tparam EntryType Type of Registry entries, * EntryType need to name a name field. */ template class Registry { public: /*! \return list of functions in the registry */ inline static const std::vector &List() { return Get()->entry_list_; } /*! * \brief Find the entry with corresponding name. * \param name name of the function * \return the corresponding function, can be NULL */ inline static const EntryType *Find(const std::string &name) { const std::map &fmap = Get()->fmap_; typename std::map::const_iterator p = fmap.find(name); if (p != fmap.end()) { return p->second; } else { return NULL; } } /*! * \brief Internal function to register a name function under name. * \param name name of the function * \return ref to the registered entry, used to set properties */ inline EntryType &__REGISTER__(const std::string& name) { CHECK_EQ(fmap_.count(name), 0) << name << " already registered"; EntryType *e = new EntryType(); e->name = name; fmap_[name] = e; entry_list_.push_back(e); return *e; } /*! * \brief Internal function to either register or get registered entry * \param name name of the function * \return ref to the registered entry, used to set properties */ inline EntryType &__REGISTER_OR_GET__(const std::string& name) { if (fmap_.count(name) == 0) { return __REGISTER__(name); } else { return *fmap_.at(name); } } /*! * \brief get a singleton of the Registry. * This function can be defined by DMLC_ENABLE_REGISTRY. * \return get a singleton */ static Registry *Get(); private: /*! \brief list of entry types */ std::vector entry_list_; /*! \brief map of name->function */ std::map fmap_; /*! \brief constructor */ Registry() {} /*! \brief destructor */ ~Registry() { for (typename std::map::iterator p = fmap_.begin(); p != fmap_.end(); ++p) { delete p->second; } } }; /*! * \brief Common base class for function registry. * * \code * // This example demonstrates how to use Registry to create a factory of trees. * struct TreeFactory : * public FunctionRegEntryBase > { * }; * * // in a independent cc file * namespace dmlc { * DMLC_REGISTRY_ENABLE(TreeFactory); * } * // register binary tree constructor into the registry. * DMLC_REGISTRY_REGISTER(TreeFactory, TreeFactory, BinaryTree) * .describe("Constructor of BinaryTree") * .set_body([]() { return new BinaryTree(); }); * \endcode * * \tparam EntryType The type of subclass that inheritate the base. * \tparam FunctionType The function type this registry is registerd. */ template class FunctionRegEntryBase { public: /*! \brief name of the entry */ std::string name; /*! \brief description of the entry */ std::string description; /*! \brief additional arguments to the factory function */ std::vector arguments; /*! \brief Function body to create ProductType */ FunctionType body; /*! * \brief Set the function body. * \param body Function body to set. * \return reference to self. */ inline EntryType &set_body(FunctionType body) { this->body = body; return this->self(); } /*! * \brief Describe the function. * \param description The description of the factory function. * \return reference to self. */ inline EntryType &describe(const std::string &description) { this->description = description; return this->self(); } /*! * \brief Add argument information to the function. * \param name Name of the argument. * \param type Type of the argument. * \param description Description of the argument. * \return reference to self. */ inline EntryType &add_argument(const std::string &name, const std::string &type, const std::string &description) { ParamFieldInfo info; info.name = name; info.type = type; info.type_info_str = info.type; info.description = description; arguments.push_back(info); return this->self(); } /*! * \brief Append list if arguments to the end. * \param args Additional list of arguments. * \return reference to self. */ inline EntryType &add_arguments(const std::vector &args) { arguments.insert(arguments.end(), args.begin(), args.end()); return this->self(); } protected: /*! * \return reference of self as derived type */ inline EntryType &self() { return *(static_cast(this)); } }; /*! * \brief Macro to enable the registry of EntryType. * This macro must be used under namespace dmlc, and only used once in cc file. * \param EntryType Type of registry entry */ #define DMLC_REGISTRY_ENABLE(EntryType) \ template<> \ Registry *Registry::Get() { \ static Registry inst; \ return &inst; \ } \ /*! * \brief Generic macro to register an EntryType * There is a complete example in FactoryRegistryEntryBase. * * \param EntryType The type of registry entry. * \param EntryTypeName The typename of EntryType, must do not contain namespace :: . * \param Name The name to be registered. * \sa FactoryRegistryEntryBase */ #define DMLC_REGISTRY_REGISTER(EntryType, EntryTypeName, Name) \ static EntryType & __make_ ## EntryTypeName ## _ ## Name ## __ = \ ::dmlc::Registry::Get()->__REGISTER__(#Name) \ } // namespace dmlc #endif // DMLC_REGISTRY_H_ //===== EXPANDED: mxnet/dmlc-core/include/dmlc/registry.h ===== //===== EXPANDIND: mxnet/include/mxnet/base.h ===== /*! * Copyright (c) 2015 by Contributors * \file base.h * \brief configuation of mxnet as well as basic data structure. */ #ifndef MXNET_BASE_H_ #define MXNET_BASE_H_ //===== EXPANDIND: mxnet/dmlc-core/include/dmlc/io.h ===== /*! * Copyright (c) 2015 by Contributors * \file io.h * \brief defines serializable interface of dmlc */ #ifndef DMLC_IO_H_ #define DMLC_IO_H_ // include uint64_t only to make io standalone #ifdef _MSC_VER /*! \brief uint64 */ typedef unsigned __int64 uint64_t; #else #endif /*! \brief namespace for dmlc */ namespace dmlc { /*! * \brief interface of stream I/O for serialization */ class Stream { // NOLINT(*) public: /*! * \brief reads data from a stream * \param ptr pointer to a memory buffer * \param size block size * \return the size of data read */ virtual size_t Read(void *ptr, size_t size) = 0; /*! * \brief writes data to a stream * \param ptr pointer to a memory buffer * \param size block size */ virtual void Write(const void *ptr, size_t size) = 0; /*! \brief virtual destructor */ virtual ~Stream(void) {} /*! * \brief generic factory function * create an stream, the stream will close the underlying files upon deletion * * \param uri the uri of the input currently we support * hdfs://, s3://, and file:// by default file:// will be used * \param flag can be "w", "r", "a" * \param allow_null whether NULL can be returned, or directly report error * \return the created stream, can be NULL when allow_null == true and file do not exist */ static Stream *Create(const char *uri, const char* const flag, bool allow_null = false); // helper functions to write/read different data structures /*! * \brief writes a data to stream * * dmlc::Stream support Write/Read of most STL * composites and base types. * If the data type is not supported, a compile time error will * be issued. * * \param data data to be written * \tparam T the data type to be written */ template inline void Write(const T &data); /*! * \brief loads a data from stream. * * dmlc::Stream support Write/Read of most STL * composites and base types. * If the data type is not supported, a compile time error will * be issued. * * \param out_data place holder of data to be deserialized * \return whether the load was successful */ template inline bool Read(T *out_data); }; /*! \brief interface of i/o stream that support seek */ class SeekStream: public Stream { public: // virtual destructor virtual ~SeekStream(void) {} /*! \brief seek to certain position of the file */ virtual void Seek(size_t pos) = 0; /*! \brief tell the position of the stream */ virtual size_t Tell(void) = 0; /*! * \brief generic factory function * create an SeekStream for read only, * the stream will close the underlying files upon deletion * error will be reported and the system will exit when create failed * \param uri the uri of the input currently we support * hdfs://, s3://, and file:// by default file:// will be used * \param allow_null whether NULL can be returned, or directly report error * \return the created stream, can be NULL when allow_null == true and file do not exist */ static SeekStream *CreateForRead(const char *uri, bool allow_null = false); }; /*! \brief interface for serializable objects */ class Serializable { public: /*! \brief virtual destructor */ virtual ~Serializable() {} /*! * \brief load the model from a stream * \param fi stream where to load the model from */ virtual void Load(Stream *fi) = 0; /*! * \brief saves the model to a stream * \param fo stream where to save the model to */ virtual void Save(Stream *fo) const = 0; }; /*! * \brief input split creates that allows reading * of records from split of data, * independent part that covers all the dataset * * see InputSplit::Create for definition of record */ class InputSplit { public: /*! \brief a blob of memory region */ struct Blob { /*! \brief points to start of the memory region */ void *dptr; /*! \brief size of the memory region */ size_t size; }; /*! * \brief hint the inputsplit how large the chunk size * it should return when implementing NextChunk * this is a hint so may not be enforced, * but InputSplit will try adjust its internal buffer * size to the hinted value * \param chunk_size the chunk size */ virtual void HintChunkSize(size_t chunk_size) {} /*! \brief reset the position of InputSplit to beginning */ virtual void BeforeFirst(void) = 0; /*! * \brief get the next record, the returning value * is valid until next call to NextRecord or NextChunk * caller can modify the memory content of out_rec * * For text, out_rec contains a single line * For recordio, out_rec contains one record content(with header striped) * * \param out_rec used to store the result * \return true if we can successfully get next record * false if we reached end of split * \sa InputSplit::Create for definition of record */ virtual bool NextRecord(Blob *out_rec) = 0; /*! * \brief get a chunk of memory that can contain multiple records, * the caller needs to parse the content of the resulting chunk, * for text file, out_chunk can contain data of multiple lines * for recordio, out_chunk can contain multiple records(including headers) * * This function ensures there won't be partial record in the chunk * caller can modify the memory content of out_chunk, * the memory is valid until next call to NextRecord or NextChunk * * Usually NextRecord is sufficient, NextChunk can be used by some * multi-threaded parsers to parse the input content * * \param out_chunk used to store the result * \return true if we can successfully get next record * false if we reached end of split * \sa InputSplit::Create for definition of record * \sa RecordIOChunkReader to parse recordio content from out_chunk */ virtual bool NextChunk(Blob *out_chunk) = 0; /*! \brief destructor*/ virtual ~InputSplit(void) {} /*! * \brief factory function: * create input split given a uri * \param uri the uri of the input, can contain hdfs prefix * \param part_index the part id of current input * \param num_parts total number of splits * \param type type of record * List of possible types: "text", "recordio" * - "text": * text file, each line is treated as a record * input split will split on '\\n' or '\\r' * - "recordio": * binary recordio file, see recordio.h * \return a new input split * \sa InputSplit::Type */ static InputSplit* Create(const char *uri, unsigned part_index, unsigned num_parts, const char *type); }; /*! * \brief a std::ostream class that can can wrap Stream objects, * can use ostream with that output to underlying Stream * * Usage example: * \code * * Stream *fs = Stream::Create("hdfs:///test.txt", "w"); * dmlc::ostream os(fs); * os << "hello world" << std::endl; * delete fs; * \endcode */ class ostream : public std::basic_ostream { public: /*! * \brief construct std::ostream type * \param stream the Stream output to be used * \param buffer_size internal streambuf size */ explicit ostream(Stream *stream, size_t buffer_size = (1 << 10)) : std::basic_ostream(NULL), buf_(buffer_size) { this->set_stream(stream); } // explictly synchronize the buffer virtual ~ostream() { buf_.pubsync(); } /*! * \brief set internal stream to be stream, reset states * \param stream new stream as output */ inline void set_stream(Stream *stream) { buf_.set_stream(stream); this->rdbuf(&buf_); } /*! \return how many bytes we written so far */ inline size_t bytes_written(void) const { return buf_.bytes_out(); } private: // internal streambuf class OutBuf : public std::streambuf { public: explicit OutBuf(size_t buffer_size) : stream_(NULL), buffer_(buffer_size), bytes_out_(0) { if (buffer_size == 0) buffer_.resize(2); } // set stream to the buffer inline void set_stream(Stream *stream); inline size_t bytes_out() const { return bytes_out_; } private: /*! \brief internal stream by StreamBuf */ Stream *stream_; /*! \brief internal buffer */ std::vector buffer_; /*! \brief number of bytes written so far */ size_t bytes_out_; // override sync inline int_type sync(void); // override overflow inline int_type overflow(int c); }; /*! \brief buffer of the stream */ OutBuf buf_; }; /*! * \brief a std::istream class that can can wrap Stream objects, * can use istream with that output to underlying Stream * * Usage example: * \code * * Stream *fs = Stream::Create("hdfs:///test.txt", "r"); * dmlc::istream is(fs); * is >> mydata; * delete fs; * \endcode */ class istream : public std::basic_istream { public: /*! * \brief construct std::ostream type * \param stream the Stream output to be used * \param buffer_size internal buffer size */ explicit istream(Stream *stream, size_t buffer_size = (1 << 10)) : std::basic_istream(NULL), buf_(buffer_size) { this->set_stream(stream); } virtual ~istream() {} /*! * \brief set internal stream to be stream, reset states * \param stream new stream as output */ inline void set_stream(Stream *stream) { buf_.set_stream(stream); this->rdbuf(&buf_); } /*! \return how many bytes we read so far */ inline size_t bytes_read(void) const { return buf_.bytes_read(); } private: // internal streambuf class InBuf : public std::streambuf { public: explicit InBuf(size_t buffer_size) : stream_(NULL), bytes_read_(0), buffer_(buffer_size) { if (buffer_size == 0) buffer_.resize(2); } // set stream to the buffer inline void set_stream(Stream *stream); // return how many bytes read so far inline size_t bytes_read(void) const { return bytes_read_; } private: /*! \brief internal stream by StreamBuf */ Stream *stream_; /*! \brief how many bytes we read so far */ size_t bytes_read_; /*! \brief internal buffer */ std::vector buffer_; // override underflow inline int_type underflow(); }; /*! \brief input buffer */ InBuf buf_; }; } // namespace dmlc //===== EXPANDIND: mxnet/dmlc-core/include/dmlc/serializer.h ===== /*! * Copyright (c) 2015 by Contributors * \file serializer.h * \brief serializer template class that helps serialization. * This file do not need to be directly used by most user. */ #ifndef DMLC_SERIALIZER_H_ #define DMLC_SERIALIZER_H_ #if DMLC_USE_CXX11 #endif namespace dmlc { /*! \brief internal namespace for serializers */ namespace serializer { /*! * \brief generic serialization handler * \tparam T the type to be serialized */ template struct Handler; //! \cond Doxygen_Suppress /*! * \brief Serializer that redirect calls by condition * \tparam cond the condition * \tparam Then the serializer used for then condition * \tparam Else the serializer used for else condition * \tparam Return the type of data the serializer handles */ template struct IfThenElse; template struct IfThenElse { inline static void Write(Stream *strm, const T &data) { Then::Write(strm, data); } inline static bool Read(Stream *strm, T *data) { return Then::Read(strm, data); } }; template struct IfThenElse { inline static void Write(Stream *strm, const T &data) { Else::Write(strm, data); } inline static bool Read(Stream *strm, T *data) { return Else::Read(strm, data); } }; /*! \brief Serializer for POD(plain-old-data) data */ template struct PODHandler { inline static void Write(Stream *strm, const T &data) { strm->Write(&data, sizeof(T)); } inline static bool Read(Stream *strm, T *dptr) { return strm->Read((void*)dptr, sizeof(T)) == sizeof(T); // NOLINT(*) } }; // serializer for class that have save/load function template struct SaveLoadClassHandler { inline static void Write(Stream *strm, const T &data) { data.Save(strm); } inline static bool Read(Stream *strm, T *data) { return data->Load(strm); } }; /*! * \brief dummy class for undefined serialization. * This is used to generate error message when user tries to * serialize something that is not supported. * \tparam T the type to be serialized */ template struct UndefinedSerializerFor { }; /*! * \brief Serializer handler for std::vector where T is POD type. * \tparam T element type */ template struct PODVectorHandler { inline static void Write(Stream *strm, const std::vector &vec) { uint64_t sz = static_cast(vec.size()); strm->Write(&sz, sizeof(sz)); if (sz != 0) { strm->Write(&vec[0], sizeof(T) * vec.size()); } } inline static bool Read(Stream *strm, std::vector *out_vec) { uint64_t sz; if (strm->Read(&sz, sizeof(sz)) != sizeof(sz)) return false; size_t size = static_cast(sz); out_vec->resize(size); if (sz != 0) { size_t nbytes = sizeof(T) * size; return strm->Read(&(*out_vec)[0], nbytes) == nbytes; } return true; } }; /*! * \brief Serializer handler for std::vector where T can be composed type * \tparam T element type */ template struct ComposeVectorHandler { inline static void Write(Stream *strm, const std::vector &vec) { uint64_t sz = static_cast(vec.size()); strm->Write(&sz, sizeof(sz)); for (size_t i = 0; i < vec.size(); ++i) { Handler::Write(strm, vec[i]); } } inline static bool Read(Stream *strm, std::vector *out_vec) { uint64_t sz; if (strm->Read(&sz, sizeof(sz)) != sizeof(sz)) return false; size_t size = static_cast(sz); out_vec->resize(size); for (size_t i = 0; i < size; ++i) { if (!Handler::Read(strm, &(*out_vec)[i])) return false; } return true; } }; /*! * \brief Serializer handler for std::basic_string where T is POD type. * \tparam T element type */ template struct PODStringHandler { inline static void Write(Stream *strm, const std::basic_string &vec) { uint64_t sz = static_cast(vec.length()); strm->Write(&sz, sizeof(sz)); if (sz != 0) { strm->Write(&vec[0], sizeof(T) * vec.length()); } } inline static bool Read(Stream *strm, std::basic_string *out_vec) { uint64_t sz; if (strm->Read(&sz, sizeof(sz)) != sizeof(sz)) return false; size_t size = static_cast(sz); out_vec->resize(size); if (sz != 0) { size_t nbytes = sizeof(T) * size; return strm->Read(&(*out_vec)[0], nbytes) == nbytes; } return true; } }; /*! \brief Serializer for std::pair */ template struct PairHandler { inline static void Write(Stream *strm, const std::pair &data) { Handler::Write(strm, data.first); Handler::Write(strm, data.second); } inline static bool Read(Stream *strm, std::pair *data) { return Handler::Read(strm, &(data->first)) && Handler::Read(strm, &(data->second)); } }; // set type handler that can handle most collection type case template struct CollectionHandler { inline static void Write(Stream *strm, const ContainerType &data) { typedef typename ContainerType::value_type ElemType; // dump data to vector std::vector vdata(data.begin(), data.end()); // serialize the vector Handler >::Write(strm, vdata); } inline static bool Read(Stream *strm, ContainerType *data) { typedef typename ContainerType::value_type ElemType; std::vector vdata; if (!Handler >::Read(strm, &vdata)) return false; data->clear(); data->insert(vdata.begin(), vdata.end()); return true; } }; // handler that can handle most list type case // this type insert function takes additional iterator template struct ListHandler { inline static void Write(Stream *strm, const ListType &data) { typedef typename ListType::value_type ElemType; // dump data to vector std::vector vdata(data.begin(), data.end()); // serialize the vector Handler >::Write(strm, vdata); } inline static bool Read(Stream *strm, ListType *data) { typedef typename ListType::value_type ElemType; std::vector vdata; if (!Handler >::Read(strm, &vdata)) return false; data->clear(); data->insert(data->begin(), vdata.begin(), vdata.end()); return true; } }; //! \endcond /*! * \brief generic serialization handler for type T * * User can define specialization of this class to support * composite serialization of their own class. * * \tparam T the type to be serialized */ template struct Handler { /*! * \brief write data to stream * \param strm the stream we write the data. * \param data the data obeject to be serialized */ inline static void Write(Stream *strm, const T &data) { IfThenElse::value, PODHandler, IfThenElse::value, SaveLoadClassHandler, UndefinedSerializerFor, T>, T> ::Write(strm, data); } /*! * \brief read data to stream * \param strm the stream to read the data. * \param data the pointer to the data obeject to read * \return whether the read is successful */ inline static bool Read(Stream *strm, T *data) { return IfThenElse::value, PODHandler, IfThenElse::value, SaveLoadClassHandler, UndefinedSerializerFor, T>, T> ::Read(strm, data); } }; //! \cond Doxygen_Suppress template struct Handler > { inline static void Write(Stream *strm, const std::vector &data) { IfThenElse::value, PODVectorHandler, ComposeVectorHandler, std::vector > ::Write(strm, data); } inline static bool Read(Stream *strm, std::vector *data) { return IfThenElse::value, PODVectorHandler, ComposeVectorHandler, std::vector > ::Read(strm, data); } }; template struct Handler > { inline static void Write(Stream *strm, const std::basic_string &data) { IfThenElse::value, PODStringHandler, UndefinedSerializerFor, std::basic_string > ::Write(strm, data); } inline static bool Read(Stream *strm, std::basic_string *data) { return IfThenElse::value, PODStringHandler, UndefinedSerializerFor, std::basic_string > ::Read(strm, data); } }; template struct Handler > { inline static void Write(Stream *strm, const std::pair &data) { IfThenElse::value && dmlc::is_pod::value, PODHandler >, PairHandler, std::pair > ::Write(strm, data); } inline static bool Read(Stream *strm, std::pair *data) { return IfThenElse::value && dmlc::is_pod::value, PODHandler >, PairHandler, std::pair > ::Read(strm, data); } }; template struct Handler > : public CollectionHandler > { }; template struct Handler > : public CollectionHandler > { }; template struct Handler > : public CollectionHandler > { }; template struct Handler > : public CollectionHandler > { }; template struct Handler > : public ListHandler > { }; template struct Handler > : public ListHandler > { }; #if DMLC_USE_CXX11 template struct Handler > : public CollectionHandler > { }; template struct Handler > : public CollectionHandler > { }; template struct Handler > : public CollectionHandler > { }; template struct Handler > : public CollectionHandler > { }; #endif //! \endcond } // namespace serializer } // namespace dmlc #endif // DMLC_SERIALIZER_H_ //===== EXPANDED: mxnet/dmlc-core/include/dmlc/serializer.h ===== namespace dmlc { // implementations of inline functions template inline void Stream::Write(const T &data) { serializer::Handler::Write(this, data); } template inline bool Stream::Read(T *out_data) { return serializer::Handler::Read(this, out_data); } // implementations for ostream inline void ostream::OutBuf::set_stream(Stream *stream) { if (stream_ != NULL) this->pubsync(); this->stream_ = stream; this->setp(&buffer_[0], &buffer_[0] + buffer_.size() - 1); } inline int ostream::OutBuf::sync(void) { if (stream_ == NULL) return -1; std::ptrdiff_t n = pptr() - pbase(); stream_->Write(pbase(), n); this->pbump(-static_cast(n)); bytes_out_ += n; return 0; } inline int ostream::OutBuf::overflow(int c) { *(this->pptr()) = c; std::ptrdiff_t n = pptr() - pbase(); this->pbump(-static_cast(n)); if (c == EOF) { stream_->Write(pbase(), n); bytes_out_ += n; } else { stream_->Write(pbase(), n + 1); bytes_out_ += n + 1; } return c; } // implementations for istream inline void istream::InBuf::set_stream(Stream *stream) { stream_ = stream; this->setg(&buffer_[0], &buffer_[0], &buffer_[0]); } inline int istream::InBuf::underflow() { char *bhead = &buffer_[0]; if (this->gptr() == this->egptr()) { size_t sz = stream_->Read(bhead, buffer_.size()); this->setg(bhead, bhead, bhead + sz); bytes_read_ += sz; } if (this->gptr() == this->egptr()) { return traits_type::eof(); } else { return traits_type::to_int_type(*gptr()); } } } // namespace dmlc #endif // DMLC_IO_H_ //===== EXPANDED: mxnet/dmlc-core/include/dmlc/io.h ===== //===== EXPANDIND: mxnet/mshadow/mshadow/tensor.h ===== /*! * Copyright (c) 2014 by Contributors * \file tensor.h * \brief header file of tensor data structure and functions * This lib requires explicit memory allocation and de-allocation * all the data structure Tensor, Tensor are like handles(pointers), * no memory allocation is happening during calculation * * For STL style tensor, see tensor_container.h * \author Bing Xu, Tianqi Chen */ #ifndef MSHADOW_TENSOR_H_ #define MSHADOW_TENSOR_H_ //===== EXPANDIND: mxnet/mshadow/mshadow/base.h ===== /*! * Copyright (c) 2014 by Contributors * \file base.h * \brief definitions of base types, operators, macros functions * * \author Bing Xu, Tianqi Chen */ #ifndef MSHADOW_BASE_H_ #define MSHADOW_BASE_H_ #ifdef _MSC_VER #ifndef _CRT_SECURE_NO_WARNINGS #define _CRT_SECURE_NO_WARNINGS #endif #ifndef _CRT_SECURE_NO_DEPRECATE #define _CRT_SECURE_NO_DEPRECATE #endif #define NOMINMAX #endif // macro defintiions /*! * \brief if this macro is define to be 1, * mshadow should compile without any of other libs */ #ifndef MSHADOW_STAND_ALONE #define MSHADOW_STAND_ALONE 0 #endif /*! \brief whether do padding during allocation */ #ifndef MSHADOW_ALLOC_PAD #define MSHADOW_ALLOC_PAD true #endif /*! * \brief * x dimension of data must be bigger pad_size * ratio to be alloced padded memory, * otherwise use tide allocation * for example, if pad_ratio=2, GPU memory alignement size is 32, * then we will only allocate padded memory if x dimension > 64 * set it to 0 then we will always allocate padded memory */ #ifndef MSHADOW_MIN_PAD_RATIO #define MSHADOW_MIN_PAD_RATIO 2 #endif #if MSHADOW_STAND_ALONE #define MSHADOW_USE_CBLAS 0 #define MSHADOW_USE_MKL 0 #define MSHADOW_USE_CUDA 0 #endif /*! * \brief force user to use GPU stream during computation * error will be shot when default stream NULL is used */ #ifndef MSHADOW_FORCE_STREAM #define MSHADOW_FORCE_STREAM 1 #endif /*! \brief use CBLAS for CBLAS */ #ifndef MSHADOW_USE_CBLAS #define MSHADOW_USE_CBLAS 0 #endif /*! \brief use MKL for BLAS */ #ifndef MSHADOW_USE_MKL #define MSHADOW_USE_MKL 1 #endif /*! * \brief use CUDA support, must ensure that the cuda include path is correct, * or directly compile using nvcc */ #ifndef MSHADOW_USE_CUDA #define MSHADOW_USE_CUDA 1 #endif /*! * \brief use CUDNN support, must ensure that the cudnn include path is correct */ #ifndef MSHADOW_USE_CUDNN #define MSHADOW_USE_CUDNN 0 #endif /*! * \brief seems CUDAARCH is deprecated in future NVCC * set this to 1 if you want to use CUDA version smaller than 2.0 */ #ifndef MSHADOW_OLD_CUDA #define MSHADOW_OLD_CUDA 0 #endif /*! * \brief macro to decide existence of c++11 compiler */ #ifndef MSHADOW_IN_CXX11 #define MSHADOW_IN_CXX11 (defined(__GXX_EXPERIMENTAL_CXX0X__) ||\ __cplusplus >= 201103L || defined(_MSC_VER)) #endif /*! \brief whether use SSE */ #ifndef MSHADOW_USE_SSE #define MSHADOW_USE_SSE 1 #endif /*! \brief whether use NVML to get dynamic info */ #ifndef MSHADOW_USE_NVML #define MSHADOW_USE_NVML 0 #endif // SSE is conflict with cudacc #ifdef __CUDACC__ #undef MSHADOW_USE_SSE #define MSHADOW_USE_SSE 0 #endif #if MSHADOW_USE_CBLAS extern "C" { } #elif MSHADOW_USE_MKL #endif #if MSHADOW_USE_CUDA #endif #if MSHADOW_USE_CUDNN == 1 #endif #if MSHADOW_USE_NVML #endif // -------------------------------- // MSHADOW_XINLINE is used for inlining template code for both CUDA and CPU code #ifdef MSHADOW_XINLINE #error "MSHADOW_XINLINE must not be defined" #endif #ifdef _MSC_VER #define MSHADOW_FORCE_INLINE __forceinline #pragma warning(disable : 4068) #else #define MSHADOW_FORCE_INLINE inline __attribute__((always_inline)) #endif #ifdef __CUDACC__ #define MSHADOW_XINLINE MSHADOW_FORCE_INLINE __device__ __host__ #else #define MSHADOW_XINLINE MSHADOW_FORCE_INLINE #endif /*! \brief cpu force inline */ #define MSHADOW_CINLINE MSHADOW_FORCE_INLINE #if defined(__GXX_EXPERIMENTAL_CXX0X) ||\ defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L #define MSHADOW_CONSTEXPR constexpr #else #define MSHADOW_CONSTEXPR const #endif /*! * \brief default data type for tensor string * in code release, change it to default_real_t * during development, change it to empty string so that missing * template arguments can be detected */ #ifndef MSHADOW_DEFAULT_DTYPE #define MSHADOW_DEFAULT_DTYPE = default_real_t #endif /*! * \brief DMLC marco for logging */ #ifndef MSHADOW_USE_GLOG #define MSHADOW_USE_GLOG DMLC_USE_GLOG #endif // MSHADOW_USE_GLOG /*! * \brief Protected cuda call in mshadow * \param func Expression to call. * It checks for CUDA errors after invocation of the expression. */ #define MSHADOW_CUDA_CALL(func) \ { \ cudaError_t e = (func); \ if (e == cudaErrorCudartUnloading) { \ throw dmlc::Error(cudaGetErrorString(e)); \ } \ CHECK(e == cudaSuccess) \ << "CUDA: " << cudaGetErrorString(e); \ } /*! * \brief Run function and catch error, log unknown error. * \param func Expression to call. */ #define MSHADOW_CATCH_ERROR(func) \ { \ try { \ (func); \ } catch (const dmlc::Error &e) { \ std::string what = e.what(); \ if (what.find("driver shutting down") == std::string::npos) { \ LOG(ERROR) << "Ignore CUDA Error " << what; \ } \ } \ } /*! \brief namespace for mshadow */ namespace mshadow { /*! \brief buffer size for each random number generator */ const unsigned kRandBufferSize = 1000000; /*! \brief pi */ const float kPi = 3.1415926f; /*! \brief type that will be used for index */ typedef unsigned index_t; /*! \brief float point type that will be used in default by mshadow */ typedef float default_real_t; /*! \brief namespace for operators */ namespace op { // binary operator /*! \brief mul operator */ struct mul{ /*! \brief map a, b to result using defined operation */ template MSHADOW_XINLINE static DType Map(DType a, DType b) { return a * b; } }; /*! \brief plus operator */ struct plus { /*! \brief map a, b to result using defined operation */ template MSHADOW_XINLINE static DType Map(DType a, DType b) { return a + b; } }; /*! \brief minus operator */ struct minus { /*! \brief map a, b to result using defined operation */ template MSHADOW_XINLINE static DType Map(DType a, DType b) { return a - b; } }; /*! \brief divide operator */ struct div { /*! \brief map a, b to result using defined operation */ template MSHADOW_XINLINE static DType Map(DType a, DType b) { return a / b; } }; /*! \brief get rhs */ struct right { /*! \brief map a, b to result using defined operation */ template MSHADOW_XINLINE static DType Map(DType a, DType b) { return b; } }; // unary operator/ function: example // these operators can be defined by user, // in the same style as binary and unary operator // to use, simply write F( src ) /*! \brief identity function that maps a real number to it self */ struct identity{ /*! \brief map a to result using defined operation */ template MSHADOW_XINLINE static DType Map(DType a) { return a; } }; } // namespace op /*! \brief namespace for savers */ namespace sv { /*! \brief save to saver: = */ struct saveto { /*! \brief save b to a using save method */ template MSHADOW_XINLINE static void Save(DType &a, DType b) { // NOLINT(*) a = b; } /*! \brief helper constant to use BLAS, alpha */ inline static default_real_t AlphaBLAS(void) { return 1.0f; } /*! \brief helper constant to use BLAS, beta */ inline static default_real_t BetaBLAS(void) { return 0.0f; } /*! \brief corresponding binary operator type */ typedef op::right OPType; }; /*! \brief save to saver: += */ struct plusto { /*! \brief save b to a using save method */ template MSHADOW_XINLINE static void Save(DType &a, DType b) { // NOLINT(*) a += b; } /*! \brief helper constant to use BLAS, alpha */ inline static default_real_t AlphaBLAS(void) { return 1.0f; } /*! \brief helper constant to use BLAS, beta */ inline static default_real_t BetaBLAS(void) { return 1.0f; } /*! \brief corresponding binary operator type */ typedef op::plus OPType; }; /*! \brief minus to saver: -= */ struct minusto { /*! \brief save b to a using save method */ template MSHADOW_XINLINE static void Save(DType &a, DType b) { // NOLINT(*) a -= b; } /*! \brief helper constant to use BLAS, alpha */ inline static default_real_t AlphaBLAS(void) { return -1.0f; } /*! \brief helper constant to use BLAS, beta */ inline static default_real_t BetaBLAS(void) { return 1.0f; } /*! \brief corresponding binary operator type */ typedef op::minus OPType; }; /*! \brief multiply to saver: *= */ struct multo { /*! \brief save b to a using save method */ template MSHADOW_XINLINE static void Save(DType &a, DType b) { // NOLINT(*) a *= b; } /*! \brief corresponding binary operator type */ typedef op::mul OPType; }; /*! \brief divide to saver: /= */ struct divto { /*! \brief save b to a using save method */ template MSHADOW_XINLINE static void Save(DType& a, DType b) { // NOLINT(*) a /= b; } /*! \brief corresponding binary operator type */ typedef op::div OPType; }; } // namespace sv /*! \brief namespace for potential reducer operations */ namespace red { namespace limits { /*! * \brief minimum value of certain types * \tparam DType data type */ template MSHADOW_XINLINE DType MinValue(void); /*! \brief minimum value of float */ template<> MSHADOW_XINLINE float MinValue(void) { return -FLT_MAX; } /*! \brief minimum value of double */ template<> MSHADOW_XINLINE double MinValue(void) { return -DBL_MAX; } /*! \brief minimum value of int */ template<> MSHADOW_XINLINE int MinValue(void) { return INT_MIN; } } // namespace limits /*! \brief sum reducer */ struct sum { /*! \brief do reduction into dst */ template MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*) dst += src; } /*! *\brief calculate gradient of redres with respect to redsrc, * redres: reduced result, redsrc: one of reduction element */ template MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) { return 1; } /*! *\brief set the initial value during reduction */ template MSHADOW_XINLINE static void SetInitValue(DType &initv) { // NOLINT(*) initv = 0; } }; /*! \brief maximum reducer */ struct maximum { /*! \brief do reduction into dst */ template MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*) using namespace std; dst = max(dst, src); } /*! * \brief calculate gradient of redres with respect to redsrc, * redres: reduced result, redsrc: one of reduction element */ template MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) { return redres == redsrc ? 1: 0; } /*! *\brief set the initial value during reduction */ template MSHADOW_XINLINE static void SetInitValue(DType &initv) { // NOLINT(*) initv = limits::MinValue(); } }; } // namespace red } // namespace mshadow #endif // MSHADOW_BASE_H_ //===== EXPANDED: mxnet/mshadow/mshadow/base.h ===== //===== EXPANDIND: mxnet/mshadow/mshadow/expression.h ===== /*! * Copyright (c) 2014 by Contributors * \file expression.h * \brief definitions of abstract expressions and expressions template * \author Tianqi Chen, Bing Xu */ #ifndef MSHADOW_EXPRESSION_H_ #define MSHADOW_EXPRESSION_H_ namespace mshadow { /*! * \brief namespace for abstract expressions and expressions template, * have no dependecy on tensor.h, * These data structure takes no charge in computations, * they are only used to define operations and represent expression in a symbolic way */ namespace expr { /*! \brief type of expressions */ namespace type { // type expression type are defined as bitmask // subtype relationshop kRValue < kMapper < kPull < kComplex /*! * \brief this expression directly correspnds to a data class, * can be used to assign data */ const int kRValue = 0; /*! * \brief expression contains element-wise tensor operations, * map a expression to same shape */ const int kMapper = 1; /*! * \brief expression that can be chained with other expressiones * Usually it have function Eval(i,j) defined, which pulls the result (i, j) from input * expression and output the result at certain position. */ const int kChainer = 3; /*! \brief othercase: e.g dot product */ const int kComplex = 7; } // namespace type /*! * \brief expression engine that actually interprets these expressions * this is a function template that needed to be implemented for specific expressions * \tparam Saver the save method * \tparam RValue the type of RValue to be saved * \sa namespace sv */ template struct ExpEngine; /*! \brief defines how expression exp can be evaluated and stored into dst */ // template // inline static void Eval(RValue *dst, const EType &exp); /*! * \brief base class for expression * \tparam SubType inheritated class must put their type into this parameter * \tparam DType the data type of each element in the expression * \tparam exp_type expression type, see namespace type */ template struct Exp { public: /*! \return subtype instance of current class */ inline const SubType& self(void) const { return *static_cast(this); } /*! \return reference of subtype instance of current class */ inline SubType* ptrself(void) { return static_cast(this); } }; /*! * \brief scalar expression * \tparam DType the data type of the scalar */ template struct ScalarExp: public Exp, DType, type::kMapper> { /*! \brief scalar value */ DType scalar_; /*! \brief implicit constructor, MUST NOT BE explicit */ ScalarExp(DType scalar) : scalar_(scalar) {} // NOLINT(*) }; /*! \brief create an scalar expression */ template inline ScalarExp scalar(DType s) { return ScalarExp(s); } /*! * \brief typecast expression, cast the type of elements * \tparam DstDType the target type we want to cast into * \tparam SrcDType the target type we want to cast from * \tparam EType the type of the source expression * \tparam etype the type of expression after cast */ template struct TypecastExp: public Exp, DstDType, etype> { /*! \brief expression to be typecasted */ const EType &exp; /*! \brief constructor */ explicit TypecastExp(const EType &e) : exp(e) {} }; /*! \brief create an scalar expression */ template inline TypecastExp tcast(const Exp &exp) { return TypecastExp(exp.self()); } /*! \brief represent a transpose expression of a container */ template struct TransposeExp: public Exp, DType, type::kChainer> { /*! \brief expression to be transposed */ const EType &exp; /*! \brief constructor */ explicit TransposeExp(const EType &e) : exp(e) {} /*! \brief transpose expression */ inline const EType &T(void) const { return exp; } }; /*! * \brief base class of all rvalues * \tparam Container the actually class of data container, e.g. Tensor1D * \tparam DataType the element data type of each element in the container */ template class RValueExp: public Exp { public: /*! *\brief transpose of a matrix *\return transpose of current expression */ inline const TransposeExp T(void) const { return TransposeExp(this->self()); } /*! \brief operator overload */ inline Container &operator+=(DType s) { ExpEngine::Eval(this->ptrself(), scalar(s)); return *(this->ptrself()); } /*! \brief operator overload */ inline Container &operator-=(DType s) { ExpEngine::Eval(this->ptrself(), scalar(s)); return *(this->ptrself()); } /*! \brief operator overload */ inline Container &operator*=(DType s) { ExpEngine::Eval(this->ptrself(), scalar(s)); return *(this->ptrself()); } /*! \brief operator overload */ inline Container &operator/=(DType s) { ExpEngine::Eval(this->ptrself(), scalar(s)); return *(this->ptrself()); } /*! \brief operator overload */ inline Container &__assign(DType s) { ExpEngine::Eval(this->ptrself(), scalar(s)); return *(this->ptrself()); } /*! \brief we can not define container = container */ template inline Container &__assign(const Exp &exp) { ExpEngine::Eval(this->ptrself(), exp.self()); return *(this->ptrself()); } /*! \brief operator overload, assign */ inline Container &__assign(const Exp &exp); /*! \brief implementation of operator+= */ template inline Container &operator+=(const Exp &exp) { ExpEngine::Eval(this->ptrself(), exp.self()); return *(this->ptrself()); } /*! \brief implementation of operator-= */ template inline Container &operator-=(const Exp &exp) { ExpEngine::Eval(this->ptrself(), exp.self()); return *(this->ptrself()); } /*! \brief implementation of operator*= */ template inline Container &operator*=(const Exp &exp) { ExpEngine::Eval(this->ptrself(), exp.self()); return *(this->ptrself()); } /*! \brief implementation of operator/= */ template inline Container &operator/=(const Exp &exp) { ExpEngine::Eval(this->ptrself(), exp.self()); return *(this->ptrself()); } }; /*! * \brief matrix multiplication expression dot(lhs[.T], rhs[.T]) * \tparam TA type of lhs * \tparam TB type of rhs * \tparam ltrans whether lhs is transposed * \tparam rtrans whether rhs is transposed * \tparam DType the data type of the scalar */ template struct DotExp: public Exp, DType, type::kComplex> { /*! \brief left operand */ const TA &lhs_; /*! \brief right operand */ const TB &rhs_; /*! \brief scale over result */ DType scale_; /*! \brief constructor */ explicit DotExp(const TA &lhs, const TB &rhs, DType scale) : lhs_(lhs), rhs_(rhs), scale_(scale) {} }; // definition of dot expression /*! \brief dot operator def */ template inline DotExp dot(const RValueExp &lhs, const RValueExp &rhs) { return DotExp(lhs.self(), rhs.self(), 1.0f); } /*! \brief dot operator def */ template inline DotExp dot(const TransposeExp &lhs, const RValueExp &rhs) { return DotExp(lhs.exp, rhs.self(), 1.0f); } /*! \brief dot operator def */ template inline DotExp dot(const RValueExp &lhs, const TransposeExp &rhs) { return DotExp(lhs.self(), rhs.exp, 1.0f); } /*! \brief dot operator def */ template inline DotExp dot(const TransposeExp &lhs, const TransposeExp &rhs) { return DotExp(lhs.exp, rhs.exp, 1.0f); } //--------------- // BinaryMapExp // -------------- /*! * \brief binary map expression lhs [op] rhs * \tparam OP operator * \tparam TA type of lhs * \tparam TB type of rhs * \tparam etype expression type, sa namespace::type */ template struct BinaryMapExp: public Exp, DType, etype> { /*! \brief left operand */ const TA &lhs_; /*! \brief right operand */ const TB &rhs_; /*! \brief constructor */ explicit BinaryMapExp(const TA &lhs, const TB &rhs) :lhs_(lhs), rhs_(rhs) {} }; /*! \brief make expression */ template inline BinaryMapExp MakeExp(const Exp &lhs, const Exp &rhs) { return BinaryMapExp(lhs.self(), rhs.self()); } /*! * \brief short hand for MakeExp, usage F(lhs, rhs). create a binary operation expression * \param lhs left operand * \param rhs right operand * \return the result expression * \tparam binary operator * \tparam TA lhs expression * \tparam ta lhs expression type * \tparam TB rhs expression * \tparam tb rhs expression type * \sa mshadow::op */ template inline BinaryMapExp F(const Exp &lhs, const Exp &rhs) { return MakeExp(lhs, rhs); } // operator rules /*! \brief operator overload */ template inline BinaryMapExp operator+(const Exp &lhs, const Exp &rhs) { return MakeExp(lhs, rhs); } /*! \brief operator overload */ template inline BinaryMapExp operator-(const Exp &lhs, const Exp &rhs) { return MakeExp(lhs, rhs); } /*! \brief operator overload */ template inline BinaryMapExp operator*(const Exp &lhs, const Exp &rhs) { return MakeExp(lhs, rhs); } /*! \brief operator overload */ template inline BinaryMapExp operator/(const Exp &lhs, const Exp &rhs) { return MakeExp(lhs, rhs); } //--------------- // UnaryMapExp // -------------- /*! * \brief unary map expression op(src) * \tparam OP operator * \tparam TA type of src * \tparam etype expression type, sa namespace::type */ template struct UnaryMapExp: public Exp, DType, etype> { /*! \brief source expression */ const TA &src_; /*! \brief constructor */ explicit UnaryMapExp(const TA &src) : src_(src) {} }; /*! \brief make expression */ template inline UnaryMapExp MakeExp(const Exp &src) { return UnaryMapExp(src.self()); } /*! * \brief short hand for MakeExp, usage F(src), create a unary operation expression * \param src source expression * \return the result expression * \tparam operator * \tparam TA source expression * \tparam ta source expression type * \sa mshadow::op */ template inline UnaryMapExp F(const Exp &src) { return MakeExp(src); } } // namespace expr } // namespace mshadow #endif // MSHADOW_EXPRESSION_H_ //===== EXPANDED: mxnet/mshadow/mshadow/expression.h ===== namespace mshadow { /*! \brief device name CPU */ struct cpu { /*! \brief whether this device is CPU or not */ static const bool kDevCPU = true; /*! \brief device flag number, identifies this device */ static const int kDevMask = 1 << 0; }; /*! \brief device name CPU */ struct gpu { /*! \brief whether this device is CPU or not */ static const bool kDevCPU = false; /*! \brief device flag number, identifies this device */ static const int kDevMask = 1 << 1; }; template struct Shape; /*! * \brief allow string printing of the shape * \param os the output stream * \param shape the shape * \return the ostream */ template inline std::ostream &operator<<(std::ostream &os, const Shape &shape); // NOLINT(*) /*! * \brief shape of a tensor * IMPORTANT NOTE: this shape is different from numpy.shape * shape[0] gives the lowest dimension, shape[dimension-1] gives the highest dimension * shape[k] corresponds to k-th dimension of tensor * \tparam dimension dimension of tensor */ template struct Shape { /*! \brief dimension of current shape */ static const int kDimension = dimension; /*! \brief dimension of current shape minus one */ static const int kSubdim = dimension - 1; /*! \brief storing the dimension information */ index_t shape_[kDimension]; /*! \brief default constructor, do nothing */ MSHADOW_XINLINE Shape(void) {} /*! \brief constuctor */ MSHADOW_XINLINE Shape(const Shape &s) { #pragma unroll for (int i = 0; i < kDimension; ++i) { this->shape_[i] = s[i]; } } /*! * \brief get corresponding index * \param idx dimension index * \return the corresponding dimension size */ MSHADOW_XINLINE index_t &operator[](index_t idx) { return shape_[idx]; } /*! * \brief get corresponding index * \param idx dimension index * \return the corresponding dimension size */ MSHADOW_XINLINE const index_t &operator[](index_t idx) const { return shape_[idx]; } /*! * \return whether two shape equals * \param s the shape to compare against */ MSHADOW_XINLINE bool operator==(const Shape &s) const { #pragma unroll for (int i = 0; i < kDimension; ++i) { if (s.shape_[i] != this->shape_[i]) return false; } return true; } /*! * \return whether two shape not equal * \param s the shape to compare against */ MSHADOW_XINLINE bool operator!=(const Shape &s) const { return !(*this == s); } /*! * flatten the higher dimension to second dimension, return a 2D shape * \return the flat 2d shape */ MSHADOW_XINLINE Shape<2> FlatTo2D(void) const { Shape<2> s; s.shape_[1] = this->shape_[kDimension - 1]; index_t ymax = 1; #pragma unroll for (int i = 0; i < kDimension - 1; ++i) { ymax *= this->shape_[i]; } s.shape_[0] = ymax; return s; } /*! \return number of valid elements */ MSHADOW_XINLINE size_t Size(void) const { size_t size = this->shape_[0]; #pragma unroll for (int i = 1; i < kDimension; ++i) { size *= this->shape_[i]; } return size; } /*! * \return product shape in [dimstart,dimend) * \param dimstart start dimension * \param dimend end dimension */ MSHADOW_XINLINE index_t ProdShape(int dimstart, int dimend) const { index_t num = 1; #pragma unroll for (int i = dimstart; i < dimend; ++i) { num *= this->shape_[i]; } return num; } /*! * \brief get subshape that takes off largest dimension v * \return subshape */ MSHADOW_XINLINE Shape SubShape(void) const { Shape s; // for cuda #pragma unroll for (int i = 0; i < kSubdim; ++i) { s.shape_[i] = this->shape_[i + 1]; } return s; } /*! * \brief slice the shape from start to end * \tparam dimstart start dimension * \tparam dimend end dimension * \return the sliced shape */ template MSHADOW_XINLINE Shape Slice(void) const { Shape s; #pragma unroll for (int i = dimstart; i < dimend; ++i) { s[i - dimstart] = this->shape_[i]; } return s; } //! \cond Doxygen_Suppress template friend std::ostream &operator<<(std::ostream &os, const Shape &shape); // NOLINT(*) //! \endcond }; // Shape //------------------------------------------------ // useful construction functions to generate shape //------------------------------------------------- /*! * \brief construct a one dimension shape, stride will equal s0 * \param s0 size of dimension 0 * \return the shape construction */ MSHADOW_XINLINE Shape<1> Shape1(index_t s0) { Shape<1> s; s[0] = s0; return s; } /*! * \brief construct a two dimension shape, stride will equal s0 * \param s0 size of dimension 0 * \param s1 size of dimension 1 * \return the shape construction */ MSHADOW_XINLINE Shape<2> Shape2(index_t s0, index_t s1) { Shape<2> s; s[0] = s0; s[1] = s1; return s; } /*! * \brief construct a three dimension shape, stride will equal s0 * \param s0 size of dimension 0 * \param s1 size of dimension 1 * \param s2 size of dimension 2 * \return the shape construction */ MSHADOW_XINLINE Shape<3> Shape3(index_t s0, index_t s1, index_t s2) { Shape<3> s; s[0] = s0; s[1] = s1; s[2] = s2; return s; } /*! * \brief construct a four dimension shape, stride will equal s0 * \param s0 size of dimension 0 * \param s1 size of dimension 1 * \param s2 size of dimension 2 * \param s3 size of dimension 3 * \return the shape construction */ MSHADOW_XINLINE Shape<4> Shape4(index_t s0, index_t s1, index_t s2, index_t s3) { Shape<4> s; s[0] = s0; s[1] = s1; s[2] = s2; s[3] = s3; return s; } /*! * \brief computaion stream structure, used for asynchronize computation */ template struct Stream { // this is only a dummy implementation for CPU // for GPU, the actual implementation will be specialized in tensor_gpu-inl.h /*! * \brief wait for all the computation associated * with this stream to complete */ inline void Wait(void) {} /*! * \brief query whether the the stream is idle * \return true if the stream is idle and all the job have been completed */ inline bool CheckIdle(void) { return true; } /*! \brief create a blas handle */ inline void CreateBlasHandle() {} }; /*! * \brief Tensor RValue, this is the super type of all kinds of possible tensors * \tparam Container the tensor type * \tparam Device which device the tensor is on * \tparam dimension dimension of the tensor * \tparam DType the type of elements in the tensor */ template struct TRValue: public expr::RValueExp { }; // more compact template /*! * \brief general tensor * \tparam Device which device the tensor is on * \tparam dimension dimension of the tensor * \tparam DType the type of elements in the tensor */ template struct Tensor: public TRValue, Device, dimension, DType> { public: //-------------------------------- // struct memembers //-------------------------------- /*! \brief whether current type lies in cpu */ static const bool kDevCPU = Device::kDevCPU; /*! \brief dimension of subtype */ static const int kSubdim = dimension - 1; //-------------------------------- // struct memembers //-------------------------------- /*! \brief pointer to the data */ DType *dptr_; /*! \brief shape of the tensor */ Shape shape_; /*! * \brief storing the stride information in x dimension * this is used to deal with pitch allocation in gpu or sse(align x dimension to 64bit) for efficiency */ index_t stride_; /*! * \brief stream where the computation lies * stream is a device dependency concept where each computation */ Stream *stream_; //-------------------------------- // functions //-------------------------------- /*! \brief default constructor */ MSHADOW_XINLINE Tensor(void) : stream_(NULL) {} /*! \brief constructor from shape */ MSHADOW_XINLINE Tensor(const Shape &shape) : shape_(shape), stream_(NULL) {} /*! \brief constructor from data pointer and shape, without stride */ MSHADOW_XINLINE Tensor(DType *dptr, const Shape &shape) : dptr_(dptr), shape_(shape), stride_(shape[kSubdim]), stream_(NULL) {} /*! \brief constructor from data pointer and shape, without stride */ MSHADOW_XINLINE Tensor(DType *dptr, const Shape &shape, Stream *stream) : dptr_(dptr), shape_(shape), stride_(shape[kSubdim]), stream_(stream) {} /*! \brief constructor from data pointer and shape */ MSHADOW_XINLINE Tensor(DType *dptr, const Shape &shape, index_t stride, Stream *stream) : dptr_(dptr), shape_(shape), stride_(stride), stream_(stream) {} /*! * \brief set the stream to do computation of current tensor * \param stream the computation stream */ inline void set_stream(Stream *stream) { this->stream_ = stream; } /*! * \return memory cost of the tensor, including the aligned x dimension * \tparam startdim the starting dimension */ template MSHADOW_XINLINE size_t MemSize(void) const { size_t memsz = this->stride_; #pragma unroll for (int i = startdim; i < kSubdim; ++i) { memsz *= this->shape_[i]; } return memsz; } /*! * \return whether the tensor's memory is continuous * x dimension same as stride */ MSHADOW_XINLINE bool CheckContiguous(void) const { return this->shape_[dimension - 1] == stride_; } /*! * \return memory cost of the tensor, including the aligned x dimension */ MSHADOW_XINLINE size_t MSize(void) const { return this->MemSize<0>(); } /*! * \brief return size of i-th dimension, start counting from highest dimension * \param idx the dimension count from the highest dimensin * \return the size */ MSHADOW_XINLINE index_t size(index_t idx) const { return shape_[idx]; } /*! * \brief flatten the tensor to 2 dimension, collapse the higher dimensions together * \return tensor after flatten */ MSHADOW_XINLINE Tensor FlatTo2D(void) const { return Tensor(dptr_, shape_.FlatTo2D(), stride_, stream_); } /*! * \brief get a element of dimension - 1 * \param idx index * \return the result tensor */ MSHADOW_XINLINE Tensor operator[](index_t idx) const { return Tensor(dptr_ + this->MemSize<1>() * idx, shape_.SubShape(), stride_, stream_); } /*! * \brief slice the tensor in highest dimension [begin,end) * \param begin begin position of slice * \param end end position of slice * \return tensor after slice */ MSHADOW_XINLINE Tensor Slice(index_t begin, index_t end) const { Shape s = this->shape_; s[0] = end - begin; return Tensor(dptr_ + this->MemSize<1>() * begin, s, stride_, stream_); } /*!\brief implement the assignment of same type */ inline Tensor & operator=(const Tensor &exp) { dptr_ = exp.dptr_; shape_ = exp.shape_; stride_ = exp.stride_; stream_ = exp.stream_; return *this; } /*!\brief functions to fit expression template */ template inline Tensor & operator=(const expr::Exp &exp) { return this->__assign(exp); } /*!\brief functions to fit expression template */ inline Tensor &operator=(const DType &exp) { return this->__assign(exp); } }; /* * respecialized class Tensor1D, thei is due to different implementation in operator[] */ template struct Tensor: public TRValue, Device, 1, DType> { public: DType *dptr_; Shape<1> shape_; index_t stride_; Stream *stream_; // constructor MSHADOW_XINLINE Tensor(void) : stream_(NULL) {} MSHADOW_XINLINE Tensor(const Shape<1> &shape) : shape_(shape), stream_(NULL) {} MSHADOW_XINLINE Tensor(DType *dptr, Shape<1> shape) : dptr_(dptr), shape_(shape), stride_(shape[0]), stream_(NULL) {} MSHADOW_XINLINE Tensor(DType *dptr, Shape<1> shape, index_t stride, Stream *stream) : dptr_(dptr), shape_(shape), stride_(stride), stream_(stream) {} inline void set_stream(Stream *stream) { this->stream_ = stream; } MSHADOW_XINLINE Tensor FlatTo2D(void) const { return Tensor(dptr_, shape_.FlatTo2D(), stride_, stream_); } MSHADOW_XINLINE Tensor Slice(index_t begin, index_t end) const { Shape<1> s; s[0] = end - begin; return Tensor(dptr_ + begin, s, s[0], stream_); } MSHADOW_XINLINE bool CheckContiguous(void) const { return true; } MSHADOW_XINLINE size_t MSize(void) const { return shape_[0]; } MSHADOW_XINLINE index_t size(index_t i) const { return shape_[0]; } MSHADOW_XINLINE DType &operator[](index_t idx) { return dptr_[idx]; } MSHADOW_XINLINE const DType &operator[](index_t idx) const { return dptr_[idx]; } /*!\brief implement the assignment of same type */ inline Tensor & operator=(const Tensor &exp) { dptr_ = exp.dptr_; shape_ = exp.shape_; stride_ = exp.stride_; stream_ = exp.stream_; return *this; } template inline Tensor & operator=(const expr::Exp &exp) { return this->__assign(exp); } inline Tensor &operator=(const DType &exp) { return this->__assign(exp); } }; //------------------------ // Function Declarations //----------------------- /*! * \brief initialize tensor engine, used to call intialization functions of dependent libs * this function should be called before all GPU tensor operations, * for using tensors in CPU, this call is actually not needed * \param device_id GPU device id to be choosed * \tparam Device the device type */ template inline void InitTensorEngine(int device_id = 0); /*! * \brief Shutdown tensor engine on current device * this function should be called after all GPU tensor operations, * for using tensors in CPU, this call is actually not needed * \tparam Device the device type */ template inline void ShutdownTensorEngine(void); /*! * \brief set the device of current thread to work on * \param devid the device id * \tparam Device the device type */ template inline void SetDevice(int devid); /*! * \brief create a new stream from system * \param create_blas_handle whether create blas handle in stream * \param create_dnn_handle whether create cudnn handle in stream * \return a pointer to the created stream * \tparam Device the device type */ template inline Stream *NewStream(bool create_blas_handle, bool create_dnn_handle); /*! \brief default behavior: create cublas handle */ template inline Stream *NewStream() { return NewStream(true, false); } /*! * \brief delete the computing stream * \param stream the stream parameter to be deleted */ template inline void DeleteStream(Stream *stream); /*! * \brief CPU/CPU: allocate space for CTensor, according to the shape in the obj * this function is responsible to set the stride_ in each obj.shape * \param obj the tensor object, with shape specified * \param pad whether padding dimension 0, to make last dimension aligned, * padding may help improve efficiency of matrix multiplications * if true, will allocate space with stride_ that may not equals shape[0] * if false, will allocate continuous space * \tparam dim specify the dim of tensor * \tparam DType type of element in tensor */ template inline void AllocSpace(Tensor *obj, bool pad = MSHADOW_ALLOC_PAD); /*! * \brief CPU/CPU: allocate space for CTensor, according to the shape in the obj * this function is responsible to set the stride_ in each obj.shape * \param obj the tensor object, with shape specified * \param pad whether padding dimension 0, to make last dimension aligned, * padding may help improve efficiency of matrix multiplications * if true, will allocate space with stride_ that may not equals shape[0] * if false, will allocate continuous space * \tparam dim specify the dim of tensor * \tparam DType type of element in tensor */ template inline void AllocSpace(Tensor *obj, bool pad = MSHADOW_ALLOC_PAD); /*! * \brief CPU/GPU: free the space of tensor, will set obj.dptr to NULL * \param obj the tensor object * \tparam dim specify the dim of tensor * \tparam DType type of element in tensor */ template inline void FreeSpace(Tensor *obj); /*! * \brief CPU/GPU: free the space of tensor, will set obj.dptr to NULL * \param obj the tensor object * \tparam dim specify the dim of tensor * \tparam DType type of element in tensor */ template inline void FreeSpace(Tensor *obj); /*! * \brief CPU/GPU: short cut to allocate and initialize a Tensor * \param shape: shape of tensor * \param initv: initialization value * \param pad : padding option * \param stream : stream of tensor * \tparam Device device of tensor * \tparam DType type of element in tensor * \tparam dim dimention of tensor * \return a new allocated tensor * \sa AllocSpace */ template inline Tensor NewTensor(const Shape &shape, DType initv, bool pad = MSHADOW_ALLOC_PAD, Stream *stream = NULL); /*! * \brief copy data from one tensor to another, with same shape * \param dst target tensor * \param src source tensor * \param stream the stream, when specified, the copy can exhibit asynchronize behavior * \tparam dim specify the dim of tensor * \tparam DType type of element in tensor */ template inline void Copy(Tensor dst, const Tensor &src, Stream *stream = NULL); /*! * \brief copy data from one tensor to another, with same shape * \param dst target tensor * \param src source tensor * \param stream the stream, when specified, the copy can exhibit asynchronize behavior * \tparam dim specify the dim of tensor * \tparam DType type of element in tensor */ template inline void Copy(Tensor dst, const Tensor &src, Stream *stream = NULL); /*! * \brief copy data from one tensor to another, with same shape * \param dst target tensor * \param src source tensor * \param stream the stream, when specified, the copy can exhibit asynchronize behavior * \tparam dim specify the dim of tensor * \tparam DType type of element in tensor */ template inline void Copy(Tensor dst, const Tensor &src, Stream *stream = NULL); /*! * \brief copy data from one tensor to another, with same shape * \param dst target tensor * \param src source tensor * \param stream the stream, when specified, the copy can exhibit asynchronize behavior * \tparam dim specify the dim of tensor * \tparam DType type of element in tensor */ template inline void Copy(Tensor dst, const Tensor &src, Stream *stream = NULL); /*! * \brief CPU/GPU: normalize softmax: dst[i][j] = exp(energy[i][j]) /(sum_j exp(energy[i][j])) * \param dst destination * \param energy input energy */ template inline void Softmax(Tensor dst, const Tensor &energy); /*! * \brief CPU/GPU: normalize softmax: dst[i][j] = exp(energy[i][j]) /(sum_j exp(energy[i][j])) * \param dst destination * \param energy input energy */ template inline void Softmax(Tensor dst, const Tensor &energy); /*! * \brief CPU/GPU: softmax gradient * \param dst destination * \param src source output * \param label label info */ template inline void SoftmaxGrad(Tensor dst, const Tensor &src, const Tensor &label); /*! * \brief CPU/GPU: softmax gradient * \param dst destination * \param src source output * \param label label info */ template inline void SoftmaxGrad(Tensor dst, const Tensor &src, const Tensor &label); // function declarations to support expression, no need to understand them // these functions do not need to be directly used /*! * \brief CPU/GPU: map a expression to a tensor, this function calls MapPlan * \tparam Saver specify storage method * \tparam R specifies the storage type of the tensor * \tparam dim dim of the tensor, during usage, there is no need to specify this parameter * \tparam DType the type of elements in the tensor * \tparam E specifies the expression type, not need to specify this parameter during usage * \tparam etype expression type * \param dst destination * \param exp expression * \sa namespace mshadow:sv, mshadow::op, mshadow::expr */ template inline void MapExp(TRValue *dst, const expr::Exp &exp); /*! * \brief CPU/GPU: map a expression to a tensor, this function calls MapPlan * \tparam Saver specify storage method * \tparam R specifies the storage type of the tensor * \tparam dim dim of the tensor, during usage, there is no need to specify this parameter * \tparam DType the type of elements in the tensor * \tparam E specifies the expression type, not need to specify this parameter during usage * \tparam etype expression type * \param dst destination * \param exp expression * \sa namespace mshadow:sv, mshadow::op, mshadow::expr */ template inline void MapExp(TRValue *dst, const expr::Exp &exp); /*! * \brief CPU/GPU: map a expression, do reduction to 1D Tensor in lowest dimension (dimension 0) * \tparam Saver specify storage method * \tparam Reducer specify a reducer method * \tparam R specifies the storage type of the tensor * \tparam DType the type of elements in the tensor * \tparam E specifies the expression type, not need to specify this parameter during usage * \tparam etype expression type * \param dst destination * \param exp expression * \param scale scale the result before save * \sa namespace mshadow:sv, mshadow::op, mshadow::red, mshadow::expr */ template inline void MapReduceKeepLowest(TRValue *dst, const expr::Exp &exp, DType scale = 1); /*! * \brief CPU/GPU: map a expression, do reduction to 1D Tensor in lowest dimension (dimension 0) * \tparam Saver specify storage method * \tparam Reducer specify a reducer method * \tparam R specifies the storage type of the tensor * \tparam DType the type of elements in the tensor * \tparam E specifies the expression type, not need to specify this parameter during usage * \tparam etype expression type * \param dst destination * \param exp expression * \param scale scale the result before save * \sa namespace mshadow:sv, mshadow::op, mshadow::red, mshadow::expr */ template inline void MapReduceKeepLowest(TRValue *dst, const expr::Exp &exp, DType scale = 1); /*! * \brief CPU/GPU: map a expression, do reduction to 1D Tensor in third dimension (dimension 2) * \tparam Saver specify storage method * \tparam Reducer specify a reducer method * \tparam R specifies the storage type of the tensor * \tparam DType the type of elements in the tensor * \tparam dimkeep the target dimension to be kept, should be larger than 0, for 0, use MapReduceKeepLowest * \tparam E specifies the expression type, not need to specify this parameter during usage * \tparam etype expression type * \param dst destination * \param exp expression * \param scale scale the result before save * \sa namespace mshadow:sv, mshadow::op, mshadow::red, mshadow::expr */ template inline void MapReduceKeepHighDim(TRValue *dst, const expr::Exp &exp, DType scale = 1); /*! * \brief CPU/GPU: map a expression, do reduction to 1D Tensor in third dimension (dimension 2) * \tparam Saver specify storage method * \tparam Reducer specify a reducer method * \tparam R specifies the storage type of the tensor * \tparam DType the type of elements in the tensor * \tparam dimkeep the target dimension to be kept, should be larger than 0, for 0, use MapReduceKeepLowest * \tparam E specifies the expression type, not need to specify this parameter during usage * \tparam etype expression type * \param dst destination * \param exp expression * \param scale scale the result before save * \sa namespace mshadow:sv, mshadow::op, mshadow::red, mshadow::expr */ template inline void MapReduceKeepHighDim(TRValue *dst, const expr::Exp &exp, DType scale = 1); /*! * \brief CPU/GPU: 1 dimension vector dot * \param dst Length 1 vector, used to hold the result. * \param lhs Left operand vector * \param rhs right operand vector */ template inline void VectorDot(Tensor dst, const Tensor &lhs, const Tensor &rhs); } // namespace mshadow // include headers //===== EXPANDIND: mxnet/mshadow/mshadow/stream_gpu-inl.h ===== /*! * Copyright (c) 2014 by Contributors * \file stream_gpu-inl.h * \brief implementation of GPU code * \author Bing Xu, Tianqi Chen */ #ifndef MSHADOW_STREAM_GPU_INL_H_ #define MSHADOW_STREAM_GPU_INL_H_ //===== EXPANDIND: mxnet/mshadow/mshadow/logging.h ===== /*! * Copyright (c) 2015 by Contributors * \file logging.h * \brief defines logging macros of dmlc * allows use of GLOG, fall back to internal * implementation when disabled */ #ifndef MSHADOW_LOGGING_H_ #define MSHADOW_LOGGING_H_ #ifndef DMLC_LOGGING_H_ #define DMLC_LOGGING_H_ namespace dmlc { /*! \brief taken from DMLC directly */ /*! * \brief exception class that will be thrown by * default logger if DMLC_LOG_FATAL_THROW == 1 */ struct Error : public std::runtime_error { /*! * \brief constructor * \param s the error message */ explicit Error(const std::string &s) : std::runtime_error(s) {} }; } // namespace dmlc #if defined(_MSC_VER) && _MSC_VER < 1900 #define noexcept(a) #endif #if DMLC_USE_CXX11 #define DMLC_THROW_EXCEPTION noexcept(false) #else #define DMLC_THROW_EXCEPTION #endif #if DMLC_USE_GLOG namespace dmlc { /*! \brief taken from DMLC directly */ inline void InitLogging(const char* argv0) { google::InitGoogleLogging(argv0); } } // namespace dmlc #else // use a light version of glog #if defined(_MSC_VER) #pragma warning(disable : 4722) #endif namespace dmlc { inline void InitLogging(const char* argv0) { // DO NOTHING } // Always-on checking #define CHECK(x) \ if (!(x)) \ dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "Check " \ "failed: " #x << ' ' #define CHECK_LT(x, y) CHECK((x) < (y)) #define CHECK_GT(x, y) CHECK((x) > (y)) #define CHECK_LE(x, y) CHECK((x) <= (y)) #define CHECK_GE(x, y) CHECK((x) >= (y)) #define CHECK_EQ(x, y) CHECK((x) == (y)) #define CHECK_NE(x, y) CHECK((x) != (y)) #define CHECK_NOTNULL(x) \ ((x) == NULL ? dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "Check notnull: " #x << ' ', (x) : (x)) // NOLINT(*) // Debug-only checking. #ifdef NDEBUG #define DCHECK(x) \ while (false) CHECK(x) #define DCHECK_LT(x, y) \ while (false) CHECK((x) < (y)) #define DCHECK_GT(x, y) \ while (false) CHECK((x) > (y)) #define DCHECK_LE(x, y) \ while (false) CHECK((x) <= (y)) #define DCHECK_GE(x, y) \ while (false) CHECK((x) >= (y)) #define DCHECK_EQ(x, y) \ while (false) CHECK((x) == (y)) #define DCHECK_NE(x, y) \ while (false) CHECK((x) != (y)) #else #define DCHECK(x) CHECK(x) #define DCHECK_LT(x, y) CHECK((x) < (y)) #define DCHECK_GT(x, y) CHECK((x) > (y)) #define DCHECK_LE(x, y) CHECK((x) <= (y)) #define DCHECK_GE(x, y) CHECK((x) >= (y)) #define DCHECK_EQ(x, y) CHECK((x) == (y)) #define DCHECK_NE(x, y) CHECK((x) != (y)) #endif // NDEBUG #define LOG_INFO dmlc::LogMessage(__FILE__, __LINE__) #define LOG_ERROR LOG_INFO #define LOG_WARNING LOG_INFO #define LOG_FATAL dmlc::LogMessageFatal(__FILE__, __LINE__) #define LOG_QFATAL LOG_FATAL // Poor man version of VLOG #define VLOG(x) LOG_INFO.stream() #define LOG(severity) LOG_##severity.stream() #define LG LOG_INFO.stream() #define LOG_IF(severity, condition) \ !(condition) ? (void)0 : dmlc::LogMessageVoidify() & LOG(severity) #ifdef NDEBUG #define LOG_DFATAL LOG_ERROR #define DFATAL ERROR #define DLOG(severity) true ? (void)0 : dmlc::LogMessageVoidify() & LOG(severity) #define DLOG_IF(severity, condition) \ (true || !(condition)) ? (void)0 : dmlc::LogMessageVoidify() & LOG(severity) #else #define LOG_DFATAL LOG_FATAL #define DFATAL FATAL #define DLOG(severity) LOG(severity) #define DLOG_IF(severity, condition) LOG_IF(severity, condition) #endif // Poor man version of LOG_EVERY_N #define LOG_EVERY_N(severity, n) LOG(severity) class DateLogger { public: DateLogger() { #if defined(_MSC_VER) _tzset(); #endif } const char* HumanDate() { #if defined(_MSC_VER) _strtime_s(buffer_, sizeof(buffer_)); #else time_t time_value = time(NULL); struct tm now; localtime_r(&time_value, &now); snprintf(buffer_, sizeof(buffer_), "%02d:%02d:%02d", now.tm_hour, now.tm_min, now.tm_sec); #endif return buffer_; } private: char buffer_[9]; }; class LogMessage { public: LogMessage(const char* file, int line) : #ifdef __ANDROID__ log_stream_(std::cout) #else log_stream_(std::cerr) #endif { log_stream_ << "[" << pretty_date_.HumanDate() << "] " << file << ":" << line << ": "; } ~LogMessage() { log_stream_ << "\n"; } std::ostream& stream() { return log_stream_; } protected: std::ostream& log_stream_; private: DateLogger pretty_date_; LogMessage(const LogMessage&); void operator=(const LogMessage&); }; #if DMLC_LOG_FATAL_THROW == 0 class LogMessageFatal : public LogMessage { public: LogMessageFatal(const char* file, int line) : LogMessage(file, line) {} ~LogMessageFatal() { log_stream_ << "\n"; abort(); } private: LogMessageFatal(const LogMessageFatal&); void operator=(const LogMessageFatal&); }; #else class LogMessageFatal { public: LogMessageFatal(const char* file, int line) { log_stream_ << "[" << pretty_date_.HumanDate() << "] " << file << ":" << line << ": "; } std::ostringstream &stream() { return log_stream_; } ~LogMessageFatal() DMLC_THROW_EXCEPTION { // throwing out of destructor is evil // hopefully we can do it here throw Error(log_stream_.str()); } private: std::ostringstream log_stream_; DateLogger pretty_date_; LogMessageFatal(const LogMessageFatal&); void operator=(const LogMessageFatal&); }; #endif // This class is used to explicitly ignore values in the conditional // logging macros. This avoids compiler warnings like "value computed // is not used" and "statement has no effect". class LogMessageVoidify { public: LogMessageVoidify() {} // This has to be an operator with a precedence lower than << but // higher than "?:". See its usage. void operator&(std::ostream&) {} }; } // namespace dmlc #endif #endif // DMLC_LOGGING_H_ #endif // MSHADOW_LOGGING_H_ //===== EXPANDED: mxnet/mshadow/mshadow/logging.h ===== namespace mshadow { #if MSHADOW_USE_CUDA == 1 // Stream alocation // actual implementation of GPU stream in CUDA template<> struct Stream { /*! \brief handle state */ enum HandleState { NoHandle = 0, OwnHandle = 1, }; /*! \brief cudaStream */ cudaStream_t stream_; /*! \brief cublas handle */ cublasHandle_t blas_handle_; /*! \brief cudnn handle */ #if MSHADOW_USE_CUDNN == 1 cudnnHandle_t dnn_handle_; #endif /*! \brief cublas handle ownership */ HandleState blas_handle_ownership_; /*! \brief cudnn handle ownership */ HandleState dnn_handle_ownership_; Stream(void) : stream_(0), blas_handle_ownership_(NoHandle), dnn_handle_ownership_(NoHandle) {} /*! * \brief wait for all the computation associated * with this stream to complete */ inline void Wait(void) { MSHADOW_CUDA_CALL(cudaStreamSynchronize(stream_)); } /*! * \brief query whether the the stream is idle * \return true if the stream is idle and all the job have been completed */ inline bool CheckIdle(void) { cudaError_t err = cudaStreamQuery(stream_); if (err == cudaSuccess) return true; if (err == cudaErrorNotReady) return false; LOG(FATAL) << cudaGetErrorString(err); return false; } /*! * \brief returns actual cudaStream_t given an input GPU stream pointer * \param stream pointer to GPU stream */ inline static cudaStream_t GetStream(Stream *stream) { if (stream == NULL) { #if MSHADOW_FORCE_STREAM LOG(FATAL) << "Default GPU stream was used when MSHADOW_FORCE_STREAM was on"; #endif return 0; } else { return stream->stream_; } } /*! * \brief return actual cublasHandle * \param pointer to GPU stream */ inline static cublasHandle_t GetBlasHandle(Stream *stream) { if (stream == NULL) { return 0; } else { CHECK_NE(stream->blas_handle_ownership_, NoHandle) << "No handle exist in source stream"; return stream->blas_handle_; } } /*! \brief Destory cublas handle if own it */ inline void DestoryBlasHandle() { if (blas_handle_ownership_ == OwnHandle) { cublasStatus_t err = cublasDestroy(blas_handle_); blas_handle_ownership_ = NoHandle; CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Destory cublas handle failed"; } } /*! \brief Destory original blas handle and create a new one */ inline void CreateBlasHandle() { this->DestoryBlasHandle(); cublasStatus_t err = cublasCreate(&blas_handle_); blas_handle_ownership_ = OwnHandle; CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Create cublas handle failed"; } // #if MSHADOW_USE_CUDNN && defined(__CUDACC__) #if MSHADOW_USE_CUDNN == 1 inline static cudnnHandle_t GetDnnHandle(Stream *stream) { if (stream == NULL) { return 0; } else { CHECK_NE(stream->dnn_handle_ownership_, NoHandle) << "No handle exist in source stream"; return stream->dnn_handle_; } } #endif inline void DestroyDnnHandle() { // #if MSHADOW_USE_CUDNN && defined(__CUDACC__) #if MSHADOW_USE_CUDNN == 1 if (dnn_handle_ownership_ == OwnHandle) { cudnnStatus_t err = cudnnDestroy(dnn_handle_); CHECK_EQ(err, CUDNN_STATUS_SUCCESS) << cudnnGetErrorString(err); } #endif } inline void CreateDnnHandle() { // #if MSHADOW_USE_CUDNN == 1 && defined(__CUDACC__) #if MSHADOW_USE_CUDNN == 1 this->DestroyDnnHandle(); cudnnStatus_t err = cudnnCreate(&dnn_handle_); CHECK_EQ(err, CUDNN_STATUS_SUCCESS) << cudnnGetErrorString(err); err = cudnnSetStream(dnn_handle_, stream_); CHECK_EQ(err, CUDNN_STATUS_SUCCESS) << cudnnGetErrorString(err); this->dnn_handle_ownership_ = OwnHandle; #endif } }; template<> inline Stream *NewStream(bool create_blas_handle, bool create_dnn_handle) { Stream *st = new Stream(); MSHADOW_CUDA_CALL(cudaStreamCreate(&st->stream_)); if (create_blas_handle) { st->CreateBlasHandle(); } if (create_dnn_handle) { st->CreateDnnHandle(); } return st; } template<> inline void DeleteStream(Stream *stream) { MSHADOW_CUDA_CALL(cudaStreamDestroy(stream->stream_)); stream->DestoryBlasHandle(); stream->DestroyDnnHandle(); delete stream; } #endif } // namespace mshadow #endif // MSHADOW_STREAM_GPU_INL_H_ //===== EXPANDED: mxnet/mshadow/mshadow/stream_gpu-inl.h ===== //===== EXPANDIND: mxnet/mshadow/mshadow/extension.h ===== /*! * Copyright by Contributors * \file extension.h * \brief some extension of expressions, * used to support something beyond elementwise op * \author Tianqi Chen, Bing Xu */ #ifndef MSHADOW_EXTENSION_H_ #define MSHADOW_EXTENSION_H_ //===== EXPANDIND: mxnet/mshadow/mshadow/expr_engine-inl.h ===== /*! * Copyright (c) 2014 by Contributors * \file expr_engine-inl.h * \brief definitions of how expressions should be evaluated * \author Tianqi Chen, Bing Xu */ #ifndef MSHADOW_EXPR_ENGINE_INL_H_ #define MSHADOW_EXPR_ENGINE_INL_H_ namespace mshadow { namespace expr { /*! * \brief a general class that allows extension that makes tensors of some shape * \tparam SubType type of subclass * \tparam SrcExp source expression of the MakeTensorExp, the source of operation * \tparam dim dimension of the expression * \tparam DType the type of elements */ template struct MakeTensorExp : public Exp, DType, type::kChainer> { /*! \brief the shape of this expression */ Shape shape_; /*! \brief true self of subtype */ inline const SubType& real_self(void) const{ return *static_cast(this); } }; //---------------------------------------------------------------------- // This part of code gives plan that can be used to carry out execution //--------------------------------------------------------------------- // Declarations of plans template class Plan { public: /*! * \brief evaluate the expression at index [y][x] * to be implemented by SubType, for RValue, the return type will be DType & */ MSHADOW_XINLINE DType Eval(index_t y, index_t x) const; }; // tensor plan template class Plan, DType> { public: explicit Plan(const Tensor &t) : dptr_(t.dptr_), stride_(t.stride_) {} // for RValue, the return type should be reference MSHADOW_XINLINE DType &REval(index_t y, index_t x) { return dptr_[y * stride_ + x]; } // const evaluation MSHADOW_XINLINE const DType &Eval(index_t y, index_t x) const { return dptr_[y * stride_ + x]; } private: DType *dptr_; index_t stride_; }; // special evaluation case for 1d tensor, no stride template class Plan, DType> { public: explicit Plan(const Tensor &t) : dptr_(t.dptr_) {} MSHADOW_XINLINE DType &REval(index_t y, index_t x) { return dptr_[x]; } MSHADOW_XINLINE const DType &Eval(index_t y, index_t x) const { return dptr_[x]; } private: DType *dptr_; }; // scalar template class Plan, DType> { public: explicit Plan(DType scalar) : scalar_(scalar) {} MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { return scalar_; } private: DType scalar_; }; // unary expression template class Plan, DstDType> { public: explicit Plan(const Plan &src) : src_(src) {} MSHADOW_XINLINE DstDType Eval(index_t y, index_t x) const { return static_cast(src_.Eval(y, x)); } private: Plan src_; }; // binary expression template class Plan, DType> { public: explicit Plan(const Plan &lhs, const Plan &rhs) : lhs_(lhs), rhs_(rhs) {} MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { return OP::Map(lhs_.Eval(y, x), rhs_.Eval(y, x)); } private: Plan lhs_; Plan rhs_; }; // unary expression template class Plan, DType> { public: explicit Plan(const Plan &src) : src_(src) {} MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { return OP::Map(src_.Eval(y, x)); } private: Plan src_; }; // remaps map tensor expression to subtype's plan template struct Plan, DType> { public: Plan(const Plan &src) : src_(src) {} MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { return src_.Eval(y, x); } private: Plan src_; }; // tranpsoe template class Plan, DType> { public: explicit Plan(const Plan &src) : src_(src) {} MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { return src_.Eval(x, y); } private: Plan src_; }; //---------------------------------------------------------------------- // Mappings from expression to plans //--------------------------------------------------------------------- template inline Plan, DType> MakePlan(const BinaryMapExp &e); template inline Plan, DType> MakePlan(const ScalarExp &e) { return Plan, DType>(e.scalar_); } template inline Plan, DstDType> MakePlan(const TypecastExp &e) { return Plan, DstDType>(MakePlan(e.exp)); } template inline Plan MakePlan(const RValueExp &e) { return Plan(e.self()); } template inline Plan, DType> MakePlan(const TransposeExp &e) { return Plan, DType>(MakePlan(e.exp)); } template inline Plan MakePlan(const MakeTensorExp &e) { return Plan(e.real_self()); } template inline Plan, DType> MakePlan(const UnaryMapExp &e) { return Plan, DType>(MakePlan(e.src_)); } template inline Plan, DType> MakePlan(const BinaryMapExp &e) { return Plan, DType>(MakePlan(e.lhs_), MakePlan(e.rhs_)); } //---------------------------------------------------------------- // Static Type inference and Type Checking //---------------------------------------------------------------- /*! * \brief static type inference template, * used to get the dimension of each expression, * if ExpInfo::kDim == -1, this means here are mismatch in expression * if (ExpInfo::kDevMask & cpu::kDevMask) != 0, this means this expression can be assigned to cpu * \tparam E expression */ template struct ExpInfo { static const int kDim = -1; static const int kDevMask = 0; }; template struct ExpInfo< ScalarExp > { static const int kDim = 0; static const int kDevMask = 0xffff; }; template struct ExpInfo > { static const int kDim = ExpInfo::kDim; static const int kDevMask = ExpInfo::kDevMask; }; template struct ExpInfo > { static const int kDim = ExpInfo::kDim; static const int kDevMask = ExpInfo::kDevMask; }; template struct ExpInfo > { static const int kDim = dim; static const int kDevMask = Device::kDevMask; }; template struct ExpInfo > { static const int kDimSrc = ExpInfo::kDim; static const int kDim = kDimSrc >= 0 ? dim : -1; static const int kDevMask = ExpInfo::kDevMask; }; template struct ExpInfo > { static const int kDim = ExpInfo::kDim; static const int kDevMask = ExpInfo::kDevMask; }; template struct ExpInfo > { static const int kDimLhs = ExpInfo::kDim; static const int kDimRhs = ExpInfo::kDim; static const int kDim = (kDimLhs >= 0 && kDimRhs >= 0) ?\ (kDimLhs == 0 ?\ kDimRhs :\ ((kDimRhs == 0 || kDimLhs == kDimRhs) ? kDimLhs : -1)) : -1; static const int kDevMask = ExpInfo::kDevMask & ExpInfo::kDevMask; }; /*! \brief template to do type check */ template struct TypeCheck { /*! \brief dimension of expression*/ static const int kExpDim = ExpInfo::kDim; /*! \brief whether the expression device type matches */ static const bool kDevPass = (ExpInfo::kDevMask & Device::kDevMask) != 0; /*! \brief whether the expression can be mapped to expression of dim */ static const bool kMapPass = (kExpDim == 0 || kExpDim == dim) && kDevPass; /*! \brief whether the expression can be reduced to expression of dim */ static const bool kRedPass = (kExpDim > dim) && kDevPass; }; /*! \brief used to help static type check*/ template struct TypeCheckPass; // Todo : add static assert using C++11 template<> struct TypeCheckPass {}; template<> struct TypeCheckPass { inline static void Error_All_Tensor_in_Exp_Must_Have_Same_Type(void) {} inline static void Error_TypeCheck_Not_Pass_For_Reduce_Exp(void) {} inline static void Error_Expression_Does_Not_Meet_Dimension_Req(void) {} }; //---------------------------------------------------------------- // Runtime Stream Getting //---------------------------------------------------------------- template struct StreamInfo { inline static Stream *Get(const E &t); }; template struct StreamInfo > { inline static Stream *Get(const Tensor &t) { return t.stream_; } }; //---------------------------------------------------------------- // Runtime Shape Checking //---------------------------------------------------------------- /*! * \brief runtime shape checking template * get the shape of an expression, report error if shape mismatch * \tparam dim the dimension of the shape * \tparam E expression */ template struct ShapeCheck { inline static Shape Check(const E &t); }; template struct ShapeCheck > { inline static Shape Check(const ScalarExp &exp) { // use lowest dimension to mark scalar exp Shape shape; shape[0] = 0; return shape; } }; template struct ShapeCheck > { inline static Shape Check(const TypecastExp &exp) { return ShapeCheck::Check(exp.exp); } }; template struct ShapeCheck > { inline static Shape Check(const TransposeExp &e) { // swap the lowest two dimensions Shape s = ShapeCheck::Check(e.exp); std::swap(s[0], s[1]); return s; } }; template struct ShapeCheck > { inline static Shape Check(const Tensor &t) { return t.shape_; } }; template struct ShapeCheck > { inline static Shape Check(const MakeTensorExp &t) { return t.shape_; } }; template struct ShapeCheck > { inline static Shape Check(const UnaryMapExp &t) { Shape s = ShapeCheck::Check(t.src_); return s; } }; template struct ShapeCheck > { inline static Shape Check(const BinaryMapExp &t) { Shape shape1 = ShapeCheck::Check(t.lhs_); Shape shape2 = ShapeCheck::Check(t.rhs_); if (shape1[0] == 0) return shape2; if (shape2[0] == 0) return shape1; CHECK_EQ(shape1, shape2) << "BinaryMapExp: Shapes of operands are not the same"; return shape1; } }; } // namespace expr } // namespace mshadow // include definition of dot engine //===== EXPANDIND: mxnet/mshadow/mshadow/dot_engine-inl.h ===== /*! * Copyright (c) 2014 by Contributors * \file dot_engine-inl.h * \brief definitions of how Matrix Multiplications can be evaluated * \author Tianqi Chen */ #ifndef MSHADOW_DOT_ENGINE_INL_H_ #define MSHADOW_DOT_ENGINE_INL_H_ //===== EXPANDIND: mxnet/mshadow/mshadow/extension/implicit_gemm.h ===== /*! * Copyright (c) 2014 by Contributors * \file implicit_gemm.h * \brief support for implicit GEMM operation * \author Tianqi Chen */ #ifndef MSHADOW_EXTENSION_IMPLICIT_GEMM_H_ #define MSHADOW_EXTENSION_IMPLICIT_GEMM_H_ //===== EXPANDIND: mxnet/mshadow/mshadow/packet-inl.h ===== /*! * Copyright (c) 2014 by Contributors * \file packet-inl.h * \brief Generic packet vectorization code */ #ifndef MSHADOW_PACKET_INL_H_ #define MSHADOW_PACKET_INL_H_ #ifdef __APPLE__ #else #endif namespace mshadow { /*! \brief namespace of packet math*/ namespace packet { enum PacketArch { kPlain, kSSE2, }; #if MSHADOW_USE_SSE #define MSHADOW_DEFAULT_PACKET ::mshadow::packet::kSSE2 #else #define MSHADOW_DEFAULT_PACKET ::mshadow::packet::kPlain #endif // whether packet operator is enabled. /*! * \brief Generic packet type * \tparam DType The data type of the packet. * \tparam Arch the Arch of the packet. */ template struct Packet; template struct AlignBytes { static const index_t value = 4; }; } // namespace packet } // namespace mshadow namespace mshadow { namespace packet { /*! * \brief analog to cudaMallocPitch, allocate a aligned space with num_line * lspace cells * \param out_pitch output parameter, the actuall space allocated for each line * \param lspace number of cells required for each line * \param num_line number of lines to be allocated */ inline void* AlignedMallocPitch(size_t *out_pitch, size_t lspace, size_t num_line) { const index_t bits = AlignBytes::value; const index_t mask = (1 << bits) - 1; size_t pitch = ((lspace + mask) >> bits) << bits; *out_pitch = pitch; #ifdef _MSC_VER void *res = _aligned_malloc(pitch * num_line, 1 << bits); #else void *res; int ret = posix_memalign(&res, 1 << bits, pitch * num_line); CHECK_EQ(ret, 0) << "AlignedMallocPitch failed"; #endif if (res == NULL) { LOG(FATAL) << "AlignedMallocPitch failed"; } return res; } /*! * \brief free aligned space * \param ptr pointer to space to be freed */ inline void AlignedFree(void *ptr) { #ifdef _MSC_VER _aligned_free(ptr); #else free(ptr); #endif } /*! \brief check if a pointer is aligned */ template inline bool CheckAlign(size_t pitch) { const index_t bits = AlignBytes::value; return !(pitch & ((1 << bits) - 1)); } /*! \brief check if a pointer is aligned */ template inline bool CheckAlign(void *ptr) { return CheckAlign(reinterpret_cast(ptr)); } /*! * \brief get upper bound of aligned index of size * \param size size of the array * \param fsize size of float */ template inline index_t UpperAlign(index_t size) { const index_t bits = AlignBytes::value; const index_t mask = (1 << bits) - 1; const index_t fsize = sizeof(DType); return (((size * fsize + mask) >> bits) << bits) / fsize; } /*! * \brief get lower bound of aligned index of size * \param size size of the array * \param fsize size of float */ template inline index_t LowerAlign(index_t size) { const index_t bits = AlignBytes::value; const index_t fsize = sizeof(DType); return (((size * fsize) >> bits) << bits) / fsize; } /*! * \brief generic Packet operator * \tparam OP The operator * \tparam DType The data type * \tparam Arch The architecture. */ template struct PacketOp { static const bool kEnabled = false; }; // specialization of operators template struct PacketOp { static const bool kEnabled = true; MSHADOW_CINLINE static Packet Map(const Packet& lhs, const Packet& rhs) { return lhs + rhs; } }; template struct PacketOp { static const bool kEnabled = true; MSHADOW_CINLINE static Packet Map(const Packet& lhs, const Packet& rhs) { return lhs - rhs; } }; template struct PacketOp { static const bool kEnabled = true; MSHADOW_CINLINE static Packet Map(const Packet& lhs, const Packet& rhs) { return lhs * rhs; } }; template struct PacketOp { static const bool kEnabled = true; MSHADOW_CINLINE static Packet Map(const Packet& lhs, const Packet& rhs) { return lhs / rhs; } }; template struct PacketOp { static const bool kEnabled = true; MSHADOW_CINLINE static Packet Map(const Packet& src) { return src; } }; // savers to do storage template struct Saver{ MSHADOW_CINLINE static void Save(TFloat *dst, const Packet& src) { Packet lhs = Packet::Load(dst); Packet ans = PacketOp::Map(lhs, src); ans.Store(dst); } }; template struct Saver { MSHADOW_CINLINE static void Save(TFloat *dst, const Packet& src) { src.Store(dst); } }; } // namespace packet } // namespace mshadow //===== EXPANDIND: mxnet/mshadow/mshadow/packet/plain-inl.h ===== /*! * Copyright (c) 2014 by Contributors * \file plain-inl.h * \brief support of plain packet that use the plain datatype. */ #ifndef MSHADOW_PACKET_PLAIN_INL_H_ #define MSHADOW_PACKET_PLAIN_INL_H_ namespace mshadow { namespace packet { template struct Packet { public: /*! \brief number of float in vector */ static const index_t kSize = 1; /*! \brief The internal data */ DType data_; // enable default copy constructor Packet(void) {} // constructor from the intrinsic type explicit Packet(DType data) : data_(data) {} // create a fill with the target value s MSHADOW_CINLINE static Packet Fill(DType s) { return Packet(s); } // load from address MSHADOW_CINLINE static Packet Load(const DType* src) { return Packet(*src); } // load from address MSHADOW_CINLINE static Packet LoadUnAligned(const DType* src) { return Packet(*src); } // fill it with value s MSHADOW_CINLINE Packet& operator=(DType s) { data_ = s; return *this; } // store data into dst MSHADOW_CINLINE void Store(DType* dst) const { *dst = data_; } // get the sum of all contents MSHADOW_CINLINE DType Sum() const { return data_; } }; template MSHADOW_CINLINE Packet operator+(const Packet& lhs, const Packet& rhs) { return Packet(lhs.data_ + rhs.data_); } template MSHADOW_CINLINE Packet operator-(const Packet& lhs, const Packet& rhs) { return Packet(lhs.data_ - rhs.data_); } template MSHADOW_CINLINE Packet operator*(const Packet& lhs, const Packet& rhs) { return Packet(lhs.data_ * rhs.data_); } template MSHADOW_CINLINE Packet operator/(const Packet& lhs, const Packet& rhs) { return Packet(lhs.data_ / rhs.data_); } } // namespace packet } // namespace mshadow #endif // MSHADOW_PACKET_PLAIN_INL_H_ //===== EXPANDED: mxnet/mshadow/mshadow/packet/plain-inl.h ===== #if MSHADOW_USE_SSE && !defined(__CUDACC__) //===== EXPANDIND: mxnet/mshadow/mshadow/packet/sse-inl.h ===== /*! * Copyright (c) 2014 by Contributors * \file sse-inl.h * \brief support of sse2 packet optimization of some operations * \author Tianqi Chen */ #ifndef MSHADOW_PACKET_SSE_INL_H_ #define MSHADOW_PACKET_SSE_INL_H_ namespace mshadow { namespace packet { template<> struct Packet { public: /*! \brief number of float in vector */ static const index_t kSize = 4; /*! \brief The internal data */ __m128 data_; // enable default copy constructor Packet(void) {} // constructor from the intrinsic type explicit Packet(__m128 data) : data_(data) {} // create a fill with the target value s MSHADOW_CINLINE static Packet Fill(float s) { return Packet(_mm_set1_ps(s)); } // load from address MSHADOW_CINLINE static Packet Load(const float* src) { return Packet(_mm_load_ps(src)); } // load from address MSHADOW_CINLINE static Packet LoadUnAligned(const float* src) { return Packet(_mm_loadu_ps(src)); } // fill it with value s MSHADOW_CINLINE Packet& operator=(float s) { data_ = _mm_set1_ps(s); return *this; } // store data into dst MSHADOW_CINLINE void Store(float* dst) const { _mm_store_ps(dst, data_); } // get the sum of all contents MSHADOW_CINLINE float Sum() const { __m128 ans = _mm_add_ps(data_, _mm_movehl_ps(data_, data_)); __m128 rst = _mm_add_ss(ans, _mm_shuffle_ps(ans, ans, 1)); #if defined(_MSC_VER) && (_MSC_VER <= 1500) && defined(_WIN64) return rst.m128_f32[0]; #else float rr = _mm_cvtss_f32(rst); return rr; #endif } }; /*! \brief vector real type for float */ template<> struct Packet { /*! \brief number of float in vector */ static const index_t kSize = 2; // internal data __m128d data_; // constructor Packet(void) {} explicit Packet(__m128d data) : data_(data) {} // create a fill with the target value s MSHADOW_CINLINE static Packet Fill(double s) { return Packet(_mm_set1_pd(s)); } // load from address MSHADOW_CINLINE static Packet Load(const double* src) { return Packet(_mm_load_pd(src)); } MSHADOW_CINLINE static Packet LoadUnAligned(const double* src) { return Packet(_mm_loadu_pd(src)); } // fill it with value s MSHADOW_CINLINE Packet& operator=(double s) { data_ = _mm_set1_pd(s); return *this; } // store data into dst MSHADOW_CINLINE void Store(double* dst) const { _mm_store_pd(dst, data_); } // get sum of all content inline double Sum(void) const { __m128d tmp = _mm_add_sd(data_, _mm_unpackhi_pd(data_, data_)); #if defined(_MSC_VER) && (_MSC_VER <= 1500) && defined(_WIN64) return tmp.m128d_f64[0]; #else double ans = _mm_cvtsd_f64(tmp); return ans; #endif } }; MSHADOW_CINLINE Packet operator+(const Packet& lhs, const Packet& rhs) { return Packet(_mm_add_ps(lhs.data_, rhs.data_)); } MSHADOW_CINLINE Packet operator+(const Packet& lhs, const Packet& rhs) { return Packet(_mm_add_pd(lhs.data_, rhs.data_)); } MSHADOW_CINLINE Packet operator-(const Packet& lhs, const Packet& rhs) { return Packet(_mm_sub_ps(lhs.data_, rhs.data_)); } MSHADOW_CINLINE Packet operator-(const Packet& lhs, const Packet& rhs) { return Packet(_mm_sub_pd(lhs.data_, rhs.data_)); } MSHADOW_CINLINE Packet operator*(const Packet& lhs, const Packet& rhs) { return Packet(_mm_mul_ps(lhs.data_, rhs.data_)); } MSHADOW_CINLINE Packet operator*(const Packet& lhs, const Packet& rhs) { return Packet(_mm_mul_pd(lhs.data_, rhs.data_)); } MSHADOW_CINLINE Packet operator/(const Packet& lhs, const Packet& rhs) { return Packet(_mm_div_ps(lhs.data_, rhs.data_)); } MSHADOW_CINLINE Packet operator/(const Packet& lhs, const Packet& rhs) { return Packet(_mm_div_pd(lhs.data_, rhs.data_)); } } // namespace packet } // namespace mshadow #endif // MSHADOW_PACKET_SSE_INL_H_ //===== EXPANDED: mxnet/mshadow/mshadow/packet/sse-inl.h ===== #endif namespace mshadow { namespace expr { typedef packet::PacketArch PacketArch; // same as plan, but use packet template class PacketPlan { public: /*! * \brief evaluate the expression at index [y][x], * x will be aligned to Packet::kSize */ MSHADOW_CINLINE packet::Packet EvalPacket(index_t y, index_t x) const; MSHADOW_CINLINE DType Eval(index_t y, index_t x) const; }; template class PacketPlan, DType, Arch> { public: explicit PacketPlan(const Tensor &t) :dptr_(t.dptr_), stride_(t.stride_) {} MSHADOW_CINLINE packet::Packet EvalPacket(index_t y, index_t x) const { return packet::Packet::Load(&dptr_[y * stride_ + x]); } MSHADOW_CINLINE DType Eval(index_t y, index_t x) const { return dptr_[y * stride_ + x]; } private: const DType *dptr_; index_t stride_; }; template class PacketPlan, DType, Arch> { public: explicit PacketPlan(DType scalar) : scalar_(scalar) {} MSHADOW_CINLINE packet::Packet EvalPacket(index_t y, index_t x) const { return packet::Packet::Fill(scalar_); } MSHADOW_CINLINE DType Eval(index_t y, index_t x) const { return scalar_; } private: DType scalar_; }; template class PacketPlan, DType, Arch> { public: PacketPlan(const PacketPlan &lhs, const PacketPlan &rhs) : lhs_(lhs), rhs_(rhs) {} MSHADOW_CINLINE packet::Packet EvalPacket(index_t y, index_t x) const { return packet::PacketOp::Map(lhs_.EvalPacket(y, x), rhs_.EvalPacket(y, x)); } MSHADOW_CINLINE DType Eval(index_t y, index_t x) const { return OP::Map(lhs_.Eval(y, x), rhs_.Eval(y, x)); } private: PacketPlan lhs_; PacketPlan rhs_; }; template class PacketPlan, DType, Arch> { public: PacketPlan(const PacketPlan &src) : src_(src) {} MSHADOW_CINLINE packet::Packet EvalPacket(index_t y, index_t x) const { return packet::PacketOp::Map(src_.EvalPacket(y, x)); } MSHADOW_CINLINE DType Eval(index_t y, index_t x) const { return OP::Map(src_.Eval(y, x)); } private: PacketPlan src_; }; template inline PacketPlan, DType, Arch> MakePacketPlan(const BinaryMapExp &e); template inline PacketPlan, DType, Arch> MakePacketPlan(const ScalarExp &e) { return PacketPlan, DType, Arch>(e.scalar_); } template inline PacketPlan MakePacketPlan(const RValueExp &e) { return PacketPlan(e.self()); } template inline PacketPlan MakePacketPlan(const MakeTensorExp &e) { return PacketPlan(e.real_self()); } template inline PacketPlan, DType, Arch> MakePacketPlan(const UnaryMapExp &e) { return PacketPlan, DType, Arch>(MakePacketPlan(e.src_)); } template inline PacketPlan, DType, Arch> MakePacketPlan(const BinaryMapExp &e) { return PacketPlan, DType, Arch>(MakePacketPlan(e.lhs_), MakePacketPlan(e.rhs_)); } /*! * \brief static check packet enable * * \tparam Device the type of Device * \tparam dim dimension of the tensor * \tparam E expression */ template struct PacketCheck{ static const bool kPass = false; }; template struct PacketCheck, Arch> { static const bool kPass = true; }; template struct PacketCheck, Arch> { static const bool kPass = true; }; template struct PacketCheck, Arch> { static const bool kPass = PacketCheck::kPass && packet::PacketOp::kEnabled; }; template struct PacketCheck< BinaryMapExp, Arch> { static const bool kPass = packet::PacketOp::kEnabled && PacketCheck::kPass && PacketCheck::kPass; }; //---------------------------------------------------- // Check if data is aligned and allow packet operation //---------------------------------------------------- template struct PacketAlignCheck { inline static bool Check(const E &exp) { return false; } }; template struct PacketAlignCheck, Arch> { inline static bool Check(const ScalarExp &exp) { return true; } }; template struct PacketAlignCheck, Arch> { inline static bool Check(const Tensor &t) { return packet::CheckAlign(t.dptr_) && packet::CheckAlign(t.stride_ * sizeof(DType)); } }; template struct PacketAlignCheck, Arch> { inline static bool Check(const UnaryMapExp &t) { return PacketAlignCheck::Check(t.src_); } }; template struct PacketAlignCheck, Arch> { inline static bool Check(const BinaryMapExp &t) { return PacketAlignCheck::Check(t.lhs_) && PacketAlignCheck::Check(t.rhs_); } }; /*! * \brief use PacketPlan to compute result */ template inline void MapPacketPlan(Tensor _dst, const expr::PacketPlan& plan) { Tensor dst = _dst.FlatTo2D(); const index_t xlen = packet::LowerAlign(dst.size(1)); for (index_t y = 0; y < dst.size(0); ++y) { for (index_t x = 0; x < xlen; x += packet::Packet::kSize) { packet::Saver::Save(&dst[y][x], plan.EvalPacket(y, x)); } for (index_t x = xlen; x < dst.size(1); ++x) { SV::Save(dst[y][x], plan.Eval(y, x)); } } } } // namespace expr } // namespace mshadow #endif // MSHADOW_PACKET_INL_H_ //===== EXPANDED: mxnet/mshadow/mshadow/packet-inl.h ===== namespace mshadow { namespace expr { /*! * \brief Matrix multiplication. * \tparam LhsExp type of lhs expression * \tparam LhsExp type of rhs expression * \tparam DType the type of elements */ template struct ImplicitGEMMExp: public Exp, DType, type::kChainer> { /*! \brief lhs operand */ const LhsExp &lhs_; /*! \brief rhs operand */ const RhsExp &rhs_; /*! \brief internal production size*/ index_t prod_size_; /*! \brief the shape of this expression */ Shape<2> shape_; /*! \brief constructor */ ImplicitGEMMExp(const LhsExp &lhs, const RhsExp &rhs) : lhs_(lhs), rhs_(rhs) { Shape<2> slhs = ShapeCheck<2, LhsExp>::Check(lhs_); Shape<2> srhs = ShapeCheck<2, RhsExp>::Check(rhs_); this->shape_ = mshadow::Shape2(slhs[0], srhs[1]); prod_size_ = slhs[1]; } }; template inline ImplicitGEMMExp implicit_dot(const Exp &lhs, const Exp &rhs) { TypeCheckPass::kDim == 2 && ExpInfo::kDim == 2> ::Error_Expression_Does_Not_Meet_Dimension_Req(); return ImplicitGEMMExp(lhs.self(), rhs.self()); } //---------------------- // Execution plan //---------------------- template struct Plan, DType> { public: explicit Plan(const ImplicitGEMMExp &e) : lhs_(MakePlan(e.lhs_)), rhs_(MakePlan(e.rhs_)), prod_size_(e.prod_size_), prod_size_lower_align_(packet::LowerAlign(e.prod_size_)) { } MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { typedef packet::Packet Packet; Packet sum = Packet::Fill(0); DType lhs_temp[Packet::kSize], rhs_temp[Packet::kSize]; for (index_t i = 0; i < prod_size_lower_align_; i += packet::Packet::kSize) { // unroll for (index_t j = 0; j < Packet::kSize; ++j) { lhs_temp[j] = lhs_.Eval(y, i + j); } for (index_t j = 0; j < Packet::kSize; ++j) { rhs_temp[j] = rhs_.Eval(i + j, x); } sum = sum + Packet::LoadUnAligned(lhs_temp) * Packet::LoadUnAligned(rhs_temp); } DType ret_result = sum.Sum(); for (index_t i = prod_size_lower_align_; i < prod_size_; ++i) { ret_result += lhs_.Eval(y, i) * rhs_.Eval(i, x); } return ret_result; } private: expr::Plan lhs_; expr::Plan rhs_; const index_t prod_size_; const index_t prod_size_lower_align_; }; template inline Plan, DType> MakePlan(const ImplicitGEMMExp &exp) { return Plan, DType>(exp); } template struct ShapeCheck > { inline static Shape Check(const ImplicitGEMMExp &t) { CHECK(dim == 2) << "ImplicitGEMMExp only support 2 dimension"; Shape shape1 = ShapeCheck::Check(t.lhs_); Shape shape2 = ShapeCheck::Check(t.rhs_); CHECK_EQ(shape1[1], shape2[0]) << "implicit_dot The matrix shape do not match"; return t.shape_; } }; template struct ExpInfo > { static const int kDim = 2; static const int kDevMask = ExpInfo::kDevMask & ExpInfo::kDevMask; }; } // namespace expr } // namespace mshadow #endif // MSHADOW_EXTENSION_IMPLICIT_GEMM_H_ //===== EXPANDED: mxnet/mshadow/mshadow/extension/implicit_gemm.h ===== namespace mshadow { namespace expr { //--------------------------------------------------------------------- // Matrix Multiplications, depends on BLAS Engine //--------------------------------------------------------------------- template struct DotEngine { inline static void Eval(Tensor *p_dst, const Tensor &lhs, const Tensor &rhs, DType scale); }; // handles the dot template struct BLASEngine; #if (MSHADOW_USE_MKL || MSHADOW_USE_CBLAS) template<> struct BLASEngine { inline static CBLAS_TRANSPOSE GetT(bool t) { return t ? CblasTrans : CblasNoTrans; } inline static void SetStream(Stream *stream) { } inline static void gemm(Stream *stream, bool transa, bool transb, int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc) { cblas_sgemm(CblasColMajor, GetT(transa), GetT(transb), m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); } inline static void gemm(Stream *stream, bool transa, bool transb, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc) { cblas_dgemm(CblasColMajor, GetT(transa), GetT(transb), m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); } inline static void gemv(Stream *stream, bool trans, int m, int n, float alpha, const float *A, int lda, const float *X, int incX, float beta, float *Y, int incY) { cblas_sgemv(CblasColMajor, GetT(trans), m, n, alpha, A, lda, X, incX, beta, Y, incY); } inline static void gemv(Stream *stream, bool trans, int m, int n, double alpha, const double *A, int lda, const double *X, int incX, double beta, double *Y, int incY) { cblas_dgemv(CblasColMajor, GetT(trans), m, n, alpha, A, lda, X, incX, beta, Y, incY); } inline static void ger(Stream *stream, int m, int n, float alpha, const float *X, int incX, const float *Y, int incY, float *A, int lda) { cblas_sger(CblasColMajor, m, n, alpha, X, incX, Y, incY, A, lda); } inline static void ger(Stream *stream, int m, int n, double alpha, const double *X, int incX, const double *Y, int incY, double *A, int lda) { cblas_dger(CblasColMajor, m, n, alpha, X, incX, Y, incY, A, lda); } inline static void dot(Stream *stream, int n, const float* X, int incX, const float* Y, int incY, float* ret) { *ret = cblas_sdot(n, X, incX, Y, incY); } inline static void dot(Stream *stream, int n, const double* X, int incX, const double* Y, int incY, double* ret) { *ret = cblas_ddot(n, X, incX, Y, incY); } }; #elif MSHADOW_STAND_ALONE == 1 template<> struct BLASEngine { inline static bool GetT(bool t) { return t ? true : false; } inline static void SetStream(Stream *stream) { } inline static void gemm(Stream *stream, bool transa, bool transb, int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc) { LOG(FATAL) << "Not implmented!"; } inline static void gemm(Stream *stream, bool transa, bool transb, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc) { LOG(FATAL) << "Not implmented!"; } inline static void gemv(Stream *stream, bool trans, int m, int n, float alpha, const float *A, int lda, const float *X, int incX, float beta, float *Y, int incY) { LOG(FATAL) << "Not implmented!"; } inline static void gemv(Stream *stream, bool trans, int m, int n, double alpha, const double *A, int lda, const double *X, int incX, double beta, double *Y, int incY) { LOG(FATAL) << "Not implmented!"; } inline static void ger(Stream *stream, int m, int n, float alpha, const float *X, int incX, const float *Y, int incY, float *A, int lda) { LOG(FATAL) << "Not implmented!"; } inline static void ger(Stream *stream, int m, int n, double alpha, const double *X, int incX, const double *Y, int incY, double *A, int lda) { LOG(FATAL) << "Not implmented!"; } inline static void dot(Stream *stream, int n, const float* X, int incX, const float* Y, int incY, float* ret) { LOG(FATAL) << "Not implmented!"; } inline static void dot(Stream *stream, int n, const double* X, int incX, const double* Y, int incY, double* ret) { LOG(FATAL) << "Not implmented!"; } }; #endif // MSHADOW_USE_CBLAS || MSHADOW_USE_MKL || MSHADOW_STAND_ALONE // CuBLAS redirect code #if MSHADOW_USE_CUDA // All CuBLAS goes to here, use legacy API: not threadsafe template<> struct BLASEngine { inline static cublasOperation_t GetT(bool t) { return t ? CUBLAS_OP_T : CUBLAS_OP_N; } inline static void SetStream(Stream *stream) { cublasStatus_t err = cublasSetStream(Stream::GetBlasHandle(stream), Stream::GetStream(stream)); CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas set stream fail"; } inline static void gemm(Stream *stream, bool transa, bool transb, int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc) { cublasStatus_t err = cublasSgemm(Stream::GetBlasHandle(stream), GetT(transa), GetT(transb), m, n, k, &alpha, A, lda, B, ldb, &beta, C, ldc); CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas Sgemm fail"; } inline static void gemm(Stream *stream, bool transa, bool transb, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc) { cublasStatus_t err = cublasDgemm(Stream::GetBlasHandle(stream), GetT(transa), GetT(transb), m, n, k, &alpha, A, lda, B, ldb, &beta, C, ldc); CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dgemm fail"; } inline static void gemv(Stream *stream, bool trans, int m, int n, float alpha, const float *A, int lda, const float *X, int incX, float beta, float *Y, int incY) { cublasStatus_t err = cublasSgemv(Stream::GetBlasHandle(stream), GetT(trans), m, n, &alpha, A, lda, X, incX, &beta, Y, incY); CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Sgemv fail"; } inline static void gemv(Stream *stream, bool trans, int m, int n, double alpha, const double *A, int lda, const double *X, int incX, double beta, double *Y, int incY) { cublasStatus_t err = cublasDgemv(Stream::GetBlasHandle(stream), GetT(trans), m, n, &alpha, A, lda, X, incX, &beta, Y, incY); CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dgemv fail"; } inline static void ger(Stream *stream, int m, int n, float alpha, const float *X, int incX, const float *Y, int incY, float *A, int lda) { cublasStatus_t err = cublasSger(Stream::GetBlasHandle(stream), m, n, &alpha, X, incX, Y, incY, A, lda); CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Sger fail"; } inline static void ger(Stream *stream, int m, int n, double alpha, const double *X, int incX, const double *Y, int incY, double *A, int lda) { cublasStatus_t err = cublasDger(Stream::GetBlasHandle(stream), m, n, &alpha, X, incX, Y, incY, A, lda); CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dger fail"; } inline static void dot(Stream *stream, int n, const float* X, int incX, const float* Y, int incY, float *ret) { cublasSetPointerMode(Stream::GetBlasHandle(stream), CUBLAS_POINTER_MODE_DEVICE); cublasStatus_t err = cublasSdot(Stream::GetBlasHandle(stream), n, X, incX, Y, incY, ret); CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dot fail"; cublasSetPointerMode(Stream::GetBlasHandle(stream), CUBLAS_POINTER_MODE_HOST); } inline static void dot(Stream *stream, int n, const double* X, int incX, const double* Y, int incY, double *ret) { cublasSetPointerMode(Stream::GetBlasHandle(stream), CUBLAS_POINTER_MODE_DEVICE); cublasStatus_t err = cublasDdot(Stream::GetBlasHandle(stream), n, X, incX, Y, incY, ret); CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dot fail"; cublasSetPointerMode(Stream::GetBlasHandle(stream), CUBLAS_POINTER_MODE_HOST); } }; #endif // MSHADOW_USE_CUDA // helper function to decide which shape we are in inline static Shape<2> GetShape(const Shape<2> &shape, bool transpose) { return transpose ? Shape2(shape[1], shape[0]) : shape; } // dst = dot(lhs[.T], rhs[.T]) template struct DotEngine { inline static void Eval(Tensor *p_dst, const Tensor &lhs, const Tensor &rhs, DType scale) { Tensor &dst = *p_dst; #if MSHADOW_STAND_ALONE if (xpu::kDevMask == cpu::kDevMask && scale == 1.0f) { if (!transpose_left && !transpose_right) { dst = expr::implicit_dot(lhs, rhs); return; } else if (!transpose_left && transpose_right) { dst = expr::implicit_dot(lhs, rhs.T()); return; } else if (transpose_left && !transpose_right) { dst = expr::implicit_dot(lhs.T(), rhs); return; } } #endif // set kernel stream // if there is no stream, crush BLASEngine::SetStream(dst.stream_); Shape<2> sleft = GetShape(lhs.shape_, transpose_left); Shape<2> sright = GetShape(rhs.shape_, transpose_right); CHECK(dst.size(0) == sleft[0] && dst.size(1) == sright[1] && sleft[1] == sright[0]) << "dot-gemm: matrix shape mismatch"; // use column major argument to compatible with most BLAS BLASEngine::gemm (dst.stream_, transpose_right , transpose_left, transpose_right ? rhs.size(0) : rhs.size(1), transpose_left ? lhs.size(1) : lhs.size(0), transpose_right ? rhs.size(1) : rhs.size(0), scale * SV::AlphaBLAS(), rhs.dptr_, rhs.stride_, lhs.dptr_, lhs.stride_, SV::BetaBLAS(), dst.dptr_, dst.stride_); } }; template struct DotEngine { inline static void Eval(Tensor *p_dst, const Tensor &lhs, const Tensor &rhs, DType scale) { Tensor &dst = *p_dst; // set kernel stream // if there is no stream, crush BLASEngine::SetStream(dst.stream_); Shape<2> sright = GetShape(rhs.shape, transpose_right); CHECK(dst.size(0) == sright[1] && lhs.size(0) == sright[0]) << "dot-gemv: matrix shape mismatch" << "dst: " << dst.shape_ << "\n" << "lhs: " << lhs.shape_ << "\n" << "rhs: " << sright << "\n"; BLASEngine::gemv (dst.stream_, transpose_right, rhs.size(1), rhs.size(0), scale * SV::AlphaBLAS(), rhs.dptr_, rhs.stride_, lhs.dptr_, 1, SV::BetaBLAS(), dst.dptr_, 1); } }; template struct DotEngine { inline static void Eval(Tensor *p_dst, const Tensor &lhs, const Tensor &rhs, DType scale) { Tensor &dst = *p_dst; // set kernel stream // if there is no stream, crush BLASEngine::SetStream(dst.stream_); CHECK_EQ(dst.size(0), lhs.size(0) && dst.size(1) == rhs.size(0)) << "dot-ger: matrix shape mismatch" << "dst: " << dst.shape_ << "\n" << "lhs: " << lhs.shape_ << "\n" << "rhs: " << rhs.shape_; if (SV::BetaBLAS() == 0.0f) { BLASEngine::ger (dst.stream_, rhs.size(0), lhs.size(0), scale * SV::AlphaBLAS(), rhs.dptr_, 1, lhs.dptr_, 1, dst.dptr_, dst.stride_); } else { DotEngine::Eval(dst, lhs.FlatTo2D(), rhs.FlatTo2D(), scale); } } }; } // namespace expr } // namespace mshadow #endif // MSHADOW_DOT_ENGINE_INL_H_ //===== EXPANDED: mxnet/mshadow/mshadow/dot_engine-inl.h ===== namespace mshadow { namespace expr { /*! \brief some engine that evaluate complex expression */ template struct ExpComplexEngine { inline static void Eval(RV *dst, const E &exp); }; /*! \brief the engine that dispatches simple operations*/ template struct ExpEngine { template inline static void Eval(RV *dst, const Exp &exp) { MapExp(dst, exp); } template inline static void Eval(RV *dst, const Exp &exp) { MapExp(dst, exp); } template inline static void Eval(RV *dst, const Exp &exp) { MapExp(dst, exp); } template inline static void Eval(RV *dst, const Exp &exp) { ExpComplexEngine::Eval(dst->ptrself(), exp.self()); } }; template struct ExpComplexEngine, DotExp, Tensor, ltrans, rtrans, DType>, DType> { inline static void Eval(Tensor *dst, const DotExp, Tensor, ltrans, rtrans, DType> &exp) { DotEngine::Eval(dst, exp.lhs_, exp.rhs_, exp.scale_); } }; } // namespace expr } // namespace mshadow #endif // MSHADOW_EXPR_ENGINE_INL_H_ //===== EXPANDED: mxnet/mshadow/mshadow/expr_engine-inl.h ===== //===== EXPANDIND: mxnet/mshadow/mshadow/extension/broadcast.h ===== /*! * Copyright (c) 2014 by Contributors * \file broadcast.h * \brief support for broadcast and repmat * \author Tianqi Chen */ #ifndef MSHADOW_EXTENSION_BROADCAST_H_ #define MSHADOW_EXTENSION_BROADCAST_H_ namespace mshadow { namespace expr { /*! * \brief broadcast Tensor1D into a higher dimension Tensor * input: Tensor: ishape[0] * output: Tensor : oshape[dimcast] = ishape[0] * \tparam SrcExp type of input expression * \tparam DType the type of elements * \tparam dimdst target tensor dimension * \tparam dimcast_m_dst dimcast - dimdst */ template struct Broadcast1DExp: public MakeTensorExp, SrcExp, dimdst, DType> { /*! \brief source operand */ const SrcExp &src_; /*! \brief constructor */ Broadcast1DExp(const SrcExp &src, Shape shape) : src_(src) { this->shape_ = shape; } }; /*! * \brief a expression that replicate a 1 dimension tensor in dimension dimcast * \param src Tensor: shape[0] * \param shape shape of output * \return a expresion with type Tensor * \tparam dimcast target dimension where the 1D tensor will be broadcasted * \tparam SrcExp type of input expression * \tparam DType the type of elements * \tparam dimdst dimension of destination tensor * \tparam dimcast_lowest the dimension we want to cast the data into */ template inline Broadcast1DExp broadcast(const expr::Exp &src, Shape shape) { TypeCheckPass::kDim == 1> ::Error_Expression_Does_Not_Meet_Dimension_Req(); typedef ShapeCheck<1, SrcExp> ShapeCheckDim1SrcExp; CHECK_EQ(ShapeCheckDim1SrcExp::Check(src.self())[0], shape[dimcast]) << "broadcast, shape mismatch"; return Broadcast1DExp(src.self(), shape); } // short cut functions /*! * \brief a expression that replicate a 1 dimension tensor for nrow times * \param src Tensor: shape[0] * \param nrow number of rows to replicate * \return a expresion with type Tensor size(1), size(0) = nrow * \tparam Device which device it lies */ template inline Broadcast1DExp repmat(const expr::Exp &src, index_t nrow) { return broadcast<1> (src, Shape2(nrow, ShapeCheck<1, SrcExp>::Check(src.self())[0])); } //---------------------- // Execution plan //---------------------- template struct Plan, DType> { public: static const int dimcast = dimdst - dimdst_m_cast; explicit Plan(const Broadcast1DExp &e) : src_(MakePlan(e.src_)), ystride_(e.shape_.ProdShape(dimcast + 1, dimdst - 1)), length_(e.shape_[dimcast]) { TypeCheckPass ::Error_Expression_Does_Not_Meet_Dimension_Req(); } MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { return src_.Eval(0, (y / ystride_) % length_); } private: expr::Plan src_; const index_t ystride_, length_; }; /*! \brief execution plan of Broadcast1DExp */ template struct Plan, DType>{ public: explicit Plan(const Broadcast1DExp &e) : src_(MakePlan(e.src_)) {} MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { return src_.Eval(0, x); } private: expr::Plan src_; }; } // namespace expr } // namespace mshadow #endif // MSHADOW_EXTENSION_BROADCAST_H_ //===== EXPANDED: mxnet/mshadow/mshadow/extension/broadcast.h ===== //===== EXPANDIND: mxnet/mshadow/mshadow/extension/unpack_patch2col.h ===== /*! * Copyright (c) 2014 by Contributors * \file unpack_patch2col.h * \brief support for unpack * \author Tianqi Chen */ #ifndef MSHADOW_EXTENSION_UNPACK_PATCH2COL_H_ #define MSHADOW_EXTENSION_UNPACK_PATCH2COL_H_ namespace mshadow { namespace expr { /*! * \brief unpack local (overlap) patches of image to column of mat, * can be used to implement convolution, this expression allow unpack of a batch * this is a version support unpacking multiple images * after getting unpacked mat, we can use: output = dot(weight, mat) to get covolved results, the relations: * \tparam SrcExp source expression * \tparam dstdim destination dimension */ template struct UnpackPatchToColXExp: public MakeTensorExp, SrcExp, 2, DType>{ /*! \brief source operand */ const SrcExp &img_; /*! \brief patch height */ index_t psize_y_; /*! \brief patch width */ index_t psize_x_; /*! \brief patch stride */ index_t pstride_y_; index_t pstride_x_; /*! \brief number of input channel */ index_t i_channel_; /*! \brief height of img */ index_t i_height_; /*! \brief width of img */ index_t i_width_; /*! \brief constructor */ UnpackPatchToColXExp(const SrcExp &img, index_t psize_y, index_t psize_x, index_t pstride_y, index_t pstride_x) : img_(img), psize_y_(psize_y), psize_x_(psize_x), pstride_y_(pstride_y), pstride_x_(pstride_x) { Shape imshape = ShapeCheck::Check(img_); CHECK(imshape[srcdim - 1] >= psize_x && imshape[srcdim - 2] >= psize_y) << "UnpackPatchToCol:image shape smaller than patch size"; this->i_channel_ = imshape[srcdim - 3]; this->i_height_ = imshape[srcdim - 2]; this->i_width_ = imshape[srcdim - 1]; // calculate number of batches const index_t num = imshape.ProdShape(0, srcdim - 3); const index_t o_height = (i_height_ - psize_y) / pstride_y + 1; const index_t o_width = (i_width_ - psize_x) / pstride_x + 1; this->shape_[1] = o_height * o_width * num; this->shape_[0] = psize_y * psize_x * i_channel_; } }; /*! * \brief unpack local (overlap) patches of image to column of mat, can be used to implement convolution * after getting unpacked mat, we can use: output = dot(weight, mat) to get covolved results, the relations: * * weight; shape[0]: out_channel, shape[1]: ichannel * psize_y * psize_x * output; shape[0]: out_channel, shape[1]: out_height * out_width * num_of_images * out_height = (in_height - psize_y) / pstride + 1, this means we pad inperfect patch with 0 * out_width = (in_width - psize_x) / pstride + 1 * * \return mat target matrix; shape[0]: in_channel*psize_y*psize_x shape[1]: out_height*out_width * num_of_images * \param img source image; shape[-3]: in_channels, shape[-2]: in_height, shape[-1]: in_width, can be 3D or 4D tensor(multiple images) * \param psize_y height of each patch * \param psize_x width of each patch * \param pstride stride of each patch * \tparam SrcExp source expression * \tparam DType the type of elements * \tparam etype type of expression */ template inline UnpackPatchToColXExp::kDim> unpack_patch2col(const Exp &img, index_t psize_y, index_t psize_x, index_t pstride) { TypeCheckPass::kDim >= 3> ::Error_Expression_Does_Not_Meet_Dimension_Req(); return UnpackPatchToColXExp::kDim> (img.self(), psize_y, psize_x, pstride, pstride); } /*! *if you want to specify stride_x and stride_y */ template inline UnpackPatchToColXExp::kDim> unpack_patch2col(const Exp &img, index_t psize_y, index_t psize_x, index_t pstride_y_, index_t pstride_x_) { TypeCheckPass::kDim >= 3> ::Error_Expression_Does_Not_Meet_Dimension_Req(); return UnpackPatchToColXExp::kDim> (img.self(), psize_y, psize_x, pstride_y_, pstride_x_); } //---------------------- // Execution plan //---------------------- template struct Plan, DType> { public: explicit Plan(const UnpackPatchToColXExp &e) :src_(MakePlan(e.img_)), psize_y_(e.psize_y_), psize_x_(e.psize_x_), pstride_y_(e.pstride_y_), pstride_x_(e.pstride_x_), i_channel_(e.i_channel_), i_height_(e.i_height_), i_width_(e.i_width_), o_height_((i_height_ - psize_y_) / pstride_y_ + 1), o_width_((i_width_ - psize_x_) / pstride_x_ + 1) {} MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { const index_t x_offset = i % psize_x_; const index_t idivp = i / psize_x_; const index_t y_offset = idivp % psize_y_; const index_t c = idivp / psize_y_; const index_t x = (j % o_width_) * pstride_x_ + x_offset; const index_t jdivw = j / o_width_; const index_t y = (jdivw % o_height_) * pstride_y_ + y_offset; const index_t n = jdivw / o_height_; if (x < i_width_ && y < i_height_) { return src_.Eval((n * i_channel_ + c) * i_height_ + y, x); } else { return 0.0f; } } private: Plan src_; const index_t psize_y_, psize_x_, pstride_y_, pstride_x_, i_channel_; const index_t i_height_, i_width_, o_height_, o_width_; }; } // namespace expr } // namespace mshadow #endif // MSHADOW_EXTENSION_UNPACK_PATCH2COL_H_ //===== EXPANDED: mxnet/mshadow/mshadow/extension/unpack_patch2col.h ===== //===== EXPANDIND: mxnet/mshadow/mshadow/extension/pack_col2patch.h ===== /*! * Copyright (c) 2014 by Contributors * \file pack_col2patch.h * \brief support for pack * \author Tianqi Chen */ #ifndef MSHADOW_EXTENSION_PACK_COL2PATCH_H_ #define MSHADOW_EXTENSION_PACK_COL2PATCH_H_ namespace mshadow { namespace expr { /*! * \brief reverse operation of UnpackPatchToCol, * used to backprop gradient back * this is a version supporting multiple images * \tparam SrcExp source expression * \tparam DType the type of elements * \tparam dstdim destination dimension */ template struct PackColToPatchXExp: public MakeTensorExp, SrcExp, dstdim, DType> { /*! \brief source operand */ const SrcExp &src_; /*! \brief patch height */ index_t psize_y_; /*! \brief patch height */ index_t psize_x_; /*! \brief patch stride */ index_t pstride_y_; index_t pstride_x_; /*! \brief constructor */ PackColToPatchXExp(const SrcExp &src, Shape imshape, index_t psize_y, index_t psize_x, index_t pstride_y, index_t pstride_x) :src_(src), psize_y_(psize_y), psize_x_(psize_x), pstride_y_(pstride_y), pstride_x_(pstride_x){ this->shape_ = imshape; const index_t o_height = (imshape[dstdim - 2] - psize_y) / pstride_y + 1; const index_t o_width = (imshape[dstdim - 1] - psize_x) / pstride_x + 1; Shape<2> sshape = ShapeCheck<2, SrcExp>::Check(src_); CHECK_EQ(sshape[1], o_height * o_width * imshape.ProdShape(0, dstdim - 3)) << "PackColToPatchExp: src.size(1) mismatch"; CHECK_EQ(sshape[0], psize_y * psize_x * imshape[dstdim - 3]) << "PackColToPatchExp: src.size(0) mismatch"; } }; /*! * \brief reverse operation of pack_col2patch, can be used to implement deconvolution * \return packed img expression * \param mat source matrix * \param imshape shape of target img * \param psize_y height of each patch * \param psize_x height of each patch * \param pstride stride of each patch * \tparam SrcExp source expression * \tparam DType the type of elements * \tparam dstdim destination dimension * \tparam etype type of expression */ template inline PackColToPatchXExp pack_col2patch(const expr::Exp &src, Shape imshape, index_t psize_y, index_t psize_x, index_t pstride) { TypeCheckPass::kDim == 2> ::Error_Expression_Does_Not_Meet_Dimension_Req(); CHECK(imshape[dstdim - 1] >= psize_x && imshape[dstdim - 2] >= psize_y) << "PackColToPatch:image shape smaller than patch size"; return PackColToPatchXExp(src.self(), imshape, psize_y, psize_x, pstride, pstride); } /*! *if you want to specify kstride_y and kstride_x */ template inline PackColToPatchXExp pack_col2patch(const expr::Exp &src, Shape imshape, index_t psize_y, index_t psize_x, index_t pstride_y, index_t pstride_x) { TypeCheckPass::kDim == 2> ::Error_Expression_Does_Not_Meet_Dimension_Req(); CHECK(imshape[dstdim - 1] >= psize_x && imshape[dstdim - 2] >= psize_y) << "PackColToPatch:image shape smaller than patch size"; return PackColToPatchXExp(src.self(), imshape, psize_y, psize_x, pstride_y, pstride_x); } //---------------------- // Execution plan //---------------------- template struct Plan, DType> { public: explicit Plan(const PackColToPatchXExp &e) :src_(MakePlan(e.src_)), psize_y_(e.psize_y_), psize_x_(e.psize_x_), pstride_y_(e.pstride_y_), pstride_x_(e.pstride_x_), i_channel_(e.shape_[dstdim - 3]), i_height_(e.shape_[dstdim - 2]), o_height_((e.shape_[dstdim - 2] - psize_y_) / pstride_y_ + 1), o_width_((e.shape_[dstdim - 1] - psize_x_) / pstride_x_ + 1) { // note: i/o convention are same as unpack } MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { using namespace std; const index_t y = i % i_height_; const index_t idivh = i / i_height_; const index_t c = idivh % i_channel_; const index_t n = idivh / i_channel_; const index_t x = j; const index_t py_min = y < psize_y_ ? 0 : (y-psize_y_ + pstride_y_) / pstride_y_; const index_t px_min = x < psize_x_ ? 0 : (x-psize_x_ + pstride_x_) / pstride_x_; const index_t py_max = min((y + pstride_y_) / pstride_y_, o_height_); const index_t px_max = min((x + pstride_x_) / pstride_x_, o_width_); DType res = static_cast(0); for (index_t py = py_min; py < py_max; ++py) { for (index_t px = px_min; px < px_max; ++px) { res += src_.Eval(((c * psize_y_ + y - py*pstride_y_) * psize_x_ + x - px * pstride_x_), (n * o_height_ + py) * o_width_ + px); } } return res; } private: Plan src_; const index_t psize_y_, psize_x_, pstride_y_, pstride_x_, i_channel_; const index_t i_height_, o_height_, o_width_; }; } // namespace expr } // namespace mshadow #endif // MSHADOW_EXTENSION_PACK_COL2PATCH_H_ //===== EXPANDED: mxnet/mshadow/mshadow/extension/pack_col2patch.h ===== //===== EXPANDIND: mxnet/mshadow/mshadow/extension/reshape.h ===== /*! * Copyright (c) 2014 by Contributors * \file reshape.h * \brief support for reshape * \author Tianqi Chen */ #ifndef MSHADOW_EXTENSION_RESHAPE_H_ #define MSHADOW_EXTENSION_RESHAPE_H_ namespace mshadow { namespace expr { /*! * \brief reshape the content to another shape * input: Tensor: ishape * output: Tensor ishape.Size() == oshape.Size() * \tparam SrcExp source expression * \tparam dimdst target dimension * \tparam dimsrc source dimension */ template struct ReshapeExp: public MakeTensorExp, SrcExp, dimdst, DType> { /*! \brief source expression */ const SrcExp &src_; /*! \brief smallest dimension of input */ index_t ishapex_; /*! \brief constructor */ ReshapeExp(const SrcExp &src, Shape shape) : src_(src) { Shape ishape = ShapeCheck::Check(src_); CHECK_EQ(ishape.Size(), shape.Size()) << "reshape size must match"; ishapex_ = ishape[dimsrc - 1]; this->shape_ = shape; } }; /*! * \brief a expression that reshapes a tensor to another shape * \param src Tensor: * \param oshape target shape * \return a expresion with type Tensor * \tparam SrcExp source expression * \tparam etype source expression type * \tparam dimdst target dimension */ template inline ReshapeExp::kDim> reshape(const Exp &src, Shape oshape) { return ReshapeExp::kDim> (src.self(), oshape); } //---------------------- // Execution plan //---------------------- template struct Plan, DType> { public: explicit Plan(const ReshapeExp &e) : src_(MakePlan(e.src_)), oshapex_(e.shape_[dimdst - 1]), ishapex_(e.ishapex_) {} MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { const index_t idx = y * oshapex_ + x; return src_.Eval(idx / ishapex_, idx % ishapex_); } private: Plan src_; const index_t oshapex_, ishapex_; }; // special work plan for 1 dimensional data template struct Plan, DType> { public: explicit Plan(const ReshapeExp &e) : src_(MakePlan(e.src_)), oshapex_(e.shape_[dimdst - 1]) { } MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { return src_.Eval(0, y * oshapex_ + x); } private: Plan src_; const index_t oshapex_; }; } // namespace expr } // namespace mshadow #endif // MSHADOW_EXTENSION_RESHAPE_H_ //===== EXPANDED: mxnet/mshadow/mshadow/extension/reshape.h ===== //===== EXPANDIND: mxnet/mshadow/mshadow/extension/swapaxis.h ===== /*! * Copyright (c) 2014 by Contributors * \file swapaxis.h * \brief support for swapaxis * \author Tianqi Chen */ #ifndef MSHADOW_EXTENSION_SWAPAXIS_H_ #define MSHADOW_EXTENSION_SWAPAXIS_H_ namespace mshadow { namespace expr { /*! * \brief swap two axis of a tensor * input: Tensor: ishape * output: Tensor oshape[a1],oshape[a2] = ishape[a2],oshape[a1] * * \tparam SrcExp type of source expression * \tparam DType the type of elements * \tparam dimsrc source dimension, assert a1 > a2 * \tparam m_a1 one dimension to be swapped, encoded by dimsrc - a1 * \tparam a2 second dimension to be swapped, encoded by a2 */ template struct SwapAxisExp: public MakeTensorExp, SrcExp, dimsrc, DType> { // decode the a1, a2 static const int a1 = dimsrc - m_a1; /*! \brief source expression */ const SrcExp &src_; /*! \brief constructor */ explicit SwapAxisExp(const SrcExp &src) : src_(src) { this->shape_ = ShapeCheck::Check(src); std::swap(this->shape_[a1], this->shape_[a2]); } }; /*! * \brief a expression that reshapes a tensor to another shape * \param src Tensor: * \return a expresion with type Tensor * \tparam a1 higher dimension to be swapped, assert a1 > a2 * \tparam a2 lower dimension to be swapped * \tparam SrcExp source expression * \tparam DType the type of elements * \tparam etype source expression type */ template inline SwapAxisExp::kDim, ExpInfo::kDim - a1, a2> swapaxis(const Exp &src) { typedef ExpInfo Info; TypeCheckPass= a1 + 1 && Info::kDim >= a2 + 1 && a2 < a1>::Error_Expression_Does_Not_Meet_Dimension_Req(); return SwapAxisExp::kDim, ExpInfo::kDim - a1, a2>(src.self()); } template struct Plan, DType> { public: // decode the a1 static const int a1 = dimsrc - m_a1; explicit Plan(const SwapAxisExp &e) : src_(MakePlan(e.src_)), shapey_(e.shape_.ProdShape(a1 + 1, dimsrc - 1)), shapez_(e.shape_[a1]), shapec_(e.shape_.ProdShape(a2 + 1, a1)), shapen_(e.shape_[a2]) {} MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { const index_t y = i % shapey_; i /= shapey_; const index_t z = i % shapez_; i /= shapez_; const index_t c = i % shapec_; i /= shapec_; const index_t n = i % shapen_; // swap z and n return src_.Eval(((((i / shapen_) * shapez_ + z) * shapec_ + c) * shapen_ + n) * shapey_ + y, j); } private: Plan src_; const index_t shapey_, shapez_, shapec_, shapen_; }; template struct Plan, DType> { public: explicit Plan(const SwapAxisExp &e) : src_(MakePlan(e.src_)), shapex_(e.shape_[dimsrc - 1]), shapey_(e.shape_.ProdShape(a2 + 1, dimsrc - 1)), shapez_(e.shape_[a2]) {} MSHADOW_XINLINE DType Eval(index_t i, index_t x) const { // swap x and z const index_t y = i % shapey_; i /= shapey_; const index_t z = i % shapez_; const index_t n = i / shapez_; return src_.Eval((n * shapex_ + x) * shapey_ + y , z); } private: Plan src_; const index_t shapex_, shapey_, shapez_; }; } // namespace expr } // namespace mshadow #endif // MSHADOW_EXTENSION_SWAPAXIS_H_ //===== EXPANDED: mxnet/mshadow/mshadow/extension/swapaxis.h ===== //===== EXPANDIND: mxnet/mshadow/mshadow/extension/reduceto1d.h ===== /*! * Copyright (c) 2014 by Contributors * \file reduceto1d.h * \brief support for sum_rows and sumall_except_dim * \author Tianqi Chen */ #ifndef MSHADOW_EXTENSION_REDUCETO1D_H_ #define MSHADOW_EXTENSION_REDUCETO1D_H_ namespace mshadow { namespace expr { /*! * \brief reduction to 1 dimension tensor * input: Tensor: ishape * output: Tensor shape[0] = ishape[dimkeep]; * * \tparam SrcExp type of expression to be reduced * \tparam DType the data type of the scalar * \tparam Reducer which reducer to use * \tparam m_dimkeep which dimension to be kept, encoded with dimsrc - dimkeep */ template struct ReduceTo1DExp: public Exp, DType, type::kComplex> { /*! \brief source operand */ const SrcExp &src_; /*! \brief source operand, scale of the */ DType scale_; /*! \brief construct a repmat expression from src and nrow */ ReduceTo1DExp(const SrcExp& src, DType scale) : src_(src), scale_(scale) {} }; /*! * \brief a sum over all dimensions, except dimkeep * \param exp input expression that must be a matrix Tensor * \return a expresion with type Tensor * \tparam dimkeep the dimension that will be kept * \tparam SrcExp expression * \tparam etype type of expression */ template inline ReduceTo1DExp::kDim - dimkeep> sumall_except_dim(const Exp &exp) { return ReduceTo1DExp::kDim - dimkeep>(exp.self(), 1); } /*! * \brief a expression that sum over rows of a matrix * \param exp input expression that must be a matrix Tensor * \return a expresion with type Tensor * \tparam SrcExp expression * \tparam etype type of expression */ template inline ReduceTo1DExp sum_rows(const Exp &exp) { TypeCheckPass::kDim ==2> ::Error_Expression_Does_Not_Meet_Dimension_Req(); return sumall_except_dim<1>(exp); } template struct ExpComplexEngine, ReduceTo1DExp, DType> { static const int dimkeep = ExpInfo::kDim - m_dimkeep; inline static void Eval(Tensor *dst, const ReduceTo1DExp &exp) { TypeCheckPass ::Error_Expression_Does_Not_Meet_Dimension_Req(); MapReduceKeepHighDim(dst, exp.src_, exp.scale_); } }; template struct ExpComplexEngine, ReduceTo1DExp, DType> { inline static void Eval(Tensor *dst, const ReduceTo1DExp &exp) { MapReduceKeepLowest(dst, exp.src_, exp.scale_); } }; } // namespace expr } // namespace mshadow #endif // MSHADOW_EXTENSION_REDUCETO1D_H_ //===== EXPANDED: mxnet/mshadow/mshadow/extension/reduceto1d.h ===== //===== EXPANDIND: mxnet/mshadow/mshadow/extension/spatial_pool.h ===== /*! * Copyright (c) 2014 by Contributors * \file spatial_pool.h * \brief support for spatial pooling * \author Tianqi Chen */ #ifndef MSHADOW_EXTENSION_SPATIAL_POOL_H_ #define MSHADOW_EXTENSION_SPATIAL_POOL_H_ namespace mshadow { namespace expr { /*! * \brief pooling expression, do reduction over local patches of a image * \tparam Reducer reduction method during pooling * \tparam SrcExp source expression to be pooled from * \tparam DType the content data type * \tparam srcdim dimension of src */ template struct PoolingExp: public MakeTensorExp, SrcExp, srcdim, DType> { /*! \brief source operand */ const SrcExp &src_; /*! \brief kernel size in height */ index_t ksize_y_; /*! \brief kernel size in width */ index_t ksize_x_; /*! \brief kernel stride */ index_t kstride_; /*! \brief source height shape[1] */ index_t src_height_; /*! \brief source width shape[0] */ index_t src_width_; /*! \brief constructor */ PoolingExp(const SrcExp &src, index_t ksize_y, index_t ksize_x, index_t kstride) : src_(src), ksize_y_(ksize_y), ksize_x_(ksize_x), kstride_(kstride) { Shape sshape = ShapeCheck::Check(src_); CHECK(sshape[srcdim - 1] >= ksize_x && sshape[srcdim - 2] >= ksize_y) << "PoolingExp: kernel must be smaller than image"; this->src_height_ = sshape[srcdim - 2]; this->src_width_ = sshape[srcdim - 1]; this->shape_ = sshape; this->shape_[srcdim - 2] = (src_height_ - ksize_y) / kstride + 1; this->shape_[srcdim - 1] = (src_width_ - ksize_x) / kstride + 1; } /*! \brief constructor, specify shape */ PoolingExp(const SrcExp &src, Shape<2> pshape, index_t ksize_y, index_t ksize_x, index_t kstride) : src_(src), ksize_y_(ksize_y), ksize_x_(ksize_x), kstride_(kstride) { Shape sshape = ShapeCheck::Check(src_); CHECK(sshape[srcdim - 1] >= ksize_x && sshape[srcdim - 2] >= ksize_y) << "PoolingExp: kernel must be smaller than image"; this->src_height_ = sshape[srcdim - 2]; this->src_width_ = sshape[srcdim - 1]; this->shape_ = sshape; this->shape_[srcdim - 2] = pshape[0]; this->shape_[srcdim - 1] = pshape[1]; } }; /*! * \brief pooling subregion results together * \param src source image, shape: (batch, channel, height, width) * \param ksize_y kernel size in height * \param ksize_x kernel size in width * \param kstride stride for each kernel * \return expression of pooled result * \tparam Reducer reducer type * \tparam SrcExp source expression * \tparam DType the content data type * \tparam etype type of expression */ template inline PoolingExp::kDim> pool(const Exp &src, index_t ksize_y, index_t ksize_x, index_t kstride) { TypeCheckPass::kDim >= 2> ::Error_Expression_Does_Not_Meet_Dimension_Req(); return PoolingExp::kDim> (src.self(), ksize_y, ksize_x, kstride); } /*! * \brief same as pool, except the output shape is specified by pshape * \param src source image * \param pshape ouput shape * \param ksize_y kernel size in y * \param ksize_x kernel size in x * \param kstride stride for each kernel * \return expression of pooled result * \tparam Reducer reducer type * \tparam SrcExp source expression * \tparam DType the content data type * \tparam etype type of expression */ template inline PoolingExp::kDim> pool(const Exp &src, Shape<2> pshape, index_t ksize_y, index_t ksize_x, index_t kstride) { TypeCheckPass::kDim >= 2> ::Error_Expression_Does_Not_Meet_Dimension_Req(); return PoolingExp::kDim> (src.self(), pshape, ksize_y, ksize_x, kstride); } //---------------------- // Execution plan //---------------------- template struct Plan, DType> { public: explicit Plan(const PoolingExp &e) : src_(MakePlan(e.src_)), ksize_y_(e.ksize_y_), ksize_x_(e.ksize_x_), kstride_(e.kstride_), src_height_(e.src_height_), src_width_(e.src_width_), new_height_(e.shape_[srcdim - 2]) {} MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { using namespace std; const index_t py = i % new_height_; const index_t y_start = py * kstride_; const index_t y_end = min(y_start + ksize_y_, src_height_); const index_t px = j; const index_t x_start = px * kstride_; const index_t x_end = min(x_start + ksize_x_, src_width_); const index_t c = i / new_height_; DType res; Reducer::SetInitValue(res); for (index_t y = y_start; y < y_end; ++y) { for (index_t x = x_start; x < x_end; ++x) { Reducer::Reduce(res, src_.Eval(c * src_height_ + y, x)); } } return res; } private: Plan src_; const index_t ksize_y_, ksize_x_, kstride_; const index_t src_height_, src_width_; const index_t new_height_; }; } // namespace expr } // namespace mshadow #endif // MSHADOW_EXTENSION_SPATIAL_POOL_H_ //===== EXPANDED: mxnet/mshadow/mshadow/extension/spatial_pool.h ===== //===== EXPANDIND: mxnet/mshadow/mshadow/extension/spatial_unpool.h ===== /*! * Copyright (c) 2014 by Contributors * \file spatial_unpool.h * \brief support for unpool * \author Tianqi Chen */ #ifndef MSHADOW_EXTENSION_SPATIAL_UNPOOL_H_ #define MSHADOW_EXTENSION_SPATIAL_UNPOOL_H_ namespace mshadow { namespace expr { /*! * \brief unpooling expr reverse operation of pooling, used to pass gradient back * \tparam Reducer reduction method during pooling * \tparam SrcExp source expression to be pooled from * \tparam DType the content data type * \tparam srcdim dimension of src */ template struct UnPoolingExp: public MakeTensorExp, SrcExp, srcdim, DType> { /*! \brief source input, corresponds to src in pooling */ const SrcExp &data_src_; /*! \brief result of pooled data, corresponds to result of pooling */ const SrcExp &data_pooled_; /*! \brief gradient data of pooled part, to be propgate down */ const SrcExp &grad_pooled_; /*! \brief shape of pooled expression */ index_t pshape_y_; /*! \brief shape of pooled expression */ index_t pshape_x_; /*! \brief kernel size in height */ index_t ksize_y_; /*! \brief kernel size in width */ index_t ksize_x_; /*! \brief kernel stride */ index_t kstride_; /*! \brief constructor */ UnPoolingExp(const SrcExp &data_src, const SrcExp &data_pooled, const SrcExp &grad_pooled, index_t ksize_y, index_t ksize_x, index_t kstride) : data_src_(data_src), data_pooled_(data_pooled), grad_pooled_(grad_pooled), ksize_y_(ksize_y), ksize_x_(ksize_x), kstride_(kstride) { Shape pshape = ShapeCheck::Check(grad_pooled); typedef ShapeCheck ShapeCheckSrcDimSrcExp; CHECK_EQ(pshape, ShapeCheckSrcDimSrcExp::Check(data_pooled)) << "UnPoolingExp: pooled shape mismatch"; Shape sshape = ShapeCheck::Check(data_src); for (int k = 0; k < srcdim - 2; ++k) { CHECK_EQ(pshape[k], sshape[k]) << "UnPoolingExp: pool and src shape mismatch"; } pshape_x_ = pshape[srcdim - 1]; pshape_y_ = pshape[srcdim - 2]; this->shape_ = sshape; } }; /*! * \brief unpooling gradient for 4D, backprop gradient value back, revserse operation of pooling, * same as unpooling, but allows unequal size of kernel * \param data_src source input, corresponds to src in pooling * \param data_pooled result of pooled data, corresponds to result of pooling * \param grad_pooled gradient data of pooled part, to be propgate down * \param ksize_y kernel height * \param ksize_x kernel width * \param kstride stride for each kernel * \return expression corresponding to unpooled 4D Tensor, storing backproped gradient * \tparam Reducer reducer type * \tparam SrcExp source expression * \tparam DType the content data type * \tparam etype type of expression */ template inline UnPoolingExp::kDim> unpool(const Exp &data_src, const Exp &data_pooled, const Exp &grad_pooled, index_t ksize_y, index_t ksize_x, index_t kstride) { return UnPoolingExp::kDim> (data_src.self(), data_pooled.self(), grad_pooled.self(), ksize_y, ksize_x, kstride); } //---------------------- // Execution plan //---------------------- template struct Plan, DType> { public: explicit Plan(const UnPoolingExp &e) : data_src_(MakePlan(e.data_src_)), data_pooled_(MakePlan(e.data_pooled_)), grad_pooled_(MakePlan(e.grad_pooled_)), sshape_y_(e.shape_[srcdim - 2]), pshape_y_(e.pshape_y_), pshape_x_(e.pshape_x_), ksize_y_(e.ksize_y_), ksize_x_(e.ksize_x_), kstride_(e.kstride_) {} MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { using namespace std; const index_t x = j; const index_t y = i % sshape_y_; const index_t c = i / sshape_y_; const DType vsrc = data_src_.Eval(i, j); const index_t py_min = y < ksize_y_ ? 0 : (y - ksize_y_ + kstride_) / kstride_; const index_t px_min = x < ksize_x_ ? 0 : (x - ksize_x_ + kstride_) / kstride_; const index_t py_max = min((y + kstride_) / kstride_, pshape_y_); const index_t px_max = min((x + kstride_) / kstride_, pshape_x_); DType val = static_cast(0); for (index_t py = py_min; py < py_max; ++py) { for (index_t px = px_min; px < px_max; ++px) { val += Reducer::PartialGrad(vsrc, data_pooled_.Eval(c * pshape_y_ + py, px)) * grad_pooled_.Eval(c * pshape_y_ + py, px); } } return val; } private: Plan data_src_, data_pooled_, grad_pooled_; const index_t sshape_y_, pshape_y_, pshape_x_; const index_t ksize_y_, ksize_x_; const index_t kstride_; }; } // namespace expr } // namespace mshadow #endif // MSHADOW_EXTENSION_SPATIAL_UNPOOL_H_ //===== EXPANDED: mxnet/mshadow/mshadow/extension/spatial_unpool.h ===== //===== EXPANDIND: mxnet/mshadow/mshadow/extension/channel_pool.h ===== /*! * Copyright (c) 2014 by Contributors * \file channel_pool.h * \brief support for chpool * \author Tianqi Chen */ #ifndef MSHADOW_EXTENSION_CHANNEL_POOL_H_ #define MSHADOW_EXTENSION_CHANNEL_POOL_H_ namespace mshadow { namespace expr { /*! * \brief channel pooling expression, do reduction over (local nearby) channels, * used to implement local response normalization * \tparam Reducer reduction method during pooling * \tparam SrcExp source expression to be pooled from * \tparam DType the type of elements * \tparam srcdim dimension of src */ template struct ChannelPoolingExp: public MakeTensorExp, SrcExp, srcdim, DType> { /*! \brief source operand */ const SrcExp &src_; /*! \brief neighbor size */ index_t nsize_; /*! \brief stride of pooling */ index_t stride_; /*! \brief pad of pooling of each side */ index_t pad_; index_t src_channel_; /*! \brief constructor */ ChannelPoolingExp(const SrcExp &src, index_t nsize, index_t stride, index_t pad) : src_(src), nsize_(nsize), stride_(stride), pad_(pad) { this->shape_ = ShapeCheck::Check(src_); this->src_channel_ = this->shape_[srcdim - 3]; CHECK_GE(this->shape_[srcdim - 3], nsize_) << "chpool: local size must be smaller than nchannels"; this->shape_[srcdim - 3] = (this->src_channel_ - nsize + pad * 2 + 1) / stride; } }; /*! * \brief channel pooling, do reduction over (local nearby) channels, * used to implement local response normalization * \param src source data * \param nsize neighbor size * \return expression of pooled result * \tparam Reducer reducer type * \tparam SrcExp source expression * \tparam DType the type of elements * \tparam etype type of expression */ template inline ChannelPoolingExp::kDim> chpool(const Exp &src, index_t nsize) { TypeCheckPass::kDim >= 3> ::Error_Expression_Does_Not_Meet_Dimension_Req(); CHECK_EQ(nsize % 2, 1) << "chpool: if no pad is specified, local size must be odd"; return ChannelPoolingExp::kDim>(src.self(), nsize, 1, nsize / 2); } template inline ChannelPoolingExp::kDim> chpool(const Exp &src, index_t nsize, index_t stride, index_t pad) { TypeCheckPass::kDim >= 3> ::Error_Expression_Does_Not_Meet_Dimension_Req(); return ChannelPoolingExp::kDim>(src.self(), nsize, stride, pad); } //---------------------- // Execution plan //---------------------- template struct Plan, DType> { public: explicit Plan(const ChannelPoolingExp &e) : src_(MakePlan(e.src_)), channel_(e.shape_[srcdim - 3]), height_(e.shape_[srcdim - 2]), width_(e.shape_[srcdim - 1]), hnsize_(e.nsize_), stride_(e.stride_), pad_(e.pad_), src_channel_(e.src_channel_) {} MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { using namespace std; const index_t y = i % height_; i /= height_; const index_t c = i % channel_; const index_t n = i / channel_; const index_t x = j; const index_t cstart = c * stride_ < pad_ ? 0 : c * stride_ - pad_; const index_t cend = min(cstart + hnsize_, channel_); DType res; Reducer::SetInitValue(res); for (index_t cc = cstart; cc < cend; ++cc) { Reducer::Reduce(res, src_.Eval((n * src_channel_ + cc) * height_ + y, x)); } return res; } private: Plan src_; const index_t channel_, height_, width_, hnsize_, stride_, pad_, src_channel_; }; } // namespace expr } // namespace mshadow #endif // MSHADOW_EXTENSION_CHANNEL_POOL_H_ //===== EXPANDED: mxnet/mshadow/mshadow/extension/channel_pool.h ===== //===== EXPANDIND: mxnet/mshadow/mshadow/extension/channel_unpool.h ===== /*! * Copyright (c) 2014 by Contributors * \file channel_pool.h * \brief support for chpool * \author Tianqi Chen */ #ifndef MSHADOW_EXTENSION_CHANNEL_UNPOOL_H_ #define MSHADOW_EXTENSION_CHANNEL_UNPOOL_H_ namespace mshadow { namespace expr { /*! * \brief channel pooling expression, do reduction over (local nearby) channels, * used to implement local response normalization * \tparam Reducer reduction method during pooling * \tparam SrcExp source expression to be pooled from * \tparam DType the type of elements * \tparam srcdim dimension of src */ template struct ChannelUnpoolingExp: public MakeTensorExp, SrcExp, srcdim, DType> { /*! \brief source input, corresponds to src in pooling */ const SrcExp &data_src_; /*! \brief result of pooled data, corresponds to result of pooling */ const SrcExp &data_pooled_; /*! \brief gradient data of pooled part, to be propgate down */ const SrcExp &grad_pooled_; /*! \brief channel of pooled expression */ index_t pchannel_; /*! \brief kernel size in height */ index_t nsize_; /*! \brief kernel size in width */ index_t kstride_; /*! \brief pad */ index_t pad_; /*! \brief constructor */ ChannelUnpoolingExp(const SrcExp &data_src, const SrcExp &data_pooled, const SrcExp &grad_pooled, index_t nsize, index_t kstride, index_t pad) : data_src_(data_src), data_pooled_(data_pooled), grad_pooled_(grad_pooled), nsize_(nsize), kstride_(kstride), pad_(pad) { Shape pshape = ShapeCheck::Check(grad_pooled); typedef ShapeCheck ShapeCheckSrcDimSrcExp; CHECK_EQ(pshape, ShapeCheckSrcDimSrcExp::Check(data_pooled)) << "ChannelUnPoolingExp: data and grad shape mismatch"; Shape sshape = ShapeCheck::Check(data_src); for (int k = 0; k < srcdim; ++k) { if (k == 1) { continue; } CHECK_EQ(pshape[k], sshape[k]) << "ChannelUnPoolingExp: pooled tensor and src tensor shape mismatch" << pshape[k] << " vs " << sshape[k]; } pchannel_ = pshape[1]; this->shape_ = sshape; } }; /*! * \brief channel unpooling, do unroll over (local nearby) channels * \param src source data * \param nsize neighbor size * \param stride stride of the pooling * \param pad number of padding at each side * \return expression of pooled result * \tparam Reducer reducer type * \tparam SrcExp source expression * \tparam DType the type of elements * \tparam etype type of expression */ template inline ChannelUnpoolingExp::kDim> ch_unpool(const Exp &data_src, const Exp &data_pooled, const Exp &grad_pooled, index_t nsize, index_t stride, index_t pad) { TypeCheckPass::kDim >= 3> ::Error_Expression_Does_Not_Meet_Dimension_Req(); return ChannelUnpoolingExp::kDim> (data_src.self(), data_pooled.self(), grad_pooled.self(), nsize, stride, pad); } template inline ChannelUnpoolingExp::kDim> ch_unpool(const Exp &data_src, const Exp &data_pooled, const Exp &grad_pooled, index_t nsize) { return ch_unpool(data_src, data_pooled, grad_pooled, nsize, 1, nsize / 2); } //---------------------- // Execution plan //---------------------- template struct Plan, DType> { public: explicit Plan(const ChannelUnpoolingExp &e) : data_src_(e.data_src_), data_pooled_(e.data_pooled_), grad_pooled_(e.grad_pooled_), channel_(e.shape_[srcdim - 3]), height_(e.shape_[srcdim - 2]), pchannel_(e.pchannel_), hnsize_(e.nsize_), stride_(e.kstride_), pad_(e.pad_) {} MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { using namespace std; const DType vsrc = data_src_.Eval(i, j); const index_t y = i % height_; i /= height_; const index_t c = i % channel_; const index_t n = i / channel_; const index_t x = j; const index_t cstart = c < hnsize_ - pad_ ? 0 : (c - (hnsize_ - pad_) + stride_) / stride_; const index_t cend = min((c + pad_ + stride_) / stride_, channel_); DType val = static_cast(0); for (index_t cc = cstart; cc < cend; ++cc) { val += Reducer::PartialGrad(vsrc, data_pooled_.Eval((n * pchannel_ + cc) * height_ + y, x)) * grad_pooled_.Eval((n * pchannel_ + cc) * height_ + y, x); } return val; } private: Plan data_src_, data_pooled_, grad_pooled_; const index_t channel_, height_, pchannel_, hnsize_, stride_, pad_; }; } // namespace expr } // namespace mshadow #endif // MSHADOW_EXTENSION_CHANNEL_UNPOOL_H_ //===== EXPANDED: mxnet/mshadow/mshadow/extension/channel_unpool.h ===== //===== EXPANDIND: mxnet/mshadow/mshadow/extension/pad.h ===== /*! * Copyright (c) 2014 by Contributors * \file pad.h * \brief support for pad * \author Tianqi Chen */ #ifndef MSHADOW_EXTENSION_PAD_H_ #define MSHADOW_EXTENSION_PAD_H_ namespace mshadow { namespace expr { /*! * \brief padding expression, pad a image with zeros * \tparam SrcExp source expression * \tparam DType the type of elements * \tparam srcdim dimension of src */ template struct PaddingExp: public MakeTensorExp, SrcExp, srcdim, DType> { /*! \brief source operand */ const SrcExp &src_; /*! \brief pad size in y */ index_t pad_y_; /*! \brief pad size in x */ index_t pad_x_; /*! \brief source tensor height */ index_t src_height_; /*! \brief source tensor width */ index_t src_width_; /*! \brief constructor */ PaddingExp(const SrcExp &src, index_t pad_y, index_t pad_x) : src_(src), pad_y_(pad_y), pad_x_(pad_x) { this->shape_ = ShapeCheck::Check(src_); src_height_ = this->shape_[srcdim - 2]; src_width_ = this->shape_[srcdim - 1]; this->shape_[srcdim - 2] += pad_y * 2; // height this->shape_[srcdim - 1] += pad_x * 2; // width } }; /*! * \brief padding expression, pad a image with zeros on boundaries, padding affects shape[0], and shape[1] * \param src original image batches * \param pad padding size * \return expression corresponding to padded result * \tparam SrcExp source expression * \tparam DType the content data type * \tparam etype type of expression */ template inline PaddingExp::kDim> pad(const Exp &src, index_t pad) { TypeCheckPass::kDim >= 2> ::Error_Expression_Does_Not_Meet_Dimension_Req(); return PaddingExp::kDim>(src.self(), pad, pad); } /*! * \brief padding expression, pad a image with zeros on boundaries, padding affects shape[0], and shape[1] * \param src original image batches * \param pad_y padding size in y * \param pad_x padding size in x * \return expression corresponding to padded result * \tparam SrcExp source expression * \tparam DType the content data type * \tparam etype type of expression */ template inline PaddingExp::kDim> pad(const Exp &src, index_t pad_y, index_t pad_x) { TypeCheckPass::kDim >= 2> ::Error_Expression_Does_Not_Meet_Dimension_Req(); return PaddingExp::kDim> (src.self(), pad_y, pad_x); } //---------------------- // Execution plan //---------------------- template struct Plan, DType> { public: explicit Plan(const PaddingExp &e) : src_(MakePlan(e.src_)), pad_y_(e.pad_y_), pad_x_(e.pad_x_), new_height_(e.shape_[srcdim - 2]), src_height_(e.src_height_), src_width_(e.src_width_) {} MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { const index_t x = j; const index_t y = i % new_height_; const index_t c = i / new_height_; if (y < pad_y_ || x < pad_x_) return static_cast(0); const index_t h = y - pad_y_; const index_t w = x - pad_x_; if (h < src_height_ && w < src_width_) { return src_.Eval(c * src_height_ + h, w); } else { return static_cast(0); } } private: Plan src_; const index_t pad_y_; const index_t pad_x_; const index_t new_height_; const index_t src_height_; const index_t src_width_; }; } // namespace expr } // namespace mshadow #endif // MSHADOW_EXTENSION_PAD_H_ //===== EXPANDED: mxnet/mshadow/mshadow/extension/pad.h ===== //===== EXPANDIND: mxnet/mshadow/mshadow/extension/crop.h ===== /*! * Copyright (c) 2014 by Contributors * \file crop.h * \brief support for crop * \author Tianqi Chen */ #ifndef MSHADOW_EXTENSION_CROP_H_ #define MSHADOW_EXTENSION_CROP_H_ namespace mshadow { namespace expr { /*! * \brief crop expression, cut off the boundary region, reverse operation of padding * \tparam SrcExp source expression to be pooled from * \tparam DType the type of elements * \tparam srcdim dimension of src */ template struct CroppingExp: public MakeTensorExp, SrcExp, srcdim, DType> { /*! \brief source operand */ const SrcExp &src_; /*! \brief pad height */ index_t pad_height_; /*! \brief pad height */ index_t pad_width_; /*! \brief src height */ index_t src_height_; /*! \brief constructor */ explicit CroppingExp(const SrcExp &src, Shape<2> cshape) : src_(src) { this->shape_ = ShapeCheck::Check(src_); CHECK_GE(this->shape_[srcdim - 2], cshape[0]) << "CroppingExp: height requirement not met"; CHECK_GE(this->shape_[srcdim - 1], cshape[1]) << "CroppingExp: width requirement not met"; pad_height_ = (this->shape_[srcdim - 2] - cshape[0]) / 2; pad_width_ = (this->shape_[srcdim - 1] - cshape[1]) / 2; src_height_ = this->shape_[srcdim - 2]; this->shape_[srcdim - 2] = cshape[0]; // height this->shape_[srcdim - 1] = cshape[1]; // width } /*! \brief constructor */ explicit CroppingExp(const SrcExp &src, Shape<2> cshape, index_t start_height, index_t start_width) : src_(src), pad_height_(start_height), pad_width_(start_width) { this->shape_ = ShapeCheck::Check(src_); CHECK_GE(this->shape_[srcdim - 2], cshape[0] + start_height) << "CroppingExp: height requirement not met"; CHECK_GE(this->shape_[srcdim - 1], cshape[1] + start_width) << "CroppingExp: width requirement not met"; src_height_ = this->shape_[srcdim - 2]; this->shape_[srcdim - 2] = cshape[0]; // height this->shape_[srcdim - 1] = cshape[1]; // width } }; // struct CroppingExp /*! * \brief revserse operationg of padding, cut off boundaries, * crop output from center of input * \param src original image batches * \param oshape output shape to be cropped * \return expression corresponding to padded result * \tparam SrcExp source expression * \tparam DType the type of elements * \tparam etype type of expression */ template inline CroppingExp::kDim> crop(const Exp &src, Shape<2> oshape) { TypeCheckPass::kDim >= 2> ::Error_Expression_Does_Not_Meet_Dimension_Req(); return CroppingExp::kDim>(src.self(), oshape); } /*! * \brief same as crop, but can specify starting position to do cropping * \param src original image batches * \param oshape output shape to be cropped * \param start_height start height position to do cropping * \param start_width start width position to do cropping * \return expression corresponding to padded result * \tparam SrcExp source expression * \tparam DType the type of elements * \tparam etype type of expression */ template inline CroppingExp::kDim> crop(const Exp &src, Shape<2> oshape, index_t start_height, index_t start_width) { TypeCheckPass::kDim >= 2> ::Error_Expression_Does_Not_Meet_Dimension_Req(); return CroppingExp::kDim> (src.self(), oshape, start_height, start_width); } //---------------------- // Execution plan //---------------------- template struct Plan, DType> { public: explicit Plan(const CroppingExp &e) : src_(MakePlan(e.src_)), pad_height_(e.pad_height_), pad_width_(e.pad_width_), new_height_(e.shape_[srcdim - 2]), src_height_(e.src_height_) {} MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { const index_t x = j; const index_t y = i % new_height_; const index_t c = i / new_height_; const index_t h = y + pad_height_; const index_t w = x + pad_width_; return src_.Eval(c * src_height_ + h, w); } private: Plan src_; const index_t pad_height_, pad_width_; const index_t new_height_; const index_t src_height_; }; } // namespace expr } // namespace mshadow #endif // MSHADOW_EXTENSION_CROP_H_ //===== EXPANDED: mxnet/mshadow/mshadow/extension/crop.h ===== //===== EXPANDIND: mxnet/mshadow/mshadow/extension/mirror.h ===== /*! * Copyright (c) 2014 by Contributors * \file mirror.h * \brief support for mirror * \author Tianqi Chen */ #ifndef MSHADOW_EXTENSION_MIRROR_H_ #define MSHADOW_EXTENSION_MIRROR_H_ namespace mshadow { namespace expr { /*! * \brief mirror expression, mirror a image in width * \tparam SrcExp source expression to be mirrored * \tparam DType the type of elements * \tparam srcdim dimension of src */ template struct MirroringExp: public MakeTensorExp, SrcExp, srcdim, DType> { /*! \brief source operand */ const SrcExp &src_; /*! \brief constructor */ explicit MirroringExp(const SrcExp &src) : src_(src) { this->shape_ = ShapeCheck::Check(src_); } }; /*! * \brief mirroring expression, mirror images in width * \param src original image batches * \return expression corresponding to mirrored result * \tparam SrcExp source expression * \tparam DType the type of elements * \tparam etype type of expression */ template inline MirroringExp::kDim> mirror(const Exp &src) { TypeCheckPass::kDim >= 2> ::Error_Expression_Does_Not_Meet_Dimension_Req(); return MirroringExp::kDim>(src.self()); } //---------------------- // Execution plan //---------------------- template struct Plan, DType> { public: explicit Plan(const MirroringExp &e) : src_(MakePlan(e.src_)), width_(e.shape_[srcdim - 1]) {} MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { return src_.Eval(i, width_ - j - 1); } private: Plan src_; const index_t width_; }; } // namespace expr } // namespace mshadow #endif // MSHADOW_EXTENSION_MIRROR_H_ //===== EXPANDED: mxnet/mshadow/mshadow/extension/mirror.h ===== //===== EXPANDIND: mxnet/mshadow/mshadow/extension/concat.h ===== /*! * Copyright (c) 2014 by Contributors * \file concat.h * \brief support for concatenation */ #ifndef MSHADOW_EXTENSION_CONCAT_H_ #define MSHADOW_EXTENSION_CONCAT_H_ namespace mshadow { namespace expr { /*! * \brief concat expression, concat two tensor's channel * \tparam LhsExp left expression * \tparam RhsExp right expression * \tparam DType the type of elements * \tparam srcdim dimension of src * \tparam dimsrc_m_cat dimsrc - dimcat */ template struct ConcatExp : public TRValue, Device, srcdim, DType> { static const int dimcat = srcdim - dimsrc_m_cat; const LhsExp &src1_; const RhsExp &src2_; index_t dcat_src1_; index_t dcat_src2_; Shape<4> shape_; ConcatExp(const LhsExp &src1, const RhsExp &src2) : src1_(src1), src2_(src2) { Shape sshape1 = ShapeCheck::Check(src1_); Shape sshape2 = ShapeCheck::Check(src2_); #pragma unroll for (int i = 0; i < srcdim; ++i) { if (i != dimcat) { CHECK_EQ(sshape1[i], sshape2[i]) << "ConcatExp: shape mismatch"; } } this->shape_ = sshape1; this->shape_[dimcat] = sshape1[dimcat] + sshape2[dimcat]; this->dcat_src1_ = sshape1[dimcat]; this->dcat_src2_ = sshape2[dimcat]; } template inline void operator=(const expr::Exp &exp) { this->__assign(exp); } inline void operator=(const DType &exp) { this->__assign(exp); } }; // struct ConcatExp /*! * \brief concat two 4D tensor * \param src1 source tensor1 * \param src2 source tensor2 * \return concated 4D tensor * \tparam cdim the dimension to concatnate on * \tparam SrcExp source expression * \tparam DType the type of elements * \tparam etype type of expression */ template inline ConcatExp concat(const TRValue &src1, const TRValue &src2) { TypeCheckPass::kDim == ExpInfo::kDim> ::Error_Expression_Does_Not_Meet_Dimension_Req(); TypeCheckPass::kDim == srcdim> ::Error_Expression_Does_Not_Meet_Dimension_Req(); return ConcatExp (src1.self(), src2.self()); } //------------------------ // engine plugin //------------------------ // runtime shapecheck template struct ShapeCheck >{ inline static Shape Check(const ConcatExp &t) { return t.shape_; } }; template struct StreamInfo >{ inline static Stream * Get(const ConcatExp &t) { Stream *lhs = StreamInfo::Get(t.src1_); Stream *rhs = StreamInfo::Get(t.src2_); if (lhs != rhs) return NULL; return lhs; } }; // static typecheck template struct ExpInfo >{ static const int kDimLhs = ExpInfo::kDim; static const int kDimRhs = ExpInfo::kDim; // copy from binarymap static const int kDim = (kDimLhs >= 0 && kDimRhs >= 0) ?\ (kDimLhs == 0 ?\ kDimRhs :\ ((kDimRhs == 0 || kDimLhs == kDimRhs) ? kDimLhs : -1)) : -1; static const int kDevMask = ExpInfo::kDevMask & ExpInfo::kDevMask; }; //---------------------- // Execution plan //--------------------- template struct Plan, DType> { public: static const int dimcat = srcdim - dimsrc_m_cat; explicit Plan(const ConcatExp &e) : src1_(MakePlan(e.src1_)), src2_(MakePlan(e.src2_)), height_(e.shape_.ProdShape(dimcat + 1, srcdim - 1)), ch_src1_(e.dcat_src1_), ch_src2_(e.dcat_src2_), ch_(e.shape_[dimcat]) {} MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { const index_t y = i % height_; i /= height_; const index_t c = i % ch_; const index_t b = i / ch_; const index_t x = j; if (c < ch_src1_) { return src1_.Eval((b * ch_src1_ + c) * height_ + y, x); } else { return src2_.Eval((b * ch_src2_ + c - ch_src1_) * height_ + y, x); } } MSHADOW_XINLINE DType &REval(index_t i, index_t j) { const index_t y = i % height_; i /= height_; const index_t c = i % ch_; const index_t b = i / ch_; const index_t x = j; if (c < ch_src1_) { return src1_.REval((b * ch_src1_ + c) * height_ + y, x); } else { return src2_.REval((b * ch_src2_ + c - ch_src1_) * height_ + y, x); } } private: Plan src1_; Plan src2_; const index_t height_, ch_src1_, ch_src2_, ch_; }; // struct Plan // specialize for concat in x template struct Plan, DType> { public: explicit Plan(const ConcatExp &e) : src1_(MakePlan(e.src1_)), src2_(MakePlan(e.src2_)), width_src1_(e.dcat_src1_) {} MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { if (x < width_src1_) { return src1_.Eval(y, x); } else { return src2_.Eval(y, x - width_src1_); } } MSHADOW_XINLINE DType &REval(index_t y, index_t x) { if (x < width_src1_) { return src1_.REval(y, x); } else { return src2_.REval(y, x - width_src1_); } } private: Plan src1_; Plan src2_; const index_t width_src1_; }; } // namespace expr } // namespace mshadow #endif // MSHADOW_EXTENSION_CONCAT_H_ //===== EXPANDED: mxnet/mshadow/mshadow/extension/concat.h ===== //===== EXPANDIND: mxnet/mshadow/mshadow/extension/choose.h ===== /*! * Copyright (c) 2014 by Contributors * \file choose.h * \brief support for implicit array selection operation * \author Tianqi Chen */ #ifndef MSHADOW_EXTENSION_CHOOSE_H_ #define MSHADOW_EXTENSION_CHOOSE_H_ namespace mshadow { namespace expr { /*! * \brief Make a choice of index in the lowest changing dimension. * \tparam SrcExp type of lhs expression * \tparam IndexExp type of index expression * \tparam DType the type of elements */ template struct MatChooseRowElementExp: public Exp, DType, type::kChainer> { /*! \brief source operand */ const SrcExp &src_; /*! \brief index operand */ const IndexExp &index_; /*! \brief constructor */ MatChooseRowElementExp(const SrcExp &src, const IndexExp &index) : src_(src), index_(index) {} }; template inline MatChooseRowElementExp mat_choose_row_element(const Exp &src, const Exp &index) { TypeCheckPass::kDim == 2 && ExpInfo::kDim == 1> ::Error_Expression_Does_Not_Meet_Dimension_Req(); return MatChooseRowElementExp(src.self(), index.self()); } //---------------------- // Execution plan //---------------------- template struct Plan, DType> { public: explicit Plan(const MatChooseRowElementExp &e) : src_(MakePlan(e.src_)), index_(MakePlan(e.index_)) { } MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { index_t idx = static_cast(index_.Eval(0, x)); return src_.Eval(x, idx); } private: expr::Plan src_; expr::Plan index_; }; template inline Plan, DType> MakePlan(const MatChooseRowElementExp &exp) { return Plan, DType>(exp); } template struct ShapeCheck > { inline static Shape Check(const MatChooseRowElementExp &t) { CHECK(dim == 1) << "MatChooseRowElementExp only support 1 dimension output"; Shape<2> shape1 = ShapeCheck<2, SrcExp>::Check(t.src_); Shape shape2 = ShapeCheck::Check(t.index_); CHECK_EQ(shape1[0], shape2[0]) << "mat_choose_row_element index length and number of rows in matrix"; return shape2; } }; template struct ExpInfo > { static const int kDim = 1; static const int kDevMask = ExpInfo::kDevMask & ExpInfo::kDevMask; }; } // namespace expr } // namespace mshadow #endif // MSHADOW_EXTENSION_CHOOSE_H_ //===== EXPANDED: mxnet/mshadow/mshadow/extension/choose.h ===== //===== EXPANDIND: mxnet/mshadow/mshadow/extension/one_hot.h ===== /*! * Copyright (c) 2014 by Contributors * \file one_hot.h * \brief Create one-hot indicator array based on the index. * \author Tianqi Chen */ #ifndef MSHADOW_EXTENSION_ONE_HOT_H_ #define MSHADOW_EXTENSION_ONE_HOT_H_ namespace mshadow { namespace expr { /*! * \brief Create a one-hot indicator array. * \tparam IndexExp type of index expression * \tparam DType the type of elements */ template struct OneHotEncodeExp: public Exp, DType, type::kChainer> { /*! \brief index operand */ const IndexExp &index_; /*! \brief number of choices we can have. */ index_t num_choices_; /*! \brief constructor */ OneHotEncodeExp(const IndexExp &index, index_t num_choices) : index_(index), num_choices_(num_choices) {} }; template inline OneHotEncodeExp one_hot_encode(const Exp &index, index_t num_choices) { TypeCheckPass::kDim == 1> ::Error_Expression_Does_Not_Meet_Dimension_Req(); return OneHotEncodeExp(index.self(), num_choices); } //---------------------- // Execution plan //---------------------- template struct Plan, DType> { public: explicit Plan(const OneHotEncodeExp &e) : index_(MakePlan(e.index_)) { } MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { index_t idx = static_cast(index_.Eval(0, y)); return static_cast(x == idx); } private: expr::Plan index_; }; template inline Plan, DType> MakePlan(const OneHotEncodeExp &exp) { return Plan, DType>(exp); } template struct ShapeCheck > { inline static Shape Check(const OneHotEncodeExp &t) { CHECK(dim == 2) << "OneHotEncodeExp only support 2 dimension output"; Shape<1> shape = ShapeCheck<1, IndexExp>::Check(t.index_); Shape ret; ret[0] = shape[0]; ret[1] = t.num_choices_; return ret; } }; template struct ExpInfo > { static const int kDim = 2; static const int kDevMask = ExpInfo::kDevMask; }; } // namespace expr } // namespace mshadow #endif // MSHADOW_EXTENSION_ONE_HOT_H_ //===== EXPANDED: mxnet/mshadow/mshadow/extension/one_hot.h ===== //===== EXPANDIND: mxnet/mshadow/mshadow/extension/slice.h ===== /*! * Copyright (c) 2014 by Contributors * \file slice.h * \brief support for slice a certain dimension. */ #ifndef MSHADOW_EXTENSION_SLICE_H_ #define MSHADOW_EXTENSION_SLICE_H_ namespace mshadow { namespace expr { /*! * \brief slice expression, slice a tensor's channel * \tparam SrcExp left expression * \tparam DType the type of elements * \tparam srcdim dimension of src * \tparam dimsrc_m_cat dimsrc - dimcat */ template struct SliceExp : public TRValue, Device, srcdim, DType> { static const int dimslice = srcdim - dimsrc_m_slice; const SrcExp &src_; index_t ch_begin_; index_t ch_old_; Shape shape_; SliceExp(const SrcExp &src, index_t begin, index_t end) : src_(src), ch_begin_(begin) { shape_ = ShapeCheck::Check(src_); ch_old_ = shape_[dimslice]; CHECK(begin < shape_[dimslice] && end <= shape_[dimslice]) << "The slice went out of range"; shape_[dimslice] = end - begin; } template inline void operator=(const expr::Exp &exp) { this->__assign(exp); } inline void operator=(const DType &exp) { this->__assign(exp); } }; // struct Slice /*! * \brief Slice a Tensor * \param src source tensor * \param begin The beginning slice. * \param end The end slice. * \return sliced tensor * \tparam sdim the dimension to slice on * \tparam SrcExp source expression * \tparam DType the type of elements * \tparam etype type of expression */ template inline SliceExp slice(const TRValue &src, index_t begin, index_t end) { TypeCheckPass::kDim == srcdim> ::Error_Expression_Does_Not_Meet_Dimension_Req(); return SliceExp(src.self(), begin, end); } //------------------------ // engine plugin //------------------------ // runtime shapecheck template struct ShapeCheck >{ inline static Shape Check(const SliceExp &t) { return t.shape_; } }; template struct StreamInfo >{ inline static Stream * Get(const SliceExp &t) { return StreamInfo::Get(t.src_); } }; // static typecheck template struct ExpInfo >{ static const int kDim = ExpInfo::kDim; static const int kDevMask = ExpInfo::kDevMask; }; //---------------------- // Execution plan //--------------------- template struct Plan, DType> { public: static const int dimslice = srcdim - dimsrc_m_slice; explicit Plan(const SliceExp &e) : src_(MakePlan(e.src_)), height_(e.shape_.ProdShape(dimslice + 1, srcdim - 1)), ch_begin_(e.ch_begin_), ch_old_(e.ch_old_), ch_(e.shape_[dimslice]) {} MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { const index_t y = i % height_; i /= height_; const index_t c = i % ch_ + ch_begin_; const index_t b = i / ch_; const index_t x = j; return src_.Eval((b * ch_old_ + c) * height_ + y, x); } MSHADOW_XINLINE DType &REval(index_t i, index_t j) { const index_t y = i % height_; i /= height_; const index_t c = i % ch_ + ch_begin_; const index_t b = i / ch_; const index_t x = j; return src_.REval((b * ch_old_ + c) * height_ + y, x); } private: Plan src_; const index_t height_, ch_begin_, ch_old_, ch_; }; // struct Plan template struct Plan, DType> { public: explicit Plan(const SliceExp &e) : src_(MakePlan(e.src_)), ch_begin_(e.ch_begin_) {} MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { return src_.Eval(y, x + ch_begin_); } MSHADOW_XINLINE DType &REval(index_t y, index_t x) { return src_.REval(y, x + ch_begin_); } private: Plan src_; const index_t ch_begin_; }; } // namespace expr } // namespace mshadow #endif // MSHADOW_EXTENSION_SLICE_H_ //===== EXPANDED: mxnet/mshadow/mshadow/extension/slice.h ===== #endif // MSHADOW_EXTENSION_H_ //===== EXPANDED: mxnet/mshadow/mshadow/extension.h ===== //===== EXPANDIND: mxnet/mshadow/mshadow/tensor_cpu-inl.h ===== /*! * Copyright (c) 2014 by Contributors * \file tensor_cpu-inl.h * \brief implementation of CPU host code * \author Bing Xu, Tianqi Chen */ #ifndef MSHADOW_TENSOR_CPU_INL_H_ #define MSHADOW_TENSOR_CPU_INL_H_ namespace mshadow { template<> inline void InitTensorEngine(int dev_id) { } template<> inline void ShutdownTensorEngine(void) { } template<> inline void SetDevice(int devid) { } template<> inline Stream *NewStream(bool create_blas_handle, bool create_dnn_handle) { return new Stream(); } template<> inline void DeleteStream(Stream *stream) { delete stream; } template inline std::ostream &operator<<(std::ostream &os, const Shape &shape) { // NOLINT(*) os << "("; for (int i = 0; i < ndim; ++i) { if (i != 0) os << ","; os << shape[i]; } os << ")"; return os; } template inline void *AllocHost_(size_t size); template inline void FreeHost_(void * dptr); #ifdef __CUDACC__ template<> inline void *AllocHost_(size_t size) { void *dptr; MSHADOW_CUDA_CALL(cudaMallocHost(&dptr, size, cudaHostAllocPortable)); return dptr; } template<> inline void FreeHost_(void *dptr) { MSHADOW_CUDA_CALL(cudaFreeHost(dptr)); } #endif template<> inline void *AllocHost_(size_t size) { size_t pitch; return packet::AlignedMallocPitch(&pitch, size, 1); } template<> inline void FreeHost_(void *dptr) { packet::AlignedFree(dptr); } template inline void AllocHost(Tensor *obj) { obj->stride_ = obj->size(dim - 1); CHECK_EQ(obj->CheckContiguous(), true) << "AllocHost"; void *dptr = AllocHost_(obj->MSize() * sizeof(DType)); obj->dptr_ = reinterpret_cast(dptr); } template inline void FreeHost(Tensor *obj) { if (obj->dptr_ == NULL) { LOG(FATAL) << "FreeHost:: double free"; } FreeHost_(obj->dptr_); obj->dptr_ = NULL; } template inline void AllocSpace(Tensor *obj, bool pad) { size_t pitch; void *dptr; if (pad) { dptr = packet::AlignedMallocPitch (&pitch, obj->size(dim - 1) * sizeof(DType), obj->shape_.FlatTo2D()[0]); obj->stride_ = static_cast(pitch / sizeof(DType)); } else { obj->stride_ = obj->size(dim - 1); dptr = packet::AlignedMallocPitch (&pitch, obj->shape_.Size() * sizeof(DType), 1); } obj->dptr_ = reinterpret_cast(dptr); } template inline Tensor NewTensor(const Shape &shape, DType initv, bool pad, Stream *stream_) { Tensor obj(shape); obj.stream_ = stream_; AllocSpace(&obj, pad); MapExp(&obj, expr::ScalarExp(initv)); return obj; } template inline void FreeSpace(Tensor *obj) { packet::AlignedFree(obj->dptr_); obj->dptr_ = NULL; } template inline void Copy(Tensor _dst, const Tensor &_src, Stream *stream) { CHECK_EQ(_dst.shape_, _src.shape_) << "Copy:shape mismatch:" << _dst.shape_ << " vs " << _src.shape_; if (_dst.CheckContiguous() && _src.CheckContiguous()) { memcpy(_dst.dptr_, _src.dptr_, sizeof(DType) * _dst.shape_.Size()); } else { Tensor dst = _dst.FlatTo2D(); Tensor src = _src.FlatTo2D(); for (index_t y = 0; y < dst.size(0); ++y) { memcpy(dst[y].dptr_, src[y].dptr_, sizeof(DType) * dst.size(1)); } } } template inline void MapPlan(TRValue *dst, const expr::Plan &plan) { Shape<2> shape = expr::ShapeCheck::Check(dst->self()).FlatTo2D(); expr::Plan dplan = expr::MakePlan(dst->self()); for (index_t y = 0; y < shape[0]; ++y) { for (index_t x = 0; x < shape[1]; ++x) { // trust your compiler! -_- they will optimize it Saver::Save(dplan.REval(y, x), plan.Eval(y, x)); } } } // code to handle SSE optimization template struct MapExpCPUEngine { inline static void Map(TRValue *dst, const expr::Exp &exp) { MapPlan(dst, MakePlan(exp.self())); } }; template struct MapExpCPUEngine, dim, DType, E, etype> { inline static void Map(Tensor *dst, const expr::Exp &exp) { if (expr::PacketAlignCheck::Check(exp.self()) && expr::PacketAlignCheck, MSHADOW_DEFAULT_PACKET>::Check(*dst)) { expr::MapPacketPlan(dst->self(), expr::MakePacketPlan(exp.self())); } else { MapPlan(dst, MakePlan(exp.self())); } } }; template inline void MapExp(TRValue *dst, const expr::Exp &exp) { expr::TypeCheckPass::kMapPass> ::Error_All_Tensor_in_Exp_Must_Have_Same_Type(); Shape eshape = expr::ShapeCheck::Check(exp.self()); Shape dshape = expr::ShapeCheck::Check(dst->self()); CHECK(eshape[0] == 0 || eshape == dshape) << "Assignment: Shape of Tensors are not consistent with target"; MapExpCPUEngine::kPass, Saver, R, dim, DType, E, etype> ::Map(dst->ptrself(), exp); } template inline void MapReduceKeepLowest(TRValue *dst, const expr::Exp &exp, DType scale) { expr::TypeCheckPass::kRedPass> ::Error_TypeCheck_Not_Pass_For_Reduce_Exp(); Shape<2> eshape = expr::ShapeCheck::kDim, E> ::Check(exp.self()).FlatTo2D(); Shape<1> dshape = expr::ShapeCheck<1, R>::Check(dst->self()); CHECK_EQ(eshape[1], dshape[0]) << "MapReduceKeepLowest::reduction dimension do not match"; CHECK_NE(eshape[0], 0) << "can not reduce over empty tensor"; // execution expr::Plan dplan = MakePlan(dst->self()); expr::Plan splan = MakePlan(exp.self()); for (index_t x = 0; x < eshape[1]; ++x) { DType res = splan.Eval(0, x); for (index_t y = 1; y < eshape[0]; ++y) { Reducer::Reduce(res, splan.Eval(y, x)); } Saver::Save(dplan.REval(0, x), res * scale); } } template inline void MapReduceKeepHighDim(TRValue *dst, const expr::Exp &exp, DType scale) { expr::TypeCheckPass::kRedPass> ::Error_TypeCheck_Not_Pass_For_Reduce_Exp(); typedef Shape::kDim> EShape; EShape eshape = expr::ShapeCheck::kDim, E> ::Check(exp.self()); Shape<1> dshape = expr::ShapeCheck<1, R>::Check(dst->self()); CHECK_EQ(eshape[dimkeep], dshape[0]) << "MapReduceKeepHighDim::reduction dimension do not match"; // use equvalent form Shape<4> pshape = Shape4(eshape.ProdShape(0, dimkeep), eshape[dimkeep], eshape.ProdShape(dimkeep + 1, EShape::kSubdim), eshape[EShape::kSubdim]); // execution expr::Plan dplan = MakePlan(dst->self()); expr::Plan splan = MakePlan(exp.self()); for (index_t c = 0; c < pshape[1]; ++c) { DType res; Reducer::SetInitValue(res); for (index_t n = 0; n < pshape[0]; ++n) { DType tres; Reducer::SetInitValue(tres); for (index_t y = 0; y < pshape[2]; ++y) { for (index_t x = 0; x < pshape[3]; ++x) { Reducer::Reduce(tres, splan.Eval((n * pshape[1] + c) * pshape[2] + y, x)); } } Reducer::Reduce(res, tres); } Saver::Save(dplan.REval(0, c), res * scale); } } template inline void Softmax(Tensor dst, const Tensor &energy) { DType mmax = energy[0]; for (index_t x = 1; x < dst.size(0); ++x) { if (mmax < energy[x]) mmax = energy[x]; } DType sum = 0.0f; for (index_t x = 0; x < dst.size(0); ++x) { dst[x] = std::exp(energy[x] - mmax); sum += dst[x]; } for (index_t x = 0; x < dst.size(0); ++x) { dst[x] /= sum; } } template inline void SoftmaxGrad(Tensor dst, const Tensor &src, const Tensor &label) { for (index_t y = 0; y < dst.size(0); ++y) { const index_t k = static_cast(label[y]); for (index_t x = 0; x < dst.size(1); ++x) { if (x == k) { dst[y][k] = src[y][k] - 1.0f; } else { dst[y][x] = src[y][x]; } } } } template inline void SoftmaxGrad(Tensor dst, const Tensor &src, const Tensor &label) { for (index_t n = 0; n < dst.size(2); ++n) { for (index_t y = 0; y < dst.size(0); ++y) { const index_t k = static_cast(label[y][n]); for (index_t x = 0; x < dst.size(1); ++x) { if (x == k) { dst[y][k][n] = src[y][k][n] - 1.0f; } else { dst[y][x][n] = src[y][x][n]; } } } } } template inline void Softmax(Tensor dst, const Tensor &energy) { CHECK_EQ(dst.shape_, energy.shape_) << "Softmax: shape mismatch"; for (index_t y = 0; y < dst.size(0); ++y) { Softmax(dst[y], energy[y]); } } template inline void Softmax(Tensor dst, const Tensor &energy) { CHECK_EQ(dst.shape_, energy.shape_) << "Softmax: shape mismatch"; for (index_t y = 0; y < dst.size(0); ++y) { for (index_t n = 0; n < dst.size(2); ++n) { DType mmax = energy[y][0][n]; for (index_t x = 1; x < dst.size(1); ++x) { if (mmax < energy[y][x][n]) mmax = energy[y][x][n]; } DType sum = 0.0f; for (index_t x = 0; x < dst.size(1); ++x) { dst[y][x][n] = std::exp(energy[y][x][n] - mmax); sum += dst[y][x][n]; } for (index_t x = 0; x < dst.size(1); ++x) { dst[y][x][n] /= sum; } } } } // blas related template inline void VectorDot(Tensor dst, const Tensor &lhs, const Tensor &rhs) { CHECK_EQ(lhs.size(0), rhs.size(0)) << "VectorDot: Shape mismatch"; CHECK_EQ(dst.size(0), 1) << "VectorDot: expect dst to be scalar"; expr::BLASEngine::SetStream(lhs.stream_); mshadow::expr::BLASEngine::dot( lhs.stream_, lhs.size(0), lhs.dptr_, 1, rhs.dptr_, 1, dst.dptr_); } } // namespace mshadow #endif // MSHADOW_TENSOR_CPU_INL_H_ //===== EXPANDED: mxnet/mshadow/mshadow/tensor_cpu-inl.h ===== //===== EXPANDIND: mxnet/mshadow/mshadow/tensor_gpu-inl.h ===== /*! * Copyright (c) 2014 by Contributors * \file tensor_gpu-inl.h * \brief implementation of GPU host code * \author Bing Xu, Tianqi Chen */ #ifndef MSHADOW_TENSOR_GPU_INL_H_ #define MSHADOW_TENSOR_GPU_INL_H_ namespace mshadow { #if MSHADOW_USE_CUDA template<> inline void InitTensorEngine(int dev_id) { cudaDeviceProp prop; int device_id = 0; int device_count = 0; cudaGetDeviceCount(&device_count); CHECK_GT(device_count, 0) << "Cannot find CUDA device. Please check CUDA-Configuration"; if (dev_id < 0) { device_id = 0; } else { device_id = dev_id; } CHECK_LT(device_id, device_count) << "Incorrect Device ID"; MSHADOW_CUDA_CALL(cudaSetDevice(device_id)); MSHADOW_CUDA_CALL(cudaGetDeviceProperties(&prop, device_id)); } template<> inline void ShutdownTensorEngine(void) { } template<> inline void SetDevice(int devid) { MSHADOW_CUDA_CALL(cudaSetDevice(devid)); } template inline void AllocSpace(Tensor *obj, bool pad) { size_t pitch; // common choice for cuda mem align unit is 32 if (pad && obj->size(dim - 1) >= MSHADOW_MIN_PAD_RATIO * 32) { MSHADOW_CUDA_CALL(cudaMallocPitch(reinterpret_cast(&(obj->dptr_)), &pitch, obj->size(dim - 1) * sizeof(DType), obj->shape_.FlatTo2D()[0])); obj->stride_ = static_cast(pitch / sizeof(DType)); } else { obj->stride_ = obj->size(dim - 1); MSHADOW_CUDA_CALL(cudaMallocPitch(reinterpret_cast(&(obj->dptr_)), &pitch, obj->shape_.Size() * sizeof(DType), 1)); } } template inline void FreeSpace(Tensor *obj) { MSHADOW_CUDA_CALL(cudaFree(obj->dptr_)); obj->dptr_ = NULL; } template inline void Copy(Tensor _dst, Tensor _src, cudaMemcpyKind kind, Stream *stream) { CHECK_EQ(_dst.shape_, _src.shape_) << "Copy:shape mismatch"; Tensor dst = _dst.FlatTo2D(); Tensor src = _src.FlatTo2D(); MSHADOW_CUDA_CALL(cudaMemcpy2DAsync(dst.dptr_, dst.stride_ * sizeof(DType), src.dptr_, src.stride_ * sizeof(DType), dst.size(1) * sizeof(DType), dst.size(0), kind, Stream::GetStream(stream))); // use synchronize call behavior for zero stream if (stream == NULL) { MSHADOW_CUDA_CALL(cudaStreamSynchronize(0)); } } template inline void Copy(Tensor dst, const Tensor &src, Stream *stream) { Copy(dst, src, cudaMemcpyDeviceToHost, stream); } template inline void Copy(Tensor dst, const Tensor &src, Stream *stream) { Copy(dst, src, cudaMemcpyDeviceToDevice, stream); } template inline void Copy(Tensor dst, const Tensor &src, Stream *stream) { Copy(dst, src, cudaMemcpyHostToDevice, stream); } #endif // MSHADOW_USE_CUDA } // namespace mshadow // the following part is included only if compiler is nvcc #ifdef __CUDACC__ namespace mshadow { template inline void MapExp(TRValue *dst, const expr::Exp &exp) { expr::TypeCheckPass::kMapPass> ::Error_All_Tensor_in_Exp_Must_Have_Same_Type(); Shape eshape = expr::ShapeCheck::Check(exp.self()); Shape dshape = expr::ShapeCheck::Check(dst->self()); CHECK(eshape[0] == 0 || eshape == dshape) << "Assignment: Shape of Tensors are not consistent with target"; cuda::MapPlan(MakePlan(dst->self()), MakePlan(exp.self()), dshape.FlatTo2D(), Stream::GetStream(expr::StreamInfo::Get(dst->self()))); } template inline void MapReduceKeepLowest(TRValue *dst, const expr::Exp &exp, DType scale) { expr::TypeCheckPass::kRedPass> ::Error_TypeCheck_Not_Pass_For_Reduce_Exp(); Shape<2> eshape = expr::ShapeCheck::kDim, E> ::Check(exp.self()).FlatTo2D(); Shape<1> dshape = expr::ShapeCheck<1, R>::Check(dst->self()); CHECK_EQ(eshape[1], dshape[0]) << "MapReduceKeepLowest::reduction dimension do not match"; CHECK_NE(eshape[0], 0) << "can not reduce over empty tensor"; cuda::MapReduceKeepLowest (MakePlan(dst->self()), MakePlan(exp.self()), scale, eshape, Stream::GetStream(expr::StreamInfo::Get(dst->self()))); } template inline void MapReduceKeepHighDim(TRValue *dst, const expr::Exp &exp, DType scale) { expr::TypeCheckPass::kRedPass> ::Error_TypeCheck_Not_Pass_For_Reduce_Exp(); typedef Shape::kDim> EShape; EShape eshape = expr::ShapeCheck::kDim, E> ::Check(exp.self()); Shape<1> dshape = expr::ShapeCheck<1, R>::Check(dst->self()); CHECK_EQ(eshape[dimkeep], dshape[0]) << "MapReduceKeepHighDim::reduction dimension do not match"; // use equvalent form Shape<4> pshape = Shape4(eshape.ProdShape(0, dimkeep), eshape[dimkeep], eshape.ProdShape(dimkeep + 1, EShape::kSubdim), eshape[EShape::kSubdim]); // call equavalent map red dim 2 cuda::MapReduceKeepDim1 (MakePlan(dst->self()), MakePlan(exp.self()), scale, pshape, Stream::GetStream(expr::StreamInfo::Get(dst->self()))); } template inline void Softmax(Tensor dst, const Tensor& src) { cuda::Softmax(dst, src); } template inline void Softmax(Tensor dst, const Tensor& src) { cuda::Softmax(dst, src); } template inline void SoftmaxGrad(Tensor dst, const Tensor &src, const Tensor &label) { cuda::SoftmaxGrad(dst, src, label); } template inline void SoftmaxGrad(Tensor dst, const Tensor &src, const Tensor &label) { cuda::SoftmaxGrad(dst, src, label); } } // namespace mshadow #endif // __CUDACC__ #endif // MSHADOW_TENSOR_GPU_INL_H_ //===== EXPANDED: mxnet/mshadow/mshadow/tensor_gpu-inl.h ===== //===== EXPANDIND: mxnet/mshadow/mshadow/io.h ===== /*! * Copyright (c) 2014 by Contributors * \file io.h * \brief definitions of I/O functions for mshadow tensor * \author Tianqi Chen */ #ifndef MSHADOW_IO_H_ #define MSHADOW_IO_H_ namespace mshadow { namespace utils { /*! * \brief interface of stream I/O, used to serialize data, * mshadow does not restricted to only this interface in SaveBinary/LoadBinary * mshadow accept all class that implements Read and Write */ class IStream { public: /*! * \brief read data from stream * \param ptr pointer to memory buffer * \param size size of block * \return usually is the size of data readed */ virtual size_t Read(void *ptr, size_t size) = 0; /*! * \brief write data to stream * \param ptr pointer to memory buffer * \param size size of block */ virtual void Write(const void *ptr, size_t size) = 0; /*! \brief virtual destructor */ virtual ~IStream(void) {} }; } // namespace utils /*! * \brief CPU/GPU: save a tensor by binary format, for GPU version, a temp Tensor storage will be allocated * \param fo output binary stream * \param src source data file * \tparam dim dimension of tensor * \tparam DType type of element in tensor * \tparam TStream type of stream, need to support Read, Write, one example is utils::IStream. */ template inline void SaveBinary(TStream &fo, const Tensor &src); // NOLINT(*) /*! * \brief CPU/GPU: save a tensor by binary format, for GPU version, a temp Tensor storage will be allocated * \param fo output binary stream * \param src source data file * \tparam dim dimension of tensor * \tparam DType type of element in tensor * \tparam TStream type of stream, need to support Read, Write, one example is utils::IStream. */ template inline void SaveBinary(TStream &fo, const Tensor &src); // NOLINT(*) /*! * \brief CPU/GPU: load a tensor by binary format, for GPU version, a temp Tensor storage will be allocated * if pre_alloc is true , then space in dst is preallocated, and must have same shape of the tensor loaded * if pre_alloc is false, then dst originally does not have space allocated, LoadBinary will allocate space for dst * \param fi output binary stream * \param dst destination file * \param pre_alloc whether space is pre-allocated, if false, space allocation will happen * \tparam dim dimension of tensor * \tparam DType type of element in tensor * \tparam TStream type of stream, need to support Read, Write, one example is utils::IStream. */ template inline void LoadBinary(TStream &fi, // NOLINT(*) Tensor *dst, bool pre_alloc); /*! * \brief CPU/GPU: load a tensor by binary format, for GPU version, a temp Tensor storage will be allocated * if pre_alloc is true , then space in dst is preallocated, and must have same shape of the tensor loaded * if pre_alloc is false, then dst originally does not have space allocated, LoadBinary will allocate space for dst * \param fi output binary stream * \param dst destination file * \param pre_alloc whether space is pre-allocated, if false, space allocation will happen * \tparam dim dimension of tensor * \tparam DType type of element in tensor * \tparam TStream type of stream, need to support Read, Write, one example is utils::IStream. */ template inline void LoadBinary(TStream &fi, // NOLINT(*) Tensor *dst, bool pre_alloc); // implementations template inline void SaveBinary(TStream &fo, const Tensor &src_) { // NOLINT(*) fo.Write(&src_.shape_, sizeof(src_.shape_)); Tensor src = src_.FlatTo2D(); for (index_t i = 0; i < src.size(0); ++i) { fo.Write(src[i].dptr_, sizeof(DType) * src.size(1)); } } template inline void SaveBinary(TStream &fo, const Tensor &src) { // NOLINT(*) // copy to CPU, then save Tensor tmp(src.shape_); AllocSpace(&tmp); Stream stream; Copy(tmp, src, &stream); SaveBinary(fo, tmp); FreeSpace(&tmp); } template inline void LoadBinary(TStream &fi, // NOLINT(*) Tensor *dst_, bool pre_alloc) { Shape shape; CHECK_NE(fi.Read(&shape, sizeof(shape)), 0) << "mshadow::LoadBinary"; if (pre_alloc) { CHECK_EQ(shape, dst_->shape_) << "LoadBinary, shape do not match pre-allocated shape"; } else { dst_->shape_ = shape; AllocSpace(dst_); } Tensor dst = dst_->FlatTo2D(); if (dst.size(0) == 0) return; for (index_t i = 0; i < dst.size(0); ++i) { CHECK_NE(fi.Read(dst[i].dptr_, sizeof(DType) * dst.size(1)), 0) << "mshadow::LoadBinary"; } } template inline void LoadBinary(TStream &fi, // NOLINT(*) Tensor *dst, bool pre_alloc) { Tensor tmp; LoadBinary(fi, &tmp, false); if (pre_alloc) { CHECK_EQ(tmp.shape, dst->shape_) << "LoadBinary, shape do not match pre-allocated shape"; } else { dst->shape = tmp.shape; AllocSpace(dst); } Stream stream; Copy(*dst, tmp, &stream); FreeSpace(&tmp); } } // namespace mshadow #endif // MSHADOW_IO_H_ //===== EXPANDED: mxnet/mshadow/mshadow/io.h ===== //===== EXPANDIND: mxnet/mshadow/mshadow/tensor_container.h ===== /*! * Copyright (c) 2014 by Contributors * \file tensor_container.h * \brief tensor container that does memory allocation and resize like STL * \author Tianqi Chen */ #ifndef MSHADOW_TENSOR_CONTAINER_H_ #define MSHADOW_TENSOR_CONTAINER_H_ namespace mshadow { /*! * \brief tensor container that does memory allocation and resize like STL, * use it to save the lines of FreeSpace in class. * Do not abuse it, efficiency can come from pre-allocation and no re-allocation * * \tparam Device which device the tensor is on * \tparam dimension dimension of the tensor */ template class TensorContainer: public Tensor { public: /*! * \brief constructor * \param pad whether use padding alignment in space allocation */ explicit TensorContainer(bool pad = MSHADOW_ALLOC_PAD) { this->pad_ = pad; this->dptr_ = data_.dptr_ = NULL; this->shape_[0] = 0; this->stride_ = 0; this->data_.stride_ = 0; this->data_.shape_[0] = 0; } /*! * \brief constructor * \param shape intial shape */ explicit TensorContainer(const Shape &shape) { this->pad_ = MSHADOW_ALLOC_PAD; data_.dptr_ = NULL; this->AllocByShape(shape); } /*! * \brief constructor * \param shape intial shape * \param initv intial value */ explicit TensorContainer(const Shape &shape, DType initv) { this->pad_ = MSHADOW_ALLOC_PAD; data_.dptr_ = NULL; this->AllocByShape(shape); (*this) = initv; } /*! * \brief copy constructor * \param src source value */ TensorContainer (const TensorContainer &src) : pad_(src.pad_) { this->dptr_ = data_.dptr_ = NULL; this->shape_[0] = 0; this->stride_ = 0; this->data_.stride_ = 0; this->data_.shape_[0] = 0; this->stream_ = src.stream_; if (src.dptr_ != NULL) { this->AllocByShape(src.shape_); mshadow::Copy(*this, src, this->stream_); } } ~TensorContainer(void) { this->Release(); } /*! * \brief resize the container to given shape, content is NOT preserved * \param shape target shape */ inline void Resize(const Shape &shape) { Shape<2> s2 = shape.FlatTo2D(); if (s2.shape_[1] > data_.stride_ || s2.shape_[0] > data_.size(0)) { this->AllocByShape(shape); } else { this->shape_ = shape; if (this->pad_) { this->stride_ = data_.stride_; } else { this->stride_ = s2.shape_[1]; } } } /*! * \brief resize the container to given shape, and initialize, content is NOT preserved * \param shape target shape * \param initv initialization value */ inline void Resize(const Shape &shape, DType initv) { this->Resize(shape); (*this) = initv; } /*! \brief set whether padding is allowed in tensor */ inline void set_pad(bool pad) { this->pad_ = pad; } /*! * \brief save by binary format * \param fo output binary stream * \tparam TStream type of stream, need to support Read, Write, one example is utils::IStream. */ template inline void SaveBinary(TStream &fo) const { // NOLINT(*) mshadow::SaveBinary(fo, *this); } /*! * \brief load by binary format, a temp Tensor storage will be allocated * \param fi input binary stream * \tparam TStream type of stream, need to support Read, Write, one example is utils::IStream. */ template inline void LoadBinary(TStream &fi) { // NOLINT(*) Tensor tmp; mshadow::LoadBinary(fi, &tmp, false); this->Resize(tmp.shape_); Stream stream; Copy(*this, tmp, &stream); mshadow::FreeSpace(&tmp); } /*! * \brief assign operator from TensorContainer * \param src source value * \return reference of self */ inline TensorContainer &operator= (const TensorContainer &src) { this->pad_ = src.pad_; this->stream_ = src.stream_; if (src.dptr_ != NULL) { this->Resize(src.shape_); mshadow::Copy(*this, src, this->stream_); } return *this; } /*!\brief functions to fit expression template */ inline Tensor &operator=(DType s) { return this->__assign(s); } /*!\brief functions to fit expression template */ template inline Tensor & operator=(const expr::Exp &exp) { return this->__assign(exp); } /*!\brief functions to fit expression template */ template inline Tensor & operator=(const expr::Exp &exp) { return this->__assign(exp); } /*!\brief functions to fit expression template */ template inline Tensor & operator=(const expr::Exp &exp) { return this->__assign(exp); } /*! * \brief Release the llocated space, * The TensorContainer is still functionable, * but will restart allocating space when Resize is called. */ inline void Release(void) { if (data_.dptr_ != NULL) { mshadow::FreeSpace(&data_); this->dptr_ = data_.dptr_ = NULL; this->shape_[0] = 0; this->stride_ = 0; this->data_.stride_ = 0; this->data_.shape_[0] = 0; } } private: /*! \brief whether we do padding in the space */ bool pad_; /*! \brief the shape of data_ is actually current data space */ Tensor data_; inline void AllocByShape(const Shape& shape) { if (data_.dptr_ != NULL) this->Release(); data_.shape_ = shape.FlatTo2D(); mshadow::AllocSpace(&data_, pad_); this->dptr_ = data_.dptr_; this->shape_ = shape; if (this->pad_) { this->stride_ = data_.stride_; } else { this->stride_ = data_.size(1); } } }; } // namespace mshadow #endif // MSHADOW_TENSOR_CONTAINER_H_ //===== EXPANDED: mxnet/mshadow/mshadow/tensor_container.h ===== //===== EXPANDIND: mxnet/mshadow/mshadow/tensor_blob.h ===== /*! * Copyright (c) 2014 by Contributors * \file tensor_blob.h * \brief TBlob class that holds common representation of * arbirary dimension tensor, can be used to transformed * to normal fixed dimenson tensor * \author Tianqi Chen */ #ifndef MSHADOW_TENSOR_BLOB_H_ #define MSHADOW_TENSOR_BLOB_H_ namespace mshadow { /*! * \brief dynamic shape class that can hold shape * of arbirary dimension */ struct TShape { public: /*! \brief constructor */ TShape() : ndim_(0), num_heap_allocated_(0), data_heap_(NULL) {} /*! * \brief constructor from TShape * \param s the source shape */ TShape(const TShape &s) : ndim_(s.ndim_) { if (ndim_ <= kStackCache) { data_heap_ = NULL; num_heap_allocated_ = 0; std::copy(s.data_stack_, s.data_stack_ + ndim_, data_stack_); } else { data_heap_ = new index_t[ndim_]; num_heap_allocated_ = ndim_; std::copy(s.data_heap_, s.data_heap_ + ndim_, data_heap_); } } /*! * \brief construct the TShape from content of iterator * \param begin the beginning of iterator * \param end end the end of the iterator * \tparam RandomAccessIterator iterator type */ template TShape(RandomAccessIterator begin, RandomAccessIterator end) : ndim_(0), num_heap_allocated_(0), data_heap_(NULL) { this->CopyFrom(begin, end); } #if MSHADOW_IN_CXX11 /*! * \brief move constructor from TShape * \param s the source shape */ TShape(TShape &&s) : ndim_(s.ndim_), num_heap_allocated_(s.num_heap_allocated_), data_heap_(s.data_heap_) { if (ndim_ <= kStackCache) { std::copy(s.data_stack_, s.data_stack_ + ndim_, data_stack_); } // remove data heap space from s s.data_heap_ = NULL; } /*! * \brief move constructor from Shape * \param s the source shape */ template TShape(Shape &&s) // NOLINT(*) : ndim_(0), num_heap_allocated_(0), data_heap_(NULL) { this->CopyFrom(s.shape_, s.shape_ + dim); } #endif /*! \brief destructor */ ~TShape() { // data_heap_ can be NULL delete [] data_heap_; } /*! * \brief copy shape from content betwen two iterators * \param begin the beginning of iterator * \param end the end of the iterator * \tparam RandomAccessIterator iterator type */ template inline void CopyFrom(RandomAccessIterator begin, RandomAccessIterator end) { this->SetDim(end - begin); std::copy(begin, end, data()); } /*! * \brief assignment from shape * \param shape source shape * \return reference of self */ inline TShape &operator=(const TShape &shape) { this->SetDim(shape.ndim_); const index_t *src = shape.data(); std::copy(src, src + ndim_, data()); return *this; } /*! * \brief assignment from vector * \param shape source shape * \return reference of self */ inline TShape &operator=(const std::vector &shape) { this->CopyFrom(shape.begin(), shape.end()); return *this; } /*! * \brief assignment from shape * \param shape source shape * \tparam dim shape dimension * \return reference of self */ template inline TShape &operator=(const Shape &shape) { this->SetDim(dim); index_t *d = dim <= kStackCache ? data_stack_ : data_heap_; for (int i = 0; i < dim; ++i) { d[i] = shape[i]; } return *this; } /*! \return the data content of the shape */ inline const index_t *data() const { return ndim_ <= kStackCache ? data_stack_ : data_heap_; } /*! \return the data content of the shape */ inline index_t *data() { return ndim_ <= kStackCache ? data_stack_ : data_heap_; } /*! \brief return number of dimension of the tensor inside */ inline index_t ndim(void) const { return ndim_; } /*! * \brief get corresponding index * \param i dimension index * \return the corresponding dimension size */ inline index_t &operator[](index_t i) { return data()[i]; } /*! * \brief get corresponding index * \param i dimension index * \return the corresponding dimension size */ inline const index_t &operator[](index_t i) const { return data()[i]; } /*! \brief total number of elements in the tensor */ inline size_t Size(void) const { size_t size = 1; const index_t *d = this->data(); for (index_t i = 0; i < ndim_; ++i) { size *= d[i]; } return size; } /*! * flatten the higher dimension to second dimension, return a 2D shape * \return the flat 2d shape */ inline Shape<2> FlatTo2D(void) const { Shape<2> s; if (ndim_ == 0) return Shape2(0, 0); const index_t *d = this->data(); s.shape_[1] = d[ndim_ - 1]; index_t ymax = 1; for (index_t i = 1; i < ndim_; ++i) { ymax *= d[i - 1]; } s.shape_[0] = ymax; return s; } /*! * \brief get the shape of tensor specifying dim * \return the shape requested * \tparam dim dimension of the tensor */ template inline Shape get(void) const { CHECK_EQ(dim, ndim_) << "dimension do not match target dimension " << dim << " vs " << ndim_; const index_t *d = this->data(); Shape s; for (int i = 0; i < dim; ++i) { s[i] = d[i]; } return s; } /*! * \return whether two shape equals * \param s the shape to compare against */ inline bool operator==(const TShape &s) const { if (ndim_ != s.ndim_) return false; if (ndim_ <= kStackCache) { for (index_t i = 0; i < ndim_; ++i) { if (data_stack_[i] != s.data_stack_[i]) return false; } } else { for (index_t i = 0; i < ndim_; ++i) { if (data_heap_[i] != s.data_heap_[i]) return false; } } return true; } /*! * \return whether two shape not equals * \param s the shape to compare against */ inline bool operator!=(const TShape &s) const { return !(*this == s); } /*! * \return whether two shape equals * \param s the shape to compare against * \tparam dim dimension of the shape */ template inline bool operator==(const Shape &s) const { if (ndim_ != dim) return false; const index_t *d = dim <= kStackCache ? data_stack_ : data_heap_; for (index_t i = 0; i < dim; ++i) { if (d[i] != s.shape_[i]) return false; } return true; } /*! * \return whether two shape not equals * \param s the shape to compare against * \tparam dim dimension of the shape */ template inline bool operator!=(const Shape &s) const { return !(*this == s); } /*! * \brief save the content into binary stream * \param strm the output stream * \tparam TStream any stream type that have write */ template inline void Save(TStream *strm) const { strm->Write(&ndim_, sizeof(ndim_)); strm->Write(data(), sizeof(index_t) * ndim_); } /*! * \brief load the content from binary stream * \param strm the output stream * \tparam TStream any stream type that have write * \return whether the load is successful */ template inline bool Load(TStream *strm) { if (strm->Read(&ndim_, sizeof(ndim_)) != sizeof(ndim_)) return false; this->SetDim(ndim_); size_t nread = sizeof(index_t) * ndim_; if (strm->Read(data(), nread) != nread) return false; return true; } friend std::ostream &operator<<(std::ostream &os, const TShape &shape); friend std::istream &operator>>(std::istream &is, TShape &shape); private: // the shape will be stored in data_stack_ // when dimension is smaller than kStackCache // when it is bigger, it will be stored in data_heap_; /*! \brief size of in stack space */ static const index_t kStackCache = 4; /*! \brief number of dimnsion of the shape */ index_t ndim_; /*! \brief number of cells allocated in data_heap_ */ index_t num_heap_allocated_; /*! \brief in stack space used to store shape when it is small */ index_t data_stack_[kStackCache]; /*! \brief space to store shape when dimension is big*/ index_t *data_heap_; /*! * \brief internal function to set the dimension * \param dim the dimension of the shape */ inline void SetDim(index_t dim) { if (dim > kStackCache && dim > num_heap_allocated_) { // data_heap_ can be NULL delete [] data_heap_; data_heap_ = new index_t[dim]; num_heap_allocated_ = dim; } ndim_ = dim; } }; /*! * \brief allow string printing of the shape * \param os the output stream * \param shape the shape * \return the ostream */ inline std::ostream &operator<<(std::ostream &os, const TShape &shape) { os << '('; for (index_t i = 0; i < shape.ndim(); ++i) { if (i != 0) os << ", "; os << shape[i]; } // python style tuple if (shape.ndim() == 1) os << ','; os << ')'; return os; } /*! * \brief read shape from the istream * \param is the input stream * \param shape the shape * \return the istream */ inline std::istream &operator>>(std::istream &is, TShape &shape) { // get ( while (true) { char ch = is.get(); if (ch == '(') break; if (!isspace(ch)) { is.setstate(std::ios::failbit); return is; } } index_t idx; std::vector tmp; while (is >> idx) { tmp.push_back(idx); char ch; do { ch = is.get(); } while (isspace(ch)); if (ch == ',') { while (true) { ch = is.peek(); if (isspace(ch)) { is.get(); continue; } if (ch == ')') { is.get(); break; } break; } if (ch == ')') break; } else if (ch == ')') { break; } else { is.setstate(std::ios::failbit); return is; } } shape.CopyFrom(tmp.begin(), tmp.end()); return is; } /*! \brief data type flag */ template struct DataType; template<> struct DataType { static const int kFlag = 0; }; template<> struct DataType { static const int kFlag = 1; }; /*! * \brief tensor blob class that can be used to hold tensor of any dimension, * any device and any data type, * This is a weak type that can be used to transfer data through interface * TBlob itself do not involve any arithmentic operations, * but it can be converted to tensor of fixed dimension for further operations * * Like tensor, this data structure is like a pointer class and do not * implicit allocated, de-allocate space. * This data structure can be helpful to hold tensors of different dimensions * and wait for further processing */ class TBlob { public: /*! \brief pointer to the data */ void *dptr_; /*! \brief shape of the tensor */ TShape shape_; /*! * \brief storing the stride information in x dimension */ index_t stride_; /*! \brief device mask of the corresponding device */ int dev_mask_; /*! \brief type flag of the tensor blob */ int type_flag_; /*! \brief default constructor, default copy assign will work */ TBlob(void) : dptr_(NULL), dev_mask_(cpu::kDevMask), type_flag_(DataType::kFlag) {} /*! * \brief constructor that construct TBlob from contiguous memory * \param dptr the pointer to the memory * \param shape the shape of the data * \param dev_mask the device mask, can be cpu::kDevMask or gpu::kDevMask */ template TBlob(DType *dptr, const TShape &shape, int dev_mask) : dptr_(dptr), shape_(shape), stride_(shape[shape.ndim() - 1]), dev_mask_(dev_mask), type_flag_(DataType::kFlag) {} /*! * \brief constructor from tensor * \param src source tensor * \tparam Device which device the tensor is on * \tparam dim tensor dimension * \tparam DType the type of elements in the tensor */ template TBlob(const Tensor &src) { // NOLINT(*) *this = src; } /*! * \brief assignment from tensor * \param src source tensor * \tparam Device which device the tensor is on * \tparam dim tensor dimension * \tparam DType the type of elements in the tensor * \return reference of self */ template inline TBlob &operator=(const Tensor &src) { dptr_ = src.dptr_; shape_ = src.shape_; stride_ = src.stride_; dev_mask_ = Device::kDevMask; type_flag_ = DataType::kFlag; return *this; } /*! * \return whether the tensor's memory is continuous */ inline bool CheckContiguous(void) const { return shape_[shape_.ndim() - 1] == stride_; } /*! * \brief flatten the tensor to 2 dimension, collapse the higher dimensions together * \param stream the possible stream target tensor should reside on * \tparam Device which device the tensor is on * \tparam DType the type of elements in the tensor * \return tensor after flatten */ template inline Tensor FlatTo2D(Stream *stream = NULL) const { CHECK(Device::kDevMask == dev_mask_ && DataType::kFlag == type_flag_) << "TBlob.get: device type do not match specified type"; return Tensor(static_cast(dptr_), shape_.FlatTo2D(), stride_, stream); } /*! \brief return number of dimension of the tensor inside */ inline int ndim(void) const { return shape_.ndim(); } /*! * \brief return size of i-th dimension, start counting from highest dimension * \param idx the dimension count from the highest dimensin * \return the size */ inline index_t size(index_t idx) const { return shape_[idx]; } /*! \brief total number of elements in the tensor */ inline index_t Size(void) const { return shape_.Size(); } /*! * \brief fetch the tensor, with respect to specific dimension * if dim do not match the stored dimension, an error will be issued * \return the tensor requested * \param stream the possible stream target tensor should reside on * \tparam Device which device the tensor is on * \tparam dim dimension of the tensor * \tparam DType the type of elements in the tensor */ template inline Tensor get(Stream *stream = NULL) const { CHECK(Device::kDevMask == dev_mask_ && DataType::kFlag == type_flag_) << "TBlob.get: device type do not match specified type"; return Tensor(static_cast(dptr_), shape_.get(), stride_, stream); } /*! * \brief fetch a tensor in given shape * If size do not match the stored size, an error will be issued * \return the tensor requested * \param shape the shape required * \param stream the possible stream target tensor should reside on * \tparam Device which device the tensor is on * \tparam dim dimension of the tensor * \tparam DType the type of elements in the tensor */ template inline Tensor get_with_shape(const Shape &shape, Stream *stream = NULL) const { CHECK(Device::kDevMask == dev_mask_ && DataType::kFlag == type_flag_) << "TBlob.get_with_shape: device type do not match specified type"; CHECK_EQ(this->CheckContiguous(), true) << "TBlob.get_reshape: must be contiguous"; CHECK_EQ(this->shape_.Size(), shape.Size()) << "TBlob.get_with_shape: new and old shape do not match total elements"; return Tensor(static_cast(dptr_), shape, shape[dim - 1], stream); } }; } // namespace mshadow #endif // MSHADOW_TENSOR_BLOB_H_ //===== EXPANDED: mxnet/mshadow/mshadow/tensor_blob.h ===== //===== EXPANDIND: mxnet/mshadow/mshadow/random.h ===== /*! * Copyright (c) 2014 by Contributors * \file random.h * \brief Random inline functions for tensor. * \author Bing Xu, Tianqi Chen * Based on curand|MKL|stdlib */ #ifndef MSHADOW_RANDOM_H_ #define MSHADOW_RANDOM_H_ #if MSHADOW_IN_CXX11 #endif #if _MSC_VER #define rand_r(x) rand() #endif namespace mshadow { /*! * \brief random number generator * \tparam Device the device of random number generator * \tparam DType the target data type of random number can be float for double */ template class Random {}; /*! \brief CPU random number generator */ template class Random { public: /*! * \brief constructor of random engine * \param seed random number seed */ explicit Random(int seed) { this->Seed(seed); buffer_.Resize(Shape1(kRandBufferSize)); } ~Random(void) { } /*! * \brief seed random number generator using this seed * \param seed seed of prng */ inline void Seed(int seed) { #if MSHADOW_IN_CXX11 rnd_engine_.seed(seed); #else this->rseed_ = static_cast(seed); #endif } /*! * \brief set the stream of computation * \param stream computation stream */ inline void set_stream(Stream *stream) { } /*! * \brief generate data from uniform [a,b) * \param dst destination * \param a lower bound of uniform * \param b upper bound of uniform * \tparam dim dimension of tensor */ template inline void SampleUniform(Tensor *dst, DType a = 0.0f, DType b = 1.0f) { if (dst->CheckContiguous()) { this->GenUniform(dst->dptr_, dst->shape_.Size(), a, b); } else { Tensor mat = dst->FlatTo2D(); for (index_t i = 0; i < mat.size(0); ++i) { this->GenUniform(mat[i].dptr_, mat.size(1), a, b); } } } /*! * \brief generate data from standard gaussian * \param dst destination * \param mu mean variable * \param sigma standard deviation * \tparam dim dimension of tensor */ template inline void SampleGaussian(Tensor *dst, DType mu = 0.0f, DType sigma = 1.0f) { if (sigma <= 0.0f) { *dst = mu; return; } if (dst->CheckContiguous()) { this->GenGaussian(dst->dptr_, dst->shape_.Size(), mu, sigma); } else { Tensor mat = dst->FlatTo2D(); for (index_t i = 0; i < mat.size(0); ++i) { this->GenGaussian(mat[i].dptr_, mat.size(1), mu, sigma); } } } /*! * \brief return a temporal expression storing standard gaussian random variables * the temporal tensor is only valid before next call of gaussian or uniform * can be used as part of expression * Caution: this means expression such as A = gaussian(s1) * gaussian(s2) will give invalid result, * since second call of gaussian(s2) makes gaussian(s1) invalid * A = gaussian(s1)*B+C; is correct; use one gaussian/uniform in each expression * \param shape shape of the tensor * \return a temporal expression storing standard gaussian random variables * \tparam dim dimension of tensor */ template inline expr::ReshapeExp, DType, dim, 1> gaussian(Shape shape) { buffer_.Resize(Shape1(shape.Size())); this->SampleGaussian(&buffer_, 0.0f, 1.0f); return expr::reshape(buffer_, shape); } /*! * \brief return a temporal expression storing standard uniform [0,1) * the temporal tensor is only valid before next call of gaussian or uniform * can be used as part of expression * Caution: this means expression such as A = uniform(s1) * uniform(s2) will give invalid result, * since second call of gaussian(s2) makes gaussian(s1) invalid * A = gaussian(s1)*B+C; is correct; use one gaussian/uniform in each expression * \param shape shape of the tensor * \return a temporal expression storing standard uniform [0,1) * \tparam dim dimension of tensor */ template inline expr::ReshapeExp, DType, dim, 1> uniform(Shape shape) { buffer_.Resize(Shape1(shape.Size())); this->SampleUniform(&buffer_, 0.0f, 1.0f); return expr::reshape(buffer_, shape); } private: #if MSHADOW_IN_CXX11 /*! \brief use c++11 random engine. */ std::mt19937 rnd_engine_; // implementing generators. inline void GenUniform(DType *dptr, index_t size, DType a, DType b) { std::uniform_real_distribution dist_uniform(a, b); for (size_t i = 0; i < size; ++i) { dptr[i] = dist_uniform(rnd_engine_); } } inline void GenGaussian(DType *dptr, index_t size, DType mu, DType sigma) { std::normal_distribution dist_normal(mu, sigma); for (size_t i = 0; i < size; ++i) { dptr[i] = dist_normal(rnd_engine_); } } #else /*! \brief random number seed used by PRNG */ unsigned rseed_; // functions inline void GenUniform(float *dptr, index_t size, float a, float b) { for (index_t j = 0; j < size; ++j) { dptr[j] = static_cast(RandNext()) * (b - a) + a; } } inline void GenUniform(double *dptr, index_t size, double a, double b) { for (index_t j = 0; j < size; ++j) { dptr[j] = static_cast(RandNext()) * (b - a) + a; } } inline void GenGaussian(float *dptr, index_t size, float mu, float sigma) { this->GenGaussianX(dptr, size, mu, sigma); } inline void GenGaussian(double *dptr, index_t size, double mu, double sigma) { this->GenGaussianX(dptr, size, mu, sigma); } inline void GenGaussianX(DType *dptr, index_t size, DType mu, DType sigma) { DType g1 = 0.0f, g2 = 0.0f; for (index_t j = 0; j < size; ++j) { if ((j & 1) == 0) { this->SampleNormal2D(&g1, &g2); dptr[j] = mu + g1 * sigma; } else { dptr[j] = mu + g2 * sigma; } } } /*! \brief get next random number from rand */ inline DType RandNext(void) { return static_cast(rand_r(&rseed_)) / (static_cast(RAND_MAX) + 1.0f); } /*! \brief return a real numer uniform in (0,1) */ inline DType RandNext2(void) { return (static_cast(rand_r(&rseed_)) + 1.0f) / (static_cast(RAND_MAX) + 2.0f); } /*! * \brief sample iid xx,yy ~N(0,1) * \param xx first gaussian output * \param yy second gaussian output */ inline void SampleNormal2D(DType *xx_, DType *yy_) { DType &xx = *xx_, &yy = *yy_; DType x, y, s; do { x = 2.0f * RandNext2() - 1.0f; y = 2.0f * RandNext2() - 1.0f; s = x * x + y * y; } while (s >= 1.0f || s == 0.0f); DType t = std::sqrt(-2.0f * std::log(s) / s); xx = x * t; yy = y * t; } #endif /*! \brief temporal space used to store random numbers */ TensorContainer buffer_; }; // class Random // only allow GPU PRNG when cuda is enabled #if MSHADOW_USE_CUDA /*! \brief GPU random number generator */ template class Random { public: /*! * \brief constructor of random engine * \param seed random number seed */ explicit Random(int seed) { curandStatus_t status; status = curandCreateGenerator(&gen_, CURAND_RNG_PSEUDO_DEFAULT); CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "Can not create CURAND Generator"; this->Seed(seed); buffer_.Resize(Shape1(kRandBufferSize)); } ~Random(void) DMLC_THROW_EXCEPTION { curandStatus_t status; status = curandDestroyGenerator(gen_); CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "Destory CURAND Gen failed"; } /*! * \brief set the stream of computation * \param stream computation stream */ inline void set_stream(Stream *stream) { curandStatus_t status; status = curandSetStream(gen_, Stream::GetStream(stream)); CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "set_stream CURAND failed"; } /*! * \brief seed random number generator using this seed * \param seed seed of prng */ inline void Seed(int seed) { curandStatus_t status; status = curandSetPseudoRandomGeneratorSeed(gen_, seed); CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "Set CURAND seed failed."; } /*! * \brief generate data from uniform [a,b) * \param dst destination * \param a lower bound of uniform * \param b upper bound of uniform * \tparam dim dimension of tensor */ template inline void SampleUniform(Tensor *dst, DType a = 0.0f, DType b = 1.0f); /*! * \brief generate data from standard gaussian * \param dst destination * \param mu mean variable * \param sigma standard deviation * \tparam dim dimension of tensor */ template inline void SampleGaussian(Tensor *dst, DType mu = 0.0f, DType sigma = 1.0f); /*! * \brief return a temporal expression storing standard gaussian random variables * the temporal tensor is only valid before next call of gaussian or uniform * can be used as part of expression * Caution: this means expression such as A = gaussian(s1) * gaussian(s2) will give invalid result, * since second call of gaussian(s2) makes gaussian(s1) invalid * A = gaussian(s1)*B+C; is correct; use one gaussian/uniform in each expression * \param shape shape of the tensor * \param mu mean * \param sigma variance * \return a temporal expression storing standard gaussian random variables * \tparam dim dimension of tensor */ template inline expr::ReshapeExp, DType, dim, 1> gaussian(Shape shape, DType mu = 0.0f, DType sigma = 1.0f); /*! * \brief return a temporal expression storing standard uniform [0,1) * the temporal tensor is only valid before next call of gaussian or uniform * can be used as part of expression * Caution: this means expression such as A = gaussian(s1) * gaussian(s2) will give invalid result, * since second call of gaussian(s2) makes gaussian(s1) invalid * A = gaussian(s1)*B+C; is correct; use one gaussian/uniform in each expression * \param shape shape of the tensor * \return a temporal expression storing standard uniform [0,1) * \tparam dim dimension of tensor */ template inline expr::ReshapeExp, DType, dim, 1> uniform(Shape shape); private: inline void GenGaussian(float *dptr, size_t size, float mu, float sigma) { curandStatus_t status; status = curandGenerateNormal(gen_, dptr, size, mu, sigma); CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "CURAND Gen Uniform failed"; } inline void GenGaussian(double *dptr, size_t size, double mu, double sigma) { curandStatus_t status; status = curandGenerateNormalDouble(gen_, dptr, size, mu, sigma); CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "CURAND Gen Uniform failed"; } inline void GenUniform(float *dptr, size_t size) { curandStatus_t status; status = curandGenerateUniform(gen_, dptr, size); CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "CURAND Gen Uniform failed"; } inline void GenUniform(double *dptr, size_t size) { curandStatus_t status; status = curandGenerateUniformDouble(gen_, dptr, size); CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "CURAND Gen Uniform failed"; } /*! \brief random numbeer generator */ curandGenerator_t gen_; /*! \brief templ buffer */ TensorContainer buffer_; }; // class Random #endif // MSHADOW_USE_CUDA #ifdef __CUDACC__ // implementations that depends on cuda kernels template template inline void Random::SampleUniform( Tensor *dst, DType a, DType b) { if (a == 0.0f && b == 1.0f) { if (dst->CheckContiguous()) { this->GenUniform(dst->dptr_, dst->shape_.Size()); } else { *dst = this->uniform(dst->shape_); } } else { *dst = this->uniform(dst->shape_) * (b - a) + a; } } template template inline void Random::SampleGaussian( Tensor *dst, DType mu, DType sigma) { if (dst->CheckContiguous()) { this->GenGaussian(dst->dptr_, dst->shape_.Size(), mu, sigma); } else { *dst = this->gaussian(dst->shape_, mu, sigma); } } template template inline expr::ReshapeExp, DType, dim, 1> Random::gaussian(Shape shape, DType mu, DType sigma) { size_t aligned_sz = ((shape.Size() + 1UL) >> 1) << 1; // allocate alligned size buffer_.Resize(Shape1(aligned_sz)); buffer_.Resize(Shape1(shape.Size())); this->GenGaussian(buffer_.dptr_, aligned_sz, mu, sigma); return expr::reshape(buffer_, shape); } template template inline expr::ReshapeExp, DType, dim, 1> Random::uniform(Shape shape) { buffer_.Resize(Shape1(shape.Size())); this->GenUniform(buffer_.dptr_, buffer_.size(0)); return expr::reshape(buffer_, shape); } #endif // __CUDACC__ } // namespace mshadow #endif // MSHADOW_RANDOM_H_ //===== EXPANDED: mxnet/mshadow/mshadow/random.h ===== // add definition of scalar related operators #ifdef MSAHDOW_SCALAR_ #error "MSHADOW_SCALAR_ must not be defined" #endif // enumerate all the scalar data type we aim to be good at #define MSHADOW_SCALAR_ float //===== EXPANDIND: mxnet/mshadow/mshadow/expr_scalar-inl.h ===== /*! * Copyright (c) 2014 by Contributors * \file expression-inl.h * \brief definitions of operators in expression with respect to scalar * this file will be included several times, each time with MACRO MSHADOW_SCALAR_ to be different types * * DO NOT add pragma once or macro guard * \author Tianqi Chen, Bing Xu */ // macro guard is harmful, used to pass the cpplint #ifndef MSHADOW_EXPR_SCALAR_INL_H_ #define MSHADOW_EXPR_SCALAR_INL_H_ // undef the guard so it can be included multiple times #undef MSHADOW_EXPR_SCALAR_INL_H_ namespace mshadow { namespace expr { // DotExp /*! \brief dot operator def */ template inline DotExp operator*(const DotExp &lhs, MSHADOW_SCALAR_ rhs) { return DotExp(lhs.lhs_, lhs.rhs_, lhs.scale_ * rhs); } /*! \brief scale of dot operation */ template inline DotExp operator*(MSHADOW_SCALAR_ lhs, const DotExp &rhs) { return DotExp(rhs.lhs_, rhs.rhs_, rhs.scale_ * lhs); } /*! \brief operator overload */ template inline ReduceTo1DExp operator*(const ReduceTo1DExp &e, MSHADOW_SCALAR_ scale) { return ReduceTo1DExp(e.src_, e.scale_ * scale); } /*! \brief operator overload */ template inline ReduceTo1DExp operator*(MSHADOW_SCALAR_ scale, const ReduceTo1DExp &e) { return ReduceTo1DExp(e.src_, e.scale_ * scale); } /*! \brief operator overload for const */ template inline BinaryMapExp, MSHADOW_SCALAR_, (ta|type::kMapper)> F(const Exp &lhs, const ScalarExp &rhs) { return MakeExp(lhs, rhs); } /*! \brief operator overload for const */ template inline BinaryMapExp, TB, MSHADOW_SCALAR_, (tb|type::kMapper)> F(const ScalarExp &lhs, const Exp &rhs) { return MakeExp(lhs, rhs); } // constant operators /*! \brief operator overload */ template inline BinaryMapExp, MSHADOW_SCALAR_, (ta|type::kMapper)> operator+(const Exp &lhs, const ScalarExp &rhs) { return MakeExp(lhs, rhs); } /*! \brief operator overload */ template inline BinaryMapExp, MSHADOW_SCALAR_, (ta|type::kMapper)> operator-(const Exp &lhs, const ScalarExp &rhs) { return MakeExp(lhs, rhs); } /*! \brief operator overload */ template inline BinaryMapExp, MSHADOW_SCALAR_, (ta|type::kMapper)> operator*(const Exp &lhs, const ScalarExp &rhs) { return MakeExp(lhs, rhs); } /*! \brief operator overload */ template inline BinaryMapExp, MSHADOW_SCALAR_, (ta|type::kMapper)> operator/(const Exp &lhs, const ScalarExp &rhs) { return MakeExp(lhs, rhs); } // constant operators 2 /*! \brief operator overload */ template inline BinaryMapExp, TB, MSHADOW_SCALAR_, (tb|type::kMapper)> operator+(const ScalarExp &lhs, const Exp &rhs) { return MakeExp(lhs, rhs); } /*! \brief operator overload */ template inline BinaryMapExp, TB, MSHADOW_SCALAR_, (tb|type::kMapper)> operator-(const ScalarExp &lhs, const Exp &rhs) { return MakeExp(lhs, rhs); } /*! \brief operator overload */ template inline BinaryMapExp, TB, MSHADOW_SCALAR_, (tb|type::kMapper)> operator*(const ScalarExp &lhs, const Exp &rhs) { return MakeExp(lhs, rhs); } /*! \brief operator overload */ template inline BinaryMapExp, TB, MSHADOW_SCALAR_, (tb|type::kMapper)> operator/(const ScalarExp &lhs, const Exp &rhs) { return MakeExp(lhs, rhs); } } // namespace expr } // namespace mshadow #endif // MSHADOW_EXPR_SCALAR_INL_H_ //===== EXPANDED: mxnet/mshadow/mshadow/expr_scalar-inl.h ===== #undef MSHADOW_SCALAR_ #define MSHADOW_SCALAR_ double #undef MSHADOW_SCALAR_ #define MSHADOW_SCALAR_ int #undef MSHADOW_SCALAR_ #endif // MSHADOW_TENSOR_H_ //===== EXPANDED: mxnet/mshadow/mshadow/tensor.h ===== /*! *\brief whether to use opencv support */ #ifndef MXNET_USE_OPENCV #define MXNET_USE_OPENCV 1 #endif /*! *\brief whether to use cuda support */ #ifndef MXNET_USE_CUDA #define MXNET_USE_CUDA MSHADOW_USE_CUDA #endif /*! *\brief whether to use cudnn library for convolution */ #ifndef MXNET_USE_CUDNN #define MXNET_USE_CUDNN MSHADOW_USE_CUDNN #endif /*! \brief Error message for using gpu when MXNET_USE_CUDA==0 */ #define MXNET_GPU_NOT_ENABLED_ERROR "GPU is not enabled" /*! * \brief define compatible keywords in g++ * Used to support g++-4.6 and g++4.7 */ #if DMLC_USE_CXX11 && defined(__GNUC__) && !defined(__clang_version__) #if __GNUC__ == 4 && __GNUC_MINOR__ < 8 #error "Currently we need g++ 4.8 or higher to fully support c++11 features" #define override #define final #endif #endif /*! * \brief define dllexport for Visual Studio */ #ifdef _MSC_VER #ifdef MXNET_EXPORTS #define MXNET_API __declspec(dllexport) #else #define MXNET_API __declspec(dllimport) #endif #else #define MXNET_API #endif /*! * \brief define prediction only */ #ifndef MXNET_PREDICT_ONLY #define MXNET_PREDICT_ONLY 0 #endif /*! \brief namespace of mxnet */ namespace mxnet { /*! \brief mxnet cpu */ typedef mshadow::cpu cpu; /*! \brief mxnet gpu */ typedef mshadow::gpu gpu; /*! \brief index type usually use unsigned */ typedef mshadow::index_t index_t; /*! \brief data type that will be used to store ndarray */ typedef mshadow::default_real_t real_t; /*! \brief dynamic shape type */ typedef mshadow::TShape TShape; /*! \brief storage container type */ typedef mshadow::TBlob TBlob; /*! \brief Context information about the execution enviroment */ struct Context { /*! \brief Type of device */ enum DeviceType { kCPU = cpu::kDevMask, kGPU = gpu::kDevMask, kCPUPinned = 3 }; /*! \brief the device type we run the op on */ DeviceType dev_type; /*! \brief device id we are going to run it on */ int32_t dev_id; /*! \brief default constructor */ Context() : dev_type(kCPU), dev_id(0) {} /*! * \brief Get corresponding device mask * \return cpu::kDevMask or gpu::kDevMask */ inline int dev_mask() const { if (dev_type == kCPUPinned) return cpu::kDevMask; return dev_type; } /*! * \brief Comparator, used to enable Context as std::map key. * \param b another context to compare * \return compared result */ inline bool operator<(const Context &b) const; /*! * \brief check if current context equals another one * \param b another context to compare * \return whether dev mask and id are same */ inline bool operator==(const Context &b) const { return dev_type == b.dev_type && dev_id == b.dev_id; } /*! * \brief check if current context not equals another one * \param b another context to compare * \return whether they are not the same */ inline bool operator!=(const Context &b) const { return !(*this == b); } /*! * \brief save the content into binary stream * \param strm the output stream */ inline void Save(dmlc::Stream *strm) const { strm->Write(&dev_type, sizeof(dev_type)); strm->Write(&dev_id, sizeof(dev_id)); } /*! * \brief load the content from binary stream * \param strm the output stream * \return whether the load is successful */ inline bool Load(dmlc::Stream *strm) { if (strm->Read(&dev_type, sizeof(dev_type)) != sizeof(dev_type)) return false; if (strm->Read(&dev_id, sizeof(int32_t)) != sizeof(int32_t)) return false; return true; } /*! \brief the maximal device type */ static const int32_t kMaxDevType = 4; /*! \brief the maximal device index */ static const int32_t kMaxDevID = 16; /*! * \brief Create a new context. * \param dev_type device type. * \param dev_id device id. */ inline static Context Create(DeviceType dev_type, int32_t dev_id); /*! \return CPU Context */ inline static Context CPU(); /*! * Create a GPU context. * \param dev_id the device id. * \return GPU Context. */ inline static Context GPU(int32_t dev_id); /*! * Create a pinned CPU context. * \param dev_id the device id for corresponding GPU. * \return Pinned CPU context. */ inline static Context CPUPinned(int32_t dev_id); }; /*! * \brief execution time context. * The information needed in runtime for actual execution. */ struct RunContext { /*! * \brief the stream of the device, can be NULL or Stream* in GPU mode */ void *stream; /*! * \brief get mshadow stream from Context * \return the mshadow stream * \tparam xpu the device type of the stream */ template inline mshadow::Stream* get_stream() const { return static_cast*>(stream); } }; } // namespace mxnet //! \cond Doxygen_Suppress namespace mxnet { // implementing Context inline bool Context::operator<(const Context &b) const { if (dev_type == b.dev_type) { return dev_id < b.dev_id; } else { return dev_type < b.dev_type; } } inline Context Context::Create(DeviceType dev_type, int32_t dev_id) { Context ctx; ctx.dev_type = dev_type; ctx.dev_id = dev_id; return ctx; } inline Context Context::CPU() { return Create(kCPU, 0); } inline Context Context::CPUPinned(int32_t dev_id) { return Create(kCPUPinned, dev_id); } inline Context Context::GPU(int32_t dev_id) { return Create(kGPU, dev_id); } } // namespace mxnet namespace dmlc { // Add a few patches to support TShape in dmlc/parameter. DMLC_DECLARE_TYPE_NAME(mxnet::TShape, "Shape(tuple)"); namespace parameter { template<> class FieldEntry : public FieldEntryBase, mxnet::TShape> { public: FieldEntry() : enforce_nonzero_(false), expect_ndim_(0) {} // parent class typedef FieldEntryBase, mxnet::TShape> Parent; virtual void Check(void *head) const { Parent::Check(head); mxnet::TShape &v = this->Get(head); if (expect_ndim_ != 0 && v.ndim() != expect_ndim_) { std::ostringstream os; os << "value " << v << "for Parameter " << this->key_ << " has wrong dimensions, expected dimension=" << expect_ndim_; throw dmlc::ParamError(os.str()); } if (enforce_nonzero_) { for (mxnet::index_t i = 0; i < v.ndim(); ++i) { if (v[i] == 0U) { std::ostringstream os; os << "value " << v << "for Parameter " << this->key_ << " is invalid, the input shape must be nonzero in all dimensions"; throw dmlc::ParamError(os.str()); } } } } inline FieldEntry &enforce_nonzero() { this->enforce_nonzero_ = true; return this->self(); } inline FieldEntry &set_expect_ndim(mshadow::index_t ndim) { expect_ndim_ = ndim; return this->self(); } private: // whether all the entries need to be nonzero bool enforce_nonzero_; // expected number of dimension, default = 0 means no restriction. mxnet::index_t expect_ndim_; }; } // namespace parameter } // namespace dmlc //! \endcond #endif // MXNET_BASE_H_ //===== EXPANDED: mxnet/include/mxnet/base.h ===== //===== EXPANDIND: mxnet/include/mxnet/operator.h ===== /*! * Copyright (c) 2015 by Contributors * \file operator.h * \brief Operator interface of mxnet. * \author Naiyan Wang */ #ifndef MXNET_OPERATOR_H_ #define MXNET_OPERATOR_H_ //===== EXPANDIND: mxnet/include/mxnet/resource.h ===== /*! * Copyright (c) 2015 by Contributors * \file resource.h * \brief Global resource allocation handling. */ #ifndef MXNET_RESOURCE_H_ #define MXNET_RESOURCE_H_ //===== EXPANDIND: mxnet/include/mxnet/engine.h ===== /*! * Copyright (c) 2015 by Contributors * \file engine.h * \brief Engine that schedules all the operations according to dependency. */ #ifndef MXNET_ENGINE_H_ #define MXNET_ENGINE_H_ #if DMLC_USE_CXX11 #endif namespace mxnet { /*! \brief namespace of engine internal types. */ namespace engine { /*! \brief Internal representation of variable. */ struct Var; /*! \brief Internal representation of operator. */ struct Opr; /*! \brief Variable pointer type, usually hold by user used to specify dependencies. */ typedef Var* VarHandle; /*! \brief Operator pointer type, usually hold by user.*/ typedef Opr* OprHandle; } // namespace engine #if DMLC_USE_CXX11 /*! \brief Function property, used to hint what action is pushed to engine. */ enum class FnProperty { /*! \brief Normal operation */ kNormal, /*! \brief Copy operation from GPU to other devices */ kCopyFromGPU, /*! \brief Copy operation from CPU to other devices */ kCopyToGPU, /*! \brief Prioritized sync operation on CPU */ kCPUPrioritized, /*! \brief Asynchronous function call */ kAsync }; // enum class FnProperty /*! * \brief Dependency engine that schedules operations. */ class MXNET_API Engine { public: /*! * \brief OnComplete Callback to the engine, * called by AsyncFn when action completes */ class CallbackOnComplete { public: // use implicit copy and assign /*! \brief involve the callback */ inline void operator()() const { (*callback_)(engine_, param_); } private: /*! \brief engine can see content of callback */ friend class ::mxnet::Engine; /*! \brief the real callback */ void (*callback_)(Engine *, void *); /*! \brief the engine class passed to callback */ Engine* engine_; /*! \brief the parameter set on callback */ void* param_; }; /*! \brief Synchronous operation to pass to engine. */ typedef std::function SyncFn; /*! \brief Asynchronous operation to pass to engine. */ typedef std::function AsyncFn; /*! \brief Variable pointer */ typedef engine::VarHandle VarHandle; /*! \brief Operator pointer */ typedef engine::OprHandle OprHandle; /*! * \brief Notify the engine about a shutdown, * This can help engine to print less messages into display. * * User do not have to call this function. * \return 0 when success, -1 when failure happens. */ virtual void NotifyShutdown() = 0; /*! * \brief Allocate a new variable, the variable can then * be used to schedule the operation concurrently via dependency * patterns. * \return The new variable allocated. */ virtual VarHandle NewVariable() = 0; /*! * \brief Create a new operator. The returned operator could be saved * externally so that it could be resued for scheduling. * \param fn The execution function. * \param const_vars The variables that current operation will use but not * mutate. * \param mutable_vars The variables that current operation will mutate. * \param prop Property of the function. * \return The new operator allocated. */ virtual OprHandle NewOperator(AsyncFn fn, std::vector const& const_vars, std::vector const& mutable_vars, FnProperty prop = FnProperty::kNormal) = 0; /*! * \brief Delete the given operator. * \param op The operator to delete. * * The delete will not happen immediately, but will wait until all the * operations using this operator are completed. */ virtual void DeleteOperator(OprHandle op) = 0; /*! * \brief Push an operator to the engine. * \param op The operator to push. * \param exec_ctx Execution context. * \param priority Priority of the action, as hint to the engine. */ virtual void Push(OprHandle op, Context exec_ctx, int priority = 0) = 0; /*! * \brief Push an asynchronous operation to the engine. * \param exec_fun Execution function, this function takes a parameter * on_complete that must be called when the execution * completes. * \param exec_ctx Execution context. * \param const_vars The variables that current operation will use but not * mutate. * \param mutable_vars The variables that current operation will mutate. * \param prop Property of the function. * \param priority Priority of the action, as hint to the engine. */ virtual void PushAsync(AsyncFn exec_fun, Context exec_ctx, std::vector const& const_vars, std::vector const& mutable_vars, FnProperty prop = FnProperty::kNormal, int priority = 0) = 0; /*! * \brief Schedule the deletion of a variable. * * The delete will not happen immediately, but will wait until all the * operations depending on var are completed. * * \param delete_fn A function that will be called after the variable is * deleted. * \param exec_ctx Execution context. * \param var The variable to be deleted. */ virtual void DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var) = 0; /*! * \brief Wait for a variable. * \param var The variable we should wait for. This function returns when the * variable is ready. */ virtual void WaitForVar(VarHandle var) = 0; /*! * \brief Wait until all the activity of engine finishes. */ virtual void WaitForAll() = 0; /*!\brief virtual destructor */ virtual ~Engine() noexcept(false) {} /*! * \return Engine singleton. */ static Engine* Get(); /*! * \brief Get shared pointer reference to engine singleton. * Most user should not call this function. * This function is called by another singleton X who requires * engine to be destructed after X. * * \return A shared pointer to Engine singleton. */ static std::shared_ptr _GetSharedRef(); /*! * \brief Push an synchronous operation to the engine. * \param exec_fn Execution function that executes the operation. * \param exec_ctx Execution context. * \param const_vars The variables that current operation will use but not * mutate. * \param mutable_vars The variables that current operation will mutate. * \param prop Property of the function. * \param priority Priority of the action, as hint to the engine. * \tparam SyncFn the synchronous function to be pushed. */ template inline void PushSync(SyncFn exec_fn, Context exec_ctx, std::vector const& const_vars, std::vector const& mutable_vars, FnProperty prop = FnProperty::kNormal, int priority = 0) { this->PushAsync([exec_fn](RunContext ctx, CallbackOnComplete on_complete) { exec_fn(ctx); on_complete(); }, exec_ctx, const_vars, mutable_vars, prop, priority); } protected: /*! * \brief factory function to create OnComplete callback. * \param callback th static callback function. * \param param the paramter passed to callback. */ inline CallbackOnComplete CreateCallback( void (*callback)(Engine *, void *), void *param) { CallbackOnComplete ret; ret.callback_ = callback; ret.engine_ = this; ret.param_ = param; return ret; } }; // class Engine #endif // DMLC_USE_CXX11 } // namespace mxnet #endif // MXNET_ENGINE_H_ //===== EXPANDED: mxnet/include/mxnet/engine.h ===== namespace mxnet { /*! * \brief The resources that can be requested by Operator */ struct ResourceRequest { /*! \brief Resource type, indicating what the pointer type is */ enum Type { /*! \brief mshadow::Random object */ kRandom, /*! \brief A dynamic temp space that can be arbitrary size */ kTempSpace }; /*! \brief type of resources */ Type type; /*! \brief default constructor */ ResourceRequest() {} /*! * \brief constructor, allow implicit conversion * \param type type of resources */ ResourceRequest(Type type) // NOLINT(*) : type(type) {} }; /*! * \brief Resources used by mxnet operations. * A resource is something special other than NDArray, * but will still participate */ struct Resource { /*! \brief The original request */ ResourceRequest req; /*! \brief engine variable */ engine::VarHandle var; /*! \brief identifier of id information, used for debug purpose */ int32_t id; /*! * \brief pointer to the resource, do not use directly, * access using member functions */ void *ptr_; /*! \brief default constructor */ Resource() : id(0) {} /*! * \brief Get random number generator. * \param stream The stream to use in the random number generator. * \return the mshadow random number generator requested. * \tparam xpu the device type of random number generator. */ template inline mshadow::Random* get_random( mshadow::Stream *stream) const { CHECK_EQ(req.type, ResourceRequest::kRandom); mshadow::Random *ret = static_cast*>(ptr_); ret->set_stream(stream); return ret; } /*! * \brief Get space requested as mshadow Tensor. * The caller can request arbitrary size. * * \param shape the Shape of returning tensor. * \param stream the stream of retruning tensor. * \return the mshadow tensor requested. * \tparam xpu the device type of random number generator. * \tparam ndim the number of dimension of the tensor requested. */ template inline mshadow::Tensor get_space( mshadow::Shape shape, mshadow::Stream *stream) const { CHECK_EQ(req.type, ResourceRequest::kTempSpace); mshadow::TensorContainer *space = static_cast*>(ptr_); space->Resize(mshadow::Shape1(shape.Size())); return mshadow::Tensor( space->dptr_, shape, shape[ndim - 1], stream); } }; /*! \brief Global resource manager */ class ResourceManager { public: /*! * \brief Get resource of requested type. * \param ctx the context of the request. * \param req the resource request. * \return the requested resource. * \note The returned resource's ownership is * still hold by the manager singleton. */ virtual Resource Request(Context ctx, const ResourceRequest &req) = 0; /*! * \brief Seed all the allocated random numbers. * \param seed the seed to the random number generators on all devices. */ virtual void SeedRandom(uint32_t seed) = 0; /*! \brief virtual destructor */ virtual ~ResourceManager() DMLC_THROW_EXCEPTION {} /*! * \return Resource manager singleton. */ static ResourceManager *Get(); }; } // namespace mxnet #endif // MXNET_RESOURCE_H_ //===== EXPANDED: mxnet/include/mxnet/resource.h ===== namespace mxnet { /*! \brief operation request type to Forward and Backward */ enum OpReqType { /*! \brief no operation, do not write anything */ kNullOp, /*! \brief write gradient to provided space */ kWriteTo, /*! * \brief perform an inplace write, * Target shares memory with one of input arguments. * This option only happen when */ kWriteInplace, /*! \brief add to the provided space */ kAddTo }; /*! * \brief All the possible information needed by Operator.Forward and Backward * This is the superset of RunContext. * We use this data structure to bookkeep everything needed by Forward and Backward. * \sa Resource */ struct OpContext { /*! \brief whether it is training phase */ int is_train; /*! \brief RunContext related resources */ RunContext run_ctx; /*! \brief Resources requested by the operator */ std::vector requested; /*! * \brief get mshadow stream from Context * \return the mshadow stream * \tparam xpu the device type of the stream */ template inline mshadow::Stream* get_stream() const { return run_ctx.get_stream(); } }; /*! * \brief Operator interface. * Operator defins basic operation unit of optimized computation graph in mxnet. * This interface relies on pre-allocated memory in TBlob, the caller need to set * the memory region in TBlob correctly before calling Forward and Backward. * * Operator is generated by OperatorProperty. * To add new operator(aka. layers of neural nets) to mxnet, developer need to create * a new OperatorProperty and its corresponding Operator. * * \sa TBlob, TShape, OperatorProperty */ class Operator { public: /*! \brief destructor */ virtual ~Operator() {} /*! * \brief perform a forward operation of Operator, save the output to TBlob. * \param ctx runtime context available to this call * \param in_data array of input data, it is const * \param req the request types of saving operation, can only be kWriteTo or kWriteInplace. * \param out_data array of output data, pointer is used to indicate that this is holder * the space of TBlob in out_data must be pre-allocated with InferShape * \param aux_states Auxiliary states of operator. Normally operator doesn't * need, epecial case like Batch Norm requires. * \sa OpReqType, OpContext */ virtual void Forward(const OpContext &ctx, const std::vector &in_data, const std::vector &req, const std::vector &out_data, const std::vector &aux_states) = 0; /*! * \brief Perform a Backward Operation, write gradient to the in_grad. * * \note * Convention: * out_grad.size() == OperatorProperty.NumVisibleOutputs() * out_data.size() == OperatorProperty.NumOutputs() * out_data can contain additional invisible returns that remembers the * state carried from the Forward pass. For example mask in the dropout. * The gradients are passed from visible returns in this function. * * \par * Not all the TBlobs in the arguments will be available * if you override the DeclareBackwardDependency of corresponding OperatorProperty class. * Only the dependencies you declared will be available at corresponding position, * the rest of the parameters are simply dummy where you will get a nullptr. * You will be safe if you use the default DeclareBackwardDependency. * But only declare what you need will give engine more chance for optimization. * * \param ctx runtime context available to this call * \param out_grad the gradient value we get from of the Operator. * \param in_data the array of input data. * \param out_data the array of output data. * \param req request types of the saving operation, can be all types. * \param in_grad the array of gradient we need to write to. * \param aux_states Auxiliary states of operator. Normally operator doesn't need * \sa OperatorProperty, OpReqType, OpContext */ virtual void Backward(const OpContext &ctx, const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, const std::vector &req, const std::vector &in_grad, const std::vector &aux_states) { LOG(FATAL) << "Backward is not implemented"; } }; #if DMLC_USE_CXX11 // OperatorProperty allows C++11, while Operator do not rely on it. /*! * \brief OperatorProperty is a object that stores all information about Operator. * It also contains method to generate context(device) specific operators. * * It also contains various functions that can be optimally overriden to * provide optimization chance for computation engine. */ class OperatorProperty { public: /*! * \brief virtual destructor */ virtual ~OperatorProperty() {} /*! * \brief Initialize the Operator by setting the parameters * This function need to be called before all other functions. * \param kwargs the keyword arguments parameters */ virtual void Init(const std::vector >& kwargs) = 0; /*! * \brief Get a map representation of internal parameters. * This can be used by Init to recover the state of OperatorProperty. */ virtual std::map GetParams() const = 0; /*! * \brief Get input arguments of the Operator. * \return vector of arguments. */ virtual std::vector ListArguments() const { return {"data"}; } /*! * \brief Get name of output values of Operator * \return name of output values. */ virtual std::vector ListOutputs() const { return {"output"}; } /*! * \brief Get name of auxilary states of Operator * \return name of return values. */ virtual std::vector ListAuxiliaryStates() const { return {}; } /*! \return number of real return values of the Operator */ virtual int NumOutputs() const { return 1; } /*! * \brief get number of visible return values during Symbol creation. * If NumVisibleOutputs() = k, and NumOutputs() = n. * The first k returns will be presented in the resulting symbol. * * The rest of the returns can be used for auxiliary states for Backward. * For example, Dropout will return [data, mask], with NumVisibleOutputs() == 1. * So when user call sym = Dropout(input), only data is presented in sym. * But all the returns will be presented in out_data parameter of Backward if requested. * * \return number of default return values */ virtual int NumVisibleOutputs() const { return NumOutputs(); } /*! * \brief infer the shapes of outputs and unknown input arguments * \param in_shape the shape of input arguments of the operator * this should be of same length as the vector returned by DescribeArgs * in_shape allows unknown elements, which are checked by shape.ndim() == 0. * For unknown shapes, InferShape will try to fill in the correct Shape in in_shape * For known shapes, InferShape will check shape consistency * * common practice: set the shape of data input, and usually weight's shape can be infered * * \param out_shape the shape of outputs of the operator * InferShape will modify the vector to fill output TShape * \param aux_shape the shape of auxiliary states of the operator * InferShape will modify the vector to fill output TShape * \return true if the shape inference is successful, false if there is not enough information. * \throws dmlc::Error if the known arg_shapes are inconsistent. */ virtual bool InferShape(std::vector *in_shape, std::vector *out_shape, std::vector *aux_shape) const = 0; /*! * \brief Copy this OperatorProperty. * \return a pointer to the copied OperatorProperty */ virtual OperatorProperty* Copy() const = 0; /*! * \brief Create a Operator on specific context */ virtual Operator* CreateOperator(Context ctx) const = 0; /*! * \brief return the type string of the Operator * subclasses override this function. * \return The type string. */ virtual std::string TypeString() const = 0; //-------------------------------------------------------- // All the below functions are optional to override. //-------------------------------------------------------- /*! * \brief Declare additional resource required in forward pass. * These additional resources will be presented in OpContext.requested * in the same order of the returned Resource. * \param in_shape The input shape to the operator, corresponds to shapes of in_data. * \return Additional resource request */ virtual std::vector ForwardResource( const std::vector &in_shape) const { return std::vector(); } /*! * \brief Decalre additional resource required in backward pass. * These additional resources will be presented in OpContext.requested * in the same order of the returned Resource. * \param in_shape The input shape to the operator, corresponds to shapes of in_data. * \return Additional resource request */ virtual std::vector BackwardResource( const std::vector &in_shape) const { return std::vector(); } /*! * \brief Declare the input requirement of Backward pass. * * Only the returned list of variables will be used in Backward. * This function is used for memory optimization. * It is adviced to override and only return what is actually needed. * If this function is not overriden, all the variables will be valid in Backward. * * \code * // The following code declares Backward need out_grad[0], in_data[0],in_data[1] * vector BackwardInputs(const vector &out_grad, * const vector &in_data, * const vector &out_data) const { * return {out_grad[0], in_data[0], in_data[1]}; * } * \endcode * \param out_grad gradient of outputs in backward pass. * \param in_data the input data in forward pass. * \param out_data the output data in forward pass. * \return an integer vector indicating the input requirments * \sa BackwardInputs */ virtual std::vector DeclareBackwardDependency( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data) const { // By default requires to see all the things. // remember to override this function to get a better performance. std::vector ret = out_grad; ret.insert(ret.end(), in_data.begin(), in_data.end()); ret.insert(ret.end(), out_data.begin(), out_data.end()); return ret; } /*! * \brief Get possible forward inplace options. * This function enables optimization to reuse memory of inputs in output. * Only override when necessary, by default in-place is disabled. * * The reason for void* type in the out_data is to distinguish the order * of mappings between the two, compiler will report error when * in_data and out_data's order in the pair get reversed. * * \code * // The following code says out_data[0] can share data with in_data[0] * vector > ForwardInplaceOption(const vector &in_data, * const vector &out_data) const { * return {{in_data[0], out_data[0]}}; * } * \endcode * \param in_data The input data in forward pass. * \param out_data The output data in forward pass. * \return list of pair of that maps input->output, * indicating possible in place operations. */ virtual std::vector > ForwardInplaceOption( const std::vector &in_data, const std::vector &out_data) const { return std::vector >(); } /*! * \brief Get possible backward inplace options. * This function enables optimization to reuse memory of inputs in output. * Only override when necessary, by default in-place is disabled. * * The reason for void* type in the in_grad is to distinguish the order * of mappings between the two, compiler will report error when * in_data and out_data's order in the pair get reversed. * * \code * // The following code says in_grad[0] can share data with in_data[0] * vector > BackwardInplaceOption( * const std::vector &out_grad, * const std::vector &in_data, * const std::vector &out_data, * const std::vector &in_grad) const { * return {in_data[0], in_grad[0]}}; * } * \endcode * \param in_data The input data in forward pass. * \param out_data The output data in forward pass. * \param in_grad Gradient of inputs in backward pass. * \param out_grad Gradient of outputs in backward pass. * \return list of pair of that maps input->output, * indicating possible in place operations. */ virtual std::vector > BackwardInplaceOption( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, const std::vector &in_grad) const { return std::vector >(); } /*! * \brief Get Backward Input Dependency for generic types of data. * Normally T can be pointer of Symbol::DataEntry, or NDArray. * This function will select the result list of T according to DeclareBackwardDependency. * * \param in_data the input data in forward pass. * \param out_data the output data in forward pass. * \param out_grad gradient of outputs in backward pass. * \tparam T the generic type parameter. * \return vector of inputs the Backward Operation depends on. * \sa DeclareBackwardDependency */ template inline std::vector BackwardInputs(const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data) const { int counter = 0; std::vector out_grad_index(out_grad.size()); std::vector in_data_index(in_data.size()); std::vector out_data_index(out_data.size()); for (size_t i = 0; i < out_grad_index.size(); ++i) { out_grad_index[i] = counter++; } for (size_t i = 0; i < in_data_index.size(); ++i) { in_data_index[i] = counter++; } for (size_t i = 0; i < out_data_index.size(); ++i) { out_data_index[i] = counter++; } std::vector all_data; all_data.insert(all_data.end(), out_grad.begin(), out_grad.end()); all_data.insert(all_data.end(), in_data.begin(), in_data.end()); all_data.insert(all_data.end(), out_data.begin(), out_data.end()); std::vector ret_index = this->DeclareBackwardDependency( out_grad_index, in_data_index, out_data_index); std::vector ret(ret_index.size()); for (size_t i = 0; i < ret_index.size(); ++i) { ret[i] = all_data[ret_index[i]]; } return ret; } /*! * \brief create OperatorProperty * \param type_name the type string of the OperatorProperty * \return a new constructed OperatorProperty */ static OperatorProperty *Create(const char* type_name); }; /*! \brief typedef the factory function of operator property */ typedef std::function OperatorPropertyFactory; /*! * \brief Registry entry for OperatorProperty factory functions. */ struct OperatorPropertyReg : public dmlc::FunctionRegEntryBase { /*! * \brief Set key_var_num_args * When this is set, the API caller is required to pass in a * argument with key=key_num_args.c_str(), and value=num_args. * num_args is number of positional argument when calling the function. * * This is used to pass in length of positional arguments * for operators that can take variable length of input. * Most operators do not need to set this property. * * \param key the key name to be set */ inline OperatorPropertyReg& set_key_var_num_args(const std::string &key) { // NOLINT(*) this->key_var_num_args = key; return *this; } /*! * \brief Check if TypeString of the type matches the registered name */ inline OperatorPropertyReg& check_name() { OperatorProperty *p = this->body(); std::string type = p->TypeString(); delete p; CHECK_EQ(this->name, type) << "Register Name and TypeString mismatch, name=\"" << this->name << "\"," << " but TypeString=\"" << type <<"\""; return *this; } /*! \brief The key num_args name. */ std::string key_var_num_args; }; //-------------------------------------------------------------- // The following part are API Registration of Operators //-------------------------------------------------------------- /*! * \brief Macro to register OperatorProperty * * \code * // example of registering a fully connected operator * REGISTER_OP_PROPERTY(FullyConnected, FullyConnectedOpProp) * .describe("Fully connected layer"); * * \endcode */ #define MXNET_REGISTER_OP_PROPERTY(name, OperatorPropertyType) \ DMLC_REGISTRY_REGISTER(::mxnet::OperatorPropertyReg, OperatorPropertyReg, name) \ .set_body([]() { return new OperatorPropertyType(); }) \ .check_name() #endif // DMLC_USE_CXX11 } // namespace mxnet #endif // MXNET_OPERATOR_H_ //===== EXPANDED: mxnet/include/mxnet/operator.h ===== #if DMLC_USE_CXX11 #endif namespace mxnet { namespace common { /*! \brief namespace of arguments */ namespace arg { /*! \brief super class of all gradient function argument */ struct GradFunctionArgument { /*! \brief The real data */ TBlob data; }; /*! \brief First input to the function */ struct Input0 : GradFunctionArgument {}; /*! \brief Second input to the function */ struct Input1 : GradFunctionArgument {}; /*! \brief Ouput value of the function to the function */ struct OutValue : GradFunctionArgument {}; /*! \brief Gradient of output value */ struct OutGrad : GradFunctionArgument {}; } // namespace arg /*! \brief registry for function entry */ class TBlobOpRegEntry { public: typedef void (*UnaryFunction)(const TBlob &src, TBlob* ret, OpReqType req, RunContext ctx); typedef TShape (*UnaryShapeInfer)(const TShape &src); typedef void (*UnaryGradType1)(const arg::OutGrad& out_grad, const arg::OutValue& out_value, TBlob* in_grad, OpReqType req, RunContext ctx); typedef void (*UnaryGradType2)(const arg::OutGrad& out_grad, const arg::Input0& in_data0, TBlob* in_grad, OpReqType req, RunContext ctx); /*! \brief declare self type */ typedef TBlobOpRegEntry TSelf; /*! \brief name of the entry */ std::string name; /*! * \brief set shape inference function, by default use same shape. * \param fshapeinfer The unary function that peforms the operation. */ virtual TSelf& set_shape_infer(UnaryShapeInfer fshapeinfer) = 0; /*! * \brief set function of the function to be funary * \param dev_mask The device mask of the function can act on. * \param funary The unary function that peforms the operation. * \param inplace_in_out Whether do inplace optimization on in and out. * \param register_symbolic Whether register a symbolic operator as well. */ virtual TSelf& set_function(int dev_mask, UnaryFunction funary, bool inplace_in_out, bool register_symbolic = true) = 0; /*! * \brief set gradient of the function of this function. * \param dev_mask The device mask of the function can act on. * \param fgrad The gradient function to be set. * \param inplace_out_in_grad whether out_grad and in_grad can share memory. */ virtual TSelf& set_gradient(int dev_mask, UnaryGradType1 fgrad, bool inplace_out_in_grad) = 0; virtual TSelf& set_gradient(int dev_mask, UnaryGradType2 fgrad, bool inplace_out_in_grad) = 0; /*! * \brief Describe the function. * \param description The description of the function. * \return reference to self. */ virtual TSelf& describe(const std::string &description) = 0; /*! \brief destructor */ virtual ~TBlobOpRegEntry() {} }; /*! \brief registry for TBlob functions */ class TBlobOpRegistry { public: /*! * \brief Internal function to register a name function under name. * \param name name of the function * \return ref to the registered entry, used to set properties */ TBlobOpRegEntry &__REGISTER_OR_FIND__(const std::string& name); /*! * \brief Find the entry with corresponding name. * \param name name of the function * \return the corresponding function, can be NULL */ inline static const TBlobOpRegEntry *Find(const std::string &name) { return Get()->fmap_.at(name); } /*! \return global singleton of the registry */ static TBlobOpRegistry* Get(); private: // destructor ~TBlobOpRegistry(); /*! \brief internal registry map */ std::map fmap_; }; #define MXNET_REGISTER_TBLOB_FUN(Name, DEV) \ static ::mxnet::common::TBlobOpRegEntry & \ __make_ ## TBlobOpRegEntry ## _ ## Name ## __ ## DEV ##__ = \ ::mxnet::common::TBlobOpRegistry::Get()->__REGISTER_OR_FIND__(#Name) } // namespace common } // namespace mxnet #endif // MXNET_COMMON_TBLOB_OP_REGISTRY_H_ //===== EXPANDED: mxnet/src/common/tblob_op_registry.h ===== //===== EXPANDIND: mxnet/src/operator/mshadow_op.h ===== /*! * Copyright (c) 2015 by Contributors * \file mshadow_op.h * \brief * \author Bing Xu */ #ifndef MXNET_OPERATOR_MSHADOW_OP_H_ #define MXNET_OPERATOR_MSHADOW_OP_H_ namespace mxnet { namespace op { namespace mshadow_op { /*! \brief identity Operation */ struct identity { MSHADOW_XINLINE static real_t Map(real_t a) { return a; } }; struct identity_grad { MSHADOW_XINLINE static real_t Map(real_t a) { return 1.0f; } }; struct negation { MSHADOW_XINLINE static real_t Map(real_t a) { return -a; } }; /*! \brief sigmoid unit */ struct sigmoid { MSHADOW_XINLINE static real_t Map(real_t a) { return 1.0f / (1.0f + expf(-a)); } }; struct sigmoid_grad { MSHADOW_XINLINE static real_t Map(real_t a) { return a * (1.0f - a); } }; /*! \brief Rectified Linear Operation */ struct relu { MSHADOW_XINLINE static real_t Map(real_t a) { return a > 0.0f ? a : 0.0f; } }; struct relu_grad { MSHADOW_XINLINE static real_t Map(real_t a) { return a > 0.0f ? 1.0f : 0.0f; } }; /*! \brief Leaky ReLU Operation */ struct xelu { MSHADOW_XINLINE static real_t Map(real_t a, real_t b) { return a > 0.0f ? a : a * b; } }; struct xelu_grad { MSHADOW_XINLINE static real_t Map(real_t a, real_t b) { return a > 0.0f ? 1.0f : b; } }; struct tanh { MSHADOW_XINLINE static real_t Map(real_t a) { return tanhf( a ); } }; struct tanh_grad { MSHADOW_XINLINE static real_t Map(real_t a) { return 1.0f - a * a; } }; struct exp { MSHADOW_XINLINE static real_t Map(real_t a) { return expf(a); } }; struct log { MSHADOW_XINLINE static real_t Map(real_t a) { return logf(a); } }; struct log_grad { MSHADOW_XINLINE static real_t Map(real_t a) { return 1.0f / a; } }; struct square { MSHADOW_XINLINE static real_t Map(real_t a) { return a * a; } }; struct square_grad { MSHADOW_XINLINE static real_t Map(real_t a) { return 2.0f * a; } }; /*! \brief used for generate Bernoulli mask */ struct threshold { MSHADOW_XINLINE static real_t Map(real_t a, real_t b) { return a < b ? 1.0f : 0.0f; } }; /*! \brief used for generate element of power */ struct power { MSHADOW_XINLINE static real_t Map(real_t a, real_t b) { return powf( a, b ); } }; /*!\ \brief used for generate element sqrt */ struct square_root { MSHADOW_XINLINE static real_t Map(real_t a) { return sqrt(a); } }; struct square_root_grad { MSHADOW_XINLINE static real_t Map(real_t a) { return 0.5f / a; } }; } // namespace mshadow_op } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_MSHADOW_OP_H_ //===== EXPANDED: mxnet/src/operator/mshadow_op.h ===== //===== EXPANDIND: mxnet/src/operator/operator_common.h ===== /*! * Copyright (c) 2015 by Contributors * \file operator_common.h * \brief common internal header of most operators * this header includes utility functions operator can use * \author Bing Xu */ #ifndef MXNET_OPERATOR_OPERATOR_COMMON_H_ #define MXNET_OPERATOR_OPERATOR_COMMON_H_ namespace mxnet { namespace op { /*! * \brief assign the expression to out according to request * \param out the data to be assigned * \param req the assignment request * \param exp the expression * \tparam OType output type * \tparam Exp expression type */ template inline void Assign(OType &out, // NOLINT(*) OpReqType req, const Exp &exp) { switch (req) { case kNullOp: break; case kWriteTo: case kWriteInplace: out = exp; break; case kAddTo: out += exp; break; default: LOG(FATAL) << "not reached"; } } /*! \brief exception throwed by InferShape error */ struct InferShapeError { /*! \brief analyze message */ std::string msg; /*! \brief corresponding input index */ int index; // constructor InferShapeError(std::string msg, int index) : msg(msg), index(index) {} }; /*! * \brief macro assign shape to out if out is unknown otherwise check consistency * Use macro so we can see the error file more clearly * \param shape_array the shape array to store the result * \param index the index of in the array * \param shape the infered shape */ #define SHAPE_ASSIGN_CHECK(shape_array, index, shape) \ { \ auto &out = (shape_array)[index]; \ if (out.ndim() == 0) { \ out = shape; \ } else { \ if (out != shape) { \ std::ostringstream os; \ os << "Shape inconsistent, Provided " << '='<< out << ',' \ << " inferred shape=" << shape; \ throw ::mxnet::op::InferShapeError(os.str(), index); \ } \ } \ } // helper macro to implement bind dispatch #if MXNET_USE_CUDA #define DO_BIND_DISPATCH(Method, ...) \ if (ctx.dev_mask() == cpu::kDevMask) { \ return Method(__VA_ARGS__); \ } else { \ return Method(__VA_ARGS__); \ } #else #define DO_BIND_DISPATCH(Method, ...) \ if (ctx.dev_mask() == cpu::kDevMask) { \ return Method(__VA_ARGS__); \ } else { \ LOG(FATAL) << "GPU is not enabled"; \ return nullptr; \ } #endif } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_OPERATOR_COMMON_H_ //===== EXPANDED: mxnet/src/operator/operator_common.h ===== #if defined(__CUDACC__) #define XPU gpu #else #define XPU cpu #endif namespace mxnet { namespace ndarray { using namespace common; // NOLINT(*) template void UnaryForward_(const TBlob &src, TBlob *ret, OpReqType req, RunContext ctx) { using namespace mxnet::op; using namespace mshadow::expr; mshadow::Stream *s = ctx.get_stream(); mshadow::Tensor out = ret->FlatTo2D(s); Assign(out, req, F(src.FlatTo2D(s))); } // backward function that takes input value of the op template void UnaryBackwardUseIn_(const arg::OutGrad& out_grad, const arg::Input0& in_data0, TBlob *in_grad, OpReqType req, RunContext ctx) { using namespace mxnet::op; using namespace mshadow::expr; mshadow::Stream *s = ctx.get_stream(); mshadow::Tensor igrad = in_grad->FlatTo2D(s); Assign(igrad, req, F(in_data0.data.FlatTo2D(s)) * out_grad.data.FlatTo2D()); } // backward function that takes output value of the op template void UnaryBackwardUseOut_(const arg::OutGrad& out_grad, const arg::OutValue& out_value, TBlob *in_grad, OpReqType req, RunContext ctx) { using namespace mxnet::op; using namespace mshadow::expr; mshadow::Stream *s = ctx.get_stream(); mshadow::Tensor igrad = in_grad->FlatTo2D(s); Assign(igrad, req, F(out_value.data.FlatTo2D(s)) * out_grad.data.FlatTo2D()); } // return a shape of scalar inline TShape ScalarShape(const TShape& ishape) { mshadow::index_t shape[] = {1}; return TShape(shape, shape + 1); } template void L2Norm(const TBlob &src, TBlob *ret, OpReqType req, RunContext ctx) { mshadow::Stream *s = ctx.get_stream(); mshadow::Tensor out = ret->get(s); mshadow::Tensor in = src.get_with_shape(mshadow::Shape1(src.shape_.Size()), s); mshadow::VectorDot(out, in, in); out = mshadow::expr::F(out); } // Register all unary operations here // The true means inplace can be enabled. // square MXNET_REGISTER_TBLOB_FUN(square, XPU) .set_function(XPU::kDevMask, UnaryForward_, true) .set_gradient(XPU::kDevMask, UnaryBackwardUseIn_, true) .describe("Take square of the src"); // sqrt MXNET_REGISTER_TBLOB_FUN(sqrt, XPU) .set_function(XPU::kDevMask, UnaryForward_, true) .set_gradient(XPU::kDevMask, UnaryBackwardUseOut_, true) .describe("Take sqrt of the src"); // exp MXNET_REGISTER_TBLOB_FUN(exp, XPU) .set_function(XPU::kDevMask, UnaryForward_, true) .set_gradient(XPU::kDevMask, UnaryBackwardUseOut_, true) .describe("Take exp of the src"); // log MXNET_REGISTER_TBLOB_FUN(log, XPU) .set_function(XPU::kDevMask, UnaryForward_, true) .set_gradient(XPU::kDevMask, UnaryBackwardUseIn_, true) .describe("Take log of the src"); // L2 norm MXNET_REGISTER_TBLOB_FUN(norm, XPU) .set_function(XPU::kDevMask, L2Norm, false, false) .set_shape_infer(ScalarShape) .describe("Take L2 norm of the src." "The result will be ndarray of shape (1,) on the same device."); } // namespace ndarray } // namespace mxnet #endif // MXNET_NDARRAY_UNARY_FUNCTION_INL_H_ //===== EXPANDED: mxnet/src/ndarray/unary_function-inl.h ===== //===== EXPANDED: mxnet/src/ndarray/unary_function.cc ===== //===== EXPANDIND: mxnet/src/ndarray/ndarray_function.cc ===== /*! * Copyright (c) 2015 by Contributors * \file ndarray_function_cpu.cc * \brief CPU Implementation of ndarray function. */ // this will be invoked by gcc and compile CPU version //===== EXPANDIND: mxnet/src/ndarray/ndarray_function.h ===== /*! * Copyright (c) 2015 by Contributors * \file ndarray_op.h * \brief the real execution functions of ndarray operations */ #ifndef MXNET_NDARRAY_NDARRAY_FUNCTION_H_ #define MXNET_NDARRAY_NDARRAY_FUNCTION_H_ namespace mxnet { /*! \brief namespace to support all possible Ndarray operator */ namespace ndarray { struct BinaryBase { inline static TShape GetShape(const TShape &lshape, const TShape &rshape) { CHECK(lshape == rshape) << "operands shape mismatch"; CHECK(lshape.ndim() != 0) << "source operand have zero dimension shape"; return lshape; } }; // operators struct Plus : public BinaryBase { typedef mshadow::op::plus mshadow_op; }; struct Minus : public BinaryBase { typedef mshadow::op::minus mshadow_op; }; struct Mul : public BinaryBase { typedef mshadow::op::mul mshadow_op; }; struct Div : public BinaryBase { typedef mshadow::op::div mshadow_op; }; struct ClipMin : public BinaryBase { struct mshadow_op { MSHADOW_XINLINE static real_t Map(real_t a, real_t b) { if (a < b) { return b; } else { return a; } } }; }; struct ClipMax : public BinaryBase { struct mshadow_op { MSHADOW_XINLINE static real_t Map(real_t a, real_t b) { if (a > b) { return b; } else { return a; } } }; }; struct Dot { inline static TShape GetShape(const TShape &lshape, const TShape &rshape) { CHECK(lshape.ndim() == 2 && rshape.ndim() == 2) << "dot only support 2D Array"; CHECK_EQ(lshape[1], rshape[0]) << "dot shape error: " << lshape << " X " << rshape; size_t target_shape[] = {lshape[0], rshape[1]}; return TShape(target_shape, target_shape + 2); } }; struct OneHotEncode { inline static TShape GetShape(const TShape &index, const TShape &proptype) { CHECK(index.ndim() == 1 && proptype.ndim() == 2) << "OneHotEncode only support 1d index."; CHECK_EQ(index[0], proptype[0]) << "OneHotEncode shape inconsistent"; return proptype; } }; struct MatChooseRowElem { inline static TShape GetShape(const TShape &lshape, const TShape &rshape) { CHECK(lshape.ndim() == 2 && rshape.ndim() == 1) << "choose_row_element only support 2D Matrix and 1D index"; CHECK_EQ(lshape[0], rshape[0]) << "choose_row_element index and matrix shape mismatch"; return rshape; } }; // type holder for random number generators struct UniformDistribution {}; struct GaussianDistribution {}; template void EvalClip(const TBlob &src, const real_t &a_min, const real_t &a_max, TBlob *ret, RunContext ctx); template void Eval(const TBlob &lhs, const TBlob &rhs, TBlob *ret, RunContext ctx); template void Eval(const TBlob &src, TBlob *ret, RunContext ctx); template void Eval(const TBlob &lhs, const real_t &rhs, TBlob *ret, RunContext ctx); template void Eval(const real_t &rhs, TBlob *ret, RunContext ctx); template void EvalRandom(const real_t &a, const real_t &b, const Resource &resource, TBlob *ret, RunContext ctx); // copy function when only cpu is involved template void Copy(const TBlob &from, TBlob *to, Context from_ctx, Context to_ctx, RunContext ctx); template void ElementwiseSum(const std::vector source, TBlob *out, RunContext ctx); } // namespace ndarray } // namespace mxnet #endif // MXNET_NDARRAY_NDARRAY_FUNCTION_H_ //===== EXPANDED: mxnet/src/ndarray/ndarray_function.h ===== //===== EXPANDIND: mxnet/src/ndarray/ndarray_function-inl.h ===== /*! * Copyright (c) 2015 by Contributors * \file ndarray_function-inl.h * \brief The real implementation of NDArray functions. */ #ifndef MXNET_NDARRAY_NDARRAY_FUNCTION_INL_H_ #define MXNET_NDARRAY_NDARRAY_FUNCTION_INL_H_ // this file will be included twice by CPU and GPU // macro to help specialize evaluation function #ifndef DECL_BINARY #define DECL_BINARY(XPU, OP, FUN) \ template<> \ void Eval(const TBlob &lhs, const TBlob &rhs, TBlob *ret, RunContext ctx) { \ FUN(lhs, rhs, ret, ctx); \ } #endif #ifndef DECL_SCALAR #define DECL_SCALAR(XPU, OP, FUN, REVERSE) \ template<> \ void Eval(const TBlob &lhs, const real_t &rhs, TBlob *ret, RunContext ctx) { \ FUN(lhs, rhs, ret, ctx); \ } #endif #if defined(__CUDACC__) #define DEVICE gpu #else #define DEVICE cpu #endif namespace mxnet { namespace ndarray { // true implementation template inline void EvalBinary_(const TBlob &lhs, const TBlob &rhs, TBlob *ret, RunContext ctx) { using namespace mshadow::expr; mshadow::Stream *s = ctx.get_stream(); ret->FlatTo2D(s) = F(lhs.FlatTo2D(s), rhs.FlatTo2D(s)); } template inline void EvalDot_(const TBlob &lhs, const TBlob &rhs, TBlob *ret, RunContext ctx) { using namespace mshadow::expr; mshadow::Stream *s = ctx.get_stream(); ret->FlatTo2D(s) = dot(lhs.FlatTo2D(s), rhs.FlatTo2D(s)); } template inline void EvalOneHot_(const TBlob &index, const TBlob &rhs, TBlob *ret, RunContext ctx) { using namespace mshadow::expr; mshadow::Stream *s = ctx.get_stream(); ret->get(s) = one_hot_encode(index.get(s), rhs.shape_[1]); } template inline void EvalMatChooseRowElem_(const TBlob &lhs, const TBlob &rhs, TBlob *ret, RunContext ctx) { using namespace mshadow::expr; mshadow::Stream *s = ctx.get_stream(); ret->get(s) = mat_choose_row_element(lhs.get(s), rhs.get(s)); } template inline void EvalScalar_(const TBlob &lhs, const real_t &rhs, TBlob *ret, RunContext ctx) { using namespace mshadow::expr; mshadow::Stream *s = ctx.get_stream(); if (reverse) { ret->FlatTo2D(s) = F(rhs, lhs.FlatTo2D(s)); } else { ret->FlatTo2D(s) = F(lhs.FlatTo2D(s), rhs); } } template<> void EvalClip(const TBlob &src, const real_t &a_min, const real_t &a_max, TBlob *ret, RunContext ctx) { typedef DEVICE xpu; using namespace mshadow::expr; mshadow::Stream *s = ctx.get_stream(); ret->FlatTo2D(s) = F( F(src.FlatTo2D(s), a_min), a_max); } template<> void EvalRandom( const real_t &a, const real_t &b, const Resource &resource, TBlob *ret, RunContext ctx) { typedef DEVICE xpu; mshadow::Stream *s = ctx.get_stream(); mshadow::Tensor tmp = ret->FlatTo2D(s); mshadow::Random *prnd = resource.get_random(s); prnd->SampleUniform(&tmp, a, b); } template<> void EvalRandom( const real_t &mu, const real_t &sigma, const Resource &resource, TBlob *ret, RunContext ctx) { typedef DEVICE xpu; mshadow::Stream *s = ctx.get_stream(); mshadow::Tensor tmp = ret->FlatTo2D(s); mshadow::Random *prnd = resource.get_random(s); prnd->SampleGaussian(&tmp, mu, sigma); } template<> void Eval(const real_t &rhs, TBlob *ret, RunContext ctx) { mshadow::Stream *s = ctx.get_stream(); ret->FlatTo2D(s) = rhs; } template<> void ElementwiseSum(const std::vector source, TBlob *dst, RunContext ctx) { typedef DEVICE xpu; using namespace mshadow; using namespace mshadow::expr; Stream *s = ctx.get_stream(); Tensor out = dst->FlatTo2D(s); switch (source.size()) { case 2: { Tensor in_0 = source[0].FlatTo2D(s); Tensor in_1 = source[1].FlatTo2D(s); out = in_0 + in_1; break; } case 3: { Tensor in_0 = source[0].FlatTo2D(s); Tensor in_1 = source[1].FlatTo2D(s); Tensor in_2 = source[2].FlatTo2D(s); out = in_0 + in_1 + in_2; break; } case 4: { Tensor in_0 = source[0].FlatTo2D(s); Tensor in_1 = source[1].FlatTo2D(s); Tensor in_2 = source[2].FlatTo2D(s); Tensor in_3 = source[3].FlatTo2D(s); out = in_0 + in_1 + in_2 + in_3; break; } default: { Tensor in_0 = source[0].FlatTo2D(s); out = F(in_0); for (size_t i = 1; i < source.size(); ++i) { out += source[i].FlatTo2D(s); } break; } } } // declarations DECL_BINARY(DEVICE, MatChooseRowElem, EvalMatChooseRowElem_) DECL_BINARY(DEVICE, Dot, EvalDot_) DECL_BINARY(DEVICE, OneHotEncode, EvalOneHot_) DECL_BINARY(DEVICE, Plus, EvalBinary_) DECL_BINARY(DEVICE, Minus, EvalBinary_) DECL_BINARY(DEVICE, Mul, EvalBinary_) DECL_BINARY(DEVICE, Div, EvalBinary_) DECL_SCALAR(DEVICE, Plus, EvalScalar_, true) DECL_SCALAR(DEVICE, Minus, EvalScalar_, true) DECL_SCALAR(DEVICE, Mul, EvalScalar_, true) DECL_SCALAR(DEVICE, Div, EvalScalar_, true) // for reverse seq DECL_SCALAR(DEVICE, Plus, EvalScalar_, false) DECL_SCALAR(DEVICE, Minus, EvalScalar_, false) DECL_SCALAR(DEVICE, Mul, EvalScalar_, false) DECL_SCALAR(DEVICE, Div, EvalScalar_, false) } // namespace ndarray } // namespace mxnet #endif // MXNET_NDARRAY_NDARRAY_FUNCTION_INL_H_ //===== EXPANDED: mxnet/src/ndarray/ndarray_function-inl.h ===== namespace mxnet { namespace ndarray { template<> void Copy(const TBlob &from, TBlob *to, Context from_ctx, Context to_ctx, RunContext ctx) { mshadow::Copy(to->FlatTo2D(), from.FlatTo2D()); } } // namespace ndarray } // namespace mxnet //===== EXPANDED: mxnet/src/ndarray/ndarray_function.cc ===== //===== EXPANDIND: mxnet/src/ndarray/ndarray.cc ===== /*! * Copyright (c) 2015 by Contributors * \file ndarray.cc * \brief ndarry module of mxnet */ //===== EXPANDIND: mxnet/include/mxnet/ndarray.h ===== /*! * Copyright (c) 2015 by Contributors * \file ndarray.h * \brief NDArray interface that handles array arithematics. */ #ifndef MXNET_NDARRAY_H_ #define MXNET_NDARRAY_H_ //===== EXPANDIND: mxnet/include/mxnet/storage.h ===== /*! * Copyright (c) 2015 by Contributors * \file storage.h * \brief Storage manager across multiple devices. */ #ifndef MXNET_STORAGE_H_ #define MXNET_STORAGE_H_ namespace mxnet { /*! * \brief Storage manager across multiple devices. */ class MXNET_API Storage { public: /*! * \brief Storage handle. */ struct Handle { /*! * \brief Pointer to the data. */ void* dptr; /*! * \brief Size of the storage. */ size_t size; /*! * \brief Context information about device and ID. */ Context ctx; }; /*! * \brief Allocate a new contiguous memory for a given size. * \param size Total size of memory in bytes. * \param ctx Context information about the device and ID. * \return Handle struct. */ virtual Handle Alloc(size_t size, Context ctx) = 0; /*! * \brief Free storage. * \param handle Handle struect. */ virtual void Free(Handle handle) = 0; /*! * \brief Destructor. */ virtual ~Storage() {} /*! * \return Storage singleton. */ static Storage* Get(); /*! * \brief Get shared pointer reference to engine singleton. * Most user should not call this function. * This function is called by another singleton X who requires * Storage to be destructed after X. * * \return A shared pointer to Storage singleton. */ static std::shared_ptr _GetSharedRef(); }; // class Storage } // namespace mxnet #endif // MXNET_STORAGE_H_ //===== EXPANDED: mxnet/include/mxnet/storage.h ===== // check c++11 #if DMLC_USE_CXX11 == 0 #error "cxx11 was required for ndarray module" #endif namespace mxnet { /*! * \brief ndarray interface */ class NDArray { public: /*! \brief default cosntructor */ NDArray() {} /*! * \brief constructing a new dynamic NDArray * \param shape the shape of array * \param ctx context of NDArray * \param delay_alloc whether delay the allocation */ NDArray(const TShape &shape, Context ctx, bool delay_alloc = false) : ptr_(std::make_shared(shape.Size(), ctx, delay_alloc)), shape_(shape), offset_(0) { } /*! * \brief constructing a static NDArray that shares data with TBlob * Use with caution: allocate ONLY ONE NDArray for each TBlob, * make sure the memory region is available through out the life of NDArray * \param data the memory content of static data * \param dev_id the device id this tensor sits at */ NDArray(const TBlob &data, int dev_id) : ptr_(std::make_shared(data, dev_id)), shape_(data.shape_), offset_(0) { } /*! * \return the shape of current NDArray */ inline const TShape &shape() const { return shape_; } /*! * \return the data TBlob */ inline TBlob data() const { return TBlob(static_cast(ptr_->shandle.dptr) + offset_, \ shape_, ptr_->shandle.ctx.dev_mask()); } /*! * \return the context of NDArray, this function is only valid when the NDArray is not empty */ inline Context ctx() const { return ptr_->shandle.ctx; } /*! \return whether this ndarray is not initialized */ inline bool is_none() const { return ptr_.get() == nullptr; } /*! * \brief Block until all the pending write operations with respect * to current NDArray are finished, and read can be performed. */ inline void WaitToRead() const { if (is_none()) return; Engine::Get()->WaitForVar(ptr_->var); } /*! * \brief Block until all the pending read/write operations with respect * to current NDArray are finished, and write can be performed. */ inline void WaitToWrite() const { if (is_none()) return; /*! * Push an empty mutable function to flush all preceding reads to the * variable. */ Engine::Get()->PushSync([](RunContext) {}, Context{}, {}, {ptr_->var}); Engine::Get()->WaitForVar(ptr_->var); } /*! \return the associated variable of the ndarray.*/ inline Engine::VarHandle var() const { return ptr_->var; } /*! * \brief save the content into binary stream * \param strm the output stream */ void Save(dmlc::Stream *strm) const; /*! * \brief load the content from binary stream * \param strm the output stream * \return whether the load is successful */ bool Load(dmlc::Stream *strm); /*! * \brief set all the elements in ndarray to be scalar * \param scalar the scalar to set * \return reference of self */ NDArray &operator=(real_t scalar); /*! * \brief elementwise add to current space * this mutate the current NDArray * \param src the data to add * \return reference of self */ NDArray &operator+=(const NDArray &src); /*! * \brief elementwise add to current space * this mutate the current NDArray * \param src the data to add * \return reference of self */ NDArray &operator+=(const real_t &src); /*! * \brief elementwise subtract from current ndarray * this mutate the current NDArray * \param src the data to substract * \return reference of self */ NDArray &operator-=(const NDArray &src); /*! * \brief elementwise subtract from current ndarray * this mutate the current NDArray * \param src the data to substract * \return reference of self */ NDArray &operator-=(const real_t &src); /*! * \brief elementwise multiplication to current ndarray * this mutate the current NDArray * \param src the data to substract * \return reference of self */ NDArray &operator*=(const NDArray &src); /*! * \brief elementwise multiplication to current ndarray * this mutate the current NDArray * \param src the data to substract * \return reference of self */ NDArray &operator*=(const real_t &src); /*! * \brief elementwise division from current ndarray * this mutate the current NDArray * \param src the data to substract * \return reference of self */ NDArray &operator/=(const NDArray &src); /*! * \brief elementwise division from current ndarray * this mutate the current NDArray * \param src the data to substract * \return reference of self */ NDArray &operator/=(const real_t &src); /*! * \brief return transpose of current NDArray * \return a new transposed NDArray */ NDArray T() const; /*! * \brief return a new copy this NDArray * \param ctx the new context of this NDArray * \return the new copy */ NDArray Copy(Context ctx) const; /*! * \brief Do a synchronize copy from a continugous CPU memory region. * * This function will call WaitToWrite before the copy is performed. * This is useful to copy data from existing memory region that are * not wrapped by NDArray(thus dependency not being tracked). * * \param data the data source to copy from. * \param size the memory size we want to copy from. */ void SyncCopyFromCPU(const real_t *data, size_t size) const; /*! * \brief Do a synchronize copy to a continugous CPU memory region. * * This function will call WaitToRead before the copy is performed. * This is useful to copy data from existing memory region that are * not wrapped by NDArray(thus dependency not being tracked). * * \param data the data source to copyinto. * \param size the memory size we want to copy into. */ void SyncCopyToCPU(real_t *data, size_t size) const; /*! * \brief Slice a NDArray * \param begin begin index in first dim * \param end end index in first dim * \return sliced NDArray */ inline NDArray Slice(index_t begin, index_t end) const { NDArray ret = *this; CHECK(!is_none()) << "NDArray is not initialized"; CHECK_GE(shape_[0], end) << "Slice end index out of range"; size_t length = 1; for (index_t i = 1; i < shape_.ndim(); ++i) { length *= shape_[i]; } ret.offset_ += begin * length; ret.shape_[0] = end - begin; return ret; } /*! * \brief Get an reshaped NDArray * \param shape new shape * \return NDArray in new shape */ inline NDArray Reshape(const TShape &shape) const { CHECK_GE(shape_.Size(), shape.Size()) << "NDArray.Reshape: target shape size is different from current shape"; NDArray ret = *this; ret.shape_ = shape; return ret; } /*! * \brief Allocate the space if it is delayed allocated. * This is an internal function used by system that normal user should not use */ inline void CheckAndAlloc() const { ptr_->CheckAndAlloc(); } /*! * \brief Save list of narray into the Stream.x * \param fo The stream of output. * \param data the NDArrays to be saved. * \param names the name of the NDArray, optional, can be zero length. */ static void Save(dmlc::Stream* fo, const std::vector& data, const std::vector& names); /*! * \brief Load list of narray into from the stream. * \param fi The stream of the input file. * \param data the NDArrays to be loaded * \param keys the name of the NDArray, if saved in the file. */ static void Load(dmlc::Stream* fi, std::vector* data, std::vector* keys); private: /*! \brief the real data chunk that backs NDArray */ struct Chunk { /*! \brief storage handlefrom storage engine */ Storage::Handle shandle; /*! \brief variable from engine */ Engine::VarHandle var; /*! * \brief if this is true, this means the data do not come * from Storage, and do not need to be freed */ bool static_data; /*! \brief whether allocation is delayed */ bool delay_alloc; /*! \brief default cosntructor */ Chunk() : static_data(true), delay_alloc(false) { var = Engine::Get()->NewVariable(); } /*! \brief construct from static data */ Chunk(const TBlob &data, int dev_id) : static_data(true), delay_alloc(false) { var = Engine::Get()->NewVariable(); if (data.dev_mask_ == cpu::kDevMask) { shandle.ctx = Context::CPU(); } else { CHECK_EQ(data.dev_mask_, gpu::kDevMask); shandle.ctx = Context::GPU(dev_id); } shandle.dptr = data.dptr_; shandle.size = data.shape_.Size() * sizeof(real_t); } /*! \brief construct a new chunk */ Chunk(uint64_t size, Context ctx, bool delay_alloc_) : static_data(false), delay_alloc(true) { var = Engine::Get()->NewVariable(); shandle.size = size * sizeof(real_t); shandle.ctx = ctx; if (!delay_alloc_) this->CheckAndAlloc(); } /*! \brief check if delay alloc is on, do alloc if not yet done */ inline void CheckAndAlloc(void) { if (delay_alloc) { shandle = Storage::Get()->Alloc(shandle.size, shandle.ctx); delay_alloc = false; } } /*! \brief destructor */ ~Chunk() { if (static_data || delay_alloc) { Engine::Get()->DeleteVariable([](RunContext s) {}, shandle.ctx, var); } else { Storage::Handle h = this->shandle; Engine::Get()->DeleteVariable([h](RunContext s) { Storage::Get()->Free(h); }, shandle.ctx, var); } } }; /*! \brief internal data of NDArray */ std::shared_ptr ptr_; /*! \brief shape of current NDArray */ TShape shape_; /*! \brief offset in chunk */ size_t offset_; }; /*! * \brief issue an copy operation from one NDArray to another * the two ndarray can sit on different devices * this operation will be scheduled by the engine * * \param from the ndarray we want to copy data from * \param to the target ndarray * \param priority Priority of the action. * \note The function name explicitly marks the order of from and to * due to different possible convention carried by copy function. */ void CopyFromTo(const NDArray &from, NDArray *to, int priority = 0); /*! * \brief Perform elementwise sum over each data from source, store result into out. * \param source the ndarray we want to sum * \param out the target ndarray * \param priority Priority of the action. */ void ElementwiseSum(const std::vector &source, NDArray *out, int priority = 0); /*! * \brief elementwise add * \param lhs left operand * \param rhs right operand * \return a new result ndarray */ NDArray operator+(const NDArray &lhs, const NDArray &rhs); /*! * \brief elementwise add * \param lhs left operand * \param rhs right operand * \return a new result ndarray */ NDArray operator+(const NDArray &lhs, const real_t &rhs); /*! * \brief elementwise substraction * \param lhs left operand * \param rhs right operand * \return a new result ndarray */ NDArray operator-(const NDArray &lhs, const NDArray &rhs); /*! * \brief elementwise substraction * \param lhs left operand * \param rhs right operand * \return a new result ndarray */ NDArray operator-(const NDArray &lhs, const real_t &rhs); /*! * \brief elementwise multiplication * \param lhs left operand * \param rhs right operand * \return a new result ndarray */ NDArray operator*(const NDArray &lhs, const NDArray &rhs);\ /*! * \brief elementwise multiplication * \param lhs left operand * \param rhs right operand * \return a new result ndarray */ NDArray operator*(const NDArray &lhs, const real_t &rhs); /*! * \brief elementwise division * \param lhs left operand * \param rhs right operand * \return a new result ndarray */ NDArray operator/(const NDArray &lhs, const NDArray &rhs); /*! * \brief elementwise division * \param lhs left operand * \param rhs right operand * \return a new result ndarray */ NDArray operator/(const NDArray &lhs, const real_t &rhs); /*! * \brief Seed the random number generator. * \param seed the seed to set to global random number generators. */ void RandomSeed(uint32_t seed); /*! * \brief Sample uniform distribution for each elements of out. * \param begin lower bound of distribution. * \param end upper bound of distribution. * \param out output NDArray. */ void SampleUniform(real_t begin, real_t end, NDArray *out); /*! * \brief Sample gaussian distribution for each elements of out. * \param mu mean of gaussian distribution. * \param sigma standard deviation of gaussian distribution. * \param out output NDArray. */ void SampleGaussian(real_t mu, real_t sigma, NDArray *out); //-------------------------------------------------------------- // The following part are API Registration of NDArray functions. //-------------------------------------------------------------- /*! \brief definition of NDArray function */ typedef std::function NDArrayAPIFunction; /*! \brief mask information on how functions can be exposed */ enum NDArrayFunctionTypeMask { /*! \brief all the use_vars should go before scalar */ kNDArrayArgBeforeScalar = 1, /*! \brief all the scalar should go before use_vars */ kScalarArgBeforeNDArray = 1 << 1, /*! * \brief whether this function allows the handles in the target to * be empty NDArray that are not yet initialized, and will initialize * them when the function is invoked. * * most function should support this, except copy between different * devices, which requires the NDArray to be pre-initialized with context */ kAcceptEmptyMutateTarget = 1 << 2 }; /*! \brief Registry entry for NDArrayFunction */ struct NDArrayFunctionReg : public dmlc::FunctionRegEntryBase { /*! \brief number of variable used by this function */ unsigned num_use_vars; /*! \brief number of variable mutated by this function */ unsigned num_mutate_vars; /*! \brief number of scalars used by this function */ unsigned num_scalars; /*! \brief information on how function should be called from API */ int type_mask; /*! * \brief constructor */ NDArrayFunctionReg() : num_use_vars(0), num_mutate_vars(0), num_scalars(0), type_mask(0) {} /*! * \brief set the function body to a NDArray setvalue function * this will also auto set the parameters correctly * \param fsetvalue function body to set * \return ref to the registered entry, used to set properties */ inline NDArrayFunctionReg &set_function(void (*fsetvalue)(const real_t &rhs, NDArray *out)) { body = [fsetvalue] (NDArray **used_vars, real_t *s, NDArray **mutate_vars) { (*fsetvalue)(s[0], mutate_vars[0]); }; num_mutate_vars = 1; num_scalars = 1; this->add_argument("src", "real_t", "Source input to the function."); return *this; } /*! * \brief set the function body to a binary NDArray function * this will also auto set the parameters correctly * \param fbinary function body to set * \return ref to the registered entry, used to set properties */ inline NDArrayFunctionReg &set_function(void (*fbinary)(const NDArray &lhs, const NDArray &rhs, NDArray *out)) { body = [fbinary] (NDArray **used_vars, real_t *s, NDArray **mutate_vars) { (*fbinary)(*used_vars[0], *used_vars[1], mutate_vars[0]); }; num_use_vars = 2; num_mutate_vars = 1; type_mask = kNDArrayArgBeforeScalar | kAcceptEmptyMutateTarget; this->add_argument("lhs", "NDArray", "Left operand to the function."); this->add_argument("rhs", "NDArray", "Right operand to the function."); return *this; } /*! * \brief set the function body to a binary NDArray function * this will also auto set the parameters correctly * \param fscalar function body to set * \return ref to the registered entry, used to set properties */ inline NDArrayFunctionReg &set_function(void (*fscalar)(const NDArray &lhs, const real_t &rhs, NDArray *out)) { body = [fscalar] (NDArray **used_vars, real_t *s, NDArray **mutate_vars) { (*fscalar)(*used_vars[0], s[0], mutate_vars[0]); }; num_use_vars = 1; num_mutate_vars = 1; num_scalars = 1; type_mask = kNDArrayArgBeforeScalar | kAcceptEmptyMutateTarget; this->add_argument("lhs", "NDArray", "Left operand to the function."); this->add_argument("rhs", "real_t", "Right operand to the function."); return *this; } /*! * \brief set the function body to a unary NDArray function * this will also auto set the parameters correctly * \param funary function body to set * \return ref to the registered entry, used to set properties */ inline NDArrayFunctionReg &set_function(void (*funary)(const NDArray &src, NDArray *out)) { body = [funary] (NDArray **used_vars, real_t *s, NDArray **mutate_vars) { (*funary)(*used_vars[0], mutate_vars[0]); }; num_use_vars = 1; num_mutate_vars = 1; type_mask = kNDArrayArgBeforeScalar | kAcceptEmptyMutateTarget; this->add_argument("src", "NDArray", "Source input to the function."); return *this; } /*! * \brief set the number of mutate variables * \param n number of mutate variablesx * \return ref to the registered entry, used to set properties */ inline NDArrayFunctionReg &set_num_use_vars(unsigned n) { num_use_vars = n; return *this; } /*! * \brief set the number of mutate variables * \param n number of mutate variablesx * \return ref to the registered entry, used to set properties */ inline NDArrayFunctionReg &set_num_mutate_vars(unsigned n) { num_mutate_vars = n; return *this; } /*! * \brief set the number of scalar arguments * \param n number of scalar arguments * \return ref to the registered entry, used to set properties */ inline NDArrayFunctionReg &set_num_scalars(unsigned n) { num_scalars = n; return *this; } /*! * \brief set type mask * \param tmask typemask * \return ref to the registered entry, used to set properties */ inline NDArrayFunctionReg &set_type_mask(int tmask) { type_mask = tmask; return *this; } }; // NDArrayFunctionReg /*! * \brief Macro to register NDArray function * * Example: the following code is example to register a plus * \code * * REGISTER_NDARRAY_FUN(Plus) * .set_function(Plus); * * \endcode */ #define MXNET_REGISTER_NDARRAY_FUN(name) \ DMLC_REGISTRY_REGISTER(::mxnet::NDArrayFunctionReg, NDArrayFunctionReg, name) } // namespace mxnet namespace dmlc { /*!\brief traits */ DMLC_DECLARE_TRAITS(has_saveload, mxnet::NDArray, true); } // namespace dmlc #endif // MXNET_NDARRAY_H_ //===== EXPANDED: mxnet/include/mxnet/ndarray.h ===== namespace dmlc { DMLC_REGISTRY_ENABLE(::mxnet::NDArrayFunctionReg); } // namespace dmlc namespace mxnet { /*! * \brief run a binary operation * \param lhs left operand * \param rhs right operand * \param out the output ndarray * \param binary_op the real */ template void BinaryOp(const NDArray &lhs, const NDArray &rhs, NDArray *out) { // no check if both of them are on cpu if (lhs.ctx().dev_mask() != cpu::kDevMask || rhs.ctx().dev_mask() != cpu::kDevMask) { CHECK(lhs.ctx() == rhs.ctx()) << "operands context mismatch"; } // if out is none, allocate space if (out->is_none()) { *out = NDArray(OP::GetShape(lhs.shape(), rhs.shape()), lhs.ctx(), true); } else { // no check if both of them are on cpu if (lhs.ctx().dev_mask() != cpu::kDevMask || out->ctx().dev_mask() != cpu::kDevMask) { CHECK(out->ctx() == lhs.ctx()) << "target context mismatch"; } CHECK(out->shape() == OP::GetShape(lhs.shape(), rhs.shape())) << "target shape mismatch"; } // important: callback must always capture by value NDArray ret = *out; // get the const variables std::vector const_vars; if (lhs.var() != ret.var()) const_vars.push_back(lhs.var()); if (rhs.var() != ret.var()) const_vars.push_back(rhs.var()); // redirect everything to mshadow operations switch (lhs.ctx().dev_mask()) { case cpu::kDevMask: { Engine::Get()->PushSync([lhs, rhs, ret](RunContext ctx) { ret.CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::Eval(lhs.data(), rhs.data(), &tmp, ctx); }, lhs.ctx(), const_vars, {ret.var()}); break; } #if MXNET_USE_CUDA case gpu::kDevMask: { Engine::Get()->PushSync([lhs, rhs, ret](RunContext ctx) { ret.CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::Eval(lhs.data(), rhs.data(), &tmp, ctx); // Wait GPU kernel to complete ctx.get_stream()->Wait(); }, lhs.ctx(), const_vars, {ret.var()}); break; } #endif default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; } } void SetValueOp(const real_t &rhs, NDArray *out) { CHECK_NE(out->is_none(), true) << "Set value target must not be empty"; // important: callback must always capture by value NDArray ret = *out; switch (ret.ctx().dev_mask()) { case cpu::kDevMask: { Engine::Get()->PushSync([rhs, ret](RunContext ctx) { ret.CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::Eval(rhs, &tmp, ctx); }, ret.ctx(), {}, {ret.var()}); break; } #if MXNET_USE_CUDA case gpu::kDevMask: { Engine::Get()->PushSync([rhs, ret](RunContext ctx) { ret.CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::Eval(rhs, &tmp, ctx); // Wait GPU kernel to complete ctx.get_stream()->Wait(); }, ret.ctx(), {}, {ret.var()}); break; } #endif default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; } } /*! * \brief run a binary operation * \param lhs left operand * \param rhs right operand * \param out the output ndarray * \param binary_op the real */ template void ScalarOp(const NDArray &lhs, const real_t &rhs, NDArray *out) { if (out->is_none()) { *out = NDArray(lhs.shape(), lhs.ctx(), true); } else { CHECK(out->ctx() == lhs.ctx()) << "target context mismatch"; CHECK(out->shape() == lhs.shape()) << "target shape mismatch"; } // important: callback must always capture by value NDArray ret = *out; // get the const variables std::vector const_vars; if (lhs.var() != ret.var()) const_vars.push_back(lhs.var()); // redirect everything to mshadow operations switch (lhs.ctx().dev_mask()) { case cpu::kDevMask: { Engine::Get()->PushSync([lhs, rhs, ret](RunContext ctx) { ret.CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::Eval(lhs.data(), rhs, &tmp, ctx); }, lhs.ctx(), const_vars, {ret.var()}); break; } #if MXNET_USE_CUDA case gpu::kDevMask: { Engine::Get()->PushSync([lhs, rhs, ret](RunContext ctx) { ret.CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::Eval(lhs.data(), rhs, &tmp, ctx); // Wait GPU kernel to complete ctx.get_stream()->Wait(); }, lhs.ctx(), const_vars, {ret.var()}); break; } #endif default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; } } void CopyFromTo(const NDArray &from, NDArray *to, int priority) { CHECK(from.shape() == to->shape()) << "operands shape mismatch"; CHECK(from.shape().ndim() != 0) << "source operands have zero dimension shape"; // important: callback must always capture by value NDArray ret = *to; int a = from.ctx().dev_mask(); int b = to->ctx().dev_mask(); std::vector const_vars; if (from.var() != ret.var()) const_vars.push_back(from.var()); if (a == cpu::kDevMask && b == cpu::kDevMask) { Engine::Get()->PushSync([from, ret](RunContext ctx) { ret.CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::Copy(from.data(), &tmp, from.ctx(), ret.ctx(), ctx); }, from.ctx(), const_vars, {ret.var()}, FnProperty::kNormal, priority); } else { #if MXNET_USE_CUDA if (a == cpu::kDevMask && b == gpu::kDevMask) { Engine::Get()->PushSync([from, ret](RunContext ctx) { ret.CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::Copy(from.data(), &tmp, from.ctx(), ret.ctx(), ctx); // Wait GPU kernel to complete ctx.get_stream()->Wait(); }, ret.ctx(), const_vars, {ret.var()}, FnProperty::kCopyToGPU, priority); } else if (a == gpu::kDevMask && b == cpu::kDevMask) { Engine::Get()->PushSync([from, ret](RunContext ctx) { ret.CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::Copy(from.data(), &tmp, from.ctx(), ret.ctx(), ctx); // Wait GPU kernel to complete ctx.get_stream()->Wait(); }, from.ctx(), const_vars, {ret.var()}, FnProperty::kCopyFromGPU, priority); } else if (a == gpu::kDevMask && b == gpu::kDevMask) { Engine::Get()->PushSync([from, ret](RunContext ctx) { ret.CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::Copy(from.data(), &tmp, from.ctx(), ret.ctx(), ctx); // Wait GPU kernel to complete ctx.get_stream()->Wait(); }, from.ctx(), const_vars, {ret.var()}, FnProperty::kCopyFromGPU, priority); } else { LOG(FATAL) << "unknown device mask"; } #else LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; #endif } } void ElementwiseSum(const std::vector &source, NDArray *out, int priority) { std::vector const_vars; const_vars.reserve(source.size()); for (size_t i = 0; i < source.size(); ++i) { if (source[i].var() != out->var()) { const_vars.push_back(source[i].var()); } CHECK_EQ(source[i].shape() , out->shape()) << "operands shape mismatch"; if (out->ctx().dev_mask() == cpu::kDevMask) { CHECK_EQ(source[i].ctx().dev_mask(), cpu::kDevMask) << "operands context mismatch"; } else { CHECK(source[i].ctx() == out->ctx()) << "operands context mismatch"; } } // important: callback must always capture by value NDArray ret = *out; switch (out->ctx().dev_mask()) { case cpu::kDevMask: { Engine::Get()->PushSync([source, ret](RunContext ctx) { std::vector source_tblob(source.size()); for (size_t i = 0; i < source.size(); ++i) { source_tblob[i] = source[i].data(); } ret.CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::ElementwiseSum(source_tblob, &tmp, ctx); }, out->ctx(), const_vars, {ret.var()}, FnProperty::kNormal, priority); break; } #if MXNET_USE_CUDA case gpu::kDevMask: { Engine::Get()->PushSync([source, ret](RunContext ctx) { std::vector source_tblob(source.size()); for (size_t i = 0; i < source.size(); ++i) { source_tblob[i] = source[i].data(); } ret.CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::ElementwiseSum(source_tblob, &tmp, ctx); // Wait GPU kernel to complete ctx.get_stream()->Wait(); }, out->ctx(), const_vars, {ret.var()}, FnProperty::kNormal, priority); break; } #endif default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; } } void ClipOp(const NDArray &src, const real_t &a_min, const real_t &a_max, NDArray *out) { if (out->is_none()) { *out = NDArray(src.shape(), src.ctx(), true); } else { CHECK(out->ctx() == src.ctx()) << "target context mismatch"; CHECK(out->shape() == src.shape()) << "target shape mismatch"; } NDArray ret = *out; std::vector const_vars; if (src.var() != ret.var()) const_vars.push_back(src.var()); switch (src.ctx().dev_mask()) { case cpu::kDevMask: { Engine::Get()->PushSync([src, a_min, a_max, ret](RunContext ctx) { ret.CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::EvalClip(src.data(), a_min, a_max, &tmp, ctx); }, src.ctx(), const_vars, {ret.var()}); break; } #if MXNET_USE_CUDA case gpu::kDevMask: { Engine::Get()->PushSync([src, a_min, a_max, ret](RunContext ctx) { ret.CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::EvalClip(src.data(), a_min, a_max, &tmp, ctx); }, src.ctx(), const_vars, {ret.var()}); break; } #endif default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; } } inline void CopyFromToSimple(const NDArray &from, NDArray *to) { CopyFromTo(from, to, 0); } template void SampleOP(const real_t &a, const real_t &b, NDArray *out) { CHECK(!out->is_none()); Resource resource = ResourceManager::Get()->Request( out->ctx(), ResourceRequest::kRandom); // important: callback must always capture by value NDArray ret = *out; // redirect everything to mshadow operations switch (out->ctx().dev_mask()) { case cpu::kDevMask: { Engine::Get()->PushSync([a, b, resource, ret](RunContext ctx) { ret.CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::EvalRandom(a, b, resource, &tmp, ctx); }, out->ctx(), {}, {ret.var(), resource.var}); break; } #if MXNET_USE_CUDA case gpu::kDevMask: { Engine::Get()->PushSync([a, b, resource, ret](RunContext ctx) { ret.CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::EvalRandom(a, b, resource, &tmp, ctx); // Wait GPU kernel to complete ctx.get_stream()->Wait(); }, out->ctx(), {}, {ret.var(), resource.var}); break; } #endif default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; } } void SampleUniform(real_t begin, real_t end, NDArray *out) { SampleOP(begin, end, out); } void SampleGaussian(real_t mu, real_t sigma, NDArray *out) { SampleOP(mu, sigma, out); } void RandomSeed(uint32_t seed) { ResourceManager::Get()->SeedRandom(seed); } template inline NDArray BinaryOpRet(const NDArray &lhs, const NDArray &rhs) { NDArray ret; BinaryOp(lhs, rhs, &ret); return ret; } template inline NDArray ScalarOpRet(const NDArray &lhs, const real_t &rhs) { NDArray ret; ScalarOp(lhs, rhs, &ret); return ret; } template inline NDArray &BinaryOpApply(NDArray *dst, const NDArray &src) { BinaryOp(*dst, src, dst); return *dst; } template inline NDArray &ScalarOpApply(NDArray *dst, const real_t &src) { ScalarOp(*dst, src, dst); return *dst; } // Binary NDArray operator+(const NDArray &lhs, const NDArray &rhs) { return BinaryOpRet(lhs, rhs); } NDArray operator-(const NDArray &lhs, const NDArray &rhs) { return BinaryOpRet(lhs, rhs); } NDArray operator*(const NDArray &lhs, const NDArray &rhs) { return BinaryOpRet(lhs, rhs); } NDArray operator/(const NDArray &lhs, const NDArray &rhs) { return BinaryOpRet(lhs, rhs); } // Scalar NDArray operator+(const NDArray &lhs, const real_t &rhs) { return ScalarOpRet(lhs, rhs); } NDArray operator-(const NDArray &lhs, const real_t &rhs) { return ScalarOpRet(lhs, rhs); } NDArray operator*(const NDArray &lhs, const real_t &rhs) { return ScalarOpRet(lhs, rhs); } NDArray operator/(const NDArray &lhs, const real_t &rhs) { return ScalarOpRet(lhs, rhs); } // Binary NDArray &NDArray::operator=(real_t scalar) { SetValueOp(scalar, this); return *this; } NDArray &NDArray::operator+=(const NDArray &src) { return BinaryOpApply(this, src); } NDArray &NDArray::operator-=(const NDArray &src) { return BinaryOpApply(this, src); } NDArray &NDArray::operator*=(const NDArray &src) { return BinaryOpApply(this, src); } NDArray &NDArray::operator/=(const NDArray &src) { return BinaryOpApply(this, src); } // Scalar NDArray &NDArray::operator+=(const real_t &src) { return ScalarOpApply(this, src); } NDArray &NDArray::operator-=(const real_t &src) { return ScalarOpApply(this, src); } NDArray &NDArray::operator*=(const real_t &src) { return ScalarOpApply(this, src); } NDArray &NDArray::operator/=(const real_t &src) { return ScalarOpApply(this, src); } void NDArray::Save(dmlc::Stream *strm) const { // save shape shape_.Save(strm); if (is_none()) return; // save context Context ctx = this->ctx(); ctx.Save(strm); TBlob save_data; NDArray temp; if (ctx.dev_mask() != cpu::kDevMask) { temp = this->Copy(Context::CPU()); temp.WaitToRead(); save_data = temp.data(); } else { this->WaitToRead(); save_data = this->data(); } // save type flag int32_t type_flag = save_data.type_flag_; CHECK(type_flag == mshadow::DataType::kFlag) << "Only support float NDArray so far"; strm->Write(&type_flag, sizeof(type_flag)); CHECK(save_data.CheckContiguous()); // save data: need to change this after more type mask is supported size_t type_size = sizeof(real_t); strm->Write(save_data.dptr_, type_size * shape_.Size()); } bool NDArray::Load(dmlc::Stream *strm) { // load shape TShape shape; if (!shape.Load(strm)) return false; if (shape.ndim() == 0) { *this = NDArray(); return true; } // load context Context ctx; if (!ctx.Load(strm)) return false; // load type flag int32_t type_flag; if (strm->Read(&type_flag, sizeof(type_flag)) != sizeof(type_flag)) return false; CHECK(type_flag == mshadow::DataType::kFlag) << "Only support float NDArray so far"; // load data into CPU NDArray temp(shape, Context::CPU()); TBlob load_data = temp.data(); size_t type_size = sizeof(real_t); size_t nread = type_size * shape.Size(); if (strm->Read(load_data.dptr_, nread) != nread) return false; if (ctx.dev_mask() == cpu::kDevMask) { *this = std::move(temp); return true; } else { *this = temp.Copy(ctx); return true; } } const uint64_t kMXAPINDArrayListMagic = 0x112; void NDArray::Save(dmlc::Stream* fo, const std::vector& data, const std::vector& names) { uint64_t header = kMXAPINDArrayListMagic, reserved = 0; fo->Write(&header, sizeof(header)); fo->Write(&reserved, sizeof(reserved)); fo->Write(data); fo->Write(names); } void NDArray::Load(dmlc::Stream* fi, std::vector* data, std::vector* keys) { uint64_t header, reserved; CHECK(fi->Read(&header)) << "Invalid NDArray file format"; CHECK(fi->Read(&reserved)) << "Invalid NDArray file format"; CHECK(header == kMXAPINDArrayListMagic) << "Invalid NDArray file format"; CHECK(fi->Read(data)) << "Invalid NDArray file format"; CHECK(fi->Read(keys)) << "Invalid NDArray file format"; CHECK(keys->size() == 0 || keys->size() == data->size()) << "Invalid NDArray file format"; } NDArray NDArray::Copy(Context ctx) const { NDArray ret(shape(), ctx, true); CopyFromTo(*this, &ret); return ret; } void NDArray::SyncCopyFromCPU(const real_t *data, size_t size) const { this->WaitToWrite(); TShape dshape = this->shape(); CHECK_EQ(dshape.Size(), size) << "Memory size do not match"; Context ctx = this->ctx(); TBlob dst = this->data(); TBlob src((real_t*)data, dshape, cpu::kDevMask); // NOLINT(*) RunContext run_ctx; run_ctx.stream = nullptr; if (ctx.dev_mask() == cpu::kDevMask) { ndarray::Copy(src, &dst, Context::CPU(), ctx, run_ctx); } else { #if MXNET_USE_CUDA // use empty stream to do sync copy // TODO(bing, yutian) consider use a Real Stream, so it is not blocking others // Maybe move to engine part mshadow::Stream zero_stream; run_ctx.stream = &zero_stream; ndarray::Copy(src, &dst, Context::CPU(), ctx, run_ctx); #else LOG(FATAL) << "GPU is not enabled"; #endif } } void NDArray::SyncCopyToCPU(real_t *data, size_t size) const { this->WaitToRead(); TShape dshape = this->shape(); CHECK_EQ(dshape.Size(), size) << "Memory size do not match"; Context ctx = this->ctx(); TBlob src = this->data(); TBlob dst(data, dshape, cpu::kDevMask); // NOLINT(*) RunContext run_ctx; run_ctx.stream = nullptr; if (ctx.dev_mask() == cpu::kDevMask) { ndarray::Copy(src, &dst, ctx, Context::CPU(), run_ctx); } else { #if MXNET_USE_CUDA // use empty stream to do sync copy // TODO(bing, yutian) consider use a Real Stream, so it is not blocking others // Maybe move to engine part mshadow::Stream zero_stream; run_ctx.stream = &zero_stream; ndarray::Copy(src, &dst, ctx, Context::CPU(), run_ctx); #else LOG(FATAL) << "GPU is not enabled"; #endif } } #if MXNET_PREDICT_ONLY == 0 // register API function // those with underscore will be registered at NDArray MXNET_REGISTER_NDARRAY_FUN(_set_value).set_function(SetValueOp); MXNET_REGISTER_NDARRAY_FUN(_plus).set_function(BinaryOp); MXNET_REGISTER_NDARRAY_FUN(_minus).set_function(BinaryOp); MXNET_REGISTER_NDARRAY_FUN(_mul).set_function(BinaryOp); MXNET_REGISTER_NDARRAY_FUN(_div).set_function(BinaryOp); MXNET_REGISTER_NDARRAY_FUN(dot).set_function(BinaryOp) .describe("Calcuate 2D matrix multiplication"); MXNET_REGISTER_NDARRAY_FUN(_onehot_encode).set_function(BinaryOp); MXNET_REGISTER_NDARRAY_FUN(choose_element_0index) .set_function(BinaryOp) .describe("Choose one element from each line(row for python, column for R/Julia)" " in lhs according to index indicated by rhs." " This function assume rhs uses 0-based index."); // register API function // those with underscore will be registered at NDArray MXNET_REGISTER_NDARRAY_FUN(_plus_scalar).set_function(ScalarOp); MXNET_REGISTER_NDARRAY_FUN(_minus_scalar).set_function(ScalarOp); MXNET_REGISTER_NDARRAY_FUN(_mul_scalar).set_function(ScalarOp); MXNET_REGISTER_NDARRAY_FUN(_div_scalar).set_function(ScalarOp); // register API function // scalar, reverse scalar MXNET_REGISTER_NDARRAY_FUN(_rminus_scalar).set_function(ScalarOp); MXNET_REGISTER_NDARRAY_FUN(_rdiv_scalar).set_function(ScalarOp); // copy function is special // that we need to remove kAcceptEmptyMutateTarget from it MXNET_REGISTER_NDARRAY_FUN(_copyto) .set_function(CopyFromToSimple) .set_type_mask(kNDArrayArgBeforeScalar); // register random number generators MXNET_REGISTER_NDARRAY_FUN(_random_uniform) .set_body([](NDArray **u, real_t *s, NDArray **out) { SampleUniform(s[0], s[1], out[0]); }) .set_num_scalars(2) .set_num_mutate_vars(1); MXNET_REGISTER_NDARRAY_FUN(_random_gaussian) .set_body([](NDArray **u, real_t *s, NDArray **out) { SampleGaussian(s[0], s[1], out[0]); }) .set_num_scalars(2) .set_num_mutate_vars(1); MXNET_REGISTER_NDARRAY_FUN(clip) .set_type_mask(kNDArrayArgBeforeScalar | kAcceptEmptyMutateTarget) .set_body([](NDArray **u, real_t *s, NDArray **out) { ClipOp(*u[0], s[0], s[1], out[0]); }) .set_num_use_vars(1) .set_num_scalars(2) .set_num_mutate_vars(1) .describe("Clip ndarray elements to range (a_min, a_max)") .add_argument("src", "NDArray", "Source input") .add_argument("a_min", "real_t", "Minimum value") .add_argument("a_max", "real_t", "Maximum value"); #endif } // namespace mxnet //===== EXPANDED: mxnet/src/ndarray/ndarray.cc ===== //===== EXPANDIND: mxnet/src/engine/engine.cc ===== /*! * Copyright (c) 2015 by Contributors * \file engine.cc * \brief Implementation of engine. */ //===== EXPANDIND: mxnet/src/engine/engine_impl.h ===== /*! * Copyright (c) 2015 by Contributors * \file engine_impl.h * \brief Internal implementation header of engine components. */ #ifndef MXNET_ENGINE_ENGINE_IMPL_H_ #define MXNET_ENGINE_ENGINE_IMPL_H_ /*! \brief MACRO on whether or not enable debug option*/ #define ENGINE_DEBUG 0 namespace mxnet { namespace engine { /*! \brief base class of engine variables, used for type checking */ struct Var { #if ENGINE_DEBUG virtual ~Var() = default; #endif // ENGINE_DEBUG /*! * \brief cast variable to derived type T * \tparam T the type we want to cast into. * \return A casted variable. */ template inline T* Cast(); }; // struct Var /*! \brief base class of engine operators, used for type checking */ struct Opr { #if ENGINE_DEBUG virtual ~Opr() = default; #endif /*! * \brief cast variable to derived type T * \tparam T the type we want to cast into. * \return A casted variable. */ template inline T* Cast(); }; // struct Opr // implementation of the inline functions template inline T* Var::Cast() { static_assert(std::is_base_of::value, "must inherit `mxnet::engine::Var`"); #if ENGINE_DEBUG return dynamic_cast(this); #else return static_cast(this); #endif } template inline T* Opr::Cast() { static_assert(std::is_base_of::value, "must inherit `mxnet::engine::Opr`"); #if ENGINE_DEBUG return dynamic_cast(this); #else return static_cast(this); #endif } /*! \brief Maximum number of GPUs */ static constexpr std::size_t kMaxNumGPUs = 16; // predeclare factory function for each type of engine /*! \return NaiveEngine instance */ Engine *CreateNaiveEngine(); #if MXNET_PREDICT_ONLY == 0 /*! \return ThreadedEnginePooled instance */ Engine *CreateThreadedEnginePooled(); /*! \return ThreadedEnginePerDevie instance */ Engine *CreateThreadedEnginePerDevice(); #endif } // namespace engine } // namespace mxnet #endif // MXNET_ENGINE_ENGINE_IMPL_H_ //===== EXPANDED: mxnet/src/engine/engine_impl.h ===== namespace mxnet { namespace engine { inline Engine* CreateEngine() { const char *type = getenv("MXNET_ENGINE_TYPE"); const bool default_engine = (type == nullptr); if (type == nullptr) type = "ThreadedEnginePerDevice"; std::string stype = type; Engine *ret = nullptr; #if MXNET_PREDICT_ONLY == 0 if (stype == "NaiveEngine") { ret = CreateNaiveEngine(); } else if (stype == "ThreadedEngine") { ret = CreateThreadedEnginePooled(); } else if (stype == "ThreadedEnginePerDevice") { ret = CreateThreadedEnginePerDevice(); } #else ret = CreateNaiveEngine(); #endif if (ret ==nullptr) { LOG(FATAL) << "Cannot find Engine " << type; } if (!default_engine) { LOG(INFO) << "MXNet start using engine: " << type; } return ret; } } // namespace engine std::shared_ptr Engine::_GetSharedRef() { static std::shared_ptr sptr(engine::CreateEngine()); return sptr; } Engine* Engine::Get() { static Engine *inst = _GetSharedRef().get(); return inst; } } // namespace mxnet //===== EXPANDED: mxnet/src/engine/engine.cc ===== //===== EXPANDIND: mxnet/src/engine/naive_engine.cc ===== /*! * Copyright (c) 2015 by Contributors * \file naive_engine.cc * \brief Implementation of NaiveEngine */ namespace mxnet { namespace engine { // implement naive engine class NaiveEngine final : public Engine { public: struct NaiveOpr : public Opr { AsyncFn fn; std::vector const_vars; std::vector mutable_vars; FnProperty prop; }; NaiveEngine() { } // virtual destructor virtual ~NaiveEngine() { #if MXNET_USE_CUDA LOG(INFO) << "Engine shutdown"; for (size_t i = 0; i < streams_.size(); ++i) { if (streams_[i] != nullptr) { // Catch exception for CUDA driver shutdown MSHADOW_CATCH_ERROR(mshadow::DeleteStream(streams_[i])); streams_[i] = nullptr; } } #endif } // new variables VarHandle NewVariable() override { return nullptr; } OprHandle NewOperator(AsyncFn fn, std::vector const& const_vars, std::vector const& mutable_vars, FnProperty prop) override { NaiveOpr *opr = new NaiveOpr(); opr->fn = fn; opr->const_vars = const_vars; opr->mutable_vars = mutable_vars; opr->prop = prop; return opr; } void DeleteOperator(OprHandle op) override { NaiveOpr *opr = op->Cast(); delete opr; } void Push(OprHandle op, Context exec_ctx, int priority) override { NaiveOpr *opr = op->Cast(); this->PushAsync(opr->fn, exec_ctx, opr->const_vars, opr->mutable_vars, opr->prop); } void PushAsync(AsyncFn exec_fun, Context exec_ctx, std::vector const& const_vars, std::vector const& mutable_vars, FnProperty prop, int priority = 0) override { CallbackOnComplete callback = CreateCallback( NaiveEngine::OnComplete, nullptr); this->req_completed_ = false; if (exec_ctx.dev_mask() == gpu::kDevMask) { #if MXNET_USE_CUDA size_t dev_id = static_cast(exec_ctx.dev_id); MSHADOW_CATCH_ERROR(mshadow::SetDevice(exec_ctx.dev_id)); if (streams_.size() <= dev_id) { streams_.resize(dev_id + 1, nullptr); } if (streams_[dev_id] == nullptr) { streams_[dev_id] = mshadow::NewStream(true, MXNET_USE_CUDNN != 0); } ctx_.stream = streams_[dev_id]; exec_fun(ctx_, callback); #else LOG(FATAL) << "GPU is not enabled"; #endif } else { ctx_.stream = &cpu_stream_; exec_fun(ctx_, callback); } CHECK(this->req_completed_) << "NaiveEngine only support synchronize Push so far"; } void DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var) override { this->PushSync(delete_fn, exec_ctx, {}, {var}, FnProperty::kNormal); } void WaitForVar(VarHandle var) override { } void WaitForAll() override { } void NotifyShutdown() override { shutdown_phase_.store(true); } private: // callback to oncomplete static void OnComplete(Engine *engine, void *param) { static_cast(engine)->req_completed_ = true; } // runtime contetxt RunContext ctx_; // whether action is completed bool req_completed_; /*! \brief whether it is during shutdown phase*/ std::atomic shutdown_phase_{false}; // CPU stream mshadow::Stream cpu_stream_; // GPU streams std::vector*> streams_; }; // class NaiveEngine Engine *CreateNaiveEngine() { return new NaiveEngine(); } } // namespace engine } // namespace mxnet //===== EXPANDED: mxnet/src/engine/naive_engine.cc ===== //===== EXPANDIND: mxnet/src/engine/threaded_engine.cc ===== /*! * Copyright (c) 2015 by Contributors * \file threaded_engine.cc * \brief implements base threaded engine. * \author Yutian Li */ //===== EXPANDIND: mxnet/src/engine/threaded_engine.h ===== /*! * Copyright (c) 2015 by Contributors * \file threaded_engine.h * \brief Implements base class of threaded engine * that tracks the dependency and pushes actions to execute. * \author Yutian Li */ #ifndef MXNET_ENGINE_THREADED_ENGINE_H_ #define MXNET_ENGINE_THREADED_ENGINE_H_ //===== EXPANDIND: mxnet/src/common/object_pool.h ===== /*! * Copyright (c) 2015 by Contributors */ #ifndef MXNET_COMMON_OBJECT_POOL_H_ #define MXNET_COMMON_OBJECT_POOL_H_ namespace mxnet { namespace common { /*! * \brief Object pool for fast allocation and deallocation. */ template class ObjectPool { public: /*! * \brief Destructor. */ ~ObjectPool(); /*! * \brief Create new object. * \return Pointer to the new object. */ template T* New(Args&&... args); /*! * \brief Delete an existing object. * \param ptr The pointer to delete. * * Make sure the pointer to delete is allocated from this pool. */ void Delete(T* ptr); /*! * \brief Get singleton instance of pool. * \return Object Pool. */ static ObjectPool* Get(); /*! * \brief Get a shared ptr of the singleton instance of pool. * \return Shared pointer to the Object Pool. */ static std::shared_ptr _GetSharedRef(); private: /*! * \brief Internal structure to hold pointers. */ struct LinkedList { #if defined(_MSC_VER) T t; LinkedList* next{nullptr}; #else union { T t; LinkedList* next{nullptr}; }; #endif }; /*! * \brief Page size of allocation. * * Currently defined to be 4KB. */ constexpr static std::size_t kPageSize = 1 << 12; /*! \brief internal mutex */ std::mutex m_; /*! * \brief Head of free list. */ LinkedList* head_{nullptr}; /*! * \brief Pages allocated. */ std::vector allocated_; /*! * \brief Private constructor. */ ObjectPool(); /*! * \brief Allocate a page of raw objects. * * This function is not protected and must be called with caution. */ void AllocateChunk(); DISALLOW_COPY_AND_ASSIGN(ObjectPool); }; // class ObjectPool /*! * \brief Helper trait class for easy allocation and deallocation. */ template struct ObjectPoolAllocatable { /*! * \brief Create new object. * \return Pointer to the new object. */ template static T* New(Args&&... args); /*! * \brief Delete an existing object. * \param ptr The pointer to delete. * * Make sure the pointer to delete is allocated from this pool. */ static void Delete(T* ptr); }; // struct ObjectPoolAllocatable template ObjectPool::~ObjectPool() { // TODO(hotpxl): mind destruction order // for (auto i : allocated_) { // free(i); // } } template template T* ObjectPool::New(Args&&... args) { LinkedList* ret; { std::lock_guard lock{m_}; if (head_->next == nullptr) { AllocateChunk(); } ret = head_; head_ = head_->next; } return new (static_cast(ret)) T(std::forward(args)...); } template void ObjectPool::Delete(T* ptr) { ptr->~T(); auto linked_list_ptr = reinterpret_cast(ptr); { std::lock_guard lock{m_}; linked_list_ptr->next = head_; head_ = linked_list_ptr; } } template ObjectPool* ObjectPool::Get() { return _GetSharedRef().get(); } template std::shared_ptr > ObjectPool::_GetSharedRef() { static std::shared_ptr > inst_ptr(new ObjectPool()); return inst_ptr; } template ObjectPool::ObjectPool() { AllocateChunk(); } template void ObjectPool::AllocateChunk() { static_assert(sizeof(LinkedList) <= kPageSize, "Object too big."); static_assert(sizeof(LinkedList) % alignof(LinkedList) == 0, "ObjectPooll Invariant"); static_assert(alignof(LinkedList) % alignof(T) == 0, "ObjectPooll Invariant"); static_assert(kPageSize % alignof(LinkedList) == 0, "ObjectPooll Invariant"); void* new_chunk_ptr; #ifdef _MSC_VER new_chunk_ptr = _aligned_malloc(kPageSize, kPageSize); CHECK_NE(new_chunk_ptr, NULL) << "Allocation failed"; #else int ret = posix_memalign(&new_chunk_ptr, kPageSize, kPageSize); CHECK_EQ(ret, 0) << "Allocation failed"; #endif allocated_.emplace_back(new_chunk_ptr); auto new_chunk = static_cast(new_chunk_ptr); auto size = kPageSize / sizeof(LinkedList); for (std::size_t i = 0; i < size - 1; ++i) { new_chunk[i].next = &new_chunk[i + 1]; } new_chunk[size - 1].next = head_; head_ = new_chunk; } template template T* ObjectPoolAllocatable::New(Args&&... args) { return ObjectPool::Get()->New(std::forward(args)...); } template void ObjectPoolAllocatable::Delete(T* ptr) { ObjectPool::Get()->Delete(ptr); } } // namespace common } // namespace mxnet #endif // MXNET_COMMON_OBJECT_POOL_H_ //===== EXPANDED: mxnet/src/common/object_pool.h ===== namespace mxnet { namespace engine { // Define helper macros for debug information. #if ENGINE_DEBUG #define DEFINE_ENGINE_DEBUG_INFO(Type) \ static std::atomic counter; \ Type() { LOG(INFO) << __func__ << " " << ++counter; } \ ~Type() { LOG(INFO) << __func__ << " " << --counter; } #else #define DEFINE_ENGINE_DEBUG_INFO(Type) #endif // Forward declarations struct ThreadedOpr; /*! * \brief Operation block in the scheduler. * Each OprBlock corresponds to an operation pushed to the engine. */ struct OprBlock : public common::ObjectPoolAllocatable { /*! * \brief wait number of pending tasks this OprBlock is waiting for. */ std::atomic wait{0}; /*! \brief Pointer to information on performing real operation */ ThreadedOpr* opr{nullptr}; /*! \brief The context this operator */ Context ctx; /*! \brief priority of the function */ int priority; // define possible debug information DEFINE_ENGINE_DEBUG_INFO(OprBlock); /*! * \brief call this function to decrease the wait counter. * \return the wait counter after the decreasement. */ inline int decr_wait() { // chack invariant, avoid over trigger int ret = --wait; CHECK_GE(ret, 0); return ret; } }; // struct OprBlock /*! * \brief VersionedVarBlock that corresponding to a variable version. * This is a basic unit of LinkedList in the ThreadedVar. */ struct VersionedVarBlock : public common::ObjectPoolAllocatable { /*! \brief next block in the LinkedList */ VersionedVarBlock* next{nullptr}; /*! \brief the operation this block triggers */ OprBlock* trigger{nullptr}; /*! \brief whether this operation is a write(mutate) operation. */ bool write{false}; /*! \brief define possible debug information */ DEFINE_ENGINE_DEBUG_INFO(VersionedVarBlock); }; // struct VersionedVarBlock /*! * \brief Variable implementation. * Each ThreadedVar is a linked list(queue) of operations to be performed. */ class ThreadedVar final : public Var, public common::ObjectPoolAllocatable { public: /*! * \brief constructor * \param head head block of the LinkedList, * need to be initialized with next==nullptr and trigger=nullptr. */ explicit ThreadedVar(VersionedVarBlock* head); /*! * \brief Schedule a read operation on this variable. * If the opr_block can be runed right away, * the wait counter of opr_block will be decreased. * Otherwise, the opr_block will be added to waiting queue. * \param opr_block The operation to be scheduled. */ inline void AppendReadDependency(OprBlock* opr_block); /*! * \brief Schedule a write operation on this variable. * If the opr_block can be runed right away, * the wait counter of opr_block will be decreased. * Otherwise, the opr_block will be added to waiting queue. * \param opr_block The operation to be scheduled. */ inline void AppendWriteDependency(OprBlock* opr_block); /*! * \brief A read operation is completed on this variable. * This function may trigger subsequent waiting operations on this variable. * * \param dispatcher the function called to trigger the operation, * when all of its dependencies are satiesfied. * \tparam Dispatcher the function called to trigger an operation. */ template inline void CompleteReadDependency(Dispatcher dispatcher); /*! * \brief A write operation is completed on this variable. * This function may trigger subsequent waiting operations on this variable. * * \param dispatcher the function called to trigger the operation, * when all of its dependencies are satiesfied. * \tparam Dispatcher the function called to trigger an operation. * \return to_delete, whether this Variable can be deleted after this functin. */ template inline bool CompleteWriteDependency(Dispatcher dispatcher); /*! \brief Mark this variable to be deleted. */ inline void SetToDelete(); /*! \return whether this variable is ready to read. */ inline bool ready_to_read(); /*! * \brief Cast a Var pointer to ThreadedVar pointer * \param ptr pointer from base. * \return a casted pointer. */ inline static ThreadedVar* CastFromBase(Var* ptr) { return ptr->Cast(); } // code for debug. #if ENGINE_DEBUG static std::atomic counter; ~ThreadedVar() { LOG(INFO) << __func__ << " " << --counter; } #endif // ENGINE_DEBUG private: // TODO(hotpxl) change this to spinlock for faster runtime // TODO(hotpxl) consider rename head /*! \brief inetrnal mutex of the ThreadedVar */ std::mutex m_; /*! * \brief number of pending reads operation in the variable. * will be marked as -1 when there is a already triggered pending write. */ int num_pending_reads_{0}; /*! * \brief Points to the last VersionedVarBlock in the queue. * head_ always points to a empty VersionedVarBlock. * So when we want to append an operation to the queue: * 1) update head_->trigger to be new op * 2) update head_->next to be a new VersionedVarBlock * 3) move head to head->next. */ VersionedVarBlock* head_{nullptr}; /*! * \brief The pointer to next write to perform. * This pointer will only be updated when the write completes. * This is actually the head(oldest operation) in the queue. */ VersionedVarBlock* pending_write_{nullptr}; /*! * \brief If true, delete after operation completes. */ bool to_delete_{false}; /*! \brief special const on num_pending_reads_ to mark write being triggered */ static constexpr int kWriteTriggered = -1; /*! * \brief derived invariant of ready to ready, without lock. * \return whether the current variable is ready to read. */ inline bool is_ready_to_read() const { return pending_write_ == nullptr; } }; // struct ThreadedVar /*! * \brief Operator used in ThreadedEngine. */ struct ThreadedOpr final : public Opr, public common::ObjectPoolAllocatable { /*! \brief The function to be invoked each time. */ Engine::AsyncFn fn; /*! \brief The variable this operation will read from. */ std::vector const_vars; /*! \brief The variable this operation will mutate. */ std::vector mutable_vars; /*! \brief the property of the operator */ FnProperty prop; /*! * \brief Whether this is an temporary operator * that can be deleted right after the operation completed. */ bool temporary{false}; /*! * \brief Cast a Opr pointer to ThreadedOpr pointer * \param ptr pointer from base. * \return a casted pointer. */ inline static ThreadedOpr* CastFromBase(Opr* ptr) { return ptr->Cast(); } // define possible debug information DEFINE_ENGINE_DEBUG_INFO(ThreadedOpr); }; // struct ThreadedOpr /*! * \brief Base class of all ThreadedEngine. * This class implements a thread safe version of engine. * The engine tracks the dependencies, and will call PushToExecute * to execute a specific task. * * Subclass can implement PushToExecute to design specific * execution policy for the tasks. */ class ThreadedEngine : public Engine { public: // implementing all the functions from Engine. ThreadedVar* NewVariable() override; ThreadedOpr* NewOperator(AsyncFn fn, std::vector const& const_vars, std::vector const& mutable_vars, FnProperty prop) override; void DeleteOperator(OprHandle op) override; void Push(OprHandle op, Context exec_ctx, int priority) override; void PushAsync(AsyncFn exec_fun, Context exec_ctx, std::vector const& const_vars, std::vector const& mutable_vars, FnProperty prop, int priority) override; void DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var) override; void WaitForVar(VarHandle var) override; void WaitForAll() override; void NotifyShutdown() override { shutdown_phase_.store(true); } ThreadedEngine() { engine_info_ = dmlc::GetEnv("MXNET_ENGINE_INFO", false); objpool_opr_ref_ = common::ObjectPool::_GetSharedRef(); objpool_blk_ref_ = common::ObjectPool::_GetSharedRef(); objpool_varblk_ref_ = common::ObjectPool::_GetSharedRef(); objpool_var_ref_ = common::ObjectPool::_GetSharedRef(); } ~ThreadedEngine() { { std::unique_lock lock{finished_m_}; kill_.store(true); } finished_cv_.notify_all(); } protected: /*! * \brief Push the opr block to execution queue to be executed. * This function is implemented by the corresponding subclass * for specific policy. * * \param opr_block The operator block. * \param pusher_thread whether the caller is the thread that calls push */ virtual void PushToExecute(OprBlock* opr_block, bool pusher_thread) = 0; /*! * \brief Call this function to actually execute an opr_block * This function also deletes the opr_block after execution. * \param run_ctx runtime context used to execute the function. * \param opr_block the opr_block to be executed and deleted. */ void ExecuteOprBlock(RunContext run_ctx, OprBlock *opr_block) { ThreadedOpr* threaded_opr = opr_block->opr; CallbackOnComplete callback = this->CreateCallback( ThreadedEngine::OnCompleteStatic, threaded_opr); if (!shutdown_phase_) { try { threaded_opr->fn(run_ctx, callback); } catch(dmlc::Error &e) { std::string what = e.what(); if (what.find("driver shutting down") == std::string::npos && !shutdown_phase_) { LOG(FATAL) << e.what(); } } } else { callback(); } OprBlock::Delete(opr_block); } private: /*! * \brief check if thee is duplication in const_vars and mutable_vars. * \param const_vars the variables to read from. * \param mutable_vars the variables to mutate. */ void CheckDuplicate(std::vector const& const_vars, std::vector const& mutable_vars); /*! * \brief Callback on operation completion. * * On operation completion, this will trigger subsequent operations. */ inline void OnComplete(ThreadedOpr* threaded_opr); // callback to the threaded engine static void OnCompleteStatic(Engine *engine, void *threaded_opr); /*! * \brief Number of pending operations. */ std::atomic pending_{0}; /*! \brief whether we want to kill the waiters */ std::atomic kill_{false}; /*! \brief whether it is during shutdown phase*/ std::atomic shutdown_phase_{false}; /*!\brief show more information from engine actions */ bool engine_info_{false}; /*! * \brief Mutex and condition_variable, * used to Notify waits for single or all variables. */ std::mutex finished_m_; std::condition_variable finished_cv_; /*! * \brief Holding a shared_ptr to the object pool to prevent it from being destructed too early * See also #309 (https://github.com/dmlc/mxnet/issues/309) */ std::shared_ptr > objpool_opr_ref_; std::shared_ptr > objpool_blk_ref_; std::shared_ptr > objpool_varblk_ref_; std::shared_ptr > objpool_var_ref_; /*! * \brief Disallow copy construction and assignment. */ DISALLOW_COPY_AND_ASSIGN(ThreadedEngine); }; // class ThreadedEngine } // namespace engine } // namespace mxnet #endif // MXNET_ENGINE_THREADED_ENGINE_H_ //===== EXPANDED: mxnet/src/engine/threaded_engine.h ===== //===== EXPANDIND: mxnet/src/common/cuda_utils.h ===== /*! * Copyright (c) 2015 by Contributors * \file cuda_utils.h * \brief CUDA debugging utilities. */ #ifndef MXNET_COMMON_CUDA_UTILS_H_ #define MXNET_COMMON_CUDA_UTILS_H_ #if MXNET_USE_CUDA namespace mxnet { namespace common { /*! \brief common utils for cuda */ namespace cuda { /*! * \brief Get string representation of cuBLAS errors. * \param error The error. * \return String representation. */ inline const char* CublasGetErrorString(cublasStatus_t error) { switch (error) { case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS"; case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED"; case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED"; case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE"; case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH"; case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR"; case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED"; case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR"; case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED"; default: break; } return "Unknown cuBLAS status"; } /*! * \brief Get string representation of cuRAND errors. * \param status The status. * \return String representation. */ inline const char* CurandGetErrorString(curandStatus_t status) { switch (status) { case CURAND_STATUS_SUCCESS: return "CURAND_STATUS_SUCCESS"; case CURAND_STATUS_VERSION_MISMATCH: return "CURAND_STATUS_VERSION_MISMATCH"; case CURAND_STATUS_NOT_INITIALIZED: return "CURAND_STATUS_NOT_INITIALIZED"; case CURAND_STATUS_ALLOCATION_FAILED: return "CURAND_STATUS_ALLOCATION_FAILED"; case CURAND_STATUS_TYPE_ERROR: return "CURAND_STATUS_TYPE_ERROR"; case CURAND_STATUS_OUT_OF_RANGE: return "CURAND_STATUS_OUT_OF_RANGE"; case CURAND_STATUS_LENGTH_NOT_MULTIPLE: return "CURAND_STATUS_LENGTH_NOT_MULTIPLE"; case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED: return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED"; case CURAND_STATUS_LAUNCH_FAILURE: return "CURAND_STATUS_LAUNCH_FAILURE"; case CURAND_STATUS_PREEXISTING_FAILURE: return "CURAND_STATUS_PREEXISTING_FAILURE"; case CURAND_STATUS_INITIALIZATION_FAILED: return "CURAND_STATUS_INITIALIZATION_FAILED"; case CURAND_STATUS_ARCH_MISMATCH: return "CURAND_STATUS_ARCH_MISMATCH"; case CURAND_STATUS_INTERNAL_ERROR: return "CURAND_STATUS_INTERNAL_ERROR"; } return "Unknown cuRAND status"; } } // namespace cuda } // namespace common } // namespace mxnet /*! * \brief Check CUDA error. * \param msg Message to print if an error occured. */ #define CHECK_CUDA_ERROR(msg) \ { \ cudaError_t e = cudaGetLastError(); \ CHECK_EQ(e, cudaSuccess) << (msg) << " CUDA: " << cudaGetErrorString(e); \ } /*! * \brief Protected CUDA call. * \param func Expression to call. * * It checks for CUDA errors after invocation of the expression. */ #define CUDA_CALL(func) \ { \ cudaError_t e = (func); \ CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \ << "CUDA: " << cudaGetErrorString(e); \ } /*! * \brief Protected cuBLAS call. * \param func Expression to call. * * It checks for cuBLAS errors after invocation of the expression. */ #define CUBLAS_CALL(func) \ { \ cublasStatus_t e = (func); \ CHECK_EQ(e, CUBLAS_STATUS_SUCCESS) \ << "cuBLAS: " << common::cuda::CublasGetErrorString(e); \ } /*! * \brief Protected cuRAND call. * \param func Expression to call. * * It checks for cuRAND errors after invocation of the expression. */ #define CURAND_CALL(func) \ { \ curandStatus_t e = (func); \ CHECK_EQ(e, CURAND_STATUS_SUCCESS) \ << "cuRAND: " << common::cuda::CurandGetErrorString(e); \ } #endif // MXNET_USE_CUDA #if MXNET_USE_CUDNN #define CUDNN_CALL(func) \ { \ cudnnStatus_t e = (func); \ CHECK_EQ(e, CUDNN_STATUS_SUCCESS) << "cuDNN: " << cudnnGetErrorString(e); \ } #endif // MXNET_USE_CUDNN #endif // MXNET_COMMON_CUDA_UTILS_H_ //===== EXPANDED: mxnet/src/common/cuda_utils.h ===== namespace mxnet { namespace engine { #if ENGINE_DEBUG std::atomic OprBlock::counter{0}; std::atomic VersionedVarBlock::counter{0}; std::atomic ThreadedVar::counter{0}; std::atomic ThreadedOpr::counter{0}; #endif // ENGINE_DEBUG ThreadedVar::ThreadedVar(VersionedVarBlock* head) : head_{head} { #if ENGINE_DEBUG LOG(INFO) << __func__ << " " << ++counter; #endif // ENGINE_DEBUG } inline void ThreadedVar::AppendReadDependency(OprBlock* opr_block) { std::lock_guard lock{m_}; if (pending_write_ == nullptr) { // invariant: is_ready_to_read() CHECK_GE(num_pending_reads_, 0); // STATE CHANGE ++num_pending_reads_; // decrease wait counter opr_block->decr_wait(); } else { auto&& new_var_block = VersionedVarBlock::New(); assert(head_->next == nullptr); assert(head_->trigger == nullptr); assert(head_->write == false); // append things to next. head_->next = new_var_block; head_->trigger = opr_block; head_ = new_var_block; } } inline void ThreadedVar::AppendWriteDependency(OprBlock* opr_block) { auto&& new_var_block = VersionedVarBlock::New(); std::lock_guard lock{m_}; // invariant. assert(head_->next == nullptr); assert(head_->trigger == nullptr); assert(head_->write == false); // attach to head. head_->next = new_var_block; head_->trigger = opr_block; head_->write = true; // check if it is ready to write if (pending_write_ == nullptr) { // invariant: is_ready_to_read() pending_write_ = head_; CHECK_GE(num_pending_reads_, 0); if (num_pending_reads_ == 0) { // STATE CHANGE opr_block->decr_wait(); num_pending_reads_ = kWriteTriggered; } } else { CHECK_NE(num_pending_reads_, 0); } head_ = new_var_block; } template inline void ThreadedVar::CompleteReadDependency(Dispatcher dispatcher) { OprBlock *trigger = nullptr; { // this is lock scope std::lock_guard lock{m_}; CHECK_GT(num_pending_reads_, 0); if (--num_pending_reads_ == 0) { if (pending_write_ != nullptr) { // STATE CHANGE trigger = pending_write_->trigger; num_pending_reads_ = kWriteTriggered; } } } if (trigger != nullptr && trigger->decr_wait() == 0) { dispatcher(trigger); } } template inline bool ThreadedVar::CompleteWriteDependency(Dispatcher dispatcher) { // this is lock scope VersionedVarBlock *old_pending_write, *end_of_read_chain; OprBlock* trigger_write = nullptr; { std::lock_guard lock{m_}; // invariants assert(head_->next == nullptr); assert(pending_write_ != nullptr); CHECK_EQ(num_pending_reads_, kWriteTriggered); // really delete if (to_delete_) { VersionedVarBlock *head = pending_write_->next; VersionedVarBlock::Delete(pending_write_); assert(head_ == head); VersionedVarBlock::Delete(head); return true; } // detach pending write old_pending_write = pending_write_; // search for chains to trigger end_of_read_chain = old_pending_write->next; // reset to 0 pending reads num_pending_reads_ = 0; while (end_of_read_chain != head_ && end_of_read_chain->write == false) { ++num_pending_reads_; end_of_read_chain = end_of_read_chain->next; } if (end_of_read_chain == head_) { pending_write_ = nullptr; } else { // check if there is pending reads, if not trigger write assert(end_of_read_chain->write == true); pending_write_ = end_of_read_chain; if (num_pending_reads_ == 0) { // mark write as already actived in this var num_pending_reads_ = kWriteTriggered; trigger_write = end_of_read_chain->trigger; } } } // This is outside of lock scope // Be very carful, pending_write_ and num_pending_reads_ // can change now, do not reply ont the two variables. // The linked list \in [old_pending_write, end_of_read_chain) // is already detached from this Var. // So it is safe to modify these VersionedVarBlock *cur_head = old_pending_write->next; VersionedVarBlock::Delete(old_pending_write); // dispatch all the events while (cur_head != end_of_read_chain) { if (cur_head->trigger->decr_wait() == 0) { dispatcher(cur_head->trigger); } auto prev = cur_head; cur_head = cur_head->next; assert(cur_head != nullptr); VersionedVarBlock::Delete(prev); } if (trigger_write != nullptr && trigger_write->decr_wait() == 0) { dispatcher(trigger_write); } return false; } inline void ThreadedVar::SetToDelete() { std::lock_guard lock{m_}; to_delete_ = true; } inline bool ThreadedVar::ready_to_read() { std::lock_guard lock{m_}; return this->is_ready_to_read(); } // implementation of threaded engine ThreadedVar* ThreadedEngine::NewVariable() { return ThreadedVar::New(VersionedVarBlock::New()); } ThreadedOpr* ThreadedEngine::NewOperator( ThreadedEngine::AsyncFn fn, std::vector const& const_vars, std::vector const& mutable_vars, FnProperty prop) { auto ret = ThreadedOpr::New(); ret->fn = fn; ret->prop = prop; ret->const_vars.resize(const_vars.size()); ret->mutable_vars.resize(mutable_vars.size()); std::transform(const_vars.begin(), const_vars.end(), ret->const_vars.begin(), ThreadedVar::CastFromBase); std::transform(mutable_vars.begin(), mutable_vars.end(), ret->mutable_vars.begin(), ThreadedVar::CastFromBase); if (ENGINE_DEBUG != 0) { CheckDuplicate(const_vars, mutable_vars); } return ret; } void ThreadedEngine::CheckDuplicate(std::vector const& const_vars, std::vector const& mutable_vars) { // Check for duplicates. auto use = const_vars; auto mutate = mutable_vars; auto use_size = use.size(); auto mutate_size = mutate.size(); std::sort(use.begin(), use.end()); std::sort(mutate.begin(), mutate.end()); for (std::size_t i = 0; i < use_size; ++i) { if (i != 0 && use.at(i) == use.at(i - 1)) { LOG(FATAL) << "duplicate items found in `const_vars`"; } } for (std::size_t i = 0; i < mutate_size; ++i) { if (i != 0 && mutate.at(i) == mutate.at(i - 1)) { LOG(FATAL) << "duplicate items found in `mutable_vars`"; } } std::size_t j = 0; for (std::size_t i = 0; i < use_size; ++i) { while (j < mutate_size && mutate.at(j) < use.at(i)) { ++j; } if (j == mutate_size) { break; } if (mutate.at(j) == use.at(i)) { LOG(FATAL) << "duplicate items found between `const_vars` and `mutable_vars`"; } } } void ThreadedEngine::DeleteOperator(OprHandle op) { ThreadedOpr* threaded_opr = ThreadedOpr::CastFromBase(op); std::vector deps; deps.reserve(threaded_opr->const_vars.size() + threaded_opr->mutable_vars.size()); deps.insert(deps.end(), threaded_opr->const_vars.begin(), threaded_opr->const_vars.end()); deps.insert(deps.end(), threaded_opr->mutable_vars.begin(), threaded_opr->mutable_vars.end()); this->PushSync([threaded_opr](RunContext) { ThreadedOpr::Delete(threaded_opr); }, Context::CPU(), {}, deps, FnProperty::kAsync); } void ThreadedEngine::Push(OprHandle op, Context exec_ctx, int priority) { ThreadedOpr* threaded_opr = ThreadedOpr::CastFromBase(op); OprBlock* opr_block = OprBlock::New(); opr_block->opr = threaded_opr; opr_block->wait.store(static_cast( threaded_opr->const_vars.size() + threaded_opr->mutable_vars.size() + 1)); opr_block->ctx = exec_ctx; opr_block->priority = priority; ++pending_; // Add read dependencies. for (auto&& i : threaded_opr->const_vars) { i->AppendReadDependency(opr_block); } // Add write dependencies. for (auto&& i : threaded_opr->mutable_vars) { i->AppendWriteDependency(opr_block); } if (opr_block->decr_wait() == 0) { this->PushToExecute(opr_block, true); } } void ThreadedEngine::PushAsync(AsyncFn fn, Context exec_ctx, std::vector const& const_vars, std::vector const& mutable_vars, FnProperty prop, int priority) { ThreadedOpr *opr = NewOperator(fn, const_vars, mutable_vars, prop); opr->temporary = true; Push(opr, exec_ctx, priority); } void ThreadedEngine::DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var) { ThreadedVar* threaded_var = ThreadedVar::CastFromBase(var); this->PushSync([delete_fn, threaded_var](RunContext ctx) { // Mark variable as orphan, // so during `ThreadedEngine::OnComplete` it could be recycled. threaded_var->SetToDelete(); delete_fn(ctx); }, exec_ctx, {}, {var}, FnProperty::kAsync); } void ThreadedEngine::WaitForVar(VarHandle var) { ThreadedVar* threaded_var = ThreadedVar::CastFromBase(var); if (threaded_var->ready_to_read()) return; std::atomic done{false}; this->PushSync([this, &done](RunContext) { if (engine_info_) { LOG(INFO) << "Sync is executed"; } { std::unique_lock lock{finished_m_}; done.store(true); } finished_cv_.notify_all(); if (engine_info_) { LOG(INFO) << "Sync is notified"; } }, Context::CPU(), {var}, {}, FnProperty::kNormal); { std::unique_lock lock{finished_m_}; finished_cv_.wait(lock, [this, &done]() { return done.load() || kill_.load(); }); } } void ThreadedEngine::WaitForAll() { std::unique_lock lock{finished_m_}; finished_cv_.wait(lock, [this]() { return pending_.load() == 0 || kill_.load(); }); } inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) { // Mark complete for read variables for (auto&& i : threaded_opr->const_vars) { i->CompleteReadDependency([this](OprBlock* opr) { this->PushToExecute(opr, false); }); } // Mark complete for write variables. for (auto&& i : threaded_opr->mutable_vars) { bool to_delete = i->CompleteWriteDependency( [this](OprBlock* opr) { this->PushToExecute(opr, false); }); if (to_delete) { ThreadedVar::Delete(i); } } int npending; { std::unique_lock lock{finished_m_}; npending = --pending_; } CHECK_GE(npending, 0); if (npending == 0) { // no need to grab lock when notify. finished_cv_.notify_all(); } // delte operator if it is temperory if (threaded_opr->temporary) { ThreadedOpr::Delete(threaded_opr); } } void ThreadedEngine::OnCompleteStatic( Engine *engine, void *threaded_opr) { static_cast(engine)->OnComplete( static_cast(threaded_opr)); } } // namespace engine } // namespace mxnet //===== EXPANDED: mxnet/src/engine/threaded_engine.cc ===== //===== EXPANDIND: mxnet/src/engine/threaded_engine_perdevice.cc ===== /*! * Copyright (c) 2015 by Contributors * \file threaded_engine_perdevice.cc * \brief ThreadedEngine that uses fix amount of thread for each device. */ //===== EXPANDIND: mxnet/dmlc-core/include/dmlc/omp.h ===== /*! * Copyright (c) 2015 by Contributors * \file omp.h * \brief header to handle OpenMP compatibility issues */ #ifndef DMLC_OMP_H_ #define DMLC_OMP_H_ #if defined(_OPENMP) #else #ifndef DISABLE_OPENMP // use pragma message instead of warning #pragma message("Warning: OpenMP is not available, " \ "project will be compiled into single-thread code. " \ "Use OpenMP-enabled compiler to get benefit of multi-threading.") #endif //! \cond Doxygen_Suppress inline int omp_get_thread_num() { return 0; } inline int omp_get_num_threads() { return 1; } inline int omp_get_num_procs() { return 1; } inline void omp_set_num_threads(int nthread) {} #endif // loop variable used in openmp namespace dmlc { #ifdef _MSC_VER typedef int omp_uint; typedef long omp_ulong; // NOLINT(*) #else typedef unsigned omp_uint; typedef unsigned long omp_ulong; // NOLINT(*) #endif //! \endcond } // namespace dmlc #endif // DMLC_OMP_H_ //===== EXPANDED: mxnet/dmlc-core/include/dmlc/omp.h ===== //===== EXPANDIND: mxnet/dmlc-core/include/dmlc/concurrency.h ===== /*! * Copyright (c) 2015 by Contributors * \file concurrency.h * \brief thread-safe data structures. * \author Yutian Li */ #ifndef DMLC_CONCURRENCY_H_ #define DMLC_CONCURRENCY_H_ // this code depends on c++11 #if DMLC_USE_CXX11 namespace dmlc { /*! * \brief Simple userspace spinlock implementation. */ class Spinlock { public: #ifdef _MSC_VER Spinlock() { lock_.clear(); } #else Spinlock() = default; #endif ~Spinlock() = default; /*! * \brief Acquire lock. */ inline void lock() noexcept(true); /*! * \brief Release lock. */ inline void unlock() noexcept(true); private: #ifdef _MSC_VER std::atomic_flag lock_; #else std::atomic_flag lock_ = ATOMIC_FLAG_INIT; #endif /*! * \brief Disable copy and move. */ DISALLOW_COPY_AND_ASSIGN(Spinlock); }; /*! \brief type of concurrent queue */ enum class ConcurrentQueueType { /*! \brief FIFO queue */ kFIFO, /*! \brief queue with priority */ kPriority }; /*! * \brief Cocurrent blocking queue. */ template class ConcurrentBlockingQueue { public: ConcurrentBlockingQueue(); ~ConcurrentBlockingQueue() = default; /*! * \brief Push element into the queue. * \param e Element to push into. * \param priority the priority of the element, only used for priority queue. * The higher the priority is, the better. * \tparam E the element type * * It will copy or move the element into the queue, depending on the type of * the parameter. */ template void Push(E&& e, int priority = 0); /*! * \brief Pop element from the queue. * \param rv Element popped. * \return On false, the queue is exiting. * * The element will be copied or moved into the object passed in. */ bool Pop(T* rv); /*! * \brief Signal the queue for destruction. * * After calling this method, all blocking pop call to the queue will return * false. */ void SignalForKill(); /*! * \brief Get the size of the queue. * \return The size of the queue. */ size_t Size(); private: struct Entry { T data; int priority; Entry(const T& data, int priority) : data(data), priority(priority) {} inline bool operator<(const Entry &b) const { return priority < b.priority; } }; std::mutex mutex_; std::condition_variable cv_; std::atomic exit_now_; int nwait_consumer_; // a priority queue std::priority_queue priority_queue_; // a FIFO queue std::queue fifo_queue_; /*! * \brief Disable copy and move. */ DISALLOW_COPY_AND_ASSIGN(ConcurrentBlockingQueue); }; inline void Spinlock::lock() noexcept(true) { while (lock_.test_and_set(std::memory_order_acquire)) { } } inline void Spinlock::unlock() noexcept(true) { lock_.clear(std::memory_order_release); } template ConcurrentBlockingQueue::ConcurrentBlockingQueue() : exit_now_{false}, nwait_consumer_{0} {} template template void ConcurrentBlockingQueue::Push(E&& e, int priority) { static_assert(std::is_same::type>::type, T>::value, "Types must match."); bool notify; { std::lock_guard lock{mutex_}; if (type == ConcurrentQueueType::kFIFO) { fifo_queue_.emplace(std::forward(e)); notify = nwait_consumer_ != 0; } else { priority_queue_.emplace(std::forward(e), priority); notify = nwait_consumer_ != 0; } } if (notify) cv_.notify_one(); } template bool ConcurrentBlockingQueue::Pop(T* rv) { std::unique_lock lock{mutex_}; if (type == ConcurrentQueueType::kFIFO) { ++nwait_consumer_; cv_.wait(lock, [this] { return !fifo_queue_.empty() || exit_now_.load(); }); --nwait_consumer_; if (!exit_now_.load()) { *rv = std::move(fifo_queue_.front()); fifo_queue_.pop(); return true; } else { return false; } } else { ++nwait_consumer_; cv_.wait(lock, [this] { return !priority_queue_.empty() || exit_now_.load(); }); --nwait_consumer_; if (!exit_now_.load()) { *rv = std::move(priority_queue_.top().data); priority_queue_.pop(); return true; } else { return false; } } } template void ConcurrentBlockingQueue::SignalForKill() { { std::lock_guard lock{mutex_}; exit_now_.store(true); } cv_.notify_all(); } template size_t ConcurrentBlockingQueue::Size() { std::lock_guard lock{mutex_}; if (type == ConcurrentQueueType::kFIFO) { return fifo_queue_.size(); } else { return priority_queue_.size(); } } } // namespace dmlc #endif // DMLC_USE_CXX11 #endif // DMLC_CONCURRENCY_H_ //===== EXPANDED: mxnet/dmlc-core/include/dmlc/concurrency.h ===== //===== EXPANDIND: mxnet/src/engine/thread_pool.h ===== /*! * Copyright (c) 2015 by Contributors */ #ifndef MXNET_ENGINE_THREAD_POOL_H_ #define MXNET_ENGINE_THREAD_POOL_H_ namespace mxnet { namespace engine { /*! * \brief Thread pool. */ class ThreadPool { public: /*! * \brief Constructor takes function to run. * \param size size of the thread pool. * \param func the function to run on the thread pool. */ explicit ThreadPool(size_t size, std::function func) : worker_threads_(size) { for (auto& i : worker_threads_) { i = std::thread(func); } } ~ThreadPool() noexcept(false) { for (auto&& i : worker_threads_) { i.join(); } } private: /*! * \brief Worker threads. */ std::vector worker_threads_; /*! * \brief Disallow default construction. */ ThreadPool() = delete; /*! * \brief Disallow copy construction and assignment. */ DISALLOW_COPY_AND_ASSIGN(ThreadPool); }; } // namespace engine } // namespace mxnet #endif // MXNET_ENGINE_THREAD_POOL_H_ //===== EXPANDED: mxnet/src/engine/thread_pool.h ===== //===== EXPANDIND: mxnet/src/common/lazy_alloc_array.h ===== /*! * Copyright (c) 2015 by Contributors * \file lazy_alloc_array.h * \brief An array that lazily allocate elements as * First time the cell get visited. */ #ifndef MXNET_COMMON_LAZY_ALLOC_ARRAY_H_ #define MXNET_COMMON_LAZY_ALLOC_ARRAY_H_ namespace mxnet { namespace common { template class LazyAllocArray { public: /*! * \brief Get element of corresponding index, * if it is not created create by creator * \param index the array index position * \param creator a lambda function to create new element when needed. */ template inline TElem* Get(int index, FCreate creator); /*! * \brief for each not null element of the array, call fvisit * \param fvisit a function of (size_t, TElem*) */ template inline void ForEach(FVisit fvisit); /*! \brief clear all the allocated elements in array */ inline void Clear(); private: /*! \brief the initial size of the array */ static constexpr std::size_t kInitSize = 16; /*! \brief mutex used during creation */ std::mutex create_mutex_; /*! \brief internal data fir initial size */ std::array, kInitSize> head_; /*! \brief overflow array of more elements */ std::vector > more_; }; // implementations template template inline TElem* LazyAllocArray::Get(int index, FCreate creator) { CHECK_GE(index, 0); size_t idx = static_cast(index); if (idx < kInitSize) { TElem *ptr = head_[idx].get(); if (ptr != nullptr) { return ptr; } else { std::lock_guard lock(create_mutex_); TElem *ptr = head_[idx].get(); if (ptr != nullptr) return ptr; head_[idx].reset(ptr = creator()); return ptr; } } else { std::lock_guard lock(create_mutex_); idx -= kInitSize; if (more_.size() <= idx) more_.resize(idx + 1); TElem *ptr = more_[idx].get(); if (ptr != nullptr) return ptr; more_[idx].reset(ptr = creator()); return ptr; } } template inline void LazyAllocArray::Clear() { std::lock_guard lock(create_mutex_); for (size_t i = 0; i < head_.size(); ++i) { head_[i].reset(nullptr); } for (size_t i = 0; i < more_.size(); ++i) { more_[i].reset(nullptr); } } template template inline void LazyAllocArray::ForEach(FVisit fvisit) { std::lock_guard lock(create_mutex_); for (size_t i = 0; i < head_.size(); ++i) { if (head_[i].get() != nullptr) { fvisit(i, head_[i].get()); } } for (size_t i = 0; i < more_.size(); ++i) { if (more_[i].get() != nullptr) { fvisit(i + kInitSize, more_[i].get()); } } } } // namespace common } // namespace mxnet #endif // MXNET_COMMON_LAZY_ALLOC_ARRAY_H_ //===== EXPANDED: mxnet/src/common/lazy_alloc_array.h ===== //===== EXPANDIND: mxnet/src/common/utils.h ===== /*! * Copyright (c) 2015 by Contributors * \file utils.h * \brief Basic utilility functions. */ #ifndef MXNET_COMMON_UTILS_H_ #define MXNET_COMMON_UTILS_H_ #if DMLC_USE_CXX11 #endif // DMLC_USE_CXX11 namespace mxnet { namespace common { #if DMLC_USE_CXX11 // heuristic to dermine number of threads per GPU inline int GetNumThreadPerGPU() { // This is resource efficient option. return dmlc::GetEnv("MXNET_GPU_WORKER_NTHREADS", 2); } // heuristic to get number of matching colors. // this decides how much parallelism we can get in each GPU. inline int GetExecNumMatchColor() { // This is resource efficient option. int num_match_color = dmlc::GetEnv("MXNET_EXEC_NUM_TEMP", 1); return std::min(num_match_color, GetNumThreadPerGPU()); } /*! * \brief Random Engine */ typedef std::mt19937 RANDOM_ENGINE; /*! * \brief Helper functions. */ namespace helper { /*! * \brief Helper for non-array type `T`. */ template struct UniqueIf { /*! * \brief Type of `T`. */ using SingleObject = std::unique_ptr; }; /*! * \brief Helper for an array of unknown bound `T`. */ template struct UniqueIf { /*! * \brief Type of `T`. */ using UnknownBound = std::unique_ptr; }; /*! * \brief Helper for an array of known bound `T`. */ template struct UniqueIf { /*! * \brief Type of `T`. */ using KnownBound = void; }; } // namespace helper /*! * \brief Constructs an object of type `T` and wraps it in a * `std``::``unique_ptr`. * \param args List of arguments with which an instance of `T` will be * constructed. * \return `std``::``unique_ptr` of an instance of type `T`. * * Constructs a non-array type `T`. The arguments `args` are passed to the * constructor of `T`. The function does not participate in the overload * resolution if `T` is an array type. */ template typename helper::UniqueIf::SingleObject MakeUnique(Args&&... args) { return std::unique_ptr(new T(std::forward(args)...)); } /*! * \brief Constructs an object of type `T` and wraps it in a * `std``::``unique_ptr`. * \param n The size of the array to construct. * \return `std``::``unique_ptr` of an instance of type `T`. * * Constructs an array of unknown bound `T`. The function does not participate * in the overload resolution unless `T` is an array of unknown bound. */ template typename helper::UniqueIf::UnknownBound MakeUnique(size_t n) { using U = typename std::remove_extent::type; return std::unique_ptr(new U[n]{}); } /*! * \brief Constructs an object of type `T` and wraps it in a * `std``::``unique_ptr`. * \param args List of arguments with which an instance of `T` will be * constructed. * * Constructs an arrays of known bound is disallowed. */ template typename helper::UniqueIf::KnownBound MakeUnique(Args&&... args) = delete; #endif // DMLC_USE_CXX11 } // namespace common } // namespace mxnet #endif // MXNET_COMMON_UTILS_H_ //===== EXPANDED: mxnet/src/common/utils.h ===== namespace mxnet { namespace engine { /*! * \brief ThreadedEngine uses per device threads. * The policy of this Engine: * - Execute Async operation immediately if pushed from Pusher. * - Use fixed amount of threads for each device. * - Use special threads for copy operations. * - Each stream is allocated and binded to each of the thread. */ class ThreadedEnginePerDevice : public ThreadedEngine { public: static auto constexpr kFIFO = dmlc::ConcurrentQueueType::kFIFO; static auto constexpr kPriority = dmlc::ConcurrentQueueType::kPriority; static auto constexpr kCopyQueue = kPriority; static auto constexpr kPriorityQueue = kPriority; static auto constexpr kWorkerQueue = kFIFO; ThreadedEnginePerDevice() noexcept(false) { gpu_worker_nthreads_ = common::GetNumThreadPerGPU(); gpu_copy_nthreads_ = dmlc::GetEnv("MXNET_GPU_COPY_NTHREADS", 1); cpu_worker_nthreads_ = dmlc::GetEnv("MXNET_CPU_WORKER_NTHREADS", 1); // create CPU task int cpu_priority_nthreads = dmlc::GetEnv("MXNET_CPU_PRIORITY_NTHREADS", 4); cpu_priority_worker_.reset(new ThreadWorkerBlock()); cpu_priority_worker_->pool.reset(new ThreadPool( cpu_priority_nthreads, [this] { this->CPUWorker(cpu_priority_worker_.get()); })); // GPU tasks will be created lazily } ~ThreadedEnginePerDevice() noexcept(false) { gpu_normal_workers_.Clear(); gpu_copy_workers_.Clear(); cpu_normal_workers_.Clear(); cpu_priority_worker_.reset(nullptr); } protected: void PushToExecute(OprBlock *opr_block, bool pusher_thread) override { const Context& ctx = opr_block->ctx; if (opr_block->opr->prop == FnProperty::kAsync && pusher_thread) { if (ctx.dev_mask() == gpu::kDevMask) { #if MXNET_USE_CUDA MSHADOW_CATCH_ERROR(mshadow::SetDevice(ctx.dev_id)); #endif } RunContext run_ctx; run_ctx.stream = nullptr; this->ExecuteOprBlock(run_ctx, opr_block); } else { if (ctx.dev_mask() == cpu::kDevMask) { if (opr_block->opr->prop == FnProperty::kCPUPrioritized) { cpu_priority_worker_->task_queue.Push(opr_block, opr_block->priority); } else { int dev_id = ctx.dev_id; int nthread = cpu_worker_nthreads_; cpu_normal_workers_.Get(dev_id, [this, dev_id, nthread]() { auto blk = new ThreadWorkerBlock(); blk->pool.reset(new ThreadPool(nthread, [this, blk] () { this->CPUWorker(blk); })); return blk; })->task_queue.Push(opr_block, opr_block->priority); } } else { CHECK_EQ(ctx.dev_mask(), gpu::kDevMask); // GPU execution. FnProperty prop = opr_block->opr->prop; bool is_copy = (prop == FnProperty::kCopyFromGPU || prop == FnProperty::kCopyToGPU); int nthread = gpu_worker_nthreads_; int dev_id = ctx.dev_id; if (is_copy) { gpu_copy_workers_.Get(dev_id, [this, dev_id, is_copy, nthread]() { auto blk = new ThreadWorkerBlock(); blk->pool.reset(new ThreadPool(nthread, [this, dev_id, is_copy, blk] () { this->GPUWorker(dev_id, is_copy, blk); })); return blk; })->task_queue.Push(opr_block, opr_block->priority); } else { gpu_normal_workers_.Get(dev_id, [this, dev_id, is_copy, nthread]() { auto blk = new ThreadWorkerBlock(); blk->pool.reset(new ThreadPool(nthread, [this, dev_id, is_copy, blk] () { this->GPUWorker(dev_id, is_copy, blk); })); return blk; })->task_queue.Push(opr_block, opr_block->priority); } } } } private: // working unit for each of the task. template struct ThreadWorkerBlock { // task queue on this task dmlc::ConcurrentBlockingQueue task_queue; // thread pool that works on this task std::unique_ptr pool; // destructor ~ThreadWorkerBlock() noexcept(false) { task_queue.SignalForKill(); } }; /*! \brief number of concurrent thread cpu worker uses */ int cpu_worker_nthreads_; /*! \brief number of concurrent thread each gpu worker uses */ int gpu_worker_nthreads_; /*! \brief number of concurrent thread each gpu copy worker uses */ int gpu_copy_nthreads_; // cpu worker common::LazyAllocArray > cpu_normal_workers_; // cpu priority worker std::unique_ptr > cpu_priority_worker_; // workers doing normal works on GPU common::LazyAllocArray > gpu_normal_workers_; // workers doing copy works from/to GPU common::LazyAllocArray > gpu_copy_workers_; /*! * \brief GPU worker that performs operations on a certain device. * \param dev_id The device id of the worker. * \param is_copy_worker whether the worker only do copy job * \param block The task block of the worker. */ template inline void GPUWorker(int dev_id, bool is_copy_worker, ThreadWorkerBlock *block) { #if MXNET_USE_CUDA // allocate stream mshadow::SetDevice(dev_id); RunContext run_ctx; mshadow::Stream *stream; if (is_copy_worker) { stream = mshadow::NewStream(false, false); } else { stream = mshadow::NewStream(true, MXNET_USE_CUDNN != 0); } run_ctx.stream = stream; // execute task OprBlock* opr_block; auto* task_queue = &(block->task_queue); while (task_queue->Pop(&opr_block)) { this->ExecuteOprBlock(run_ctx, opr_block); } // Catch exception for CUDA driver shutdown MSHADOW_CATCH_ERROR(mshadow::DeleteStream(stream)); #endif } /*! * \brief CPU worker that performs operations on CPU. * \param block The task block of the worker. */ template inline void CPUWorker(ThreadWorkerBlock *block) { auto* task_queue = &(block->task_queue); RunContext run_ctx; run_ctx.stream = nullptr; // execute task OprBlock* opr_block; while (task_queue->Pop(&opr_block)) { this->ExecuteOprBlock(run_ctx, opr_block); } } }; Engine *CreateThreadedEnginePerDevice() { return new ThreadedEnginePerDevice(); } } // namespace engine } // namespace mxnet //===== EXPANDED: mxnet/src/engine/threaded_engine_perdevice.cc ===== //===== EXPANDIND: mxnet/src/engine/threaded_engine_pooled.cc ===== /*! * Copyright (c) 2015 by Contributors * \file threaded_engine_pooled.cc * \brief Pooled threaded engine * \author Yutian Li */ //===== EXPANDIND: mxnet/src/engine/stream_manager.h ===== /*! * Copyright (c) 2015 by Contributors */ #ifndef MXNET_ENGINE_STREAM_MANAGER_H_ #define MXNET_ENGINE_STREAM_MANAGER_H_ namespace mxnet { namespace engine { /*! * \brief Stream manager. * * Uses a basic round-robin algorithm to dispatch GPU streams. Returns default * context on CPU. */ template class StreamManager { public: StreamManager(); ~StreamManager() { Finalize(); } RunContext GetRunContext(Context const& ctx); RunContext GetIORunContext(Context const& ctx); void Finalize(); private: std::mutex m_; #if MXNET_USE_CUDA std::array*, kStreams>, kNumGpus> gpu_streams_; std::array*, kNumGpus> gpu_io_streams_; std::array gpu_cnt_; #endif // MXNET_USE_CUDA DISALLOW_COPY_AND_ASSIGN(StreamManager); }; // class StreamManager template RunContext StreamManager::GetRunContext( Context const& ctx) { RunContext ret; ret.stream = nullptr; switch (ctx.dev_mask()) { case cpu::kDevMask: break; case gpu::kDevMask: { #if MXNET_USE_CUDA std::size_t use_counter; CUDA_CALL(cudaSetDevice(ctx.dev_id)); { std::lock_guard lock{m_}; auto&& counter = gpu_cnt_.at(ctx.dev_id); if (counter == -1) { for (auto&& i : gpu_streams_.at(ctx.dev_id)) { i = mshadow::NewStream(true, MXNET_USE_CUDNN != 0); } counter = 0; } use_counter = counter; counter = (counter + 1) % kStreams; } ret.stream = gpu_streams_.at(ctx.dev_id).at(use_counter); break; #else LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; #endif // MXNET_USE_CUDA } } return ret; } template RunContext StreamManager::GetIORunContext( Context const& ctx) { RunContext ret; ret.stream = nullptr; switch (ctx.dev_mask()) { case cpu::kDevMask: break; case gpu::kDevMask: { #if MXNET_USE_CUDA CUDA_CALL(cudaSetDevice(ctx.dev_id)); { std::lock_guard lock{m_}; if (gpu_io_streams_.at(ctx.dev_id) == nullptr) { gpu_io_streams_.at(ctx.dev_id) = mshadow::NewStream(false, false); } } ret.stream = gpu_io_streams_.at(ctx.dev_id); break; #else LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; #endif // MXNET_USE_CUDA } } return ret; } template StreamManager::StreamManager() { #if MXNET_USE_CUDA for (std::size_t i = 0; i < kNumGpus; ++i) { gpu_cnt_.at(i) = -1; } for (auto&& i : gpu_io_streams_) { i = nullptr; } #endif // MXNET_USE_CUDA } template void StreamManager::Finalize() { #if MXNET_USE_CUDA for (std::size_t i = 0; i < kNumGpus; ++i) { if (gpu_cnt_.at(i) != -1) { for (auto&& j : gpu_streams_.at(i)) { // Catch exception for CUDA driver shutdown MSHADOW_CATCH_ERROR(mshadow::DeleteStream(j)); } gpu_cnt_.at(i) = -1; } } #endif // MXNET_USE_CUDA } } // namespace engine } // namespace mxnet #endif // MXNET_ENGINE_STREAM_MANAGER_H_ //===== EXPANDED: mxnet/src/engine/stream_manager.h ===== namespace mxnet { namespace engine { /*! * \brief ThreadedEngine using global thread pool across all devices. * The policy of this Engine: * - Execute Async operation immediately if pushed from Pusher. * - Use a common thread pool for normal operations on all devices. * - Use special thread pool for copy operations. */ class ThreadedEnginePooled : public ThreadedEngine { public: ThreadedEnginePooled() : thread_pool_(kNumWorkingThreads, [this]() { ThreadWorker(&task_queue_); }), io_thread_pool_(1, [this]() { ThreadWorker(&io_task_queue_); }) {} ~ThreadedEnginePooled() noexcept(false) { streams_.Finalize(); task_queue_.SignalForKill(); io_task_queue_.SignalForKill(); } protected: void PushToExecute(OprBlock *opr_block, bool pusher_thread) override { if (opr_block->opr->prop == FnProperty::kAsync && pusher_thread) { DoExecute(opr_block); } else { DoPushToQueue(opr_block); } } private: /*! \brief Concurrency for thread pool */ static constexpr std::size_t kNumWorkingThreads = 16; /*! \brief Maximum number of GPUs */ static constexpr std::size_t kMaxNumGpus = 16; /*!\brief number of streams allocated for each GPU */ static constexpr std::size_t kNumStreamsPerGpu = 16; /*! * \brief Streams. */ StreamManager streams_; /*! * \brief Task queues. */ dmlc::ConcurrentBlockingQueue task_queue_; dmlc::ConcurrentBlockingQueue io_task_queue_; /*! * \brief Thread pools. */ ThreadPool thread_pool_; ThreadPool io_thread_pool_; /*! * \brief Worker. * \param task_queue Queue to work on. * * The method to pass to thread pool to parallelize. */ void ThreadWorker(dmlc::ConcurrentBlockingQueue* task_queue) { OprBlock* opr_block; while (task_queue->Pop(&opr_block)) { DoExecute(opr_block); } } /*! * \brief Execute an operation. * \param opr_block The operator block. */ void DoExecute(OprBlock* opr_block) { assert(opr_block->wait.load() == 0); if (opr_block->ctx.dev_mask() == gpu::kDevMask) { #if MXNET_USE_CUDA CUDA_CALL(cudaSetDevice(opr_block->ctx.dev_id)); #else // MXNET_USE_CUDA LOG(FATAL) << "Please compile with CUDA enabled"; #endif // MXNET_USE_CUDA } bool is_copy = (opr_block->opr->prop == FnProperty::kCopyFromGPU || opr_block->opr->prop == FnProperty::kCopyToGPU); auto&& rctx = is_copy ? streams_.GetIORunContext(opr_block->ctx) : streams_.GetRunContext(opr_block->ctx); this->ExecuteOprBlock(rctx, opr_block); } /*! * \brief Push the operation to the queue. * \param opr_block The operator block. */ void DoPushToQueue(OprBlock* opr_block) { switch (opr_block->opr->prop) { case FnProperty::kCopyFromGPU: case FnProperty::kCopyToGPU: { io_task_queue_.Push(opr_block); break; } default: { task_queue_.Push(opr_block); break; } } } }; Engine *CreateThreadedEnginePooled() { return new ThreadedEnginePooled(); } } // namespace engine } // namespace mxnet //===== EXPANDED: mxnet/src/engine/threaded_engine_pooled.cc ===== //===== EXPANDIND: mxnet/src/io/io.cc ===== // Copyright (c) 2015 by Contributors //===== EXPANDIND: mxnet/include/mxnet/io.h ===== /*! * Copyright (c) 2015 by Contributors * \file io.h * \brief mxnet io data structure and data iterator */ #ifndef MXNET_IO_H_ #define MXNET_IO_H_ //===== EXPANDIND: mxnet/dmlc-core/include/dmlc/data.h ===== /*! * Copyright (c) 2015 by Contributors * \file data.h * \brief defines common input data structure, * and interface for handling the input data */ #ifndef DMLC_DATA_H_ #define DMLC_DATA_H_ namespace dmlc { /*! * \brief this defines the float point * that will be used to store feature values */ typedef float real_t; /*! * \brief this defines the unsigned integer type * that can normally be used to store feature index */ typedef unsigned index_t; // This file describes common data structure that can be used // for large-scale machine learning, this may not be a complete list // But we will keep the most common and useful ones, and keep adding new ones /*! * \brief data iterator interface * this is not a C++ style iterator, but nice for data pulling:) * This interface is used to pull in the data * The system can do some useful tricks for you like pre-fetching * from disk and pre-computation. * * Usage example: * \code * * itr->BeforeFirst(); * while (itr->Next()) { * const DType &batch = itr->Value(); * // some computations * } * \endcode * \tparam DType the data type */ template class DataIter { public: /*! \brief destructor */ virtual ~DataIter(void) {} /*! \brief set before first of the item */ virtual void BeforeFirst(void) = 0; /*! \brief move to next item */ virtual bool Next(void) = 0; /*! \brief get current data */ virtual const DType &Value(void) const = 0; }; /*! * \brief one row of training instance * \tparam IndexType type of index */ template class Row { public: /*! \brief label of the instance */ real_t label; /*! \brief weight of the instance */ real_t weight; /*! \brief length of the sparse vector */ size_t length; /*! * \brief index of each instance */ const IndexType *index; /*! * \brief array value of each instance, this can be NULL * indicating every value is set to be 1 */ const real_t *value; /*! * \param i the input index * \return i-th feature */ inline IndexType get_index(size_t i) const { return index[i]; } /*! * \param i the input index * \return i-th feature value, this function is always * safe even when value == NULL */ inline real_t get_value(size_t i) const { return value == NULL ? 1.0f : value[i]; } /*! * \brief helper function to compute dot product of current * \param weight the dense array of weight we want to product * \param size the size of the weight vector * \tparam V type of the weight vector * \return the result of dot product */ template inline V SDot(const V *weight, size_t size) const { V sum = static_cast(0); if (value == NULL) { for (size_t i = 0; i < length; ++i) { CHECK(index[i] < size) << "feature index exceed bound"; sum += weight[index[i]]; } } else { for (size_t i = 0; i < length; ++i) { CHECK(index[i] < size) << "feature index exceed bound"; sum += weight[index[i]] * value[i]; } } return sum; } }; /*! * \brief a block of data, containing several rows in sparse matrix * This is useful for (streaming-sxtyle) algorithms that scans through rows of data * examples include: SGD, GD, L-BFGS, kmeans * * The size of batch is usually large enough so that parallelizing over the rows * can give significant speedup * \tparam IndexType type to store the index used in row batch */ template struct RowBlock { /*! \brief batch size */ size_t size; /*! \brief array[size+1], row pointer to beginning of each rows */ const size_t *offset; /*! \brief array[size] label of each instance */ const real_t *label; /*! \brief With weight: array[size] label of each instance, otherwise nullptr */ const real_t *weight; /*! \brief feature index */ const IndexType *index; /*! \brief feature value, can be NULL, indicating all values are 1 */ const real_t *value; /*! * \brief get specific rows in the batch * \param rowid the rowid in that row * \return the instance corresponding to the row */ inline Row operator[](size_t rowid) const; /*! \return memory cost of the block in bytes */ inline size_t MemCostBytes(void) const { size_t cost = size * (sizeof(size_t) + sizeof(real_t)); if (weight != NULL) cost += size * sizeof(real_t); size_t ndata = offset[size] - offset[0]; if (index != NULL) cost += ndata * sizeof(IndexType); if (value != NULL) cost += ndata * sizeof(real_t); return cost; } /*! * \brief slice a RowBlock to get rows in [begin, end) * \param begin the begin row index * \param end the end row index * \return the sliced RowBlock */ inline RowBlock Slice(size_t begin, size_t end) const { CHECK(begin <= end && end < size); RowBlock ret; ret.size = end - begin; ret.label = label + begin; if (weight != NULL) { ret.weight = weight + begin; } else { ret.weight = NULL; } ret.offset = offset + begin; ret.index = index; ret.value = value; return ret; } }; /*! * \brief Data structure that holds the data * Row block iterator interface that gets RowBlocks * Difference between RowBlockIter and Parser: * RowBlockIter caches the data internally that can be used * to iterate the dataset multiple times, * Parser holds very limited internal state and was usually * used to read data only once * * \sa Parser * \tparam IndexType type of index in RowBlock * Create function was only implemented for IndexType uint64_t and uint32_t */ template class RowBlockIter : public DataIter > { public: /*! * \brief create a new instance of iterator that returns rowbatch * by default, a in-memory based iterator will be returned * * \param uri the uri of the input, can contain hdfs prefix * \param part_index the part id of current input * \param num_parts total number of splits * \param type type of dataset can be: "libsvm", ... * * \return the created data iterator */ static RowBlockIter * Create(const char *uri, unsigned part_index, unsigned num_parts, const char *type); /*! \return maximum feature dimension in the dataset */ virtual size_t NumCol() const = 0; }; /*! * \brief parser interface that parses input data * used to load dmlc data format into your own data format * Difference between RowBlockIter and Parser: * RowBlockIter caches the data internally that can be used * to iterate the dataset multiple times, * Parser holds very limited internal state and was usually * used to read data only once * * * \sa RowBlockIter * \tparam IndexType type of index in RowBlock * Create function was only implemented for IndexType uint64_t and uint32_t */ template class Parser : public DataIter > { public: /*! * \brief create a new instance of parser based on the "type" * * \param uri_ the uri of the input, can contain hdfs prefix * \param part_index the part id of current input * \param num_parts total number of splits * \param type type of dataset can be: "libsvm", ... * * \return the created parser */ static Parser * Create(const char *uri_, unsigned part_index, unsigned num_parts, const char *type); /*! \return size of bytes read so far */ virtual size_t BytesRead(void) const = 0; }; // implementation of operator[] template inline Row RowBlock::operator[](size_t rowid) const { CHECK(rowid < size); Row inst; inst.label = label[rowid]; if (weight != NULL) { inst.weight = weight[rowid]; } else { inst.weight = 1.0f; } inst.length = offset[rowid + 1] - offset[rowid]; inst.index = index + offset[rowid]; if (value == NULL) { inst.value = NULL; } else { inst.value = value + offset[rowid]; } return inst; } } // namespace dmlc #endif // DMLC_DATA_H_ //===== EXPANDED: mxnet/dmlc-core/include/dmlc/data.h ===== namespace mxnet { /*! * \brief iterator type * \tparam DType data type */ template class IIterator : public dmlc::DataIter { public: /*! * \brief set the parameters and init iter * \param kwargs key-value pairs */ virtual void Init(const std::vector >& kwargs) = 0; /*! \brief reset the iterator */ virtual void BeforeFirst(void) = 0; /*! \brief move to next item */ virtual bool Next(void) = 0; /*! \brief get current data */ virtual const DType &Value(void) const = 0; /*! \brief constructor */ virtual ~IIterator(void) {} /*! \brief store the name of each data, it could be used for making NDArrays */ std::vector data_names; /*! \brief set data name to each attribute of data */ inline void SetDataName(const std::string data_name) { data_names.push_back(data_name); } }; // class IIterator /*! \brief a single data instance */ struct DataInst { /*! \brief unique id for instance */ unsigned index; /*! \brief content of data */ std::vector data; /*! \brief extra data to be fed to the network */ std::string extra_data; }; // struct DataInst /*! * \brief DataBatch of NDArray, returned by Iterator */ struct DataBatch { /*! \brief content of dense data, if this DataBatch is dense */ std::vector data; /*! \brief extra data to be fed to the network */ std::string extra_data; /*! \brief num of example padded to batch */ int num_batch_padd; }; // struct DataBatch /*! \brief typedef the factory function of data iterator */ typedef std::function *()> DataIteratorFactory; /*! * \brief Registry entry for DataIterator factory functions. */ struct DataIteratorReg : public dmlc::FunctionRegEntryBase { }; //-------------------------------------------------------------- // The following part are API Registration of Iterators //-------------------------------------------------------------- /*! * \brief Macro to register Iterators * * \code * // example of registering a mnist iterator * REGISTER_IO_ITE(MNISTIter) * .describe("Mnist data iterator") * .set_body([]() { * return new PrefetcherIter(new MNISTIter()); * }); * \endcode */ #define MXNET_REGISTER_IO_ITER(name) \ DMLC_REGISTRY_REGISTER(::mxnet::DataIteratorReg, DataIteratorReg, name) } // namespace mxnet #endif // MXNET_IO_H_ //===== EXPANDED: mxnet/include/mxnet/io.h ===== //===== EXPANDIND: mxnet/src/io/image_augmenter.h ===== /*! * Copyright (c) 2015 by Contributors * \file image_augmenter_opencv.hpp * \brief threaded version of page iterator */ #ifndef MXNET_IO_IMAGE_AUGMENTER_H_ #define MXNET_IO_IMAGE_AUGMENTER_H_ #if MXNET_USE_OPENCV #endif namespace mxnet { namespace io { /*! \brief image augmentation parameters*/ struct ImageAugmentParam : public dmlc::Parameter { /*! \brief whether we do random cropping */ bool rand_crop; /*! \brief whether we do nonrandom croping */ int crop_y_start; /*! \brief whether we do nonrandom croping */ int crop_x_start; /*! \brief [-max_rotate_angle, max_rotate_angle] */ int max_rotate_angle; /*! \brief max aspect ratio */ float max_aspect_ratio; /*! \brief random shear the image [-max_shear_ratio, max_shear_ratio] */ float max_shear_ratio; /*! \brief max crop size */ int max_crop_size; /*! \brief min crop size */ int min_crop_size; /*! \brief max scale ratio */ float max_random_scale; /*! \brief min scale_ratio */ float min_random_scale; /*! \brief min image size */ float min_img_size; /*! \brief max image size */ float max_img_size; /*! \brief rotate angle */ int rotate; /*! \brief filled color while padding */ int fill_value; /*! \brief shape of the image data*/ TShape data_shape; // declare parameters DMLC_DECLARE_PARAMETER(ImageAugmentParam) { DMLC_DECLARE_FIELD(rand_crop).set_default(false) .describe("Augmentation Param: Whether to random crop on the image"); DMLC_DECLARE_FIELD(crop_y_start).set_default(-1) .describe("Augmentation Param: Where to nonrandom crop on y."); DMLC_DECLARE_FIELD(crop_x_start).set_default(-1) .describe("Augmentation Param: Where to nonrandom crop on x."); DMLC_DECLARE_FIELD(max_rotate_angle).set_default(0.0f) .describe("Augmentation Param: rotated randomly in [-max_rotate_angle, max_rotate_angle]."); DMLC_DECLARE_FIELD(max_aspect_ratio).set_default(0.0f) .describe("Augmentation Param: denotes the max ratio of random aspect ratio augmentation."); DMLC_DECLARE_FIELD(max_shear_ratio).set_default(0.0f) .describe("Augmentation Param: denotes the max random shearing ratio."); DMLC_DECLARE_FIELD(max_crop_size).set_default(-1) .describe("Augmentation Param: Maximum crop size."); DMLC_DECLARE_FIELD(min_crop_size).set_default(-1) .describe("Augmentation Param: Minimum crop size."); DMLC_DECLARE_FIELD(max_random_scale).set_default(1.0f) .describe("Augmentation Param: Maxmum scale ratio."); DMLC_DECLARE_FIELD(min_random_scale).set_default(1.0f) .describe("Augmentation Param: Minimum scale ratio."); DMLC_DECLARE_FIELD(max_img_size).set_default(1e10f) .describe("Augmentation Param: Maxmum image size after resizing."); DMLC_DECLARE_FIELD(min_img_size).set_default(0.0f) .describe("Augmentation Param: Minimum image size after resizing."); DMLC_DECLARE_FIELD(rotate).set_default(-1.0f) .describe("Augmentation Param: Rotate angle."); DMLC_DECLARE_FIELD(fill_value).set_default(255) .describe("Augmentation Param: Maximum value of illumination variation."); DMLC_DECLARE_FIELD(data_shape) .set_expect_ndim(3).enforce_nonzero() .describe("Dataset Param: Shape of each instance generated by the DataIter."); } }; /*! \brief helper class to do image augmentation */ class ImageAugmenter { public: // contructor ImageAugmenter(void) { #if MXNET_USE_OPENCV rotateM_ = cv::Mat(2, 3, CV_32F); #endif } virtual ~ImageAugmenter() { } virtual void Init(const std::vector >& kwargs) { std::vector > kwargs_left; kwargs_left = param_.InitAllowUnknown(kwargs); for (size_t i = 0; i < kwargs_left.size(); i++) { if (!strcmp(kwargs_left[i].first.c_str(), "rotate_list")) { const char* val = kwargs_left[i].second.c_str(); const char *end = val + strlen(val); char buf[128]; while (val < end) { sscanf(val, "%[^,]", buf); val += strlen(buf) + 1; rotate_list_.push_back(atoi(buf)); } } } } #if MXNET_USE_OPENCV #ifdef _MSC_VER #define M_PI CV_PI #endif /*! * \brief augment src image, store result into dst * this function is not thread safe, and will only be called by one thread * however, it will tries to re-use memory space as much as possible * \param src the source image * \param source of random number * \param dst the pointer to the place where we want to store the result */ virtual cv::Mat Process(const cv::Mat &src, common::RANDOM_ENGINE *prnd) { using mshadow::index_t; cv::Mat res; // normal augmentation by affine transformation. if (param_.max_rotate_angle > 0 || param_.max_shear_ratio > 0.0f || param_.rotate > 0 || rotate_list_.size() > 0) { std::uniform_real_distribution rand_uniform(0, 1); // shear float s = rand_uniform(*prnd) * param_.max_shear_ratio * 2 - param_.max_shear_ratio; // rotate int angle = std::uniform_int_distribution( -param_.max_rotate_angle, param_.max_rotate_angle)(*prnd); if (param_.rotate > 0) angle = param_.rotate; if (rotate_list_.size() > 0) { angle = rotate_list_[std::uniform_int_distribution(0, rotate_list_.size() - 1)(*prnd)]; } float a = cos(angle / 180.0 * M_PI); float b = sin(angle / 180.0 * M_PI); // scale float scale = rand_uniform(*prnd) * (param_.max_random_scale - param_.min_random_scale) + param_.min_random_scale; // aspect ratio float ratio = rand_uniform(*prnd) * param_.max_aspect_ratio * 2 - param_.max_aspect_ratio + 1; float hs = 2 * scale / (1 + ratio); float ws = ratio * hs; // new width and height float new_width = std::max(param_.min_img_size, std::min(param_.max_img_size, scale * src.cols)); float new_height = std::max(param_.min_img_size, std::min(param_.max_img_size, scale * src.rows)); cv::Mat M(2, 3, CV_32F); M.at(0, 0) = hs * a - s * b * ws; M.at(1, 0) = -b * ws; M.at(0, 1) = hs * b + s * a * ws; M.at(1, 1) = a * ws; float ori_center_width = M.at(0, 0) * src.cols + M.at(0, 1) * src.rows; float ori_center_height = M.at(1, 0) * src.cols + M.at(1, 1) * src.rows; M.at(0, 2) = (new_width - ori_center_width) / 2; M.at(1, 2) = (new_height - ori_center_height) / 2; cv::warpAffine(src, temp_, M, cv::Size(new_width, new_height), cv::INTER_LINEAR, cv::BORDER_CONSTANT, cv::Scalar(param_.fill_value, param_.fill_value, param_.fill_value)); res = temp_; } else { res = src; } // crop logic if (param_.max_crop_size != -1 || param_.min_crop_size != -1) { CHECK(res.cols >= param_.max_crop_size && res.rows >= \ param_.max_crop_size && param_.max_crop_size >= param_.min_crop_size) << "input image size smaller than max_crop_size"; index_t rand_crop_size = std::uniform_int_distribution(param_.min_crop_size, param_.max_crop_size)(*prnd); index_t y = res.rows - rand_crop_size; index_t x = res.cols - rand_crop_size; if (param_.rand_crop != 0) { y = std::uniform_int_distribution(0, y)(*prnd); x = std::uniform_int_distribution(0, x)(*prnd); } else { y /= 2; x /= 2; } cv::Rect roi(x, y, rand_crop_size, rand_crop_size); cv::resize(res(roi), res, cv::Size(param_.data_shape[1], param_.data_shape[2])); } else { CHECK(static_cast(res.cols) >= param_.data_shape[1] && static_cast(res.rows) >= param_.data_shape[2]) << "input image size smaller than input shape"; index_t y = res.rows - param_.data_shape[2]; index_t x = res.cols - param_.data_shape[1]; if (param_.rand_crop != 0) { y = std::uniform_int_distribution(0, y)(*prnd); x = std::uniform_int_distribution(0, x)(*prnd); } else { y /= 2; x /= 2; } cv::Rect roi(x, y, param_.data_shape[1], param_.data_shape[2]); res = res(roi); } return res; } #endif private: #if MXNET_USE_OPENCV // temporal space cv::Mat temp_; // rotation param cv::Mat rotateM_; #endif // parameters ImageAugmentParam param_; /*! \brief list of possible rotate angle */ std::vector rotate_list_; }; } // namespace io } // namespace mxnet #endif // MXNET_IO_IMAGE_AUGMENTER_H_ //===== EXPANDED: mxnet/src/io/image_augmenter.h ===== //===== EXPANDIND: mxnet/src/io/iter_normalize.h ===== /*! * Copyright (c) 2015 by Contributors * \file iter_normalize.h * \brief Iterator that substracts mean and do a few augmentations. */ #ifndef MXNET_IO_ITER_NORMALIZE_H_ #define MXNET_IO_ITER_NORMALIZE_H_ //===== EXPANDIND: mxnet/dmlc-core/include/dmlc/timer.h ===== /*! * Copyright (c) 2015 by Contributors * \file timer.h * \brief cross platform timer for timing * \author Tianqi Chen */ #ifndef DMLC_TIMER_H_ #define DMLC_TIMER_H_ #ifdef __MACH__ #endif namespace dmlc { /*! * \brief return time in seconds */ inline double GetTime(void) { // TODO(tqchen): use c++11 chrono when c++11 was available #ifdef __MACH__ clock_serv_t cclock; mach_timespec_t mts; host_get_clock_service(mach_host_self(), CALENDAR_CLOCK, &cclock); CHECK(clock_get_time(cclock, &mts) == 0) << "failed to get time"; mach_port_deallocate(mach_task_self(), cclock); return static_cast(mts.tv_sec) + static_cast(mts.tv_nsec) * 1e-9; #else #if defined(__unix__) || defined(__linux__) timespec ts; CHECK(clock_gettime(CLOCK_REALTIME, &ts) == 0) << "failed to get time"; return static_cast(ts.tv_sec) + static_cast(ts.tv_nsec) * 1e-9; #else return static_cast(time(NULL)); #endif #endif } } // namespace dmlc #endif // DMLC_TIMER_H_ //===== EXPANDED: mxnet/dmlc-core/include/dmlc/timer.h ===== namespace mxnet { namespace io { // normalize parameters struct ImageNormalizeParam : public dmlc::Parameter { /*! \brief random seed */ int seed; /*! \brief whether to mirror the image */ bool mirror; /*! \brief whether to perform rand mirror the image */ bool rand_mirror; /*! \brief mean file string */ std::string mean_img; /*! \brief mean value for r channel */ float mean_r; /*! \brief mean value for g channel */ float mean_g; /*! \brief mean value for b channel */ float mean_b; /*! \brief scale on color space */ float scale; /*! \brief maximum ratio of contrast variation */ float max_random_contrast; /*! \brief maximum value of illumination variation */ float max_random_illumination; /*! \brief silent */ bool verbose; // declare parameters DMLC_DECLARE_PARAMETER(ImageNormalizeParam) { DMLC_DECLARE_FIELD(seed).set_default(0) .describe("Augmentation Param: Random Seed."); DMLC_DECLARE_FIELD(mirror).set_default(false) .describe("Augmentation Param: Whether to mirror the image."); DMLC_DECLARE_FIELD(rand_mirror).set_default(false) .describe("Augmentation Param: Whether to mirror the image randomly."); DMLC_DECLARE_FIELD(mean_img).set_default("") .describe("Augmentation Param: Mean Image to be subtracted."); DMLC_DECLARE_FIELD(mean_r).set_default(0.0f) .describe("Augmentation Param: Mean value on R channel."); DMLC_DECLARE_FIELD(mean_g).set_default(0.0f) .describe("Augmentation: Mean value on G channel."); DMLC_DECLARE_FIELD(mean_b).set_default(0.0f) .describe("Augmentation: Mean value on B channel."); DMLC_DECLARE_FIELD(scale).set_default(1.0f) .describe("Augmentation Param: Scale in color space."); DMLC_DECLARE_FIELD(max_random_contrast).set_default(0.0f) .describe("Augmentation Param: Maximum ratio of contrast variation."); DMLC_DECLARE_FIELD(max_random_illumination).set_default(0.0f) .describe("Augmentation Param: Maximum value of illumination variation."); DMLC_DECLARE_FIELD(verbose).set_default(true) .describe("Augmentation Param: Whether to print augmentor info."); } }; /*! * \brief Iterator that normalize a image. * It also applies a few augmention before normalization. */ class ImageNormalizeIter : public IIterator { public: explicit ImageNormalizeIter(IIterator *base) : base_(base), meanfile_ready_(false) { } virtual void Init(const std::vector >& kwargs) { param_.InitAllowUnknown(kwargs); base_->Init(kwargs); rnd_.seed(kRandMagic + param_.seed); if (param_.mean_img.length() != 0) { std::unique_ptr fi( dmlc::Stream::Create(param_.mean_img.c_str(), "r", true)); if (fi.get() == nullptr) { this->CreateMeanImg(); } else { fi.reset(nullptr); if (param_.verbose) { LOG(INFO) << "Load mean image from " << param_.mean_img; } // use python compatible ndarray store format std::vector data; std::vector keys; { std::unique_ptr fi(dmlc::Stream::Create(param_.mean_img.c_str(), "r")); NDArray::Load(fi.get(), &data, &keys); } CHECK_EQ(data.size(), 1) << "Invalid mean image file format"; data[0].WaitToRead(); mshadow::Tensor src = data[0].data().get(); meanimg_.Resize(src.shape_); mshadow::Copy(meanimg_, src); meanfile_ready_ = true; } } } virtual void BeforeFirst(void) { base_->BeforeFirst(); } virtual const DataInst& Value(void) const { return out_; } virtual bool Next(void) { if (!this->Next_()) return false; return true; } private: /*! \brief base iterator */ std::unique_ptr > base_; // whether mean image is ready. bool meanfile_ready_; /*! \brief output data */ DataInst out_; // normalize parameter. ImageNormalizeParam param_; /*! \brief mean image, if needed */ mshadow::TensorContainer meanimg_; /*! \brief temp space for output image */ mshadow::TensorContainer outimg_; /*! \brief random numeber engine */ common::RANDOM_ENGINE rnd_; // random magic number of this iterator static const int kRandMagic = 0; /*! \brief internal next function, inlined for fater processing. */ inline bool Next_(void) { if (!base_->Next()) return false; const DataInst &src = base_->Value(); this->SetOutImg(src); out_.data.resize(2); out_.data[0] = outimg_; out_.data[1] = src.data[1]; out_.index = src.index; out_.extra_data = src.extra_data; return true; } /*! * \brief Set the output image, after augmentation and normalization. * \param src The source image. */ inline void SetOutImg(const DataInst &src) { using namespace mshadow::expr; // NOLINT(*) std::uniform_real_distribution rand_uniform(0, 1); std::bernoulli_distribution coin_flip(0.5); mshadow::Tensor data = src.data[0].get(); outimg_.Resize(data.shape_); float contrast = rand_uniform(rnd_) * param_.max_random_contrast * 2 - param_.max_random_contrast + 1; float illumination = rand_uniform(rnd_) * param_.max_random_illumination * 2 - param_.max_random_illumination; if (param_.mean_r > 0.0f || param_.mean_g > 0.0f || param_.mean_b > 0.0f) { // substract mean value data[0] -= param_.mean_b; data[1] -= param_.mean_g; data[2] -= param_.mean_r; if ((param_.rand_mirror && coin_flip(rnd_)) || param_.mirror) { outimg_ = mirror(data * contrast + illumination) * param_.scale; } else { outimg_ = (data * contrast + illumination) * param_.scale; } } else if (!meanfile_ready_ || param_.mean_img.length() == 0) { // do not substract anything if ((param_.rand_mirror && coin_flip(rnd_)) || param_.mirror) { outimg_ = mirror(data) * param_.scale; } else { outimg_ = F(data) * param_.scale; } } else { CHECK(meanfile_ready_); if ((param_.rand_mirror && coin_flip(rnd_)) || param_.mirror) { outimg_ = mirror((data - meanimg_) * contrast + illumination) * param_.scale; } else { outimg_ = ((data - meanimg_) * contrast + illumination) * param_.scale; } } } // creat mean image. inline void CreateMeanImg(void) { if (param_.verbose) { LOG(INFO) << "Cannot find " << param_.mean_img << ": create mean image, this will take some time..."; } double start = dmlc::GetTime(); size_t imcnt = 1; // NOLINT(*) CHECK(this->Next_()) << "input iterator failed."; meanimg_.Resize(outimg_.shape_); mshadow::Copy(meanimg_, outimg_); while (this->Next_()) { meanimg_ += outimg_; imcnt += 1; double elapsed = dmlc::GetTime() - start; if (imcnt % 10000L == 0 && param_.verbose) { LOG(INFO) << imcnt << " images processed, " << elapsed << " sec elapsed"; } } meanimg_ *= (1.0f / imcnt); // save as mxnet python compatible format. TBlob tmp = meanimg_; { std::unique_ptr fo(dmlc::Stream::Create(param_.mean_img.c_str(), "w")); NDArray::Save(fo.get(), {NDArray(tmp, 0)}, {"mean_img"}); } if (param_.verbose) { LOG(INFO) << "Save mean image to " << param_.mean_img << ".."; } meanfile_ready_ = true; this->BeforeFirst(); } }; } // namespace io } // namespace mxnet #endif // MXNET_IO_ITER_NORMALIZE_H_ //===== EXPANDED: mxnet/src/io/iter_normalize.h ===== //===== EXPANDIND: mxnet/src/io/iter_batchloader.h ===== /*! * Copyright (c) 2015 by Contributors * \file iter_batchloader.h * \brief define a batch adapter to create tblob batch */ #ifndef MXNET_IO_ITER_BATCHLOADER_H_ #define MXNET_IO_ITER_BATCHLOADER_H_ //===== EXPANDIND: mxnet/src/io/inst_vector.h ===== /*! * Copyright (c) 2015 by Contributors * \file inst_vector.h * \brief holder of a sequence of DataInst in CPU * that are not necessarily of same shape */ #ifndef MXNET_IO_INST_VECTOR_H_ #define MXNET_IO_INST_VECTOR_H_ namespace mxnet { namespace io { /*! * \brief a vector of tensor with various shape * * data are stored in memory continously */ template class TensorVector { public: TensorVector(void) { this->Clear(); } /*! \brief get the buffer to the i-th tensor */ inline mshadow::Tensor operator[](size_t i) const { CHECK_LT(i + 1, offset_.size()); CHECK_EQ(shape_[i].Size(), offset_[i + 1] - offset_[i]); return mshadow::Tensor ((DType*)dmlc::BeginPtr(content_) + offset_[i], shape_[i]); // NOLINT(*) } inline mshadow::Tensor Back() const { return (*this)[Size() - 1]; } inline size_t Size(void) const { return shape_.size(); } /*! \brief allocate space given the shape (data are copied) */ inline void Push(mshadow::Shape shape) { shape_.push_back(shape); offset_.push_back(offset_.back() + shape.Size()); content_.resize(offset_.back()); } inline void Clear(void) { offset_.clear(); offset_.push_back(0); content_.clear(); shape_.clear(); } private: // offset of the data content std::vector offset_; // data content std::vector content_; // shape of data std::vector > shape_; }; /*! * \brief a list of (label, example) pairs, examples can have various shape */ class InstVector { public: /*! \brief return the number of (label, example) pairs */ inline size_t Size(void) const { return index_.size(); } // get index inline unsigned Index(unsigned i) const { return index_[i]; } // instance /* \brief get the i-th (label, example) pair */ inline DataInst operator[](size_t i) const { DataInst inst; inst.index = index_[i]; inst.data.push_back(TBlob(data_[i])); inst.data.push_back(TBlob(label_[i])); return inst; } /* \brief get the last (label, example) pair */ inline DataInst Back() const { return (*this)[Size() - 1]; } inline void Clear(void) { index_.clear(); data_.Clear(); label_.Clear(); } /* * \brief push a (label, example) pair * only reserved the space, while the data is not copied */ inline void Push(unsigned index, mshadow::Shape<3> dshape, mshadow::Shape<1> lshape) { index_.push_back(index); data_.Push(dshape); label_.Push(lshape); } /*! \return the data content */ inline const TensorVector<3, real_t>& data() const { return data_; } /*! \return the label content */ inline const TensorVector<1, real_t>& label() const { return label_; } private: /*! \brief index of the data */ std::vector index_; // label TensorVector<3, real_t> data_; // data TensorVector<1, real_t> label_; }; /*! * \brief tblob batch * * data are stored in tblob before going into NDArray */ struct TBlobBatch { public: /*! \brief unique id for instance, can be NULL, sometimes is useful */ unsigned *inst_index; /*! \brief number of instance */ mshadow::index_t batch_size; /*! \brief number of padding elements in this batch, this is used to indicate the last elements in the batch are only padded up to match the batch, and should be discarded */ mshadow::index_t num_batch_padd; /*! \brief content of dense data */ std::vector data; /*! \brief extra data to be fed to the network */ std::string extra_data; /*! \brief constructor */ TBlobBatch(void) { inst_index = NULL; batch_size = 0; num_batch_padd = 0; } /*! \brief destructor */ ~TBlobBatch() { delete inst_index; } }; // struct TBlobBatch } // namespace io } // namespace mxnet #endif // MXNET_IO_INST_VECTOR_H_ //===== EXPANDED: mxnet/src/io/inst_vector.h ===== namespace mxnet { namespace io { // Batch parameters struct BatchParam : public dmlc::Parameter { /*! \brief label width */ index_t batch_size; /*! \brief input shape */ TShape data_shape; /*! \brief label width */ index_t label_width; /*! \brief use round roubin to handle overflow batch */ bool round_batch; // declare parameters DMLC_DECLARE_PARAMETER(BatchParam) { DMLC_DECLARE_FIELD(batch_size) .describe("Batch Param: Batch size."); DMLC_DECLARE_FIELD(data_shape) .set_expect_ndim(3).enforce_nonzero() .describe("Dataset Param: Shape of each instance generated by the DataIter."); DMLC_DECLARE_FIELD(label_width).set_default(1) .describe("Dataset Param: Label width."); DMLC_DECLARE_FIELD(round_batch).set_default(true) .describe("Batch Param: Use round robin to handle overflow batch."); } }; /*! \brief create a batch iterator from single instance iterator */ class BatchLoader : public IIterator { public: explicit BatchLoader(IIterator *base): base_(base), head_(1), num_overflow_(0) {} virtual ~BatchLoader(void) { delete base_; // Free space for TblobBatch mshadow::FreeSpace(&data_holder_); mshadow::FreeSpace(&label_holder_); } inline void Init(const std::vector >& kwargs) { std::vector > kwargs_left; // init batch param, it could have similar param with kwargs_left = param_.InitAllowUnknown(kwargs); // init object attributes std::vector data_shape_vec; data_shape_vec.push_back(param_.batch_size); for (size_t shape_dim = 0; shape_dim < param_.data_shape.ndim(); ++shape_dim) { data_shape_vec.push_back(param_.data_shape[shape_dim]); } data_shape_ = TShape(data_shape_vec.begin(), data_shape_vec.end()); std::vector label_shape_vec; label_shape_vec.push_back(param_.batch_size); label_shape_vec.push_back(param_.label_width); label_shape_ = TShape(label_shape_vec.begin(), label_shape_vec.end()); // Init space for out_ out_.inst_index = new unsigned[param_.batch_size]; out_.data.clear(); data_holder_ = mshadow::NewTensor(data_shape_.get<4>(), 0.0f); label_holder_ = mshadow::NewTensor(label_shape_.get<2>(), 0.0f); out_.data.push_back(TBlob(data_holder_)); out_.data.push_back(TBlob(label_holder_)); // init base iterator base_->Init(kwargs); } inline void BeforeFirst(void) { if (param_.round_batch == 0 || num_overflow_ == 0) { // otherise, we already called before first base_->BeforeFirst(); } else { num_overflow_ = 0; } head_ = 1; } inline bool Next(void) { out_.num_batch_padd = 0; this->head_ = 0; // if overflow from previous round, directly return false, until before first is called if (num_overflow_ != 0) return false; index_t top = 0; while (base_->Next()) { const DataInst& d = base_->Value(); out_.inst_index[top] = d.index; mshadow::Copy(out_.data[1].get()[top], d.data[1].get()); mshadow::Copy(out_.data[0].get()[top], d.data[0].get()); if (++ top >= param_.batch_size) { return true; } } if (top != 0) { if (param_.round_batch != 0) { num_overflow_ = 0; base_->BeforeFirst(); for (; top < param_.batch_size; ++top, ++num_overflow_) { CHECK(base_->Next()) << "number of input must be bigger than batch size"; const DataInst& d = base_->Value(); out_.inst_index[top] = d.index; mshadow::Copy(out_.data[1].get()[top], d.data[1].get()); mshadow::Copy(out_.data[0].get()[top], d.data[0].get()); } out_.num_batch_padd = num_overflow_; } else { out_.num_batch_padd = param_.batch_size - top; } return true; } return false; } virtual const TBlobBatch &Value(void) const { return out_; } private: /*! \brief batch parameters */ BatchParam param_; /*! \brief output data */ TBlobBatch out_; /*! \brief base iterator */ IIterator *base_; /*! \brief on first */ int head_; /*! \brief number of overflow instances that readed in round_batch mode */ int num_overflow_; /*! \brief data shape */ TShape data_shape_; /*! \brief label shape */ TShape label_shape_; /*! \brief tensor to hold data */ mshadow::Tensor data_holder_; /*! \brief tensor to hold label */ mshadow::Tensor label_holder_; }; // class BatchLoader } // namespace io } // namespace mxnet #endif // MXNET_IO_ITER_BATCHLOADER_H_ //===== EXPANDED: mxnet/src/io/iter_batchloader.h ===== //===== EXPANDIND: mxnet/src/io/iter_prefetcher.h ===== /*! * Copyright (c) 2015 by Contributors * \file iter_prefetcher.h * \brief define a prefetcher using threaditer to keep k batch fetched */ #ifndef MXNET_IO_ITER_PREFETCHER_H_ #define MXNET_IO_ITER_PREFETCHER_H_ //===== EXPANDIND: mxnet/dmlc-core/include/dmlc/threadediter.h ===== /*! * Copyright (c) 2015 by Contributors * \file threadediter.h * \brief thread backed iterator that can be used to implement * general thread-based pipeline such as prefetch and pre-computation * To use the functions in this header, C++11 is required * \author Tianqi Chen */ #ifndef DMLC_THREADEDITER_H_ #define DMLC_THREADEDITER_H_ // defines DMLC_USE_CXX11 // this code depends on c++11 #if DMLC_USE_CXX11 namespace dmlc { /*! * \brief a iterator that was backed by a thread * to pull data eagerly from a single producer into a bounded buffer * the consumer can pull the data at its own rate * * NOTE: thread concurrency cost time, make sure to store big blob of data in DType * * Usage example: * \code * ThreadedIter iter; * iter.Init(&producer); * // the following code can be in parallel * DType *dptr; * while (iter.Next(&dptr)) { * // do something on dptr * // recycle the space * iter.Recycle(&dptr); * } * \endcode * \tparam DType the type of data blob we support */ template class ThreadedIter : public DataIter { public: /*! * \brief producer class interface * that threaditer used as source to * preduce the content */ class Producer { public: // virtual destructor virtual ~Producer() {} /*! \brief reset the producer to beginning */ virtual void BeforeFirst(void) { NotImplemented(); } /*! * \brief load the data content into DType, * the caller can pass in NULL or an existing address * when inout_dptr is NULL: * producer need to allocate a DType and fill the content * when inout_dptr is specified * producer takes need to fill the content into address * specified inout_dptr, or delete the one and create a new one * * \param inout_dptr used to pass in the data holder cell * and return the address of the cell filled * \return true if there is next record, false if we reach the end */ virtual bool Next(DType **inout_dptr) = 0; }; /*! * \brief constructor * \param max_capacity maximum capacity of the queue */ explicit ThreadedIter(size_t max_capacity = 8) : producer_owned_(NULL), producer_thread_(NULL), max_capacity_(max_capacity), nwait_consumer_(0), nwait_producer_(0), out_data_(NULL) {} /*! \brief destructor */ virtual ~ThreadedIter(void) { this->Destroy(); } /*! * \brief destroy all the related resources * this is equivalent to destructor, can be used * to destroy the threaditer when user think it is * appropriate, it is safe to call this multiple times */ inline void Destroy(void); /*! * \brief set maximum capacity of the queue * \param max_capacity maximum capacity of the queue */ inline void set_max_capacity(size_t max_capacity) { max_capacity_ = max_capacity; } /*! * \brief initialize the producer and start the thread * can only be called once * \param producer pointer to the producer * \param pass_ownership whether pass the ownership to the iter * if this is true, the threaditer will delete the producer * when destructed */ inline void Init(Producer *producer, bool pass_ownership = false); /*! * \brief initialize the producer and start the thread * pass in two function(closure) of producer to represent the producer * the beforefirst function is optional, and defaults to not implemented * NOTE: the closure must remain valid until the ThreadedIter destructs * \param next the function called to get next element, see Producer.Next * \param beforefirst the function to call to reset the producer, see Producer.BeforeFirst */ inline void Init(std::function next, std::function beforefirst = NotImplemented); /*! * \brief get the next data, this function is threadsafe * \param out_dptr used to hold the pointer to the record * after the function call, the caller takes ownership of the pointer * the caller can call recycle to return ownership back to the threaditer * so that the pointer can be re-used * \return true if there is next record, false if we reach the end * \sa Recycle */ inline bool Next(DType **out_dptr); /*! * \brief recycle the data cell, this function is threadsafe * the threaditer can reuse the data cell for future data loading * \param inout_dptr pointer to the dptr to recycle, after the function call * the content of inout_dptr will be set to NULL */ inline void Recycle(DType **inout_dptr); /*! * \brief adapt the iterator interface's Next * NOTE: the call to this function is not threadsafe * use the other Next instead * \return true if there is next record, false if we reach the end */ virtual bool Next(void) { if (out_data_ != NULL) { this->Recycle(&out_data_); } if (Next(&out_data_)) { return true; } else { return false; } } /*! * \brief adapt the iterator interface's Value * NOTE: the call to this function is not threadsafe * use the other Next instead */ virtual const DType &Value(void) const { CHECK(out_data_ != NULL) << "Calling Value at beginning or end?"; return *out_data_; } /*! \brief set the iterator before first location */ virtual void BeforeFirst(void) { std::unique_lock lock(mutex_); if (out_data_ != NULL) { free_cells_.push(out_data_); out_data_ = NULL; } if (producer_sig_ == kDestroy) return; producer_sig_ = kBeforeFirst; CHECK(!producer_sig_processed_); if (nwait_producer_ != 0) { producer_cond_.notify_one(); } CHECK(!producer_sig_processed_); // wait until the request has been processed consumer_cond_.wait(lock, [this]() { return producer_sig_processed_; }); producer_sig_processed_ = false; bool notify = nwait_producer_ != 0 && !produce_end_; lock.unlock(); // notify producer, in case they are waiting for the condition. if (notify) producer_cond_.notify_one(); } private: /*! \brief not support BeforeFirst */ inline static void NotImplemented(void) { LOG(FATAL) << "BeforeFirst is not supported"; } /*! \brief signals send to producer */ enum Signal { kProduce, kBeforeFirst, kDestroy }; /*! \brief producer class */ Producer *producer_owned_; /*! \brief signal to producer */ Signal producer_sig_; /*! \brief whether the special signal other than kProduce is procssed */ bool producer_sig_processed_; /*! \brief thread that runs the producer */ std::thread *producer_thread_; /*! \brief whether produce ends */ bool produce_end_; /*! \brief maximum queue size */ size_t max_capacity_; /*! \brief internal mutex */ std::mutex mutex_; /*! \brief number of consumer waiting */ unsigned nwait_consumer_; /*! \brief number of consumer waiting */ unsigned nwait_producer_; /*! \brief conditional variable for producer thread */ std::condition_variable producer_cond_; /*! \brief conditional variable for consumer threads */ std::condition_variable consumer_cond_; /*! \brief the current output cell */ DType *out_data_; /*! \brief internal queue of producer */ std::queue queue_; /*! \brief free cells that can be used */ std::queue free_cells_; }; // implementation of functions template inline void ThreadedIter::Destroy(void) { if (producer_thread_ != NULL) { { // lock the mutex std::lock_guard lock(mutex_); // send destroy signal producer_sig_ = kDestroy; if (nwait_producer_ != 0) { producer_cond_.notify_one(); } } producer_thread_->join(); delete producer_thread_; producer_thread_ = NULL; } // end of critical region // now the slave thread should exit while (free_cells_.size() != 0) { delete free_cells_.front(); free_cells_.pop(); } while (queue_.size() != 0) { delete queue_.front(); queue_.pop(); } if (producer_owned_ != NULL) { delete producer_owned_; } if (out_data_ != NULL) { delete out_data_; out_data_ = NULL; } } template inline void ThreadedIter:: Init(Producer *producer, bool pass_ownership) { CHECK(producer_owned_ == NULL) << "can only call Init once"; if (pass_ownership) producer_owned_ = producer; auto next = [producer](DType **dptr) { return producer->Next(dptr); }; auto beforefirst = [producer]() { producer->BeforeFirst(); }; this->Init(next, beforefirst); } template inline void ThreadedIter:: Init(std::function next, std::function beforefirst) { producer_sig_ = kProduce; producer_sig_processed_ = false; produce_end_ = false; // procedure running in prodcuer // run producer thread auto producer_fun = [this, next, beforefirst] () { while (true) { DType *cell = NULL; { // lockscope std::unique_lock lock(mutex_); ++this->nwait_producer_; producer_cond_.wait(lock, [this]() { if (producer_sig_ == kProduce) { bool ret = !produce_end_ && (queue_.size() < max_capacity_ || free_cells_.size() != 0); return ret; } else { return true; } }); --this->nwait_producer_; if (producer_sig_ == kProduce) { if (free_cells_.size() != 0) { cell = free_cells_.front(); free_cells_.pop(); } } else if (producer_sig_ == kBeforeFirst) { // reset the producer beforefirst(); // cleanup the queue while (queue_.size() != 0) { free_cells_.push(queue_.front()); queue_.pop(); } // reset the state produce_end_ = false; producer_sig_processed_ = true; producer_sig_ = kProduce; // notify consumer that all the process as been done. lock.unlock(); consumer_cond_.notify_all(); continue; } else { // destroy the thread CHECK(producer_sig_ == kDestroy); producer_sig_processed_ = true; produce_end_ = true; consumer_cond_.notify_all(); return; } } // end of lock scope // now without lock produce_end_ = !next(&cell); CHECK(cell != NULL || produce_end_); bool notify; { // lockscope std::lock_guard lock(mutex_); if (!produce_end_) { queue_.push(cell); } else { if (cell != NULL) free_cells_.push(cell); } // put things into queue notify = nwait_consumer_ != 0; } if (notify) consumer_cond_.notify_all(); } }; producer_thread_ = new std::thread(producer_fun); } template inline bool ThreadedIter:: Next(DType **out_dptr) { if (producer_sig_ == kDestroy) return false; std::unique_lock lock(mutex_); CHECK(producer_sig_ == kProduce) << "Make sure you call BeforeFirst not inconcurrent with Next!"; ++nwait_consumer_; consumer_cond_.wait(lock, [this]() { return queue_.size() != 0 || produce_end_; }); --nwait_consumer_; if (queue_.size() != 0) { *out_dptr = queue_.front(); queue_.pop(); bool notify = nwait_producer_ != 0 && !produce_end_; lock.unlock(); if (notify) producer_cond_.notify_one(); return true; } else { CHECK(produce_end_); return false; } } template inline void ThreadedIter::Recycle(DType **inout_dptr) { bool notify; { std::lock_guard lock(mutex_); free_cells_.push(*inout_dptr); *inout_dptr = NULL; notify = nwait_producer_ != 0 && !produce_end_; } if (notify) producer_cond_.notify_one(); } } // namespace dmlc #endif // DMLC_USE_CXX11 #endif // DMLC_THREADEDITER_H_ //===== EXPANDED: mxnet/dmlc-core/include/dmlc/threadediter.h ===== namespace mxnet { namespace io { // Define prefetcher parameters struct PrefetcherParam : public dmlc::Parameter { /*! \brief number of prefetched batches */ size_t prefetch_buffer; // declare parameters DMLC_DECLARE_PARAMETER(PrefetcherParam) { DMLC_DECLARE_FIELD(prefetch_buffer).set_default(4) .describe("Backend Param: Number of prefetched parameters"); } }; // iterator on image recordio class PrefetcherIter : public IIterator { public: explicit PrefetcherIter(IIterator* base) : out_(nullptr), loader_(base) { } ~PrefetcherIter() { while (recycle_queue_.size() != 0) { DataBatch *batch = recycle_queue_.front(); recycle_queue_.pop(); delete batch; } delete out_; iter_.Destroy(); } virtual void Init(const std::vector >& kwargs) { std::vector > kwargs_left; // init image rec param kwargs_left = param_.InitAllowUnknown(kwargs); // use the kwarg to init batch loader loader_->Init(kwargs); // maximum prefetch threaded iter internal size const int kMaxPrefetchBuffer = 16; // init thread iter iter_.set_max_capacity(kMaxPrefetchBuffer); iter_.Init([this](DataBatch **dptr) { if (!loader_->Next()) return false; const TBlobBatch& batch = loader_->Value(); if (*dptr == nullptr) { // allocate databatch *dptr = new DataBatch(); (*dptr)->num_batch_padd = batch.num_batch_padd; (*dptr)->data.resize(batch.data.size()); for (size_t i = 0; i < batch.data.size(); ++i) { (*dptr)->data.at(i) = NDArray(batch.data[i].shape_, Context::CPU()); } } CHECK(batch.data.size() == (*dptr)->data.size()); // copy data over for (size_t i = 0; i < batch.data.size(); ++i) { CHECK_EQ((*dptr)->data.at(i).shape(), batch.data[i].shape_); mshadow::Copy(((*dptr)->data)[i].data().FlatTo2D(), batch.data[i].FlatTo2D()); (*dptr)->num_batch_padd = batch.num_batch_padd; } return true; }, [this]() { loader_->BeforeFirst(); }); } virtual void BeforeFirst(void) { iter_.BeforeFirst(); } virtual bool Next(void) { if (out_ != nullptr) { recycle_queue_.push(out_); out_ = nullptr; } // do recycle if (recycle_queue_.size() == param_.prefetch_buffer) { DataBatch *old_batch = recycle_queue_.front(); // can be more efficienct on engine for (NDArray& arr : old_batch->data) { arr.WaitToWrite(); } recycle_queue_.pop(); iter_.Recycle(&old_batch); } return iter_.Next(&out_); } virtual const DataBatch &Value(void) const { return *out_; } private: /*! \brief prefetcher parameters */ PrefetcherParam param_; // output data DataBatch *out_; // queue to be recycled std::queue recycle_queue_; // backend thread dmlc::ThreadedIter iter_; // internal batch loader std::unique_ptr > loader_; }; } // namespace io } // namespace mxnet #endif // MXNET_IO_ITER_PREFETCHER_H_ //===== EXPANDED: mxnet/src/io/iter_prefetcher.h ===== // Registers namespace dmlc { DMLC_REGISTRY_ENABLE(::mxnet::DataIteratorReg); } // namespace dmlc namespace mxnet { namespace io { // Register parameters in header files DMLC_REGISTER_PARAMETER(BatchParam); DMLC_REGISTER_PARAMETER(PrefetcherParam); DMLC_REGISTER_PARAMETER(ImageAugmentParam); DMLC_REGISTER_PARAMETER(ImageNormalizeParam); } // namespace io } // namespace mxnet //===== EXPANDED: mxnet/src/io/io.cc ===== //===== EXPANDIND: mxnet/src/kvstore/kvstore.cc ===== /*! * Copyright (c) 2015 by Contributors * \file kvstore.cc * \brief implement kv_store */ //===== EXPANDIND: mxnet/include/mxnet/kvstore.h ===== /*! * Copyright (c) 2015 by Contributors * \file kvstore.h * \brief key-value store interface for mxnet */ #ifndef MXNET_KVSTORE_H_ #define MXNET_KVSTORE_H_ namespace mxnet { /*! * \brief distributed key-value store * * A distributed key-value store for data synchronization over multiple * devices/machines. It support user-defined updater. */ class KVStore { public: /*! \brief virtual destructor */ virtual ~KVStore() {} /*! * \brief Factory function to create a new KVStore. * \param type The type of the kvstore, * - 'local' or 'local_update_cpu' or 'local_allreduce_cpu' * multi-devices on a single machine. can be also * - 'device' or 'local_allreduce_device' : same to local but use gpus for kv * allreduce * - 'dist_*' : multi-machines * \return a new created KVStore. */ static KVStore *Create(const char *type = "local"); /** * \brief return the type */ inline const std::string& type() { return type_; } /*! * \brief Initialize a list of key-value pair to the store. * * One must initalize the key before \ref Push and \ref Pull, and a key * should be only initialized once * * It returns after data have been initialized successfully. * * For multiple workers, all workers must call \ref Init. But only worker 0 * (get_rank() == 0)'s values are used for initialization. So others' values * can be empty (but not keys). This function blocks until all workers are * finished. That means, any worker can push and pull on the keys now. * * \param keys a list of unique keys * \param values a list of values */ virtual void Init(const std::vector& keys, const std::vector& values) = 0; /*! * \brief push a list of key-value pairs into the store * * If a key appears mulitple times in \a keys, then the according values will * be aggregated (summed) before pushing. * * The (aggregated) values are merged into the store one by one * * \code * updater(key, value, &value_in_store); * \endcode * * One can set a user-defined updater by \ref set_updater. The default updater * is Assign. * * This function returns after adding a push operator to the engine. Any * following operator requiring writing value will be blocked until the * actual push is finished. One can wait the push is finished by * * - when type == "local" * \code * for (auto& v : values) v.WaitToWrite() * \endcode * * - when type == "dist" * \code * Wait(keys); * \endcode * * One must call Init() on every key before. And the value NDArray should be * always has the same shape as being inited. * * \param keys the list of keys * \param values the list of values * \param priority Priority of the action. */ virtual void Push(const std::vector& keys, const std::vector& values, int priority = 0) = 0; /*! * \brief pull a list of key-value pairs from the store * * One must call Init() on \a key before. And \a value should be pre-allocated * * This function returns after adding a pull operator to the engine. Any * following operator requiring reading value will be blocked until the * actual pull is finished. One can wait the pull is finished by * * - when type == "local" * \code * for (auto& v : values) v.WaitToRead() * \endcode * * - when type == "dist" * \code * Wait(keys); * \endcode * * \param keys the list of keys * \param values the list of buffers for the pulled data, they should be preallocated * \param priority Priority of the action. */ virtual void Pull(const std::vector& keys, const std::vector& values, int priority = 0) = 0; /** * \brief the prototype of user-defined updater */ typedef std::function Updater; /*! * \brief set an updater * * Given a key, assume \a x is the received (pushed) value and \a y is the * value stored on the store node. The store updates \a y by `h(x, &y)`. The * default \a h is ASSIGN, namely `*y = x`. * * \param updater user-defined updater, default is assign */ virtual void set_updater(const Updater& updater) { CHECK(updater) << "invalid updater"; updater_ = updater; } /****************************************************** * the following are used for multi-machines. ******************************************************/ /** * \return whether or not this process is a worker node. * * Always returns true when type == "local" */ static bool IsWorkerNode() { char* role_str = getenv("DMLC_ROLE"); return (role_str == nullptr) || (!strcmp(role_str, "worker")); } /** * \return whether or not this process is a server node. * * Always returns false when type == "local" */ static bool IsServerNode() { char* role_str = getenv("DMLC_ROLE"); return (role_str != nullptr) && (!strcmp(role_str, "server")); } /** * \return whether or not this process is a scheduler node. * * Always returns false when type == "local" */ static bool IsSchedulerNode() { char* role_str = getenv("DMLC_ROLE"); return (role_str != nullptr) && (!strcmp(role_str, "scheduler")); } /*! * \return The rank of this node in its group, which is in [0, * GroupSize). * * Always return 0 when type == "local" */ virtual int get_rank() const { return 0; } /*! * \return The number of worker nodes */ virtual int get_group_size() const { return 1; } /*! * \brief global barrier among all worker machines * * But note that, this functions only blocks the main thread of workers until * all of them are reached this point. It doesn't guarantee that all * operations issued before are actually finished, such as \ref Push and \ref Pull. */ virtual void Barrier() { } /** * \brief Send a command to all server nodes * * Send a command to all server nodes, which will make each server node run * \a controller * * This function returns after the command has been executed in all server nodes * * \param cmd_id the head of the command * \param cmd_body the body of the command */ virtual void SendCommandToServers(int cmd_id, const std::string& cmd_body) { } /** * \brief the prototype of a server controller */ typedef std::function Controller; /** * \brief Run as server (or scheduler) * * The behavior of a server: * \code * while(receive(x)) { * if (IsCommand(x)) controller(x) * else if (IsKeyValue(x)) updater(x) * } * \endcode * * \param controller the user-defined server controller */ virtual void RunServer(const Controller& controller) { } protected: /** * \brief the user-defined updater */ Updater updater_; /** * \brief the kvstore type */ std::string type_; }; } // namespace mxnet #endif // MXNET_KVSTORE_H_ //===== EXPANDED: mxnet/include/mxnet/kvstore.h ===== //===== EXPANDIND: mxnet/src/kvstore/kvstore_local.h ===== /** * Copyright (c) 2015 by Contributors * @file kvstore_local.h * @brief local implementation */ #ifndef MXNET_KVSTORE_KVSTORE_LOCAL_H_ #define MXNET_KVSTORE_KVSTORE_LOCAL_H_ namespace mxnet { namespace kvstore { /** * \brief store data in local machine */ class KVStoreLocal : public KVStore { public: KVStoreLocal() { pinned_ctx_ = (MXNET_USE_CUDA != 0) ? Context::CPUPinned(0) : Context::CPU(); // the server perameters nthread_reduction_ = dmlc::GetEnv("MXNET_KVSTORE_REDUCTION_NTHREADS", 4); bigarray_bound_ = dmlc::GetEnv("MXNET_KVSTORE_BIGARRAY_BOUND", 1000 * 1000); } void Init(const std::vector& keys, const std::vector& values) override { for (size_t i = 0; i < keys.size(); ++i) { CHECK(local_.find(keys[i]) == local_.end()) << "duplicate init of key " << keys[i]; local_[keys[i]] = values[i].Copy(pinned_ctx_); } } void Push(const std::vector& keys, const std::vector& values, int priority) override { std::vector uniq_keys; std::vector > grouped_vals; GroupKVPairs(keys, values, &uniq_keys, &grouped_vals); for (size_t i = 0; i < uniq_keys.size(); ++i) { int key = uniq_keys[i]; const NDArray& merged = MergePushValue(key, grouped_vals[i], priority); if (updater_ != nullptr) { auto it = local_.find(key); CHECK(it != local_.end()) << "key " << key << " has not been inited"; updater_(key, merged, &(it->second)); } } } void Pull(const std::vector& keys, const std::vector& values, int priority) override { std::vector uniq_keys; std::vector > grouped_vals; GroupKVPairs(keys, values, &uniq_keys, &grouped_vals); for (size_t i = 0; i < uniq_keys.size(); ++i) { int key = uniq_keys[i]; auto it = merge_buf_.find(key); if (updater_ != nullptr || it == merge_buf_.end()) { auto it = local_.find(key); CHECK(it != local_.end()) << "key " << key << " has not been inited"; const NDArray& src = it->second; for (auto* vptr : grouped_vals[i]) { CopyFromTo(src, vptr, priority); } } else { auto& src = it->second.merged; for (auto* vptr : grouped_vals[i]) { CopyFromTo(src, vptr, priority); } } } } protected: /// \brief temperal space for pushing and pull struct BufferEntry { // Context of merged Context ctx; // the merged value NDArray merged; /// \brief the cpu buffer for gpu data std::vector copy_buf; // allocate copy buffer, if it has not been allocated inline NDArray *AllocCopyBuf(size_t index, Context ctx, const TShape& shape) { if (index >= copy_buf.size()) copy_buf.resize(index + 1); if (copy_buf[index].is_none()) { copy_buf[index] = NDArray(shape, ctx); } return ©_buf[index]; } }; /** * \brief group values on keys */ template void GroupKVPairs(const std::vector& keys, const std::vector& values, std::vector* uniq_keys, std::vector >* grouped_vals) { CHECK_EQ(keys.size(), values.size()); // TODO(mli) check if already sorted as an optimization using Idx = std::pair; std::vector idx(keys.size()); for (size_t i = 0; i < keys.size(); ++i) { idx[i].first = keys[i]; idx[i].second = i; } std::sort(idx.begin(), idx.end(), [](const Idx& a, const Idx& b) { return a.first < b.first; }); int pre_key = idx[0].first - 1; for (auto i : idx) { if (i.first != pre_key) { uniq_keys->push_back(i.first); grouped_vals->push_back({values[i.second]}); pre_key = i.first;; } else { grouped_vals->back().push_back(values[i.second]); } } } /*! * \brief returns the aggregated push value */ virtual const NDArray& MergePushValue( int key, const std::vector& val, int priority) { auto& buf = merge_buf_[key]; // copy buffer std::vector const_vars(val.size() - 1); std::vector reduce(val.size()); if (buf.merged.is_none()) { buf.ctx = Context::CPUPinned(val[0].ctx().dev_id); if (MXNET_USE_CUDA == 0) buf.ctx = Context::CPU(); buf.merged = NDArray(val[0].shape(), buf.ctx); } CopyFromTo(val[0], &(buf.merged), priority); reduce[0] = buf.merged; for (size_t i = 1; i < val.size(); ++i) { const NDArray& v = val[i]; Context ctx = v.ctx(); if (ctx.dev_mask() == cpu::kDevMask) { reduce[i] = val[i]; } else { NDArray *copy_buf = buf.AllocCopyBuf( i, Context::CPUPinned(ctx.dev_id), val[0].shape()); CopyFromTo(val[i], copy_buf, priority); reduce[i] = *copy_buf; } const_vars[i - 1] = reduce[i].var(); } Engine::Get()->PushSync([reduce, this](RunContext rctx) { ReduceSumCPU(reduce); }, Context::CPU(), const_vars, {reduce[0].var()}, FnProperty::kCPUPrioritized, priority); return buf.merged; } /// \brief buffer for merging push value std::unordered_map merge_buf_; // pinned context Context pinned_ctx_; // the lower bound of a big array size_t bigarray_bound_; private: inline static void ReduceSumCPU(const std::vector &dptr, size_t offset, index_t size) { using namespace mshadow; // NOLINT(*) Tensor in_0(dptr[0] + offset, Shape1(size)); switch (dptr.size()) { case 2: { Tensor in_1(dptr[1] + offset, Shape1(size)); in_0 += in_1; break; } case 3: { Tensor in_1(dptr[1] + offset, Shape1(size)); Tensor in_2(dptr[2] + offset, Shape1(size)); in_0 += in_1 + in_2; break; } case 4: { Tensor in_1(dptr[1] + offset, Shape1(size)); Tensor in_2(dptr[2] + offset, Shape1(size)); Tensor in_3(dptr[3] + offset, Shape1(size)); in_0 += in_1 + in_2 + in_3; break; } default: { for (size_t i = 1; i < dptr.size(); ++i) { Tensor in_k(dptr[i] + offset, Shape1(size)); in_0 += in_k; } } } } // reduce sum into val[0] // this is performance critical inline void ReduceSumCPU(const std::vector &in_data) { const size_t step = 4 << 10; // ge ptr out std::vector dptr(in_data.size()); for (size_t i = 0; i < in_data.size(); ++i) { TBlob data = in_data[i].data(); CHECK(data.CheckContiguous()); dptr[i] = data.FlatTo2D().dptr_; } size_t total = in_data[0].shape().Size(); long ntask = (total + 1 - step) / step; // NOLINT(*) if (total < bigarray_bound_ || nthread_reduction_ <= 1) { ReduceSumCPU(dptr, 0, total); } else { #pragma omp parallel for schedule(static) num_threads(nthread_reduction_) for (long j = 0; j < ntask; ++j) { // NOLINT(*) size_t k = static_cast(j); size_t begin = std::min(k * step, total); size_t end = std::min((k + 1) * step, total); ReduceSumCPU(dptr, begin, static_cast(end - begin)); } } } /// \brief buffer for storing local values std::unordered_map local_; // number of threads to do reduction int nthread_reduction_; }; } // namespace kvstore } // namespace mxnet #endif // MXNET_KVSTORE_KVSTORE_LOCAL_H_ //===== EXPANDED: mxnet/src/kvstore/kvstore_local.h ===== //===== EXPANDIND: mxnet/src/kvstore/kvstore_device.h ===== /*! * Copyright (c) 2015 by Contributors * \file kvstore_device.h * \brief Device implementation of KVStore that do reduction on GPU reduction. */ #ifndef MXNET_KVSTORE_KVSTORE_DEVICE_H_ #define MXNET_KVSTORE_KVSTORE_DEVICE_H_ namespace mxnet { namespace kvstore { /*! * \brief Device implementation of KVStore that do reduction on GPU reduction. */ class KVStoreDevice : public KVStoreLocal { protected: using KeyShape = std::pair; void Init(const std::vector& keys, const std::vector& values) override { KVStoreLocal::Init(keys, values); for (size_t i = 0; i < keys.size(); ++i) { sorted_key_shape_.push_back(std::make_pair(keys[i], values[i].shape())); } } void InitMergeBuffers(const std::vector& val) { std::sort(sorted_key_shape_.begin(), sorted_key_shape_.end(), []( const KeyShape& a, const KeyShape& b) { return a.second.Size() > b.second.Size(); }); CHECK(!val.empty()); std::unordered_map> ctx_info; for (size_t i = 0; i < val.size(); ++i) { int32_t dev_id = val[i].ctx().dev_id; ctx_info[dev_id] = std::make_pair(val[i].ctx(), 0); } for (size_t i = 0; i < sorted_key_shape_.size(); ++i) { int k = sorted_key_shape_[i].first; TShape s = sorted_key_shape_[i].second; auto& tm_buf = merge_buf_[k]; size_t min_size = std::numeric_limits::max(); for (auto it = ctx_info.begin(); it != ctx_info.end(); ++it) { size_t tm_size = it->second.second; if (tm_size <= min_size) { tm_buf.ctx = it->second.first; min_size = tm_size; } } tm_buf.merged = NDArray(s, tm_buf.ctx); ctx_info[tm_buf.ctx.dev_id].second += s.Size(); } } const NDArray& MergePushValue( int key, const std::vector& val, int priority) override { if (updater_ != nullptr) { // fall back to CPU based update if updater presents return KVStoreLocal::MergePushValue(key, val, priority); } if (merge_buf_.empty()) { InitMergeBuffers(val); } auto& buf = merge_buf_[key]; std::vector reduce(val.size()); CHECK(!buf.merged.is_none()); CopyFromTo(val[0], &(buf.merged), priority); reduce[0] = buf.merged; for (size_t i = 1; i < val.size(); ++i) { NDArray *copy_buf = buf.AllocCopyBuf( i, buf.ctx, val[0].shape()); CopyFromTo(val[i], copy_buf, priority); reduce[i] = *copy_buf; } ElementwiseSum(reduce, &buf.merged); return buf.merged; } private: std::vector sorted_key_shape_; }; } // namespace kvstore } // namespace mxnet #endif // MXNET_KVSTORE_KVSTORE_DEVICE_H_ //===== EXPANDED: mxnet/src/kvstore/kvstore_device.h ===== #if MXNET_USE_DIST_KVSTORE #endif // MXNET_USE_DIST_KVSTORE namespace mxnet { KVStore* KVStore::Create(const char *type_name) { std::string tname = type_name; std::transform(tname.begin(), tname.end(), tname.begin(), ::tolower); KVStore* kv = nullptr; if (tname == "local" || tname == "local_update_cpu" || tname == "local_allreduce_cpu") { kv = new kvstore::KVStoreLocal(); } else if (tname == "device" || tname == "local_allreduce_device") { tname = "local_allreduce_device"; kv = new kvstore::KVStoreDevice(); } else if (tname == "dist_async" || tname == "dist_sync" || tname == "dist") { #if MXNET_USE_DIST_KVSTORE kv = new kvstore::KVStoreDist(); if (tname == "dist_sync" && kv->IsWorkerNode() && kv->get_rank() == 0) { // configure the server to be the sync mode kv->SendCommandToServers(kvstore::kSyncMode, ""); } #else LOG(FATAL) << "compile with USE_DIST_KVSTORE=1 to use " << tname; return nullptr; #endif // MXNET_USE_DIST_KVSTORE } else { LOG(FATAL) << "Unknown KVStore type \"" << tname << "\""; } kv->type_ = tname; return kv; } } // namespace mxnet //===== EXPANDED: mxnet/src/kvstore/kvstore.cc ===== //===== EXPANDIND: mxnet/src/symbol/graph_executor.cc ===== /*! * Copyright (c) 2015 by Contributors * \file graph_executor.cc * \brief Executor to execute the Graph. */ //===== EXPANDIND: mxnet/include/mxnet/symbolic.h ===== /*! * Copyright (c) 2015 by Contributors * \file symbolic.h * \brief Symbolic interface of mxnet. * \author Min Lin, Bing Xu */ #ifndef MXNET_SYMBOLIC_H_ #define MXNET_SYMBOLIC_H_ // check c++11 #if DMLC_USE_CXX11 == 0 #error "CXX11 was required for symbolic module" #endif namespace mxnet { /*! * \brief Internal data structure used for * graph serializaion and graph algorithms. */ class StaticGraph; /*! * \brief Symbol is used to represent dynamically generated symbolic computation graph. * * This class is used as a tool to generate computation graphs(aka. configuration) of the network. * Symbol is always composite, the head Node is the output node of the symbol. * An atomic symbol can be seen as a special case of the composite symbol with only the head node. */ class Symbol { public: /*! * \brief copy the symbol * \return a deep copy of the graph */ Symbol Copy() const; /*! * \brief print the symbol info to output stream. * \param os the output stream we like to print to */ void Print(std::ostream &os) const; // NOLINT(*) /*! * \brief List the arguments names. * * The position of the returned list also corresponds to calling position in operator() * \return the arguments list of this symbol, they can be either named or unnamed (empty string). */ std::vector ListArguments() const; /*! \return get the descriptions of outputs for this symbol */ std::vector ListOutputs() const; /*! \return get the descriptions of auxiliary data for this symbol */ std::vector ListAuxiliaryStates() const; /*! * \brief get the index th element from the returned tuple. * \param index index of multi output * \return the symbol corresponds to the indexed element. */ Symbol operator[] (size_t index) const; /*! * \brief Compose the symbol with arguments, this changes current symbol. * * The positional arguments passed in must be complete(contain all arguments). * * \param args positional arguments for the symbol * \param name name of returned symbol. */ void Compose(const std::vector& args, const std::string& name); /*! * \brief Compose the symbol with arguments, this changes the current symbol. * The kwargs passed in can be in-complete, * * The rest of the symbols will remain the same name. * * \param kwargs keyword arguments for the symbol * \param name name of returned symbol. */ void Compose(const std::unordered_map& kwargs, const std::string& name); /*! * \brief Apply the symbol as a function, compose with arguments * \param args positional arguments for the symbol * \param name name of returned symbol. * \return a new Symbol which is the composition of current symbol with its arguments */ Symbol operator () (const std::vector& args, const std::string& name) const; /*! * \brief compose with named arguments * \param kwargs keyword arguments for the symbol * \param name name of returned symbol. * \return a new symbol which is the composition of current symbol with its arguments */ Symbol operator () (const std::unordered_map& kwargs, const std::string& name) const; /* * \brief Get all the internal nodes of the symbol. * \return symbol A new symbol whose output contains all the outputs of the symbols * Including input variables and intermediate outputs. */ Symbol GetInternals() const; /*! * \brief get the gradient graph * \param wrt with respect to the input * \return the new symbol with gradient graph */ Symbol Grad(const std::vector& wrt) const; /*! * \brief infer the shapes of outputs and unknown input arguments * \param arg_shapes the shape of input arguments of the operator * this should be of same length as the vector returned by ListArguments * in_shape allows unknown elements, which are checked by shape.ndim() == 0. * For unknown shapes, InferShape will try to fill in the correct Shape in in_shape * For known shapes, InferShape will check shape consistency * * common practice: set the shape of data input, and usually weight's shape can be infered * * \param out_shapes Use to store the infered shapes of outputs. * \param aux_shapes Use to store the infered shapes of auxiliary states * \return true if the shape inference is successful, false if there is not enough information. * \throws dmlc::Error if the known arg_shapes are inconsistent. */ bool InferShape(std::vector *arg_shapes, std::vector *out_shapes, std::vector *aux_shapes) const; /*! * \brief infer the shapes by providing shapes of known arguments. * \param known_arg_shapes map of argument name to shape of arguments with known shapes. * \param arg_shapes used to store infered shapes of arguments. * \param out_shapes used to store infered shapes of outputs. * \param aux_shapes Use to store the infered shapes of auxiliary states * \return true if the shape inference is successful, false if there is not enough information. * \throws dmlc::Error if the known arg_shapes are inconsistent. */ bool InferShape(const std::unordered_map &known_arg_shapes, std::vector *arg_shapes, std::vector *out_shapes, std::vector *aux_shapes) const; /*! * \brief interface for json serialization. * \param writer the JSON writer write json. */ void Save(dmlc::JSONWriter *writer) const; /*! * \brief interface for json serialization. * \param reader the JSON read to read json. */ void Load(dmlc::JSONReader *reader); /*! * \brief get number of outputs of this symbol * \return number of outputs */ inline size_t NumOutputs() const { return heads_.size(); } /*! * \brief create Symbol by wrapping OperatorProperty * This function takes the ownership of op * * \param op the OperatorProperty of the Operator * \return Symbol * \sa OperatorProperty::Create */ static Symbol Create(OperatorProperty *op); /*! * \brief create equivalence of symbol by grouping the symbols together * \param symbols list of symbols * \return the grouped symbol */ static Symbol CreateGroup(const std::vector &symbols); /*! * \brief create variable symbol node * \param name name of the variable * \return the new variable */ static Symbol CreateVariable(const std::string &name); protected: // Decalre node, internal data structure. struct Node; /*! \brief an entry that represents output data from a node */ struct DataEntry { /*! \brief the source node of this data */ std::shared_ptr source; /*! \brief index of output from the source. */ uint32_t index; /*! \brief enabled default copy constructor */ DataEntry() {} /*! \brief constructor from index */ DataEntry(std::shared_ptr source, uint32_t index) : source(source), index(index) {} }; /*! * \brief the head nodes of Symbols * This head is only effective when */ std::vector heads_; private: /*! \return whwther the symbol is atomic */ inline bool is_atomic() const; /*! * \brief Visit all the nodes in left-to-right depth first order. * * This function will visit the graph in DFS order, call fvisit exactly once * for each Node, and store the result in out_result. * * \param fvisit function applied for each visit. * \tparam FVisit visiting function type */ template inline void DFSVisit(FVisit fvisit) const; /*! * \brief Find duplicate arguments in the composition * \param out the map of argument-name -> occurence count * \return maximum number of duplication factor */ int FindDuplicateArgs(std::unordered_map *out) const; /*! * \brief Convert symbol into internal static graph * * \param out_graph the pointer holder of the output graph */ void ToStaticGraph(StaticGraph *out_graph) const; /*! * \brief create equivalence of symbol from static graphs. * This operation will change the content of current symbol. * \param graph the static graph */ void FromStaticGraph(const StaticGraph &graph); /*! \brief let static graph know the contents */ friend class StaticGraph; }; /*! * \brief Executor of a computation graph. * Executor can be created by Binding a symbol. */ class Executor { public: /*! \brief destructor */ virtual ~Executor() {} /*! * \brief Perform a Forward operation of Operator * After this operation, user can get the result by using function head. */ virtual void Forward(bool is_train) = 0; /*! * \brief Perform a Backward operation of the Operator. * This must be called after Forward. * After this operation, NDArrays specified by grad_in_args_store will be updated accordingly. * User is allowed to pass in an empty Array if the head node is * loss function and head gradeitn is not needed. * * \param head_grads the gradient of head nodes to be backproped. */ virtual void Backward(const std::vector &head_grads) = 0; /*! * \brief print the execution plan info to output stream. * \param os the output stream we like to print to. */ virtual void Print(std::ostream &os) const {} // NOLINT(*) /*! * \brief get array of outputs in the executor. * \return array of outputs in the executor. */ virtual const std::vector &outputs() const = 0; /*! * \brief Create an operator by bind symbol with context and arguments. * If user do not want to compute the gradients of i-th argument, grad_req_type[i] can be kNullOp. * * \param ctx the context of binding. * \param symbol the symbol that specifies the output of Forward pass. * \param in_args the NDArray that stores the input arguments to the symbol. * \param arg_grad_store NDArray that is used to store the gradient output of the input arguments. * \param grad_req_type requirment type of gradient saving. Can only be in {kNullOp, kAddTo, kWriteTo}. * \param aux_states NDArray that is used as internal state in op * \return a new executor. */ static Executor *Bind(Symbol symbol, Context ctx, const std::vector &in_args, const std::vector &arg_grad_store, const std::vector &grad_req_type, const std::vector &aux_states); }; // class operator } // namespace mxnet #endif // MXNET_SYMBOLIC_H_ //===== EXPANDED: mxnet/include/mxnet/symbolic.h ===== //===== EXPANDIND: mxnet/src/symbol/graph_executor.h ===== /*! * Copyright (c) 2015 by Contributors * \file graph_executor.h * \brief Executor to execute the Forward and Backward on Composition Graph. */ #ifndef MXNET_SYMBOL_GRAPH_EXECUTOR_H_ #define MXNET_SYMBOL_GRAPH_EXECUTOR_H_ //===== EXPANDIND: mxnet/src/symbol/static_graph.h ===== /*! * Copyright (c) 2015 by Contributors * \file static_graph.h * \brief A memory compact representation of symbolic graph * Used for serialization, and helper data structure. * \author Naiyan Wang */ #ifndef MXNET_SYMBOL_STATIC_GRAPH_H_ #define MXNET_SYMBOL_STATIC_GRAPH_H_ namespace mxnet { /*! * \brief StaticGraph is the configuration of computation graphs. * This is the "configuration file" of mxnet. * It can be converted to/from Symbol, and can be used to bind to operators. * The symbol can be converted from/to StaticGraph, the actual configuration used by mxnet. * Symbol offers more flexible way to composite nodes than StaticGraph, which makes it good * tool to generate configurations from language bindings such as python. * \sa Symbol */ class StaticGraph { public: /*! \brief represents a data in the graph */ struct DataEntry { /*! \brief the source node id in the computation graph */ uint32_t source_id; /*! \brief index of output from the source. */ uint32_t index; /*! \brief default constructor */ DataEntry() {} /*! * \brief constructor with source and index * \param source_id source id * \param index node index */ DataEntry(uint32_t source_id, uint32_t index) : source_id(source_id), index(index) {} /*! * \brief compare equality * \param other the other entry to compare * \return whether two entries equals to each other */ inline bool operator==(const DataEntry &other) const { return source_id == other.source_id && index == other.index; } /*! * \brief comparator, allows to use map * \param other the other entry to compare * \return whether two entries is smaller than the other */ inline bool operator<(const DataEntry &other) const { if (source_id == other.source_id) return index < other.index; return source_id < other.source_id; } /*! * \brief interface for json serialization. * \param writer the JSON writer to write json into. */ inline void Save(dmlc::JSONWriter *writer) const { writer->BeginArray(false); writer->WriteArrayItem(source_id); writer->WriteArrayItem(index); writer->EndArray(); } /*! * \brief interface for json serialization. * \param reader the JSON reader to read json from. */ inline void Load(dmlc::JSONReader *reader) { std::pair p; reader->Read(&p); *this = DataEntry(p.first, p.second); } }; /*! * \brief Operation Node in static graphs. * There are two types of node, Forward and Backward Node. * * - Forward node corresponds to the op.Forward * - Backward node corresponds to the Backward pass, * where the corresponding forward node is indicated by backward_source_id. * The op field in Backward node is nullptr * * The reason we explicit support Backward node is to allow special treatment * such as shape inference and state sharing with Forward pass. */ struct Node { /*! \brief wrapped operator property */ std::unique_ptr op; /*! \brief name of the node */ std::string name; /*! \brief inputs (node_id, index) for of the nodes*/ std::vector inputs; /*! * \brief If this field is nonnegative, this indicates this * Node is corresponds to a Backward Operation of Operator. * backward_source_id will points to the corresponding Forward Node. * * For normal node, this field is -1. * When the node is a Backward node, the op field will be nullptr */ int32_t backward_source_id; /*! \brief default constructor */ Node() : backward_source_id(-1) {} friend void swap(Node& lhs, Node& rhs) { std::swap(lhs.op, rhs.op); std::swap(lhs.name, rhs.name); std::swap(lhs.inputs, rhs.inputs); std::swap(lhs.backward_source_id, rhs.backward_source_id); } /*! \brief copy constructor in favor of serialization. */ Node(const Node& another) : op(another.op.get() ? another.op.get()->Copy() : nullptr), name(another.name), inputs(another.inputs), backward_source_id(another.backward_source_id) {} inline Node& operator=(Node another) { swap(*this, another); return *this; } /*! \return whether the node is forward op node */ inline bool is_forward() const { return op != nullptr; } /*! \return whether the node is backward op node */ inline bool is_backward() const { return backward_source_id != -1; } /*! \return whether the node is variable node */ inline bool is_variable() const { return op == nullptr && !is_backward(); } /*! * \brief interface for json serialization. * \param writer the JSON writer write json. */ void Save(dmlc::JSONWriter *writer) const; /*! * \brief interface for json serialization. * \param reader the JSON read to read json. */ void Load(dmlc::JSONReader *reader); }; /*! \brief all nodes in the graph */ std::vector nodes; /*! \brief index of nodes that correspods to arguments */ std::vector arg_nodes; /*! \brief heads outputs of the graph */ std::vector heads; /*! * \brief interface for json serialization. * \param writer the JSON writer write json. */ void Save(dmlc::JSONWriter *writer) const; /*! * \brief interface for json serialization. * \param reader the JSON read to read json. */ void Load(dmlc::JSONReader *reader); // funtions to help inference in static graph /*! * \brief Perform a topological sort on the graph * \return a topological order of node indices. */ std::vector TopoSort() const; /*! * \brief infer the node shapes in the computation graph. * * When calling this function, user can setup the shape information known into right position. * Unknown shape are indicated by shape.ndim() == 0. * * \param topo_order The topological order of node index, as created by TopoSort. * \param node_out_shapes The shapes of the each outputs of nodes in the graph. * \param node_aux_shapes The shapes of the each auxiliary states of nodes in the graph. * \return if the shape inference is successful, return true, else return false. */ bool InferNodeShapes(const std::vector &topo_order, std::vector > *node_out_shapes, std::vector > *node_aux_shapes) const; /*! * \brief infer the shapes of outputs and unknown input arguments * \param in_shape the shape of input arguments of the operator * this should be of same length as the vector returned by ListArguments * in_shape allows unknown elements, which are checked by shape.ndim() == 0. * For unknown shapes, InferShape will try to fill in the correct Shape in in_shape * For known shapes, InferShape will check shape consistency * * common practice: set the shape of data input, and usually weight's shape can be infered * * \param out_shape the shape of outputs of the operator * InferShape will modify the vector to fill output TShape * \param aux_shape the shape of auxiliary states of the operator * InferShape will modify the vector to fill output TShape * \return if the shape inference is successful, return true, else return false. */ bool InferShape(std::vector* in_shape, std::vector* out_shape, std::vector* aux_shape) const; /*! * \brief Add a full backward pass in the static graph. * This function will add gradient nodes for each heads, * and add the backward pass to backprop the gradients all * the way to the arguments. * * This will change the nodes field in the StaticGraph, but will not change other fields. * The head and input of Backward pass will be returned by head_grad_nodes and arg_grads. * * \param head_grad_nodes used to store the created head gradient inputs for backward pass. * \param arg_grads used to store gradients to args, can be multiple one if an argument is used by operator */ void MakeBackwardPass(std::vector *head_grad_nodes, std::vector *arg_grads); /*! * \brief Convert symbol into static graph. * \param symbol the symbol to convert from. */ inline void FromSymbol(const Symbol &symbol) { symbol.ToStaticGraph(this); } /*! * \brief create a sum node that aggregates gradient together * \param grad_source the source of the inputs. * \return a created ElementWiseSum node */ static Node CreateSumNode(const std::vector &grad_source); }; } // namespace mxnet namespace dmlc { DMLC_DECLARE_TRAITS(is_pod, ::mxnet::StaticGraph::DataEntry, true); } #endif // MXNET_SYMBOL_STATIC_GRAPH_H_ //===== EXPANDED: mxnet/src/symbol/static_graph.h ===== //===== EXPANDIND: mxnet/src/symbol/graph_memory_allocator.h ===== /*! * Copyright (c) 2015 by Contributors * \file graph_memory_allocator.h * \brief Memory allocator for graph executor. */ #ifndef MXNET_SYMBOL_GRAPH_MEMORY_ALLOCATOR_H_ #define MXNET_SYMBOL_GRAPH_MEMORY_ALLOCATOR_H_ //===== EXPANDIND: mxnet/src/symbol/graph_algorithm.h ===== /*! * Copyright (c) 2015 by Contributors * \file graph_allocation_helper.h * \brief This header contains graph algorithms on StaticGraph. * It is used compute informations such as whether two * operations can run in parallel, and helps allocation. */ #ifndef MXNET_SYMBOL_GRAPH_ALGORITHM_H_ #define MXNET_SYMBOL_GRAPH_ALGORITHM_H_ namespace mxnet { namespace graph { /*! * \brief Find best path in the DAG, with reward defined * by sum of reward of each node along the path. * \param graph the original static graph. * \param topo_order topo order of the nodes in the graph. * \param node_reward the reward of each node. * \param path the output path of nodes. * \return the total reward of best path. */ inline uint32_t FindBestPath( const StaticGraph &graph, const std::vector &topo_order, const std::vector &node_reward, std::vector *path) { const uint32_t num_nodes = static_cast(graph.nodes.size()); CHECK_EQ(graph.nodes.size(), node_reward.size()); CHECK_EQ(graph.nodes.size(), topo_order.size()); std::vector best_reward(node_reward.size(), 0); std::vector next_node(node_reward.size(), num_nodes); uint32_t best_solution = 0, best_start_node = 0; // traverse in reverse topo order for (auto it = topo_order.rbegin(); it != topo_order.rend(); ++it) { const uint32_t nid = *it; best_reward[nid] += node_reward[nid]; if (best_reward[nid] > best_solution) { best_solution = best_reward[nid]; best_start_node = nid; } for (const StaticGraph::DataEntry& e : graph.nodes[nid].inputs) { const uint32_t prev = e.source_id; if (best_reward[nid] > best_reward[prev]) { best_reward[prev] = best_reward[nid]; next_node[prev] = nid; } } } path->clear(); uint32_t reward = 0; for (uint32_t nid = best_start_node; nid < num_nodes; nid = next_node[nid]) { path->push_back(nid); reward += node_reward[nid]; } CHECK_EQ(reward, best_solution); return best_solution; } /*! * \brief Color the nodes in the graph into index. * The coloring algorithm tries to assign node group * such that node in the same group cannot run in parallel. * * \param graph the original static graph. * \param topo_order topo order of the nodes in the graph. * \param node_importance The importance of the node * \param max_ncolor maximum number of colors allowed. * \param color the color index of each of the node. * \return the total number of colors. */ inline uint32_t ColorNodeGroup( const StaticGraph &graph, const std::vector &topo_order, std::vector node_importance, uint32_t max_ncolor, std::vector *color) { CHECK_NE(max_ncolor, 0); CHECK_EQ(graph.nodes.size(), topo_order.size()); CHECK_EQ(graph.nodes.size(), node_importance.size()); color->clear(); color->resize(topo_order.size(), max_ncolor); uint32_t cindex; // greedy algorithm, every time // find a path with best reward and assign a new color // All the nodes in the path cannot run in parallel. for (cindex = 0; cindex < max_ncolor - 1; ++cindex) { std::vector path; uint32_t reward = FindBestPath(graph, topo_order, node_importance, &path); if (reward == 0) break; for (uint32_t nid : path) { if (node_importance[nid] != 0) { CHECK_EQ(color->at(nid), max_ncolor); color->at(nid) = cindex; // make the importance 0 after color is decided. node_importance[nid] = 0; } } } // assign i for rest of the node for (size_t i = 0; i < topo_order.size(); ++i) { if (color->at(i) == max_ncolor) { color->at(i) = cindex; } } return cindex + 1; } } // namespace graph } // namespace mxnet #endif // MXNET_SYMBOL_GRAPH_ALGORITHM_H_ //===== EXPANDED: mxnet/src/symbol/graph_algorithm.h ===== namespace mxnet { /*! * \brief Memory allocators for the GraphExecutor. * This class is intended to be used by GraphExecutor * to allocate the memory for each DataEntryInfo. * * The class algorithm works in two phase: * (1) Planning Phase: GraphExecutor call Request and Release * to request and release resources according to dependency. * - Each call to Request will get a ResourceID that is used to * identify the memory block assigned to each DataEntryInfo. * (2) Allocating phase: GraphExecutor call InitMemory. * - Then each DataEntry will call Get to get the real NDArray. * (3) All the memory will be freed up when reference to all the related NDArray ends. */ class GraphStorageAllocator { public: /*! \brief resource index */ typedef int64_t StorageID; /*! \brief bad storage id */ static const StorageID kBadStorageID = -1; /*! \brief constructor to the graph memory allocator */ explicit GraphStorageAllocator( StaticGraph *graph, const std::vector& topo_order) noexcept(false); /*! * \brief Request a memory. * \param ctx the context of the graph * \param shape shape of the NDArray we want * \param node_id the node that is requesting the memory, used as hint. */ StorageID Request(Context ctx, TShape shape, uint32_t node_id); /*! * \brief Release a memory. * \param id the storage ID of the memory. * \param node_id the node id in the graph that is releasing the memory. */ void Release(StorageID id, uint32_t node_id); /*! * \brief Initialize all the memories requested * \return size of memory allocated. */ size_t InitStorages(); /*! * \brief Get the the memory allocated in planning phase. * \param id the storage id allocated in planning phase. * \param shape the shape of the NDArray requested. */ NDArray Get(StorageID id, TShape shape); protected: /*! \brief internal storage entry */ struct StorageEntry { /*! \brief id of the storage */ StorageID id; /*! \brief the context of the storage */ Context ctx; /*! \brief maximum size of the storage that is requested */ size_t max_size; /*! \brief node index that released it last time */ uint32_t released_by_node; /*! \brief the actual NDArray to hold the data */ NDArray data; /*! \brief constructor */ StorageEntry() : max_size(0), released_by_node(0) {} }; /*! * \brief Allocate a StorageID when Request cannot found existing ones. * \param ctx the context of the graph * \param shape shape of the NDArray we want */ StorageID Alloc(Context ctx, size_t size); /*! * \brief Initialize the colors of graph nodes. * \param topo_order the topological order in the graph. */ void InitColor(const std::vector &topo_order); /*! \brief reference to the computation graph */ StaticGraph *graph_; /*! \brief all the resources available */ std::vector > data_; /*! \brief scale used for rough match */ size_t match_range_; /*! * \brief free list of storage entries, maps size to free list */ std::multimap free_; /*! * \brief color of nodes in the graph, used for auxiliary policy making. */ std::vector node_color_; /*! \brief whether use color based match algorithm */ uint32_t num_match_color_; }; // put implementation in header files for now GraphStorageAllocator::GraphStorageAllocator( StaticGraph *graph, const std::vector& topo_order) noexcept(false) : graph_(graph) , num_match_color_(0) { match_range_ = dmlc::GetEnv("MXNET_EXEC_MATCH_RANGE", 16); // if we set this to 1, this means no color based match. // color based match will cost a bit more memory usually // but also enables more parallelization. num_match_color_ = static_cast(common::GetExecNumMatchColor()); this->InitColor(topo_order); } void GraphStorageAllocator::InitColor(const std::vector& topo_order) { std::vector importance(graph_->nodes.size(), 0); for (size_t i = 0; i < topo_order.size(); ++i) { uint32_t nid = topo_order[i]; if (graph_->nodes[nid].is_variable()) continue; importance[nid] = 1; } num_match_color_ = graph::ColorNodeGroup( *graph_, topo_order, importance, num_match_color_, &node_color_); } GraphStorageAllocator::StorageID GraphStorageAllocator::Alloc(Context ctx, size_t size) { StorageID id = static_cast(data_.size()); std::unique_ptr ptr(new StorageEntry()); ptr->id = id; ptr->ctx = ctx; ptr->max_size = size; data_.push_back(std::move(ptr)); return id; } GraphStorageAllocator::StorageID GraphStorageAllocator::Request(Context ctx, TShape shape, uint32_t node_id) { // search memory block in [size / match_range_, size * match_range_) size_t size = shape.Size(); if (match_range_ == 0) return this->Alloc(ctx, size); auto begin = free_.lower_bound(size / match_range_); auto mid = free_.lower_bound(size); auto end = free_.upper_bound(size * match_range_); // TODO(bing, min) consider better strategy // search for memory blocks larger than requested for (auto it = mid; it != end; ++it) { StorageEntry *e = it->second; if (e->ctx != ctx) continue; if (node_color_[e->released_by_node] != node_color_[node_id]) continue; // Use exect matching strategy e->max_size = std::max(size, e->max_size); // find a exact match, erase from map and return free_.erase(it); return e->id; } // then search for memory blocks smaller than requested space for (auto it = mid; it != begin;) { --it; StorageEntry *e = it->second; if (e->ctx != ctx) continue; if (node_color_[e->released_by_node] != node_color_[node_id]) continue; // Use exect matching strategy e->max_size = std::max(size, e->max_size); // find a exact match, erase from map and return free_.erase(it); return e->id; } // cannot find anything return a new one. return this->Alloc(ctx, size); } void GraphStorageAllocator::Release(StorageID id, uint32_t node_id) { CHECK_NE(id, kBadStorageID); StorageEntry *e = data_[id].get(); e->released_by_node = node_id; free_.insert({e->max_size, e}); } size_t GraphStorageAllocator::InitStorages() { size_t total = 0; for (size_t i = 0; i < data_.size(); ++i) { StorageEntry *e = data_[i].get(); TShape shape = mshadow::Shape1(e->max_size); e->data = NDArray(shape, e->ctx); total += e->max_size; } return total; } NDArray GraphStorageAllocator::Get(StorageID id, TShape shape) { CHECK_NE(id, kBadStorageID); StorageEntry *e = data_[id].get(); return e->data.Slice(0, shape.Size()).Reshape(shape); } } // namespace mxnet #endif // MXNET_SYMBOL_GRAPH_MEMORY_ALLOCATOR_H_ //===== EXPANDED: mxnet/src/symbol/graph_memory_allocator.h ===== namespace mxnet { /*! * \brief Executor of a computation graph. */ class GraphExecutor : public Executor { public: virtual ~GraphExecutor(); void Forward(bool is_train) override; void Backward(const std::vector &head_grads) override; const std::vector &outputs() const override { return heads_ndarray_; } void Print(std::ostream &os) const override; // NOLINT(*) // implement Executor::Bind, only call it once. inline void Init(Symbol symbol, Context ctx, const std::vector &in_args, const std::vector &arg_grad_store, const std::vector &grad_req_type, const std::vector &aux_states) { enable_inplace_allocation_ = dmlc::GetEnv("MXNET_EXEC_ENABLE_INPLACE", true); CHECK_EQ(grad_req_type.size(), arg_grad_store.size()); bool need_backward = false; for (auto req : grad_req_type) { if (req != kNullOp) need_backward = true; } this->InitGraph(symbol, ctx, need_backward); this->InitDataEntryInfo(in_args, arg_grad_store, grad_req_type, aux_states); this->InitDataEntryMemory(); this->InitResources(); this->InitOpNodes(); } protected: // internal class of wrapping BackwardOp as ForwardOp class BackwardOpWrapper; // type of data entry enum DataEntryType { // memory is binded by external NDArray in Bind kBindByExternal, // to be binded by external NDArray in Forward and Backward kTobeBindByExternal, // internal memory, allocated kInternalAllocated, // internal memory, to be allocated kNotInitialized }; // Additional information about each data entry struct DataEntryInfo { // the actual data for the entry NDArray data; // write request to this entry OpReqType op_req; // the operatio node that will take // this DataEntry as inplace input int inplace_op_id; // data entry type DataEntryType type; // shape of this entry TShape shape; // storage id from allocator if it is internal allocation. GraphStorageAllocator::StorageID storage_id; // reference count on how many times this entry is being used. // That is how many operators and heads need this DataEntry // this is a temporal variable that is used during initialization. uint32_t temp_ref_count; // real permanent ref count uint32_t ref_count; // constructor DataEntryInfo() : op_req(kNullOp), inplace_op_id(-1), type(kNotInitialized), storage_id(GraphStorageAllocator::kBadStorageID), temp_ref_count(0), ref_count(0) {} }; // all the information needed to push the op to engine struct OpExecEntry { // execution function for Engine::AsyncFn exec_fun; // variables to read from std::vector use_vars; // variables to mutate std::vector mutate_vars; // constructor OpExecEntry() : exec_fun(nullptr) {} }; // Information about operational node struct OpNode { // whether this op node is activated bool activated; // the context of the node Context ctx; // data entry information about outputs of op std::vector outputs; // auxiliary data information of op std::vector aux_states; // The following parts are constructed in InitOpNodes // the real operator std::shared_ptr op; // op context, that is defined for this op. OpContext op_ctx; // executor, this is only allocated for nodes // whose inputs, outputs are pre-defined. // otherwise cached_exec.exec_fun == nullptr OpExecEntry cached_exec; // cached operator handle Engine::OprHandle cached_opr{nullptr}; // constructor OpNode() : activated(false) {} // Manual option for delete operator // need to do this before delete NDArrays inline void DeleteOperator() { if (cached_opr != nullptr) { Engine::Get()->DeleteOperator(cached_opr); cached_opr = nullptr; } } }; /*! * \brief Get input option of a node. * This function is overriden for both Forward and Backward node. * * \param node_id node index of node in StaticGraph * \param in_data the input data entry to the node * \param out_data the output data entry in the graph * \return the paired inplace option. */ template inline std::vector > GetInplaceOption( uint32_t node_id, const std::vector &in_data, const std::vector &out_data) const; /*! * \brief Get resource requirement of a node. * This function is overriden for both Forward and Backward node. * \param node_id node index of node in StaticGraph * \return the desired resource request. */ inline std::vector GetResource(uint32_t node_id) const; /*! * \brief Get number of outputs of a node. * This function is overriden for both Forward and Backward node. * \param node_id node index of node in StaticGraph * \return the number of outputs of the node. */ inline int GetNumOutputs(uint32_t node_id) const; /*! * \brief get execution entry for an OpNode. * This function can only be called after initialization is done. * \param node_id the id of operational node. * \return the execution entry. */ inline OpExecEntry GetOpExecEntry(uint32_t node_id); // initialize the internal graph structure void InitGraph(const Symbol &symbol, Context ctx, bool need_backward); // initialize internal DataEntryInfo, reference counting void InitDataEntryInfo(const std::vector &in_args, const std::vector &arg_grad_store, const std::vector &grad_req_type, const std::vector &aux_states); // initialize internal data entries NDArray void InitDataEntryMemory(); // initialize the internal resources for each op void InitResources(); // initialize OpNode data structure void InitOpNodes(); // run ops from topo order start to end void RunOps(bool is_train, size_t topo_start, size_t topo_end); // internal computational graph StaticGraph graph_; // topological order of nodes in computation graph // backward nodes always follow forward nodes std::vector topo_order_; // whether to enable inplace space bool enable_inplace_allocation_; // total allocated space in #reals size_t total_allocated_reals_; // total allocated temp space size_t total_allocated_temp_; // number of forward nodes in the graph size_t num_forward_nodes_; // head gradient node in the graph, if there is backward pass std::vector head_grad_nodes_; // argument node in the graph, if there is backward pass std::vector arg_grads_; // operational nodes std::vector op_nodes_; // head NDArrays std::vector heads_ndarray_; }; // class GraphExecutor } // namespace mxnet #endif // MXNET_SYMBOL_GRAPH_EXECUTOR_H_ //===== EXPANDED: mxnet/src/symbol/graph_executor.h ===== namespace mxnet { /*! * \brief wrapper class that wraps Backward operation as Forward. */ class GraphExecutor::BackwardOpWrapper : public Operator { public: /*! * \brief create a backward Operator wrapper given forward op. * \param prop pointer to the property of forward wrapper * \param forward_op the shared ptr to Forward operator * \return the created wrapper. */ explicit BackwardOpWrapper(const OperatorProperty *prop, std::shared_ptr forward_op) : op_(forward_op) { out_grad_.resize(prop->NumVisibleOutputs()); in_data_.resize(prop->ListArguments().size()); out_data_.resize(prop->NumOutputs()); std::vector out_grad_ptr(out_grad_.size()); for (size_t i = 0; i < out_grad_.size(); ++i) { out_grad_ptr[i] = &out_grad_[i]; } std::vector in_data_ptr(in_data_.size()); for (size_t i = 0; i < in_data_.size(); ++i) { in_data_ptr[i] = &in_data_[i]; } std::vector out_data_ptr(out_data_.size()); for (size_t i = 0; i < out_data_.size(); ++i) { out_data_ptr[i] = &out_data_[i]; } arg_data_ptr_ = prop->BackwardInputs( out_grad_ptr, in_data_ptr, out_data_ptr); } // implement forward virtual void Forward(const OpContext &ctx, const std::vector &in_data, const std::vector &req, const std::vector &out_data, const std::vector &aux_states) { // set things correctly CHECK(arg_data_ptr_.size() == in_data.size()); for (size_t i = 0; i < in_data.size(); ++i) { *(arg_data_ptr_[i]) = in_data[i]; } // redirect internally op_->Backward(ctx, out_grad_, in_data_, out_data_, req, out_data, aux_states); } private: /*! \brief internal forward operator */ std::shared_ptr op_; /*! \brief internal space for out_grad */ std::vector out_grad_; /*! \brief internal space for in_data */ std::vector in_data_; /*! \brief internal space for out_data */ std::vector out_data_; /*! * \brief pointer to places in the internal space. * arg_data_ptr_ maps in_data in Forward to the internal space. */ std::vector arg_data_ptr_; }; // get resource inline std::vector GraphExecutor::GetResource(uint32_t node_id) const { const StaticGraph::Node &node = graph_.nodes[node_id]; // use input shape std::vector in_shapes; for (StaticGraph::DataEntry e : node.inputs) { in_shapes.push_back(op_nodes_[e.source_id].outputs[e.index].shape); } if (node.is_forward()) { return node.op->ForwardResource(in_shapes); } else { CHECK(node.is_backward()); return graph_.nodes[node.backward_source_id] .op->BackwardResource(in_shapes); } } inline int GraphExecutor::GetNumOutputs(uint32_t node_id) const { const StaticGraph::Node &node = graph_.nodes[node_id]; if (node.is_forward()) { return node.op->NumOutputs(); } else if (node.is_backward()) { return static_cast( graph_.nodes[node.backward_source_id].op->ListArguments().size()); } else { CHECK(node.is_variable()); return 1; } } // implement get input option template inline std::vector > GraphExecutor::GetInplaceOption( uint32_t node_id, const std::vector &in_data, const std::vector &out_data) const { // get the node const StaticGraph::Node &node = graph_.nodes[node_id]; if (node.is_forward()) { std::vector in_data_index(in_data.size()); for (size_t i = 0; i < in_data.size(); ++i) { in_data_index[i] = static_cast(i); } std::vector out_data_ptr(out_data.size()); for (size_t i = 0; i < out_data.size(); ++i) { out_data_ptr[i] = (void*)&out_data[i]; // NOLINT(*) } auto rmap_index = node.op->ForwardInplaceOption(in_data_index, out_data_ptr); std::vector > remap(rmap_index.size()); for (size_t i = 0; i < remap.size(); ++i) { remap[i].first = in_data[rmap_index[i].first]; remap[i].second = *static_cast(rmap_index[i].second); } return remap; } else { CHECK(node.is_backward()); // forward property const OperatorProperty *fwd = graph_.nodes[node.backward_source_id].op.get(); std::vector out_grad_index(fwd->NumVisibleOutputs()); std::vector in_data_index(fwd->ListArguments().size()); std::vector out_data_index(fwd->NumOutputs()); CHECK_EQ(in_data_index.size(), out_data.size()); int counter = 0; for (size_t i = 0; i < out_grad_index.size(); ++i) { out_grad_index[i] = counter++; } for (size_t i = 0; i < in_data_index.size(); ++i) { in_data_index[i] = counter++; } for (size_t i = 0; i < out_data_index.size(); ++i) { out_data_index[i] = counter++; } auto args_index = fwd->DeclareBackwardDependency( out_grad_index, in_data_index, out_data_index); std::vector args_array(counter, nullptr); CHECK_EQ(args_index.size(), in_data.size()); for (size_t i = 0; i < in_data.size(); ++i) { args_array[args_index[i]] = &in_data[i]; } std::vector in_grad_ptr(out_data.size()); for (size_t i = 0; i < in_grad_ptr.size(); ++i) { in_grad_ptr[i] = (void*)&out_data[i]; // NOLINT(*) } auto remap_index = fwd->BackwardInplaceOption( out_grad_index, in_data_index, out_data_index, in_grad_ptr); std::vector > remap(remap_index.size()); for (size_t i = 0; i < remap_index.size(); ++i) { if (args_array[remap_index[i].first] == nullptr) { LOG(FATAL) << "BackwardInplaceOption not consistent with DeclareBackwardDependency"; } remap[i].first = *args_array[remap_index[i].first]; remap[i].second = *static_cast(remap_index[i].second); } return remap; } } inline GraphExecutor::OpExecEntry GraphExecutor::GetOpExecEntry(uint32_t nid) { OpNode& op_node = op_nodes_[nid]; std::vector req; std::vector in_array, out_array, aux_array; in_array.reserve(graph_.nodes[nid].inputs.size()); out_array.reserve(op_node.outputs.size()); req.reserve(op_node.outputs.size()); aux_array.reserve(op_node.aux_states.size()); OpExecEntry exec; // output for (const DataEntryInfo& out : op_node.outputs) { out_array.push_back(out.data); exec.mutate_vars.push_back(out.data.var()); req.push_back(out.op_req); } // aux for (const DataEntryInfo& aux : op_node.aux_states) { aux_array.push_back(aux.data); exec.mutate_vars.push_back(aux.data.var()); } // input for (StaticGraph::DataEntry e : graph_.nodes[nid].inputs) { const DataEntryInfo &info = op_nodes_[e.source_id].outputs[e.index]; in_array.push_back(info.data); // skip inplace since they already appear in mutate vars if (info.inplace_op_id != static_cast(nid)) { exec.use_vars.push_back(info.data.var()); } } // start setup exec function. for (const Resource& r : op_node.op_ctx.requested) { exec.mutate_vars.push_back(r.var); } Operator* op = op_node.op.get(); OpContext* op_ctx_ptr = &op_node.op_ctx; bool is_gpu = op_node.ctx.dev_mask() == gpu::kDevMask; exec.exec_fun = [op, is_gpu, op_ctx_ptr, in_array, req, out_array, aux_array] (RunContext ctx, Engine::CallbackOnComplete on_complete) { std::vector in_data(in_array.size()); std::vector out_data(out_array.size()); std::vector aux_data(aux_array.size()); std::transform(in_array.begin(), in_array.end(), in_data.begin(), [](const NDArray& nd) { return nd.data(); }); std::transform(out_array.begin(), out_array.end(), out_data.begin(), [](const NDArray& nd) { return nd.data(); }); std::transform(aux_array.begin(), aux_array.end(), aux_data.begin(), [](const NDArray& nd) { return nd.data(); }); op_ctx_ptr->run_ctx = ctx; op->Forward(*op_ctx_ptr, in_data, req, out_data, aux_data); if (is_gpu) { #if MXNET_USE_CUDA // Wait GPU kernel to finish. ctx.get_stream()->Wait(); #else LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; #endif } on_complete(); }; return exec; } GraphExecutor::~GraphExecutor() { Engine::Get()->WaitForAll(); // need to delete the operators before delete the NDArray they referenced. for (OpNode& node : op_nodes_) { node.DeleteOperator(); } } void GraphExecutor::InitGraph(const Symbol &symbol, Context ctx, bool need_backward) { // initialize all internal data structures graph_.FromSymbol(symbol); num_forward_nodes_ = graph_.nodes.size(); if (need_backward) { graph_.MakeBackwardPass(&head_grad_nodes_, &arg_grads_); } // reorganize so backward node always follow forward // note that this may not be the case, because existence of head_grad_nodes std::vector topo = graph_.TopoSort(); std::vector backward; for (uint32_t nid : topo) { if (nid < num_forward_nodes_) { topo_order_.push_back(nid); } else { backward.push_back(nid); } } topo_order_.insert(topo_order_.end(), backward.begin(), backward.end()); // setup all the operator nodes data structure op_nodes_.resize(graph_.nodes.size()); for (size_t i = 0; i < graph_.nodes.size(); ++i) { op_nodes_[i].ctx = ctx; op_nodes_[i].outputs.resize(GetNumOutputs(i)); } } void GraphExecutor::InitDataEntryInfo(const std::vector &in_args, const std::vector &arg_grad_store, const std::vector &grad_req_type, const std::vector &aux_states) { CHECK_EQ(arg_grad_store.size(), grad_req_type.size()); CHECK_EQ(in_args.size(), graph_.arg_nodes.size()); // bind inputs for (size_t i = 0; i < graph_.arg_nodes.size(); ++i) { DataEntryInfo &info = op_nodes_[graph_.arg_nodes[i]].outputs[0]; info.type = kBindByExternal; info.data = in_args[i]; } // setup ref for head nodes for (StaticGraph::DataEntry e : graph_.heads) { DataEntryInfo &info = op_nodes_[e.source_id].outputs[e.index]; ++info.ref_count; op_nodes_[e.source_id].activated = true; } // need Backward pass if (arg_grads_.size() != 0) { CHECK_EQ(arg_grads_.size(), arg_grad_store.size()); CHECK_EQ(arg_grads_.size(), grad_req_type.size()); // setup gradient placeholders for (size_t i = 0; i < arg_grads_.size(); ++i) { if (grad_req_type[i] == kNullOp) continue; CHECK_NE(grad_req_type[i], kWriteInplace) << "Gradient request can only be nullop, add, write"; StaticGraph::DataEntry &grad_source = arg_grads_[i]; DataEntryInfo &info = op_nodes_[grad_source.source_id].outputs[grad_source.index]; info.type = kBindByExternal; info.op_req = grad_req_type[i]; info.data = arg_grad_store[i]; ++info.ref_count; op_nodes_[grad_source.source_id].activated = true; } // setup head gradient for (uint32_t nid : head_grad_nodes_) { DataEntryInfo &info = op_nodes_[nid].outputs[0]; info.type = kTobeBindByExternal; } } // update ref counters for all other nodes, in reverse topo order for (auto it = topo_order_.rbegin(); it != topo_order_.rend(); ++it) { uint32_t nid = *it; if (op_nodes_[nid].activated) { for (StaticGraph::DataEntry e : graph_.nodes[nid].inputs) { DataEntryInfo &info = op_nodes_[e.source_id].outputs[e.index]; ++info.ref_count; op_nodes_[e.source_id].activated = true; } } } // shape inference std::vector > out_shapes(op_nodes_.size()); std::vector > aux_shapes(op_nodes_.size()); for (size_t i = 0; i < out_shapes.size(); ++i) { out_shapes[i].resize(op_nodes_[i].outputs.size()); } for (size_t i = 0; i < graph_.arg_nodes.size(); ++i) { out_shapes[graph_.arg_nodes[i]][0] = in_args[i].shape(); } CHECK(graph_.InferNodeShapes(topo_order_, &out_shapes, &aux_shapes)) << "Shape inference cannot be complete in bind"; for (size_t i = 0; i < out_shapes.size(); ++i) { for (size_t j = 0; j < out_shapes[i].size(); ++j) { op_nodes_[i].outputs[j].shape = out_shapes[i][j]; } } // bind aux args size_t aux_ndarray_idx = 0; for (size_t i = 0; i < aux_shapes.size(); ++i) { op_nodes_[i].aux_states.resize(aux_shapes[i].size()); for (size_t j = 0; j < aux_shapes[i].size(); ++j) { DataEntryInfo &info = op_nodes_[i].aux_states[j]; info.shape = aux_shapes[i][j]; info.type = kBindByExternal; CHECK_GT(aux_states.size(), aux_ndarray_idx) << "Input auxiliary NDArray is less than required"; info.data = aux_states[aux_ndarray_idx++]; CHECK_EQ(info.data.data().shape_, info.shape) << "Incorrect NDArray shape" << " Input: " << info.data.data().shape_ << " Desired: " << info.shape; } } } void GraphExecutor::InitDataEntryMemory() { // setup the temp ref counter for allocator algorithms for (OpNode &op : op_nodes_) { for (DataEntryInfo &node : op.outputs) { node.temp_ref_count = node.ref_count; } } // use allocator to allocate memory. GraphStorageAllocator allocator(&graph_, topo_order_); for (size_t i = 0; i < topo_order_.size(); ++i) { uint32_t nid = topo_order_[i]; if (!op_nodes_[nid].activated) continue; if (graph_.nodes[nid].is_variable()) continue; // check inplace option std::vector in_data; in_data.reserve(graph_.nodes[nid].inputs.size()); // check inputs are ready. for (StaticGraph::DataEntry e : graph_.nodes[nid].inputs) { DataEntryInfo &info = op_nodes_[e.source_id].outputs[e.index]; CHECK_NE(info.type, kNotInitialized); CHECK_NE(info.temp_ref_count, 0); in_data.push_back(&info); } std::vector out_data(op_nodes_[nid].outputs.size()); for (size_t i = 0; i < op_nodes_[nid].outputs.size(); ++i) { out_data[i] = &op_nodes_[nid].outputs[i]; CHECK_NE(out_data[i]->type, kInternalAllocated); } auto inplace = GetInplaceOption(nid, in_data, out_data); for (std::pair kv : inplace) { DataEntryInfo* in = kv.first; DataEntryInfo* out = kv.second; if (enable_inplace_allocation_ && in->temp_ref_count == 1 && in->type == kInternalAllocated && out->type == kNotInitialized) { // we can only do inplace if we are last user of in // and out is not initialized. out->type = kInternalAllocated; out->op_req = kWriteInplace; out->storage_id = in->storage_id; // set inplace op id in->temp_ref_count = 0; in->inplace_op_id = static_cast(nid); } } // allocate output, for (DataEntryInfo *out : out_data) { if (out->op_req == kNullOp && out->temp_ref_count != 0) { out->op_req = kWriteTo; } if (out->type == kNotInitialized) { out->storage_id = allocator.Request( op_nodes_[nid].ctx, out->shape, nid); out->type = kInternalAllocated; } } // then free inputs for (DataEntryInfo *in : in_data) { // temp_ref_count == 0 means it is taken by inplace op if (in->temp_ref_count == 0) { CHECK_EQ(in->inplace_op_id, static_cast(nid)); continue; } // if we decrease it to zero, means we are ready to relase --in->temp_ref_count; if (in->temp_ref_count == 0 && in->type == kInternalAllocated) { allocator.Release(in->storage_id, nid); } } // check out again, if there is temp_ref_count == 0, release it for (DataEntryInfo *out : out_data) { if (out->temp_ref_count == 0 && out->type == kInternalAllocated) { allocator.Release(out->storage_id, nid); } } } // one pass complete, allocate real memory this->total_allocated_reals_ = allocator.InitStorages(); // get the real data NDArray into the DataEntryInfo for (size_t i = 0; i < topo_order_.size(); ++i) { uint32_t nid = topo_order_[i]; if (!op_nodes_[nid].activated) continue; for (DataEntryInfo &out : op_nodes_[nid].outputs) { CHECK_NE(out.type, kNotInitialized); if (out.type == kInternalAllocated) { out.data = allocator.Get(out.storage_id, out.shape); } } } // setup heads for (StaticGraph::DataEntry e : graph_.heads) { DataEntryInfo &info = op_nodes_[e.source_id].outputs[e.index]; CHECK_EQ(info.type, kInternalAllocated); heads_ndarray_.push_back(info.data); } } void GraphExecutor::InitResources() { // prepare for temp space allocation std::vector req_temp_cnt(topo_order_.size(), 0); for (size_t i = 0; i < topo_order_.size(); ++i) { uint32_t nid = topo_order_[i]; if (!op_nodes_[nid].activated) continue; if (graph_.nodes[nid].is_variable()) continue; uint32_t cnt = 0; for (const ResourceRequest& req : GetResource(nid)) { if (req.type == ResourceRequest::kTempSpace) ++cnt; } CHECK_LE(cnt, 1) << "Node can only have one temp space request"; req_temp_cnt[nid] = cnt; } uint32_t num_color = static_cast(common::GetExecNumMatchColor()); std::vector req_temp_color; // use graph coloring to find node that won't run in parallel num_color = graph::ColorNodeGroup(graph_, topo_order_, req_temp_cnt, num_color, &req_temp_color); // cached resources temp space std::map > cached_temp; total_allocated_temp_ = 0; // Resource allocation for (size_t i = 0; i < topo_order_.size(); ++i) { uint32_t nid = topo_order_[i]; if (!op_nodes_[nid].activated) continue; if (graph_.nodes[nid].is_variable()) continue; const std::vector& reqs = GetResource(nid); auto& requested = op_nodes_[nid].op_ctx.requested; requested.clear(); // Get the resource of temporal space. for (const ResourceRequest& req : reqs) { const Context &ctx = op_nodes_[nid].ctx; if (req.type == ResourceRequest::kTempSpace) { uint32_t color = req_temp_color[nid]; // try to reuse graph in same color std::map &cmap = cached_temp[ctx]; if (cmap.count(color) != 0) { requested.push_back(cmap.at(color)); } else { Resource r = ResourceManager::Get()->Request(ctx, req); requested.push_back(r); cmap[color] = r; ++total_allocated_temp_; } } else if (req.type == ResourceRequest::kRandom) { requested.push_back(ResourceManager::Get()->Request(ctx, req)); } else { LOG(FATAL) << "resource type not yet supported"; } } } } void GraphExecutor::InitOpNodes() { for (size_t i = 0; i < topo_order_.size(); ++i) { uint32_t nid = topo_order_[i]; if (!op_nodes_[nid].activated) continue; if (graph_.nodes[nid].is_variable()) continue; OpNode& op_node = op_nodes_[nid]; if (graph_.nodes[nid].is_forward()) { op_node.op.reset(graph_.nodes[nid].op->CreateOperator(op_node.ctx)); } else { CHECK(graph_.nodes[nid].is_backward()); op_node.op.reset(new BackwardOpWrapper( graph_.nodes[graph_.nodes[nid].backward_source_id].op.get(), op_nodes_[graph_.nodes[nid].backward_source_id].op)); } bool allow_cache = true; for (StaticGraph::DataEntry e : graph_.nodes[nid].inputs) { DataEntryInfo& info = op_nodes_[e.source_id].outputs[e.index]; if (info.type == kTobeBindByExternal) allow_cache = false; } for (DataEntryInfo& info : op_node.outputs) { if (info.type == kTobeBindByExternal) allow_cache = false; } if (allow_cache) { op_node.cached_exec = GetOpExecEntry(nid); op_node.cached_opr = Engine::Get()->NewOperator( op_node.cached_exec.exec_fun, op_node.cached_exec.use_vars, op_node.cached_exec.mutate_vars, FnProperty::kNormal); } } } void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) { for (size_t i = topo_start; i < topo_end; ++i) { uint32_t nid = topo_order_[i]; if (!op_nodes_[nid].activated) continue; if (graph_.nodes[nid].is_variable()) continue; OpNode& opnode = op_nodes_[nid]; opnode.op_ctx.is_train = is_train; if (opnode.cached_opr != nullptr) { Engine::Get()->Push(opnode.cached_opr, opnode.ctx); } else { auto exec = GetOpExecEntry(nid); Engine::Get()->PushAsync( exec.exec_fun, opnode.ctx, exec.use_vars, exec.mutate_vars, FnProperty::kNormal); } } } void GraphExecutor::Print(std::ostream &os) const { os << "num_forward_nodes=" << num_forward_nodes_ << '\n'; for (size_t i = 0; i < topo_order_.size(); ++i) { uint32_t nid = topo_order_[i]; if (!op_nodes_[nid].activated) continue; os << "Op " << i << ":" << graph_.nodes[nid].name << '\n'; for (size_t j = 0; j < op_nodes_[nid].outputs.size(); ++j) { const DataEntryInfo &info = op_nodes_[nid].outputs[j]; os << "\toutput[" << j << "]: shape=" << info.shape; if (info.storage_id != GraphStorageAllocator::kBadStorageID) { os << ", storage_id=" << info.storage_id; } if (info.inplace_op_id != -1) { os << ", inplace_consumer=" << graph_.nodes[info.inplace_op_id].name; } os << '\n'; } for (size_t j = 0; j < op_nodes_[nid].op_ctx.requested.size(); ++j) { const Resource& resource = op_nodes_[nid].op_ctx.requested[j]; os << "\tresource[" << j << "]: "; if (resource.req.type == ResourceRequest::kTempSpace) { os << "type=TempSpace, id=" << resource.id; } else if (resource.req.type == ResourceRequest::kRandom) { os << "type=RandomNumber"; } os << '\n'; } } os << "Total " << (total_allocated_reals_ >> 18UL) <<" MB allocated\n"; os << "Total " << total_allocated_temp_ <<" TempSpace resource requested\n"; } void GraphExecutor::Forward(bool is_train) { RunOps(is_train, 0, num_forward_nodes_); } void GraphExecutor::Backward(const std::vector &head_grads) { if (head_grads.size() != 0) { // TODO(bing, min): consider pass a map for backward CHECK_EQ(head_grad_nodes_.size(), head_grads.size()); for (size_t i = 0; i < head_grad_nodes_.size(); ++i) { uint32_t nid = head_grad_nodes_[i]; CHECK(graph_.nodes[nid].is_variable()); DataEntryInfo &info = op_nodes_[nid].outputs[0]; CHECK_EQ(info.type, kTobeBindByExternal); info.data = head_grads[i]; } } else { // check all the head_grad_nodes need to have zero ref_count // loss function do not need out_grad for (size_t i = 0; i < head_grad_nodes_.size(); ++i) { uint32_t nid = head_grad_nodes_[i]; DataEntryInfo &info = op_nodes_[nid].outputs[0]; CHECK_EQ(info.ref_count, 0) << "Because the last operator is not Loss function, " << "head_gradient is required in calling backward."; } } RunOps(true, num_forward_nodes_, topo_order_.size()); } Executor *Executor::Bind(Symbol symbol, Context ctx, const std::vector &in_args, const std::vector &arg_grad_store, const std::vector &grad_req_type, const std::vector &aux_states) { GraphExecutor *exec = new GraphExecutor(); exec->Init(symbol, ctx, in_args, arg_grad_store, grad_req_type, aux_states); return exec; } } // namespace mxnet //===== EXPANDED: mxnet/src/symbol/graph_executor.cc ===== //===== EXPANDIND: mxnet/src/symbol/static_graph.cc ===== /*! * Copyright (c) 2015 by Contributors * \file static_graph.cc * \brief static graph of mxnet */ namespace mxnet { std::vector StaticGraph::TopoSort() const { std::vector > stack; std::unordered_set visited; std::vector ret(nodes.size()); std::vector head_node; // out degree std::vector out_degree(nodes.size(), 0); for (const Node& n : nodes) { for (const DataEntry& e : n.inputs) { ++out_degree[e.source_id]; } if (n.is_backward()) { ++out_degree[n.backward_source_id]; } } for (size_t i = 0; i < nodes.size(); ++i) { if (out_degree[i] == 0) { stack.push_back(std::make_pair(static_cast(i), 0)); } } // heads for (auto &head : head_node) { stack.push_back(std::make_pair(head, 0)); } int count = 0; while (!stack.empty()) { std::pair& back = stack.back(); const Node& n = nodes[back.first]; if (back.second == n.inputs.size() + (n.is_backward() ? 1 : 0)) { ret[count++] = back.first; visited.insert(back.first); stack.pop_back(); } else { uint32_t input; if (back.second == n.inputs.size() && n.is_backward()) { input = n.backward_source_id; back.second++; } else { input = n.inputs[back.second++].source_id; } if (visited.count(input) == 0) { stack.push_back(std::make_pair(input, 0)); } } } return ret; } bool StaticGraph::InferNodeShapes(const std::vector &topo_order, std::vector > *node_out_shapes, std::vector > *node_aux_shapes) const { for (uint32_t nid : topo_order) { const Node& node = nodes[nid]; if (node.is_forward()) { std::vector in_shape; for (const DataEntry& e : node.inputs) { in_shape.push_back((*node_out_shapes)[e.source_id][e.index]); } try { if (!node.op->InferShape(&in_shape, &(*node_out_shapes)[nid], &(*node_aux_shapes)[nid])) return false; } catch (const op::InferShapeError &err) { // error handling const std::string &op_name = node.name; std::string arg_name = node.op->ListArguments()[err.index]; std::ostringstream os; os << "InferShape Error in " << op_name << "\'s" << ' ' << arg_name << " argument\n"; auto &source = nodes[node.inputs[err.index].source_id]; if (source.is_variable()) { os << "Corresponding keyword of symbol: " << source.name << '\n' << err.msg; } throw dmlc::Error(os.str()); } for (size_t i = 0; i < node.inputs.size(); ++i) { const DataEntry& e = node.inputs[i]; (*node_out_shapes)[e.source_id][e.index] = in_shape[i]; } } else if (nodes[nid].is_backward()) { // simply use shapes from forward pass to assign backward shape const Node& forward = nodes[node.backward_source_id]; CHECK(forward.is_forward()); std::vector& in_grad_shapes = (*node_out_shapes)[nid]; CHECK(in_grad_shapes.size() == forward.inputs.size()); // assign the input shape to output gradients for (size_t i = 0; i < forward.inputs.size(); ++i) { const DataEntry &e = forward.inputs[i]; try { SHAPE_ASSIGN_CHECK(in_grad_shapes, i, (*node_out_shapes)[e.source_id][e.index]); } catch (const op::InferShapeError &err) { const std::string &op_name = forward.name; std::string arg_name = forward.op->ListArguments()[e.index]; std::ostringstream os; os << "InferShape Error in " << op_name << "\'s" << ' ' << arg_name << " gradient argument\n" << err.msg; throw dmlc::Error(os.str()); } } // consistent check for input shapes auto& out_data_shapes = (*node_out_shapes)[node.backward_source_id]; // use BackwardInputs to select entries corresponding to node.inputs auto in_shape = forward.op->BackwardInputs( out_data_shapes, in_grad_shapes, out_data_shapes); for (size_t i = 0; i < node.inputs.size(); ++i) { const DataEntry& e = node.inputs[i]; try { SHAPE_ASSIGN_CHECK((*node_out_shapes)[e.source_id], e.index, in_shape[i]); } catch (const op::InferShapeError &err) { const std::string &op_name = nodes[e.source_id].name; std::ostringstream os; os << "InferShape Error in " << op_name << "\'s" << " gradient values\n" << err.msg; throw dmlc::Error(os.str()); } } } } // TODO(bing) assign shape for head gradient return true; } bool StaticGraph::InferShape(std::vector *in_shape, std::vector *out_shape, std::vector *aux_shape) const { std::vector > node_out_shapes(nodes.size()); std::vector > node_aux_shapes(nodes.size()); for (size_t i = 0; i < nodes.size(); ++i) { int nout = 1; if (nodes[i].is_forward()) { nout = nodes[i].op->NumOutputs(); } else if (nodes[i].is_backward()) { nout = static_cast(nodes[nodes[i].backward_source_id].inputs.size()); } node_out_shapes[i].resize(nout); } CHECK(in_shape->size() == arg_nodes.size()) << "Wrong number of inputs to infer shape"; for (size_t i = 0; i < arg_nodes.size(); ++i) { node_out_shapes[arg_nodes[i]][0] = (*in_shape)[i]; } if (!InferNodeShapes(this->TopoSort(), &node_out_shapes, &node_aux_shapes)) return false; for (size_t i = 0; i < arg_nodes.size(); ++i) { (*in_shape)[i] = node_out_shapes[arg_nodes[i]][0]; } out_shape->resize(heads.size()); for (size_t i = 0; i < heads.size(); ++i) { const DataEntry &e = heads[i]; (*out_shape)[i] = node_out_shapes[e.source_id][e.index]; } aux_shape->clear(); for (size_t i = 0; i < node_aux_shapes.size(); ++i) { if (node_aux_shapes[i].size() > 0) { for (auto const &shape : node_aux_shapes[i]) { aux_shape->push_back(shape); } } } return true; } StaticGraph::Node StaticGraph::CreateSumNode( const std::vector &grad_source) { // find multiple gradients, need aggregate std::ostringstream os_size; Node agg_node; agg_node.op.reset(OperatorProperty::Create("ElementWiseSum")); os_size << grad_source.size(); agg_node.op->Init({{"num_args", os_size.str()}}); agg_node.inputs = grad_source; return agg_node; } void StaticGraph::MakeBackwardPass(std::vector *head_grad_nodes, std::vector *arg_grads) { arg_grads->clear(); head_grad_nodes->clear(); // get topo order of nodes, before new nodes are added std::vector topo_order = TopoSort(); // map out_data entry to out_grad std::map > grad_map; // allocate head gradient nodes for (DataEntry head : heads) { Node node; std::ostringstream os; os << nodes[head.source_id].name << '_' << head.index << "_grad"; // TODO(bing): add index to name node.name = os.str(); // node id uint32_t nid = static_cast(nodes.size()); nodes.push_back(std::move(node)); // create a variable node for gradient input DataEntry igrad(nid, 0); head_grad_nodes->push_back(nid); // update gradient map auto it = grad_map.find(head); if (it == grad_map.end()) { grad_map[head] = {igrad}; } else { it->second.push_back(igrad); } } // do backward pass traverse for (auto it = topo_order.rbegin(); it != topo_order.rend(); ++it) { uint32_t nid = *it; // skip variables if (nodes[nid].is_variable()) continue; CHECK(nodes[nid].is_forward()) << "Do not support Backward of Backward"; // get out_grad and out_data entry std::vector out_grad, out_data; // nvisible is out_grad.size() int nvisible = nodes[nid].op->NumVisibleOutputs(); // ntotal is out_data.size() int ntotal = nodes[nid].op->NumOutputs(); // check all outpus for (int i = 0; i < ntotal; ++i) { DataEntry odata(nid, static_cast(i)); out_data.push_back(odata); if (i >= nvisible) continue; // get out_grad auto it = grad_map.find(odata); CHECK(it != grad_map.end()) << "bad graph"; std::vector &gnodes = it->second; if (gnodes.size() == 1) { out_grad.push_back(gnodes[0]); } else { std::ostringstream os_name; Node agg_node = StaticGraph::CreateSumNode(gnodes); os_name << nodes[nid].name << '_' << i << "_out_grad_agg"; agg_node.name = os_name.str(); uint32_t agg_node_id = static_cast(nodes.size()); nodes.push_back(std::move(agg_node)); out_grad.push_back(DataEntry(agg_node_id, 0)); } } // Create a gradient backward node Node grad_node; // Point to the corresponding source grad_node.backward_source_id = nid; // select out the dependent inputs grad_node.inputs = nodes[nid].op->BackwardInputs( out_grad, nodes[nid].inputs, out_data); grad_node.name = nodes[nid].name + "_backward"; uint32_t grad_node_id = static_cast(nodes.size()); nodes.push_back(std::move(grad_node)); // update gradient map for (size_t i = 0; i < nodes[nid].inputs.size(); ++i) { DataEntry idata = nodes[nid].inputs[i]; DataEntry igrad(grad_node_id, static_cast(i)); auto it = grad_map.find(idata); if (it == grad_map.end()) { grad_map[idata] = {igrad}; } else { it->second.push_back(igrad); } } } // create return values of arg_grads arg_grads->resize(arg_nodes.size()); for (size_t i = 0; i < arg_nodes.size(); ++i) { DataEntry odata(arg_nodes[i], 0); auto it = grad_map.find(odata); CHECK(it != grad_map.end()) << "bad graph"; if (it->second.size() == 1) { arg_grads->at(i) = it->second[0]; } else { std::ostringstream os_name; Node agg_node = StaticGraph::CreateSumNode(it->second); os_name << nodes[arg_nodes[i]].name << "_grad_agg"; agg_node.name = os_name.str(); uint32_t agg_node_id = static_cast(nodes.size()); nodes.push_back(std::move(agg_node)); arg_grads->at(i) = DataEntry(agg_node_id, 0); } } } void StaticGraph::Node::Save(dmlc::JSONWriter *writer) const { writer->BeginObject(); if (op.get() != nullptr) { writer->WriteObjectKeyValue("op", op->TypeString()); std::map param = op->GetParams(); writer->WriteObjectKeyValue("param", param); } else { std::map empty_param; std::string json_null = "null"; writer->WriteObjectKeyValue("op", json_null); writer->WriteObjectKeyValue("param", empty_param); } writer->WriteObjectKeyValue("name", name); writer->WriteObjectKeyValue("inputs", inputs); writer->WriteObjectKeyValue("backward_source_id", backward_source_id); writer->EndObject(); } void StaticGraph::Node::Load(dmlc::JSONReader *reader) { dmlc::JSONObjectReadHelper helper; std::string op_type_str; std::map param; helper.DeclareField("op", &op_type_str); helper.DeclareField("param", ¶m); helper.DeclareField("name", &name); helper.DeclareField("inputs", &inputs); helper.DeclareField("backward_source_id", &backward_source_id); helper.ReadAllFields(reader); if (op_type_str != "null") { op.reset(OperatorProperty::Create(op_type_str.c_str())); std::vector > vec(param.begin(), param.end()); op->Init(vec); } else { op.reset(nullptr); } } void StaticGraph::Save(dmlc::JSONWriter *writer) const { writer->BeginObject(); writer->WriteObjectKeyValue("nodes", nodes); writer->WriteObjectKeyValue("arg_nodes", arg_nodes); writer->WriteObjectKeyValue("heads", heads); writer->EndObject(); } void StaticGraph::Load(dmlc::JSONReader *reader) { dmlc::JSONObjectReadHelper helper; helper.DeclareField("nodes", &nodes); helper.DeclareField("arg_nodes", &arg_nodes); helper.DeclareField("heads", &heads); helper.ReadAllFields(reader); } } // namespace mxnet //===== EXPANDED: mxnet/src/symbol/static_graph.cc ===== //===== EXPANDIND: mxnet/src/symbol/symbol.cc ===== /*! * Copyright (c) 2015 by Contributors * \file symbol.cc * \brief symbol of mxnet */ namespace mxnet { /*! * \brief Node is represents node of an operator in the symbolic graph. * * It stores connection to the inputs to function represented by OperatorProperty * NOTE on data structure: there are three types of node: * - Normal node: contains all the necessary elements of a graph. * - OperatorProperty: the inputs_ is empty, represents an OperatorProperty that has not been applied. * - Variable: the sym_ is nullptr, represents an named Variable of tensors that can be composed. */ struct Symbol::Node { /*! \brief source node of the current node */ std::shared_ptr backward_source_node; /*! \brief Operator of this node */ std::unique_ptr op; /*! \brief name of the node */ std::string name; /*! \brief inputs to this node */ std::vector inputs; /*! *\brief constructor *\param op the OperatorProperty to construct the Node *\param name the name of the symbol */ explicit Node(OperatorProperty *op = nullptr, const std::string& name = "") : op(op), name(name) { } /*! \return Whether the symbol is atomic */ inline bool is_atomic() const { return inputs.size() == 0 && op != nullptr; } /*! \return Whether it is unit variable */ inline bool is_variable() const { return op == nullptr && !backward_source_node; } /*! \return Whether it is backward op */ inline bool is_backward() const { return backward_source_node.get() != nullptr; } }; /*! \return whwther the symbol is atomic */ inline bool Symbol::is_atomic() const { return heads_[0].source->is_atomic(); } // implementation of template functions template inline void Symbol::DFSVisit(FVisit fvisit) const { std::vector*, uint32_t> > stack; std::unordered_set visited; // put the head into the graph for (auto &head : heads_) { Node* ptr = head.source.get(); if (visited.count(ptr) == 0) { stack.push_back(std::make_pair(&head.source, 0)); visited.insert(ptr); } } while (!stack.empty()) { std::pair *, uint32_t>& back = stack.back(); if (back.second == back.first->get()->inputs.size()) { fvisit(*(back.first)); stack.pop_back(); } else { std::vector& inputs = back.first->get()->inputs; Symbol::DataEntry& input = inputs.at(back.second++); Node* ptr = input.source.get(); if (visited.count(ptr) == 0) { stack.push_back(std::make_pair(&input.source, 0)); visited.insert(ptr); } } } } // helper function to handle keyword argument mismatch // throw approperiate messages inline void KeywordArgumentMismatch(const char *source, const std::vector &user_args, const std::vector &args) { std::unordered_set keys(args.begin(), args.end()); std::ostringstream head, msg; msg << "\nCandidate arguments:\n"; for (size_t i = 0; i < args.size(); ++i) { msg << "\t[" << i << ']' << args[i] << '\n'; } for (const auto& key : user_args) { if (keys.count(key) == 0) { LOG(FATAL) << source << "Keyword argument name " << key << " not found." << msg.str(); } } } int Symbol::FindDuplicateArgs(std::unordered_map *out) const { out->clear(); int max_dup = 1; this->DFSVisit([out, &max_dup](const std::shared_ptr &node) { if (node->is_variable()) { auto iter = out->find(node->name); if (iter == out->end()) { (*out)[node->name] = 1; } else { ++iter->second; max_dup = std::max(max_dup, iter->second); } } }); return max_dup; } // public functions Symbol Symbol::Copy() const { std::unordered_map > old_new; // use DFSVisit to copy all the nodes this->DFSVisit([&old_new](const std::shared_ptr &node) { if (node->op == nullptr) { old_new[node.get()] = std::make_shared(nullptr, node->name); } else { old_new[node.get()] = std::make_shared(node->op->Copy(), node->name); } }); // connect nodes of new graph for (const auto &kv : old_new) { for (const DataEntry& n : kv.first->inputs) { Node *ptr = n.source.get(); kv.second->inputs.push_back(DataEntry(old_new[ptr], n.index)); } } // set the head Symbol s; for (auto &head : heads_) { s.heads_.push_back(DataEntry(old_new[head.source.get()], head.index)); } return s; } void Symbol::Print(std::ostream &os) const { if (this->is_atomic()) { os << "AtomicFunction "<< " Type:" << heads_[0].source->op->TypeString() << '\n' << "Inputs:"; std::vector args = this->ListArguments(); for (size_t i = 0; i < args.size(); ++i) { os << "\targ[" << i << "]=" << args[i] << "\n"; } } else { // use DFSVisit to copy all the nodes os << "Outputs:\n"; for (size_t i = 0; i < heads_.size(); ++i) { os << "\toutput[" << i << "]=" << heads_[i].source->name << '(' << heads_[i].index << ")\n"; } this->DFSVisit([&os](const std::shared_ptr &node) { if (node->is_variable()) { os << "Variable:" << node->name << '\n'; } else { std::string type_string; if (!node->backward_source_node) { type_string = node->op->TypeString(); } else { type_string = node->backward_source_node->op->TypeString(); } os << "Name: " << node->name << " Type:" << type_string << '\n' << "Inputs:\n"; for (size_t i = 0; i < node->inputs.size(); ++i) { os << "\targ[" << i << "]=" << node->inputs[i].source->name << '(' << node->inputs[i].index << ")\n"; } } }); } } std::vector Symbol::ListArguments() const { std::vector ret; if (this->is_atomic()) { return heads_[0].source->op->ListArguments(); } else { this->DFSVisit([&ret](const std::shared_ptr &node) { if (node->is_variable()) { ret.push_back(node->name); } }); return ret; } } std::vector Symbol::ListOutputs() const { std::vector ret; for (auto &head : heads_) { if (head.source->is_variable()) { ret.push_back(head.source->name); } else { auto &hname = head.source->name; std::string rname; if (head.source->is_backward()) { rname = head.source->backward_source_node->op->ListArguments()[head.index]; } else { rname = head.source->op->ListOutputs()[head.index]; } if (hname.length() == 0) { ret.push_back(std::move(rname)); } else { ret.push_back(hname + '_' + rname); } } } return ret; } std::vector Symbol::ListAuxiliaryStates() const { // TODO(linmin, bing): better solution std::vector ret; StaticGraph g; this->ToStaticGraph(&g); std::vector topo_order = g.TopoSort(); for (uint32_t nid : topo_order) { const auto& node = g.nodes[nid]; if (node.op != nullptr) { auto aux_args = node.op->ListAuxiliaryStates(); if (aux_args.size() > 0) { auto &hname = node.name; for (auto const &aux : aux_args) { ret.push_back(hname + '_' + aux); } } } } return ret; } Symbol Symbol::operator[] (size_t index) const { size_t nreturn = NumOutputs(); CHECK_LT(index, nreturn) << "Symbol only accept nonnegative index"; if (nreturn == 1) { return *this; } else { Symbol s; s.heads_.push_back(heads_[index]); return s; } } Symbol Symbol::GetInternals() const { Symbol ret; this->DFSVisit([&ret](const std::shared_ptr &node) { Node* n = node.get(); uint32_t nout; if (n->is_variable()) { nout = 1; } else if (n->is_backward()) { nout = static_cast(n->backward_source_node->inputs.size()); } else { nout = n->op->NumVisibleOutputs(); } for (uint32_t i = 0; i < nout; ++i) { ret.heads_.push_back(DataEntry(node, i)); } }); return ret; } // create a default variable name inline std::string DefaultVarName(const std::string &op_name, const std::string &arg_name) { if (op_name.length() == 0) { return arg_name; } else { return op_name + '_' + arg_name; } } void Symbol::Compose(const std::vector& args, const std::string& name) { // CHECK_EQ(NumOutputs(), 1) << "Only composition of value function is supported currently"; CHECK(!heads_[0].source->is_variable()) << "Variable cannot be composed"; heads_[0].source->name = name; for (size_t i = 0; i < args.size(); ++i) { CHECK_EQ(args[i].NumOutputs(), 1) << "Argument " << i << " is a tuple with " << args[i].NumOutputs() << " elements, scalar is required"; } // positional arguments requires all arguments for now. // TODO(bing) consider partial assignments if (this->is_atomic()) { // atomic symbol do not have place holder for all the arguments std::vector req_args = heads_[0].source->op->ListArguments(); CHECK_LE(args.size(), req_args.size()) << "Incorrect number of arguments, requires " << req_args.size() << ", provided " << args.size(); heads_[0].source->inputs.resize(req_args.size()); for (size_t i = 0; i < args.size(); ++i) { heads_[0].source->inputs[i] = args[i].heads_[0]; } for (size_t i = args.size(); i < req_args.size(); ++i) { heads_[0].source->inputs[i] = DataEntry( std::make_shared(nullptr, DefaultVarName(name, req_args[i])), 0); } } else { // find all the place holders size_t arg_counter = 0; std::unordered_map replace_map; std::vector > replace_plan; // replace map stores the existing replacement plan for arguments node this->DFSVisit([&arg_counter, &replace_map, &replace_plan, &args] (const std::shared_ptr &node) { // visit all the childs, find possible replacement for (size_t i = 0; i < node->inputs.size(); ++i) { DataEntry *e = &(node->inputs[i]); if (e->source->is_variable()) { const DataEntry *target = nullptr; auto iter = replace_map.find(e->source.get()); if (iter == replace_map.end()) { if (arg_counter < args.size()) { target = &(args[arg_counter].heads_[0]); replace_map[e->source.get()] = target; } ++arg_counter; } else { target = iter->second; } replace_plan.push_back(std::make_pair(e, target)); } } }); CHECK_EQ(args.size(), arg_counter) << "Incorrect number of arguments, requires " << arg_counter << ", provided " << args.size(); // now run the replacement for (const auto& kv : replace_plan) { *(kv.first) = *(kv.second); } } } void Symbol::Compose(const std::unordered_map& kwargs, const std::string& name) { // CHECK_EQ(NumOutputs(), 1) << "Only composition of value function is supported currently"; CHECK(!heads_[0].source->is_variable()) << "Variable cannot be composed"; heads_[0].source->name = name; for (const auto& kv : kwargs) { CHECK_EQ(kv.second.NumOutputs(), 1) << "Keyword Argument " << kv.first << " is a tuple, scalar is required"; } size_t nmatched = 0; if (this->is_atomic()) { // atomic symbol do not have place holder for all the arguments std::vector req_args = heads_[0].source->op->ListArguments(); heads_[0].source->inputs.resize(req_args.size()); for (size_t i = 0; i < req_args.size(); ++i) { auto iter = kwargs.find(req_args[i]); if (iter != kwargs.end()) { heads_[0].source->inputs[i] = iter->second.heads_[0]; ++nmatched; } else { heads_[0].source->inputs[i] = DataEntry( std::make_shared(nullptr, DefaultVarName(name, req_args[i])), 0); } } // if things goes wrong recover the old state if (nmatched != kwargs.size()) { heads_[0].source->inputs.clear(); } } else { // find all the arguments positions std::unordered_map dup_args; int max_dup = this->FindDuplicateArgs(&dup_args); if (max_dup > 1) { for (const auto& kv : dup_args) { CHECK_EQ(kv.second, 1) << " Argument name=\"" << kv.first << "\" occured in " << kv.second << " places in the Symbol, " << "Keyword argument call is not supported because this duplication."; } } CHECK_EQ(max_dup, 1); std::vector > replace_plan; std::unordered_set visited; // replace map stores the existing replacement plan for arguments node this->DFSVisit([&nmatched, &visited, &kwargs, &replace_plan] (const std::shared_ptr &node) { // visit all the childs, find possible replacement for (size_t i = 0; i < node->inputs.size(); ++i) { DataEntry *e = &(node->inputs[i]); if (e->source->is_variable()) { const DataEntry *target = nullptr; auto iter = kwargs.find(e->source->name); if (iter != kwargs.end()) { target = &(iter->second.heads_[0]); // count how many arguments have been matched. if (visited.count(e->source.get()) == 0) { visited.insert(e->source.get()); ++nmatched; } replace_plan.push_back(std::make_pair(e, target)); } } } }); if (nmatched == kwargs.size()) { for (const auto& kv : replace_plan) { *(kv.first) = *(kv.second); } } } if (nmatched != kwargs.size()) { std::vector keys(kwargs.size()); std::transform(kwargs.begin(), kwargs.end(), keys.begin(), [](decltype(*kwargs.begin())& kv)->std::string { return kv.first; }); KeywordArgumentMismatch("Symbol.Compose", keys, ListArguments()); } } Symbol Symbol::operator () (const std::vector& args, const std::string& name) const { Symbol s = this->Copy(); s.Compose(args, name); return s; } Symbol Symbol::operator () (const std::unordered_map& kwargs, const std::string& name) const { Symbol s = this->Copy(); s.Compose(kwargs, name); return s; } Symbol Symbol::Grad(const std::vector& wrt) const { StaticGraph g; this->ToStaticGraph(&g); uint32_t num_nodes = g.nodes.size(); std::vector head_grad_nodes; std::vector arg_grads; g.MakeBackwardPass(&head_grad_nodes, &arg_grads); std::vector > shared_node; this->DFSVisit([&shared_node](const std::shared_ptr &n) { shared_node.push_back(n); }); for (std::vector::const_iterator it = g.nodes.begin() + num_nodes; it != g.nodes.end(); ++it) { auto sym_node = std::make_shared(); sym_node->name = it->name; if (it->backward_source_id != -1) { sym_node->backward_source_node = shared_node[it->backward_source_id]; } shared_node.push_back(sym_node); for (auto e : it->inputs) { Symbol::DataEntry entry(shared_node[e.source_id], e.index); sym_node->inputs.push_back(std::move(entry)); } } // make arg lookup dict auto arg_list = ListArguments(); std::unordered_map arg_index; for (uint32_t i = 0; i < arg_list.size(); ++i) { arg_index[arg_list[i]] = i; } // generate the heads Symbol ret; for (const std::string& name : wrt) { if (arg_index.find(name) != arg_index.end()) { uint32_t index = arg_index[name]; Symbol::DataEntry entry(shared_node[arg_grads[index].source_id], arg_grads[index].index); ret.heads_.push_back(entry); } else { KeywordArgumentMismatch("Symbol.Grad ", wrt, arg_list); } } return ret; } bool Symbol::InferShape(std::vector *arg_shapes, std::vector *out_shapes, std::vector *aux_shapes) const { StaticGraph g; this->ToStaticGraph(&g); return g.InferShape(arg_shapes, out_shapes, aux_shapes); } bool Symbol::InferShape(const std::unordered_map& known_arg_shapes, std::vector *arg_shapes, std::vector *out_shapes, std::vector *aux_shapes) const { StaticGraph g; this->ToStaticGraph(&g); arg_shapes->clear(); arg_shapes->resize(g.arg_nodes.size(), TShape()); size_t nmatched = 0; for (size_t i = 0; i < g.arg_nodes.size(); ++i) { const std::string& name = g.nodes[g.arg_nodes[i]].name; auto it = known_arg_shapes.find(name); if (it != known_arg_shapes.end()) { arg_shapes->at(i) = it->second; ++nmatched; } } if (nmatched != known_arg_shapes.size()) { std::vector keys(known_arg_shapes.size()); std::transform(known_arg_shapes.begin(), known_arg_shapes.end(), keys.begin(), [](decltype(*known_arg_shapes.begin())& kv)->std::string { return kv.first; }); KeywordArgumentMismatch("Symbol.InterShape", keys, ListArguments()); } return g.InferShape(arg_shapes, out_shapes, aux_shapes); } void Symbol::Save(dmlc::JSONWriter *writer) const { StaticGraph g; this->ToStaticGraph(&g); g.Save(writer); } void Symbol::Load(dmlc::JSONReader *reader) { StaticGraph g; g.Load(reader); this->FromStaticGraph(g); } Symbol Symbol::Create(OperatorProperty *op) { // use special representation for atomic symbol auto node = std::make_shared(op, ""); size_t nret = op->NumVisibleOutputs(); Symbol s; for (uint32_t i = 0; i < nret; ++i) { s.heads_.push_back(DataEntry(node, i)); } return s; } Symbol Symbol::CreateGroup(const std::vector &symbols) { Symbol ret; for (const auto &s : symbols) { ret.heads_.insert(ret.heads_.end(), s.heads_.begin(), s.heads_.end()); } return ret; } Symbol Symbol::CreateVariable(const std::string &name) { Symbol s; s.heads_.push_back(DataEntry(std::make_shared(nullptr, name), 0)); return s; } void Symbol::ToStaticGraph(StaticGraph *out_graph) const { // TODO(bing): Check unique name std::vector node_order; std::unordered_map node_index; auto &arg_nodes = out_graph->arg_nodes; arg_nodes.clear(); this->DFSVisit([&node_order, &node_index, &arg_nodes](const std::shared_ptr &n) { uint32_t nid = static_cast(node_index.size()); node_index[n.get()] = nid; if (n->is_variable()) { arg_nodes.push_back(nid); } node_order.push_back(n.get()); }); // setup nodes out_graph->nodes.resize(node_index.size()); for (uint32_t nid = 0; nid < node_order.size(); ++nid) { if (node_order[nid]->op != nullptr) { out_graph->nodes[nid].op.reset(node_order[nid]->op->Copy()); } else { out_graph->nodes[nid].op.reset(nullptr); } // backward source if (node_order[nid]->backward_source_node) { out_graph->nodes[nid].backward_source_id = node_index[node_order[nid]->backward_source_node.get()]; } else { out_graph->nodes[nid].backward_source_id = -1; } out_graph->nodes[nid].name = node_order[nid]->name; auto &inputs = out_graph->nodes[nid].inputs; inputs.clear(); for (const DataEntry &src : node_order[nid]->inputs) { StaticGraph::DataEntry e; e.index = src.index; e.source_id = node_index[src.source.get()]; inputs.push_back(e); } } // setup heads out_graph->heads.clear(); for (auto &head : heads_) { StaticGraph::DataEntry e; e.source_id = node_index[head.source.get()]; e.index = head.index; out_graph->heads.push_back(e); } } void Symbol::FromStaticGraph(const StaticGraph &graph) { std::unordered_map > nodes; std::vector topo_order = graph.TopoSort(); // copy ver nodes in topo order for (uint32_t nid : topo_order) { auto &gnode = graph.nodes[nid]; auto sym_node = std::make_shared(); sym_node->name = gnode.name; if (gnode.op.get() != nullptr) { sym_node->op.reset(gnode.op->Copy()); } if (gnode.backward_source_id != -1) { sym_node->backward_source_node = nodes.at(gnode.backward_source_id); } for (const StaticGraph::DataEntry& e : gnode.inputs) { Symbol::DataEntry entry(nodes.at(e.source_id), e.index); sym_node->inputs.push_back(std::move(entry)); } nodes[nid] = sym_node; } // generate the heads heads_.clear(); for (const StaticGraph::DataEntry& e : graph.heads) { Symbol::DataEntry entry(nodes.at(e.source_id), e.index); heads_.push_back(std::move(entry)); } } } // namespace mxnet //===== EXPANDED: mxnet/src/symbol/symbol.cc ===== //===== EXPANDIND: mxnet/src/operator/operator.cc ===== /*! * Copyright (c) 2015 by Contributors * \file operator.cc * \brief operator module of mxnet */ namespace dmlc { DMLC_REGISTRY_ENABLE(::mxnet::OperatorPropertyReg); } // namespace dmlc namespace mxnet { // implementation of all factory functions OperatorProperty *OperatorProperty::Create(const char* type_name) { auto *creator = dmlc::Registry::Find(type_name); if (creator == nullptr) { LOG(FATAL) << "Cannot find Operator " << type_name << " in registry"; } return creator->body(); } } // namespace mxnet //===== EXPANDED: mxnet/src/operator/operator.cc ===== //===== EXPANDIND: mxnet/src/operator/activation.cc ===== /*! * Copyright (c) 2015 by Contributors * \file activation.cc * \brief activation op * \author Bing Xu */ //===== EXPANDIND: mxnet/src/operator/activation-inl.h ===== /*! * Copyright (c) 2015 by Contributors * \file activation-inl.h * \brief Activation operator * \author Bing Xu */ #ifndef MXNET_OPERATOR_ACTIVATION_INL_H_ #define MXNET_OPERATOR_ACTIVATION_INL_H_ namespace mxnet { namespace op { // Declare enumeration of input order to make code more intuitive. // // These enums are only visible within this header namespace activation { enum ActivationOpInputs {kData}; enum ActivationOpOutputs {kOut}; enum ActivationOpType {kReLU, kSigmoid, kTanh}; } // activation struct ActivationParam : public dmlc::Parameter { // use int for enumeration int act_type; DMLC_DECLARE_PARAMETER(ActivationParam) { DMLC_DECLARE_FIELD(act_type) .add_enum("relu", activation::kReLU) .add_enum("sigmoid", activation::kSigmoid) .add_enum("tanh", activation::kTanh) .describe("Activation function to be applied."); } }; /** * \brief This is the implementation of activation operator. * \tparam xpu The device that the op will be executed on. */ template class ActivationOp : public Operator { public: virtual void Forward(const OpContext &ctx, const std::vector &in_data, const std::vector &req, const std::vector &out_data, const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(in_data.size(), 1); CHECK_EQ(out_data.size(), 1); Stream *s = ctx.get_stream(); Tensor data = in_data[activation::kData].FlatTo2D(s); Tensor out = out_data[activation::kOut].FlatTo2D(s); Assign(out, req[activation::kOut], F(data)); } virtual void Backward(const OpContext &ctx, const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, const std::vector &req, const std::vector &in_grad, const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(out_grad.size(), 1); CHECK(in_data.size() == 1 && in_grad.size() == 1); CHECK_EQ(req.size(), 1); Stream *s = ctx.get_stream(); Tensor m_out_grad = out_grad[activation::kOut].FlatTo2D(s); Tensor m_out_data = out_data[activation::kOut].FlatTo2D(s); Tensor m_in_grad = in_grad[activation::kData].FlatTo2D(s); Assign(m_in_grad, req[activation::kData], F(m_out_data) * m_out_grad); } }; // class ActivationOp // Decalre Factory function, used for dispatch specialization template Operator* CreateOp(ActivationParam type); #if DMLC_USE_CXX11 class ActivationProp : public OperatorProperty { public: void Init(const std::vector >& kwargs) override { param_.Init(kwargs); } std::map GetParams() const override { return param_.__DICT__(); } bool InferShape(std::vector *in_shape, std::vector *out_shape, std::vector *aux_shape) const override { using namespace mshadow; CHECK_EQ(in_shape->size(), 1) << "Input:[data]"; const TShape &dshape = in_shape->at(activation::kData); if (dshape.ndim() == 0) return false; out_shape->clear(); out_shape->push_back(dshape); return true; } OperatorProperty* Copy() const override { auto ptr = new ActivationProp(); ptr->param_ = param_; return ptr; } std::string TypeString() const override { return "Activation"; } // decalre dependency and inplace optimization options std::vector DeclareBackwardDependency( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data) const override { #if MXNET_USE_CUDNN == 1 return {out_grad[activation::kOut], out_data[activation::kOut], in_data[activation::kData]}; #else return {out_grad[activation::kOut], out_data[activation::kOut]}; #endif // MXNET_USE_CUDNN } std::vector > BackwardInplaceOption( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, const std::vector &in_grad) const override { return {{out_grad[activation::kOut], in_grad[activation::kData]}}; } std::vector > ForwardInplaceOption( const std::vector &in_data, const std::vector &out_data) const override { return {{in_data[activation::kData], out_data[activation::kOut]}}; } Operator* CreateOperator(Context ctx) const override; private: ActivationParam param_; }; #endif // DMLC_USE_CXX11 } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_ACTIVATION_INL_H_ //===== EXPANDED: mxnet/src/operator/activation-inl.h ===== namespace mxnet { namespace op { template<> Operator *CreateOp(ActivationParam param) { switch (param.act_type) { case activation::kReLU: return new ActivationOp(); case activation::kSigmoid: return new ActivationOp(); case activation::kTanh: return new ActivationOp(); default: LOG(FATAL) << "unknown activation type"; return NULL; } } // DO_BIND_DISPATCH comes from operator_common.h Operator *ActivationProp::CreateOperator(Context ctx) const { DO_BIND_DISPATCH(CreateOp, param_); } DMLC_REGISTER_PARAMETER(ActivationParam); MXNET_REGISTER_OP_PROPERTY(Activation, ActivationProp) .describe("Apply activation function to input.") .add_argument("data", "Symbol", "Input data to activation function.") .add_arguments(ActivationParam::__FIELDS__()); } // namespace op } // namespace mxnet //===== EXPANDED: mxnet/src/operator/activation.cc ===== //===== EXPANDIND: mxnet/src/operator/batch_norm.cc ===== /*! * Copyright (c) 2015 by Contributors * \file batch_norm.cc * \brief * \author Bing Xu */ //===== EXPANDIND: mxnet/src/operator/batch_norm-inl.h ===== /*! * Copyright (c) 2015 by Contributors * \file batch_norm-inl.h * \brief * \author Bing Xu */ #ifndef MXNET_OPERATOR_BATCH_NORM_INL_H_ #define MXNET_OPERATOR_BATCH_NORM_INL_H_ namespace mxnet { namespace op { namespace batchnorm { enum BatchNormOpInputs {kData, kGamma, kBeta}; enum BatchNormOpOutputs {kOut, kOutNoAffine, kMean, kVar}; enum BatchNormOpAuxiliary {kMovingMean, kMovingVar}; enum BatchNormBackResource {kTempSpace}; } // namespace batchnorm struct BatchNormParam : public dmlc::Parameter { float eps; float momentum; DMLC_DECLARE_PARAMETER(BatchNormParam) { DMLC_DECLARE_FIELD(eps).set_default(1e-10f) .describe("Epsilon to prevent div 0"); DMLC_DECLARE_FIELD(momentum).set_default(0.1f) .describe("Momentum for moving average"); } }; template class BatchNormOp : public Operator { public: explicit BatchNormOp(BatchNormParam param) { this->param_ = param; } virtual void Forward(const OpContext &ctx, const std::vector &in_data, const std::vector &req, const std::vector &out_data, const std::vector &aux_states) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(in_data.size(), 3); CHECK_EQ(aux_states.size(), 2); if (ctx.is_train) { CHECK_EQ(out_data.size(), 4); CHECK_EQ(req.size(), 4); } else { CHECK_GE(out_data.size(), 1); CHECK_GE(req.size(), 1); CHECK_EQ(req[batchnorm::kOut], kWriteTo); } Stream *s = ctx.get_stream(); const real_t scale = static_cast(in_data[batchnorm::kData].shape_[1]) / static_cast(in_data[batchnorm::kData].shape_.Size()); Tensor data; Tensor out, out_no_affine; if (in_data[batchnorm::kData].ndim() == 2) { Shape<4> dshape = Shape4(in_data[batchnorm::kData].shape_[0], in_data[batchnorm::kData].shape_[1], 1, 1); data = in_data[batchnorm::kData].get_with_shape(dshape, s); out = out_data[batchnorm::kOut].get_with_shape(dshape, s); if (ctx.is_train) { out_no_affine = out_data[batchnorm::kOutNoAffine].get_with_shape(dshape, s); } } else { data = in_data[batchnorm::kData].get(s); out = out_data[batchnorm::kOut].get(s); if (ctx.is_train) { out_no_affine = out_data[batchnorm::kOutNoAffine].get(s); } } Tensor slope = in_data[batchnorm::kGamma].get(s); Tensor bias = in_data[batchnorm::kBeta].get(s); Tensor moving_mean = aux_states[batchnorm::kMovingMean].get(s); Tensor moving_var = aux_states[batchnorm::kMovingVar].get(s); // cal if (ctx.is_train) { Tensor mean = out_data[batchnorm::kMean].get(s); Tensor var = out_data[batchnorm::kVar].get(s); Assign(mean, req[batchnorm::kMean], scale * sumall_except_dim<1>(data)); Assign(var, req[batchnorm::kVar], scale * sumall_except_dim<1>( F(data - broadcast<1>(mean, data.shape_)))); Assign(out_no_affine, req[batchnorm::kOutNoAffine], (data - broadcast<1>(mean, data.shape_)) / F(broadcast<1>(var + param_.eps, data.shape_))); Assign(out, req[batchnorm::kOut], out_no_affine * broadcast<1>(slope, out.shape_) + broadcast<1>(bias, out.shape_)); moving_mean = moving_mean * param_.momentum + mean * (1 - param_.momentum); moving_var = moving_var * param_.momentum + var * (1 - param_.momentum); } else { Assign(out, req[batchnorm::kOut], broadcast<1>(slope / F(moving_var + param_.eps), data.shape_) * data + broadcast<1>(bias - (slope * moving_mean) / F(moving_var + param_.eps), data.shape_)); } } virtual void Backward(const OpContext &ctx, const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, const std::vector &req, const std::vector &in_grad, const std::vector &aux_states) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(out_grad.size(), 1); CHECK_EQ(in_data.size(), 3); CHECK_EQ(out_data.size(), 4); CHECK_EQ(in_grad.size(), 3); Stream *s = ctx.get_stream(); Tensor data, grad, grad_in; Tensor out, out_no_affine; const real_t scale = static_cast(out_data[batchnorm::kOut].shape_[1]) / static_cast(out_data[batchnorm::kOut].shape_.Size()); if (in_data[batchnorm::kData].ndim() == 2) { Shape<4> dshape = Shape4(out_data[batchnorm::kOut].shape_[0], out_data[batchnorm::kOut].shape_[1], 1, 1); data = in_data[batchnorm::kData].get_with_shape(dshape, s); grad = out_grad[batchnorm::kOut].get_with_shape(dshape, s); grad_in = in_grad[batchnorm::kData].get_with_shape(dshape, s); out = out_data[batchnorm::kOut].get_with_shape(dshape, s); out_no_affine = out_data[batchnorm::kOutNoAffine].get_with_shape(dshape, s); } else { data = in_data[batchnorm::kData].get(s); grad = out_grad[batchnorm::kOut].get(s); grad_in = in_grad[batchnorm::kData].get(s); out = out_data[batchnorm::kOut].get(s); out_no_affine = out_data[batchnorm::kOutNoAffine].get(s); } Tensor mean = out_data[batchnorm::kMean].get(s); Tensor var = out_data[batchnorm::kVar].get(s); Tensor slope = in_data[batchnorm::kGamma].get(s); // Tensor bias = in_data[kBeta].get(s); Tensor gslope = in_grad[batchnorm::kGamma].get(s); Tensor gbias = in_grad[batchnorm::kBeta].get(s); // get requested temp space Tensor workspace = ctx.requested[batchnorm::kTempSpace].get_space( mshadow::Shape2(3, out.shape_[1]), s); Tensor gmean = workspace[0]; Tensor gvar = workspace[1]; Tensor tmp = workspace[2]; // cal gvar = sumall_except_dim<1>((grad * broadcast<1>(slope, data.shape_)) * (data - broadcast<1>(mean, data.shape_)) * -0.5f * F(broadcast<1>(var + param_.eps, data.shape_), -1.5f)); gmean = sumall_except_dim<1>(grad * broadcast<1>(slope, data.shape_)); gmean *= -1.0f / F(var + param_.eps); tmp = scale * sumall_except_dim<1>(-2.0f * (data - broadcast<1>(mean, data.shape_))); tmp *= gvar; gmean += tmp; // assign Assign(gslope, req[batchnorm::kGamma], sumall_except_dim<1>(grad * out_no_affine)); Assign(gbias, req[batchnorm::kBeta], sumall_except_dim<1>(grad)); Assign(grad_in, req[batchnorm::kData], (grad * broadcast<1>(slope, data.shape_)) * broadcast<1>(1.0f / F(var + param_.eps), data.shape_) + broadcast<1>(gvar, data.shape_) * scale * 2.0f * (data - broadcast<1>(mean, data.shape_)) + broadcast<1>(gmean, data.shape_) * scale); } private: BatchNormParam param_; }; // class BatchNormOp template Operator *CreateOp(BatchNormParam param); #if DMLC_USE_CXX11 class BatchNormProp : public OperatorProperty { public: void Init(const std::vector >& kwargs) override { param_.Init(kwargs); } std::map GetParams() const override { return param_.__DICT__(); } bool InferShape(std::vector *in_shape, std::vector *out_shape, std::vector *aux_shape) const override { using namespace mshadow; CHECK_EQ(in_shape->size(), 3) << "Input:[data, gamma, beta]"; const TShape &dshape = in_shape->at(0); if (dshape.ndim() == 0) return false; in_shape->at(1) = TShape(Shape1(dshape[1])); in_shape->at(2) = TShape(Shape1(dshape[1])); out_shape->clear(); out_shape->push_back(dshape); out_shape->push_back(dshape); out_shape->push_back(Shape1(dshape[1])); out_shape->push_back(Shape1(dshape[1])); aux_shape->clear(); aux_shape->push_back(Shape1(dshape[1])); aux_shape->push_back(Shape1(dshape[1])); return true; } OperatorProperty* Copy() const override { auto ptr = new BatchNormProp(); ptr->param_ = param_; return ptr; } std::string TypeString() const override { return "BatchNorm"; } std::vector DeclareBackwardDependency( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data) const override { return {out_grad[batchnorm::kOut], out_data[batchnorm::kOut], out_data[batchnorm::kOutNoAffine], out_data[batchnorm::kMean], out_data[batchnorm::kVar], in_data[batchnorm::kData], in_data[batchnorm::kGamma], in_data[batchnorm::kBeta] }; } std::vector > BackwardInplaceOption( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, const std::vector &in_grad) const override { return {{out_grad[batchnorm::kOut], in_grad[batchnorm::kData]}}; } std::vector BackwardResource( const std::vector &in_shape) const override { return {ResourceRequest::kTempSpace}; } int NumVisibleOutputs() const override { return 1; } int NumOutputs() const override { return 4; } std::vector ListArguments() const override { return {"data", "gamma", "beta"}; } std::vector ListOutputs() const override { return {"output", "output_no_affine", "mean", "var"}; } std::vector ListAuxiliaryStates() const override { return {"moving_mean", "moving_var"}; } Operator* CreateOperator(Context ctx) const override; private: BatchNormParam param_; }; // class BatchNormProp #endif // DMLC_USE_CXX11 } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_BATCH_NORM_INL_H_ //===== EXPANDED: mxnet/src/operator/batch_norm-inl.h ===== namespace mxnet { namespace op { template<> Operator *CreateOp(BatchNormParam param) { return new BatchNormOp(param); } Operator *BatchNormProp::CreateOperator(Context ctx) const { DO_BIND_DISPATCH(CreateOp, param_); } DMLC_REGISTER_PARAMETER(BatchNormParam); MXNET_REGISTER_OP_PROPERTY(BatchNorm, BatchNormProp) .describe("Apply batch normalization to input.") .add_argument("data", "Symbol", "Input data to batch normalization") .add_arguments(BatchNormParam::__FIELDS__()); } // namespace op } // namespace mxnet //===== EXPANDED: mxnet/src/operator/batch_norm.cc ===== //===== EXPANDIND: mxnet/src/operator/block_grad.cc ===== /*! * Copyright (c) 2015 by Contributors * \file block_grad.cc * \brief * \author Bing Xu */ //===== EXPANDIND: mxnet/src/operator/block_grad-inl.h ===== /*! * Copyright (c) 2015 by Contributors * \file block_grad-inl.h * \brief * \author Bing Xu */ #ifndef MXNET_OPERATOR_BLOCK_GRAD_INL_H_ #define MXNET_OPERATOR_BLOCK_GRAD_INL_H_ namespace mxnet { namespace op { namespace blockgrad { enum BlockGradientOpInputs {kData}; enum BlockGradientOpOutputs {kOut}; } // namespace blockgrad template class BlockGradientOp : public Operator { public: virtual void Forward(const OpContext &ctx, const std::vector &in_data, const std::vector &req, const std::vector &out_data, const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(in_data.size(), 1); CHECK_EQ(out_data.size(), 1); Stream *s = ctx.get_stream(); Tensor data = in_data[blockgrad::kData].FlatTo2D(s); Tensor out = out_data[blockgrad::kOut].FlatTo2D(s); out = F(data); } virtual void Backward(const OpContext &ctx, const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, const std::vector &req, const std::vector &in_grad, const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; Stream *s = ctx.get_stream(); Tensor grad = in_grad[blockgrad::kData].FlatTo2D(s); grad = 0.f; } }; // class BlockGradientOp template Operator *CreateOp(); #if DMLC_USE_CXX11 class BlockGradientProp : public OperatorProperty { public: void Init(const std::vector >& kwargs) override {} std::map GetParams() const override { return std::map(); } bool InferShape(std::vector *in_shape, std::vector *out_shape, std::vector *aux_shape) const override { using namespace mshadow; CHECK_EQ(in_shape->size(), 1); const TShape &dshape = in_shape->at(blockgrad::kData); if (dshape.ndim() == 0) return false; out_shape->clear(); out_shape->push_back(dshape); return true; } OperatorProperty* Copy() const override { return new BlockGradientProp(); } std::string TypeString() const override { return "BlockGrad"; } std::vector DeclareBackwardDependency( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data) const override { return {}; } std::vector > ForwardInplaceOption( const std::vector &in_data, const std::vector &out_data) const override { return {{in_data[blockgrad::kData], out_data[blockgrad::kOut]}}; } Operator* CreateOperator(Context ctx) const override; }; // class BlockGradientProperty #endif // DMLC_USE_CXX11 } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_BLOCK_GRAD_INL_H_ //===== EXPANDED: mxnet/src/operator/block_grad-inl.h ===== namespace mxnet { namespace op { template<> Operator *CreateOp() { return new BlockGradientOp(); } Operator *BlockGradientProp::CreateOperator(Context ctx) const { DO_BIND_DISPATCH(CreateOp); } MXNET_REGISTER_OP_PROPERTY(BlockGrad, BlockGradientProp) .describe("Get output from a symbol and pass 0 gradient back") .add_argument("data", "Symbol", "Input data."); } // namespace op } // namespace mxnet //===== EXPANDED: mxnet/src/operator/block_grad.cc ===== //===== EXPANDIND: mxnet/src/operator/concat.cc ===== /*! * Copyright (c) 2015 by Contributors * \file concat.cc * \brief * \author Bing Xu */ //===== EXPANDIND: mxnet/src/operator/concat-inl.h ===== /*! * Copyright (c) 2015 by Contributors * \file concat-inl.h * \brief * \author Bing Xu */ #ifndef MXNET_OPERATOR_CONCAT_INL_H_ #define MXNET_OPERATOR_CONCAT_INL_H_ //===== EXPANDIND: mxnet/src/operator/channel_op_common.h ===== /*! * Copyright (c) 2015 by Contributors * \file channel_op_common.h * \brief common function used for concat and split channel * \author Bing Xu */ #ifndef MXNET_OPERATOR_CHANNEL_OP_COMMON_H_ #define MXNET_OPERATOR_CHANNEL_OP_COMMON_H_ namespace mxnet { namespace op { template inline void Concatenate(const std::vector > &input, mshadow::Tensor *output) { using mshadow::expr::concat; using mshadow::expr::slice; mshadow::Tensor out = *output; size_t size = input.size(); switch (size) { case 2: { out = concat<1>(input[0], input[1]); break; } case 3: { out = concat<1>(input[0], concat<1>(input[1], input[2])); break; } case 4: { out = concat<1>(input[0], concat<1>(input[1], concat<1>(input[2], input[3]))); break; } default: { index_t begin = 0; for (index_t i = 0; i < size; ++i) { index_t end = begin + input[i].size(1); slice<1>(out, begin, end) = input[i]; begin = end; } break; } } } template void Split(const mshadow::Tensor &input, std::vector > *output) { using mshadow::expr::concat; using mshadow::expr::slice; std::vector > out = *output; size_t size = out.size(); switch (size) { case 2: { concat<1>(out[0], out[1]) = input; break; } case 3: { concat<1>(out[0], concat<1>(out[1], out[2])) = input; break; } case 4: { concat<1>(out[0], concat<1>(out[1], concat<1>(out[2], out[3]))) = input; break; } default: { index_t begin = 0; for (index_t i = 0; i < size; ++i) { index_t end = begin + out[i].size(1); out[i] = slice<1>(input, begin, end); begin = end; } break; } } } } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_CHANNEL_OP_COMMON_H_ //===== EXPANDED: mxnet/src/operator/channel_op_common.h ===== namespace mxnet { namespace op { namespace concat_enum { enum ConcatOpInputs {kData0, kData1, kData2, kData3, kData4}; enum ConcatOpOutputs {kOut}; } // namespace concat_enum struct ConcatParam : public dmlc::Parameter { int num_args; DMLC_DECLARE_PARAMETER(ConcatParam) { DMLC_DECLARE_FIELD(num_args).set_lower_bound(1) .describe("Number of inputs to be concated."); } }; // struct ConcatParam template class ConcatOp : public Operator { public: explicit ConcatOp(ConcatParam param) : size_(param.num_args) {} virtual void Forward(const OpContext &ctx, const std::vector &in_data, const std::vector &req, const std::vector &out_data, const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(static_cast(in_data.size()), size_); CHECK_EQ(out_data.size(), 1); CHECK_EQ(req[concat_enum::kOut], kWriteTo); Stream *s = ctx.get_stream(); std::vector > data(size_); Tensor out; if (in_data[concat_enum::kData0].ndim() == 2) { uint32_t dim = 0; for (int i = 0; i < size_; ++i) { Shape<4> dshape = Shape4(in_data[i].shape_[0], in_data[i].shape_[1], 1, 1); data[i] = in_data[i].get_with_shape(dshape, s); dim += in_data[i].shape_[1]; } Shape<4> dshape_out = Shape4(in_data[concat_enum::kData0].shape_[0], dim, 1, 1); out = out_data[concat_enum::kOut].get_with_shape(dshape_out, s); } else { for (int i = 0; i < size_; ++i) { data[i] = in_data[i].get(s); } out = out_data[concat_enum::kOut].get(s); } Concatenate(data, &out); } virtual void Backward(const OpContext &ctx, const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, const std::vector &req, const std::vector &in_grad, const std::vector &aux_states) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(out_grad.size(), 1); CHECK_EQ(in_grad.size(), static_cast(size_)); Stream *s = ctx.get_stream(); std::vector > grad_in(size_); Tensor grad; if (out_grad[concat_enum::kOut].ndim() == 2) { uint32_t dim = 0; for (int i = 0; i < size_; ++i) { Shape<4> dshape = Shape4(in_grad[i].shape_[0], in_grad[i].shape_[1], 1, 1); grad_in[i] = in_grad[i].get_with_shape(dshape, s); dim += in_grad[i].shape_[1]; CHECK_EQ(req[i], kWriteTo); } Shape<4> dshape_out = Shape4(in_grad[concat_enum::kData0].shape_[0], dim, 1, 1); grad = out_grad[concat_enum::kOut].get_with_shape(dshape_out, s); } else { for (int i = 0; i < size_; ++i) { grad_in[i] = in_grad[i].get(s); CHECK_EQ(req[i], kWriteTo); } grad = out_grad[concat_enum::kOut].get(s); } Split(grad, &grad_in); } private: int size_; }; // class ConcatOp template Operator *CreateOp(ConcatParam param); #if DMLC_USE_CXX11 class ConcatProp : public OperatorProperty { public: void Init(const std::vector >& kwargs) override { param_.Init(kwargs); } std::map GetParams() const override { return param_.__DICT__(); } std::vector ListArguments() const override { std::vector ret; for (int i = 0; i < param_.num_args; ++i) { ret.push_back(std::string("arg") + static_cast('0' + i)); } return ret; } bool InferShape(std::vector *in_shape, std::vector *out_shape, std::vector *aux_shape) const override { using namespace mshadow; CHECK_EQ(in_shape->size(), static_cast(param_.num_args)); TShape dshape = in_shape->at(concat_enum::kData0); if (dshape.ndim() == 0) return false; CHECK_GT(dshape.ndim(), 1); for (int i = 1; i < param_.num_args; ++i) { const TShape &tmp = in_shape->at(i); if (tmp.ndim() == 0) return false; for (uint32_t j = 0; j < dshape.ndim(); ++j) { if (j == 1) { dshape[1] += tmp[1]; } else { CHECK_EQ(dshape[j], tmp[j]) << "Incorrect shape[" << i << "]: " << tmp << ". " << "(first input shape: " << dshape << ")"; } } } out_shape->clear(); out_shape->push_back(dshape); return true; } OperatorProperty* Copy() const override { auto ptr = new ConcatProp(); ptr->param_ = param_; return ptr; } std::string TypeString() const override { return "Concat"; } std::vector DeclareBackwardDependency( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data) const override { return out_grad; } Operator* CreateOperator(Context ctx) const override; private: ConcatParam param_; }; // class ConcatProp #endif // DMLC_USE_CXX11 } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_CONCAT_INL_H_ //===== EXPANDED: mxnet/src/operator/concat-inl.h ===== namespace mxnet { namespace op { template<> Operator* CreateOp(ConcatParam param) { return new ConcatOp(param); } Operator* ConcatProp::CreateOperator(Context ctx) const { DO_BIND_DISPATCH(CreateOp, param_); } DMLC_REGISTER_PARAMETER(ConcatParam); MXNET_REGISTER_OP_PROPERTY(Concat, ConcatProp) .describe("Perform an feature concat on channel dim (dim 1) over all the inputs.") .add_arguments(ConcatParam::__FIELDS__()) .set_key_var_num_args("num_args"); } // namespace op } // namespace mxnet //===== EXPANDED: mxnet/src/operator/concat.cc ===== //===== EXPANDIND: mxnet/src/operator/convolution.cc ===== /*! * Copyright (c) 2015 by Contributors * \file convolution.cc * \brief * \author Bing Xu */ //===== EXPANDIND: mxnet/src/operator/convolution-inl.h ===== /*! * Copyright (c) 2015 by Contributors * \file convolution-inl.h * \brief * \author Bing Xu */ #ifndef MXNET_OPERATOR_CONVOLUTION_INL_H_ #define MXNET_OPERATOR_CONVOLUTION_INL_H_ namespace mxnet { namespace op { namespace conv { enum ConvolutionOpInputs {kData, kWeight, kBias}; enum ConvolutionOpOutputs {kOut}; enum ConvolutionOpResource {kTempSpace}; } struct ConvolutionParam : public dmlc::Parameter { TShape kernel; TShape stride; TShape pad; uint32_t num_filter; uint32_t num_group; uint64_t workspace; bool no_bias; DMLC_DECLARE_PARAMETER(ConvolutionParam) { int shape[] = {1, 1}; DMLC_DECLARE_FIELD(kernel).describe("convolution kernel size: (y, x)"); DMLC_DECLARE_FIELD(stride).set_default(TShape(shape, shape + 2)) .describe("convolution stride: (y, x)"); shape[0] = shape[1] = 0; DMLC_DECLARE_FIELD(pad).set_default(TShape(shape, shape + 2)) .describe("pad for convolution: (y, x)"); DMLC_DECLARE_FIELD(num_filter).set_range(1, 100000) .describe("convolution filter(channel) number"); DMLC_DECLARE_FIELD(num_group).set_default(1) .describe("Number of groups partition. " "This option is not supported by CuDNN, you can use SliceChannel to num_group," "apply convolution and concat instead to achieve the same need."); DMLC_DECLARE_FIELD(workspace).set_default(512).set_range(128, 4096) .describe("Tmp workspace for convolution (MB)"); DMLC_DECLARE_FIELD(no_bias).set_default(false) .describe("Whether to disable bias parameter."); } }; template class ConvolutionOp : public Operator { public: explicit ConvolutionOp(ConvolutionParam p) { this->param_ = p; // convert MB to words param_.workspace = (param_.workspace << 20) / sizeof(real_t); } virtual void Forward(const OpContext &ctx, const std::vector &in_data, const std::vector &req, const std::vector &out_data, const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(req[conv::kOut], kWriteTo); size_t expected = param_.no_bias ? 2 : 3; CHECK_EQ(in_data.size(), expected); CHECK_EQ(out_data.size(), 1); Stream *s = ctx.get_stream(); Tensor data = in_data[conv::kData].get(s); Shape<3> wmat_shape = Shape3(param_.num_group, param_.num_filter / param_.num_group, data.shape_[1] / param_.num_group * param_.kernel[0] * param_.kernel[1]); Tensor wmat = in_data[conv::kWeight].get_with_shape(wmat_shape, s); Tensor out = out_data[conv::kOut].get(s); #if defined(__CUDACC__) CHECK_EQ(s->blas_handle_ownership_, Stream::OwnHandle) << "Must init CuBLAS handle in stream"; #endif const index_t nbatch = data.size(0); Tensor workspace = ctx.requested[conv::kTempSpace].get_space( Shape1(this->InitTemp(data.shape_, out.shape_)), s); for (index_t i = 0; i < nbatch; i += nstep_) { const index_t step = std::min(nstep_, nbatch - i); Tensor temp_col = Tensor(workspace.dptr_, Shape2(shape_colunit_[0], shape_colunit_[1] * step), s); Tensor temp_dst = Tensor(workspace.dptr_ + temp_col.shape_.Size(), Shape3(shape_dstunit_[0], shape_dstunit_[1], shape_dstunit_[2] * step), s); if (param_.pad[0] == 0 && param_.pad[1] == 0) { temp_col = unpack_patch2col(data.Slice(i, i + step), param_.kernel[0], param_.kernel[1], param_.stride[0], param_.stride[1]); } else { temp_col = unpack_patch2col(pad(data.Slice(i, i + step), param_.pad[0], param_.pad[1]), param_.kernel[0], param_.kernel[1], param_.stride[0], param_.stride[1]); } const index_t gstride = temp_col.size(0) / param_.num_group; for (uint32_t gid = 0; gid < param_.num_group; ++gid) { mshadow::Tensor tmpc = temp_col.Slice(gstride * gid, gstride * (gid + 1)); temp_dst[gid] = dot(wmat[gid], tmpc); } out.Slice(i, i + step) = swapaxis<1, 0>(reshape(temp_dst, mshadow::Shape4(param_.num_filter, step, out.size(2), out.size(3)))); } if (!param_.no_bias) { // add bias, broadcast bias to dim 1: channel Tensor bias = in_data[conv::kBias].get(s); out += broadcast<1>(bias, out.shape_); } } virtual void Backward(const OpContext &ctx, const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, const std::vector &req, const std::vector &in_grad, const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; // TODO(bing): check the BLAS Handle, be careful CHECK_EQ(out_grad.size(), 1); size_t expected = param_.no_bias == 0 ? 3 : 2; CHECK(in_data.size() == expected && in_grad.size() == expected); CHECK_EQ(req.size(), expected); CHECK_EQ(in_data[conv::kWeight].CheckContiguous(), true); // get data Stream *s = ctx.get_stream(); Tensor data = in_data[conv::kData].get(s); Shape<3> wmat_shape = Shape3(param_.num_group, param_.num_filter / param_.num_group, data.shape_[1] / param_.num_group * param_.kernel[0] * param_.kernel[1]); Tensor wmat = in_data[conv::kWeight].get_with_shape(wmat_shape, s); Tensor grad = out_grad[conv::kOut].get(s); Tensor gdata = in_grad[conv::kData].get(s); Tensor gwmat = in_grad[conv::kWeight].get_with_shape(wmat_shape, s); #if defined(__CUDACC__) CHECK_EQ(s->blas_handle_ownership_, Stream::OwnHandle) << "Must init CuBLAS handle in stream"; #endif const index_t nbatch = data.size(0); Tensor workspace = ctx.requested[conv::kTempSpace].get_space( Shape1(this->InitTemp(data.shape_, grad.shape_)), s); for (index_t i = 0; i < nbatch; i += nstep_) { const index_t step = std::min(nstep_, nbatch - i); Tensor temp_col = Tensor(workspace.dptr_, Shape2(shape_colunit_[0], shape_colunit_[1] * step), s); Tensor temp_dst = Tensor(workspace.dptr_ + temp_col.shape_.Size(), Shape3(shape_dstunit_[0], shape_dstunit_[1], shape_dstunit_[2] * step), s); temp_dst = reshape(swapaxis<1, 0>(grad.Slice(i, i + step)), temp_dst.shape_); if (param_.pad[0] == 0 && param_.pad[1] == 0) { temp_col = unpack_patch2col(data.Slice(i, i + step), param_.kernel[0], param_.kernel[1], param_.stride[0], param_.stride[1]); } else { temp_col = unpack_patch2col(pad(data.Slice(i, i + step), param_.pad[0], param_.pad[1]), param_.kernel[0], param_.kernel[1], param_.stride[0], param_.stride[1]); } const index_t gstride = temp_col.size(0) / param_.num_group; for (uint32_t gid = 0; gid < param_.num_group; ++gid) { Tensor tmpc = temp_col.Slice(gstride * gid, gstride * (gid + 1)); if (i == 0) { Tensor tmp_gwmat = gwmat[gid]; Assign(tmp_gwmat, req[conv::kWeight], dot(temp_dst[gid], tmpc.T())); } else { gwmat[gid] += dot(temp_dst[gid], tmpc.T()); } } if (req[conv::kData] == kWriteTo || req[conv::kData] == kWriteInplace) { for (uint32_t gid = 0; gid < param_.num_group; ++gid) { Tensor tmpc = temp_col.Slice(gstride * gid, gstride * (gid + 1)); tmpc = dot(wmat[gid].T(), temp_dst[gid]); } if (param_.pad[0] == 0 && param_.pad[1] == 0) { gdata.Slice(i, i + step) = pack_col2patch(temp_col, data.Slice(i, i + step).shape_, param_.kernel[0], param_.kernel[1], param_.stride[0]); } else { Shape<4> pshape = data.Slice(i, i + step).shape_; pshape[2] += 2 * param_.pad[0]; pshape[3] += 2 * param_.pad[1]; gdata.Slice(i, i + step) = crop(pack_col2patch(temp_col, pshape, param_.kernel[0], param_.kernel[1], param_.stride[0]), gdata[i][0].shape_); } } } if (!param_.no_bias) { Tensor gbias = in_grad[conv::kBias].get(s); Assign(gbias, req[conv::kBias], sumall_except_dim<1>(grad)); } } private: inline index_t InitTemp(const mshadow::Shape<4> &ishape, const mshadow::Shape<4> &oshape) { const int ksize_y = param_.kernel[0]; const int ksize_x = param_.kernel[1]; shape_colunit_ = mshadow::Shape2(ishape[1] * ksize_y * ksize_x, oshape[2] * oshape[3]); shape_dstunit_ = mshadow::Shape3(param_.num_group, param_.num_filter / param_.num_group, oshape[2] * oshape[3]); const uint64_t workspace_size = param_.workspace; nstep_ = std::max(std::min(static_cast(workspace_size / shape_colunit_.Size()), ishape[0]), 1U); int nop = (ishape[0] + nstep_ - 1) / nstep_; nstep_ = (ishape[0] + nop - 1) / nop; mshadow::Shape<2> scol = mshadow::Shape2(shape_colunit_[0], shape_colunit_[1] * nstep_); mshadow::Shape<3> sdst = mshadow::Shape3(shape_dstunit_[0], shape_dstunit_[1], shape_dstunit_[2] * nstep_); CHECK_GE(param_.workspace, scol.Size() + sdst.Size()) << "\nMinimum workspace size: " << scol.Size() + sdst.Size() << "\n" << "Given: " << param_.workspace; return scol.Size() + sdst.Size(); } ConvolutionParam param_; mshadow::Shape<2> shape_colunit_; mshadow::Shape<3> shape_dstunit_; index_t nstep_; }; // class ConvolutionOp template Operator* CreateOp(ConvolutionParam param); #if DMLC_USE_CXX11 class ConvolutionProp : public OperatorProperty { public: std::vector ListArguments() const override { if (!param_.no_bias) { return {"data", "weight", "bias"}; } else { return {"data", "weight"}; } } void Init(const std::vector >& kwargs) override { param_.Init(kwargs); } std::map GetParams() const override { return param_.__DICT__(); } bool InferShape(std::vector *in_shape, std::vector *out_shape, std::vector *aux_shape) const override { using namespace mshadow; if (!param_.no_bias) { CHECK_EQ(in_shape->size(), 3) << "Input:[data, weight, bias]"; } else { CHECK_EQ(in_shape->size(), 2) << "Input:[data, weight]"; } const TShape &dshape = (*in_shape)[conv::kData]; if (dshape.ndim() == 0) return false; CHECK_EQ(dshape.ndim(), 4) \ << "Input data should be 4D in batch-num_filter-y-x"; SHAPE_ASSIGN_CHECK(*in_shape, conv::kWeight, Shape4(param_.num_filter, dshape[1], param_.kernel[0], param_.kernel[1])); if (!param_.no_bias) { SHAPE_ASSIGN_CHECK(*in_shape, conv::kBias, Shape1(param_.num_filter)); } out_shape->clear(); out_shape->push_back(dshape); const index_t ksize_y = static_cast(param_.kernel[0]); const index_t ksize_x = static_cast(param_.kernel[1]); CHECK_EQ(dshape[1] % param_.num_group, 0) \ << "input num_filter must divide group size"; CHECK_EQ(param_.num_filter % param_.num_group, 0) \ << "output num_filter must divide group size"; CHECK_GE(param_.kernel.Size(), 0) \ << "incorrect kernel size: " << param_.kernel; CHECK_GE(param_.stride.Size(), 0) \ << "incorrect stride size: " << param_.stride; CHECK(ksize_x <= dshape[3] && ksize_y <= dshape[2]) << "kernel size exceed input"; (*out_shape)[conv::kOut][1] = param_.num_filter; (*out_shape)[conv::kOut][2] = (dshape[2] + 2 * param_.pad[0] - ksize_y) / param_.stride[0] + 1; (*out_shape)[conv::kOut][3] = (dshape[3] + 2 * param_.pad[1] - ksize_x) / param_.stride[1] + 1; return true; } OperatorProperty* Copy() const override { auto ptr = new ConvolutionProp(); ptr->param_ = param_; return ptr; } std::string TypeString() const override { return "Convolution"; } std::vector DeclareBackwardDependency( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data) const override { return {out_grad[conv::kOut], in_data[conv::kData], in_data[conv::kWeight]}; } std::vector ForwardResource( const std::vector &in_shape) const override { return {ResourceRequest::kTempSpace}; } std::vector BackwardResource( const std::vector &in_shape) const override { return {ResourceRequest::kTempSpace}; } Operator* CreateOperator(Context ctx) const override; private: ConvolutionParam param_; }; // class ConvolutionProp #endif // DMLC_USE_CXX11 } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_CONVOLUTION_INL_H_ //===== EXPANDED: mxnet/src/operator/convolution-inl.h ===== namespace mxnet { namespace op { template<> Operator* CreateOp(ConvolutionParam param) { return new ConvolutionOp(param); } Operator* ConvolutionProp::CreateOperator(Context ctx) const { DO_BIND_DISPATCH(CreateOp, param_); } DMLC_REGISTER_PARAMETER(ConvolutionParam); MXNET_REGISTER_OP_PROPERTY(Convolution, ConvolutionProp) .add_argument("data", "Symbol", "Input data to the ConvolutionOp.") .add_argument("weight", "Symbol", "Weight matrix.") .add_argument("bias", "Symbol", "Bias parameter.") .add_arguments(ConvolutionParam::__FIELDS__()) .describe("Apply convolution to input then add a bias."); } // namespace op } // namespace mxnet //===== EXPANDED: mxnet/src/operator/convolution.cc ===== //===== EXPANDIND: mxnet/src/operator/dropout.cc ===== /*! * Copyright (c) 2015 by Contributors * \file dropout.cc * \brief * \author Bing Xu */ //===== EXPANDIND: mxnet/src/operator/dropout-inl.h ===== /*! * Copyright (c) 2015 by Contributors * \file dropout-inl.h * \brief * \author Bing Xu */ #ifndef MXNET_OPERATOR_DROPOUT_INL_H_ #define MXNET_OPERATOR_DROPOUT_INL_H_ namespace dropout { enum DropoutOpInputs {kData}; enum DropoutOpOutputs {kOut, kMask}; enum DropoutOpForwardResource {kRandom}; } // namespace dropout namespace mxnet { namespace op { struct DropoutParam : public dmlc::Parameter { float p; DMLC_DECLARE_PARAMETER(DropoutParam) { DMLC_DECLARE_FIELD(p).set_default(0.5) .set_range(0, 1) .describe("Fraction of the input that gets dropped out at training time"); } }; // struct DropoutParam template class DropoutOp : public Operator { public: explicit DropoutOp(DropoutParam param) { this->pkeep_ = 1.0f - param.p; } virtual void Forward(const OpContext &ctx, const std::vector &in_data, const std::vector &req, const std::vector &out_data, const std::vector &aux_states) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(in_data.size(), 1); if (ctx.is_train) { CHECK_EQ(out_data.size(), 2); } Stream *s = ctx.get_stream(); Tensor data = in_data[dropout::kData].FlatTo2D(s); Tensor out = out_data[dropout::kOut].FlatTo2D(s); if (ctx.is_train) { Tensor mask = out_data[dropout::kMask].FlatTo2D(s); Random *prnd = ctx.requested[dropout::kRandom].get_random(s); mask = F(prnd->uniform(mask.shape_), pkeep_) * (1.0f / pkeep_); Assign(out, req[dropout::kOut], data * mask); } else { Assign(out, req[dropout::kOut], F(data)); } } virtual void Backward(const OpContext &ctx, const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, const std::vector &req, const std::vector &in_grad, const std::vector &aux_states) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(out_grad.size(), 1); CHECK_EQ(in_grad.size(), 1); Stream *s = ctx.get_stream(); Tensor grad = out_grad[dropout::kOut].FlatTo2D(s); Tensor mask = out_data[dropout::kMask].FlatTo2D(s); Tensor gdata = in_grad[dropout::kData].FlatTo2D(s); Assign(gdata, req[dropout::kData], grad * mask); } private: real_t pkeep_; }; // class DropoutOp template Operator *CreateOp(DropoutParam param); #if DMLC_USE_CXX11 class DropoutProp : public OperatorProperty { public: void Init(const std::vector >& kwargs) override { param_.Init(kwargs); } std::map GetParams() const override { return param_.__DICT__(); } bool InferShape(std::vector *in_shape, std::vector *out_shape, std::vector *aux_shape) const override { using namespace mshadow; CHECK_EQ(in_shape->size(), 1); const TShape &dshape = in_shape->at(0); if (dshape.ndim() == 0) return false; out_shape->clear(); out_shape->push_back(dshape); out_shape->push_back(dshape); return true; } OperatorProperty* Copy() const override { auto ptr = new DropoutProp(); ptr->param_ = param_; return ptr; } std::string TypeString() const override { return "Dropout"; } std::vector DeclareBackwardDependency( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data) const override { return {out_grad[dropout::kOut], out_data[dropout::kMask]}; } std::vector > BackwardInplaceOption( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, const std::vector &in_grad) const override { return {{out_grad[dropout::kOut], in_grad[dropout::kData]}}; } std::vector > ForwardInplaceOption( const std::vector &in_data, const std::vector &out_data) const override { return {{in_data[dropout::kData], out_data[dropout::kOut]}}; } std::vector ForwardResource( const std::vector &in_shape) const override { return {ResourceRequest::kRandom}; } int NumVisibleOutputs() const override { return 1; } int NumOutputs() const override { return 2; } std::vector ListOutputs() const override { return {"output", "mask"}; } Operator* CreateOperator(Context ctx) const override; private: DropoutParam param_; }; // class DropoutProp #endif // DMLC_USE_CXX11 } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_DROPOUT_INL_H_ //===== EXPANDED: mxnet/src/operator/dropout-inl.h ===== namespace mxnet { namespace op { template<> Operator *CreateOp(DropoutParam param) { return new DropoutOp(param); } // DO_BIND_DISPATCH comes from operator_common.h Operator *DropoutProp::CreateOperator(Context ctx) const { DO_BIND_DISPATCH(CreateOp, param_); } DMLC_REGISTER_PARAMETER(DropoutParam); MXNET_REGISTER_OP_PROPERTY(Dropout, DropoutProp) .describe("Apply dropout to input") .add_argument("data", "Symbol", "Input data to dropout.") .add_arguments(DropoutParam::__FIELDS__()); } // namespace op } // namespace mxnet //===== EXPANDED: mxnet/src/operator/dropout.cc ===== //===== EXPANDIND: mxnet/src/operator/elementwise_binary_op.cc ===== /*! * Copyright (c) 2015 by Contributors * \file elementwise_binary_op.cc * \brief elementwise binary operator */ //===== EXPANDIND: mxnet/src/operator/elementwise_binary_op-inl.h ===== /*! * Copyright (c) 2015 by Contributors * \file elementwise_binary_op-inl.h * \brief Elementwise binary operation, plus, minus, mul, div */ #ifndef MXNET_OPERATOR_ELEMENTWISE_BINARY_OP_INL_H_ #define MXNET_OPERATOR_ELEMENTWISE_BINARY_OP_INL_H_ namespace mxnet { namespace op { namespace elembinary { enum ElementWiseBinaryOpInputs {kLhs, kRhs}; enum ElementWiseBinaryOpOutputs {kOut}; enum ElementWiseBinaryOpType {kPlus, kMinus, kMul, kDiv}; } // elembinary template inline elembinary::ElementWiseBinaryOpType GetOpType(); template inline const char* GetOpTypeString(); template<> inline elembinary::ElementWiseBinaryOpType GetOpType() { return elembinary::kPlus; } template<> inline elembinary::ElementWiseBinaryOpType GetOpType() { return elembinary::kMinus; } template<> inline elembinary::ElementWiseBinaryOpType GetOpType() { return elembinary::kMul; } template<> inline elembinary::ElementWiseBinaryOpType GetOpType() { return elembinary::kDiv; } template<> inline const char* GetOpTypeString() { return "_Plus"; } template<> inline const char* GetOpTypeString() { return "_Minus"; } template<> inline const char* GetOpTypeString() { return "_Mul"; } template<> inline const char* GetOpTypeString() { return "_Div"; } template class ElementWiseBinaryOp : public Operator { public: virtual void Forward(const OpContext &ctx, const std::vector &in_data, const std::vector &req, const std::vector &out_data, const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(in_data.size(), 2); CHECK_EQ(out_data.size(), 1); Stream *s = ctx.get_stream(); Tensor lhs = in_data[elembinary::kLhs].FlatTo2D(s); Tensor rhs = in_data[elembinary::kRhs].FlatTo2D(s); Tensor out = out_data[elembinary::kOut].FlatTo2D(s); Assign(out, req[elembinary::kOut], F(lhs, rhs)); } virtual void Backward(const OpContext &ctx, const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, const std::vector &req, const std::vector &in_grad, const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(out_grad.size(), 1); CHECK(in_data.size() == 2 && in_grad.size() == 2); CHECK_EQ(req.size(), 2); Stream *s = ctx.get_stream(); Tensor m_out_grad = out_grad[elembinary::kOut].FlatTo2D(s); Tensor lhs_grad = in_grad[elembinary::kLhs].FlatTo2D(s); Tensor rhs_grad = in_grad[elembinary::kRhs].FlatTo2D(s); switch (GetOpType()) { case elembinary::kPlus: { Assign(lhs_grad, req[elembinary::kLhs], F(m_out_grad)); Assign(rhs_grad, req[elembinary::kRhs], F(m_out_grad)); break; } case elembinary::kMinus: { Assign(lhs_grad, req[elembinary::kLhs], F(m_out_grad)); Assign(rhs_grad, req[elembinary::kRhs], F(m_out_grad)); break; } case elembinary::kMul: { Tensor lhs_data = in_data[elembinary::kLhs].FlatTo2D(s); Tensor rhs_data = in_data[elembinary::kRhs].FlatTo2D(s); // rhs cannot do inplace CHECK_NE(req[elembinary::kRhs], kWriteInplace); Assign(rhs_grad, req[elembinary::kRhs], lhs_data * m_out_grad); Assign(lhs_grad, req[elembinary::kLhs], rhs_data * m_out_grad); break; } case elembinary::kDiv: { Tensor lhs_data = in_data[elembinary::kLhs].FlatTo2D(s); Tensor rhs_data = in_data[elembinary::kRhs].FlatTo2D(s); // rhs cannot do inplace CHECK_NE(req[elembinary::kRhs], kWriteInplace); Assign(rhs_grad, req[elembinary::kRhs], F(m_out_grad * lhs_data) / F(rhs_data)); Assign(lhs_grad, req[elembinary::kLhs], m_out_grad / rhs_data); break; } } } }; // class ElementWiseBinaryOp template inline Operator* CreateElementWiseBinaryOp_(elembinary::ElementWiseBinaryOpType type) { switch (type) { case elembinary::kPlus: return new ElementWiseBinaryOp(); case elembinary::kMinus: return new ElementWiseBinaryOp(); case elembinary::kMul: return new ElementWiseBinaryOp(); case elembinary::kDiv: return new ElementWiseBinaryOp(); } LOG(FATAL) << "uknown op type"; return NULL; } // Decalre Factory function, used for dispatch specialization template Operator* CreateElementWiseBinaryOp(elembinary::ElementWiseBinaryOpType type); #if DMLC_USE_CXX11 template class ElementWiseBinaryOpProp : public OperatorProperty { public: void Init(const std::vector >& kwargs) override { CHECK_EQ(kwargs.size(), 0) << TypeString() << " do not take any additional keyword arguments besides lhs and rhs"; } std::map GetParams() const override { return std::map(); } bool InferShape(std::vector *in_shape, std::vector *out_shape, std::vector *aux_shape) const override { using namespace mshadow; CHECK_EQ(in_shape->size(), 2) << "Input:[lhs, rhs]"; if (in_shape->at(elembinary::kLhs).ndim() != 0) { SHAPE_ASSIGN_CHECK(*in_shape, elembinary::kRhs, in_shape->at(elembinary::kLhs)); } else if (in_shape->at(elembinary::kRhs).ndim() != 0) { in_shape->at(elembinary::kLhs) = in_shape->at(elembinary::kRhs); } else { return false; } const TShape &dshape = in_shape->at(elembinary::kLhs); out_shape->clear(); out_shape->push_back(dshape); return true; } std::vector ListArguments() const override { return {"lhs", "rhs"}; } OperatorProperty* Copy() const override { return new ElementWiseBinaryOpProp(); } std::string TypeString() const override { return GetOpTypeString(); } // decalre dependency and inplace optimization options std::vector DeclareBackwardDependency( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data) const override { switch (GetOpType()) { case elembinary::kPlus: case elembinary::kMinus: return {out_grad[elembinary::kOut]}; case elembinary::kMul: case elembinary::kDiv: return {out_grad[elembinary::kOut], in_data[elembinary::kLhs], in_data[elembinary::kRhs]}; } LOG(FATAL) << "not reached"; return {}; } std::vector > BackwardInplaceOption( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, const std::vector &in_grad) const override { switch (GetOpType()) { case elembinary::kPlus: case elembinary::kMinus: return {}; case elembinary::kMul: case elembinary::kDiv: return {{out_grad[elembinary::kOut], in_grad[elembinary::kLhs]}}; } LOG(FATAL) << "not reached"; return {}; } std::vector > ForwardInplaceOption( const std::vector &in_data, const std::vector &out_data) const override { return {{in_data[elembinary::kLhs], out_data[elembinary::kOut]}}; } Operator* CreateOperator(Context ctx) const override; }; #endif // DMLC_USE_CXX11 } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_ELEMENTWISE_BINARY_OP_INL_H_ //===== EXPANDED: mxnet/src/operator/elementwise_binary_op-inl.h ===== namespace mxnet { namespace op { template<> Operator* CreateElementWiseBinaryOp(elembinary::ElementWiseBinaryOpType type) { return CreateElementWiseBinaryOp_(type); } // DO_BIND_DISPATCH comes from static_operator_common.h template Operator* ElementWiseBinaryOpProp::CreateOperator(Context ctx) const { DO_BIND_DISPATCH(CreateElementWiseBinaryOp, GetOpType()); } MXNET_REGISTER_OP_PROPERTY(_Plus, ElementWiseBinaryOpProp) .describe("Perform an elementwise plus."); MXNET_REGISTER_OP_PROPERTY(_Minus, ElementWiseBinaryOpProp) .describe("Perform an elementwise minus."); MXNET_REGISTER_OP_PROPERTY(_Mul, ElementWiseBinaryOpProp) .describe("Perform an elementwise mul."); MXNET_REGISTER_OP_PROPERTY(_Div, ElementWiseBinaryOpProp) .describe("Perform an elementwise div."); } // namespace op } // namespace mxnet //===== EXPANDED: mxnet/src/operator/elementwise_binary_op.cc ===== //===== EXPANDIND: mxnet/src/operator/elementwise_sum.cc ===== /*! * Copyright (c) 2015 by Contributors * \file elementwise_sum.cc * \brief elementwise sum operator */ //===== EXPANDIND: mxnet/src/operator/elementwise_sum-inl.h ===== /*! * Copyright (c) 2015 by Contributors * \file elemementwise_sum-inl.h * \brief elementwise sum * \author Bing Xu */ #ifndef MXNET_OPERATOR_ELEMENTWISE_SUM_INL_H_ #define MXNET_OPERATOR_ELEMENTWISE_SUM_INL_H_ namespace mxnet { namespace op { namespace elemsum { enum ElementWiseSumOpInputs {kData0, kData1, kData2, kData3}; enum ElementWiseSumOpOutputs {kOut}; } // namespace elemsum struct ElementWiseSumParam : public dmlc::Parameter { int num_args; DMLC_DECLARE_PARAMETER(ElementWiseSumParam) { DMLC_DECLARE_FIELD(num_args).set_lower_bound(1) .describe("Number of inputs to be sumed."); } }; template class ElementWiseSumOp : public Operator { public: explicit ElementWiseSumOp(ElementWiseSumParam param) : size_(param.num_args) {} virtual void Forward(const OpContext &ctx, const std::vector &in_data, const std::vector &req, const std::vector &out_data, const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(static_cast(in_data.size()), size_); CHECK_EQ(out_data.size(), 1); if (req[elemsum::kOut] == kNullOp) return; Stream *s = ctx.get_stream(); Tensor out = out_data[elemsum::kOut].FlatTo2D(s); switch (size_) { case 2: { Tensor in_0 = in_data[elemsum::kData0].FlatTo2D(s); Tensor in_1 = in_data[elemsum::kData1].FlatTo2D(s); Assign(out, req[elemsum::kOut], in_0 + in_1); break; } case 3: { Tensor in_0 = in_data[elemsum::kData0].FlatTo2D(s); Tensor in_1 = in_data[elemsum::kData1].FlatTo2D(s); Tensor in_2 = in_data[elemsum::kData2].FlatTo2D(s); Assign(out, req[elemsum::kOut], in_0 + in_1 + in_2); break; } case 4: { Tensor in_0 = in_data[elemsum::kData0].FlatTo2D(s); Tensor in_1 = in_data[elemsum::kData1].FlatTo2D(s); Tensor in_2 = in_data[elemsum::kData2].FlatTo2D(s); Tensor in_3 = in_data[elemsum::kData3].FlatTo2D(s); Assign(out, req[elemsum::kOut], in_0 + in_1 + in_2 + in_3); break; } default: { Tensor in_0 = in_data[elemsum::kData0].FlatTo2D(s); Assign(out, req[elemsum::kOut], F(in_0)); for (int i = 1; i < size_; ++i) { out += in_data[i].FlatTo2D(s); } break; } } } virtual void Backward(const OpContext &ctx, const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, const std::vector &req, const std::vector &in_grad, const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(in_grad.size(), static_cast(size_)); Stream *s = ctx.get_stream(); Tensor ograd = out_grad[elemsum::kOut].FlatTo2D(s); for (int i = 0; i < size_; ++i) { if (req[i] == kNullOp || req[i] == kWriteInplace) continue; Tensor igrad = in_grad[i].FlatTo2D(s); Assign(igrad, req[i], F(ograd)); } } inline void Save(dmlc::JSONWriter *writer) const { writer->BeginObject(); writer->WriteObjectKeyValue("size_", size_); writer->EndObject(); } inline void Load(dmlc::JSONReader *reader) { dmlc::JSONObjectReadHelper helper; helper.DeclareField("size_", &size_); helper.ReadAllFields(reader); } private: int size_; }; // class ElementWiseSumOp template Operator* CreateOp(ElementWiseSumParam param); #if DMLC_USE_CXX11 class ElementWiseSumProp : public OperatorProperty { public: void Init(const std::vector >& kwargs) override { param_.Init(kwargs); } std::map GetParams() const override { return param_.__DICT__(); } bool InferShape(std::vector *in_shape, std::vector *out_shape, std::vector *aux_shape) const override { using namespace mshadow; CHECK_EQ(in_shape->size(), static_cast(param_.num_args)); int sidx = -1; for (int i = 0; i < param_.num_args; ++i) { if (in_shape->at(i).ndim() != 0) { sidx = i; break; } } if (sidx == -1) return false; for (int i = 0; i < param_.num_args; ++i) { if (i != sidx) { SHAPE_ASSIGN_CHECK(*in_shape, i, in_shape->at(sidx)); } } out_shape->clear(); out_shape->push_back(in_shape->at(sidx)); return true; } std::vector ListArguments() const override { std::vector ret; for (int i = 0; i < param_.num_args; ++i) { ret.push_back(std::string("arg") + static_cast('0' + i)); } return ret; } OperatorProperty* Copy() const override { auto ptr = new ElementWiseSumProp(); ptr->param_ = param_; return ptr; } std::string TypeString() const override { return "ElementWiseSum"; } std::vector DeclareBackwardDependency( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data) const override { return out_grad; } std::vector > BackwardInplaceOption( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, const std::vector &in_grad) const override { return {{out_grad[0], in_grad[0]}}; } std::vector > ForwardInplaceOption( const std::vector &in_data, const std::vector &out_data) const override { return {{in_data[0], out_data[0]}}; } Operator* CreateOperator(Context ctx) const override; private: ElementWiseSumParam param_; }; // class ElementWiseSumProp #endif // DMLC_USE_CXX11 } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_ELEMENTWISE_SUM_INL_H_ //===== EXPANDED: mxnet/src/operator/elementwise_sum-inl.h ===== namespace mxnet { namespace op { template<> Operator* CreateOp(ElementWiseSumParam param) { return new ElementWiseSumOp(param); } // DO_BIND_DISPATCH comes from static_operator_common.h Operator* ElementWiseSumProp::CreateOperator(Context ctx) const { DO_BIND_DISPATCH(CreateOp, param_); } DMLC_REGISTER_PARAMETER(ElementWiseSumParam); MXNET_REGISTER_OP_PROPERTY(ElementWiseSum, ElementWiseSumProp) .describe("Perform an elementwise sum over all the inputs.") .add_arguments(ElementWiseSumParam::__FIELDS__()) .set_key_var_num_args("num_args"); } // namespace op } // namespace mxnet //===== EXPANDED: mxnet/src/operator/elementwise_sum.cc ===== //===== EXPANDIND: mxnet/src/operator/fully_connected.cc ===== /*! * Copyright (c) 2015 by Contributors * \file fully_connected.cc * \brief fully connect operator */ //===== EXPANDIND: mxnet/src/operator/fully_connected-inl.h ===== /*! * Copyright (c) 2015 by Contributors * \file fully_connect_op-inl.h * \brief fully connect operator and symbol */ #ifndef MXNET_OPERATOR_FULLY_CONNECTED_INL_H_ #define MXNET_OPERATOR_FULLY_CONNECTED_INL_H_ namespace mxnet { namespace op { // Declare enumeration of input order to make code more intuitive. // These enums are only visible within this header namespace fullc { enum FullyConnectedOpInputs {kData, kWeight, kBias}; enum FullyConnectedOpOutputs {kOut}; } // fullc struct FullyConnectedParam : public dmlc::Parameter { int num_hidden; bool no_bias; DMLC_DECLARE_PARAMETER(FullyConnectedParam) { // TODO(bing) change to only set lower bound // add support for boolean DMLC_DECLARE_FIELD(num_hidden).set_range(1, 100000) .describe("Number of hidden nodes of the output."); DMLC_DECLARE_FIELD(no_bias).set_default(false) .describe("Whether to disable bias parameter."); } }; /** * \brief This is the implementation of fully connected operator. * \tparam xpu The device that the op will be executed on. */ template class FullyConnectedOp : public Operator { public: explicit FullyConnectedOp(FullyConnectedParam p) { this->param_ = p; } virtual void Forward(const OpContext &ctx, const std::vector &in_data, const std::vector &req, const std::vector &out_data, const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(req[fullc::kOut], kWriteTo); size_t expected = param_.no_bias ? 2 : 3; CHECK_EQ(in_data.size(), expected); CHECK_EQ(out_data.size(), 1); // TODO(bing): check the BLAS Handle, be careful // maybe need blas handle from context // TODO(bing): judge shape to remove flatten op Stream *s = ctx.get_stream(); #if defined(__CUDACC__) CHECK_EQ(s->blas_handle_ownership_, Stream::OwnHandle) << "Must init CuBLAS handle in stream"; #endif // __CUDACC__ Tensor data = in_data[fullc::kData].FlatTo2D(s); Tensor wmat = in_data[fullc::kWeight].get(s); Tensor out = out_data[fullc::kOut].FlatTo2D(s); out = dot(data, wmat.T()); if (!param_.no_bias) { Tensor bias = in_data[fullc::kBias].get(s); out += repmat(bias, data.size(0)); } } virtual void Backward(const OpContext &ctx, const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, const std::vector &req, const std::vector &in_grad, const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(out_grad.size(), 1); size_t expected = param_.no_bias ? 2 : 3; CHECK(in_data.size() == expected && in_grad.size() == expected); CHECK_EQ(req.size(), expected); // TODO(bing): check the BLAS Handle, be careful // maybe need blas handle from context Stream *s = ctx.get_stream(); Tensor data = in_data[fullc::kData].FlatTo2D(s); Tensor wmat = in_data[fullc::kWeight].get(s); Tensor grad = out_grad[fullc::kOut].FlatTo2D(s); #if defined(__CUDACC__) CHECK_EQ(s->blas_handle_ownership_, Stream::OwnHandle) << "Must init CuBLAS handle in stream"; #endif // backprop CHECK_NE(req[fullc::kWeight], kWriteInplace) << "cannot write weight inplace"; // gradient of weight Tensor gwmat = in_grad[fullc::kWeight].get(s); Assign(gwmat, req[fullc::kWeight], dot(grad.T(), data)); // gradient of bias if (!param_.no_bias) { Tensor gbias = in_grad[fullc::kBias].get(s); Assign(gbias, req[fullc::kBias], sum_rows(grad)); } // gradient of data Tensor gdata = in_grad[fullc::kData].FlatTo2D(s); Assign(gdata, req[fullc::kData], dot(grad, wmat)); } private: FullyConnectedParam param_; }; // class FullyConnectedOp // Decalre Factory function, used for dispatch specialization template Operator* CreateOp(FullyConnectedParam param); #if DMLC_USE_CXX11 class FullyConnectedProp : public OperatorProperty { public: std::vector ListArguments() const override { if (!param_.no_bias) { return {"data", "weight", "bias"}; } else { return {"data", "weight"}; } } void Init(const std::vector >& kwargs) override { param_.Init(kwargs); } std::map GetParams() const override { return param_.__DICT__(); } bool InferShape(std::vector *in_shape, std::vector *out_shape, std::vector *aux_shape) const override { using namespace mshadow; if (!param_.no_bias) { CHECK_EQ(in_shape->size(), 3) << "Input:[data, weight, bias]"; } else { CHECK_EQ(in_shape->size(), 2) << "Input:[data, weight]"; } const TShape &dshape = (*in_shape)[fullc::kData]; // require data to be known if (dshape.ndim() == 0) return false; index_t num_input = 0; mshadow::Shape<2> ishape = dshape.FlatTo2D(); num_input = ishape[1]; SHAPE_ASSIGN_CHECK(*in_shape, fullc::kWeight, Shape2(param_.num_hidden, num_input)); if (!param_.no_bias) { SHAPE_ASSIGN_CHECK(*in_shape, fullc::kBias, Shape1(param_.num_hidden)); } out_shape->clear(); out_shape->push_back(Shape2(dshape[0], param_.num_hidden)); return true; } OperatorProperty* Copy() const override { FullyConnectedProp* fc_sym = new FullyConnectedProp(); fc_sym->param_ = this->param_; return fc_sym; } std::string TypeString() const override { return "FullyConnected"; } // decalre dependency and inplace optimization options std::vector DeclareBackwardDependency( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data) const override { return {out_grad[fullc::kOut], in_data[fullc::kData], in_data[fullc::kWeight]}; } std::vector > BackwardInplaceOption( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, const std::vector &in_grad) const override { return {{in_data[fullc::kData], in_grad[fullc::kData]}}; } Operator* CreateOperator(Context ctx) const override; private: FullyConnectedParam param_; }; // class FullyConnectedSymbol #endif } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_FULLY_CONNECTED_INL_H_ //===== EXPANDED: mxnet/src/operator/fully_connected-inl.h ===== namespace mxnet { namespace op { template<> Operator* CreateOp(FullyConnectedParam param) { return new FullyConnectedOp(param); } // DO_BIND_DISPATCH comes from static_operator_common.h Operator* FullyConnectedProp::CreateOperator(Context ctx) const { DO_BIND_DISPATCH(CreateOp, param_); } DMLC_REGISTER_PARAMETER(FullyConnectedParam); MXNET_REGISTER_OP_PROPERTY(FullyConnected, FullyConnectedProp) .describe("Apply matrix multiplication to input then add a bias.") .add_argument("data", "Symbol", "Input data to the FullyConnectedOp.") .add_argument("weight", "Symbol", "Weight matrix.") .add_argument("bias", "Symbol", "Bias parameter.") .add_arguments(FullyConnectedParam::__FIELDS__()); } // namespace op } // namespace mxnet //===== EXPANDED: mxnet/src/operator/fully_connected.cc ===== //===== EXPANDIND: mxnet/src/operator/leaky_relu.cc ===== /*! * Copyright (c) 2015 by Contributors * \file leaky_relu.cc * \brief * \author Bing Xu */ //===== EXPANDIND: mxnet/src/operator/leaky_relu-inl.h ===== /*! * Copyright (c) 2015 by Contributors * \file leaky_relu-inl.h * \brief leaky relu family operator * \author Bing Xu */ #ifndef MXNET_OPERATOR_LEAKY_RELU_INL_H_ #define MXNET_OPERATOR_LEAKY_RELU_INL_H_ namespace mxnet { namespace op { namespace leakyrelu { enum LeakyReLUOpInputs {kData, kGamma}; enum LeakyReLUOpOutputs {kOut, kMask}; enum LeakyReLUOpType {kLeakyReLU, kPReLU, kRReLU}; enum LeakyReLUOpResource {kRandom}; } // namespace leakyrelu struct LeakyReLUParam : public dmlc::Parameter { // use int for enumeration int act_type; float slope; float lower_bound; float upper_bound; DMLC_DECLARE_PARAMETER(LeakyReLUParam) { DMLC_DECLARE_FIELD(act_type).set_default(leakyrelu::kLeakyReLU) .add_enum("rrelu", leakyrelu::kRReLU) .add_enum("leaky", leakyrelu::kLeakyReLU) .add_enum("prelu", leakyrelu::kPReLU) .describe("Activation function to be applied."); DMLC_DECLARE_FIELD(slope).set_default(0.25f) .describe("Init slope for the activation. (For leaky only)"); DMLC_DECLARE_FIELD(lower_bound).set_default(0.125f) .describe("Lower bound of random slope. (For rrelu only)"); DMLC_DECLARE_FIELD(upper_bound).set_default(0.334f) .describe("Upper bound of random slope. (For rrelu only)"); } }; struct prelu_grad { MSHADOW_XINLINE static real_t Map(real_t a) { return a > 0.0f ? 0.0f : a; } }; template class LeakyReLUOp : public Operator { public: explicit LeakyReLUOp(LeakyReLUParam param) { param_ = param; } virtual void Forward(const OpContext &ctx, const std::vector &in_data, const std::vector &req, const std::vector &out_data, const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; size_t expected = param_.act_type == leakyrelu::kPReLU ? 2 : 1; CHECK_EQ(in_data.size(), expected); Stream *s = ctx.get_stream(); Tensor data; Tensor out; Tensor mask; Tensor weight; if (in_data[leakyrelu::kData].ndim() == 2) { Shape<4> dshape = Shape4(in_data[leakyrelu::kData].shape_[0], in_data[leakyrelu::kData].shape_[1], 1, 1); data = in_data[leakyrelu::kData].get_with_shape(dshape, s); out = out_data[leakyrelu::kOut].get_with_shape(dshape, s); if (param_.act_type == leakyrelu::kRReLU) { mask = out_data[leakyrelu::kMask].get_with_shape(dshape, s); } } else { data = in_data[leakyrelu::kData].get(s); out = out_data[leakyrelu::kOut].get(s); if (param_.act_type == leakyrelu::kRReLU) { mask = out_data[leakyrelu::kMask].get(s); } } switch (param_.act_type) { case leakyrelu::kLeakyReLU: { Assign(out, req[leakyrelu::kOut], F(data, param_.slope)); break; } case leakyrelu::kPReLU: { weight = in_data[leakyrelu::kGamma].get(s); Assign(out, req[leakyrelu::kOut], F(data, broadcast<1>(weight, out.shape_))); break; } case leakyrelu::kRReLU: { if (ctx.is_train) { Random* prnd = ctx.requested[leakyrelu::kRandom].get_random(s); mask = prnd->uniform(mask.shape_); mask = mask * (param_.upper_bound - param_.lower_bound) + param_.lower_bound; Assign(out, req[leakyrelu::kOut], F(data, mask)); } else { const float slope = (param_.lower_bound + param_.upper_bound) / 2.0f; Assign(out, req[leakyrelu::kOut], F(data, slope)); } break; } default: LOG(FATAL) << "Not implmented"; } } virtual void Backward(const OpContext & ctx, const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, const std::vector &req, const std::vector &in_grad, const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; size_t expected = param_.act_type == leakyrelu::kPReLU ? 2 : 1; CHECK_EQ(out_grad.size(), 1); CHECK_EQ(req.size(), expected); CHECK_EQ(in_data.size(), expected); Stream *s = ctx.get_stream(); Tensor output; Tensor data; Tensor gdata; Tensor grad; Tensor mask; Tensor weight; Tensor grad_weight; if (out_grad[leakyrelu::kOut].ndim() == 2) { Shape<4> dshape = Shape4(out_grad[leakyrelu::kOut].shape_[0], out_grad[leakyrelu::kOut].shape_[1], 1, 1); grad = out_grad[leakyrelu::kOut].get_with_shape(dshape, s); gdata = in_grad[leakyrelu::kData].get_with_shape(dshape, s); output = out_data[leakyrelu::kOut].get_with_shape(dshape, s); if (param_.act_type == leakyrelu::kRReLU) { mask = out_data[leakyrelu::kMask].get_with_shape(dshape, s); } if (param_.act_type == leakyrelu::kPReLU) { data = in_data[leakyrelu::kData].get_with_shape(dshape, s); } } else { grad = out_grad[leakyrelu::kOut].get(s); gdata = in_grad[leakyrelu::kData].get(s); output = out_data[leakyrelu::kOut].get(s); if (param_.act_type == leakyrelu::kRReLU) { mask = out_data[leakyrelu::kMask].get(s); } if (param_.act_type == leakyrelu::kPReLU) { data = in_data[leakyrelu::kData].get(s); } } switch (param_.act_type) { case leakyrelu::kLeakyReLU: { Assign(gdata, req[leakyrelu::kData], F(output, param_.slope) * grad); break; } case leakyrelu::kPReLU: { weight = in_data[leakyrelu::kGamma].get(s); grad_weight = in_grad[leakyrelu::kGamma].get(s); grad_weight = sumall_except_dim<1>(F(data) * grad); gdata = F(output, broadcast<1>(weight, data.shape_)) * grad; break; } case leakyrelu::kRReLU: { Assign(gdata, req[leakyrelu::kData], F(output, mask) * grad); break; } default: LOG(FATAL) << "Not implmented"; } } private: LeakyReLUParam param_; }; // class LeakyReLUOp template Operator* CreateOp(LeakyReLUParam type); #if DMLC_USE_CXX11 class LeakyReLUProp : public OperatorProperty { public: void Init(const std::vector >& kwargs) override { param_.Init(kwargs); } std::map GetParams() const override { return param_.__DICT__(); } bool InferShape(std::vector *in_shape, std::vector *out_shape, std::vector *aux_shape) const override { using namespace mshadow; if (param_.act_type == leakyrelu::kPReLU) { CHECK_EQ(in_shape->size(), 2) << "Input:[data, gamma]"; } else { CHECK_EQ(in_shape->size(), 1) << "Input:[data]"; } const TShape &dshape = in_shape->at(leakyrelu::kData); if (dshape.ndim() == 0) return false; if (param_.act_type == leakyrelu::kPReLU) { in_shape->at(leakyrelu::kGamma) = TShape(Shape1(dshape[1])); } out_shape->clear(); out_shape->push_back(dshape); if (param_.act_type == leakyrelu::kRReLU) { out_shape->push_back(dshape); } return true; } OperatorProperty* Copy() const override { auto ptr = new LeakyReLUProp(); ptr->param_ = param_; return ptr; } std::string TypeString() const override { return "LeakyReLU"; } // decalre dependency and inplace optimization options std::vector DeclareBackwardDependency( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data) const override { if (param_.act_type == leakyrelu::kPReLU) { return {out_grad[leakyrelu::kOut], out_data[leakyrelu::kOut], in_data[leakyrelu::kData], in_data[leakyrelu::kGamma]}; } else if (param_.act_type == leakyrelu::kRReLU) { return {out_grad[leakyrelu::kOut], out_data[leakyrelu::kMask], out_data[leakyrelu::kOut]}; } else { return {out_grad[leakyrelu::kOut], out_data[leakyrelu::kData]}; } } std::vector > BackwardInplaceOption( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, const std::vector &in_grad) const override { return {{out_grad[leakyrelu::kOut], in_grad[leakyrelu::kData]}}; } std::vector > ForwardInplaceOption( const std::vector &in_data, const std::vector &out_data) const override { if (param_.act_type == leakyrelu::kPReLU) { return {}; } else { return {{in_data[leakyrelu::kData], out_data[leakyrelu::kOut]}}; } } std::vector ListArguments() const override { if (param_.act_type == leakyrelu::kPReLU) { return {"data", "gamma"}; } else { return {"data"}; } } std::vector ListOutputs() const override { if (param_.act_type == leakyrelu::kRReLU) { return {"output", "mask"}; } else { return {"output"}; } } int NumOutputs() const override { if (param_.act_type == leakyrelu::kRReLU) { return 2; } else { return 1; } } int NumVisibleOutputs() const override { return 1; } std::vector ForwardResource( const std::vector &in_shape) const override { if (param_.act_type == leakyrelu::kRReLU) { return {ResourceRequest::kRandom}; } else { return std::vector(); } } Operator* CreateOperator(Context ctx) const override; private: LeakyReLUParam param_; }; #endif // DMLC_USE_CXX11 } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_LEAKY_RELU_INL_H_ //===== EXPANDED: mxnet/src/operator/leaky_relu-inl.h ===== namespace mxnet { namespace op { template<> Operator *CreateOp(LeakyReLUParam param) { return new LeakyReLUOp(param); } Operator *LeakyReLUProp::CreateOperator(Context ctx) const { DO_BIND_DISPATCH(CreateOp, param_); } DMLC_REGISTER_PARAMETER(LeakyReLUParam); MXNET_REGISTER_OP_PROPERTY(LeakyReLU, LeakyReLUProp) .describe("Apply activation function to input.") .add_argument("data", "Symbol", "Input data to activation function.") .add_arguments(LeakyReLUParam::__FIELDS__()); } // namespace op } // namespace mxnet //===== EXPANDED: mxnet/src/operator/leaky_relu.cc ===== //===== EXPANDIND: mxnet/src/operator/lrn.cc ===== /*! * Copyright (c) 2015 by Contributors * \file lrn.cc * \brief * \author Bing Xu */ //===== EXPANDIND: mxnet/src/operator/lrn-inl.h ===== /*! * Copyright (c) 2015 by Contributors * \file lrn-inl.h * \brief * \author Bing Xu */ #ifndef MXNET_OPERATOR_LRN_INL_H_ #define MXNET_OPERATOR_LRN_INL_H_ namespace mxnet { namespace op { namespace lrn_enum { enum LRNInputs {kData}; enum LRNOutputs {kOut, kTmpNorm}; } // namespace lrn_enum struct LRNParam : public dmlc::Parameter { float alpha; float beta; float knorm; uint32_t nsize; DMLC_DECLARE_PARAMETER(LRNParam) { DMLC_DECLARE_FIELD(alpha).set_default(1e-4f) .describe("value of the alpha variance scaling parameter in the normalization formula"); DMLC_DECLARE_FIELD(beta).set_default(0.75f) .describe("value of the beta power parameter in the normalization formula"); DMLC_DECLARE_FIELD(knorm).set_default(2.0f) .describe("value of the k parameter in normalization formula"); DMLC_DECLARE_FIELD(nsize) .describe("normalization window width in elements."); } }; // struct LRNParam template class LocalResponseNormOp : public Operator { public: explicit LocalResponseNormOp(LRNParam param) { param_ = param; } virtual void Forward(const OpContext &ctx, const std::vector &in_data, const std::vector &req, const std::vector &out_data, const std::vector &aux_states) { using namespace mshadow; using namespace mshadow::expr; // TODO(xxx): Test with gradient chceker CHECK_EQ(in_data.size(), 1); CHECK_EQ(out_data.size(), 2); // CHECK_EQ(req.size(), 2); CHECK_EQ(param_.nsize % 2, 1) << "LRN only supports odd values for local_size"; const real_t salpha = param_.alpha / param_.nsize; Stream *s = ctx.get_stream(); Tensor data = in_data[lrn_enum::kData].get(s); Tensor out = out_data[lrn_enum::kOut].get(s); Tensor tmp_norm = out_data[lrn_enum::kTmpNorm].get(s); tmp_norm = chpool(F(data) , param_.nsize) * salpha + param_.knorm; Assign(out, req[lrn_enum::kOut], data * F(tmp_norm, -param_.beta)); } virtual void Backward(const OpContext &ctx, const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, const std::vector &req, const std::vector &in_grad, const std::vector &aux_states) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(out_grad.size(), 1); CHECK_EQ(in_data.size(), 1); CHECK_EQ(out_data.size(), 2); const real_t salpha = param_.alpha / param_.nsize; Stream *s = ctx.get_stream(); Tensor grad = out_grad[lrn_enum::kOut].get(s); Tensor tmp_norm = out_data[lrn_enum::kTmpNorm].get(s); Tensor data = in_data[lrn_enum::kData].get(s); Tensor grad_in = in_grad[lrn_enum::kData].get(s); grad_in = grad * F(tmp_norm, -param_.beta); grad_in += (- 2.0f * param_.beta * salpha) * chpool(grad * data * F(tmp_norm, -param_.beta - 1.0f), param_.nsize) * data; } private: LRNParam param_; }; // class LocalResponseNormOp template Operator *CreateOp(LRNParam param); #if DMLC_USE_CXX11 class LocalResponseNormProp : public OperatorProperty { public: void Init(const std::vector >& kwargs) override { param_.Init(kwargs); } std::map GetParams() const override { return param_.__DICT__(); } bool InferShape(std::vector *in_shape, std::vector *out_shape, std::vector *aux_shape) const override { using namespace mshadow; CHECK_EQ(in_shape->size(), 1) << "Input:[data]"; const TShape &dshape = in_shape->at(0); if (dshape.ndim() == 0) return false; out_shape->clear(); out_shape->push_back(dshape); #if MXNET_USE_CUDNN != 1 out_shape->push_back(dshape); #endif return true; } OperatorProperty* Copy() const override { auto ptr = new LocalResponseNormProp(); ptr->param_ = param_; return ptr; } std::string TypeString() const override { return "LRN"; } std::vector DeclareBackwardDependency( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data) const override { #if MXNET_USE_CUDNN == 1 return {out_grad[lrn_enum::kOut], in_data[lrn_enum::kData], out_data[lrn_enum::kOut]}; #else return {out_grad[lrn_enum::kOut], in_data[lrn_enum::kData], out_data[lrn_enum::kTmpNorm]}; #endif } std::vector > BackwardInplaceOption( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, const std::vector &in_grad) const override { #if MXNET_USE_CUDNN == 1 return {}; #else return {{out_grad[lrn_enum::kOut], in_grad[lrn_enum::kData]}}; #endif } int NumVisibleOutputs() const override { return 1; } int NumOutputs() const override { return MXNET_USE_CUDNN == 1 ? 1 : 2; } std::vector ListArguments() const override { return {"data"}; } std::vector ListOutputs() const override { #if MXNET_USE_CUDNN == 1 return {"output"}; #else return {"output", "tmp_norm"}; #endif } Operator* CreateOperator(Context ctx) const override; private: LRNParam param_; }; // LocalResponseNormProp #endif // DMLC_USE_CXX11 } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_LRN_INL_H_ //===== EXPANDED: mxnet/src/operator/lrn-inl.h ===== #if MXNET_USE_CUDNN == 1 #endif namespace mxnet { namespace op { template<> Operator* CreateOp(LRNParam param) { return new LocalResponseNormOp(param); } Operator* LocalResponseNormProp::CreateOperator(Context ctx) const { DO_BIND_DISPATCH(CreateOp, param_); } DMLC_REGISTER_PARAMETER(LRNParam); MXNET_REGISTER_OP_PROPERTY(LRN, LocalResponseNormProp) .add_argument("data", "Symbol", "Input data to the ConvolutionOp.") .add_arguments(LRNParam::__FIELDS__()) .describe("Apply convolution to input then add a bias."); } // namespace op } // namespace mxnet //===== EXPANDED: mxnet/src/operator/lrn.cc ===== //===== EXPANDIND: mxnet/src/operator/pooling.cc ===== /*! * Copyright (c) 2015 by Contributors * \file pooling.cc * \brief * \author Bing Xu */ //===== EXPANDIND: mxnet/src/operator/pooling-inl.h ===== /*! * Copyright (c) 2015 by Contributors * \file pooling-inl.h * \brief * \author Bing Xu */ #ifndef MXNET_OPERATOR_POOLING_INL_H_ #define MXNET_OPERATOR_POOLING_INL_H_ namespace mxnet { namespace op { namespace pool_enum { enum PoolingOpInputs {kData}; enum PoolingOpOutputs {kOut}; enum PoolingOpType {kMaxPooling, kAvgPooling, kSumPooling}; } // namespace pool_enum struct PoolingParam : public dmlc::Parameter { TShape kernel; TShape stride; TShape pad; int pool_type; DMLC_DECLARE_PARAMETER(PoolingParam) { // TODO(bing) change to only set lower bound DMLC_DECLARE_FIELD(kernel) .set_expect_ndim(2).enforce_nonzero() .describe("pooling kernel size: (y, x)"); DMLC_DECLARE_FIELD(pool_type) .add_enum("max", pool_enum::kMaxPooling) .add_enum("avg", pool_enum::kAvgPooling) .add_enum("sum", pool_enum::kSumPooling) .describe("Pooling type to be applied."); int stride_shape[] = {1, 1}; DMLC_DECLARE_FIELD(stride).set_default(TShape(stride_shape, stride_shape + 2)) .set_expect_ndim(2).enforce_nonzero() .describe("stride: for pooling (y, x)"); int pad_shape[] = {0, 0}; DMLC_DECLARE_FIELD(pad).set_default(TShape(pad_shape, pad_shape + 2)) .set_expect_ndim(2) .describe("pad for pooling: (y, x)"); } }; template class PoolingOp : public Operator { public: explicit PoolingOp(PoolingParam p) { this->param_ = p; } virtual void Forward(const OpContext &ctx, const std::vector &in_data, const std::vector &req, const std::vector &out_data, const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(in_data.size(), 1); CHECK_EQ(out_data.size(), 1); Stream *s = ctx.get_stream(); Tensor data = in_data[pool_enum::kData].get(s); Tensor out = out_data[pool_enum::kOut].get(s); mshadow::Shape<2> out_shape = Shape2(out.shape_[2], out.shape_[3]); // TODO(bing): dual stride in mshadow CHECK_EQ(param_.stride[0], param_.stride[1]) << "Only same stride is supported now"; if (param_.pool_type == pool_enum::kMaxPooling || param_.pool_type == pool_enum::kSumPooling) { Assign(out, req[pool_enum::kOut], pool(pad(data, param_.pad[0], param_.pad[1]), out_shape, param_.kernel[0], param_.kernel[1], param_.stride[0])); } else if (param_.pool_type == pool_enum::kAvgPooling) { Assign(out, req[pool_enum::kOut], (1.0f / (param_.kernel[0] * param_.kernel[1])) * \ pool(pad(data, param_.pad[0], param_.pad[1]), out_shape, param_.kernel[0], param_.kernel[1], param_.stride[0])); } } virtual void Backward(const OpContext &ctx, const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, const std::vector &req, const std::vector &in_grad, const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(out_grad.size(), 1); CHECK_EQ(in_data.size(), 1); CHECK_EQ(out_data.size(), 1); CHECK_EQ(req.size(), 1); CHECK_EQ(in_grad.size(), 1); // TODO(bing): remove pad (0,0) Stream *s = ctx.get_stream(); Tensor grad = out_grad[pool_enum::kOut].get(s); Tensor data = in_data[pool_enum::kData].get(s); Tensor output_data = out_data[pool_enum::kOut].get(s); Tensor input_grad = in_grad[pool_enum::kData].get(s); mshadow::Shape<2> in_shape = Shape2(data.shape_[2], data.shape_[3]); if (param_.pool_type == pool_enum::kMaxPooling || param_.pool_type == pool_enum::kSumPooling) { Assign(input_grad, req[pool_enum::kData], crop(unpool(pad(data, param_.pad[0], param_.pad[1]), pad(output_data, 0, 0), pad(grad, 0, 0), param_.kernel[0], param_.kernel[1], param_.stride[0]), in_shape, param_.pad[0], param_.pad[1])); } else if (param_.pool_type == pool_enum::kAvgPooling) { Assign(input_grad, req[pool_enum::kData], (1.0f / param_.kernel[0] / param_.kernel[1]) *\ crop(unpool(pad(data, param_.pad[0], param_.pad[1]), pad(output_data, 0, 0), pad(grad, 0, 0), param_.kernel[0], param_.kernel[1], param_.stride[0]), in_shape, param_.pad[0], param_.pad[1])); } } private: PoolingParam param_; }; // class PoolingOp template Operator* CreateOp(PoolingParam param); #if DMLC_USE_CXX11 class PoolingProp : public OperatorProperty { public: void Init(const std::vector >& kwargs) override { param_.Init(kwargs); } std::map GetParams() const override { return param_.__DICT__(); } bool InferShape(std::vector *in_shape, std::vector *out_shape, std::vector *aux_shape) const override { CHECK_EQ(in_shape->size(), 1); const TShape &dshape = (*in_shape)[0]; CHECK_EQ(dshape.ndim(), 4) << \ "Pooling: Input data should be 4D in (batch, channel, y, x)"; TShape oshape = dshape; if (dshape.ndim() == 0) return false; oshape[2] = std::min(dshape[2] + 2 * param_.pad[0] - param_.kernel[0] + param_.stride[0] - 1, dshape[2] + 2 * param_.pad[0] - 1) / param_.stride[0] + 1; oshape[3] = std::min(dshape[3] + 2 * param_.pad[1] - param_.kernel[1] + param_.stride[1] - 1, dshape[3] + 2 * param_.pad[1] - 1) / param_.stride[1] + 1; CHECK(oshape[2] > 0 && oshape[3] > 0) << "Pooling: kernel size exceed input"; out_shape->clear(); out_shape->push_back(oshape); return true; } OperatorProperty* Copy() const override { PoolingProp *prop_sym = new PoolingProp(); prop_sym->param_ = this->param_; return prop_sym; } std::string TypeString() const override { return "Pooling"; } std::vector DeclareBackwardDependency( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data) const override { return {out_grad[pool_enum::kOut], in_data[pool_enum::kData], out_data[pool_enum::kOut]}; } std::vector > BackwardInplaceOption( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, const std::vector &in_grad) const override { #if MXNET_USE_CUDNN == 1 return {}; #else return {{in_data[pool_enum::kData], in_grad[pool_enum::kData]}}; #endif } Operator* CreateOperator(Context ctx) const override; private: PoolingParam param_; }; // class PoolingProp #endif // DMLC_USE_CXX11 } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_POOLING_INL_H_ //===== EXPANDED: mxnet/src/operator/pooling-inl.h ===== namespace mxnet { namespace op { template<> Operator *CreateOp(PoolingParam param) { switch (param.pool_type) { case pool_enum::kMaxPooling: return new PoolingOp(param); case pool_enum::kAvgPooling: return new PoolingOp(param); case pool_enum::kSumPooling: return new PoolingOp(param); default: LOG(FATAL) << "unknown activation type"; return NULL; } } Operator* PoolingProp::CreateOperator(Context ctx) const { DO_BIND_DISPATCH(CreateOp, param_); } DMLC_REGISTER_PARAMETER(PoolingParam); MXNET_REGISTER_OP_PROPERTY(Pooling, PoolingProp) .describe("Perform spatial pooling on inputs.") .add_argument("data", "Symbol", "Input data to the pooling operator.") .add_arguments(PoolingParam::__FIELDS__()); } // namespace op } // namespace mxnet //===== EXPANDED: mxnet/src/operator/pooling.cc ===== //===== EXPANDIND: mxnet/src/operator/regression_output.cc ===== /*! * Copyright (c) 2015 by Contributors * \file regression_output.cc * \brief regression output operator */ //===== EXPANDIND: mxnet/src/operator/regression_output-inl.h ===== /*! * Copyright (c) 2015 by Contributors * \file regression_ouput-inl.h * \brief Regression output operator. */ #ifndef MXNET_OPERATOR_REGRESSION_OUTPUT_INL_H_ #define MXNET_OPERATOR_REGRESSION_OUTPUT_INL_H_ namespace mxnet { namespace op { namespace reg_enum { enum RegressionOutputOpInputs {kData, kLabel}; enum RegressionOutputOutputs {kOut}; enum RegressionOutputType {kLinear, kLogistic}; } // reg_enum // Special Operator to output regression value in forward // And get gradient in calculation. template class RegressionOutputOp : public Operator { public: virtual void Forward(const OpContext &ctx, const std::vector &in_data, const std::vector &req, const std::vector &out_data, const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(in_data.size(), 2) << "RegressionOutputOp Input: [data, label]"; CHECK_EQ(out_data.size(), 1) << "RegressionOutputOp Output: [output]"; Stream *s = ctx.get_stream(); Tensor data = in_data[reg_enum::kData].FlatTo2D(s); Tensor out = out_data[reg_enum::kOut].FlatTo2D(s); Assign(out, req[reg_enum::kOut], F(data)); } virtual void Backward(const OpContext &ctx, const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, const std::vector &req, const std::vector &in_grad, const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(in_data.size(), 2); CHECK_EQ(out_grad.size(), 1); CHECK_GE(in_grad.size(), 1); CHECK_GE(req.size(), 1); Stream *s = ctx.get_stream(); real_t num_output = in_data[reg_enum::kLabel].Size()/in_data[reg_enum::kLabel].shape_[0]; Tensor out = out_data[reg_enum::kOut].FlatTo2D(s); Tensor grad = in_grad[reg_enum::kData].FlatTo2D(s); Tensor label = in_data[reg_enum::kLabel] .get_with_shape(out.shape_, s); Assign(grad, req[reg_enum::kData], F(out, label)/num_output); } }; // Decalre Factory function, used for dispatch specialization template Operator* CreateRegressionOutputOp(reg_enum::RegressionOutputType type); #if DMLC_USE_CXX11 template class RegressionOutputProp : public OperatorProperty { public: std::vector ListArguments() const override { return {"data", "label"}; } void Init(const std::vector >& kwargs) override { } std::map GetParams() const override { return std::map(); } bool InferShape(std::vector *in_shape, std::vector *out_shape, std::vector *aux_shape) const override { using namespace mshadow; CHECK_EQ(in_shape->size(), 2) << "Input:[data, label]"; const TShape &dshape = in_shape->at(0); if (dshape.ndim() == 0) return false; auto &lshape = (*in_shape)[1]; if (lshape.ndim() == 0) { lshape = dshape; } else if (lshape[0] != dshape[0] || lshape.Size() != dshape.Size()) { std::ostringstream os; os << "Shape inconsistent, Provided " << '='<< lshape << ',' << " inferred shape=" << dshape; throw ::mxnet::op::InferShapeError(os.str(), 1); } out_shape->clear(); out_shape->push_back(dshape); return true; } OperatorProperty* Copy() const override { return new RegressionOutputProp(); } std::string TypeString() const override { switch (type) { case reg_enum::kLinear: return "LinearRegressionOutput"; case reg_enum::kLogistic: return "LogisticRegressionOutput"; default: LOG(FATAL) << "unknown type"; return ""; } } std::vector DeclareBackwardDependency( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data) const override { return {in_data[reg_enum::kLabel], out_data[reg_enum::kOut]}; } std::vector > BackwardInplaceOption( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, const std::vector &in_grad) const override { return {{out_data[reg_enum::kOut], in_grad[reg_enum::kData]}}; } std::vector > ForwardInplaceOption( const std::vector &in_data, const std::vector &out_data) const override { return {{in_data[reg_enum::kData], out_data[reg_enum::kOut]}}; } Operator* CreateOperator(Context ctx) const override; }; #endif // DMLC_USE_CXX11 } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_REGRESSION_OUTPUT_INL_H_ //===== EXPANDED: mxnet/src/operator/regression_output-inl.h ===== namespace mxnet { namespace op { template<> Operator *CreateRegressionOutputOp(reg_enum::RegressionOutputType type) { switch (type) { case reg_enum::kLinear: return new RegressionOutputOp(); case reg_enum::kLogistic: return new RegressionOutputOp(); default: LOG(FATAL) << "unknown activation type " << type; } return nullptr; } // DO_BIND_DISPATCH comes from operator_common.h template Operator *RegressionOutputProp::CreateOperator(Context ctx) const { DO_BIND_DISPATCH(CreateRegressionOutputOp, type); } MXNET_REGISTER_OP_PROPERTY(LinearRegressionOutput, RegressionOutputProp) .describe("Use linear regression for final output, this is used on final output of a net.") .add_argument("data", "Symbol", "Input data to function.") .add_argument("label", "Symbol", "Input label to function."); MXNET_REGISTER_OP_PROPERTY(LogisticRegressionOutput, RegressionOutputProp) .describe("Use Logistic regression for final output, this is used on final output of a net.\n" "Logistic regression is suitable for binary classification " "or probability prediction tasks.") .add_argument("data", "Symbol", "Input data to function.") .add_argument("label", "Symbol", "Input label to function."); } // namespace op } // namespace mxnet //===== EXPANDED: mxnet/src/operator/regression_output.cc ===== //===== EXPANDIND: mxnet/src/operator/reshape.cc ===== /*! * Copyright (c) 2015 by Contributors * \file flatten.cc * \brief * \author Bing Xu */ //===== EXPANDIND: mxnet/src/operator/reshape-inl.h ===== /*! * Copyright (c) 2015 by Contributors * \file reshape-inl.h * \brief * \author Bing Xu */ #ifndef MXNET_OPERATOR_RESHAPE_INL_H_ #define MXNET_OPERATOR_RESHAPE_INL_H_ namespace mxnet { namespace op { namespace reshape_enum { enum ReshapeOpInputs {kData}; enum ReshapeOpOutputs {kOut}; } // namespace reshape_enum struct ReshapeParam : public dmlc::Parameter { TShape target_shape; DMLC_DECLARE_PARAMETER(ReshapeParam) { DMLC_DECLARE_FIELD(target_shape).describe("Target new shape"); } }; template class ReshapeOp : public Operator { public: explicit ReshapeOp(ReshapeParam param) {} // Do nothing virtual void Forward(const OpContext &ctx, const std::vector &in_data, const std::vector &req, const std::vector &out_data, const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(in_data.size(), 1); CHECK_EQ(req.size(), 1); CHECK_EQ(out_data.size(), 1); if (req[reshape_enum::kOut] == kNullOp) return; Stream *s = ctx.get_stream(); Tensor data = in_data[reshape_enum::kData].FlatTo2D(s); Tensor out = out_data[reshape_enum::kOut].FlatTo2D(s); CHECK_EQ(data.CheckContiguous(), true); CHECK_EQ(out.CheckContiguous(), true); if (data.dptr_ == out.dptr_) return; CHECK_EQ(data.shape_.Size(), out.shape_.Size()); Assign(out, req[reshape_enum::kOut], reshape(data, out.shape_)); } virtual void Backward(const OpContext &ctx, const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, const std::vector &req, const std::vector &in_grad, const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(req.size(), 1); if (req[reshape_enum::kData] == kNullOp) return; CHECK_EQ(out_grad.size(), 1); CHECK_EQ(in_grad.size(), 1); Stream *s = ctx.get_stream(); Tensor grad_in = in_grad[reshape_enum::kOut].FlatTo2D(s); Tensor grad_out = out_grad[reshape_enum::kData].FlatTo2D(s); CHECK_EQ(grad_out.CheckContiguous(), true); CHECK_EQ(grad_in.CheckContiguous(), true); if (grad_out.dptr_ == grad_in.dptr_) return; CHECK_EQ(grad_out.shape_.Size(), grad_in.shape_.Size()); Assign(grad_in, req[reshape_enum::kData], reshape(grad_out, grad_in.shape_)); } }; // class ReshapeOp template Operator* CreateOp(ReshapeParam); #if DMLC_USE_CXX11 class ReshapeProp : public OperatorProperty { public: ReshapeProp() {} void Init(const std::vector >& kwargs) override { param_.Init(kwargs); } std::map GetParams() const override { return param_.__DICT__(); } bool InferShape(std::vector *in_shape, std::vector *out_shape, std::vector *aux_shape) const override { CHECK_EQ(in_shape->size(), 1) << "Input: [data]"; const TShape &dshape = in_shape->at(reshape_enum::kData); if (dshape.ndim() == 0) return false; CHECK(param_.target_shape.Size() == dshape.Size()) << "Target shape size is different to source. " << "Target: " << param_.target_shape.Size() << "\nSource: " << dshape.Size(); out_shape->clear(); out_shape->push_back(param_.target_shape); return true; } OperatorProperty* Copy() const override { auto ptr = new ReshapeProp(); ptr->param_ = param_; return ptr; } std::string TypeString() const override { return "Reshape"; } std::vector DeclareBackwardDependency( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data) const override { return {out_grad[reshape_enum::kOut]}; } std::vector > ForwardInplaceOption( const std::vector &in_data, const std::vector &out_data) const override { return {{in_data[reshape_enum::kData], out_data[reshape_enum::kOut]}}; } std::vector > BackwardInplaceOption( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, const std::vector &in_grad) const override { return {{out_grad[reshape_enum::kOut], in_grad[reshape_enum::kData]}}; } Operator* CreateOperator(Context ctx) const override; protected: ReshapeParam param_; }; // class ReshapeProp class FlattenProp : public ReshapeProp { public: void Init(const std::vector >& kwargs) override {} std::map GetParams() const override { // need to use this on osx return std::map(); } std::string TypeString() const override { return "Flatten"; } bool InferShape(std::vector *in_shape, std::vector *out_shape, std::vector *aux_shape) const override { CHECK_EQ(in_shape->size(), 1) << "Input: [data]"; const TShape &dshape = in_shape->at(reshape_enum::kData); if (dshape.ndim() == 0) return false; out_shape->clear(); uint32_t target_dim = 1; for (uint32_t i = 1; i < dshape.ndim(); ++i) { target_dim *= dshape[i]; } out_shape->push_back(mshadow::Shape2(dshape[0], target_dim)); return true; } OperatorProperty* Copy() const override { auto ptr = new FlattenProp(); return ptr; } }; // class FlattenProp #endif // DMLC_USE_CXX11 } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_RESHAPE_INL_H_ //===== EXPANDED: mxnet/src/operator/reshape-inl.h ===== namespace mxnet { namespace op { template<> Operator *CreateOp(ReshapeParam param) { return new ReshapeOp(param); } Operator* ReshapeProp::CreateOperator(Context ctx) const { DO_BIND_DISPATCH(CreateOp, param_); } DMLC_REGISTER_PARAMETER(ReshapeParam); MXNET_REGISTER_OP_PROPERTY(Reshape, ReshapeProp) .describe("Reshape input to target shape") .add_argument("data", "Symbol", "Input data to reshape.") .add_arguments(ReshapeParam::__FIELDS__()); MXNET_REGISTER_OP_PROPERTY(Flatten, FlattenProp) .describe("Flatten input") .add_argument("data", "Symbol", "Input data to flatten."); } // namespace op } // namespace mxnet //===== EXPANDED: mxnet/src/operator/reshape.cc ===== //===== EXPANDIND: mxnet/src/operator/slice_channel.cc ===== /*! * Copyright (c) 2015 by Contributors * \file slice_channel.cc * \brief * \author Bing Xu */ //===== EXPANDIND: mxnet/src/operator/slice_channel-inl.h ===== /*! * Copyright (c) 2015 by Contributors * \file slice_channel-inl.h * \brief * \author Bing Xu */ #ifndef MXNET_OPERATOR_SLICE_CHANNEL_INL_H_ #define MXNET_OPERATOR_SLICE_CHANNEL_INL_H_ namespace mxnet { namespace op { namespace slice_enum { enum SliceChannelOpInputs {kData}; enum SliceChannelOpOutputs {kOut0, kOut1, kOut2, kOut3, kOut4}; } // namespace slice_enum struct SliceChannelParam : public dmlc::Parameter { int num_outputs; DMLC_DECLARE_PARAMETER(SliceChannelParam) { DMLC_DECLARE_FIELD(num_outputs).set_lower_bound(1) .describe("Number of outputs to be sliced."); } }; // struct SliceChannelParam template class SliceChannelOp : public Operator { public: explicit SliceChannelOp(SliceChannelParam param) : size_(param.num_outputs) {} virtual void Forward(const OpContext &ctx, const std::vector &in_data, const std::vector &req, const std::vector &out_data, const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(in_data.size(), 1); CHECK_EQ(out_data.size(), static_cast(size_)); Stream *s = ctx.get_stream(); std::vector > outputs(size_); Tensor data; if (in_data[slice_enum::kData].ndim() == 2) { Shape<4> dshape = Shape4(in_data[slice_enum::kData].shape_[0], in_data[slice_enum::kData].shape_[1], 1, 1); data = in_data[slice_enum::kData].get_with_shape(dshape, s); Shape<4> slice_shape = dshape; slice_shape[1] = dshape[1] / size_; for (int i = 0; i < size_; ++i) { outputs[i] = out_data[i].get_with_shape(slice_shape, s); } } else { data = in_data[slice_enum::kData].get(s); for (int i = 0; i < size_; ++i) { outputs[i] = out_data[i].get(s); } } Split(data, &outputs); } virtual void Backward(const OpContext &ctx, const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, const std::vector &req, const std::vector &in_grad, const std::vector &aux_states) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(out_grad.size(), static_cast(size_)); CHECK_EQ(in_grad.size(), 1); Stream *s = ctx.get_stream(); std::vector > grad_out(size_); Tensor grad; if (out_grad[slice_enum::kOut0].ndim() == 2) { Shape<4> slice_shape = Shape4(out_grad[slice_enum::kOut0].shape_[0], out_grad[slice_enum::kOut0].shape_[1], 1, 1); for (int i = 0; i < size_; ++i) { grad_out[i] = out_grad[i].get_with_shape(slice_shape, s); } Shape<4> dshape = slice_shape; dshape[1] *= size_; grad = in_grad[slice_enum::kData].get_with_shape(dshape, s); } else { for (int i = 0; i < size_; ++i) { grad_out[i] = out_grad[i].get(s); } grad = in_grad[slice_enum::kData].get(s); } Concatenate(grad_out, &grad); } private: int size_; }; // class SliceChannelOp template Operator *CreateOp(SliceChannelParam param); #if DMLC_USE_CXX11 class SliceChannelProp : public OperatorProperty { public: void Init(const std::vector >& kwargs) override { param_.Init(kwargs); } std::map GetParams() const override { return param_.__DICT__(); } std::vector ListOutputs() const override { std::vector ret; for (int i = 0; i < param_.num_outputs; ++i) { ret.push_back(std::string("output") + static_cast('0' + i)); } return ret; } int NumOutputs() const override { return param_.num_outputs; } bool InferShape(std::vector *in_shape, std::vector *out_shape, std::vector *aux_shape) const override { using namespace mshadow; CHECK_EQ(in_shape->size(), 1); TShape dshape = in_shape->at(slice_enum::kData); if (dshape.ndim() == 0) return false; CHECK_GT(dshape.ndim(), 1); CHECK_EQ(dshape[1] % param_.num_outputs, 0) << "Channel must be divided by the output number: " << dshape[1] << " / " << param_.num_outputs; dshape[1] /= param_.num_outputs; out_shape->clear(); for (int i = 0; i < param_.num_outputs; ++i) { out_shape->push_back(dshape); } return true; } OperatorProperty* Copy() const override { auto ptr = new SliceChannelProp(); ptr->param_ = param_; return ptr; } std::string TypeString() const override { return "SliceChannel"; } std::vector DeclareBackwardDependency( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data) const override { return out_grad; } Operator* CreateOperator(Context ctx) const override; private: SliceChannelParam param_; }; // class SliceChannelProp #endif // DMLC_USE_CXX11 } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_SLICE_CHANNEL_INL_H_ //===== EXPANDED: mxnet/src/operator/slice_channel-inl.h ===== namespace mxnet { namespace op { template<> Operator* CreateOp(SliceChannelParam param) { return new SliceChannelOp(param); } Operator* SliceChannelProp::CreateOperator(Context ctx) const { DO_BIND_DISPATCH(CreateOp, param_); } DMLC_REGISTER_PARAMETER(SliceChannelParam); MXNET_REGISTER_OP_PROPERTY(SliceChannel, SliceChannelProp) .describe("Slice channel into many outputs with equally divided channel") .add_arguments(SliceChannelParam::__FIELDS__()); } // namespace op } // namespace mxnet //===== EXPANDED: mxnet/src/operator/slice_channel.cc ===== //===== EXPANDIND: mxnet/src/operator/softmax_output.cc ===== /*! * Copyright (c) 2015 by Contributors * \file softmax_output.cc * \brief * \author Bing Xu */ //===== EXPANDIND: mxnet/src/operator/softmax_output-inl.h ===== /*! * Copyright (c) 2015 by Contributors * \file softmax_output-inl.h * \brief * \author Bing Xu */ #ifndef MXNET_OPERATOR_SOFTMAX_OUTPUT_INL_H_ #define MXNET_OPERATOR_SOFTMAX_OUTPUT_INL_H_ namespace mxnet { namespace op { namespace softmaxout_enum { enum SoftmaxOutputOpInputs {kData, kLabel}; enum SoftmaxOutputOpOutputs {kOut}; } // namespace softmaxout_enum struct SoftmaxOutputParam : public dmlc::Parameter { float grad_scale; bool multi_output; DMLC_DECLARE_PARAMETER(SoftmaxOutputParam) { DMLC_DECLARE_FIELD(grad_scale).set_default(1.0f) .describe("Scale the gradient by a float factor"); DMLC_DECLARE_FIELD(multi_output).set_default(false) .describe("If set to true, for a (n,k,x_1,..,x_n) dimensional" "input tensor, softmax will generate n*x_1*...*x_n output, each" "has k classes"); }; }; template class SoftmaxOutputOp : public Operator { public: explicit SoftmaxOutputOp(SoftmaxOutputParam param) : param_(param) {} virtual void Forward(const OpContext &ctx, const std::vector &in_data, const std::vector &req, const std::vector &out_data, const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(in_data.size(), 2) << "SoftmaxOutput Input: [data, label]"; CHECK_EQ(out_data.size(), 1) << "SoftmaxOutput Output: [output]"; Stream *s = ctx.get_stream(); if (param_.multi_output) { int n = in_data[softmaxout_enum::kData].size(0); int k = in_data[softmaxout_enum::kData].size(1); Shape<3> s3 = Shape3(n, k, static_cast(in_data[softmaxout_enum::kData].Size()/n/k)); Tensor data = in_data[softmaxout_enum::kData].get_with_shape(s3, s); Tensor out = out_data[softmaxout_enum::kOut].get_with_shape(s3, s); Softmax(out, data); } else { Tensor data = in_data[softmaxout_enum::kData].FlatTo2D(s); Tensor out = out_data[softmaxout_enum::kOut].FlatTo2D(s); Softmax(out, data); } } virtual void Backward(const OpContext &ctx, const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, const std::vector &req, const std::vector &in_grad, const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(in_data.size(), 2); CHECK_EQ(out_grad.size(), 1); CHECK_GE(in_grad.size(), 1); CHECK_GE(req.size(), 1); Stream *s = ctx.get_stream(); if (param_.multi_output) { int n = out_data[softmaxout_enum::kOut].size(0); int k = out_data[softmaxout_enum::kOut].size(1); Shape<3> s3 = Shape3(n, k, static_cast(out_data[softmaxout_enum::kOut].Size()/n/k)); Tensor label = in_data[softmaxout_enum::kLabel].FlatTo2D(s); Tensor out = out_data[softmaxout_enum::kOut].get_with_shape(s3, s); Tensor grad = in_grad[softmaxout_enum::kData].get_with_shape(s3, s); SoftmaxGrad(grad, out, label); if (param_.grad_scale < 1.0) { grad *= param_.grad_scale; } } else { Tensor label = in_data[softmaxout_enum::kLabel].get(s); Tensor out = out_data[softmaxout_enum::kOut].FlatTo2D(s); Tensor grad = in_grad[softmaxout_enum::kData].FlatTo2D(s); SoftmaxGrad(grad, out, label); if (param_.grad_scale < 1.0) { grad *= param_.grad_scale; } } } private: SoftmaxOutputParam param_; }; // class SoftmaxOutputOp // Decalre Factory function, used for dispatch specialization template Operator* CreateOp(SoftmaxOutputParam param); #if DMLC_USE_CXX11 class SoftmaxOutputProp : public OperatorProperty { public: std::vector ListArguments() const override { return {"data", "label"}; } void Init(const std::vector >& kwargs) override { param_.Init(kwargs); } std::map GetParams() const override { return param_.__DICT__(); } bool InferShape(std::vector *in_shape, std::vector *out_shape, std::vector *aux_shape) const override { using namespace mshadow; CHECK_EQ(in_shape->size(), 2) << "Input:[data, label]"; const TShape &dshape = in_shape->at(0); if (dshape.ndim() == 0) return false; if (param_.multi_output) { SHAPE_ASSIGN_CHECK(*in_shape, softmaxout_enum::kLabel, Shape2(dshape[0], dshape.Size()/dshape[0]/dshape[1])); } else { SHAPE_ASSIGN_CHECK(*in_shape, softmaxout_enum::kLabel, Shape1(dshape[0])); } out_shape->clear(); out_shape->push_back(dshape); return true; } OperatorProperty* Copy() const override { auto ptr = new SoftmaxOutputProp(); ptr->param_ = param_; return ptr; } std::string TypeString() const override { return "SoftmaxOutput"; } std::vector DeclareBackwardDependency( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data) const override { return {in_data[softmaxout_enum::kLabel], out_data[softmaxout_enum::kOut]}; } std::vector > BackwardInplaceOption( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, const std::vector &in_grad) const override { return {{out_data[softmaxout_enum::kOut], in_grad[softmaxout_enum::kData]}}; } std::vector > ForwardInplaceOption( const std::vector &in_data, const std::vector &out_data) const override { return {{in_data[softmaxout_enum::kData], out_data[softmaxout_enum::kOut]}}; } Operator* CreateOperator(Context ctx) const override; protected: SoftmaxOutputParam param_; }; // class SoftmaxOutputProp class DeprecatedSoftmaxProp : public SoftmaxOutputProp { public: void Init(const std::vector >& kwargs) override { LOG(INFO) << "Softmax symbol is renamed to SoftmaxOutput. " << "This API will be deprecated in Dec, 2015"; SoftmaxOutputProp::param_.Init(kwargs); } std::string TypeString() const override { return "Softmax"; } }; #endif // DMLC_USE_CXX11 } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_SOFTMAX_OUTPUT_INL_H_ //===== EXPANDED: mxnet/src/operator/softmax_output-inl.h ===== namespace mxnet { namespace op { template<> Operator *CreateOp(SoftmaxOutputParam param) { return new SoftmaxOutputOp(param); } Operator *SoftmaxOutputProp::CreateOperator(Context ctx) const { DO_BIND_DISPATCH(CreateOp, param_); } DMLC_REGISTER_PARAMETER(SoftmaxOutputParam); MXNET_REGISTER_OP_PROPERTY(SoftmaxOutput, SoftmaxOutputProp) .describe("Perform a softmax transformation on input, backprop with logloss.") .add_argument("data", "Symbol", "Input data to softmax.") .add_arguments(SoftmaxOutputParam::__FIELDS__()); MXNET_REGISTER_OP_PROPERTY(Softmax, DeprecatedSoftmaxProp) .describe("DEPRECATED: Perform a softmax transformation on input. Please use SoftmaxOutput") .add_argument("data", "Symbol", "Input data to softmax.") .add_arguments(SoftmaxOutputParam::__FIELDS__()); } // namespace op } // namespace mxnet //===== EXPANDED: mxnet/src/operator/softmax_output.cc ===== //===== EXPANDIND: mxnet/src/operator/deconvolution.cc ===== /*! * Copyright (c) 2015 by Contributors * \file deconvolution.cc * \brief * \author Wei Wu */ //===== EXPANDIND: mxnet/src/operator/deconvolution-inl.h ===== /*! * Copyright (c) 2015 by Contributors * \file deconvolution-inl.h * \brief * \author Wei Wu */ #ifndef MXNET_OPERATOR_DECONVOLUTION_INL_H_ #define MXNET_OPERATOR_DECONVOLUTION_INL_H_ namespace mxnet { namespace op { namespace deconv { enum DeconvolutionOpInputs {kData, kWeight, kBias}; enum DeconvolutionOpOutputs {kOut}; enum DeconvolutionOpResource {kTempSpace}; } struct DeconvolutionParam : public dmlc::Parameter { TShape kernel; TShape stride; TShape pad; uint32_t num_filter; uint32_t num_group; uint64_t workspace; bool no_bias; DMLC_DECLARE_PARAMETER(DeconvolutionParam) { int shape[] = {1, 1}; DMLC_DECLARE_FIELD(kernel).describe("deconvolution kernel size: (y, x)"); DMLC_DECLARE_FIELD(stride).set_default(TShape(shape, shape + 2)) .describe("deconvolution stride: (y, x)"); shape[0] = shape[1] = 0; DMLC_DECLARE_FIELD(pad).set_default(TShape(shape, shape + 2)) .describe("pad for deconvolution: (y, x)"); DMLC_DECLARE_FIELD(num_filter).set_range(1, 100000) .describe("deconvolution filter(channel) number"); DMLC_DECLARE_FIELD(num_group).set_default(1) .describe("number of groups partition"); DMLC_DECLARE_FIELD(workspace).set_default(512).set_range(128, 4096) .describe("Tmp workspace for deconvolution (MB)"); DMLC_DECLARE_FIELD(no_bias).set_default(true) .describe("Whether to disable bias parameter."); } }; template class DeconvolutionOp : public Operator { public: explicit DeconvolutionOp(DeconvolutionParam p) { this->param_ = p; // convert MB to words param_.workspace = (param_.workspace << 20) / sizeof(real_t); } virtual void Forward(const OpContext &ctx, const std::vector &in_data, const std::vector &req, const std::vector &out_data, const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(req[deconv::kOut], kWriteTo); size_t expected = param_.no_bias ? 2 : 3; CHECK_EQ(in_data.size(), expected); CHECK_EQ(out_data.size(), 1); Stream *s = ctx.get_stream(); Tensor data = in_data[deconv::kData].get(s); Tensor out = out_data[deconv::kOut].get(s); Shape<3> wmat_shape = Shape3(param_.num_group, data.shape_[1] / param_.num_group, param_.num_filter / param_.num_group * param_.kernel[0] * param_.kernel[1]); Tensor wmat = in_data[deconv::kWeight].get_with_shape(wmat_shape, s); #if defined(__CUDACC__) CHECK_EQ(s->blas_handle_ownership_, Stream::OwnHandle) << "Must init CuBLAS handle in stream"; #endif const index_t nbatch = data.size(0); Tensor workspace = ctx.requested[deconv::kTempSpace].get_space( Shape1(this->InitTemp(out.shape_, data.shape_)), s); for (index_t i = 0; i < nbatch; i += nstep_) { const index_t step = std::min(nstep_, nbatch - i); Tensor temp_col = Tensor(workspace.dptr_, Shape2(shape_colunit_[0], shape_colunit_[1] * step), s); Tensor temp_dst = Tensor(workspace.dptr_ + temp_col.shape_.Size(), Shape3(shape_dstunit_[0], shape_dstunit_[1], shape_dstunit_[2] * step), s); temp_dst = reshape(swapaxis<1, 0>(data.Slice(i, i + step)), temp_dst.shape_); if (param_.pad[0] == 0 && param_.pad[1] == 0) { temp_col = unpack_patch2col(out.Slice(i, i + step), param_.kernel[0], param_.kernel[1], param_.stride[0], param_.stride[1]); } else { temp_col = unpack_patch2col(pad(out.Slice(i, i + step), param_.pad[0], param_.pad[1]), param_.kernel[0], param_.kernel[1], param_.stride[0], param_.stride[1]); } const index_t gstride = temp_col.size(0) / param_.num_group; for (uint32_t gid = 0; gid < param_.num_group; ++gid) { mshadow::Tensor tmpc = temp_col.Slice(gstride * gid, gstride * (gid + 1)); tmpc = dot(wmat[gid].T(), temp_dst[gid]); } if (param_.pad[0] == 0 && param_.pad[1] == 0) { out.Slice(i, i + step) = pack_col2patch(temp_col, out.Slice(i, i + step).shape_, param_.kernel[0], param_.kernel[1], param_.stride[0]); } else { Shape<4> pshape = out.Slice(i, i + step).shape_; pshape[2] += 2 * param_.pad[0]; pshape[3] += 2 * param_.pad[1]; out.Slice(i, i + step) = crop(pack_col2patch(temp_col, pshape, param_.kernel[0], param_.kernel[1], param_.stride[0]), out[i][0].shape_); } } if (!param_.no_bias) { // add bias, broadcast bias to dim 1: channel Tensor bias = in_data[deconv::kBias].get(s); out += broadcast<1>(bias, out.shape_); } } virtual void Backward(const OpContext &ctx, const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, const std::vector &req, const std::vector &in_grad, const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; // TODO(bing): check the BLAS Handle, be careful CHECK_EQ(out_grad.size(), 1); size_t expected = param_.no_bias == 0 ? 3 : 2; CHECK(in_data.size() == expected && in_grad.size() == expected); CHECK_EQ(req.size(), expected); CHECK_EQ(in_data[deconv::kWeight].CheckContiguous(), true); // get data Stream *s = ctx.get_stream(); Tensor data = in_data[deconv::kData].get(s); Tensor grad = out_grad[deconv::kOut].get(s); Tensor gdata = in_grad[deconv::kData].get(s); Shape<3> wmat_shape = Shape3(param_.num_group, data.shape_[1] / param_.num_group, param_.num_filter / param_.num_group * param_.kernel[0] * param_.kernel[1]); Tensor wmat = in_data[deconv::kWeight].get_with_shape(wmat_shape, s); Tensor gwmat = in_grad[deconv::kWeight].get_with_shape(wmat_shape, s); #if defined(__CUDACC__) CHECK_EQ(s->blas_handle_ownership_, Stream::OwnHandle) << "Must init CuBLAS handle in stream"; #endif const index_t nbatch = data.size(0); Tensor workspace = ctx.requested[deconv::kTempSpace].get_space( Shape1(this->InitTemp(grad.shape_, data.shape_)), s); for (index_t i = 0; i < nbatch; i += nstep_) { const index_t step = std::min(nstep_, nbatch - i); Tensor temp_col = Tensor(workspace.dptr_, Shape2(shape_colunit_[0], shape_colunit_[1] * step), s); Tensor temp_dst = Tensor(workspace.dptr_ + temp_col.shape_.Size(), Shape3(shape_dstunit_[0], shape_dstunit_[1], shape_dstunit_[2] * step), s); temp_dst = reshape(swapaxis<1, 0>(data.Slice(i, i + step)), temp_dst.shape_); if (param_.pad[0] == 0 && param_.pad[1] == 0) { temp_col = unpack_patch2col(grad.Slice(i, i + step), param_.kernel[0], param_.kernel[1], param_.stride[0], param_.stride[1]); } else { temp_col = unpack_patch2col(pad(grad.Slice(i, i + step), param_.pad[0], param_.pad[1]), param_.kernel[0], param_.kernel[1], param_.stride[0], param_.stride[1]); } const index_t gstride = temp_col.size(0) / param_.num_group; for (uint32_t gid = 0; gid < param_.num_group; ++gid) { Tensor tmpc = temp_col.Slice(gstride * gid, gstride * (gid + 1)); if (i == 0) { Tensor tmp_gwmat = gwmat[gid]; Assign(tmp_gwmat, req[deconv::kWeight], dot(temp_dst[gid], tmpc.T())); } else { gwmat[gid] += dot(temp_dst[gid], tmpc.T()); } } if (req[deconv::kData] == kWriteTo || req[deconv::kData] == kWriteInplace) { for (uint32_t gid = 0; gid < param_.num_group; ++gid) { Tensor tmpc = temp_col.Slice(gstride * gid, gstride * (gid + 1)); temp_dst[gid] = dot(wmat[gid], tmpc); } gdata.Slice(i, i + step) = swapaxis<1, 0>(reshape(temp_dst, mshadow::Shape4(gdata.shape_[1], step, gdata.size(2), gdata.size(3)))); } } if (!param_.no_bias) { Tensor gbias = in_grad[deconv::kBias].get(s); Assign(gbias, req[deconv::kBias], sumall_except_dim<1>(grad)); } } private: inline index_t InitTemp(const mshadow::Shape<4> &ishape, const mshadow::Shape<4> &oshape) { const int ksize_y = param_.kernel[0]; const int ksize_x = param_.kernel[1]; shape_colunit_ = mshadow::Shape2(ishape[1] * ksize_y * ksize_x, oshape[2] * oshape[3]); shape_dstunit_ = mshadow::Shape3(param_.num_group, oshape[1] / param_.num_group, oshape[2] * oshape[3]); const uint64_t workspace_size = param_.workspace; nstep_ = std::max(std::min(static_cast(workspace_size / shape_colunit_.Size()), ishape[0]), 1U); int nop = (ishape[0] + nstep_ - 1) / nstep_; nstep_ = (ishape[0] + nop - 1) / nop; mshadow::Shape<2> scol = mshadow::Shape2(shape_colunit_[0], shape_colunit_[1] * nstep_); mshadow::Shape<3> sdst = mshadow::Shape3(shape_dstunit_[0], shape_dstunit_[1], shape_dstunit_[2] * nstep_); CHECK_GE(param_.workspace, scol.Size() + sdst.Size()) << "\nMinimum workspace size: " << scol.Size() + sdst.Size() << "\n" << "Given: " << param_.workspace; return scol.Size() + sdst.Size(); } DeconvolutionParam param_; mshadow::Shape<2> shape_colunit_; mshadow::Shape<3> shape_dstunit_; index_t nstep_; }; // class DeconvolutionOp template Operator* CreateOp(DeconvolutionParam param); #if DMLC_USE_CXX11 class DeconvolutionProp : public OperatorProperty { public: std::vector ListArguments() const override { if (!param_.no_bias) { return {"data", "weight", "bias"}; } else { return {"data", "weight"}; } } void Init(const std::vector >& kwargs) override { param_.Init(kwargs); } std::map GetParams() const override { return param_.__DICT__(); } bool InferShape(std::vector *in_shape, std::vector *out_shape, std::vector *aux_shape) const override { using namespace mshadow; if (!param_.no_bias) { CHECK_EQ(in_shape->size(), 3) << "Input:[data, weight, bias]"; } else { CHECK_EQ(in_shape->size(), 2) << "Input:[data, weight]"; } const TShape &dshape = (*in_shape)[deconv::kData]; if (dshape.ndim() == 0) return false; CHECK_EQ(dshape.ndim(), 4) \ << "Input data should be 4D in batch-num_filter-y-x"; SHAPE_ASSIGN_CHECK(*in_shape, deconv::kWeight, Shape4(dshape[1], param_.num_filter, param_.kernel[0], param_.kernel[1])); if (!param_.no_bias) { SHAPE_ASSIGN_CHECK(*in_shape, deconv::kBias, Shape1(param_.num_filter)); } out_shape->clear(); out_shape->push_back(dshape); const index_t ksize_y = static_cast(param_.kernel[0]); const index_t ksize_x = static_cast(param_.kernel[1]); CHECK_EQ(dshape[1] % param_.num_group, 0) \ << "input num_filter must divide group size"; CHECK_EQ(param_.num_filter % param_.num_group, 0) \ << "output num_filter must divide group size"; CHECK_GE(param_.kernel.Size(), 0) \ << "incorrect kernel size: " << param_.kernel; CHECK_GE(param_.stride.Size(), 0) \ << "incorrect stride size: " << param_.stride; (*out_shape)[deconv::kOut][1] = param_.num_filter; (*out_shape)[deconv::kOut][2] = param_.stride[0] * (dshape[2] - 1) + ksize_y - 2 * param_.pad[0]; (*out_shape)[deconv::kOut][3] = param_.stride[1] * (dshape[3] - 1) + ksize_x - 2 * param_.pad[1]; return true; } OperatorProperty* Copy() const override { auto ptr = new DeconvolutionProp(); ptr->param_ = param_; return ptr; } std::string TypeString() const override { return "Deconvolution"; } std::vector DeclareBackwardDependency( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data) const override { return {out_grad[deconv::kOut], in_data[deconv::kData], in_data[deconv::kWeight]}; } std::vector ForwardResource( const std::vector &in_shape) const override { return {ResourceRequest::kTempSpace}; } std::vector BackwardResource( const std::vector &in_shape) const override { return {ResourceRequest::kTempSpace}; } Operator* CreateOperator(Context ctx) const override; private: DeconvolutionParam param_; }; // class DeconvolutionProp #endif // DMLC_USE_CXX11 } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_DECONVOLUTION_INL_H_ //===== EXPANDED: mxnet/src/operator/deconvolution-inl.h ===== namespace mxnet { namespace op { template<> Operator* CreateOp(DeconvolutionParam param) { return new DeconvolutionOp(param); } Operator* DeconvolutionProp::CreateOperator(Context ctx) const { DO_BIND_DISPATCH(CreateOp, param_); } DMLC_REGISTER_PARAMETER(DeconvolutionParam); MXNET_REGISTER_OP_PROPERTY(Deconvolution, DeconvolutionProp) .add_argument("data", "Symbol", "Input data to the DeconvolutionOp.") .add_argument("weight", "Symbol", "Weight matrix.") .add_argument("bias", "Symbol", "Bias parameter.") .add_arguments(DeconvolutionParam::__FIELDS__()) .describe("Apply deconvolution to input then add a bias."); } // namespace op } // namespace mxnet //===== EXPANDED: mxnet/src/operator/deconvolution.cc ===== //===== EXPANDIND: mxnet/src/operator/native_op.cc ===== /*! * Copyright (c) 2015 by Contributors * \file native_op.cc * \brief * \author Junyuan Xie */ //===== EXPANDIND: mxnet/src/operator/native_op-inl.h ===== /*! * Copyright (c) 2015 by Contributors * \file native_op-inl.h * \brief * \author Junyuan Xie */ #ifndef MXNET_OPERATOR_NATIVE_OP_INL_H_ #define MXNET_OPERATOR_NATIVE_OP_INL_H_ //===== EXPANDIND: mxnet/include/mxnet/c_api.h ===== /*! * Copyright (c) 2015 by Contributors * \file c_api.h * \brief C API of mxnet */ #ifndef MXNET_C_API_H_ #define MXNET_C_API_H_ #ifdef __cplusplus #define MXNET_EXTERN_C extern "C" #endif /*! \brief MXNET_DLL prefix for windows" */ #ifdef _WIN32 #ifdef MXNET_EXPORTS #define MXNET_DLL MXNET_EXTERN_C __declspec(dllexport) #else #define MXNET_DLL MXNET_EXTERN_C __declspec(dllimport) #endif #else #define MXNET_DLL MXNET_EXTERN_C #endif /*! \brief manually define unsigned int */ typedef unsigned int mx_uint; /*! \brief manually define unsigned int */ typedef float mx_float; // all the handles are simply void * // will be casted internally to specific pointers types // these typedefs are mainly used for readablity reasons /*! \brief handle to NDArray */ typedef void *NDArrayHandle; /*! \brief handle to a mxnet narray function that changes NDArray */ typedef const void *FunctionHandle; /*! \brief handle to a function that takes param and creates symbol */ typedef void *AtomicSymbolCreator; /*! \brief handle to a symbol that can be bind as operator */ typedef void *SymbolHandle; /*! \brief handle to a AtomicSymbol */ typedef void *AtomicSymbolHandle; /*! \brief handle to an Executor */ typedef void *ExecutorHandle; /*! \brief handle a dataiter creator */ typedef void *DataIterCreator; /*! \brief handle to a DataIterator */ typedef void *DataIterHandle; /*! \brief handle to KVStore */ typedef void *KVStoreHandle; /*! \brief handle to RecordIO */ typedef void *RecordIOHandle; MXNET_EXTERN_C { struct NativeOpInfo { void (*forward)(int, float**, int*, unsigned**, int*, void*); void (*backward)(int, float**, int*, unsigned**, int*, void*); void (*infer_shape)(int, int*, unsigned**, void*); void (*list_outputs)(char***, void*); void (*list_arguments)(char***, void*); // all functions also pass a payload void* pointer void* p_forward; void* p_backward; void* p_infer_shape; void* p_list_outputs; void* p_list_arguments; }; } /*! * \brief return str message of the last error * all function in this file will return 0 when success * and -1 when an error occured, * MXGetLastError can be called to retrieve the error * * this function is threadsafe and can be called by different thread * \return error info */ MXNET_DLL const char *MXGetLastError(); //------------------------------------- // Part 0: Global State setups //------------------------------------- /*! * \brief Seed the global random number generators in mxnet. * \param seed the random number seed. * \return 0 when success, -1 when failure happens. */ MXNET_DLL int MXRandomSeed(int seed); /*! * \brief Notify the engine about a shutdown, * This can help engine to print less messages into display. * * User do not have to call this function. * \return 0 when success, -1 when failure happens. */ MXNET_DLL int MXNotifyShutdown(); //------------------------------------- // Part 1: NDArray creation and deletion //------------------------------------- /*! * \brief create a NDArray handle that is not initialized * can be used to pass in as mutate variables * to hold the result of NDArray * \param out the returning handle * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXNDArrayCreateNone(NDArrayHandle *out); /*! * \brief create a NDArray with specified shape * \param shape the pointer to the shape * \param ndim the dimension of the shape * \param dev_type device type, specify device we want to take * \param dev_id the device id of the specific device * \param delay_alloc whether to delay allocation until * the narray is first mutated * \param out the returning handle * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXNDArrayCreate(const mx_uint *shape, mx_uint ndim, int dev_type, int dev_id, int delay_alloc, NDArrayHandle *out); /*! * \brief create a NDArray handle that is loaded from raw bytes. * \param buf the head of the raw bytes * \param size size of the raw bytes * \param out the returning handle * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXNDArrayLoadFromRawBytes(const void *buf, size_t size, NDArrayHandle *out); /*! * \brief save the NDArray into raw bytes. * \param handle the NDArray handle * \param out_size size of the raw bytes * \param out_buf the head of returning memory bytes. * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXNDArraySaveRawBytes(NDArrayHandle handle, size_t *out_size, const char **out_buf); /*! * \brief Save list of narray into the file. * \param fname name of the file. * \param num_args number of arguments to save. * \param args the array of NDArrayHandles to be saved. * \param keys the name of the NDArray, optional, can be NULL * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXNDArraySave(const char* fname, mx_uint num_args, NDArrayHandle* args, const char** keys); /*! * \brief Load list of narray from the file. * \param fname name of the file. * \param out_size number of narray loaded. * \param out_arr head of the returning narray handles. * \param out_name_size size of output name arrray. * \param out_names the names of returning NDArrays, can be NULL * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXNDArrayLoad(const char* fname, mx_uint *out_size, NDArrayHandle** out_arr, mx_uint *out_name_size, const char*** out_names); /*! * \brief Perform a synchronize copy from a continugous CPU memory region. * * This function will call WaitToWrite before the copy is performed. * This is useful to copy data from existing memory region that are * not wrapped by NDArray(thus dependency not being tracked). * * \param handle the NDArray handle * \param data the data source to copy from. * \param size the memory size we want to copy from. */ MXNET_DLL int MXNDArraySyncCopyFromCPU(NDArrayHandle handle, const mx_float *data, size_t size); /*! * \brief Perform a synchronize copyto a continugous CPU memory region. * * This function will call WaitToRead before the copy is performed. * This is useful to copy data from existing memory region that are * not wrapped by NDArray(thus dependency not being tracked). * * \param handle the NDArray handle * \param data the data source to copy into. * \param size the memory size we want to copy into. */ MXNET_DLL int MXNDArraySyncCopyToCPU(NDArrayHandle handle, mx_float *data, size_t size); /*! * \brief Wait until all the pending writes with respect NDArray are finished. * Always call this before read data out synchronizely. * \param handle the NDArray handle * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXNDArrayWaitToRead(NDArrayHandle handle); /*! * \brief Wait until all the pending read/write with respect NDArray are finished. * Always call this before write data into NDArray synchronizely. * \param handle the NDArray handle * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXNDArrayWaitToWrite(NDArrayHandle handle); /*! * \brief wait until all delayed operations in * the system is completed * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXNDArrayWaitAll(); /*! * \brief free the narray handle * \param handle the handle to be freed * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXNDArrayFree(NDArrayHandle handle); /*! * \brief Slice the NDArray along axis 0. * \param handle the handle to the narraya * \param slice_begin The beginning index of slice * \param slice_end The ending index of slice * \param out The NDArrayHandle of sliced NDArray * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXNDArraySlice(NDArrayHandle handle, mx_uint slice_begin, mx_uint slice_end, NDArrayHandle *out); /*! * \brief get the shape of the array * \param handle the handle to the narray * \param out_dim the output dimension * \param out_pdata pointer holder to get data pointer of the shape * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle, mx_uint *out_dim, const mx_uint **out_pdata); /*! * \brief get the content of the data in NDArray * \param handle the handle to the narray * \param out_pdata pointer holder to get pointer of data * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXNDArrayGetData(NDArrayHandle handle, mx_float **out_pdata); /*! * \brief get the context of the NDArray * \param handle the handle to the narray * \param out_dev_type the output device type * \param out_dev_id the output device id * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXNDArrayGetContext(NDArrayHandle handle, int *out_dev_type, int *out_dev_id); //-------------------------------- // Part 2: functions on NDArray //-------------------------------- /*! * \brief list all the available functions handles * most user can use it to list all the needed functions * \param out_size the size of returned array * \param out_array the output function array * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXListFunctions(mx_uint *out_size, FunctionHandle **out_array); /*! * \brief get the function handle by name * \param name the name of the function * \param out the corresponding function handle * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXGetFunction(const char *name, FunctionHandle *out); /*! * \brief Get the information of the function handle. * \param fun The function handle. * \param name The returned name of the function. * \param description The returned description of the function. * \param num_args Number of arguments. * \param arg_names Name of the arguments. * \param arg_type_infos Type informations about the arguments. * \param arg_descriptions Description information about the arguments. * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXFuncGetInfo(FunctionHandle fun, const char **name, const char **description, mx_uint *num_args, const char ***arg_names, const char ***arg_type_infos, const char ***arg_descriptions); /*! * \brief get the argument requirements of the function * \param fun input function handle * \param num_use_vars how many NDArrays to be passed in as used_vars * \param num_scalars scalar variable is needed * \param num_mutate_vars how many NDArrays to be passed in as mutate_vars * \param type_mask the type mask of this function * \return 0 when success, -1 when failure happens * \sa MXFuncInvoke */ MXNET_DLL int MXFuncDescribe(FunctionHandle fun, mx_uint *num_use_vars, mx_uint *num_scalars, mx_uint *num_mutate_vars, int *type_mask); /*! * \brief invoke a function, the array size of passed in arguments * must match the values in the * \param fun the function * \param use_vars the normal arguments passed to function * \param scalar_args the scalar qarguments * \param mutate_vars the mutate arguments * \return 0 when success, -1 when failure happens * \sa MXFuncDescribeArgs */ MXNET_DLL int MXFuncInvoke(FunctionHandle fun, NDArrayHandle *use_vars, mx_float *scalar_args, NDArrayHandle *mutate_vars); //-------------------------------------------- // Part 3: symbolic configuration generation //-------------------------------------------- /*! * \brief list all the available AtomicSymbolEntry * \param out_size the size of returned array * \param out_array the output AtomicSymbolCreator array * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXSymbolListAtomicSymbolCreators(mx_uint *out_size, AtomicSymbolCreator **out_array); /*! * \brief Get the detailed information about atomic symbol. * \param creator the AtomicSymbolCreator. * \param name The returned name of the creator. * \param description The returned description of the symbol. * \param num_args Number of arguments. * \param arg_names Name of the arguments. * \param arg_type_infos Type informations about the arguments. * \param arg_descriptions Description information about the arguments. * \param key_var_num_args The keyword argument for specifying variable number of arguments. * When this parameter has non-zero length, the function allows variable number * of positional arguments, and will need the caller to pass it in in * MXSymbolCreateAtomicSymbol, * With key = key_var_num_args, and value = number of positional arguments. * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator, const char **name, const char **description, mx_uint *num_args, const char ***arg_names, const char ***arg_type_infos, const char ***arg_descriptions, const char **key_var_num_args); /*! * \brief Create an AtomicSymbol. * \param creator the AtomicSymbolCreator * \param num_param the number of parameters * \param keys the keys to the params * \param vals the vals of the params * \param out pointer to the created symbol handle * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXSymbolCreateAtomicSymbol(AtomicSymbolCreator creator, mx_uint num_param, const char **keys, const char **vals, SymbolHandle *out); /*! * \brief Create a Variable Symbol. * \param name name of the variable * \param out pointer to the created symbol handle * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXSymbolCreateVariable(const char *name, SymbolHandle *out); /*! * \brief Create a Symbol by grouping list of symbols together * \param num_symbols number of symbols to be grouped * \param symbols array of symbol handles * \param out pointer to the created symbol handle * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXSymbolCreateGroup(mx_uint num_symbols, SymbolHandle *symbols, SymbolHandle *out); /*! * \brief Load a symbol from a json file. * \param fname the file name. * \param out the output symbol. * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXSymbolCreateFromFile(const char *fname, SymbolHandle *out); /*! * \brief Load a symbol from a json string. * \param json the json string. * \param out the output symbol. * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXSymbolCreateFromJSON(const char *json, SymbolHandle *out); /*! * \brief Save a symbol into a json file. * \param symbol the input symbol. * \param fname the file name. * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXSymbolSaveToFile(SymbolHandle symbol, const char *fname); /*! * \brief Save a symbol into a json string * \param symbol the input symbol. * \param out_json output json string. * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXSymbolSaveToJSON(SymbolHandle symbol, const char **out_json); /*! * \brief Free the symbol handle. * \param symbol the symbol * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXSymbolFree(SymbolHandle symbol); /*! * \brief Copy the symbol to another handle * \param symbol the source symbol * \param out used to hold the result of copy * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXSymbolCopy(SymbolHandle symbol, SymbolHandle *out); /*! * \brief Print the content of symbol, used for debug. * \param symbol the symbol * \param out_str pointer to hold the output string of the printing. * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXSymbolPrint(SymbolHandle symbol, const char **out_str); /*! * \brief List arguments in the symbol. * \param symbol the symbol * \param out_size output size * \param out_str_array pointer to hold the output string array * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXSymbolListArguments(SymbolHandle symbol, mx_uint *out_size, const char ***out_str_array); /*! * \brief List returns in the symbol. * \param symbol the symbol * \param out_size output size * \param out_str_array pointer to hold the output string array * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXSymbolListOutputs(SymbolHandle symbol, mx_uint *out_size, const char ***out_str_array); /*! * \brief Get a symbol that contains all the internals. * \param symbol The symbol * \param out The output symbol whose outputs are all the internals. * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXSymbolGetInternals(SymbolHandle symbol, SymbolHandle *out); /*! * \brief Get index-th outputs of the symbol. * \param symbol The symbol * \param index the Index of the output. * \param out The output symbol whose outputs are the index-th symbol. * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXSymbolGetOutput(SymbolHandle symbol, mx_uint index, SymbolHandle *out); /*! * \brief List auxiliary states in the symbol. * \param symbol the symbol * \param out_size output size * \param out_str_array pointer to hold the output string array * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXSymbolListAuxiliaryStates(SymbolHandle symbol, mx_uint *out_size, const char ***out_str_array); /*! * \brief Compose the symbol on other symbols. * * This function will change the sym hanlde. * To achieve function apply behavior, copy the symbol first * before apply. * * \param sym the symbol to apply * \param name the name of symbol * \param num_args number of arguments * \param keys the key of keyword args (optional) * \param args arguments to sym * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXSymbolCompose(SymbolHandle sym, const char *name, mx_uint num_args, const char** keys, SymbolHandle* args); /*! * \brief Get the gradient graph of the symbol * * \param sym the symbol to get gradient * \param num_wrt number of arguments to get gradient * \param wrt the name of the arguments to get gradient * \param out the returned symbol that has gradient * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXSymbolGrad(SymbolHandle sym, mx_uint num_wrt, const char** wrt, SymbolHandle* out); /*! * \brief infer shape of unknown input shapes given the known one. * The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data * The call will be treated as a kwargs call if key != nullptr or num_args==0, otherwise it is positional. * * \param sym symbol handle * \param num_args numbe of input arguments. * \param keys the key of keyword args (optional) * \param arg_ind_ptr the head pointer of the rows in CSR * \param arg_shape_data the content of the CSR * \param in_shape_size sizeof the returning array of in_shapes * \param in_shape_ndim returning array of shape dimensions of eachs input shape. * \param in_shape_data returning array of pointers to head of the input shape. * \param out_shape_size sizeof the returning array of out_shapes * \param out_shape_ndim returning array of shape dimensions of eachs input shape. * \param out_shape_data returning array of pointers to head of the input shape. * \param aux_shape_size sizeof the returning array of aux_shapes * \param aux_shape_ndim returning array of shape dimensions of eachs auxiliary shape. * \param aux_shape_data returning array of pointers to head of the auxiliary shape. * \param complete whether infer shape completes or more information is needed. * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXSymbolInferShape(SymbolHandle sym, mx_uint num_args, const char** keys, const mx_uint *arg_ind_ptr, const mx_uint *arg_shape_data, mx_uint *in_shape_size, const mx_uint **in_shape_ndim, const mx_uint ***in_shape_data, mx_uint *out_shape_size, const mx_uint **out_shape_ndim, const mx_uint ***out_shape_data, mx_uint *aux_shape_size, const mx_uint **aux_shape_ndim, const mx_uint ***aux_shape_data, int *complete); //-------------------------------------------- // Part 4: Executor interface //-------------------------------------------- /*! * \brief Delete the executor * \param handle the executor. * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXExecutorFree(ExecutorHandle handle); /*! * \brief Print the content of execution plan, used for debug. * \param handle the executor. * \param out_str pointer to hold the output string of the printing. * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXExecutorPrint(ExecutorHandle handle, const char **out_str); /*! * \brief Executor forward method * * \param handle executor handle * \param is_train bool value to indicate whether the forward pass is for evaluation * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXExecutorForward(ExecutorHandle handle, int is_train); /*! * \brief Excecutor run backward * * \param handle execute handle * \param len lenth * \param head_grads NDArray handle for heads' gradient * * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXExecutorBackward(ExecutorHandle handle, mx_uint len, NDArrayHandle *head_grads); /*! * \brief Get executor's head NDArray * * \param handle executor handle * \param out_size output narray vector size * \param out out put narray handles * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXExecutorOutputs(ExecutorHandle handle, mx_uint *out_size, NDArrayHandle **out); /*! * \brief Generate Executor from symbol * * \param symbol_handle symbol handle * \param dev_type device type * \param dev_id device id * \param len length * \param in_args in args array * \param arg_grad_store arg grads handle array * \param grad_req_type grad req array * \param aux_states_len length of auxiliary states * \param aux_states auxiliary states array * \param out output executor handle * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXExecutorBind(SymbolHandle symbol_handle, int dev_type, int dev_id, mx_uint len, NDArrayHandle *in_args, NDArrayHandle *arg_grad_store, mx_uint *grad_req_type, mx_uint aux_states_len, NDArrayHandle *aux_states, ExecutorHandle *out); //-------------------------------------------- // Part 5: IO Interface //-------------------------------------------- /*! * \brief List all the available iterator entries * \param out_size the size of returned iterators * \param out_array the output iteratos entries * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXListDataIters(mx_uint *out_size, DataIterCreator **out_array); /*! * \brief Init an iterator, init with parameters * the array size of passed in arguments * \param handle of the iterator creator * \param num_param number of parameter * \param keys parameter keys * \param vals parameter values * \param out resulting iterator * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXDataIterCreateIter(DataIterCreator handle, mx_uint num_param, const char **keys, const char **vals, DataIterHandle *out); /*! * \brief Get the detailed information about data iterator. * \param creator the DataIterCreator. * \param name The returned name of the creator. * \param description The returned description of the symbol. * \param num_args Number of arguments. * \param arg_names Name of the arguments. * \param arg_type_infos Type informations about the arguments. * \param arg_descriptions Description information about the arguments. * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXDataIterGetIterInfo(DataIterCreator creator, const char **name, const char **description, mx_uint *num_args, const char ***arg_names, const char ***arg_type_infos, const char ***arg_descriptions); /*! * \brief Free the handle to the IO module * \param handle the handle pointer to the data iterator * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXDataIterFree(DataIterHandle handle); /*! * \brief Move iterator to next position * \param handle the handle to iterator * \param out return value of next * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXDataIterNext(DataIterHandle handle, int *out); /*! * \brief Call iterator.Reset * \param handle the handle to iterator * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXDataIterBeforeFirst(DataIterHandle handle); /*! * \brief Get the handle to the NDArray of underlying data * \param handle the handle pointer to the data iterator * \param out handle to underlying data NDArray * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXDataIterGetData(DataIterHandle handle, NDArrayHandle *out); /*! * \brief Get the padding number in current data batch * \param handle the handle pointer to the data iterator * \param pad pad number ptr * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXDataIterGetPadNum(DataIterHandle handle, int *pad); /*! * \brief Get the handle to the NDArray of underlying label * \param handle the handle pointer to the data iterator * \param out the handle to underlying label NDArray * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXDataIterGetLabel(DataIterHandle handle, NDArrayHandle *out); //-------------------------------------------- // Part 5: basic KVStore interface //-------------------------------------------- /*! * \brief Create a kvstore * \param type the type of KVStore * \param out The output type of KVStore * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXKVStoreCreate(const char *type, KVStoreHandle *out); /*! * \brief Delete a KVStore handle. * \param handle handle to the kvstore * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXKVStoreFree(KVStoreHandle handle); /*! * \brief Init a list of (key,value) pairs in kvstore * \param handle handle to the kvstore * \param num the number of key-value pairs * \param keys the list of keys * \param vals the list of values * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXKVStoreInit(KVStoreHandle handle, mx_uint num, const int* keys, NDArrayHandle* vals); /*! * \brief Push a list of (key,value) pairs to kvstore * \param handle handle to the kvstore * \param num the number of key-value pairs * \param keys the list of keys * \param vals the list of values * \param priority the priority of the action * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXKVStorePush(KVStoreHandle handle, mx_uint num, const int* keys, NDArrayHandle* vals, int priority); /*! * \brief pull a list of (key, value) pairs from the kvstore * \param handle handle to the kvstore * \param num the number of key-value pairs * \param keys the list of keys * \param vals the list of values * \param priority the priority of the action * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXKVStorePull(KVStoreHandle handle, mx_uint num, const int* keys, NDArrayHandle* vals, int priority); /*! * \brief user-defined updater for the kvstore * It's this updater's responsibility to delete \a recv and \a local * \param the key * \param recv the pushed value on this key * \param local the value stored on local on this key * \param handle The additional handle to the updater */ typedef void (MXKVStoreUpdater)(int key, NDArrayHandle recv, NDArrayHandle local, void *handle); /*! * \brief register an push updater * \param handle handle to the KVStore * \param updater udpater function * \param updater_handle The additional handle used to invoke the updater * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXKVStoreSetUpdater(KVStoreHandle handle, MXKVStoreUpdater updater, void *updater_handle); /*! * \brief get the type of the kvstore * \param handle handle to the KVStore * \param type a string type * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXKVStoreGetType(KVStoreHandle handle, const char** type); //-------------------------------------------- // Part 6: advanced KVStore for multi-machines //-------------------------------------------- /** * \brief return The rank of this node in its group, which is in [0, GroupSize). * * \param handle handle to the KVStore * \param ret the node rank * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXKVStoreGetRank(KVStoreHandle handle, int *ret); /** * \brief return The number of nodes in this group, which is * - number of workers if if `IsWorkerNode() == true`, * - number of servers if if `IsServerNode() == true`, * - 1 if `IsSchedulerNode() == true`, * \param handle handle to the KVStore * \param ret the group size * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXKVStoreGetGroupSize(KVStoreHandle handle, int *ret); /** * \brief return whether or not this process is a worker node. * \param ret 1 for yes, 0 for no * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXKVStoreIsWorkerNode(int *ret); /** * \brief return whether or not this process is a server node. * \param ret 1 for yes, 0 for no * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXKVStoreIsServerNode(int *ret); /** * \brief return whether or not this process is a scheduler node. * \param ret 1 for yes, 0 for no * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXKVStoreIsSchedulerNode(int *ret); /** * \brief global barrier among all worker machines * * \param handle handle to the KVStore * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXKVStoreBarrier(KVStoreHandle handle); /** * \brief the prototype of a server controller * \param head the head of the command * \param body the body of the command */ typedef void (MXKVStoreServerController)(int head, const char* body); /** * \return Run as server (or scheduler) * * \param handle handle to the KVStore * \param controller the user-defined server controller * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXKVStoreRunServer(KVStoreHandle handle, MXKVStoreServerController controller); /** * \return Send a command to all server nodes * * \param handle handle to the KVStore * \param cmd_id the head of the command * \param cmd_body the body of the command * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXKVStoreSendCommmandToServers(KVStoreHandle handle, int cmd_id, const char* cmd_body); /** * \brief Create a RecordIO writer object * \param uri path to file * \param out handle pointer to the created object * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXRecordIOWriterCreate(const char *uri, RecordIOHandle *out); /** * \brief Delete a RecordIO writer object * \param handle handle to RecordIO object * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXRecordIOWriterFree(RecordIOHandle handle); /** * \brief Write a record to a RecordIO object * \param handle handle to RecordIO object * \param buf buffer to write * \param size size of buffer * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXRecordIOWriterWriteRecord(RecordIOHandle *handle, const char *buf, size_t size); /** * \brief Create a RecordIO reader object * \param uri path to file * \param out handle pointer to the created object * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXRecordIOReaderCreate(const char *uri, RecordIOHandle *out); /** * \brief Delete a RecordIO reader object * \param handle handle to RecordIO object * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXRecordIOReaderFree(RecordIOHandle *handle); /** * \brief Write a record to a RecordIO object * \param handle handle to RecordIO object * \param buf pointer to return buffer * \param size point to size of buffer * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXRecordIOReaderReadRecord(RecordIOHandle *handle, char const **buf, size_t *size); #endif // MXNET_C_API_H_ //===== EXPANDED: mxnet/include/mxnet/c_api.h ===== namespace mxnet { namespace op { struct NativeOpParam : public dmlc::Parameter { void *info; bool need_top_grad; NativeOpInfo *pinfo; int num_inputs_, num_outputs_; DMLC_DECLARE_PARAMETER(NativeOpParam) { DMLC_DECLARE_FIELD(info); DMLC_DECLARE_FIELD(need_top_grad).set_default(true) .describe("Whether this layer needs out grad for backward. " "Should be false for loss layers."); } }; template class NativeOp : public Operator { public: explicit NativeOp(NativeOpParam p) { this->param_ = p; } virtual void Forward(const OpContext &ctx, const std::vector &in_data, const std::vector &req, const std::vector &out_data, const std::vector &aux_args) { using namespace mshadow; Stream *s = ctx.get_stream(); ptrs.clear(); ndims.clear(); shapes.clear(); tags.clear(); SyncVec(in_data, "in_data", s, 0); SyncVec(out_data, "out_data", s, 1); s->Wait(); param_.pinfo->forward(ptrs.size(), ptrs.data(), ndims.data(), shapes.data(), tags.data(), param_.pinfo->p_forward); for (index_t i = 0; i < out_data.size(); ++i) { CHECK_NE(req[i], kAddTo) << "NativeOp doesn't support AddTo for output"; if (req[i] != kNullOp) { std::stringstream ss; ss << std::string("out_data") << i; Copy(out_data[i].FlatTo2D(s), buffer_map[ss.str()].second, s); } } s->Wait(); } virtual void Backward(const OpContext &ctx, const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, const std::vector &req, const std::vector &in_grad, const std::vector &aux_args) { using namespace mshadow; Stream *s = ctx.get_stream(); ptrs.clear(); ndims.clear(); shapes.clear(); tags.clear(); SyncVec(in_data, "in_data", s, 0); SyncVec(out_data, "out_data", s, 1); SyncVec(in_grad, "in_grad", s, 2); if (param_.need_top_grad) { SyncVec(out_grad, "out_grad", s, 3); } s->Wait(); param_.pinfo->backward(ptrs.size(), ptrs.data(), ndims.data(), shapes.data(), tags.data(), param_.pinfo->p_backward); for (index_t i = 0; i < in_grad.size(); ++i) { CHECK_NE(req[i], kAddTo) << "NativeOp doesn't support AddTo for output"; if (req[i] != kNullOp) { std::stringstream ss; ss << std::string("in_grad") << i; Copy(in_grad[i].FlatTo2D(s), buffer_map[ss.str()].second, s); } } s->Wait(); } private: NativeOpParam param_; std::vector ptrs; std::vector ndims; std::vector shapes; std::vector tags; std::map > > buffer_map; virtual void SyncBuffer(const TBlob &tblob, const std::string &name, mshadow::Stream *stream) { using namespace mshadow; std::map > >::iterator buffer = buffer_map.find(name); if (buffer == buffer_map.end() || buffer->second.first != tblob.shape_) { if (buffer != buffer_map.end()) { FreeSpace<2, real_t>(&(buffer->second.second)); buffer_map.erase(buffer); } buffer_map[name] = std::pair >(tblob.shape_, NewTensor(tblob.shape_.FlatTo2D(), 0.0f, false)); buffer = buffer_map.find(name); } Copy(buffer->second.second, tblob.FlatTo2D(stream), stream); } virtual void SyncVec(const std::vector &vec, const std::string &prefix, mshadow::Stream *stream, int tag) { for (size_t i = 0; i < vec.size(); ++i) { std::stringstream name; name << prefix << i; SyncBuffer(vec[i], name.str(), stream); ptrs.push_back(buffer_map[name.str()].second.dptr_); ndims.push_back(vec[i].ndim()); shapes.push_back(const_cast(vec[i].shape_.data())); tags.push_back(tag); } } }; // NativeOp template Operator* CreateOp(NativeOpParam param); #if DMLC_USE_CXX11 class NativeOpProp : public OperatorProperty { public: std::vector ListArguments() const override { char ** args = NULL; param_.pinfo->list_arguments(&args, param_.pinfo->p_list_arguments); std::vector ret; for (int i = 0; args[i] != NULL; ++i) { ret.push_back(args[i]); } return ret; } std::vector ListOutputs() const override { char ** args = NULL; param_.pinfo->list_outputs(&args, param_.pinfo->p_list_outputs); std::vector ret; for (int i = 0; args[i] != NULL; ++i) { ret.push_back(args[i]); } return ret; } int NumOutputs() const override { return param_.num_outputs_; } void Init(const std::vector >& kwargs) override { param_.Init(kwargs); for (auto iter = kwargs.begin(); iter != kwargs.end(); ++iter) { if (iter->first == "info") { sscanf(iter->second.c_str(), "%p", ¶m_.pinfo); } } param_.num_inputs_ = ListArguments().size(); param_.num_outputs_ = ListOutputs().size(); } std::map GetParams() const override { return param_.__DICT__(); } bool InferShape(std::vector *in_shape, std::vector *out_shape, std::vector *aux_shape) const override { std::vector shapes; std::vector ndims; for (auto iter = in_shape->begin(); iter != in_shape->end(); ++iter) { shapes.push_back(iter->data()); ndims.push_back(iter->ndim()); } shapes.resize(param_.num_inputs_+param_.num_outputs_); ndims.resize(param_.num_inputs_+param_.num_outputs_); param_.pinfo->infer_shape(shapes.size(), ndims.data(), shapes.data(), param_.pinfo->p_infer_shape); for (unsigned i = 0; i < in_shape->size(); ++i) { SHAPE_ASSIGN_CHECK(*in_shape, i, TShape(shapes[i], shapes[i]+ndims[i])); } out_shape->clear(); for (unsigned i = param_.num_inputs_; i < shapes.size(); ++i) { out_shape->push_back(TShape(shapes[i], shapes[i]+ndims[i])); } return true; } OperatorProperty* Copy() const override { NativeOpProp *prop_sym = new NativeOpProp(); prop_sym->param_ = this->param_; return prop_sym; } std::string TypeString() const override { return "_Native"; } std::vector DeclareBackwardDependency( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data) const override { std::vector deps; if (param_.need_top_grad) { deps.insert(deps.end(), out_grad.begin(), out_grad.end()); } deps.insert(deps.end(), in_data.begin(), in_data.end()); deps.insert(deps.end(), out_data.begin(), out_data.end()); return deps; } std::vector > BackwardInplaceOption( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, const std::vector &in_grad) const override { return {}; } Operator* CreateOperator(Context ctx) const override; private: NativeOpParam param_; }; // class PythonProp #endif // DMLC_USE_CXX11 } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_NATIVE_OP_INL_H_ //===== EXPANDED: mxnet/src/operator/native_op-inl.h ===== namespace mxnet { namespace op { template<> Operator *CreateOp(NativeOpParam param) { return new NativeOp(param); } Operator* NativeOpProp::CreateOperator(Context ctx) const { DO_BIND_DISPATCH(CreateOp, param_); } DMLC_REGISTER_PARAMETER(NativeOpParam); MXNET_REGISTER_OP_PROPERTY(_Native, NativeOpProp) .describe("Stub for implementing an operator implemented in native frontend language.") .add_arguments(NativeOpParam::__FIELDS__()); } // namespace op } // namespace mxnet //===== EXPANDED: mxnet/src/operator/native_op.cc ===== //===== EXPANDIND: mxnet/src/storage/storage.cc ===== /*! * Copyright (c) 2015 by Contributors */ //===== EXPANDIND: mxnet/src/storage/storage_manager.h ===== /*! * Copyright (c) 2015 by Contributors * \file storage_manager.h * \brief Storage manager. */ #ifndef MXNET_STORAGE_STORAGE_MANAGER_H_ #define MXNET_STORAGE_STORAGE_MANAGER_H_ namespace mxnet { namespace storage { /*! * \brief Storage manager interface. */ class StorageManager { public: /*! * \brief Allocation. * \param size Size to allocate. * \return Pointer to the storage. */ virtual void* Alloc(size_t size) = 0; /*! * \brief Deallocation. * \param ptr Pointer to deallocate. * \param size Size of the storage. */ virtual void Free(void* ptr, size_t size) = 0; /*! * \brief Destructor. */ virtual ~StorageManager() = default; }; // namespace StorageManager } // namespace storage } // namespace mxnet #endif // MXNET_STORAGE_STORAGE_MANAGER_H_ //===== EXPANDED: mxnet/src/storage/storage_manager.h ===== //===== EXPANDIND: mxnet/src/storage/naive_storage_manager.h ===== /*! * Copyright (c) 2015 by Contributors * \file naive_storage_manager.h * \brief Naive storage manager. */ #ifndef MXNET_STORAGE_NAIVE_STORAGE_MANAGER_H_ #define MXNET_STORAGE_NAIVE_STORAGE_MANAGER_H_ namespace mxnet { namespace storage { /*! * \brief Naive storage manager. */ template class NaiveStorageManager final : public StorageManager { public: /*! * \brief Default constructor. */ NaiveStorageManager() = default; /*! * \brief Default destructor. */ ~NaiveStorageManager() = default; void* Alloc(size_t size) override; void Free(void* ptr, size_t) override; private: DISALLOW_COPY_AND_ASSIGN(NaiveStorageManager); }; // class NaiveStorageManager template void* NaiveStorageManager::Alloc(size_t size) { return DeviceStorage::Alloc(size); } template void NaiveStorageManager::Free(void* ptr, size_t) { DeviceStorage::Free(ptr); } } // namespace storage } // namespace mxnet #endif // MXNET_STORAGE_NAIVE_STORAGE_MANAGER_H_ //===== EXPANDED: mxnet/src/storage/naive_storage_manager.h ===== //===== EXPANDIND: mxnet/src/storage/pooled_storage_manager.h ===== /*! * Copyright (c) 2015 by Contributors * \file pooled_storage_manager.h * \brief Storage manager with a memory pool. */ #ifndef MXNET_STORAGE_POOLED_STORAGE_MANAGER_H_ #define MXNET_STORAGE_POOLED_STORAGE_MANAGER_H_ namespace mxnet { namespace storage { /*! * \brief Storage manager with a memory pool. */ template class PooledStorageManager final : public StorageManager { public: /*! * \brief Default constructor. */ PooledStorageManager() = default; /*! * \brief Default destructor. */ ~PooledStorageManager() { ReleaseAll(); } void* Alloc(size_t size) override; void Free(void* ptr, size_t size) override; private: void ReleaseAll(); // internal mutex std::mutex mutex_; // used memory size_t used_memory_ = 0; // memory pool std::unordered_map> memory_pool_; DISALLOW_COPY_AND_ASSIGN(PooledStorageManager); }; // class PooledStorageManager template void* PooledStorageManager::Alloc(size_t size) { std::lock_guard lock(mutex_); auto&& reuse_it = memory_pool_.find(size); if (reuse_it == memory_pool_.end() || reuse_it->second.size() == 0) { if (kThreshold <= used_memory_) { ReleaseAll(); } used_memory_ += size; return DeviceStorage::Alloc(size); } else { auto&& reuse_pool = reuse_it->second; auto ret = reuse_pool.back(); reuse_pool.pop_back(); return ret; } } template void PooledStorageManager::Free(void* ptr, size_t size) { std::lock_guard lock(mutex_); auto&& reuse_pool = memory_pool_[size]; reuse_pool.push_back(ptr); } template void PooledStorageManager::ReleaseAll() { for (auto&& i : memory_pool_) { for (auto&& j : i.second) { DeviceStorage::Free(j); used_memory_ -= i.first; } } memory_pool_.clear(); } } // namespace storage } // namespace mxnet #endif // MXNET_STORAGE_POOLED_STORAGE_MANAGER_H_ //===== EXPANDED: mxnet/src/storage/pooled_storage_manager.h ===== //===== EXPANDIND: mxnet/src/storage/cpu_device_storage.h ===== /*! * Copyright (c) 2015 by Contributors * \file cpu_device_storage.h * \brief CPU storage implementation. */ #ifndef MXNET_STORAGE_CPU_DEVICE_STORAGE_H_ #define MXNET_STORAGE_CPU_DEVICE_STORAGE_H_ namespace mxnet { namespace storage { /*! * \brief CPU storage implementation. */ class CPUDeviceStorage { public: /*! * \brief Aligned allocation on CPU. * \param size Size to allocate. * \return Pointer to the storage. */ inline static void* Alloc(size_t size); /*! * \brief Deallocation. * \param ptr Pointer to deallocate. */ inline static void Free(void* ptr); private: /*! * \brief Alignment of allocation. */ static constexpr size_t alignment_ = 16; }; // class CPUDeviceStorage inline void* CPUDeviceStorage::Alloc(size_t size) { #if _MSC_VER void* ptr; ptr = _aligned_malloc(size, alignment_); return CHECK_NOTNULL(ptr); #else void* ptr; int ret = posix_memalign(&ptr, alignment_, size); CHECK_EQ(ret, 0) << "Allocation failed"; return ptr; #endif } inline void CPUDeviceStorage::Free(void* ptr) { #if _MSC_VER _aligned_free(ptr); #else free(ptr); #endif } } // namespace storage } // namespace mxnet #endif // MXNET_STORAGE_CPU_DEVICE_STORAGE_H_ //===== EXPANDED: mxnet/src/storage/cpu_device_storage.h ===== //===== EXPANDIND: mxnet/src/storage/gpu_device_storage.h ===== /*! * Copyright (c) 2015 by Contributors * \file gpu_device_storage.h * \brief GPU storage implementation. */ #ifndef MXNET_STORAGE_GPU_DEVICE_STORAGE_H_ #define MXNET_STORAGE_GPU_DEVICE_STORAGE_H_ #if MXNET_USE_CUDA #endif // MXNET_USE_CUDA namespace mxnet { namespace storage { /*! * \brief GPU storage implementation. */ class GPUDeviceStorage { public: /*! * \brief Allocation. * \param size Size to allocate. * \return Pointer to the storage. */ inline static void* Alloc(size_t size); /*! * \brief Deallocation. * \param ptr Pointer to deallocate. */ inline static void Free(void* ptr); }; // class GPUDeviceStorage inline void* GPUDeviceStorage::Alloc(size_t size) { void* ret = nullptr; #if MXNET_USE_CUDA CUDA_CALL(cudaMalloc(&ret, size)); #else // MXNET_USE_CUDA LOG(FATAL) << "Please compile with CUDA enabled"; #endif // MXNET_USE_CUDA return ret; } inline void GPUDeviceStorage::Free(void* ptr) { #if MXNET_USE_CUDA // throw special exception for caller to catch. cudaError_t err = cudaFree(ptr); // ignore unloading error, as memory has already been recycled if (err != cudaSuccess && err != cudaErrorCudartUnloading) { LOG(FATAL) << "CUDA: " << cudaGetErrorString(err); } #else // MXNET_USE_CUDA LOG(FATAL) << "Please compile with CUDA enabled"; #endif // MXNET_USE_CUDA } } // namespace storage } // namespace mxnet #endif // MXNET_STORAGE_GPU_DEVICE_STORAGE_H_ //===== EXPANDED: mxnet/src/storage/gpu_device_storage.h ===== //===== EXPANDIND: mxnet/src/storage/pinned_memory_storage.h ===== /*! * Copyright (c) 2015 by Contributors * \file cpu_device_storage.h * \brief CPU storage with pinned memory */ #ifndef MXNET_STORAGE_PINNED_MEMORY_STORAGE_H_ #define MXNET_STORAGE_PINNED_MEMORY_STORAGE_H_ namespace mxnet { namespace storage { class PinnedMemoryStorage { public: /*! * \brief Allocation. * \param size Size to allocate. * \return Pointer to the storage. */ inline static void* Alloc(size_t size); /*! * \brief Deallocation. * \param ptr Pointer to deallocate. */ inline static void Free(void* ptr); }; inline void* PinnedMemoryStorage::Alloc(size_t size) { void* ret = nullptr; #if MXNET_USE_CUDA // make the memory available across all devices CUDA_CALL(cudaHostAlloc(&ret, size, cudaHostAllocPortable)); #else // MXNET_USE_CUDA LOG(FATAL) << "Please compile with CUDA enabled"; #endif // MXNET_USE_CUDA return ret; } inline void PinnedMemoryStorage::Free(void* ptr) { #if MXNET_USE_CUDA cudaError_t err = cudaFreeHost(ptr); // ignore unloading error, as memory has already been recycled if (err != cudaSuccess && err != cudaErrorCudartUnloading) { LOG(FATAL) << "CUDA: " << cudaGetErrorString(err); } #else // MXNET_USE_CUDA LOG(FATAL) << "Please compile with CUDA enabled"; #endif // MXNET_USE_CUDA } } // namespace storage } // namespace mxnet #endif // MXNET_STORAGE_PINNED_MEMORY_STORAGE_H_ //===== EXPANDED: mxnet/src/storage/pinned_memory_storage.h ===== namespace mxnet { // consider change storage as a pure abstract class class StorageImpl : public Storage { public: Handle Alloc(size_t size, Context ctx) override; void Free(Handle handle) override; virtual ~StorageImpl() = default; private: static constexpr size_t kPoolThreshold = 4096 * 1024 * 1024ul; static constexpr size_t kMaxNumberOfDevices = Context::kMaxDevType + 1; static constexpr size_t kMaxNumberOfDeviceIDs = Context::kMaxDevID + 1; template using CurrentStorageManager = storage::PooledStorageManager; static void ActivateDevice(Context ctx) { switch (ctx.dev_type) { case Context::kCPU: break; case Context::kGPU: case Context::kCPUPinned: #if MXNET_USE_CUDA CUDA_CALL(cudaSetDevice(ctx.dev_id)); #else // MXNET_USE_CUDA LOG(FATAL) << "Please compile with CUDA enabled"; #endif // MXNET_USE_CUDA break; default: LOG(FATAL) << "Unimplemented device"; } } // internal storage managers std::array, kMaxNumberOfDevices> storage_managers_; }; // struct Storage::Impl Storage::Handle StorageImpl::Alloc(size_t size, Context ctx) { // space already recycled, ignore request Handle hd; hd.ctx = ctx; hd.size = size; auto&& device = storage_managers_.at(ctx.dev_type); storage::StorageManager *manager = device.Get( ctx.dev_id, [ctx]() { storage::StorageManager *ptr = nullptr; switch (ctx.dev_type) { case Context::kCPU: { ptr = new CurrentStorageManager(); break; } case Context::kCPUPinned: { ptr = new CurrentStorageManager(); break; } case Context::kGPU: { ptr = new CurrentStorageManager(); break; } default: LOG(FATAL) << "Unimplemented device"; } return ptr; }); this->ActivateDevice(ctx); hd.dptr = manager->Alloc(size); return hd; } void StorageImpl::Free(Storage::Handle handle) { const Context &ctx = handle.ctx; auto&& device = storage_managers_.at(ctx.dev_type); storage::StorageManager *maneger = device.Get( ctx.dev_id, []() { LOG(FATAL) << "Cannot Free space to a device you have not allocated"; return nullptr; }); this->ActivateDevice(ctx); maneger->Free(handle.dptr, handle.size); } std::shared_ptr Storage::_GetSharedRef() { static std::shared_ptr inst(new StorageImpl()); return inst; } Storage* Storage::Get() { static Storage *ptr = _GetSharedRef().get(); return ptr; } } // namespace mxnet //===== EXPANDED: mxnet/src/storage/storage.cc ===== //===== EXPANDIND: mxnet/src/common/tblob_op_registry.cc ===== /*! * Copyright (c) 2015 by Contributors * \file tblob_op_registry.cc * Implementation of tblob op registry */ namespace mxnet { namespace common { class TBlobUnaryOpProp; class TBlobOpRegEntryImpl : public TBlobOpRegEntry { public: // functions TSelf& set_function(int dev_mask, UnaryFunction funary, bool inplace_in_out, bool register_symbolic) override { std::lock_guard lock(mutex_); ++reg_counter_; if (funary_.size() <= static_cast(dev_mask)) { funary_.resize(dev_mask + 1, nullptr); } if (funary_[dev_mask] != nullptr) { LOG(FATAL) << "Device function " << this->name << " already registerd for device " << dev_mask; } funary_[dev_mask] = funary; inplace_in0_out_forward_ = inplace_in_out; if (reg_counter_ == 1) { this->RegisterUnary(); register_symbolic_ = register_symbolic; if (register_symbolic) { this->RegisterUnarySymbolic(); } } return *this; } TSelf& set_gradient(int dev_mask, UnaryGradType1 fgrad, bool inplace_out_in_grad) override { std::lock_guard lock(mutex_); if (funary_grad_t1_.size() <= static_cast(dev_mask)) { funary_grad_t1_.resize(dev_mask + 1, nullptr); } if (funary_grad_t1_[dev_mask] != nullptr) { LOG(FATAL) << "Device gradient function " << this->name << " already registerd for device " << dev_mask; } funary_grad_t1_[dev_mask] = fgrad; inplace_out_in0_grad_ = inplace_out_in_grad; return *this; } TSelf& set_gradient(int dev_mask, UnaryGradType2 fgrad, bool inplace_out_in_grad) override { std::lock_guard lock(mutex_); if (funary_grad_t2_.size() <= static_cast(dev_mask)) { funary_grad_t2_.resize(dev_mask + 1, nullptr); } if (funary_grad_t2_[dev_mask] != nullptr) { LOG(FATAL) << "Device gradient function " << this->name << " already registerd for device " << dev_mask; } funary_grad_t2_[dev_mask] = fgrad; inplace_out_in0_grad_ = inplace_out_in_grad; return *this; } TSelf& set_shape_infer(UnaryShapeInfer fshapeinfer) override { std::lock_guard lock(mutex_); unary_infer_ = fshapeinfer; return *this; } TSelf& describe(const std::string &description) override { std::lock_guard lock(mutex_); if (reg_counter_ != 1) return *this; NDArrayReg().describe(description); if (register_symbolic_) { OpReg().describe(description); } return *this; } private: // make friend with unary op friend class TBlobUnaryOpProp; // internal mutex std::mutex mutex_; // registration counter int reg_counter_{0}; bool register_symbolic_{true}; // unary shape inferencer UnaryShapeInfer unary_infer_{nullptr}; // unary functions on each device mask std::vector funary_; // type 1 gradient function std::vector funary_grad_t1_; // type 2 gradient function std::vector funary_grad_t2_; // whether do inplace optimization of in 0 and output bool inplace_in0_out_forward_{true}; // whether do inplace optimization of out_grad and in_grad0 bool inplace_out_in0_grad_{false}; // NDArray registry NDArrayFunctionReg *ndarray_reg_{nullptr}; OperatorPropertyReg *op_reg_{nullptr}; // internal function to register NDArray function. inline NDArrayFunctionReg &NDArrayReg() { if (ndarray_reg_ == nullptr) { NDArrayFunctionReg ® = ::dmlc::Registry::Get()->__REGISTER__(this->name); ndarray_reg_ = ® } return *ndarray_reg_; } // internal function to register NDArray function. inline OperatorPropertyReg &OpReg() { if (op_reg_ == nullptr) { OperatorPropertyReg ® = ::dmlc::Registry::Get()->__REGISTER__(this->name); op_reg_ = ® } return *op_reg_; } // start registering all stuffs void RegisterUnary(); void RegisterUnarySymbolic(); }; // Unary operator to invoke generic TBlob function. struct TBlobUnaryOperator : public Operator { TBlobOpRegEntry::UnaryFunction forward; TBlobOpRegEntry::UnaryGradType1 backward1{nullptr}; TBlobOpRegEntry::UnaryGradType2 backward2{nullptr}; void Forward(const OpContext &ctx, const std::vector &in_data, const std::vector &req, const std::vector &out_data, const std::vector &aux_args) override { CHECK_EQ(in_data.size(), 1); CHECK_EQ(out_data.size(), 1); TBlob out = out_data[0]; (*forward)(in_data[0], &out, req[0], ctx.run_ctx); } void Backward(const OpContext &ctx, const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, const std::vector &req, const std::vector &in_grad, const std::vector &aux_args) override { CHECK_EQ(out_grad.size(), 1); CHECK(in_data.size() == 1 && in_grad.size() == 1); CHECK_EQ(req.size(), 1); arg::OutGrad ograd; ograd.data = out_grad[0]; TBlob igrad = in_grad[0]; if (backward1 != nullptr) { arg::OutValue out_value; out_value.data = out_data[0]; (*backward1)(ograd, out_value, &igrad, req[0], ctx.run_ctx); } else if (backward2 != nullptr) { arg::Input0 in0; in0.data = in_data[0]; (*backward2)(ograd, in0, &igrad, req[0], ctx.run_ctx); } else { LOG(FATAL) << "Backward is not supported"; } } }; // class UnaryOperator class TBlobUnaryOpProp : public OperatorProperty { public: std::string name; TBlobOpRegEntryImpl* source; void Init(const std::vector >& kwargs) override { } std::map GetParams() const override { return std::map(); } bool InferShape(std::vector *in_shape, std::vector *out_shape, std::vector *aux_shape) const override { using namespace mshadow; CHECK_EQ(in_shape->size(), 1) << "Input:[data]"; const TShape &dshape = in_shape->at(0); if (dshape.ndim() == 0) return false; out_shape->clear(); if (source->unary_infer_ == nullptr) { out_shape->push_back(dshape); } else { out_shape->push_back((*(source->unary_infer_))(dshape)); } return true; } OperatorProperty* Copy() const override { auto ptr = new TBlobUnaryOpProp(); ptr->source = source; ptr->name = name; return ptr; } std::string TypeString() const override { return name; } // decalre dependency and inplace optimization options std::vector DeclareBackwardDependency( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data) const override { if (source->funary_grad_t1_.size() != 0) { return {out_grad[0], out_data[0]}; } else if (source->funary_grad_t2_.size() != 0) { return {out_grad[0], in_data[0]}; } else { LOG(FATAL) << "Backward of " << name << " is not decalred"; return {}; } } std::vector > BackwardInplaceOption( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, const std::vector &in_grad) const override { if (source->inplace_out_in0_grad_) { return {{out_grad[0], in_grad[0]}}; } else { return {}; } } std::vector > ForwardInplaceOption( const std::vector &in_data, const std::vector &out_data) const override { if (source->inplace_in0_out_forward_) { return {{in_data[0], out_data[0]}}; } else { return {}; } } Operator* CreateOperator(Context ctx) const override { size_t dev_mask = ctx.dev_mask(); TBlobUnaryOperator *op = new TBlobUnaryOperator(); CHECK(dev_mask < source->funary_.size() && source->funary_[dev_mask] != nullptr); op->forward = source->funary_[dev_mask]; if (dev_mask < source->funary_grad_t1_.size()) { op->backward1 = source->funary_grad_t1_[dev_mask]; } if (dev_mask < source->funary_grad_t2_.size()) { op->backward2 = source->funary_grad_t2_[dev_mask]; } return op; } }; void TBlobOpRegEntryImpl::RegisterUnary() { CHECK_EQ(reg_counter_, 1); // The body to be registered auto body = [this] (NDArray **used_vars, real_t *s, NDArray **mutate_vars) { NDArray src = *used_vars[0]; NDArray *out = mutate_vars[0]; TShape dshape = src.shape(); if (unary_infer_ != nullptr) dshape = unary_infer_(dshape); if (out->is_none()) { *out = NDArray(dshape, src.ctx(), true); } else { CHECK(out->ctx() == src.ctx()) << "target context mismatch"; CHECK(out->shape() == dshape) << "target shape mismatch " << out->shape() << " vs. " << dshape; } // important: callback must always capture by value NDArray ret = *out; // get the const variables std::vector const_vars; if (src.var() != ret.var()) const_vars.push_back(src.var()); // check if the function exist int dev_mask = src.ctx().dev_mask(); if (static_cast(dev_mask) >= funary_.size() || funary_[dev_mask] == nullptr) { if (dev_mask == gpu::kDevMask) LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; LOG(FATAL) << "Function " << this->name << "not registered for device " << dev_mask; } // invoke the function UnaryFunction fun = funary_[dev_mask]; Engine::Get()->PushSync([src, ret, fun, dev_mask](RunContext ctx) { ret.CheckAndAlloc(); TBlob tmp = ret.data(); (*fun)(src.data(), &tmp, kWriteTo, ctx); #if MXNET_USE_CUDA if (dev_mask == gpu::kDevMask) { ctx.get_stream()->Wait(); } #endif }, src.ctx(), const_vars, {ret.var()}); }; // register the function. NDArrayReg() .set_body(body) .set_num_use_vars(1) .set_num_mutate_vars(1) .set_type_mask(kNDArrayArgBeforeScalar | kAcceptEmptyMutateTarget) .add_argument("src", "NDArray", "Source input to the function"); } void TBlobOpRegEntryImpl::RegisterUnarySymbolic() { // register the operator auto op_factory = [this]() { TBlobUnaryOpProp *prop = new TBlobUnaryOpProp(); prop->name = this->name; prop->source = this; return prop; }; OpReg() .set_body(op_factory) .add_argument("src", "Symbol", "Source symbolic input to the function"); } TBlobOpRegEntry& TBlobOpRegistry::__REGISTER_OR_FIND__(const std::string &name) { if (fmap_.count(name) != 0) return *fmap_.at(name); TBlobOpRegEntry *e = new TBlobOpRegEntryImpl(); e->name = name; fmap_[name] = e; return *e; } TBlobOpRegistry* TBlobOpRegistry::Get() { static TBlobOpRegistry inst; return &inst; } TBlobOpRegistry::~TBlobOpRegistry() { for (auto kv : fmap_) { delete kv.second; } } } // namespace common } // namespace mxnet //===== EXPANDED: mxnet/src/common/tblob_op_registry.cc ===== //===== EXPANDIND: mxnet/src/resource.cc ===== /*! * Copyright (c) 2015 by Contributors * \file resource.cc * \brief Implementation of resource manager. */ namespace mxnet { namespace resource { // implements resource manager class ResourceManagerImpl : public ResourceManager { public: ResourceManagerImpl() noexcept(false) : global_seed_(0) { cpu_temp_space_copy_ = dmlc::GetEnv("MXNET_CPU_TEMP_COPY", 16); gpu_temp_space_copy_ = dmlc::GetEnv("MXNET_GPU_TEMP_COPY", 4); engine_ref_ = Engine::_GetSharedRef(); cpu_rand_.reset(new ResourceRandom( Context::CPU(), global_seed_)); cpu_space_.reset(new ResourceTempSpace( Context::CPU(), cpu_temp_space_copy_)); } ~ResourceManagerImpl() { // need explicit delete, before engine get killed cpu_rand_.reset(nullptr); cpu_space_.reset(nullptr); #if MXNET_USE_CUDA gpu_rand_.Clear(); gpu_space_.Clear(); #endif if (engine_ref_ != nullptr) { // release the reference to engine. engine_ref_ = nullptr; } } // request resources Resource Request(Context ctx, const ResourceRequest &req) override { if (ctx.dev_mask() == cpu::kDevMask) { switch (req.type) { case ResourceRequest::kRandom: return cpu_rand_->resource; case ResourceRequest::kTempSpace: return cpu_space_->GetNext(); default: LOG(FATAL) << "Unknown supported type " << req.type; } } else { CHECK_EQ(ctx.dev_mask(), gpu::kDevMask); #if MSHADOW_USE_CUDA switch (req.type) { case ResourceRequest::kRandom: { return gpu_rand_.Get(ctx.dev_id, [ctx, this]() { return new ResourceRandom(ctx, global_seed_); })->resource; } case ResourceRequest::kTempSpace: { return gpu_space_.Get(ctx.dev_id, [ctx, this]() { return new ResourceTempSpace(ctx, gpu_temp_space_copy_); })->GetNext(); } default: LOG(FATAL) << "Unknown supported type " << req.type; } #else LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; #endif } Resource ret; return ret; } void SeedRandom(uint32_t seed) override { global_seed_ = seed; cpu_rand_->Seed(global_seed_); #if MXNET_USE_CUDA gpu_rand_.ForEach([seed](size_t i, ResourceRandom *p) { p->Seed(seed); }); #endif } private: /*! \brief Maximum number of GPUs */ static constexpr std::size_t kMaxNumGPUs = 16; /*! \brief Random number magic number to seed different random numbers */ static constexpr uint32_t kRandMagic = 127UL; // the random number resources template struct ResourceRandom { /*! \brief the context of the PRNG */ Context ctx; /*! \brief pointer to PRNG */ mshadow::Random *prnd; /*! \brief resource representation */ Resource resource; /*! \brief constructor */ explicit ResourceRandom(Context ctx, uint32_t global_seed) : ctx(ctx) { mshadow::SetDevice(ctx.dev_id); resource.var = Engine::Get()->NewVariable(); prnd = new mshadow::Random(ctx.dev_id + global_seed * kRandMagic); resource.ptr_ = prnd; resource.req = ResourceRequest(ResourceRequest::kRandom); } ~ResourceRandom() { mshadow::Random *r = prnd; Engine::Get()->DeleteVariable( [r](RunContext rctx) { MSHADOW_CATCH_ERROR(delete r); }, ctx, resource.var); } // set seed to a PRNG inline void Seed(uint32_t global_seed) { uint32_t seed = ctx.dev_id + global_seed * kRandMagic; mshadow::Random *r = prnd; Engine::Get()->PushSync([r, seed](RunContext rctx) { r->set_stream(rctx.get_stream()); r->Seed(seed); }, ctx, {}, {resource.var}); } }; // temporal space resource. template struct ResourceTempSpace { /*! \brief the context of the device */ Context ctx; /*! \brief the underlying space */ std::vector*> space; /*! \brief resource representation */ std::vector resource; /*! \brief current pointer to the round roubin alloator */ std::atomic curr_ptr; /*! \brief constructor */ explicit ResourceTempSpace(Context ctx, size_t ncopy) : ctx(ctx), space(ncopy), resource(ncopy), curr_ptr(0) { mshadow::SetDevice(ctx.dev_id); for (size_t i = 0; i < space.size(); ++i) { space[i] = new mshadow::TensorContainer(); resource[i].var = Engine::Get()->NewVariable(); resource[i].id = static_cast(i); resource[i].ptr_ = space[i]; resource[i].req = ResourceRequest(ResourceRequest::kTempSpace); } } ~ResourceTempSpace() { for (size_t i = 0; i < space.size(); ++i) { mshadow::TensorContainer* r = space[i]; Engine::Get()->DeleteVariable( [r](RunContext rctx){ MSHADOW_CATCH_ERROR(r->Release()); }, ctx, resource[i].var); } } // get next resource in round roubin matter inline Resource GetNext() { const size_t kMaxDigit = std::numeric_limits::max() / 2; size_t ptr = ++curr_ptr; // reset ptr to avoid undefined behavior during overflow // usually this won't happen if (ptr > kMaxDigit) { curr_ptr.store((ptr + 1) % space.size()); } return resource[ptr % space.size()]; } }; /*! \brief number of copies in CPU temp space */ int cpu_temp_space_copy_; /*! \brief number of copies in GPU temp space */ int gpu_temp_space_copy_; /*! \brief Reference to the engine */ std::shared_ptr engine_ref_; /*! \brief internal seed to the random number generator */ uint32_t global_seed_; /*! \brief CPU random number resources */ std::unique_ptr > cpu_rand_; /*! \brief CPU temp space resources */ std::unique_ptr > cpu_space_; #if MXNET_USE_CUDA /*! \brief random number generator for GPU */ common::LazyAllocArray > gpu_rand_; /*! \brief temp space for GPU */ common::LazyAllocArray > gpu_space_; #endif }; } // namespace resource ResourceManager* ResourceManager::Get() { static resource::ResourceManagerImpl inst; return &inst; } } // namespace mxnet //===== EXPANDED: mxnet/src/resource.cc ===== //===== EXPANDIND: mxnet/src/c_api/c_api.cc ===== /*! * Copyright (c) 2015 by Contributors * \file c_api.cc * \brief C API of mxnet */ //===== EXPANDIND: mxnet/dmlc-core/include/dmlc/memory_io.h ===== /*! * Copyright (c) 2015 by Contributors * \file memory_io.h * \brief defines binary serialization class to serialize things into/from memory region. */ #ifndef DMLC_MEMORY_IO_H_ #define DMLC_MEMORY_IO_H_ namespace dmlc { /*! * \brief A Stream that operates on fixed region of memory * This class allows us to read/write from/to a fixed memory region. */ struct MemoryFixedSizeStream : public SeekStream { public: /*! * \brief constructor * \param p_buffer the head pointer of the memory region. * \param buffer_size the size of the memorybuffer */ MemoryFixedSizeStream(void *p_buffer, size_t buffer_size) : p_buffer_(reinterpret_cast(p_buffer)), buffer_size_(buffer_size) { curr_ptr_ = 0; } virtual size_t Read(void *ptr, size_t size) { CHECK(curr_ptr_ + size <= buffer_size_); size_t nread = std::min(buffer_size_ - curr_ptr_, size); if (nread != 0) std::memcpy(ptr, p_buffer_ + curr_ptr_, nread); curr_ptr_ += nread; return nread; } virtual void Write(const void *ptr, size_t size) { if (size == 0) return; CHECK(curr_ptr_ + size <= buffer_size_); std::memcpy(p_buffer_ + curr_ptr_, ptr, size); curr_ptr_ += size; } virtual void Seek(size_t pos) { curr_ptr_ = static_cast(pos); } virtual size_t Tell(void) { return curr_ptr_; } private: /*! \brief in memory buffer */ char *p_buffer_; /*! \brief current pointer */ size_t buffer_size_; /*! \brief current pointer */ size_t curr_ptr_; }; // class MemoryFixedSizeStream /*! * \brief A in memory stream that is backed by std::string. * This class allows us to read/write from/to a std::string. */ struct MemoryStringStream : public dmlc::SeekStream { public: /*! * \brief constructor * \param p_buffer the pointer to the string. */ explicit MemoryStringStream(std::string *p_buffer) : p_buffer_(p_buffer) { curr_ptr_ = 0; } virtual size_t Read(void *ptr, size_t size) { CHECK(curr_ptr_ <= p_buffer_->length()); size_t nread = std::min(p_buffer_->length() - curr_ptr_, size); if (nread != 0) std::memcpy(ptr, &(*p_buffer_)[0] + curr_ptr_, nread); curr_ptr_ += nread; return nread; } virtual void Write(const void *ptr, size_t size) { if (size == 0) return; if (curr_ptr_ + size > p_buffer_->length()) { p_buffer_->resize(curr_ptr_+size); } std::memcpy(&(*p_buffer_)[0] + curr_ptr_, ptr, size); curr_ptr_ += size; } virtual void Seek(size_t pos) { curr_ptr_ = static_cast(pos); } virtual size_t Tell(void) { return curr_ptr_; } private: /*! \brief in memory buffer */ std::string *p_buffer_; /*! \brief current pointer */ size_t curr_ptr_; }; // class MemoryStringStream } // namespace dmlc #endif // DMLC_MEMORY_IO_H_ //===== EXPANDED: mxnet/dmlc-core/include/dmlc/memory_io.h ===== //===== EXPANDIND: mxnet/dmlc-core/include/dmlc/recordio.h ===== /*! * Copyright (c) 2015 by Contributors * \file recordio.h * \brief recordio that is able to pack binary data into a splittable * format, useful to exchange data in binary serialization, * such as binary raw data or protobuf */ #ifndef DMLC_RECORDIO_H_ #define DMLC_RECORDIO_H_ namespace dmlc { /*! * \brief writer of binary recordio * binary format for recordio * recordio format: magic lrecord data pad * * - magic is magic number * - pad is simply a padding space to make record align to 4 bytes * - lrecord encodes length and continue bit * - data.length() = (lrecord & (1U<<29U - 1)); * - cflag == (lrecord >> 29U) & 7; * * cflag was used to handle (rare) special case when magic number * occured in the data sequence. * * In such case, the data is splitted into multiple records by * the cells of magic number * * (1) cflag == 0: this is a complete record; * (2) cflag == 1: start of a multiple-rec; * cflag == 2: middle of multiple-rec; * cflag == 3: end of multiple-rec */ class RecordIOWriter { public: /*! * \brief magic number of recordio * note: (kMagic >> 29U) & 7 > 3 * this ensures lrec will not be kMagic */ static const uint32_t kMagic = 0xced7230a; /*! * \brief encode the lrecord * \param cflag cflag part of the lrecord * \param length length part of lrecord * \return the encoded data */ inline static uint32_t EncodeLRec(uint32_t cflag, uint32_t length) { return (cflag << 29U) | length; } /*! * \brief decode the flag part of lrecord * \param rec the lrecord * \return the flag */ inline static uint32_t DecodeFlag(uint32_t rec) { return (rec >> 29U) & 7U; } /*! * \brief decode the length part of lrecord * \param rec the lrecord * \return the length */ inline static uint32_t DecodeLength(uint32_t rec) { return rec & ((1U << 29U) - 1U); } /*! * \brief constructor * \param stream the stream to be constructed */ explicit RecordIOWriter(Stream *stream) : stream_(stream), except_counter_(0) { CHECK(sizeof(uint32_t) == 4) << "uint32_t needs to be 4 bytes"; } /*! * \brief write record to the stream * \param buf the buffer of memory region * \param size the size of record to write out */ void WriteRecord(const void *buf, size_t size); /*! * \brief write record to the stream * \param data the data to write out */ inline void WriteRecord(const std::string &data) { this->WriteRecord(data.c_str(), data.length()); } /*! * \return number of exceptions(occurance of magic number) * during the writing process */ inline size_t except_counter(void) const { return except_counter_; } private: /*! \brief output stream */ Stream *stream_; /*! \brief counts the number of exceptions */ size_t except_counter_; }; /*! * \brief reader of binary recordio to reads in record from stream * \sa RecordIOWriter */ class RecordIOReader { public: /*! * \brief constructor * \param stream the stream to be constructed */ explicit RecordIOReader(Stream *stream) : stream_(stream), end_of_stream_(false) { CHECK(sizeof(uint32_t) == 4) << "uint32_t needs to be 4 bytes"; } /*! * \brief read next complete record from stream * \param out_rec used to store output record in string * \return true of read was successful, false if end of stream was reached */ bool NextRecord(std::string *out_rec); private: /*! \brief output stream */ Stream *stream_; /*! \brief whether we are at end of stream */ bool end_of_stream_; }; /*! * \brief reader of binary recordio from Blob returned by InputSplit * This class divides the blob into several independent parts specified by caller, * and read from one segment. * The part reading can be used together with InputSplit::NextChunk for * multi-threaded parsing(each thread take a RecordIOChunkReader) * * \sa RecordIOWriter, InputSplit */ class RecordIOChunkReader { public: /*! * \brief constructor * \param chunk source data returned by InputSplit * \param part_index which part we want to reado * \param num_parts number of total segments */ explicit RecordIOChunkReader(InputSplit::Blob chunk, unsigned part_index = 0, unsigned num_parts = 1); /*! * \brief read next complete record from stream * the blob contains the memory content * NOTE: this function is not threadsafe, use one * RecordIOChunkReader per thread * \param out_rec used to store output blob, the header is already * removed and out_rec only contains the memory content * \return true of read was successful, false if end was reached */ bool NextRecord(InputSplit::Blob *out_rec); private: /*! \brief internal temporal data */ std::string temp_; /*! \brief internal data pointer */ char *pbegin_, *pend_; }; } // namespace dmlc #endif // DMLC_RECORDIO_H_ //===== EXPANDED: mxnet/dmlc-core/include/dmlc/recordio.h ===== //===== EXPANDIND: mxnet/src/c_api/c_api_error.h ===== /*! * Copyright (c) 2015 by Contributors * \file c_api_error.h * \brief Error handling for C API. */ #ifndef MXNET_C_API_C_API_ERROR_H_ #define MXNET_C_API_C_API_ERROR_H_ /*! \brief macro to guard beginning and end section of all functions */ #define API_BEGIN() try { /*! \brief every function starts with API_BEGIN(); and finishes with API_END() or API_END_HANDLE_ERROR */ #define API_END() } catch(dmlc::Error &_except_) { return MXAPIHandleException(_except_); } return 0; // NOLINT(*) /*! * \brief every function starts with API_BEGIN(); * and finishes with API_END() or API_END_HANDLE_ERROR * The finally clause contains procedure to cleanup states when an error happens. */ #define API_END_HANDLE_ERROR(Finalize) } catch(dmlc::Error &_except_) { Finalize; return MXAPIHandleException(_except_); } return 0; // NOLINT(*) /*! * \brief Set the last error message needed by C API * \param msg The error message to set. */ void MXAPISetLastError(const char* msg); /*! * \brief handle exception throwed out * \param e the exception * \return the return value of API after exception is handled */ inline int MXAPIHandleException(const dmlc::Error &e) { MXAPISetLastError(e.what()); return -1; } #endif // MXNET_C_API_C_API_ERROR_H_ //===== EXPANDED: mxnet/src/c_api/c_api_error.h ===== //===== EXPANDIND: mxnet/src/common/thread_local.h ===== /*! * Copyright (c) 2015 by Contributors * \file thread_local.h * \brief Common utility for thread local storage. */ #ifndef MXNET_COMMON_THREAD_LOCAL_H_ #define MXNET_COMMON_THREAD_LOCAL_H_ namespace mxnet { namespace common { // macro hanlding for threadlocal variables #ifdef __GNUC__ #define MX_TREAD_LOCAL __thread #elif __STDC_VERSION__ >= 201112L #define MX_TREAD_LOCAL _Thread_local #elif defined(_MSC_VER) #define MX_TREAD_LOCAL __declspec(thread) #endif #ifndef MX_TREAD_LOCAL #message("Warning: Threadlocal is not enabled"); #endif /*! * \brief A threadlocal store to store threadlocal variables. * Will return a thread local singleton of type T * \tparam T the type we like to store */ template class ThreadLocalStore { public: /*! \return get a thread local singleton */ static T* Get() { static MX_TREAD_LOCAL T* ptr = nullptr; if (ptr == nullptr) { ptr = new T(); Singleton()->RegisterDelete(ptr); } return ptr; } private: /*! \brief constructor */ ThreadLocalStore() {} /*! \brief destructor */ ~ThreadLocalStore() { for (size_t i = 0; i < data_.size(); ++i) { delete data_[i]; } } /*! \return singleton of the store */ static ThreadLocalStore *Singleton() { static ThreadLocalStore inst; return &inst; } /*! * \brief register str for internal deletion * \param str the string pointer */ void RegisterDelete(T *str) { std::unique_lock lock(mutex_); data_.push_back(str); lock.unlock(); } /*! \brief internal mutex */ std::mutex mutex_; /*!\brief internal data */ std::vector data_; }; } // namespace common } // namespace mxnet #endif // MXNET_COMMON_THREAD_LOCAL_H_ //===== EXPANDED: mxnet/src/common/thread_local.h ===== using namespace mxnet; /*! \brief entry to to easily hold returning information */ struct MXAPIThreadLocalEntry { /*! \brief result holder for returning string */ std::string ret_str; /*! \brief result holder for returning strings */ std::vector ret_vec_str; /*! \brief result holder for returning string pointers */ std::vector ret_vec_charp; /*! \brief result holder for returning handles */ std::vector ret_handles; /*! \brief result holder for returning shapes */ std::vector arg_shapes, out_shapes, aux_shapes; /*! \brief result holder for returning shape dimensions */ std::vector arg_shape_ndim, out_shape_ndim, aux_shape_ndim; /*! \brief result holder for returning shape pointer */ std::vector arg_shape_data, out_shape_data, aux_shape_data; // helper function to setup return value of shape array inline static void SetupShapeArrayReturn( const std::vector &shapes, std::vector *ndim, std::vector *data) { ndim->resize(shapes.size()); data->resize(shapes.size()); for (size_t i = 0; i < shapes.size(); ++i) { ndim->at(i) = shapes[i].ndim(); data->at(i) = shapes[i].data(); } } }; // define the threadlocal store. typedef mxnet::common::ThreadLocalStore MXAPIThreadLocalStore; // Internal function to get the information // from function registry // Used to implement MXSymbolGetAtomicSymbolInfo and MXFuncGetInfo template inline int MXAPIGetFunctionRegInfo(const FunRegType *e, const char **name, const char **description, mx_uint *num_args, const char ***arg_names, const char ***arg_type_infos, const char ***arg_descriptions) { MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); API_BEGIN(); *name = e->name.c_str(); *description = e->description.c_str(); *num_args = static_cast(e->arguments.size()); ret->ret_vec_charp.clear(); for (size_t i = 0; i < e->arguments.size(); ++i) { ret->ret_vec_charp.push_back(e->arguments[i].name.c_str()); } for (size_t i = 0; i < e->arguments.size(); ++i) { ret->ret_vec_charp.push_back(e->arguments[i].type_info_str.c_str()); } for (size_t i = 0; i < e->arguments.size(); ++i) { ret->ret_vec_charp.push_back(e->arguments[i].description.c_str()); } *arg_names = dmlc::BeginPtr(ret->ret_vec_charp); *arg_type_infos = dmlc::BeginPtr(ret->ret_vec_charp) + e->arguments.size(); *arg_descriptions = dmlc::BeginPtr(ret->ret_vec_charp) + (e->arguments.size() * 2); API_END(); } // NOTE: return value is added in API_END int MXRandomSeed(int seed) { API_BEGIN(); mxnet::RandomSeed(seed); API_END(); } int MXNotifyShutdown() { API_BEGIN(); Engine::Get()->NotifyShutdown(); API_END(); } int MXNDArrayCreateNone(NDArrayHandle *out) { API_BEGIN(); *out = new NDArray(); API_END(); } int MXNDArrayCreate(const mx_uint *shape, mx_uint ndim, int dev_type, int dev_id, int delay_alloc, NDArrayHandle *out) { API_BEGIN(); *out = new NDArray( TShape(shape, shape + ndim), Context::Create(static_cast(dev_type), dev_id), delay_alloc != 0); API_END(); } int MXNDArrayLoadFromRawBytes(const void *buf, size_t size, NDArrayHandle *out) { NDArray *ptr = nullptr; API_BEGIN(); dmlc::MemoryFixedSizeStream strm((void*)buf, size); // NOLINT(*) ptr = new NDArray(); if (!ptr->Load(&strm)) { throw dmlc::Error("Invalid NDArray serialization format"); } *out = ptr; API_END_HANDLE_ERROR(delete ptr); } int MXNDArraySaveRawBytes(NDArrayHandle handle, size_t *out_size, const char **out_buf) { MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); API_BEGIN(); ret->ret_str.resize(0); dmlc::MemoryStringStream strm(&ret->ret_str); static_cast(handle)->Save(&strm); *out_size = ret->ret_str.length(); *out_buf = ret->ret_str.c_str(); API_END(); } int MXNDArraySyncCopyFromCPU(NDArrayHandle handle, const mx_float *data, size_t size) { API_BEGIN(); static_cast(handle)->SyncCopyFromCPU(data, size); API_END(); } int MXNDArraySyncCopyToCPU(NDArrayHandle handle, mx_float *data, size_t size) { API_BEGIN(); static_cast(handle)->SyncCopyToCPU(data, size); API_END(); } int MXNDArrayWaitToRead(NDArrayHandle handle) { API_BEGIN(); static_cast(handle)->WaitToRead(); API_END(); } int MXNDArrayWaitToWrite(NDArrayHandle handle) { API_BEGIN(); static_cast(handle)->WaitToWrite(); API_END(); } int MXNDArraySave(const char* fname, mx_uint num_args, NDArrayHandle* args, const char** keys) { API_BEGIN(); std::vector data(num_args); std::vector names; for (mx_uint i = 0; i < num_args; ++i) { data[i] = *static_cast(args[i]); } if (keys != nullptr) { names.resize(num_args); for (mx_uint i = 0; i < num_args; ++i) { names[i] = keys[i]; } } { std::unique_ptr fo(dmlc::Stream::Create(fname, "w")); mxnet::NDArray::Save(fo.get(), data, names); } API_END(); } int MXNDArrayLoad(const char* fname, mx_uint *out_size, NDArrayHandle** out_arr, mx_uint *out_name_size, const char*** out_names) { MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); ret->ret_vec_str.clear(); API_BEGIN(); std::vector data; std::vector &names = ret->ret_vec_str; { std::unique_ptr fi(dmlc::Stream::Create(fname, "r")); mxnet::NDArray::Load(fi.get(), &data, &names); } ret->ret_handles.resize(data.size()); for (size_t i = 0; i < data.size(); ++i) { NDArray *ptr = new NDArray(); *ptr = data[i]; ret->ret_handles[i] = ptr; } ret->ret_vec_charp.resize(names.size()); for (size_t i = 0; i < names.size(); ++i) { ret->ret_vec_charp[i] = names[i].c_str(); } *out_size = static_cast(data.size()); *out_arr = dmlc::BeginPtr(ret->ret_handles); *out_name_size = static_cast(names.size()); *out_names = dmlc::BeginPtr(ret->ret_vec_charp); API_END(); } int MXNDArrayFree(NDArrayHandle handle) { API_BEGIN(); delete static_cast(handle); API_END(); } int MXNDArraySlice(NDArrayHandle handle, mx_uint slice_begin, mx_uint slice_end, NDArrayHandle *out) { NDArray *ptr = new NDArray(); API_BEGIN(); *ptr = static_cast(handle)->Slice( slice_begin, slice_end); *out = ptr; API_END_HANDLE_ERROR(delete ptr); } int MXNDArrayGetShape(NDArrayHandle handle, mx_uint *out_dim, const mx_uint **out_pdata) { API_BEGIN(); NDArray *arr = static_cast(handle); if (!arr->is_none()) { const TShape &s = arr->shape(); *out_dim = s.ndim(); *out_pdata = s.data(); } else { *out_dim = 0; } API_END(); } int MXNDArrayGetData(NDArrayHandle handle, mx_float **out_pdata) { API_BEGIN(); NDArray *arr = static_cast(handle); if (!arr->is_none()) { CHECK(arr->ctx().dev_mask() == cpu::kDevMask) << "MXNDArrayGetData can only be called for NDArray on CPU"; const TBlob &b = arr->data(); CHECK(b.CheckContiguous()); *out_pdata = b.FlatTo2D().dptr_; } else { *out_pdata = nullptr; } API_END(); } int MXNDArrayGetContext(NDArrayHandle handle, int *out_dev_type, int *out_dev_id) { API_BEGIN(); NDArray *arr = static_cast(handle); if (!arr->is_none()) { const Context &ctx = arr->ctx(); *out_dev_type = ctx.dev_type; *out_dev_id = ctx.dev_id; } else { *out_dev_type = 0; *out_dev_id = 0; } API_END(); } int MXListFunctions(mx_uint *out_size, FunctionHandle **out_array) { API_BEGIN(); auto &vec = dmlc::Registry::List(); *out_size = static_cast(vec.size()); *out_array = (FunctionHandle*)(dmlc::BeginPtr(vec)); // NOLINT(*) API_END(); } int MXGetFunction(const char *name, FunctionHandle *out) { API_BEGIN(); *out = dmlc::Registry::Find(name); API_END(); } int MXFuncGetInfo(FunctionHandle fun, const char **name, const char **description, mx_uint *num_args, const char ***arg_names, const char ***arg_type_infos, const char ***arg_descriptions) { return MXAPIGetFunctionRegInfo(static_cast(fun), name, description, num_args, arg_names, arg_type_infos, arg_descriptions); } int MXFuncDescribe(FunctionHandle fun, mx_uint *num_use_vars, mx_uint *num_scalars, mx_uint *num_mutate_vars, int *type_mask) { API_BEGIN(); auto *f = static_cast(fun); *num_use_vars = f->num_use_vars; *num_scalars = f->num_scalars; *num_mutate_vars = f->num_mutate_vars; *type_mask = f->type_mask; API_END(); } int MXFuncInvoke(FunctionHandle fun, NDArrayHandle *use_vars, mx_float *scalar_args, NDArrayHandle *mutate_vars) { API_BEGIN(); auto *f = static_cast(fun); f->body((NDArray**)(use_vars), // NOLINT(*) scalar_args, (NDArray**)(mutate_vars)); // NOLINT(*) API_END(); } //-------------------------------------------- // Part 3: symbolic configuration generation //-------------------------------------------- int MXSymbolListAtomicSymbolCreators(mx_uint *out_size, AtomicSymbolCreator **out_array) { API_BEGIN(); auto &vec = dmlc::Registry::List(); *out_size = static_cast(vec.size()); *out_array = (AtomicSymbolCreator*)(dmlc::BeginPtr(vec)); // NOLINT(*) API_END(); } int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator, const char **out) { API_BEGIN(); OperatorPropertyReg *e = static_cast(creator); *out = e->name.c_str(); API_END(); } int MXSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator, const char **name, const char **description, mx_uint *num_args, const char ***arg_names, const char ***arg_type_infos, const char ***arg_descriptions, const char **key_var_num_args) { OperatorPropertyReg *e = static_cast(creator); *key_var_num_args = e->key_var_num_args.c_str(); return MXAPIGetFunctionRegInfo(e, name, description, num_args, arg_names, arg_type_infos, arg_descriptions); } int MXSymbolCreateAtomicSymbol(AtomicSymbolCreator creator, mx_uint num_param, const char **keys, const char **vals, SymbolHandle *out) { Symbol *s = new Symbol(); OperatorProperty *op = nullptr; API_BEGIN(); OperatorPropertyReg *e = static_cast(creator); op = e->body(); std::vector > kwargs; for (mx_uint i = 0; i < num_param; ++i) { kwargs.push_back({std::string(keys[i]), std::string(vals[i])}); } op->Init(kwargs); *s = Symbol::Create(op); *out = s; API_END_HANDLE_ERROR(delete s; delete op); } int MXSymbolCreateVariable(const char *name, SymbolHandle *out) { Symbol *s = new Symbol(); API_BEGIN(); *s = Symbol::CreateVariable(name); *out = s; API_END_HANDLE_ERROR(delete s); } int MXSymbolCreateGroup(mx_uint num_symbols, SymbolHandle *symbols, SymbolHandle *out) { Symbol *s = new Symbol(); Symbol **sym_arr = (Symbol**)symbols; // NOLINT(*) API_BEGIN(); std::vector syms; for (mx_uint i = 0; i < num_symbols; ++i) { syms.push_back(*sym_arr[i]); } *s = Symbol::CreateGroup(syms); *out = s; API_END_HANDLE_ERROR(delete s); } int MXSymbolGetOutput(SymbolHandle symbol, mx_uint index, SymbolHandle *out) { Symbol *s = new Symbol(); API_BEGIN(); *s = (*static_cast(symbol))[index]; *out = s; API_END_HANDLE_ERROR(delete s); } int MXSymbolGetInternals(SymbolHandle symbol, SymbolHandle *out) { Symbol *s = new Symbol(); API_BEGIN(); *s = static_cast(symbol)->GetInternals(); *out = s; API_END_HANDLE_ERROR(delete s); } int MXSymbolCreateFromFile(const char *fname, SymbolHandle *out) { Symbol *s = new Symbol(); API_BEGIN(); std::unique_ptr fi(dmlc::Stream::Create(fname, "r")); dmlc::istream is(fi.get()); dmlc::JSONReader reader(&is); s->Load(&reader); // reset file pointer is.set_stream(nullptr); *out = s; API_END_HANDLE_ERROR(delete s); } int MXSymbolCreateFromJSON(const char *json, SymbolHandle *out) { Symbol *s = new Symbol(); API_BEGIN(); std::string buf(json); std::istringstream is(buf); dmlc::JSONReader reader(&is); s->Load(&reader); *out = s; API_END_HANDLE_ERROR(delete s); } int MXSymbolSaveToFile(SymbolHandle symbol, const char *fname) { Symbol *s = static_cast(symbol); API_BEGIN(); std::unique_ptr fo(dmlc::Stream::Create(fname, "w")); dmlc::ostream os(fo.get()); dmlc::JSONWriter writer(&os); s->Save(&writer); // reset file pointer, force flush os.set_stream(nullptr); API_END(); } int MXSymbolSaveToJSON(SymbolHandle symbol, const char **out_json) { Symbol *s = static_cast(symbol); MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); API_BEGIN(); std::ostringstream os; dmlc::JSONWriter writer(&os); s->Save(&writer); ret->ret_str = os.str(); *out_json = ret->ret_str.c_str(); API_END(); } int MXSymbolFree(SymbolHandle symbol) { API_BEGIN(); delete static_cast(symbol); API_END(); } int MXSymbolCopy(SymbolHandle symbol, SymbolHandle *out) { Symbol *s = new Symbol(); API_BEGIN(); *s = static_cast(symbol)->Copy(); *out = s; API_END_HANDLE_ERROR(delete s); } int MXSymbolPrint(SymbolHandle symbol, const char **out_str) { Symbol *s = static_cast(symbol); MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); API_BEGIN(); std::ostringstream os; s->Print(os); ret->ret_str = os.str(); *out_str = (ret->ret_str).c_str(); API_END(); } int MXSymbolListArguments(SymbolHandle symbol, mx_uint *out_size, const char ***out_str_array) { Symbol *s = static_cast(symbol); MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); API_BEGIN(); ret->ret_vec_str = std::move(s->ListArguments()); ret->ret_vec_charp.clear(); for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str()); } *out_size = static_cast(ret->ret_vec_charp.size()); *out_str_array = dmlc::BeginPtr(ret->ret_vec_charp); API_END(); } int MXSymbolListOutputs(SymbolHandle symbol, mx_uint *out_size, const char ***out_str_array) { Symbol *s = static_cast(symbol); MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); API_BEGIN(); ret->ret_vec_str = std::move(s->ListOutputs()); ret->ret_vec_charp.clear(); for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str()); } *out_size = static_cast(ret->ret_vec_charp.size()); *out_str_array = dmlc::BeginPtr(ret->ret_vec_charp); API_END(); } int MXSymbolListAuxiliaryStates(SymbolHandle symbol, mx_uint *out_size, const char ***out_str_array) { Symbol *s = static_cast(symbol); MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); API_BEGIN(); ret->ret_vec_str = std::move(s->ListAuxiliaryStates()); ret->ret_vec_charp.clear(); for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str()); } *out_size = static_cast(ret->ret_vec_charp.size()); *out_str_array = dmlc::BeginPtr(ret->ret_vec_charp); API_END(); } int MXSymbolCompose(SymbolHandle sym, const char *name, mx_uint num_args, const char** keys, SymbolHandle* args) { API_BEGIN(); std::string s_name; if (name != nullptr) s_name = name; Symbol* s = static_cast(sym); if (keys == nullptr && num_args != 0) { std::vector pos_args; for (mx_uint i = 0; i < num_args; ++i) { pos_args.push_back(*((Symbol*)args[i])); // NOLINT(*) } s->Compose(pos_args, s_name); } else { std::unordered_map kwargs; for (mx_uint i = 0; i < num_args; ++i) { kwargs[keys[i]] = *((Symbol*)args[i]); // NOLINT(*) } s->Compose(kwargs, s_name); } API_END(); } int MXSymbolGrad(SymbolHandle sym, mx_uint num_wrt, const char** wrt, SymbolHandle* out) { API_BEGIN(); Symbol* s = static_cast(sym); std::vector wrts(num_wrt); for (mx_uint i = 0; i < num_wrt; ++i) { wrts[i] = wrt[i]; } Symbol* ret = new Symbol; *ret = s->Grad(wrts); *out = ret; API_END(); } int MXSymbolInferShape(SymbolHandle sym, mx_uint num_args, const char** keys, const mx_uint *arg_ind_ptr, const mx_uint *arg_shape_data, mx_uint *in_shape_size, const mx_uint **in_shape_ndim, const mx_uint ***in_shape_data, mx_uint *out_shape_size, const mx_uint **out_shape_ndim, const mx_uint ***out_shape_data, mx_uint *aux_shape_size, const mx_uint **aux_shape_ndim, const mx_uint ***aux_shape_data, int *complete) { Symbol *s = static_cast(sym); MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); bool succ; API_BEGIN(); if (keys == nullptr && num_args != 0) { ret->arg_shapes.clear(); for (mx_uint i = 0; i < num_args; ++i) { ret->arg_shapes.push_back(TShape(arg_shape_data + arg_ind_ptr[i], arg_shape_data + arg_ind_ptr[i+1])); } succ = s->InferShape(&(ret->arg_shapes), &(ret->out_shapes), &(ret->aux_shapes)); } else { std::unordered_map kwargs; for (mx_uint i = 0; i < num_args; ++i) { kwargs[keys[i]] = TShape(arg_shape_data + arg_ind_ptr[i], arg_shape_data + arg_ind_ptr[i+1]); } succ = s->InferShape(kwargs, &(ret->arg_shapes), &(ret->out_shapes), &(ret->aux_shapes)); } if (succ) { MXAPIThreadLocalEntry::SetupShapeArrayReturn( ret->arg_shapes, &(ret->arg_shape_ndim), &(ret->arg_shape_data)); MXAPIThreadLocalEntry::SetupShapeArrayReturn( ret->out_shapes, &(ret->out_shape_ndim), &(ret->out_shape_data)); MXAPIThreadLocalEntry::SetupShapeArrayReturn( ret->aux_shapes, &(ret->aux_shape_ndim), &(ret->aux_shape_data)); *in_shape_size = static_cast(ret->arg_shapes.size()); *in_shape_ndim = dmlc::BeginPtr(ret->arg_shape_ndim); *in_shape_data = dmlc::BeginPtr(ret->arg_shape_data); *out_shape_size = static_cast(ret->out_shapes.size()); *out_shape_ndim = dmlc::BeginPtr(ret->out_shape_ndim); *out_shape_data = dmlc::BeginPtr(ret->out_shape_data); *aux_shape_size = static_cast(ret->aux_shapes.size()); *aux_shape_ndim = dmlc::BeginPtr(ret->aux_shape_ndim); *aux_shape_data = dmlc::BeginPtr(ret->aux_shape_data); *complete = 1; } else { *complete = 0; } API_END(); } int MXExecutorPrint(ExecutorHandle handle, const char **out_str) { Executor *exec = static_cast(handle); MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); API_BEGIN(); std::ostringstream os; exec->Print(os); ret->ret_str = os.str(); *out_str = (ret->ret_str).c_str(); API_END(); } int MXExecutorFree(ExecutorHandle handle) { API_BEGIN(); delete static_cast(handle); API_END(); } int MXExecutorForward(ExecutorHandle handle, int is_train) { API_BEGIN(); Executor *exec = static_cast(handle); exec->Forward(is_train != 0); API_END(); } int MXExecutorBackward(ExecutorHandle handle, mx_uint len, NDArrayHandle *head_grads) { API_BEGIN(); Executor *exec = static_cast(handle); std::vector ndarrays; NDArray **args_ptr = reinterpret_cast(head_grads); for (mx_uint i = 0; i < len; ++i) { ndarrays.push_back(*args_ptr[i]); } exec->Backward(ndarrays); API_END(); } int MXExecutorOutputs(ExecutorHandle handle, mx_uint *out_size, NDArrayHandle **out) { MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); API_BEGIN(); Executor *exec = static_cast(handle); std::vector heads = exec->outputs(); ret->ret_handles.resize(heads.size()); for (size_t i = 0; i < heads.size(); ++i) { NDArray *ptr = new NDArray(); *ptr = heads[i]; ret->ret_handles[i] = ptr; } *out_size = heads.size(); *out = dmlc::BeginPtr(ret->ret_handles); API_END(); } int MXExecutorBind(SymbolHandle symbol_handle, int dev_type, int dev_id, mx_uint len, NDArrayHandle *in_args, NDArrayHandle *arg_grad_store, mx_uint *grad_req_type, mx_uint aux_states_len, NDArrayHandle *aux_states, ExecutorHandle *out) { API_BEGIN(); Symbol *symb = static_cast(symbol_handle); Context ctx = Context::Create(static_cast(dev_type), dev_id); NDArray **in_args_ptr = reinterpret_cast(in_args); NDArray **arg_grad_ptr = reinterpret_cast(arg_grad_store); NDArray **aux_states_ptr = reinterpret_cast(aux_states); std::vector in_args_vec; std::vector arg_grad_vec; std::vector grad_req_vec; std::vector aux_states_vec; for (mx_uint i = 0; i < len; ++i) { in_args_vec.push_back(*(in_args_ptr[i])); if (arg_grad_ptr[i] == nullptr) { arg_grad_vec.push_back(NDArray()); grad_req_vec.push_back(kNullOp); } else { arg_grad_vec.push_back(*(arg_grad_ptr[i])); grad_req_vec.push_back(static_cast(grad_req_type[i])); } } for (mx_uint i = 0; i < aux_states_len; ++i) { aux_states_vec.push_back(*(aux_states_ptr[i])); } *out = Executor::Bind(*symb, ctx, in_args_vec, arg_grad_vec, grad_req_vec, aux_states_vec); API_END(); } //-------------------------------------------- // Part 5: IO Interface //-------------------------------------------- int MXListDataIters(mx_uint *out_size, DataIterCreator **out_array) { API_BEGIN(); auto &vec = dmlc::Registry::List(); *out_size = static_cast(vec.size()); *out_array = (DataIterCreator*)(dmlc::BeginPtr(vec)); // NOLINT(*) API_END(); } int MXDataIterGetIterInfo(DataIterCreator creator, const char **name, const char **description, mx_uint *num_args, const char ***arg_names, const char ***arg_type_infos, const char ***arg_descriptions) { DataIteratorReg *e = static_cast(creator); return MXAPIGetFunctionRegInfo(e, name, description, num_args, arg_names, arg_type_infos, arg_descriptions); } int MXDataIterCreateIter(DataIterCreator creator, mx_uint num_param, const char **keys, const char **vals, DataIterHandle *out) { IIterator *iter = nullptr; API_BEGIN(); DataIteratorReg *e = static_cast(creator); iter = e->body(); std::vector > kwargs; for (mx_uint i = 0; i < num_param; ++i) { kwargs.push_back({std::string(keys[i]), std::string(vals[i])}); } iter->Init(kwargs); *out = iter; API_END_HANDLE_ERROR(delete iter); } int MXDataIterFree(DataIterHandle handle) { API_BEGIN(); delete static_cast *>(handle); API_END(); } int MXDataIterBeforeFirst(DataIterHandle handle) { API_BEGIN(); static_cast* >(handle)->BeforeFirst(); API_END(); } int MXDataIterNext(DataIterHandle handle, int *out) { API_BEGIN(); *out = static_cast* >(handle)->Next(); API_END(); } int MXDataIterGetLabel(DataIterHandle handle, NDArrayHandle *out) { API_BEGIN(); const DataBatch& db = static_cast* >(handle)->Value(); NDArray* pndarray = new NDArray(); // temp hack to make label 1D // TODO(tianjun) make label 1D when label_width=0 TShape shape = db.data[1].shape(); if (shape[1] == 1) { *pndarray = db.data[1].Reshape(mshadow::Shape1(shape[0])); } else { *pndarray = db.data[1]; } *out = pndarray; API_END(); } int MXDataIterGetData(DataIterHandle handle, NDArrayHandle *out) { API_BEGIN(); const DataBatch& db = static_cast* >(handle)->Value(); NDArray* pndarray = new NDArray(); *pndarray = db.data[0]; *out = pndarray; API_END(); } int MXDataIterGetPadNum(DataIterHandle handle, int *pad) { API_BEGIN(); const DataBatch& db = static_cast* >(handle)->Value(); *pad = db.num_batch_padd; API_END(); } int MXKVStoreCreate(const char *type, KVStoreHandle *out) { API_BEGIN(); *out = KVStore::Create(type); API_END(); } int MXKVStoreFree(KVStoreHandle handle) { API_BEGIN(); delete static_cast(handle); API_END(); } int MXKVStoreInit(KVStoreHandle handle, mx_uint num, const int* keys, NDArrayHandle* vals) { API_BEGIN(); std::vector v_keys(num); std::vector v_vals(num); for (mx_uint i = 0; i < num; ++i) { v_keys[i] = keys[i]; v_vals[i] = *static_cast(vals[i]); } static_cast(handle)->Init(v_keys, v_vals); API_END(); } int MXKVStorePush(KVStoreHandle handle, mx_uint num, const int* keys, NDArrayHandle* vals, int priority) { API_BEGIN(); std::vector v_keys(num); std::vector v_vals(num); for (mx_uint i = 0; i < num; ++i) { v_keys[i] = keys[i]; v_vals[i] = *static_cast(vals[i]); } static_cast(handle)->Push(v_keys, v_vals, priority); API_END(); } int MXKVStorePull(KVStoreHandle handle, mx_uint num, const int* keys, NDArrayHandle* vals, int priority) { API_BEGIN(); std::vector v_keys(num); std::vector v_vals(num); for (mx_uint i = 0; i < num; ++i) { v_keys[i] = keys[i]; v_vals[i] = static_cast(vals[i]); } static_cast(handle)->Pull(v_keys, v_vals, priority); API_END(); } int MXKVStoreSetUpdater(KVStoreHandle handle, MXKVStoreUpdater updater, void* updater_handle) { API_BEGIN(); MXKVStoreUpdater * updater_temp = updater; void* updater_handle_temp = updater_handle; std::function updt = [updater_temp, updater_handle_temp](int key, const NDArray& recv, NDArray* local) { NDArray* recv_copy = new NDArray(); *recv_copy = recv; NDArray* local_copy = new NDArray(); *local_copy = *local; updater_temp(key, recv_copy, local_copy, updater_handle_temp); }; static_cast(handle)->set_updater(updt); API_END(); } int MXKVStoreGetRank(KVStoreHandle handle, int *rank) { API_BEGIN(); *rank = static_cast(handle)->get_rank(); API_END(); } int MXKVStoreGetGroupSize(KVStoreHandle handle, int *size) { API_BEGIN(); *size = static_cast(handle)->get_group_size(); API_END(); } int MXKVStoreBarrier(KVStoreHandle handle) { API_BEGIN(); static_cast(handle)->Barrier(); API_END(); } int MXKVStoreIsWorkerNode(int *ret) { API_BEGIN(); *ret = KVStore::IsWorkerNode(); API_END(); } int MXKVStoreIsServerNode(int *ret) { API_BEGIN(); *ret = KVStore::IsServerNode(); API_END(); } int MXKVStoreIsSchedulerNode(int *ret) { API_BEGIN(); *ret = KVStore::IsSchedulerNode(); API_END(); } int MXKVStoreRunServer(KVStoreHandle handle, MXKVStoreServerController controller) { API_BEGIN(); MXKVStoreServerController *controller_temp = controller; auto ctrl = [controller_temp](int head, const std::string& body) { controller_temp(head, body.c_str()); }; static_cast(handle)->RunServer(ctrl); API_END(); } int MXKVStoreSendCommmandToServers(KVStoreHandle handle, int cmd_id, const char* cmd_body) { API_BEGIN(); static_cast(handle)->SendCommandToServers( cmd_id, std::string(cmd_body)); API_END(); } int MXKVStoreGetType(KVStoreHandle handle, const char** type) { API_BEGIN(); *CHECK_NOTNULL(type) = static_cast(handle)->type().c_str(); API_END(); } struct MXRecordIOContext { dmlc::RecordIOWriter *writer; dmlc::RecordIOReader *reader; dmlc::Stream *stream; std::string *read_buff; }; int MXRecordIOWriterCreate(const char *uri, RecordIOHandle *out) { API_BEGIN(); dmlc::Stream *stream = dmlc::Stream::Create(uri, "w"); MXRecordIOContext *context = new MXRecordIOContext; context->writer = new dmlc::RecordIOWriter(stream); context->reader = NULL; context->stream = stream; context->read_buff = NULL; *out = reinterpret_cast(context); API_END(); } int MXRecordIOWriterFree(RecordIOHandle handle) { API_BEGIN(); MXRecordIOContext *context = reinterpret_cast(handle); delete context->writer; delete context->stream; API_END(); } int MXRecordIOWriterWriteRecord(RecordIOHandle *handle, const char *buf, size_t size) { API_BEGIN(); MXRecordIOContext *context = reinterpret_cast(handle); context->writer->WriteRecord(reinterpret_cast(buf), size); API_END(); } int MXRecordIOReaderCreate(const char *uri, RecordIOHandle *out) { API_BEGIN(); dmlc::Stream *stream = dmlc::Stream::Create(uri, "r"); MXRecordIOContext *context = new MXRecordIOContext; context->reader = new dmlc::RecordIOReader(stream); context->writer = NULL; context->stream = stream; context->read_buff = new std::string(); *out = reinterpret_cast(context); API_END(); } int MXRecordIOReaderFree(RecordIOHandle *handle) { API_BEGIN(); MXRecordIOContext *context = reinterpret_cast(handle); delete context->reader; delete context->stream; delete context->read_buff; API_END(); } int MXRecordIOReaderReadRecord(RecordIOHandle *handle, char const **buf, size_t *size) { API_BEGIN(); MXRecordIOContext *context = reinterpret_cast(handle); context->reader->NextRecord(context->read_buff); *buf = context->read_buff->c_str(); *size = context->read_buff->size(); API_END(); } //===== EXPANDED: mxnet/src/c_api/c_api.cc ===== //===== EXPANDIND: mxnet/src/c_api/c_api_error.cc ===== /*! * Copyright (c) 2015 by Contributors * \file c_api_error.cc * \brief C error handling */ struct ErrorEntry { std::string last_error; }; typedef mxnet::common::ThreadLocalStore MXAPIErrorStore; const char *MXGetLastError() { return MXAPIErrorStore::Get()->last_error.c_str(); } void MXAPISetLastError(const char* msg) { MXAPIErrorStore::Get()->last_error = msg; } //===== EXPANDED: mxnet/src/c_api/c_api_error.cc ===== //===== EXPANDIND: mxnet/src/c_api/c_predict_api.cc ===== /*! * Copyright (c) 2015 by Contributors * \file c_predict_api.cc * \brief C predict API of mxnet */ //===== EXPANDIND: mxnet/include/mxnet/c_predict_api.h ===== /*! * Copyright (c) 2015 by Contributors * \file c_predict_api.h * \brief C predict API of mxnet, contains a minimum API to run prediction. * This file is self-contained, and do not dependent on any other files. */ #ifndef MXNET_C_PREDICT_API_H_ #define MXNET_C_PREDICT_API_H_ #ifdef __cplusplus #define MXNET_EXTERN_C extern "C" #else #define MXNET_EXTERN_C #endif #ifdef _WIN32 #ifdef MXNET_EXPORTS #define MXNET_DLL MXNET_EXTERN_C __declspec(dllexport) #else #define MXNET_DLL MXNET_EXTERN_C __declspec(dllimport) #endif #else #define MXNET_DLL MXNET_EXTERN_C #endif /*! \brief manually define unsigned int */ typedef unsigned int mx_uint; /*! \brief manually define float */ typedef float mx_float; /*! \brief handle to Predictor */ typedef void *PredictorHandle; /*! \brief handle to NDArray list */ typedef void *NDListHandle; /*! * \brief Get the last error happeneed. * \return The last error happened at the predictor. */ MXNET_DLL const char* MXGetLastError(); /*! * \brief create a predictor * \param symbol_json_str The JSON string of the symbol. * \param param_bytes The in-memory raw bytes of parameter ndarray file. * \param param_size The size of parameter ndarray file. * \param dev_type The device type, 1: cpu, 2:gpu * \param dev_id The device id of the predictor. * \param num_input_nodes Number of input nodes to the net, * For feedforward net, this is 1. * \param input_keys The name of input argument. * For feedforward net, this is {"data"} * \param input_shape_indptr Index pointer of shapes of each input node. * The length of this array = num_input_nodes + 1. * For feedforward net that takes 4 dimensional input, this is {0, 4}. * \param input_shape_data A flatted data of shapes of each input node. * For feedforward net that takes 4 dimensional input, this is the shape data. * \param out The created predictor handle. * \return 0 when success, -1 when failure. */ MXNET_DLL int MXPredCreate(const char* symbol_json_str, const char* param_bytes, size_t param_size, int dev_type, int dev_id, mx_uint num_input_nodes, const char** input_keys, const mx_uint* input_shape_indptr, const mx_uint* input_shape_data, PredictorHandle* out); /*! * \brief Get the shape of output node. * The returned shape_data and shape_ndim is only valid before next call to MXPred function. * \param handle The handle of the predictor. * \param index The index of output node, set to 0 if there is only one output. * \param shape_data Used to hold pointer to the shape data * \param shape_ndim Used to hold shape dimension. * \return 0 when success, -1 when failure. */ MXNET_DLL int MXPredGetOutputShape(PredictorHandle handle, mx_uint index, mx_uint** shape_data, mx_uint* shape_ndim); /*! * \brief Set the input data of predictor. * \param handle The predictor handle. * \param key The name of input node to set. * For feedforward net, this is "data". * \param data The pointer to the data to be set, with the shape specified in MXPredCreate. * \param size The size of data array, used for safety check. * \return 0 when success, -1 when failure. */ MXNET_DLL int MXPredSetInput(PredictorHandle handle, const char* key, const mx_float* data, mx_uint size); /*! * \brief Run a forward pass to get the output * \param handle The handle of the predictor. * \return 0 when success, -1 when failure. */ MXNET_DLL int MXPredForward(PredictorHandle handle); /*! * \brief Get the output value of prediction. * \param handle The handle of the predictor. * \param index The index of output node, set to 0 if there is only one output. * \param data User allocated data to hold the output. * \param size The size of data array, used for safe checking. * \return 0 when success, -1 when failure. */ MXNET_DLL int MXPredGetOutput(PredictorHandle handle, mx_uint index, mx_float* data, mx_uint size); /*! * \brief Free a predictor handle. * \param handle The handle of the predictor. * \return 0 when success, -1 when failure. */ MXNET_DLL int MXPredFree(PredictorHandle handle); /*! * \brief Create a NDArray List by loading from ndarray file. * This can be used to load mean image file. * \param nd_file_bytes The byte contents of nd file to be loaded. * \param nd_file_size The size of the nd file to be loaded. * \param out The out put NDListHandle * \param out_length Length of the list. * \return 0 when success, -1 when failure. */ MXNET_DLL int MXNDListCreate(const char* nd_file_bytes, size_t nd_file_size, NDListHandle *out, mx_uint* out_length); /*! * \brief Get an element from list * \param handle The handle to the NDArray * \param index The index in the list * \param out_key The output key of the item * \param out_data The data region of the item * \param out_shape The shape of the item. * \param out_ndim The number of dimension in the shape. * \return 0 when success, -1 when failure. */ MXNET_DLL int MXNDListGet(NDListHandle handle, mx_uint index, const char** out_key, const mx_float** out_data, const mx_uint** out_shape, mx_uint* out_ndim); /*! * \brief Free a predictor handle. * \param handle The handle of the predictor. * \return 0 when success, -1 when failure. */ MXNET_DLL int MXNDListFree(NDListHandle handle); #endif // MXNET_C_PREDICT_API_H_ //===== EXPANDED: mxnet/include/mxnet/c_predict_api.h ===== using namespace mxnet; // predictor interface struct MXAPIPredictor { // output arrays std::vector out_arrays; // argument arrays std::vector arg_arrays; // output shapes std::vector out_shapes; // key to arguments std::unordered_map key2arg; // executor std::unique_ptr exec; }; struct MXAPINDList { std::vector keys; std::vector shapes; std::vector indptr; std::vector data; }; int MXPredCreate(const char* symbol_json_str, const char* param_bytes, size_t param_size, int dev_type, int dev_id, mx_uint num_input_nodes, const char** input_keys, const mx_uint* input_shape_indptr, const mx_uint* input_shape_data, PredictorHandle* out) { MXAPIPredictor* ret = new MXAPIPredictor(); API_BEGIN(); Symbol sym; // load in the symbol. { std::string json = symbol_json_str; std::istringstream is(json); dmlc::JSONReader reader(&is); sym.Load(&reader); } // load the parameters std::unordered_map arg_params, aux_params; { std::vector data; std::vector names; dmlc::MemoryFixedSizeStream fi((void*)param_bytes, param_size); // NOLINT(*) NDArray::Load(&fi, &data, &names); CHECK_EQ(names.size(), data.size()) << "Invalid param file format"; for (size_t i = 0; i < names.size(); ++i) { if (!strncmp(names[i].c_str(), "aux:", 4)) { aux_params[std::string(names[i].c_str() + 4)] = data[i]; } if (!strncmp(names[i].c_str(), "arg:", 4)) { arg_params[std::string(names[i].c_str() + 4)] = data[i]; } } } // shape inference and bind std::unordered_map known_shape; for (mx_uint i = 0; i < num_input_nodes; ++i) { known_shape[std::string(input_keys[i])] = TShape(input_shape_data + input_shape_indptr[i], input_shape_data + input_shape_indptr[i + 1]); } std::vector arg_shapes; std::vector arg_names = sym.ListArguments(); std::vector aux_names = sym.ListAuxiliaryStates(); std::vector out_shapes(sym.ListOutputs().size()); std::vector aux_shapes(aux_names.size()); for (size_t i = 0; i < arg_names.size(); ++i) { std::string key = arg_names[i]; ret->key2arg[key] = i; if (known_shape.count(key) != 0) { arg_shapes.push_back(known_shape[key]); } else { arg_shapes.push_back(TShape()); } } CHECK(sym.InferShape(&arg_shapes, &out_shapes, &aux_shapes)) << "The shape information of is not enough to get the shapes"; ret->out_shapes = out_shapes; Context ctx = Context::Create(static_cast(dev_type), dev_id); std::vector arg_arrays, aux_arrays; for (size_t i = 0; i < arg_shapes.size(); ++i) { NDArray nd = NDArray(arg_shapes[i], ctx); if (arg_params.count(arg_names[i]) != 0) { CopyFromTo(arg_params[arg_names[i]], &nd); } arg_arrays.push_back(nd); } for (size_t i = 0; i < aux_shapes.size(); ++i) { NDArray nd = NDArray(aux_shapes[i], ctx); if (aux_params.count(aux_names[i]) != 0) { CopyFromTo(aux_params[aux_names[i]], &nd); } aux_arrays.push_back(nd); } ret->arg_arrays = arg_arrays; // bind { std::vector grad_store(arg_arrays.size()); std::vector grad_req(arg_arrays.size(), kNullOp); ret->exec.reset(Executor::Bind(sym, ctx, arg_arrays, grad_store, grad_req, aux_arrays)); ret->out_arrays = ret->exec->outputs(); } *out = ret; API_END_HANDLE_ERROR(delete ret); } int MXPredGetOutputShape(PredictorHandle handle, mx_uint out_index, mx_uint** shape_data, mx_uint* shape_ndim) { MXAPIPredictor* p = static_cast(handle); API_BEGIN(); CHECK_LT(out_index, p->out_arrays.size()) << "Index exceed number of outputs"; *shape_data = p->out_shapes[out_index].data(); *shape_ndim = p->out_shapes[out_index].ndim(); API_END(); } int MXPredSetInput(PredictorHandle handle, const char* key, const mx_float* data, mx_uint size) { MXAPIPredictor* p = static_cast(handle); API_BEGIN(); auto it = p->key2arg.find(key); if (it == p->key2arg.end()) { LOG(FATAL) << "cannot find input key " << key; } NDArray& nd = p->arg_arrays[it->second]; nd.SyncCopyFromCPU(data, size); API_END(); } int MXPredForward(PredictorHandle handle) { MXAPIPredictor* p = static_cast(handle); API_BEGIN(); p->exec->Forward(false); API_END(); } int MXPredGetOutput(PredictorHandle handle, mx_uint index, mx_float* data, mx_uint size) { MXAPIPredictor* p = static_cast(handle); API_BEGIN(); CHECK_LT(index, p->out_arrays.size()) << "Output index out of range"; const NDArray& nd = p->out_arrays[index]; nd.SyncCopyToCPU(data, size); API_END(); } int MXPredFree(PredictorHandle handle) { API_BEGIN(); delete static_cast(handle); API_END(); } int MXNDListCreate(const char* nd_file_bytes, size_t nd_file_size, NDListHandle *out, mx_uint* out_length) { MXAPINDList* ret = new MXAPINDList(); API_BEGIN(); std::vector arrays; dmlc::MemoryFixedSizeStream fi((void*)nd_file_bytes, nd_file_size); // NOLINT(*) NDArray::Load(&fi, &(arrays), &(ret->keys)); if (ret->keys.size() == 0) { ret->keys.resize(arrays.size()); } ret->indptr.push_back(0); for (size_t i = 0; i < arrays.size(); ++i) { TShape shape = arrays[i].shape(); size_t begin = ret->data.size(); size_t size = shape.Size(); ret->shapes.push_back(shape); ret->data.resize(begin + size); arrays[i].SyncCopyToCPU(dmlc::BeginPtr(ret->data) + begin, size); ret->indptr.push_back(begin + size); } *out = ret; *out_length = static_cast(arrays.size()); API_END(); } int MXNDListGet(NDListHandle handle, mx_uint index, const char** out_key, const mx_float** out_data, const mx_uint** out_shape, mx_uint* out_ndim) { MXAPINDList* p = static_cast(handle); API_BEGIN(); CHECK_LT(index, p->shapes.size()) << "Index out of range"; *out_key = p->keys[index].c_str(); *out_data = dmlc::BeginPtr(p->data) + p->indptr[index]; *out_shape = p->shapes[index].data(); *out_ndim = p->shapes[index].ndim(); API_END(); } int MXNDListFree(NDListHandle handle) { API_BEGIN(); delete static_cast(handle); API_END(); } //===== EXPANDED: mxnet/src/c_api/c_predict_api.cc ===== //===== EXPANDIND: mxnet/dmlc-core/src/data.cc ===== // Copyright by Contributors //===== EXPANDIND: mxnet/dmlc-core/src/io/uri_spec.h ===== /*! * Copyright (c) 2015 by Contributors * \file uri_spec.h * \brief common specification of sugars in URI * string passed to dmlc Create functions * such as local file cache * \author Tianqi Chen */ #ifndef DMLC_IO_URI_SPEC_H_ #define DMLC_IO_URI_SPEC_H_ //===== EXPANDIND: mxnet/dmlc-core/src/io/filesys.h ===== /*! * Copyright (c) 2015 by Contributors * \file filesystem.h * \brief general file system io interface * \author Tianqi Chen */ #ifndef DMLC_IO_FILESYS_H_ #define DMLC_IO_FILESYS_H_ namespace dmlc { namespace io { /*! \brief common data structure for URI */ struct URI { /*! \brief protocol */ std::string protocol; /*! * \brief host name, namenode for HDFS, bucket name for s3 */ std::string host; /*! \brief name of the path */ std::string name; /*! \brief enable default constructor */ URI(void) {} /*! * \brief construct from URI string */ explicit URI(const char *uri) { const char *p = std::strstr(uri, "://"); if (p == NULL) { name = uri; } else { protocol = std::string(uri, p - uri + 3); uri = p + 3; p = std::strchr(uri, '/'); if (p == NULL) { host = uri; name = '/'; } else { host = std::string(uri, p - uri); name = p; } } } /*! \brief string representation */ inline std::string str(void) const { return protocol + host + name; } }; /*! \brief type of file */ enum FileType { /*! \brief the file is file */ kFile, /*! \brief the file is directory */ kDirectory }; /*! \brief use to store file information */ struct FileInfo { /*! \brief full path to the file */ URI path; /*! \brief the size of the file */ size_t size; /*! \brief the type of the file */ FileType type; /*! \brief default constructor */ FileInfo() : size(0), type(kFile) {} }; /*! \brief file system system interface */ class FileSystem { public: /*! * \brief get singleton of filesystem instance according to protocol * \param protocol can be s3://, hdfs://, file://, * empty string(will return local) * \return a corresponding filesystem, report error if * we cannot find a matching system */ static FileSystem *GetInstance(const std::string &protocol); /*! \brief virtual destructor */ virtual ~FileSystem() {} /*! * \brief get information about a path * \param path the path to the file * \return the information about the file */ virtual FileInfo GetPathInfo(const URI &path) = 0; /*! * \brief list files in a directory * \param path to the file * \param out_list the output information about the files */ virtual void ListDirectory(const URI &path, std::vector *out_list) = 0; /*! * \brief open a stream * \param path path to file * \param uri the uri of the input, can contain hdfs prefix * \param flag can be "w", "r", "a * \param allow_null whether NULL can be returned, or directly report error * \return the created stream, can be NULL when allow_null == true and file do not exist */ virtual Stream *Open(const URI &path, const char* const flag, bool allow_null = false) = 0; /*! * \brief open a seekable stream for read * \param path the path to the file * \param allow_null whether NULL can be returned, or directly report error * \return the created stream, can be NULL when allow_null == true and file do not exist */ virtual SeekStream *OpenForRead(const URI &path, bool allow_null = false) = 0; }; } // namespace io } // namespace dmlc #endif // DMLC_IO_FILESYS_H_ //===== EXPANDED: mxnet/dmlc-core/src/io/filesys.h ===== namespace dmlc { namespace io { /*! * \brief some super set of URI * that allows sugars to be passed around */ struct URISpec { /*! \brief the real URI */ std::string uri; /*! \brief the path to cache file */ std::string cache_file; explicit URISpec(const char *uri, unsigned part_index, unsigned num_parts) { const char *dlm = strchr(uri, '#'); if (dlm != NULL) { CHECK(strchr(dlm + 1, '#') == NULL) << "only one `#` is allowed in file path for cachefile specification"; this->uri = std::string(uri, dlm - uri); std::ostringstream os; os << dlm + 1; if (num_parts != 1) { os << ".split" << num_parts << ".part" << part_index; } cache_file = os.str(); } else { this->uri = uri; } } }; } // namespace io } // namespace dmlc #endif // DMLC_IO_URI_SPEC_H_ //===== EXPANDED: mxnet/dmlc-core/src/io/uri_spec.h ===== //===== EXPANDIND: mxnet/dmlc-core/src/data/parser.h ===== /*! * Copyright (c) 2015 by Contributors * \file libsvm_parser.h * \brief iterator parser to parse libsvm format * \author Tianqi Chen */ #ifndef DMLC_DATA_PARSER_H_ #define DMLC_DATA_PARSER_H_ //===== EXPANDIND: mxnet/dmlc-core/src/data/row_block.h ===== /*! * Copyright (c) 2015 by Contributors * \file row_block.h * \brief additional data structure to support * RowBlock data structure * \author Tianqi Chen */ #ifndef DMLC_DATA_ROW_BLOCK_H_ #define DMLC_DATA_ROW_BLOCK_H_ namespace dmlc { namespace data { /*! * \brief dynamic data structure that holds * a row block of data * \tparam IndexType the type of index we are using */ template struct RowBlockContainer { /*! \brief array[size+1], row pointer to beginning of each rows */ std::vector offset; /*! \brief array[size] label of each instance */ std::vector label; /*! \brief array[size] weight of each instance */ std::vector weight; /*! \brief feature index */ std::vector index; /*! \brief feature value */ std::vector value; /*! \brief maximum value of index */ IndexType max_index; // constructor RowBlockContainer(void) { this->Clear(); } /*! \brief convert to a row block */ inline RowBlock GetBlock(void) const; /*! * \brief write the row block to a binary stream * \param fo output stream */ inline void Save(Stream *fo) const; /*! * \brief load row block from a binary stream * \param fi output stream * \return false if at end of file */ inline bool Load(Stream *fi); /*! \brief clear the container */ inline void Clear(void) { offset.clear(); offset.push_back(0); label.clear(); index.clear(); value.clear(); weight.clear(); max_index = 0; } /*! \brief size of the data */ inline size_t Size(void) const { return offset.size() - 1; } /*! \return estimation of memory cost of this container */ inline size_t MemCostBytes(void) const { return offset.size() * sizeof(size_t) + label.size() * sizeof(real_t) + weight.size() * sizeof(real_t) + index.size() * sizeof(IndexType) + value.size() * sizeof(real_t); } /*! * \brief push the row into container * \param row the row to push back * \tparam I the index type of the row */ template inline void Push(Row row) { label.push_back(row.label); weight.push_back(row.weight); for (size_t i = 0; i < row.length; ++i) { CHECK_LE(row.index[i], std::numeric_limits::max()) << "index exceed numeric bound of current type"; IndexType findex = static_cast(row.index[i]); index.push_back(findex); max_index = std::max(max_index, findex); } if (row.value != NULL) { for (size_t i = 0; i < row.length; ++i) { value.push_back(row.value[i]); } } offset.push_back(index.size()); } /*! * \brief push the row block into container * \param row the row to push back * \tparam I the index type of the row */ template inline void Push(RowBlock batch) { size_t size = label.size(); label.resize(label.size() + batch.size); std::memcpy(BeginPtr(label) + size, batch.label, batch.size * sizeof(real_t)); if (batch.weight != NULL) { weight.insert(weight.end(), batch.weight, batch.weight + batch.size); } size_t ndata = batch.offset[batch.size] - batch.offset[0]; index.resize(index.size() + ndata); IndexType *ihead = BeginPtr(index) + offset.back(); for (size_t i = 0; i < ndata; ++i) { CHECK_LE(batch.index[i], std::numeric_limits::max()) << "index exceed numeric bound of current type"; IndexType findex = static_cast(batch.index[i]); ihead[i] = findex; max_index = std::max(max_index, findex); } if (batch.value != NULL) { value.resize(value.size() + ndata); std::memcpy(BeginPtr(value) + value.size() - ndata, batch.value, ndata * sizeof(real_t)); } size_t shift = offset[size]; offset.resize(offset.size() + batch.size); size_t *ohead = BeginPtr(offset) + size + 1; for (size_t i = 0; i < batch.size; ++i) { ohead[i] = shift + batch.offset[i + 1] - batch.offset[0]; } } }; template inline RowBlock RowBlockContainer::GetBlock(void) const { // consistency check CHECK_EQ(label.size() + 1, offset.size()); CHECK_EQ(offset.back(), index.size()); CHECK(offset.back() == value.size() || value.size() == 0); RowBlock data; data.size = offset.size() - 1; data.offset = BeginPtr(offset); data.label = BeginPtr(label); data.weight = BeginPtr(weight); data.index = BeginPtr(index); data.value = BeginPtr(value); return data; } template inline void RowBlockContainer::Save(Stream *fo) const { fo->Write(offset); fo->Write(label); fo->Write(weight); fo->Write(index); fo->Write(value); fo->Write(&max_index, sizeof(IndexType)); } template inline bool RowBlockContainer::Load(Stream *fi) { if (!fi->Read(&offset)) return false; CHECK(fi->Read(&label)) << "Bad RowBlock format"; CHECK(fi->Read(&weight)) << "Bad RowBlock format"; CHECK(fi->Read(&index)) << "Bad RowBlock format"; CHECK(fi->Read(&value)) << "Bad RowBlock format"; CHECK(fi->Read(&max_index, sizeof(IndexType))) << "Bad RowBlock format"; return true; } } // namespace data } // namespace dmlc #endif // DMLC_DATA_ROW_BLOCK_H_ //===== EXPANDED: mxnet/dmlc-core/src/data/row_block.h ===== namespace dmlc { namespace data { /*! \brief declare thread class */ template class ThreadedParser; /*! \brief base class for parser to parse data */ template class ParserImpl : public Parser { public: ParserImpl() : data_ptr_(0), data_end_(0) {} // virtual destructor virtual ~ParserImpl() {} /*! \brief implement next */ virtual bool Next(void) { while (true) { while (data_ptr_ < data_end_) { data_ptr_ += 1; if (data_[data_ptr_ - 1].Size() != 0) { block_ = data_[data_ptr_ - 1].GetBlock(); return true; } } if (!ParseNext(&data_)) break; data_ptr_ = 0; data_end_ = static_cast(data_.size()); } return false; } virtual const RowBlock &Value(void) const { return block_; } /*! \return size of bytes read so far */ virtual size_t BytesRead(void) const = 0; protected: // allow ThreadedParser to see ParseNext friend class ThreadedParser; /*! * \brief read in next several blocks of data * \param data vector of data to be returned * \return true if the data is loaded, false if reach end */ virtual bool ParseNext(std::vector > *data) = 0; /*! \brief pointer to begin and end of data */ IndexType data_ptr_, data_end_; /*! \brief internal data */ std::vector > data_; /*! \brief internal row block */ RowBlock block_; }; #if DMLC_USE_CXX11 template class ThreadedParser : public ParserImpl { public: explicit ThreadedParser(ParserImpl *base) : base_(base), tmp_(NULL) { iter_.set_max_capacity(8); iter_.Init([base](std::vector > **dptr) { if (*dptr == NULL) { *dptr = new std::vector >(); } return base->ParseNext(*dptr); }, [base]() {base->BeforeFirst();}); } virtual ~ThreadedParser(void) { // stop things before base is deleted iter_.Destroy(); delete base_; delete tmp_; } virtual void BeforeFirst() { iter_.BeforeFirst(); } /*! \brief implement next */ using ParserImpl::data_ptr_; using ParserImpl::data_end_; virtual bool Next(void) { while (true) { while (data_ptr_ < data_end_) { data_ptr_ += 1; if ((*tmp_)[data_ptr_ - 1].Size() != 0) { this->block_ = (*tmp_)[data_ptr_ - 1].GetBlock(); return true; } } if (tmp_ != NULL) iter_.Recycle(&tmp_); if (!iter_.Next(&tmp_)) break; data_ptr_ = 0; data_end_ = tmp_->size(); } return false; } virtual size_t BytesRead(void) const { return base_->BytesRead(); } protected: virtual bool ParseNext(std::vector > *data) { LOG(FATAL) << "cannot call ParseNext"; return false; } private: /*! \brief the place where we get the data */ Parser *base_; /*! \brief backend threaded iterator */ ThreadedIter > > iter_; /*! \brief current chunk of data */ std::vector > *tmp_; }; #endif // DMLC_USE_CXX11 } // namespace data } // namespace dmlc #endif // DMLC_DATA_PARSER_H_ //===== EXPANDED: mxnet/dmlc-core/src/data/parser.h ===== //===== EXPANDIND: mxnet/dmlc-core/src/data/basic_row_iter.h ===== /*! * Copyright (c) 2015 by Contributors * \file basic_row_iter.h * \brief row based iterator that * loads in everything into memory and returns * \author Tianqi Chen */ #ifndef DMLC_DATA_BASIC_ROW_ITER_H_ #define DMLC_DATA_BASIC_ROW_ITER_H_ namespace dmlc { namespace data { /*! * \brief basic set of row iterators that provides * \tparam IndexType the type of index we are using */ template class BasicRowIter: public RowBlockIter { public: explicit BasicRowIter(Parser *parser) : at_head_(true) { this->Init(parser); delete parser; } virtual ~BasicRowIter() {} virtual void BeforeFirst(void) { at_head_ = true; } virtual bool Next(void) { if (at_head_) { at_head_ = false; return true; } else { return false; } } virtual const RowBlock &Value(void) const { return row_; } virtual size_t NumCol(void) const { return static_cast(data_.max_index) + 1; } private: // at head bool at_head_; // row block to store RowBlock row_; // back end data RowBlockContainer data_; // initialize inline void Init(Parser *parser); }; template inline void BasicRowIter::Init(Parser *parser) { data_.Clear(); double tstart = GetTime(); size_t bytes_expect = 10UL << 20UL; while (parser->Next()) { data_.Push(parser->Value()); double tdiff = GetTime() - tstart; size_t bytes_read = parser->BytesRead(); if (bytes_read >= bytes_expect) { bytes_read = bytes_read >> 20UL; LOG(INFO) << bytes_read << "MB read," << bytes_read / tdiff << " MB/sec"; bytes_expect += 10UL << 20UL; } } row_ = data_.GetBlock(); double tdiff = GetTime() - tstart; LOG(INFO) << "finish reading at " << (parser->BytesRead() >> 20UL) / tdiff << " MB/sec"; } } // namespace data } // namespace dmlc #endif // DMLC_DATA_BASIC_ROW_ITER_H__ //===== EXPANDED: mxnet/dmlc-core/src/data/basic_row_iter.h ===== //===== EXPANDIND: mxnet/dmlc-core/src/data/disk_row_iter.h ===== /*! * Copyright (c) 2015 by Contributors * \file basic_row_iter.h * \brief row based iterator that * caches things into disk and then load segments * \author Tianqi Chen */ #ifndef DMLC_DATA_DISK_ROW_ITER_H_ #define DMLC_DATA_DISK_ROW_ITER_H_ //===== EXPANDIND: mxnet/dmlc-core/src/data/libsvm_parser.h ===== /*! * Copyright (c) 2015 by Contributors * \file libsvm_parser.h * \brief iterator parser to parse libsvm format * \author Tianqi Chen */ #ifndef DMLC_DATA_LIBSVM_PARSER_H_ #define DMLC_DATA_LIBSVM_PARSER_H_ //===== EXPANDIND: mxnet/dmlc-core/src/data/strtonum.h ===== /*! *x Copyright (c) 2015 by Contributors * \file strtonum.h * \brief A faster implementation of strtod, ... */ #ifndef DMLC_DATA_STRTONUM_H_ #define DMLC_DATA_STRTONUM_H_ namespace dmlc { namespace data { inline bool isspace(char c) { return (c == ' ' || c == '\t' || c == '\r' || c == '\n' || c == '\f'); } inline bool isblank(char c) { return (c == ' ' || c == '\t'); } inline bool isdigit(char c) { return (c >= '0' && c <= '9'); } inline bool isdigitchars(char c) { return (c >= '0' && c <= '9') || c == '+' || c == '-' || c == '.' || c == 'e' || c == 'E'; } /*! * \brief A faster version of strtof * TODO the current version does not support INF, NAN, and hex number */ inline float strtof(const char *nptr, char **endptr) { const char *p = nptr; // Skip leading white space, if any. Not necessary while (isspace(*p) ) ++p; // Get sign, if any. bool sign = true; if (*p == '-') { sign = false; ++p; } else if (*p == '+') { ++p; } // Get digits before decimal point or exponent, if any. float value; for (value = 0; isdigit(*p); ++p) { value = value * 10.0f + (*p - '0'); } // Get digits after decimal point, if any. if (*p == '.') { unsigned pow10 = 1; unsigned val2 = 0; ++p; while (isdigit(*p)) { val2 = val2 * 10 + (*p - '0'); pow10 *= 10; ++p; } value += static_cast(val2) / static_cast(pow10); } // Handle exponent, if any. if ((*p == 'e') || (*p == 'E')) { ++p; bool frac = false; float scale = 1.0; unsigned expon; // Get sign of exponent, if any. if (*p == '-') { frac = true; ++p; } else if (*p == '+') { ++p; } // Get digits of exponent, if any. for (expon = 0; isdigit(*p); p += 1) { expon = expon * 10 + (*p - '0'); } if (expon > 38) expon = 38; // Calculate scaling factor. while (expon >= 8) { scale *= 1E8; expon -= 8; } while (expon > 0) { scale *= 10.0; expon -= 1; } // Return signed and scaled floating point result. value = frac ? (value / scale) : (value * scale); } if (endptr) *endptr = (char*)p; // NOLINT(*) return sign ? value : - value; } /** * \brief A faster string to integer convertor * TODO only support base <=10 */ template inline V strtoint(const char* nptr, char **endptr, int base) { const char *p = nptr; // Skip leading white space, if any. Not necessary while (isspace(*p) ) ++p; // Get sign if any bool sign = true; if (*p == '-') { sign = false; ++p; } else if (*p == '+') { ++p; } V value; for (value = 0; isdigit(*p); ++p) { value = value * base + (*p - '0'); } if (endptr) *endptr = (char*)p; // NOLINT(*) return sign ? value : - value; } template inline V strtouint(const char* nptr, char **endptr, int base) { const char *p = nptr; // Skip leading white space, if any. Not necessary while (isspace(*p)) ++p; // Get sign if any bool sign = true; if (*p == '-') { sign = false; ++p; } else if (*p == '+') { ++p; } // we are parsing unsigned, so no minus sign should be found CHECK_EQ(sign, true); V value; for (value = 0; isdigit(*p); ++p) { value = value * base + (*p - '0'); } if (endptr) *endptr = (char*)p; // NOLINT(*) return value; } inline uint64_t strtoull(const char* nptr, char **endptr, int base) { return strtouint(nptr, endptr, base); } inline long atol(const char* p) { // NOLINT(*) return strtoint(p, 0, 10); // NOLINT(*) } inline float atof(const char *nptr) { return strtof(nptr, 0); } template class Str2T { public: static inline T get(const char * begin, const char * end); }; template inline T Str2Type(const char * begin, const char * end) { return Str2T::get(begin, end); } template<> class Str2T { public: static inline int32_t get(const char * begin, const char * end) { return strtoint(begin, NULL, 10); } }; template<> class Str2T { public: static inline uint32_t get(const char * begin, const char * end) { return strtouint(begin, NULL, 10); } }; template<> class Str2T { public: static inline int64_t get(const char * begin, const char * end) { return strtoint(begin, NULL, 10); } }; template<> class Str2T { public: static inline uint64_t get(const char * begin, const char * end) { return strtouint(begin, NULL, 10); } }; template<> class Str2T { public: static inline float get(const char * begin, const char * end) { return atof(begin); } }; /** * \brief Parse colon seperated pair v1[:v2] * \param begin: pointer to string * \param end: one past end of string * \param parseEnd: end string of parsed string * \param v1: first value in the pair * \param v2: second value in the pair * \output number of values parsed */ template inline int ParsePair(const char * begin, const char * end, const char ** endptr, T1 &v1, T2 &v2) { // NOLINT(*) const char * p = begin; while (p != end && !isdigitchars(*p)) ++p; if (p == end) { *endptr = end; return 0; } const char * q = p; while (q != end && isdigitchars(*q)) ++q; v1 = Str2Type(p, q); p = q; while (p != end && isblank(*p)) ++p; if (p == end || *p != ':') { // only v1 *endptr = p; return 1; } p++; while (p != end && !isdigitchars(*p)) ++p; q = p; while (q != end && isdigitchars(*q)) ++q; *endptr = q; v2 = Str2Type(p, q); return 2; } } // namespace data } // namespace dmlc #endif // DMLC_DATA_STRTONUM_H_ //===== EXPANDED: mxnet/dmlc-core/src/data/strtonum.h ===== namespace dmlc { namespace data { /*! * \brief libsvm parser that parses the input lines * and returns rows in input data */ template class LibSVMParser : public ParserImpl { public: explicit LibSVMParser(InputSplit *source, int nthread) : bytes_read_(0), source_(source) { int maxthread; #pragma omp parallel { maxthread = std::max(omp_get_num_procs() / 2 - 4, 1); } nthread_ = std::min(maxthread, nthread); } virtual ~LibSVMParser() { delete source_; } virtual void BeforeFirst(void) { source_->BeforeFirst(); } virtual size_t BytesRead(void) const { return bytes_read_; } virtual bool ParseNext(std::vector > *data) { return FillData(data); } protected: /*! * \brief read in next several blocks of data * \param data vector of data to be returned * \return true if the data is loaded, false if reach end */ inline bool FillData(std::vector > *data); /*! * \brief parse data into out * \param begin beginning of buffer * \param end end of buffer */ inline void ParseBlock(char *begin, char *end, RowBlockContainer *out); /*! * \brief start from bptr, go backward and find first endof line * \param bptr end position to go backward * \param begin the beginning position of buffer * \return position of first endof line going backward */ inline char* BackFindEndLine(char *bptr, char *begin) { for (; bptr != begin; --bptr) { if (*bptr == '\n' || *bptr == '\r') return bptr; } return begin; } private: // nthread int nthread_; // number of bytes readed size_t bytes_read_; // source split that provides the data InputSplit *source_; }; // implementation template inline bool LibSVMParser:: FillData(std::vector > *data) { InputSplit::Blob chunk; if (!source_->NextChunk(&chunk)) return false; int nthread; #pragma omp parallel num_threads(nthread_) { nthread = omp_get_num_threads(); } // reserve space for data data->resize(nthread); bytes_read_ += chunk.size; CHECK_NE(chunk.size, 0); char *head = reinterpret_cast(chunk.dptr); #pragma omp parallel num_threads(nthread_) { // threadid int tid = omp_get_thread_num(); size_t nstep = (chunk.size + nthread - 1) / nthread; size_t sbegin = std::min(tid * nstep, chunk.size); size_t send = std::min((tid + 1) * nstep, chunk.size); char *pbegin = BackFindEndLine(head + sbegin, head); char *pend; if (tid + 1 == nthread) { pend = head + send; } else { pend = BackFindEndLine(head + send, head); } ParseBlock(pbegin, pend, &(*data)[tid]); } this->data_ptr_ = 0; return true; } template inline void LibSVMParser:: ParseBlock(char *begin, char *end, RowBlockContainer *out) { out->Clear(); char * lbegin = begin; char * lend = lbegin; while (lbegin != end) { // get line end lend = lbegin + 1; while (lend != end && *lend != '\n' && *lend != '\r') ++lend; // parse label[:weight] const char * p = lbegin; const char * q = NULL; real_t label; real_t weight; int r = ParsePair(p, lend, &q, label, weight); if (r < 1) { // empty line lbegin = lend; continue; } if (r == 2) { // has weight out->weight.push_back(weight); } if (out->label.size() != 0) { out->offset.push_back(out->index.size()); } out->label.push_back(label); // parse feature[:value] p = q; while (p != lend) { IndexType featureId; real_t value; int r = ParsePair(p, lend, &q, featureId, value); if (r < 1) { p = q; continue; } out->index.push_back(featureId); if (r == 2) { // has value out->value.push_back(value); } p = q; } // next line lbegin = lend; } if (out->label.size() != 0) { out->offset.push_back(out->index.size()); } CHECK(out->label.size() + 1 == out->offset.size()); } } // namespace data } // namespace dmlc #endif // DMLC_DATA_LIBSVM_PARSER_H_ //===== EXPANDED: mxnet/dmlc-core/src/data/libsvm_parser.h ===== #if DMLC_USE_CXX11 namespace dmlc { namespace data { /*! * \brief basic set of row iterators that provides * \tparam IndexType the type of index we are using */ template class DiskRowIter: public RowBlockIter { public: // page size 64MB static const size_t kPageSize = 64UL << 20UL; /*! * \brief disk row iterator constructor * \param parser parser used to generate this */ explicit DiskRowIter(Parser *parser, const char *cache_file, bool reuse_cache) : cache_file_(cache_file), fi_(NULL) { if (reuse_cache) { if (!TryLoadCache()) { this->BuildCache(parser); CHECK(TryLoadCache()) << "failed to build cache file " << cache_file; } } else { this->BuildCache(parser); CHECK(TryLoadCache()) << "failed to build cache file " << cache_file; } delete parser; } virtual ~DiskRowIter(void) { iter_.Destroy(); delete fi_; } virtual void BeforeFirst(void) { iter_.BeforeFirst(); } virtual bool Next(void) { if (iter_.Next()) { row_ = iter_.Value().GetBlock(); return true; } else { return false; } } virtual const RowBlock &Value(void) const { return row_; } virtual size_t NumCol(void) const { return num_col_; } private: // file place std::string cache_file_; // input stream SeekStream *fi_; // maximum feature dimension size_t num_col_; // row block to store RowBlock row_; // iterator ThreadedIter > iter_; // load disk cache file inline bool TryLoadCache(void); // build disk cache inline void BuildCache(Parser *parser); }; // build disk cache template inline bool DiskRowIter::TryLoadCache(void) { SeekStream *fi = SeekStream::CreateForRead(cache_file_.c_str(), true); if (fi == NULL) return false; this->fi_ = fi; iter_.Init([fi](RowBlockContainer **dptr) { if (*dptr ==NULL) { *dptr = new RowBlockContainer(); } return (*dptr)->Load(fi); }, [fi]() { fi->Seek(0); }); return true; } template inline void DiskRowIter:: BuildCache(Parser *parser) { Stream *fo = Stream::Create(cache_file_.c_str(), "w"); // back end data RowBlockContainer data; num_col_ = 0; double tstart = GetTime(); while (parser->Next()) { data.Push(parser->Value()); double tdiff = GetTime() - tstart; if (data.MemCostBytes() >= kPageSize) { size_t bytes_read = parser->BytesRead(); bytes_read = bytes_read >> 20UL; LOG(INFO) << bytes_read << "MB read," << bytes_read / tdiff << " MB/sec"; data.Save(fo); data.Clear(); num_col_ = std::max(num_col_, static_cast(data.max_index) + 1); } } if (data.Size() != 0) { data.Save(fo); } delete fo; double tdiff = GetTime() - tstart; LOG(INFO) << "finish reading at %g MB/sec" << (parser->BytesRead() >> 20UL) / tdiff; } } // namespace data } // namespace dmlc #endif // DMLC_USE_CXX11 #endif // DMLC_DATA_DISK_ROW_ITER_H_ //===== EXPANDED: mxnet/dmlc-core/src/data/disk_row_iter.h ===== namespace dmlc { /*! \brief namespace for useful input data structure */ namespace data { template inline ParserImpl * CreateParser_(const char *uri_, unsigned part_index, unsigned num_parts, const char *type) { using namespace std; // create parser ParserImpl *parser = NULL; if (!strcmp(type, "libsvm")) { InputSplit* source = InputSplit::Create( uri_, part_index, num_parts, "text"); parser = new LibSVMParser(source, 2); } else { LOG(FATAL) << "unknown datatype " << type; } #if DMLC_USE_CXX11 parser = new ThreadedParser(parser); #endif return parser; } template inline RowBlockIter * CreateIter_(const char *uri_, unsigned part_index, unsigned num_parts, const char *type) { using namespace std; io::URISpec spec(uri_, part_index, num_parts); Parser *parser = CreateParser_ (spec.uri.c_str(), part_index, num_parts, type); if (spec.cache_file.length() != 0) { #if DMLC_USE_CXX11 return new DiskRowIter(parser, spec.cache_file.c_str(), true); #else LOG(FATAL) << "compile with c++0x or c++11 to enable cache file"; return NULL; #endif } else { return new BasicRowIter(parser); } } } // namespace data template<> RowBlockIter * RowBlockIter::Create(const char *uri, unsigned part_index, unsigned num_parts, const char *type) { return data::CreateIter_(uri, part_index, num_parts, type); } template<> RowBlockIter * RowBlockIter::Create(const char *uri, unsigned part_index, unsigned num_parts, const char *type) { return data::CreateIter_(uri, part_index, num_parts, type); } template<> Parser * Parser::Create(const char *uri_, unsigned part_index, unsigned num_parts, const char *type) { return data::CreateParser_(uri_, part_index, num_parts, type); } template<> Parser * Parser::Create(const char *uri_, unsigned part_index, unsigned num_parts, const char *type) { return data::CreateParser_(uri_, part_index, num_parts, type); } } // namespace dmlc //===== EXPANDED: mxnet/dmlc-core/src/data.cc ===== //===== EXPANDIND: mxnet/dmlc-core/src/io/input_split_base.cc ===== // Copyright by Contributors //===== EXPANDIND: mxnet/dmlc-core/src/io/line_split.h ===== /*! * Copyright (c) 2015 by Contributors * \file line_split.h * \brief base class implementation of input splitter * \author Tianqi Chen */ #ifndef DMLC_IO_LINE_SPLIT_H_ #define DMLC_IO_LINE_SPLIT_H_ //===== EXPANDIND: mxnet/dmlc-core/src/io/input_split_base.h ===== /*! * Copyright (c) 2015 by Contributors * \file input_split_base.h * \brief base class to construct input split from multiple files * \author Tianqi Chen */ #ifndef DMLC_IO_INPUT_SPLIT_BASE_H_ #define DMLC_IO_INPUT_SPLIT_BASE_H_ namespace dmlc { namespace io { /*! \brief class to construct input split from multiple files */ class InputSplitBase : public InputSplit { public: /*! * \brief helper struct to hold chunk data * with internal pointer to move along the record */ struct Chunk { char *begin; char *end; std::vector data; explicit Chunk(size_t buffer_size) : begin(NULL), end(NULL), data(buffer_size + 1) {} // load chunk from split bool Load(InputSplitBase *split, size_t buffer_size); }; // 2MB static const size_t kBufferSize = 1UL << 15UL; // destructor virtual ~InputSplitBase(void); // implement BeforeFirst virtual void BeforeFirst(void); virtual void HintChunkSize(size_t chunk_size) { buffer_size_ = std::max(chunk_size / sizeof(size_t), buffer_size_); } // implement next record virtual bool NextRecord(Blob *out_rec) { while (!ExtractNextRecord(out_rec, &tmp_chunk_)) { if (!tmp_chunk_.Load(this, buffer_size_)) return false; } return true; } // implement next chunk virtual bool NextChunk(Blob *out_chunk) { while (!ExtractNextChunk(out_chunk, &tmp_chunk_)) { if (!tmp_chunk_.Load(this, buffer_size_)) return false; } return true; } /*! * \brief read a chunk of data into buf * the data can span multiple records, * but cannot contain partial records * * \param buf the memory region of the buffer, * should be properly aligned to 64 bits * \param size the maximum size of memory, * after the function returns, it stores the size of the chunk * \return whether end of file was reached */ bool ReadChunk(void *buf, size_t *size); /*! * \brief extract next chunk from the chunk * \param out_chunk the output record * \param chunk the chunk information * \return true if non-empty record is extracted * false if the chunk is already finishes its life */ bool ExtractNextChunk(Blob *out_rchunk, Chunk *chunk); /*! * \brief extract next record from the chunk * \param out_rec the output record * \param chunk the chunk information * \return true if non-empty record is extracted * false if the chunk is already finishes its life */ virtual bool ExtractNextRecord(Blob *out_rec, Chunk *chunk) = 0; protected: // constructor InputSplitBase() : fs_(NULL), tmp_chunk_(kBufferSize), buffer_size_(kBufferSize) {} /*! * \brief intialize the base before doing anything * \param fs the filesystem ptr * \param uri the uri of the files * \param rank the rank of the split * \param nsplit number of splits * \param align_bytes the head split must be multiple of align_bytes * this also checks if file size are multiple of align_bytes */ void Init(FileSystem *fs, const char *uri, unsigned rank, unsigned nsplit, size_t align_bytes); // to be implemented by child class /*! * \brief seek to the beginning of the first record * in current file pointer * \return how many bytes we read past */ virtual size_t SeekRecordBegin(Stream *fi) = 0; /*! * \brief find the last occurance of record header * \param begin beginning of the buffer * \param end end of the buffer * \return the pointer between [begin, end] indicating the * last record head */ virtual const char* FindLastRecordBegin(const char *begin, const char *end) = 0; private: /*! \brief FileSystem */ FileSystem *filesys_; /*! \brief information about files */ std::vector files_; /*! \brief current input stream */ SeekStream *fs_; /*! \brief file pointer of which file to read on */ size_t file_ptr_; /*! \brief file pointer where the end of file lies */ size_t file_ptr_end_; /*! \brief get the current offset */ size_t offset_curr_; /*! \brief beginning of offset */ size_t offset_begin_; /*! \brief end of the offset */ size_t offset_end_; /*! \brief temporal chunk */ Chunk tmp_chunk_; /*! \brief buffer size */ size_t buffer_size_; /*! \brief byte-offset of each file */ std::vector file_offset_; /*! \brief internal overflow buffer */ std::string overflow_; /*! \brief initialize information in files */ void InitInputFileInfo(const char *uri); /*! \brief same as stream.Read */ size_t Read(void *ptr, size_t size); }; } // namespace io } // namespace dmlc #endif // DMLC_IO_INPUT_SPLIT_BASE_H_ //===== EXPANDED: mxnet/dmlc-core/src/io/input_split_base.h ===== namespace dmlc { namespace io { /*! \brief class that split the files by line */ class LineSplitter : public InputSplitBase { public: LineSplitter(FileSystem *fs, const char *uri, unsigned rank, unsigned nsplit) { this->Init(fs, uri, rank, nsplit, 1); } virtual bool ExtractNextRecord(Blob *out_rec, Chunk *chunk); protected: virtual size_t SeekRecordBegin(Stream *fi); virtual const char* FindLastRecordBegin(const char *begin, const char *end); }; } // namespace io } // namespace dmlc #endif // DMLC_IO_LINE_SPLIT_H_ //===== EXPANDED: mxnet/dmlc-core/src/io/line_split.h ===== namespace dmlc { namespace io { void InputSplitBase::Init(FileSystem *filesys, const char *uri, unsigned rank, unsigned nsplit, size_t align_bytes) { this->filesys_ = filesys; // initialize the path this->InitInputFileInfo(uri); file_offset_.resize(files_.size() + 1); file_offset_[0] = 0; for (size_t i = 0; i < files_.size(); ++i) { file_offset_[i + 1] = file_offset_[i] + files_[i].size; CHECK(files_[i].size % align_bytes == 0) << "file do not align by " << align_bytes << " bytes"; } size_t ntotal = file_offset_.back(); size_t nstep = (ntotal + nsplit - 1) / nsplit; // align the nstep to 4 bytes nstep = ((nstep + align_bytes - 1) / align_bytes) * align_bytes; offset_begin_ = std::min(nstep * rank, ntotal); offset_end_ = std::min(nstep * (rank + 1), ntotal); offset_curr_ = offset_begin_; if (offset_begin_ == offset_end_) return; file_ptr_ = std::upper_bound(file_offset_.begin(), file_offset_.end(), offset_begin_) - file_offset_.begin() - 1; file_ptr_end_ = std::upper_bound(file_offset_.begin(), file_offset_.end(), offset_end_) - file_offset_.begin() - 1; // find the exact ending position if (offset_end_ != file_offset_[file_ptr_end_]) { CHECK(offset_end_ >file_offset_[file_ptr_end_]); CHECK(file_ptr_end_ < files_.size()); fs_ = filesys_->OpenForRead(files_[file_ptr_end_].path); fs_->Seek(offset_end_ - file_offset_[file_ptr_end_]); offset_end_ += SeekRecordBegin(fs_); delete fs_; } fs_ = filesys_->OpenForRead(files_[file_ptr_].path); if (offset_begin_ != file_offset_[file_ptr_]) { fs_->Seek(offset_begin_ - file_offset_[file_ptr_]); offset_begin_ += SeekRecordBegin(fs_); } this->BeforeFirst(); } void InputSplitBase::BeforeFirst(void) { if (offset_begin_ >= offset_end_) return; size_t fp = std::upper_bound(file_offset_.begin(), file_offset_.end(), offset_begin_) - file_offset_.begin() - 1; if (file_ptr_ != fp) { delete fs_; file_ptr_ = fp; fs_ = filesys_->OpenForRead(files_[file_ptr_].path); } // seek to beginning of stream fs_->Seek(offset_begin_ - file_offset_[file_ptr_]); offset_curr_ = offset_begin_; tmp_chunk_.begin = tmp_chunk_.end = NULL; // clear overflow buffer overflow_.clear(); } InputSplitBase::~InputSplitBase(void) { delete fs_; // no need to delete filesystem, it was singleton } void InputSplitBase::InitInputFileInfo(const char *uri) { // split by : const char *dlm = ";"; std::string uri_ = uri; char *p = std::strtok(BeginPtr(uri_), dlm); std::vector vec; while (p != NULL) { URI path(p); FileInfo info = filesys_->GetPathInfo(path); if (info.type == kDirectory) { std::vector dfiles; filesys_->ListDirectory(info.path, &dfiles); for (size_t i = 0; i < dfiles.size(); ++i) { if (dfiles[i].size != 0 && dfiles[i].type == kFile) { files_.push_back(dfiles[i]); } } } else { if (info.size != 0) { files_.push_back(info); } } p = std::strtok(NULL, dlm); } } size_t InputSplitBase::Read(void *ptr, size_t size) { if (offset_begin_ >= offset_end_) return 0; if (offset_curr_ + size > offset_end_) { size = offset_end_ - offset_curr_; } if (size == 0) return 0; size_t nleft = size; char *buf = reinterpret_cast(ptr); while (true) { size_t n = fs_->Read(buf, nleft); nleft -= n; buf += n; offset_curr_ += n; if (nleft == 0) break; if (n == 0) { if (offset_curr_ != file_offset_[file_ptr_ + 1]) { LOG(ERROR) << "curr=" << offset_curr_ << ",begin=" << offset_begin_ << ",end=" << offset_end_ << ",fileptr=" << file_ptr_ << ",fileoffset=" << file_offset_[file_ptr_ + 1]; for (size_t i = 0; i < file_ptr_; ++i) { LOG(ERROR) << "offset[" << i << "]=" << file_offset_[i]; } LOG(FATAL) << "file offset not calculated correctly"; } if (file_ptr_ + 1 >= files_.size()) break; file_ptr_ += 1; delete fs_; fs_ = filesys_->OpenForRead(files_[file_ptr_].path); } } return size - nleft; } bool InputSplitBase::ReadChunk(void *buf, size_t *size) { size_t max_size = *size; if (max_size <= overflow_.length()) { *size = 0; return true; } if (overflow_.length() != 0) { std::memcpy(buf, BeginPtr(overflow_), overflow_.length()); } size_t olen = overflow_.length(); overflow_.resize(0); size_t nread = this->Read(reinterpret_cast(buf) + olen, max_size - olen); nread += olen; if (nread == 0) return false; if (nread != max_size) { *size = nread; return true; } else { const char *bptr = reinterpret_cast(buf); // return the last position where a record starts const char *bend = this->FindLastRecordBegin(bptr, bptr + max_size); *size = bend - bptr; overflow_.resize(max_size - *size); if (overflow_.length() != 0) { std::memcpy(BeginPtr(overflow_), bend, overflow_.length()); } return true; } } bool InputSplitBase::Chunk::Load(InputSplitBase *split, size_t buffer_size) { if (buffer_size + 1 > data.size()) { data.resize(buffer_size + 1); } while (true) { // leave one tail chunk size_t size = (data.size() - 1) * sizeof(size_t); // set back to 0 for string safety data.back() = 0; if (!split->ReadChunk(BeginPtr(data), &size)) return false; if (size == 0) { data.resize(data.size() * 2); } else { begin = reinterpret_cast(BeginPtr(data)); end = begin + size; break; } } return true; } bool InputSplitBase::ExtractNextChunk(Blob *out_chunk, Chunk *chunk) { if (chunk->begin == chunk->end) return false; out_chunk->dptr = chunk->begin; out_chunk->size = chunk->end - chunk->begin; chunk->begin = chunk->end; return true; } } // namespace io } // namespace dmlc //===== EXPANDED: mxnet/dmlc-core/src/io/input_split_base.cc ===== //===== EXPANDIND: mxnet/dmlc-core/src/io/line_split.cc ===== // Copyright by Contributors namespace dmlc { namespace io { size_t LineSplitter::SeekRecordBegin(Stream *fi) { char c = '\0'; size_t nstep = 0; // search till fist end-of-line while (true) { if (fi->Read(&c, sizeof(c)) == 0) return nstep; nstep += 1; if (c == '\n' || c == '\r') break; } // search until first non-endofline while (true) { if (fi->Read(&c, sizeof(c)) == 0) return nstep; if (c != '\n' && c != '\r') break; // non-end-of-line should not count nstep += 1; } return nstep; } const char* LineSplitter::FindLastRecordBegin(const char *begin, const char *end) { CHECK(begin != end); for (const char *p = end - 1; p != begin; --p) { if (*p == '\n' || *p == '\r') return p + 1; } return begin; } bool LineSplitter::ExtractNextRecord(Blob *out_rec, Chunk *chunk) { if (chunk->begin == chunk->end) return false; char *p; for (p = chunk->begin; p != chunk->end; ++p) { if (*p == '\n' || *p == '\r') break; } for (; p != chunk->end; ++p) { if (*p != '\n' && *p != '\r') break; } // set the string end sign for safety if (p == chunk->end) { *p = '\0'; } else { *(p - 1) = '\0'; } out_rec->dptr = chunk->begin; out_rec->size = p - chunk->begin; chunk->begin = p; return true; } } // namespace io } // namespace dmlc //===== EXPANDED: mxnet/dmlc-core/src/io/line_split.cc ===== //===== EXPANDIND: mxnet/dmlc-core/src/io/local_filesys.cc ===== // Copyright by Contributors extern "C" { } #ifndef _MSC_VER extern "C" { } #else #define stat _stat64 #endif //===== EXPANDIND: mxnet/dmlc-core/src/io/local_filesys.h ===== /*! * Copyright (c) 2015 by Contributors * \file local_filesys.h * \brief local access module * \author Tianqi Chen */ #ifndef DMLC_IO_LOCAL_FILESYS_H_ #define DMLC_IO_LOCAL_FILESYS_H_ namespace dmlc { namespace io { /*! \brief local file system */ class LocalFileSystem : public FileSystem { public: /*! \brief destructor */ virtual ~LocalFileSystem() {} /*! * \brief get information about a path * \param path the path to the file * \return the information about the file */ virtual FileInfo GetPathInfo(const URI &path); /*! * \brief list files in a directory * \param path to the file * \param out_list the output information about the files */ virtual void ListDirectory(const URI &path, std::vector *out_list); /*! * \brief open a stream, will report error and exit if bad thing happens * NOTE: the IStream can continue to work even when filesystem was destructed * \param path path to file * \param uri the uri of the input * \param allow_null whether NULL can be returned, or directly report error * \return the created stream, can be NULL when allow_null == true and file do not exist */ virtual SeekStream *Open(const URI &path, const char* const flag, bool allow_null); /*! * \brief open a seekable stream for read * \param path the path to the file * \param allow_null whether NULL can be returned, or directly report error * \return the created stream, can be NULL when allow_null == true and file do not exist */ virtual SeekStream *OpenForRead(const URI &path, bool allow_null); /*! * \brief get a singleton of LocalFileSystem when needed * \return a singleton instance */ inline static LocalFileSystem *GetInstance(void) { static LocalFileSystem instance; return &instance; } private: LocalFileSystem() {} }; } // namespace io } // namespace dmlc #endif // DMLC_IO_LOCAL_FILESYS_H_ //===== EXPANDED: mxnet/dmlc-core/src/io/local_filesys.h ===== namespace dmlc { namespace io { /*! \brief implementation of file i/o stream */ class FileStream : public SeekStream { public: explicit FileStream(FILE *fp, bool use_stdio) : fp_(fp), use_stdio_(use_stdio) {} virtual ~FileStream(void) { this->Close(); } virtual size_t Read(void *ptr, size_t size) { return std::fread(ptr, 1, size, fp_); } virtual void Write(const void *ptr, size_t size) { CHECK(std::fwrite(ptr, 1, size, fp_) == size) << "FileStream.Write incomplete"; } virtual void Seek(size_t pos) { std::fseek(fp_, static_cast(pos), SEEK_SET); // NOLINT(*) } virtual size_t Tell(void) { return std::ftell(fp_); } virtual bool AtEnd(void) const { return std::feof(fp_) != 0; } inline void Close(void) { if (fp_ != NULL && !use_stdio_) { std::fclose(fp_); fp_ = NULL; } } private: std::FILE *fp_; bool use_stdio_; }; FileInfo LocalFileSystem::GetPathInfo(const URI &path) { struct stat sb; if (stat(path.name.c_str(), &sb) == -1) { int errsv = errno; LOG(FATAL) << "LocalFileSystem.GetPathInfo " << path.name << " Error:" << strerror(errsv); } FileInfo ret; ret.path = path; ret.size = sb.st_size; if ((sb.st_mode & S_IFMT) == S_IFDIR) { ret.type = kDirectory; } else { ret.type = kFile; } return ret; } void LocalFileSystem::ListDirectory(const URI &path, std::vector *out_list) { #ifndef _MSC_VER DIR *dir = opendir(path.name.c_str()); if (dir == NULL) { int errsv = errno; LOG(FATAL) << "LocalFileSystem.ListDirectory " << path.str() <<" error: " << strerror(errsv); } out_list->clear(); struct dirent *ent; /* print all the files and directories within directory */ while ((ent = readdir(dir)) != NULL) { if (!strcmp(ent->d_name, ".")) continue; if (!strcmp(ent->d_name, "..")) continue; URI pp = path; if (pp.name[pp.name.length() - 1] != '/') { pp.name += '/'; } pp.name += ent->d_name; out_list->push_back(GetPathInfo(pp)); } closedir(dir); #else WIN32_FIND_DATA fd; std::string pattern = path.name + "/*"; HANDLE handle = FindFirstFile(pattern.c_str(), &fd); if (handle == INVALID_HANDLE_VALUE) { int errsv = GetLastError(); LOG(FATAL) << "LocalFileSystem.ListDirectory " << path.str() << " error: " << strerror(errsv); } do { if (strcmp(fd.cFileName, ".") && strcmp(fd.cFileName, "..")) { URI pp = path; char clast = pp.name[pp.name.length() - 1]; if (pp.name == ".") { pp.name = fd.cFileName; } else if (clast != '/' && clast != '\\') { pp.name += '/'; pp.name += fd.cFileName; } out_list->push_back(GetPathInfo(pp)); } } while (FindNextFile(handle, &fd)); FindClose(handle); #endif } SeekStream *LocalFileSystem::Open(const URI &path, const char* const mode, bool allow_null) { bool use_stdio = false; FILE *fp = NULL; const char *fname = path.name.c_str(); using namespace std; #ifndef DMLC_STRICT_CXX98_ if (!strcmp(fname, "stdin")) { use_stdio = true; fp = stdin; } if (!strcmp(fname, "stdout")) { use_stdio = true; fp = stdout; } #endif if (!strncmp(fname, "file://", 7)) fname += 7; if (!use_stdio) { std::string flag = mode; if (flag == "w") flag = "wb"; if (flag == "r") flag = "rb"; fp = fopen64(fname, flag.c_str()); } if (fp != NULL) { return new FileStream(fp, use_stdio); } else { CHECK(allow_null) << " LocalFileSystem: fail to open \"" << path.str() << '\"'; return NULL; } } SeekStream *LocalFileSystem::OpenForRead(const URI &path, bool allow_null) { return Open(path, "r", allow_null); } } // namespace io } // namespace dmlc //===== EXPANDED: mxnet/dmlc-core/src/io/local_filesys.cc ===== //===== EXPANDIND: mxnet/dmlc-core/src/io/recordio_split.cc ===== // Copyright by Contributors //===== EXPANDIND: mxnet/dmlc-core/src/io/recordio_split.h ===== /*! * Copyright (c) 2015 by Contributors * \file recordio_split.h * \brief input split that splits recordio files * \author Tianqi Chen */ #ifndef DMLC_IO_RECORDIO_SPLIT_H_ #define DMLC_IO_RECORDIO_SPLIT_H_ namespace dmlc { namespace io { /*! \brief class that split the files by line */ class RecordIOSplitter : public InputSplitBase { public: RecordIOSplitter(FileSystem *fs, const char *uri, unsigned rank, unsigned nsplit) { this->Init(fs, uri, rank, nsplit, 4); } virtual bool ExtractNextRecord(Blob *out_rec, Chunk *chunk); protected: virtual size_t SeekRecordBegin(Stream *fi); virtual const char* FindLastRecordBegin(const char *begin, const char *end); }; } // namespace io } // namespace dmlc #endif // DMLC_IO_RECORDIO_SPLIT_H_ //===== EXPANDED: mxnet/dmlc-core/src/io/recordio_split.h ===== namespace dmlc { namespace io { size_t RecordIOSplitter::SeekRecordBegin(Stream *fi) { size_t nstep = 0; uint32_t v, lrec; while (true) { if (fi->Read(&v, sizeof(v)) == 0) return nstep; nstep += sizeof(v); if (v == RecordIOWriter::kMagic) { CHECK(fi->Read(&lrec, sizeof(lrec)) != 0) << "invalid record io format"; nstep += sizeof(lrec); uint32_t cflag = RecordIOWriter::DecodeFlag(lrec); if (cflag == 0 || cflag == 1) break; } } // should point at head of record return nstep - 2 * sizeof(uint32_t); } const char* RecordIOSplitter::FindLastRecordBegin(const char *begin, const char *end) { CHECK_EQ((reinterpret_cast(begin) & 3UL), 0); CHECK_EQ((reinterpret_cast(end) & 3UL), 0); const uint32_t *pbegin = reinterpret_cast(begin); const uint32_t *p = reinterpret_cast(end); CHECK(p >= pbegin + 2); for (p = p - 2; p != pbegin; --p) { if (p[0] == RecordIOWriter::kMagic) { uint32_t cflag = RecordIOWriter::DecodeFlag(p[1]); if (cflag == 0 || cflag == 1) { return reinterpret_cast(p); } } } return begin; } bool RecordIOSplitter::ExtractNextRecord(Blob *out_rec, Chunk *chunk) { if (chunk->begin == chunk->end) return false; CHECK(chunk->begin + 2 * sizeof(uint32_t) <= chunk->end) << "Invalid RecordIO Format"; CHECK_EQ((reinterpret_cast(chunk->begin) & 3UL), 0); CHECK_EQ((reinterpret_cast(chunk->end) & 3UL), 0); uint32_t *p = reinterpret_cast(chunk->begin); uint32_t cflag = RecordIOWriter::DecodeFlag(p[1]); uint32_t clen = RecordIOWriter::DecodeLength(p[1]); // skip header out_rec->dptr = chunk->begin + 2 * sizeof(uint32_t); // move pbegin chunk->begin += 2 * sizeof(uint32_t) + (((clen + 3U) >> 2U) << 2U); CHECK(chunk->begin <= chunk->end) << "Invalid RecordIO Format"; out_rec->size = clen; if (cflag == 0) return true; const uint32_t kMagic = RecordIOWriter::kMagic; // abnormal path, move data around to make a full part CHECK(cflag == 1U) << "Invalid RecordIO Format"; while (cflag != 3U) { CHECK(chunk->begin + 2 * sizeof(uint32_t) <= chunk->end); p = reinterpret_cast(chunk->begin); CHECK(p[0] == RecordIOWriter::kMagic); cflag = RecordIOWriter::DecodeFlag(p[1]); clen = RecordIOWriter::DecodeLength(p[1]); // pad kmagic in between std::memcpy(reinterpret_cast(out_rec->dptr) + out_rec->size, &kMagic, sizeof(kMagic)); out_rec->size += sizeof(kMagic); // move the rest of the blobs if (clen != 0) { std::memmove(reinterpret_cast(out_rec->dptr) + out_rec->size, chunk->begin + 2 * sizeof(uint32_t), clen); out_rec->size += clen; } chunk->begin += 2 * sizeof(uint32_t) + (((clen + 3U) >> 2U) << 2U); } return true; } } // namespace io } // namespace dmlc //===== EXPANDED: mxnet/dmlc-core/src/io/recordio_split.cc ===== //===== EXPANDIND: mxnet/dmlc-core/src/io.cc ===== // Copyright by Contributors //===== EXPANDIND: mxnet/dmlc-core/src/io/single_file_split.h ===== /*! * Copyright (c) 2015 by Contributors * \file single_file_split.h * \brief base implementation of line-spliter * \author Tianqi Chen */ #ifndef DMLC_IO_SINGLE_FILE_SPLIT_H_ #define DMLC_IO_SINGLE_FILE_SPLIT_H_ namespace dmlc { namespace io { /*! * \brief line split implementation from single FILE * simply returns lines of files, used for stdin */ class SingleFileSplit : public InputSplit { public: explicit SingleFileSplit(const char *fname) : use_stdin_(false), buffer_size_(kBufferSize), chunk_begin_(NULL), chunk_end_(NULL) { if (!std::strcmp(fname, "stdin")) { #ifndef DMLC_STRICT_CXX98_ use_stdin_ = true; fp_ = stdin; #endif } if (!use_stdin_) { fp_ = fopen64(fname, "rb"); CHECK(fp_ != NULL) << "SingleFileSplit: fail to open " << fname; } buffer_.resize(kBufferSize); } virtual ~SingleFileSplit(void) { if (!use_stdin_) std::fclose(fp_); } virtual void BeforeFirst(void) { fseek(fp_, 0, SEEK_SET); } virtual void HintChunkSize(size_t chunk_size) { buffer_size_ = std::max(chunk_size, buffer_size_); } virtual size_t Read(void *ptr, size_t size) { return std::fread(ptr, 1, size, fp_); } virtual void Write(const void *ptr, size_t size) { LOG(FATAL) << "InputSplit do not support write"; } virtual bool NextRecord(Blob *out_rec) { if (chunk_begin_ == chunk_end_) { if (!LoadChunk()) return false; } char *next = FindNextRecord(chunk_begin_, chunk_end_); out_rec->dptr = chunk_begin_; out_rec->size = next - chunk_begin_; chunk_begin_ = next; return true; } virtual bool NextChunk(Blob *out_chunk) { if (chunk_begin_ == chunk_end_) { if (!LoadChunk()) return false; } out_chunk->dptr = chunk_begin_; out_chunk->size = chunk_end_ - chunk_begin_; chunk_begin_ = chunk_end_; return true; } inline bool ReadChunk(void *buf, size_t *size) { size_t max_size = *size; if (max_size <= overflow_.length()) { *size = 0; return true; } if (overflow_.length() != 0) { std::memcpy(buf, BeginPtr(overflow_), overflow_.length()); } size_t olen = overflow_.length(); overflow_.resize(0); size_t nread = this->Read(reinterpret_cast(buf) + olen, max_size - olen); nread += olen; if (nread == 0) return false; if (nread != max_size) { *size = nread; return true; } else { const char *bptr = reinterpret_cast(buf); // return the last position where a record starts const char *bend = this->FindLastRecordBegin(bptr, bptr + max_size); *size = bend - bptr; overflow_.resize(max_size - *size); if (overflow_.length() != 0) { std::memcpy(BeginPtr(overflow_), bend, overflow_.length()); } return true; } } protected: inline const char* FindLastRecordBegin(const char *begin, const char *end) { if (begin == end) return begin; for (const char *p = end - 1; p != begin; --p) { if (*p == '\n' || *p == '\r') return p + 1; } return begin; } inline char* FindNextRecord(char *begin, char *end) { char *p; for (p = begin; p != end; ++p) { if (*p == '\n' || *p == '\r') break; } for (; p != end; ++p) { if (*p != '\n' && *p != '\r') return p; } return end; } inline bool LoadChunk(void) { if (buffer_.length() < buffer_size_) { buffer_.resize(buffer_size_); } while (true) { size_t size = buffer_.length(); if (!ReadChunk(BeginPtr(buffer_), &size)) return false; if (size == 0) { buffer_.resize(buffer_.length() * 2); } else { chunk_begin_ = reinterpret_cast(BeginPtr(buffer_)); chunk_end_ = chunk_begin_ + size; break; } } return true; } private: // buffer size static const size_t kBufferSize = 1 << 18UL; // file std::FILE *fp_; bool use_stdin_; // internal overflow std::string overflow_; // internal buffer std::string buffer_; // internal buffer size size_t buffer_size_; // beginning of chunk char *chunk_begin_; // end of chunk char *chunk_end_; }; } // namespace io } // namespace dmlc #endif // DMLC_IO_SINGLE_FILE_SPLIT_H_ //===== EXPANDED: mxnet/dmlc-core/src/io/single_file_split.h ===== //===== EXPANDIND: mxnet/dmlc-core/src/io/threaded_input_split.h ===== /*! * Copyright (c) 2015 by Contributors * \file threaded_input_split.h * \brief a threaded version of InputSplit with a prefetch thread * \author Tianqi Chen */ #ifndef DMLC_IO_THREADED_INPUT_SPLIT_H_ #define DMLC_IO_THREADED_INPUT_SPLIT_H_ // this code depends on c++11 #if DMLC_USE_CXX11 namespace dmlc { namespace io { /*! * \brief a threaded version of InputSplit * wraps an InputSplitBase to use an thread to prefetch the data */ class ThreadedInputSplit : public InputSplit { public: /*! * \brief constructor * \param base an base object to define how to read data */ explicit ThreadedInputSplit(InputSplitBase *base) : buffer_size_(InputSplitBase::kBufferSize), base_(base), tmp_chunk_(NULL) { iter_.set_max_capacity(8); // initalize the iterator iter_.Init([this](InputSplitBase::Chunk **dptr) { if (*dptr == NULL) { *dptr = new InputSplitBase::Chunk(buffer_size_); } return (*dptr)->Load(base_, buffer_size_); }, [base]() { base->BeforeFirst(); }); } // destructor virtual ~ThreadedInputSplit(void) { iter_.Destroy(); delete tmp_chunk_; delete base_; } virtual void BeforeFirst() { iter_.BeforeFirst(); if (tmp_chunk_ != NULL) { iter_.Recycle(&tmp_chunk_); } } virtual void HintChunkSize(size_t chunk_size) { buffer_size_ = std::max(chunk_size / sizeof(size_t), buffer_size_); } // implement next record virtual bool NextRecord(Blob *out_rec) { if (tmp_chunk_ == NULL) { if (!iter_.Next(&tmp_chunk_)) return false; } while (!base_->ExtractNextRecord(out_rec, tmp_chunk_)) { iter_.Recycle(&tmp_chunk_); if (!iter_.Next(&tmp_chunk_)) return false; } return true; } // implement next chunk virtual bool NextChunk(Blob *out_chunk) { if (tmp_chunk_ == NULL) { if (!iter_.Next(&tmp_chunk_)) return false; } while (!base_->ExtractNextChunk(out_chunk, tmp_chunk_)) { iter_.Recycle(&tmp_chunk_); if (!iter_.Next(&tmp_chunk_)) return false; } return true; } private: /*! \brief internal buffer size */ size_t buffer_size_; /*! \brief the place where we get the data */ InputSplitBase *base_; /*! \brief backend thread iterator */ ThreadedIter iter_; /*! \brief current chunk of data */ InputSplitBase::Chunk *tmp_chunk_; }; } // namespace io } // namespace dmlc #endif // DMLC_USE_CXX11 #endif // DMLC_IO_THREADED_INPUT_SPLIT_H_ //===== EXPANDED: mxnet/dmlc-core/src/io/threaded_input_split.h ===== //===== EXPANDIND: mxnet/dmlc-core/src/io/cached_input_split.h ===== /*! * Copyright (c) 2015 by Contributors * \file cached_input_split.h * \brief InputSplit that reads from an existing InputSplit * and cache the data into local disk, the second iteration * will be reading from the local cached data * \author Tianqi Chen */ #ifndef DMLC_IO_CACHED_INPUT_SPLIT_H_ #define DMLC_IO_CACHED_INPUT_SPLIT_H_ // this code depends on c++11 #if DMLC_USE_CXX11 namespace dmlc { namespace io { /*! * \brief InputSplit that reads from an existing InputSplit * and cache the data into local disk, the second iteration * will be reading from the local cached data */ class CachedInputSplit : public InputSplit { public: /*! * \brief constructor * \param base source input split * \param cache_file the path to cache file * \param reuse_exist_cache whether reuse existing cache file, if any */ CachedInputSplit(InputSplitBase *base, const char *cache_file, bool reuse_exist_cache = true) : buffer_size_(InputSplitBase::kBufferSize), cache_file_(cache_file), fo_(NULL), fi_(NULL), base_(base), tmp_chunk_(NULL), iter_preproc_(NULL) { if (reuse_exist_cache) { if (!this->InitCachedIter()) { this->InitPreprocIter(); } } else { this->InitPreprocIter(); } } // destructor virtual ~CachedInputSplit(void) { // NOTE delete can handle NULL ptr // deletion order matters delete iter_preproc_; delete fo_; iter_cached_.Destroy(); delete tmp_chunk_; delete base_; delete fi_; } virtual void BeforeFirst(void) { // if preprocessing did not end // pull data from preprocessing module if (iter_preproc_ != NULL) { if (tmp_chunk_ != NULL) { iter_preproc_->Recycle(&tmp_chunk_); } while (iter_preproc_->Next(&tmp_chunk_)) { iter_preproc_->Recycle(&tmp_chunk_); } // finalize the push out process delete iter_preproc_; delete fo_; iter_preproc_ = NULL; fo_ = NULL; CHECK(this->InitCachedIter()) << "Failed to initialize CachedIter"; } else { iter_cached_.BeforeFirst(); } if (tmp_chunk_ != NULL) { iter_cached_.Recycle(&tmp_chunk_); } } virtual void HintChunkSize(size_t chunk_size) { buffer_size_ = std::max(chunk_size / sizeof(size_t), buffer_size_); } // implement next record virtual bool NextRecord(Blob *out_rec) { auto *iter = iter_preproc_ != NULL ? iter_preproc_ : &iter_cached_; if (tmp_chunk_ == NULL) { if (!iter->Next(&tmp_chunk_)) return false; } while (!base_->ExtractNextRecord(out_rec, tmp_chunk_)) { iter->Recycle(&tmp_chunk_); if (!iter->Next(&tmp_chunk_)) return false; } return true; } // implement next chunk virtual bool NextChunk(Blob *out_chunk) { auto *iter = iter_preproc_ != NULL ? iter_preproc_ : &iter_cached_; if (tmp_chunk_ == NULL) { if (!iter->Next(&tmp_chunk_)) return false; } while (!base_->ExtractNextChunk(out_chunk, tmp_chunk_)) { iter->Recycle(&tmp_chunk_); if (!iter->Next(&tmp_chunk_)) return false; } return true; } private: /*! \brief internal buffer size */ size_t buffer_size_; /*! \brief cache file path */ std::string cache_file_; /*! \brief output stream to cache file*/ dmlc::Stream *fo_; /*! \brief input stream from cache file */ dmlc::SeekStream *fi_; /*! \brief the place where we get the data */ InputSplitBase *base_; /*! \brief current chunk of data */ InputSplitBase::Chunk *tmp_chunk_; /*! \brief backend thread iterator for preprocessing */ ThreadedIter *iter_preproc_; /*! \brief backend thread iterator for cache */ ThreadedIter iter_cached_; /*! \brief initialize the cached iterator */ inline void InitPreprocIter(void); /*! * \brief initialize the cached iterator * \return wheher the file exist and * initialization is successful */ inline bool InitCachedIter(void); }; inline void CachedInputSplit:: InitPreprocIter(void) { fo_ = dmlc::Stream::Create(cache_file_.c_str(), "w"); iter_preproc_ = new ThreadedIter(); iter_preproc_->set_max_capacity(16); iter_preproc_->Init([this](InputSplitBase::Chunk **dptr) { if (*dptr == NULL) { *dptr = new InputSplitBase::Chunk(buffer_size_); } auto *p = *dptr; if (!p->Load(base_, buffer_size_)) return false; // after loading, save to disk size_t size = p->end - p->begin; fo_->Write(&size, sizeof(size)); fo_->Write(p->begin, size); return true; }); } inline bool CachedInputSplit::InitCachedIter(void) { fi_ = dmlc::SeekStream::CreateForRead(cache_file_.c_str(), true); if (fi_ == NULL) return false; iter_cached_.Init([this](InputSplitBase::Chunk **dptr) { if (*dptr == NULL) { *dptr = new InputSplitBase::Chunk(buffer_size_); } auto *p = *dptr; // read data from cache file size_t size; size_t nread = fi_->Read(&size, sizeof(size)); if (nread == 0) return false; CHECK(nread == sizeof(size)) << cache_file_ << " has invalid cache file format"; p->data.resize(size / sizeof(size_t) + 1); p->begin = reinterpret_cast(BeginPtr(p->data)); p->end = p->begin + size; CHECK(fi_->Read(p->begin, size) == size) << cache_file_ << " has invalid cache file format"; return true; }, [this]() { fi_->Seek(0); }); return true; } } // namespace io } // namespace dmlc #endif // DMLC_USE_CXX11 #endif // DMLC_IO_CACHED_INPUT_SPLIT_H_ //===== EXPANDED: mxnet/dmlc-core/src/io/cached_input_split.h ===== #if DMLC_USE_HDFS #endif #if DMLC_USE_S3 #endif #if DMLC_USE_AZURE #endif namespace dmlc { namespace io { FileSystem *FileSystem::GetInstance(const std::string &protocol) { if (protocol == "file://" || protocol.length() == 0) { return LocalFileSystem::GetInstance(); } if (protocol == "hdfs://") { #if DMLC_USE_HDFS return HDFSFileSystem::GetInstance(); #else LOG(FATAL) << "Please compile with DMLC_USE_HDFS=1 to use hdfs"; #endif } if (protocol == "s3://" || protocol == "http://" || protocol == "https://") { #if DMLC_USE_S3 return S3FileSystem::GetInstance(); #else LOG(FATAL) << "Please compile with DMLC_USE_S3=1 to use S3"; #endif } if (protocol == "azure://") { #if DMLC_USE_AZURE return AzureFileSystem::GetInstance(); #else LOG(FATAL) << "Please compile with DMLC_USE_AZURE=1 to use Azure"; #endif } LOG(FATAL) << "unknown filesystem protocol " + protocol; return NULL; } } // namespace io InputSplit* InputSplit::Create(const char *uri_, unsigned part, unsigned nsplit, const char *type) { using namespace std; using namespace dmlc::io; // allow cachefile in format path#cachefile io::URISpec spec(uri_, part, nsplit); if (!strcmp(spec.uri.c_str(), "stdin")) { return new SingleFileSplit(spec.uri.c_str()); } CHECK(part < nsplit) << "invalid input parameter for InputSplit::Create"; URI path(spec.uri.c_str()); InputSplitBase *split = NULL; if (!strcmp(type, "text")) { split = new LineSplitter(FileSystem::GetInstance(path.protocol), spec.uri.c_str(), part, nsplit); } else if (!strcmp(type, "recordio")) { split = new RecordIOSplitter(FileSystem::GetInstance(path.protocol), spec.uri.c_str(), part, nsplit); } else { LOG(FATAL) << "unknown input split type " << type; } #if DMLC_USE_CXX11 if (spec.cache_file.length() == 0) { return new ThreadedInputSplit(split); } else { return new CachedInputSplit(split, spec.cache_file.c_str()); } #else CHECK(spec.cache_file.length() == 0) << "to enable cached file, compile with c++11"; return split; #endif } Stream *Stream::Create(const char *uri, const char * const flag, bool try_create) { io::URI path(uri); return io::FileSystem:: GetInstance(path.protocol)->Open(path, flag, try_create); } SeekStream *SeekStream::CreateForRead(const char *uri, bool try_create) { io::URI path(uri); return io::FileSystem:: GetInstance(path.protocol)->OpenForRead(path, try_create); } } // namespace dmlc //===== EXPANDED: mxnet/dmlc-core/src/io.cc ===== //===== EXPANDIND: mxnet/dmlc-core/src/recordio.cc ===== // Copyright by Contributors namespace dmlc { // implemmentation void RecordIOWriter::WriteRecord(const void *buf, size_t size) { CHECK(size < (1 << 29U)) << "RecordIO only accept record less than 2^29 bytes"; const uint32_t umagic = kMagic; // initialize the magic number, in stack const char *magic = reinterpret_cast(&umagic); const char *bhead = reinterpret_cast(buf); uint32_t len = static_cast(size); uint32_t lower_align = (len >> 2U) << 2U; uint32_t upper_align = ((len + 3U) >> 2U) << 2U; uint32_t dptr = 0; for (uint32_t i = 0; i < lower_align ; i += 4) { // use char check for alignment safety reason if (bhead[i] == magic[0] && bhead[i + 1] == magic[1] && bhead[i + 2] == magic[2] && bhead[i + 3] == magic[3]) { uint32_t lrec = EncodeLRec(dptr == 0 ? 1U : 2U, i - dptr); stream_->Write(magic, 4); stream_->Write(&lrec, sizeof(lrec)); if (i != dptr) { stream_->Write(bhead + dptr, i - dptr); } dptr = i + 4; except_counter_ += 1; } } uint32_t lrec = EncodeLRec(dptr != 0 ? 3U : 0U, len - dptr); stream_->Write(magic, 4); stream_->Write(&lrec, sizeof(lrec)); if (len != dptr) { stream_->Write(bhead + dptr, len - dptr); } // write padded bytes uint32_t zero = 0; if (upper_align != len) { stream_->Write(&zero, upper_align - len); } } bool RecordIOReader::NextRecord(std::string *out_rec) { if (end_of_stream_) return false; const uint32_t kMagic = RecordIOWriter::kMagic; out_rec->clear(); size_t size = 0; while (true) { uint32_t header[2]; size_t nread = stream_->Read(header, sizeof(header)); if (nread == 0) { end_of_stream_ = true; return false; } CHECK(nread == sizeof(header)) << "Inavlid RecordIO File"; CHECK(header[0] == RecordIOWriter::kMagic) << "Invalid RecordIO File"; uint32_t cflag = RecordIOWriter::DecodeFlag(header[1]); uint32_t len = RecordIOWriter::DecodeLength(header[1]); uint32_t upper_align = ((len + 3U) >> 2U) << 2U; out_rec->resize(size + upper_align); if (upper_align != 0) { CHECK(stream_->Read(BeginPtr(*out_rec) + size, upper_align) == upper_align) << "Invalid RecordIO File upper_align=" << upper_align; } // squeeze back size += len; out_rec->resize(size); if (cflag == 0U || cflag == 3U) break; out_rec->resize(size + sizeof(kMagic)); std::memcpy(BeginPtr(*out_rec) + size, &kMagic, sizeof(kMagic)); size += sizeof(kMagic); } return true; } // helper function to find next recordio head inline char *FindNextRecordIOHead(char *begin, char *end) { CHECK_EQ((reinterpret_cast(begin) & 3UL), 0); CHECK_EQ((reinterpret_cast(end) & 3UL), 0); uint32_t *p = reinterpret_cast(begin); uint32_t *pend = reinterpret_cast(end); for (; p + 1 < pend; ++p) { if (p[0] == RecordIOWriter::kMagic) { uint32_t cflag = RecordIOWriter::DecodeFlag(p[1]); if (cflag == 0 || cflag == 1) { return reinterpret_cast(p); } } } return end; } RecordIOChunkReader::RecordIOChunkReader(InputSplit::Blob chunk, unsigned part_index, unsigned num_parts) { size_t nstep = (chunk.size + num_parts - 1) / num_parts; // align nstep = ((nstep + 3UL) >> 2UL) << 2UL; size_t begin = std::min(chunk.size, nstep * part_index); size_t end = std::min(chunk.size, nstep * (part_index + 1)); char *head = reinterpret_cast(chunk.dptr); pbegin_ = FindNextRecordIOHead(head + begin, head + chunk.size); pend_ = FindNextRecordIOHead(head + end, head + chunk.size); } bool RecordIOChunkReader::NextRecord(InputSplit::Blob *out_rec) { if (pbegin_ >= pend_) return false; uint32_t *p = reinterpret_cast(pbegin_); CHECK(p[0] == RecordIOWriter::kMagic); uint32_t cflag = RecordIOWriter::DecodeFlag(p[1]); uint32_t clen = RecordIOWriter::DecodeLength(p[1]); if (cflag == 0) { // skip header out_rec->dptr = pbegin_ + 2 * sizeof(uint32_t); // move pbegin pbegin_ += 2 * sizeof(uint32_t) + (((clen + 3U) >> 2U) << 2U); CHECK(pbegin_ <= pend_) << "Invalid RecordIO Format"; out_rec->size = clen; return true; } else { const uint32_t kMagic = RecordIOWriter::kMagic; // abnormal path, read into string CHECK(cflag == 1U) << "Invalid RecordIO Format"; temp_.resize(0); while (true) { CHECK(pbegin_ + 2 * sizeof(uint32_t) <= pend_); p = reinterpret_cast(pbegin_); CHECK(p[0] == RecordIOWriter::kMagic); cflag = RecordIOWriter::DecodeFlag(p[1]); clen = RecordIOWriter::DecodeLength(p[1]); size_t tsize = temp_.length(); temp_.resize(tsize + clen); if (clen != 0) { std::memcpy(BeginPtr(temp_) + tsize, pbegin_ + 2 * sizeof(uint32_t), clen); tsize += clen; } pbegin_ += 2 * sizeof(uint32_t) + (((clen + 3U) >> 2U) << 2U); if (cflag == 3U) break; temp_.resize(tsize + sizeof(kMagic)); std::memcpy(BeginPtr(temp_) + tsize, &kMagic, sizeof(kMagic)); } out_rec->dptr = BeginPtr(temp_); out_rec->size = temp_.length(); return true; } } } // namespace dmlc //===== EXPANDED: mxnet/dmlc-core/src/recordio.cc ===== //===== EXPANDED: mxnet0.cc =====