12 #ifndef NOM_REPRESENTATIONS_NEURALNET_H 13 #define NOM_REPRESENTATIONS_NEURALNET_H 15 #include "nomnigraph/Graph/Graph.h" 16 #include "nomnigraph/Representations/Compiler.h" 17 #include "nomnigraph/Representations/ControlFlow.h" 18 #include "nomnigraph/Support/Casting.h" 19 #include "nomnigraph/Support/Pointer.h" 22 #include <type_traits> 39 enum class AnnotationKind { Generic, Device };
42 Annotation() : Kind(AnnotationKind::Generic) {}
44 AnnotationKind getKind()
const {
return Kind; }
49 void *getSaved()
const {
return Saved; }
50 void setSaved(
void *saved) { Saved = saved; }
53 const AnnotationKind Kind;
54 void *Saved =
nullptr;
61 :
Annotation(AnnotationKind::Device), Device(device) {}
62 void setDevice(std::string device) { Device = device; }
63 const std::string getDevice()
const {
return Device; }
66 return A->getKind() == AnnotationKind::Device;
70 std::string Device = 0;
95 :
Instruction(I), Kind(K), Layout(NNLayout::Undefined) {}
98 :
Instruction(), Kind(K), Layout(NNLayout::Undefined) {}
100 :
Instruction(), Kind(NNKind::Undefined), Layout(NNLayout::Undefined) {}
102 NNKind getKind()
const {
return Kind; }
104 void setLayout(
NNLayout L) { Layout = L; }
106 NNLayout getLayout()
const {
return Layout; }
108 void setAnnotation(std::unique_ptr<Annotation> extraAnnotation) {
109 ExtraAnnotation = std::move(extraAnnotation);
112 const Annotation *getAnnotation()
const {
return ExtraAnnotation.get(); }
113 Annotation *getMutableAnnotation() {
return ExtraAnnotation.get(); }
115 const std::string getName()
const;
125 std::vector<const NeuralNetData *> outputs) {
137 std::unique_ptr<Annotation> ExtraAnnotation;
153 const std::string getName()
const;
166 return D->getKind() == NNDataKind::Tensor;
171 const std::string getName()
const {
return name_; }
184 #define NOMNIGRAPH_DEFINE_NN_RTTI(op) \ 185 static bool classof(const NeuralNetOperator *N) { \ 186 return N->getKind() == NNKind::op; \ 191 Conv(std::vector<int> kernelShape, std::vector<int> dilations = {1, 1},
192 int group = 1, std::vector<int> pads = {0, 0},
193 std::vector<int> strides = {1, 1})
195 Dilations(dilations), Group(group), Pads(pads), Strides(strides) {}
197 NOMNIGRAPH_DEFINE_NN_RTTI(
Conv);
201 void setDilations(std::vector<int> &&dilations) {
202 Dilations = std::move(dilations);
205 void setGroup(
int &&group) { Group = std::move(group); }
207 void setPads(std::vector<int> &&pads) { Pads = std::move(pads); }
209 void setStrides(std::vector<int> &&strides) { Strides = std::move(strides); }
211 std::vector<int> getDilations() {
return Dilations; }
213 int getGroup() {
return Group; }
215 std::vector<int> getPads() {
return Pads; }
217 std::vector<int> getStrides() {
return Strides; }
219 std::vector<int> getKernelShape() {
return KernelShape; }
221 bool checkInputsAndOutputs(std::vector<const NeuralNetData *> inputs,
222 std::vector<const NeuralNetData *> outputs) {
223 assert(KernelShape.size() == Dilations.size());
224 assert(KernelShape.size() == Pads.size());
225 assert(KernelShape.size() == Strides.size());
230 std::vector<int> KernelShape;
231 std::vector<int> Dilations;
233 std::vector<int> Pads;
234 std::vector<int> Strides;
239 ConvRelu(std::vector<int> kernelShape, std::vector<int> dilations = {1, 1},
240 int group = 1, std::vector<int> pads = {0, 0},
241 std::vector<int> strides = {1, 1})
243 ConvPtr(util::make_unique<Conv>(kernelShape, dilations, group, pads,
249 NOMNIGRAPH_DEFINE_NN_RTTI(
ConvRelu);
254 std::unique_ptr<Conv> ConvPtr =
nullptr;
260 NOMNIGRAPH_DEFINE_NN_RTTI(
Relu);
267 NOMNIGRAPH_DEFINE_NN_RTTI(
Send);
274 NOMNIGRAPH_DEFINE_NN_RTTI(
Receive);
281 NOMNIGRAPH_DEFINE_NN_RTTI(
While);
288 NOMNIGRAPH_DEFINE_NN_RTTI(
NNPhi);
298 std::string getName()
const {
return name_; }
321 template<
bool B,
class T =
void >
322 using enable_if_t =
typename std::enable_if<B,T>::type;
324 template <
typename T>
325 constexpr
bool inheritedFromNeuralNetOperator() {
326 return std::is_base_of<NeuralNetOperator, T>::value &&
327 !std::is_same<NeuralNetOperator, T>::value;
330 template <
typename T>
331 constexpr
bool inheritedFromNeuralNetData() {
332 return std::is_base_of<NeuralNetData, T>::value &&
333 !std::is_same<NeuralNetData, T>::value;
338 template <
typename T,
typename N,
typename =
void>
struct is_impl {
339 inline static bool impl(N n) {
return isa<T>(n->data()); }
342 template <
typename T,
typename N>
343 struct is_impl<T, N, enable_if_t<inheritedFromNeuralNetOperator<T>()>> {
344 inline static bool impl(N n) {
350 template <
typename T,
typename N>
351 struct is_impl<T, N, enable_if_t<inheritedFromNeuralNetData<T>()>> {
352 inline static bool impl(N n) {
358 template <
typename T,
typename N>
inline bool is(N n) {
364 template <
typename T,
typename N,
typename =
void>
struct get_impl {
365 inline static T *impl(N n) {
return dyn_cast<T>(n->data().get()); }
368 template <
typename T,
typename N>
369 struct get_impl<T, N, enable_if_t<inheritedFromNeuralNetOperator<T>()>> {
370 inline static T *impl(N n) {
372 return dyn_cast<T>(nno);
376 template <
typename T,
typename N>
377 struct get_impl<T, N, enable_if_t<inheritedFromNeuralNetData<T>()>> {
378 inline static T *impl(N n) {
380 return dyn_cast<T>(nno);
384 template <
typename T,
typename N>
inline T *
get(N n) {
388 template <
typename T,
typename G>
389 std::vector<std::pair<T *, typename G::NodeRef>> dataIterator(G &g) {
390 std::vector<std::pair<T *, typename G::NodeRef>> out;
391 for (
auto node : g.getMutableNodes()) {
395 auto d = get<T>(node);
396 out.emplace_back(std::make_pair(d, node));
412 template <NNGraph* G>
421 #endif // NOM_REPRESENTATIONS_NEURALNET_H
Annotations allow for generic manipulation of neural network operations.
NNLayout
An optional tensor-type specifier.
bool checkInputsAndOutputs(std::vector< const NeuralNetData * > inputs, std::vector< const NeuralNetData * > outputs)
Validate the inputs and outputs to this operator.
NNKind
Discriminator for LLVM-style RTTI (isa<>)
Opcode
All the different types of execution.
NNDataKind
Discriminator for LLVM-style RTTI (isa<>)