tesseract  3.05.02
neural_net.h
Go to the documentation of this file.
1 // Copyright 2008 Google Inc.
2 // All Rights Reserved.
3 // Author: ahmadab@google.com (Ahmad Abdulkader)
4 //
5 // neural_net.h: Declarations of a class for an object that
6 // represents an arbitrary network of neurons
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 // http://www.apache.org/licenses/LICENSE-2.0
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
17 
18 #ifndef NEURAL_NET_H
19 #define NEURAL_NET_H
20 
21 #include <string>
22 #include <vector>
23 #include "neuron.h"
24 #include "input_file_buffer.h"
25 
26 namespace tesseract {
27 
28 // Minimum input range below which we set the input weight to zero
29 static const float kMinInputRange = 1e-6f;
30 
31 class NeuralNet {
32  public:
33  NeuralNet();
34  virtual ~NeuralNet();
35  // create a net object from a file. Uses stdio
36  static NeuralNet *FromFile(const string file_name);
37  // create a net object from an input buffer
39  // Different flavors of feed forward function
40  template <typename Type> bool FeedForward(const Type *inputs,
41  Type *outputs);
42  // Compute the output of a specific output node.
43  // This function is useful for application that are interested in a single
44  // output of the net and do not want to waste time on the rest
45  template <typename Type> bool GetNetOutput(const Type *inputs,
46  int output_id,
47  Type *output);
48  // Accessor functions
49  int in_cnt() const { return in_cnt_; }
50  int out_cnt() const { return out_cnt_; }
51 
52  protected:
53  struct Node;
54  // A node-weight pair
55  struct WeightedNode {
57  float input_weight;
58  };
59  // node struct used for fast feedforward in
60  // Read only nets
61  struct Node {
62  float out;
63  float bias;
66  };
67  // Read-Only flag (no training: On by default)
68  // will presumeably be set to false by
69  // the inherting TrainableNeuralNet class
70  bool read_only_;
71  // input count
72  int in_cnt_;
73  // output count
74  int out_cnt_;
75  // Total neuron count (including inputs)
77  // count of unique weights
78  int wts_cnt_;
79  // Neuron vector
81  // size of allocated weight chunk (in weights)
82  // This is basically the size of the biggest network
83  // that I have trained. However, the class will allow
84  // a bigger sized net if desired
85  static const int kWgtChunkSize = 0x10000;
86  // Magic number expected at the beginning of the NN
87  // binary file
88  static const unsigned int kNetSignature = 0xFEFEABD0;
89  // count of allocated wgts in the last chunk
91  // vector of weights buffers
92  vector<vector<float> *>wts_vec_;
93  // Is the net an auto-encoder type
95  // vector of input max values
96  vector<float> inputs_max_;
97  // vector of input min values
98  vector<float> inputs_min_;
99  // vector of input mean values
100  vector<float> inputs_mean_;
101  // vector of input standard deviation values
102  vector<float> inputs_std_dev_;
103  // vector of input offsets used by fast read-only
104  // feedforward function
105  vector<Node> fast_nodes_;
106  // Network Initialization function
107  void Init();
108  // Clears all neurons
109  void Clear() {
110  for (int node = 0; node < neuron_cnt_; node++) {
111  neurons_[node].Clear();
112  }
113  }
114  // Reads the net from an input buffer
115  template<class ReadBuffType> bool ReadBinary(ReadBuffType *input_buff) {
116  // Init vars
117  Init();
118  // is this an autoencoder
119  unsigned int read_val;
120  unsigned int auto_encode;
121  // read and verify signature
122  if (input_buff->Read(&read_val, sizeof(read_val)) != sizeof(read_val)) {
123  return false;
124  }
125  if (read_val != kNetSignature) {
126  return false;
127  }
128  if (input_buff->Read(&auto_encode, sizeof(auto_encode)) !=
129  sizeof(auto_encode)) {
130  return false;
131  }
132  auto_encoder_ = auto_encode;
133  // read and validate total # of nodes
134  if (input_buff->Read(&read_val, sizeof(read_val)) != sizeof(read_val)) {
135  return false;
136  }
137  neuron_cnt_ = read_val;
138  if (neuron_cnt_ <= 0) {
139  return false;
140  }
141  // set the size of the neurons vector
142  neurons_ = new Neuron[neuron_cnt_];
143  // read & validate inputs
144  if (input_buff->Read(&read_val, sizeof(read_val)) != sizeof(read_val)) {
145  return false;
146  }
147  in_cnt_ = read_val;
148  if (in_cnt_ <= 0) {
149  return false;
150  }
151  // read outputs
152  if (input_buff->Read(&read_val, sizeof(read_val)) != sizeof(read_val)) {
153  return false;
154  }
155  out_cnt_ = read_val;
156  if (out_cnt_ <= 0) {
157  return false;
158  }
159  // set neuron ids and types
160  for (int idx = 0; idx < neuron_cnt_; idx++) {
161  neurons_[idx].set_id(idx);
162  // input type
163  if (idx < in_cnt_) {
165  } else if (idx >= (neuron_cnt_ - out_cnt_)) {
167  } else {
169  }
170  }
171  // read the connections
172  for (int node_idx = 0; node_idx < neuron_cnt_; node_idx++) {
173  // read fanout
174  if (input_buff->Read(&read_val, sizeof(read_val)) != sizeof(read_val)) {
175  return false;
176  }
177  // read the neuron's info
178  int fan_out_cnt = read_val;
179  for (int fan_out_idx = 0; fan_out_idx < fan_out_cnt; fan_out_idx++) {
180  // read the neuron id
181  if (input_buff->Read(&read_val, sizeof(read_val)) != sizeof(read_val)) {
182  return false;
183  }
184  // create the connection
185  if (!SetConnection(node_idx, read_val)) {
186  return false;
187  }
188  }
189  }
190  // read all the neurons' fan-in connections
191  for (int node_idx = 0; node_idx < neuron_cnt_; node_idx++) {
192  // read
193  if (!neurons_[node_idx].ReadBinary(input_buff)) {
194  return false;
195  }
196  }
197  // size input stats vector to expected input size
198  inputs_mean_.resize(in_cnt_);
199  inputs_std_dev_.resize(in_cnt_);
200  inputs_min_.resize(in_cnt_);
201  inputs_max_.resize(in_cnt_);
202  // read stats
203  if (input_buff->Read(&(inputs_mean_.front()),
204  sizeof(inputs_mean_[0]) * in_cnt_) !=
205  sizeof(inputs_mean_[0]) * in_cnt_) {
206  return false;
207  }
208  if (input_buff->Read(&(inputs_std_dev_.front()),
209  sizeof(inputs_std_dev_[0]) * in_cnt_) !=
210  sizeof(inputs_std_dev_[0]) * in_cnt_) {
211  return false;
212  }
213  if (input_buff->Read(&(inputs_min_.front()),
214  sizeof(inputs_min_[0]) * in_cnt_) !=
215  sizeof(inputs_min_[0]) * in_cnt_) {
216  return false;
217  }
218  if (input_buff->Read(&(inputs_max_.front()),
219  sizeof(inputs_max_[0]) * in_cnt_) !=
220  sizeof(inputs_max_[0]) * in_cnt_) {
221  return false;
222  }
223  // create a readonly version for fast feedforward
224  if (read_only_) {
225  return CreateFastNet();
226  }
227  return true;
228  }
229 
230  // creates a connection between two nodes
231  bool SetConnection(int from, int to);
232  // Create a read only version of the net that
233  // has faster feedforward performance
234  bool CreateFastNet();
235  // internal function to allocate a new set of weights
236  // Centralized weight allocation attempts to increase
237  // weights locality of reference making it more cache friendly
238  float *AllocWgt(int wgt_cnt);
239  // different flavors read-only feedforward function
240  template <typename Type> bool FastFeedForward(const Type *inputs,
241  Type *outputs);
242  // Compute the output of a specific output node.
243  // This function is useful for application that are interested in a single
244  // output of the net and do not want to waste time on the rest
245  // This is the fast-read-only version of this function
246  template <typename Type> bool FastGetNetOutput(const Type *inputs,
247  int output_id,
248  Type *output);
249 };
250 }
251 
252 #endif // NEURAL_NET_H__
void set_node_type(NeuronTypes type)
Definition: neuron.cpp:71
bool SetConnection(int from, int to)
Definition: neural_net.cpp:121
static NeuralNet * FromFile(const string file_name)
Definition: neural_net.cpp:210
int in_cnt() const
Definition: neural_net.h:49
static NeuralNet * FromInputBuffer(InputFileBuffer *ib)
Definition: neural_net.cpp:219
WeightedNode * inputs
Definition: neural_net.h:65
bool FeedForward(const Type *inputs, Type *outputs)
Definition: neural_net.cpp:88
vector< float > inputs_max_
Definition: neural_net.h:96
bool FastFeedForward(const Type *inputs, Type *outputs)
Definition: neural_net.cpp:61
void set_id(int id)
Definition: neuron.h:120
bool FastGetNetOutput(const Type *inputs, int output_id, Type *output)
Definition: neural_net.cpp:234
vector< vector< float > * > wts_vec_
Definition: neural_net.h:92
static const int kWgtChunkSize
Definition: neural_net.h:85
int out_cnt() const
Definition: neural_net.h:50
bool GetNetOutput(const Type *inputs, int output_id, Type *output)
Definition: neural_net.cpp:268
vector< float > inputs_min_
Definition: neural_net.h:98
bool ReadBinary(ReadBuffType *input_buff)
Definition: neural_net.h:115
void Clear()
Definition: neuron.h:46
vector< float > inputs_std_dev_
Definition: neural_net.h:102
static const unsigned int kNetSignature
Definition: neural_net.h:88
vector< float > inputs_mean_
Definition: neural_net.h:100
float * AllocWgt(int wgt_cnt)
Definition: neural_net.cpp:195
vector< Node > fast_nodes_
Definition: neural_net.h:105