Caffe2 - C++ API
A deep learning, cross platform ML framework
mkl_context.h
1 #ifndef CAFFE2_UTILS_MKL_CONTEXT_H_
2 #define CAFFE2_UTILS_MKL_CONTEXT_H_
3 
4 #include <cstdlib>
5 #include <ctime>
6 #include <random>
7 
8 #include "caffe2/core/context.h"
9 
10 namespace caffe2 {
11 
20 class MKLContext final {
21  public:
22  MKLContext() : random_seed_(RandomNumberSeed()) {}
23  explicit MKLContext(const DeviceOption& option)
24  : random_seed_(
25  option.has_random_seed() ? option.random_seed()
26  : RandomNumberSeed()) {
27  CAFFE_ENFORCE_EQ(option.device_type(), MKLDNN);
28  }
29 
30  ~MKLContext() {}
31 
32  inline void SwitchToDevice(int /*stream_id*/ = 0) {}
33 
34  inline void WaitEvent(const Event& ev) {
35  ev.Wait(MKLDNN, this);
36  }
37 
38  inline void Record(Event* ev, const char* err_msg = nullptr) const {
39  CAFFE_ENFORCE(ev, "Event must not be null.");
40  ev->Record(MKLDNN, this, err_msg);
41  }
42 
43  inline void FinishDeviceComputation() {}
44 
45  inline std::mt19937& RandGenerator() {
46  if (!random_generator_.get()) {
47  random_generator_.reset(new std::mt19937(random_seed_));
48  }
49  return *random_generator_.get();
50  }
51 
52  inline static std::pair<void*, MemoryDeleter> New(size_t nbytes) {
53  return GetCPUAllocator()->New(nbytes);
54  }
55 
56  // Two copy functions that deals with cross-device copies.
57  template <class SrcContext, class DstContext>
58  inline void CopyBytes(size_t nbytes, const void* src, void* dst);
59 
60  template <typename T, class SrcContext, class DstContext>
61  inline void Copy(size_t n, const T* src, T* dst) {
62  if (std::is_fundamental<T>::value) {
63  CopyBytes<SrcContext, DstContext>(
64  n * sizeof(T),
65  static_cast<const void*>(src),
66  static_cast<void*>(dst));
67  } else {
68  for (int i = 0; i < n; ++i) {
69  dst[i] = src[i];
70  }
71  }
72  }
73 
74  template <class SrcContext, class DstContext>
75  inline void
76  CopyItems(const TypeMeta& meta, size_t n, const void* src, void* dst) {
77  if (meta.copy()) {
78  meta.copy()(src, dst, n);
79  } else {
80  CopyBytes<SrcContext, DstContext>(n * meta.itemsize(), src, dst);
81  }
82  }
83 
84  // By default MKL operators don't have async device parts
85  static bool HasAsyncPartDefault() {
86  return false;
87  }
88 
89  static bool SupportsAsyncScheduling() {
90  return false;
91  }
92 
93  static bool IsStreamFree(const DeviceOption& /* unused */, int /* unused */) {
94  return true;
95  }
96 
97  protected:
98  // TODO(jiayq): instead of hard-coding a generator, make it more flexible.
99  int random_seed_{1701};
100  std::unique_ptr<std::mt19937> random_generator_;
101 };
102 
103 template <>
104 inline void MKLContext::CopyBytes<MKLContext, MKLContext>(
105  size_t nbytes,
106  const void* src,
107  void* dst) {
108  memcpy(dst, src, nbytes);
109 }
110 
111 template <>
112 inline void MKLContext::CopyBytes<CPUContext, MKLContext>(
113  size_t nbytes,
114  const void* src,
115  void* dst) {
116  memcpy(dst, src, nbytes);
117 }
118 
119 template <>
120 inline void MKLContext::CopyBytes<MKLContext, CPUContext>(
121  size_t nbytes,
122  const void* src,
123  void* dst) {
124  memcpy(dst, src, nbytes);
125 }
126 } // namespace caffe2
127 
128 #endif // CAFFE2_UTILS_MKL_CONTEXT_H_
The MKL Context, which is largely the same as the CPUContext.
Definition: mkl_context.h:20
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
TypedCopy copy() const
Returns the typed copy function pointer for individual iterms.
Definition: typeid.h:155
TypeMeta is a thin class that allows us to store the type of a container such as a blob...
Definition: typeid.h:88
const size_t & itemsize() const
Returns the size of the item.
Definition: typeid.h:143
uint32_t RandomNumberSeed()
A function to generate a random number seed that is unique in a best-effort basis, using an ever-incrementing seed and the current time.
Definition: context.cc:10