Caffe2 - C++ API
A deep learning, cross platform ML framework
NeuralNet.h
1 //=== nomnigraph/Representations/NeuralNet.h - NN interface -----*- C++ -*-===//
2 //
3 // TODO Licensing.
4 //
5 //===----------------------------------------------------------------------===//
6 //
7 // This file defines classes that can be used in a
8 // nom::Graph<nom::repr::NeuralNetOperator, nom::repr::NeuralNetData> graph.
9 //
10 //===----------------------------------------------------------------------===//
11 
12 #ifndef NOM_REPRESENTATIONS_NEURALNET_H
13 #define NOM_REPRESENTATIONS_NEURALNET_H
14 
15 #include "nomnigraph/Graph/Graph.h"
16 #include "nomnigraph/Representations/Compiler.h"
17 #include "nomnigraph/Representations/ControlFlow.h"
18 #include "nomnigraph/Support/Casting.h"
19 #include "nomnigraph/Support/Pointer.h"
20 
21 #include <string>
22 #include <type_traits>
23 #include <vector>
24 
25 #include <assert.h>
26 
27 namespace nom {
28 namespace repr {
29 
30 class NeuralNetData;
31 
37 class Annotation {
38 public:
39  enum class AnnotationKind { Generic, Device };
40 
41  Annotation(AnnotationKind K) : Kind(K) {}
42  Annotation() : Kind(AnnotationKind::Generic) {}
43 
44  AnnotationKind getKind() const { return Kind; }
45 
46  Annotation(const Annotation &) = delete;
47  Annotation &operator=(Annotation &) = delete;
48 
49  void *getSaved() const { return Saved; }
50  void setSaved(void *saved) { Saved = saved; }
51 
52 private:
53  const AnnotationKind Kind;
54  void *Saved = nullptr;
55 };
56 
57 class DeviceAnnotation : public Annotation {
58 public:
59  DeviceAnnotation() : Annotation(AnnotationKind::Device) {}
60  DeviceAnnotation(std::string device)
61  : Annotation(AnnotationKind::Device), Device(device) {}
62  void setDevice(std::string device) { Device = device; }
63  const std::string getDevice() const { return Device; }
64 
65  static bool classof(const Annotation *A) {
66  return A->getKind() == AnnotationKind::Device;
67  }
68 
69 private:
70  std::string Device = 0;
71 };
72 
74 public:
76  enum class NNKind {
77  Undefined,
78  Conv,
79  Relu,
80  ConvRelu,
82  Send,
83  Receive,
84  While,
85  NNPhi,
87  };
88 
90  enum class NNLayout { Undefined, NCHW, NHWC };
91 
93  : Instruction(I), Kind(K), Layout(L) {}
95  : Instruction(I), Kind(K), Layout(NNLayout::Undefined) {}
96  NeuralNetOperator(NNKind K, NNLayout L) : Instruction(), Kind(K), Layout(L) {}
98  : Instruction(), Kind(K), Layout(NNLayout::Undefined) {}
100  : Instruction(), Kind(NNKind::Undefined), Layout(NNLayout::Undefined) {}
101 
102  NNKind getKind() const { return Kind; }
103 
104  void setLayout(NNLayout L) { Layout = L; }
105 
106  NNLayout getLayout() const { return Layout; }
107 
108  void setAnnotation(std::unique_ptr<Annotation> extraAnnotation) {
109  ExtraAnnotation = std::move(extraAnnotation);
110  }
111 
112  const Annotation *getAnnotation() const { return ExtraAnnotation.get(); }
113  Annotation *getMutableAnnotation() { return ExtraAnnotation.get(); }
114 
115  const std::string getName() const;
116 
124  bool checkInputsAndOutputs(std::vector<const NeuralNetData *> inputs,
125  std::vector<const NeuralNetData *> outputs) {
126  return true;
127  }
128 
129  virtual ~NeuralNetOperator() = 0;
130 
131  NeuralNetOperator(const NeuralNetOperator &) = delete;
132  NeuralNetOperator &operator=(NeuralNetOperator &) = delete;
133 
134 private:
135  const NNKind Kind;
136  NNLayout Layout; // Mutable attribute, much like a type cast
137  std::unique_ptr<Annotation> ExtraAnnotation;
138 };
139 
140 class NeuralNetData : public Data {
141 public:
143  enum class NNDataKind { Generic, Tensor };
144 
145  NeuralNetData(NNDataKind kind) : Kind(kind) {}
146 
147  NeuralNetData() : Kind(NNDataKind::Generic) {}
148 
149  NNDataKind getKind() const { return Kind; }
150 
151  virtual NeuralNetData *clone() = 0;
152 
153  const std::string getName() const;
154 
155  virtual ~NeuralNetData() = 0;
156 
157 private:
158  NNDataKind Kind;
159  size_t Version = 0;
160 };
161 
162 class Tensor : public NeuralNetData {
163 public:
164  Tensor(std::string name) : NeuralNetData(NNDataKind::Tensor), name_(name) {}
165  static bool classof(const NeuralNetData *D) {
166  return D->getKind() == NNDataKind::Tensor;
167  }
168 
169  NeuralNetData *clone() { return new Tensor(name_); }
170 
171  const std::string getName() const { return name_; }
172  ~Tensor() {}
173 
174 private:
175  std::string name_;
176 };
177 
179 public:
180  DynamicInput() : NeuralNetOperator(NNKind::DynamicInput) {}
181  ~DynamicInput() {}
182 };
183 
184 #define NOMNIGRAPH_DEFINE_NN_RTTI(op) \
185  static bool classof(const NeuralNetOperator *N) { \
186  return N->getKind() == NNKind::op; \
187  }
188 
189 class Conv : public NeuralNetOperator {
190 public:
191  Conv(std::vector<int> kernelShape, std::vector<int> dilations = {1, 1},
192  int group = 1, std::vector<int> pads = {0, 0},
193  std::vector<int> strides = {1, 1})
194  : NeuralNetOperator(NNKind::Conv), KernelShape(kernelShape),
195  Dilations(dilations), Group(group), Pads(pads), Strides(strides) {}
196 
197  NOMNIGRAPH_DEFINE_NN_RTTI(Conv);
198 
199  ~Conv() {}
200 
201  void setDilations(std::vector<int> &&dilations) {
202  Dilations = std::move(dilations);
203  }
204 
205  void setGroup(int &&group) { Group = std::move(group); }
206 
207  void setPads(std::vector<int> &&pads) { Pads = std::move(pads); }
208 
209  void setStrides(std::vector<int> &&strides) { Strides = std::move(strides); }
210 
211  std::vector<int> getDilations() { return Dilations; }
212 
213  int getGroup() { return Group; }
214 
215  std::vector<int> getPads() { return Pads; }
216 
217  std::vector<int> getStrides() { return Strides; }
218 
219  std::vector<int> getKernelShape() { return KernelShape; }
220 
221  bool checkInputsAndOutputs(std::vector<const NeuralNetData *> inputs,
222  std::vector<const NeuralNetData *> outputs) {
223  assert(KernelShape.size() == Dilations.size());
224  assert(KernelShape.size() == Pads.size());
225  assert(KernelShape.size() == Strides.size());
226  return true;
227  }
228 
229 protected:
230  std::vector<int> KernelShape;
231  std::vector<int> Dilations;
232  int Group;
233  std::vector<int> Pads;
234  std::vector<int> Strides;
235 };
236 
237 class ConvRelu : public NeuralNetOperator {
238 public:
239  ConvRelu(std::vector<int> kernelShape, std::vector<int> dilations = {1, 1},
240  int group = 1, std::vector<int> pads = {0, 0},
241  std::vector<int> strides = {1, 1})
242  : NeuralNetOperator(NNKind::ConvRelu),
243  ConvPtr(util::make_unique<Conv>(kernelShape, dilations, group, pads,
244  strides)) {}
245 
246  ConvRelu(Conv *conv)
247  : NeuralNetOperator(NNKind::ConvRelu), ConvPtr(std::move(conv)) {}
248 
249  NOMNIGRAPH_DEFINE_NN_RTTI(ConvRelu);
250 
251  ~ConvRelu() {}
252 
253 private:
254  std::unique_ptr<Conv> ConvPtr = nullptr;
255 };
256 
257 class Relu : public NeuralNetOperator {
258 public:
259  Relu() : NeuralNetOperator(NNKind::Relu) {}
260  NOMNIGRAPH_DEFINE_NN_RTTI(Relu);
261  ~Relu() {}
262 };
263 
264 class Send : public NeuralNetOperator {
265 public:
266  Send() : NeuralNetOperator(NNKind::Send) {}
267  NOMNIGRAPH_DEFINE_NN_RTTI(Send);
268  ~Send() {}
269 };
270 
271 class Receive : public NeuralNetOperator {
272 public:
273  Receive() : NeuralNetOperator(NNKind::Receive) {}
274  NOMNIGRAPH_DEFINE_NN_RTTI(Receive);
275  ~Receive() {}
276 };
277 
278 class While : public NeuralNetOperator {
279 public:
280  While() : NeuralNetOperator(NNKind::While, Opcode::Branch) {}
281  NOMNIGRAPH_DEFINE_NN_RTTI(While);
282  ~While() {}
283 };
284 
285 class NNPhi : public NeuralNetOperator {
286 public:
287  NNPhi() : NeuralNetOperator(NNKind::NNPhi, Opcode::Phi) {}
288  NOMNIGRAPH_DEFINE_NN_RTTI(NNPhi);
289  ~NNPhi() {}
290 };
291 
293 public:
294  GenericOperator() : NeuralNetOperator(NNKind::GenericOperator) {}
295  GenericOperator(std::string name)
296  : NeuralNetOperator(NNKind::GenericOperator), name_(name) {}
297  NOMNIGRAPH_DEFINE_NN_RTTI(GenericOperator);
298  std::string getName() const { return name_; }
299  ~GenericOperator() {}
300 
301 private:
302  std::string name_;
303 };
304 
307 
308 struct NNModule {
309  NNGraph dataFlow;
310  NNCFGraph controlFlow;
311  NNModule(const NNModule &) = delete;
312  NNModule(NNModule &&) = default;
313  NNModule() {}
314 };
315 
316 // Although these seem generic, they make subtle assumptions
317 // about the structure of the graph that is 100% valid for NNModule graphs
318 // but not any graph (such as data being a unique_ptr).
319 namespace nn {
320 
321 template< bool B, class T = void >
322 using enable_if_t = typename std::enable_if<B,T>::type;
323 
324 template <typename T>
325 constexpr bool inheritedFromNeuralNetOperator() {
326  return std::is_base_of<NeuralNetOperator, T>::value &&
327  !std::is_same<NeuralNetOperator, T>::value;
328 }
329 
330 template <typename T>
331 constexpr bool inheritedFromNeuralNetData() {
332  return std::is_base_of<NeuralNetData, T>::value &&
333  !std::is_same<NeuralNetData, T>::value;
334 }
335 
336 // This is just a way to fix issues when the isa<> implementation
337 // can't automatically downcast.
338 template <typename T, typename N, typename = void> struct is_impl {
339  inline static bool impl(N n) { return isa<T>(n->data()); }
340 };
341 
342 template <typename T, typename N>
343 struct is_impl<T, N, enable_if_t<inheritedFromNeuralNetOperator<T>()>> {
344  inline static bool impl(N n) {
345  auto nno = dyn_cast<NeuralNetOperator>(n->data().get());
346  return isa<T>(nno);
347  }
348 };
349 
350 template <typename T, typename N>
351 struct is_impl<T, N, enable_if_t<inheritedFromNeuralNetData<T>()>> {
352  inline static bool impl(N n) {
353  auto nno = dyn_cast<NeuralNetData>(n->data().get());
354  return isa<T>(nno);
355  }
356 };
357 
358 template <typename T, typename N> inline bool is(N n) {
359  return is_impl<T, N>::impl(n);
360 }
361 
362 // This is just a way to fix issues when the dyn_cast<> implementation
363 // can't automatically downcast.
364 template <typename T, typename N, typename = void> struct get_impl {
365  inline static T *impl(N n) { return dyn_cast<T>(n->data().get()); }
366 };
367 
368 template <typename T, typename N>
369 struct get_impl<T, N, enable_if_t<inheritedFromNeuralNetOperator<T>()>> {
370  inline static T *impl(N n) {
371  auto nno = dyn_cast<NeuralNetOperator>(n->data().get());
372  return dyn_cast<T>(nno);
373  }
374 };
375 
376 template <typename T, typename N>
377 struct get_impl<T, N, enable_if_t<inheritedFromNeuralNetData<T>()>> {
378  inline static T *impl(N n) {
379  auto nno = dyn_cast<NeuralNetData>(n->data().get());
380  return dyn_cast<T>(nno);
381  }
382 };
383 
384 template <typename T, typename N> inline T *get(N n) {
385  return get_impl<T, N>::impl(n);
386 }
387 
388 template <typename T, typename G>
389 std::vector<std::pair<T *, typename G::NodeRef>> dataIterator(G &g) {
390  std::vector<std::pair<T *, typename G::NodeRef>> out;
391  for (auto node : g.getMutableNodes()) {
392  if (!is<T>(node)) {
393  continue;
394  }
395  auto d = get<T>(node);
396  out.emplace_back(std::make_pair(d, node));
397  }
398  return out;
399 }
400 
402 bool hasProducer(NNGraph::NodeRef n);
403 NNGraph::NodeRef getProducer(NNGraph::NodeRef n);
404 std::vector<NNGraph::NodeRef> getConsumers(NNGraph::NodeRef n);
405 
406 bool hasInputs(NNGraph::NodeRef n);
407 std::vector<NNGraph::NodeRef> getInputs(NNGraph::NodeRef n);
408 std::vector<NNGraph::NodeRef> getOutputs(NNGraph::NodeRef n);
409 
410 void coalesceInsertedDataDependencies(repr::NNModule* m);
411 
412 template <NNGraph* G>
413 struct NodeHelper {
414 };
415 
416 } // namespace nn
417 
418 } // namespace repr
419 } // namespace nom
420 
421 #endif // NOM_REPRESENTATIONS_NEURALNET_H
Annotations allow for generic manipulation of neural network operations.
Definition: NeuralNet.h:37
NNLayout
An optional tensor-type specifier.
Definition: NeuralNet.h:90
bool checkInputsAndOutputs(std::vector< const NeuralNetData * > inputs, std::vector< const NeuralNetData * > outputs)
Validate the inputs and outputs to this operator.
Definition: NeuralNet.h:124
Definition: Caffe2.cc:16
NNKind
Discriminator for LLVM-style RTTI (isa<>)
Definition: NeuralNet.h:76
Opcode
All the different types of execution.
Definition: Compiler.h:40
NNDataKind
Discriminator for LLVM-style RTTI (isa<>)
Definition: NeuralNet.h:143