tesseract  3.05.02
hybrid_neural_net_classifier.cpp
Go to the documentation of this file.
1 /**********************************************************************
2  * File: charclassifier.cpp
3  * Description: Implementation of Convolutional-NeuralNet Character Classifier
4  * Author: Ahmad Abdulkader
5  * Created: 2007
6  *
7  * (C) Copyright 2008, Google Inc.
8  ** Licensed under the Apache License, Version 2.0 (the "License");
9  ** you may not use this file except in compliance with the License.
10  ** You may obtain a copy of the License at
11  ** http://www.apache.org/licenses/LICENSE-2.0
12  ** Unless required by applicable law or agreed to in writing, software
13  ** distributed under the License is distributed on an "AS IS" BASIS,
14  ** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  ** See the License for the specific language governing permissions and
16  ** limitations under the License.
17  *
18  **********************************************************************/
19 
20 #include <algorithm>
21 #include <stdio.h>
22 #include <stdlib.h>
23 #include <string>
24 #include <vector>
25 #include <wctype.h>
26 
27 #include "classifier_base.h"
28 #include "char_set.h"
29 #include "const.h"
30 #include "conv_net_classifier.h"
31 #include "cube_utils.h"
32 #include "feature_base.h"
33 #include "feature_bmp.h"
35 #include "tess_lang_model.h"
36 
37 namespace tesseract {
38 
40  CharSet *char_set,
41  TuningParams *params,
42  FeatureBase *feat_extract)
43  : CharClassifier(char_set, params, feat_extract) {
44  net_input_ = NULL;
45  net_output_ = NULL;
46 }
47 
49  for (int net_idx = 0; net_idx < nets_.size(); net_idx++) {
50  if (nets_[net_idx] != NULL) {
51  delete nets_[net_idx];
52  }
53  }
54  nets_.clear();
55 
56  if (net_input_ != NULL) {
57  delete []net_input_;
58  net_input_ = NULL;
59  }
60 
61  if (net_output_ != NULL) {
62  delete []net_output_;
63  net_output_ = NULL;
64  }
65 }
66 
67 // The main training function. Given a sample and a class ID the classifier
68 // updates its parameters according to its learning algorithm. This function
69 // is currently not implemented. TODO(ahmadab): implement end-2-end training
70 bool HybridNeuralNetCharClassifier::Train(CharSamp *char_samp, int ClassID) {
71  return false;
72 }
73 
74 // A secondary function needed for training. Allows the trainer to set the
75 // value of any train-time parameter. This function is currently not
76 // implemented. TODO(ahmadab): implement end-2-end training
77 bool HybridNeuralNetCharClassifier::SetLearnParam(char *var_name, float val) {
78  // TODO(ahmadab): implementation of parameter initializing.
79  return false;
80 }
81 
82 // Folds the output of the NeuralNet using the loaded folding sets
83 void HybridNeuralNetCharClassifier::Fold() {
84  // in case insensitive mode
85  if (case_sensitive_ == false) {
86  int class_cnt = char_set_->ClassCount();
87  // fold case
88  for (int class_id = 0; class_id < class_cnt; class_id++) {
89  // get class string
90  const char_32 *str32 = char_set_->ClassString(class_id);
91  // get the upper case form of the string
92  string_32 upper_form32 = str32;
93  for (int ch = 0; ch < upper_form32.length(); ch++) {
94  if (iswalpha(static_cast<int>(upper_form32[ch])) != 0) {
95  upper_form32[ch] = towupper(upper_form32[ch]);
96  }
97  }
98 
99  // find out the upperform class-id if any
100  int upper_class_id =
101  char_set_->ClassID(reinterpret_cast<const char_32 *>(
102  upper_form32.c_str()));
103  if (upper_class_id != -1 && class_id != upper_class_id) {
104  float max_out = MAX(net_output_[class_id], net_output_[upper_class_id]);
105  net_output_[class_id] = max_out;
106  net_output_[upper_class_id] = max_out;
107  }
108  }
109  }
110 
111  // The folding sets specify how groups of classes should be folded
112  // Folding involved assigning a min-activation to all the members
113  // of the folding set. The min-activation is a fraction of the max-activation
114  // of the members of the folding set
115  for (int fold_set = 0; fold_set < fold_set_cnt_; fold_set++) {
116  float max_prob = net_output_[fold_sets_[fold_set][0]];
117 
118  for (int ch = 1; ch < fold_set_len_[fold_set]; ch++) {
119  if (net_output_[fold_sets_[fold_set][ch]] > max_prob) {
120  max_prob = net_output_[fold_sets_[fold_set][ch]];
121  }
122  }
123  for (int ch = 0; ch < fold_set_len_[fold_set]; ch++) {
124  net_output_[fold_sets_[fold_set][ch]] = MAX(max_prob * kFoldingRatio,
125  net_output_[fold_sets_[fold_set][ch]]);
126  }
127  }
128 }
129 
130 // compute the features of specified charsamp and
131 // feedforward the specified nets
132 bool HybridNeuralNetCharClassifier::RunNets(CharSamp *char_samp) {
133  int feat_cnt = feat_extract_->FeatureCnt();
134  int class_cnt = char_set_->ClassCount();
135 
136  // allocate i/p and o/p buffers if needed
137  if (net_input_ == NULL) {
138  net_input_ = new float[feat_cnt];
139  net_output_ = new float[class_cnt];
140  }
141 
142  // compute input features
143  if (feat_extract_->ComputeFeatures(char_samp, net_input_) == false) {
144  return false;
145  }
146 
147  // go through all the nets
148  memset(net_output_, 0, class_cnt * sizeof(*net_output_));
149  float *inputs = net_input_;
150  for (int net_idx = 0; net_idx < nets_.size(); net_idx++) {
151  // run each net
152  vector<float> net_out(class_cnt, 0.0);
153  if (!nets_[net_idx]->FeedForward(inputs, &net_out[0])) {
154  return false;
155  }
156  // add the output values
157  for (int class_idx = 0; class_idx < class_cnt; class_idx++) {
158  net_output_[class_idx] += (net_out[class_idx] * net_wgts_[net_idx]);
159  }
160  // increment inputs pointer
161  inputs += nets_[net_idx]->in_cnt();
162  }
163 
164  Fold();
165 
166  return true;
167 }
168 
169 // return the cost of being a char
171  // it is by design that a character cost is equal to zero
172  // when no nets are present. This is the case during training.
173  if (RunNets(char_samp) == false) {
174  return 0;
175  }
176 
177  return CubeUtils::Prob2Cost(1.0f - net_output_[0]);
178 }
179 
180 // classifies a charsamp and returns an alternate list
181 // of chars sorted by char costs
183  // run the needed nets
184  if (RunNets(char_samp) == false) {
185  return NULL;
186  }
187 
188  int class_cnt = char_set_->ClassCount();
189 
190  // create an altlist
191  CharAltList *alt_list = new CharAltList(char_set_, class_cnt);
192 
193  for (int out = 1; out < class_cnt; out++) {
194  int cost = CubeUtils::Prob2Cost(net_output_[out]);
195  alt_list->Insert(out, cost);
196  }
197 
198  return alt_list;
199 }
200 
201 // set an external net (for training purposes)
203 }
204 
205 // Load folding sets
206 // This function returns true on success or if the file can't be read,
207 // returns false if an error is encountered.
208 bool HybridNeuralNetCharClassifier::LoadFoldingSets(
209  const string &data_file_path, const string &lang, LangModel *lang_mod) {
210  fold_set_cnt_ = 0;
211  string fold_file_name;
212  fold_file_name = data_file_path + lang;
213  fold_file_name += ".cube.fold";
214 
215  // folding sets are optional
216  FILE *fp = fopen(fold_file_name.c_str(), "rb");
217  if (fp == NULL) {
218  return true;
219  }
220  fclose(fp);
221 
222  string fold_sets_str;
223  if (!CubeUtils::ReadFileToString(fold_file_name,
224  &fold_sets_str)) {
225  return false;
226  }
227 
228  // split into lines
229  vector<string> str_vec;
230  CubeUtils::SplitStringUsing(fold_sets_str, "\r\n", &str_vec);
231  fold_set_cnt_ = str_vec.size();
232  fold_sets_ = new int *[fold_set_cnt_];
233  fold_set_len_ = new int[fold_set_cnt_];
234 
235  for (int fold_set = 0; fold_set < fold_set_cnt_; fold_set++) {
236  reinterpret_cast<TessLangModel *>(lang_mod)->RemoveInvalidCharacters(
237  &str_vec[fold_set]);
238 
239  // if all or all but one character are invalid, invalidate this set
240  if (str_vec[fold_set].length() <= 1) {
241  fprintf(stderr, "Cube WARNING (ConvNetCharClassifier::LoadFoldingSets): "
242  "invalidating folding set %d\n", fold_set);
243  fold_set_len_[fold_set] = 0;
244  fold_sets_[fold_set] = NULL;
245  continue;
246  }
247 
248  string_32 str32;
249  CubeUtils::UTF8ToUTF32(str_vec[fold_set].c_str(), &str32);
250  fold_set_len_[fold_set] = str32.length();
251  fold_sets_[fold_set] = new int[fold_set_len_[fold_set]];
252  for (int ch = 0; ch < fold_set_len_[fold_set]; ch++) {
253  fold_sets_[fold_set][ch] = char_set_->ClassID(str32[ch]);
254  }
255  }
256  return true;
257 }
258 
259 // Init the classifier provided a data-path and a language string
260 bool HybridNeuralNetCharClassifier::Init(const string &data_file_path,
261  const string &lang,
262  LangModel *lang_mod) {
263  if (init_ == true) {
264  return true;
265  }
266 
267  // load the nets if any. This function will return true if the net file
268  // does not exist. But will fail if the net did not pass the sanity checks
269  if (!LoadNets(data_file_path, lang)) {
270  return false;
271  }
272 
273  // load the folding sets if any. This function will return true if the
274  // file does not exist. But will fail if the it did not pass the sanity checks
275  if (!LoadFoldingSets(data_file_path, lang, lang_mod)) {
276  return false;
277  }
278 
279  init_ = true;
280  return true;
281 }
282 
283 // Load the classifier's Neural Nets
284 // This function will return true if the net file does not exist.
285 // But will fail if the net did not pass the sanity checks
286 bool HybridNeuralNetCharClassifier::LoadNets(const string &data_file_path,
287  const string &lang) {
288  string hybrid_net_file;
289  string junk_net_file;
290 
291  // add the lang identifier
292  hybrid_net_file = data_file_path + lang;
293  hybrid_net_file += ".cube.hybrid";
294 
295  // neural network is optional
296  FILE *fp = fopen(hybrid_net_file.c_str(), "rb");
297  if (fp == NULL) {
298  return true;
299  }
300  fclose(fp);
301 
302  string str;
303  if (!CubeUtils::ReadFileToString(hybrid_net_file, &str)) {
304  return false;
305  }
306 
307  // split into lines
308  vector<string> str_vec;
309  CubeUtils::SplitStringUsing(str, "\r\n", &str_vec);
310  if (str_vec.empty()) {
311  return false;
312  }
313 
314  // create and add the nets
315  nets_.resize(str_vec.size(), NULL);
316  net_wgts_.resize(str_vec.size(), 0);
317  int total_input_size = 0;
318  for (int net_idx = 0; net_idx < str_vec.size(); net_idx++) {
319  // parse the string
320  vector<string> tokens_vec;
321  CubeUtils::SplitStringUsing(str_vec[net_idx], " \t", &tokens_vec);
322  // has to be 2 tokens, net name and input size
323  if (tokens_vec.size() != 2) {
324  return false;
325  }
326  // load the net
327  string net_file_name = data_file_path + tokens_vec[0];
328  nets_[net_idx] = tesseract::NeuralNet::FromFile(net_file_name);
329  if (nets_[net_idx] == NULL) {
330  return false;
331  }
332  // parse the input size and validate it
333  net_wgts_[net_idx] = atof(tokens_vec[1].c_str());
334  if (net_wgts_[net_idx] < 0.0) {
335  return false;
336  }
337  total_input_size += nets_[net_idx]->in_cnt();
338  }
339  // validate total input count
340  if (total_input_size != feat_extract_->FeatureCnt()) {
341  return false;
342  }
343  // success
344  return true;
345 }
346 } // tesseract
static NeuralNet * FromFile(const string file_name)
Definition: neural_net.cpp:210
bool Insert(int class_id, int cost, void *tag=NULL)
virtual bool Train(CharSamp *char_samp, int ClassID)
static bool ReadFileToString(const string &file_name, string *str)
Definition: cube_utils.cpp:189
virtual CharAltList * Classify(CharSamp *char_samp)
HybridNeuralNetCharClassifier(CharSet *char_set, TuningParams *params, FeatureBase *feat_extract)
static int Prob2Cost(double prob_val)
Definition: cube_utils.cpp:37
int ClassCount() const
Definition: char_set.h:111
static void SplitStringUsing(const string &str, const string &delims, vector< string > *str_vec)
Definition: cube_utils.cpp:220
#define MAX(x, y)
Definition: ndminx.h:24
virtual bool ComputeFeatures(CharSamp *char_samp, float *features)=0
signed int char_32
Definition: string_32.h:40
basic_string< char_32 > string_32
Definition: string_32.h:41
static void UTF8ToUTF32(const char *utf8_str, string_32 *str32)
Definition: cube_utils.cpp:256
virtual bool SetLearnParam(char *var_name, float val)
int ClassID(const char_32 *str) const
Definition: char_set.h:54
virtual int FeatureCnt()=0
const char_32 * ClassString(int class_id) const
Definition: char_set.h:104