tesseract  3.05.02
conv_net_classifier.cpp
Go to the documentation of this file.
1 /**********************************************************************
2  * File: charclassifier.cpp
3  * Description: Implementation of Convolutional-NeuralNet Character Classifier
4  * Author: Ahmad Abdulkader
5  * Created: 2007
6  *
7  * (C) Copyright 2008, Google Inc.
8  ** Licensed under the Apache License, Version 2.0 (the "License");
9  ** you may not use this file except in compliance with the License.
10  ** You may obtain a copy of the License at
11  ** http://www.apache.org/licenses/LICENSE-2.0
12  ** Unless required by applicable law or agreed to in writing, software
13  ** distributed under the License is distributed on an "AS IS" BASIS,
14  ** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  ** See the License for the specific language governing permissions and
16  ** limitations under the License.
17  *
18  **********************************************************************/
19 
20 #include <algorithm>
21 #include <stdio.h>
22 #include <stdlib.h>
23 #include <string>
24 #include <vector>
25 #include <wctype.h>
26 
27 #include "char_set.h"
28 #include "classifier_base.h"
29 #include "const.h"
30 #include "conv_net_classifier.h"
31 #include "cube_utils.h"
32 #include "feature_base.h"
33 #include "feature_bmp.h"
34 #include "tess_lang_model.h"
35 
36 namespace tesseract {
37 
39  TuningParams *params,
40  FeatureBase *feat_extract)
41  : CharClassifier(char_set, params, feat_extract) {
42  char_net_ = NULL;
43  net_input_ = NULL;
44  net_output_ = NULL;
45 }
46 
48  if (char_net_ != NULL) {
49  delete char_net_;
50  char_net_ = NULL;
51  }
52 
53  if (net_input_ != NULL) {
54  delete []net_input_;
55  net_input_ = NULL;
56  }
57 
58  if (net_output_ != NULL) {
59  delete []net_output_;
60  net_output_ = NULL;
61  }
62 }
63 
69 bool ConvNetCharClassifier::Train(CharSamp *char_samp, int ClassID) {
70  return false;
71 }
72 
78 bool ConvNetCharClassifier::SetLearnParam(char *var_name, float val) {
79  // TODO(ahmadab): implementation of parameter initializing.
80  return false;
81 }
82 
86 void ConvNetCharClassifier::Fold() {
87  // in case insensitive mode
88  if (case_sensitive_ == false) {
89  int class_cnt = char_set_->ClassCount();
90  // fold case
91  for (int class_id = 0; class_id < class_cnt; class_id++) {
92  // get class string
93  const char_32 *str32 = char_set_->ClassString(class_id);
94  // get the upper case form of the string
95  string_32 upper_form32 = str32;
96  for (int ch = 0; ch < upper_form32.length(); ch++) {
97  if (iswalpha(static_cast<int>(upper_form32[ch])) != 0) {
98  upper_form32[ch] = towupper(upper_form32[ch]);
99  }
100  }
101 
102  // find out the upperform class-id if any
103  int upper_class_id =
104  char_set_->ClassID(reinterpret_cast<const char_32 *>(
105  upper_form32.c_str()));
106  if (upper_class_id != -1 && class_id != upper_class_id) {
107  float max_out = MAX(net_output_[class_id], net_output_[upper_class_id]);
108  net_output_[class_id] = max_out;
109  net_output_[upper_class_id] = max_out;
110  }
111  }
112  }
113 
114  // The folding sets specify how groups of classes should be folded
115  // Folding involved assigning a min-activation to all the members
116  // of the folding set. The min-activation is a fraction of the max-activation
117  // of the members of the folding set
118  for (int fold_set = 0; fold_set < fold_set_cnt_; fold_set++) {
119  if (fold_set_len_[fold_set] == 0)
120  continue;
121  float max_prob = net_output_[fold_sets_[fold_set][0]];
122  for (int ch = 1; ch < fold_set_len_[fold_set]; ch++) {
123  if (net_output_[fold_sets_[fold_set][ch]] > max_prob) {
124  max_prob = net_output_[fold_sets_[fold_set][ch]];
125  }
126  }
127  for (int ch = 0; ch < fold_set_len_[fold_set]; ch++) {
128  net_output_[fold_sets_[fold_set][ch]] = MAX(max_prob * kFoldingRatio,
129  net_output_[fold_sets_[fold_set][ch]]);
130  }
131  }
132 }
133 
138 bool ConvNetCharClassifier::RunNets(CharSamp *char_samp) {
139  if (char_net_ == NULL) {
140  fprintf(stderr, "Cube ERROR (ConvNetCharClassifier::RunNets): "
141  "NeuralNet is NULL\n");
142  return false;
143  }
144  int feat_cnt = char_net_->in_cnt();
145  int class_cnt = char_set_->ClassCount();
146 
147  // allocate i/p and o/p buffers if needed
148  if (net_input_ == NULL) {
149  net_input_ = new float[feat_cnt];
150  net_output_ = new float[class_cnt];
151  }
152 
153  // compute input features
154  if (feat_extract_->ComputeFeatures(char_samp, net_input_) == false) {
155  fprintf(stderr, "Cube ERROR (ConvNetCharClassifier::RunNets): "
156  "unable to compute features\n");
157  return false;
158  }
159 
160  if (char_net_ != NULL) {
161  if (char_net_->FeedForward(net_input_, net_output_) == false) {
162  fprintf(stderr, "Cube ERROR (ConvNetCharClassifier::RunNets): "
163  "unable to run feed-forward\n");
164  return false;
165  }
166  } else {
167  return false;
168  }
169  Fold();
170  return true;
171 }
172 
177  if (RunNets(char_samp) == false) {
178  return 0;
179  }
180  return CubeUtils::Prob2Cost(1.0f - net_output_[0]);
181 }
182 
188  // run the needed nets
189  if (RunNets(char_samp) == false) {
190  return NULL;
191  }
192 
193  int class_cnt = char_set_->ClassCount();
194 
195  // create an altlist
196  CharAltList *alt_list = new CharAltList(char_set_, class_cnt);
197 
198  for (int out = 1; out < class_cnt; out++) {
199  int cost = CubeUtils::Prob2Cost(net_output_[out]);
200  alt_list->Insert(out, cost);
201  }
202 
203  return alt_list;
204 }
205 
210  if (char_net_ != NULL) {
211  delete char_net_;
212  char_net_ = NULL;
213  }
214  char_net_ = char_net;
215 }
216 
221 bool ConvNetCharClassifier::LoadFoldingSets(const string &data_file_path,
222  const string &lang,
223  LangModel *lang_mod) {
224  fold_set_cnt_ = 0;
225  string fold_file_name;
226  fold_file_name = data_file_path + lang;
227  fold_file_name += ".cube.fold";
228 
229  // folding sets are optional
230  FILE *fp = fopen(fold_file_name.c_str(), "rb");
231  if (fp == NULL) {
232  return true;
233  }
234  fclose(fp);
235 
236  string fold_sets_str;
237  if (!CubeUtils::ReadFileToString(fold_file_name,
238  &fold_sets_str)) {
239  return false;
240  }
241 
242  // split into lines
243  vector<string> str_vec;
244  CubeUtils::SplitStringUsing(fold_sets_str, "\r\n", &str_vec);
245  fold_set_cnt_ = str_vec.size();
246 
247  fold_sets_ = new int *[fold_set_cnt_];
248  fold_set_len_ = new int[fold_set_cnt_];
249 
250  for (int fold_set = 0; fold_set < fold_set_cnt_; fold_set++) {
251  reinterpret_cast<TessLangModel *>(lang_mod)->RemoveInvalidCharacters(
252  &str_vec[fold_set]);
253 
254  // if all or all but one character are invalid, invalidate this set
255  if (str_vec[fold_set].length() <= 1) {
256  fprintf(stderr, "Cube WARNING (ConvNetCharClassifier::LoadFoldingSets): "
257  "invalidating folding set %d\n", fold_set);
258  fold_set_len_[fold_set] = 0;
259  fold_sets_[fold_set] = NULL;
260  continue;
261  }
262 
263  string_32 str32;
264  CubeUtils::UTF8ToUTF32(str_vec[fold_set].c_str(), &str32);
265  fold_set_len_[fold_set] = str32.length();
266  fold_sets_[fold_set] = new int[fold_set_len_[fold_set]];
267  for (int ch = 0; ch < fold_set_len_[fold_set]; ch++) {
268  fold_sets_[fold_set][ch] = char_set_->ClassID(str32[ch]);
269  }
270  }
271  return true;
272 }
273 
277 bool ConvNetCharClassifier::Init(const string &data_file_path,
278  const string &lang,
279  LangModel *lang_mod) {
280  if (init_) {
281  return true;
282  }
283 
284  // load the nets if any. This function will return true if the net file
285  // does not exist. But will fail if the net did not pass the sanity checks
286  if (!LoadNets(data_file_path, lang)) {
287  return false;
288  }
289 
290  // load the folding sets if any. This function will return true if the
291  // file does not exist. But will fail if the it did not pass the sanity checks
292  if (!LoadFoldingSets(data_file_path, lang, lang_mod)) {
293  return false;
294  }
295 
296  init_ = true;
297  return true;
298 }
299 
305 bool ConvNetCharClassifier::LoadNets(const string &data_file_path,
306  const string &lang) {
307  string char_net_file;
308 
309  // add the lang identifier
310  char_net_file = data_file_path + lang;
311  char_net_file += ".cube.nn";
312 
313  // neural network is optional
314  FILE *fp = fopen(char_net_file.c_str(), "rb");
315  if (fp == NULL) {
316  return true;
317  }
318  fclose(fp);
319 
320  // load main net
321  char_net_ = tesseract::NeuralNet::FromFile(char_net_file);
322  if (char_net_ == NULL) {
323  fprintf(stderr, "Cube ERROR (ConvNetCharClassifier::LoadNets): "
324  "could not load %s\n", char_net_file.c_str());
325  return false;
326  }
327 
328  // validate net
329  if (char_net_->in_cnt()!= feat_extract_->FeatureCnt()) {
330  fprintf(stderr, "Cube ERROR (ConvNetCharClassifier::LoadNets): "
331  "could not validate net %s\n", char_net_file.c_str());
332  return false;
333  }
334 
335  // alloc net i/o buffers
336  int feat_cnt = char_net_->in_cnt();
337  int class_cnt = char_set_->ClassCount();
338 
339  if (char_net_->out_cnt() != class_cnt) {
340  fprintf(stderr, "Cube ERROR (ConvNetCharClassifier::LoadNets): "
341  "output count (%d) and class count (%d) are not equal\n",
342  char_net_->out_cnt(), class_cnt);
343  return false;
344  }
345 
346  // allocate i/p and o/p buffers if needed
347  if (net_input_ == NULL) {
348  net_input_ = new float[feat_cnt];
349  net_output_ = new float[class_cnt];
350  }
351 
352  return true;
353 }
354 } // tesseract
static NeuralNet * FromFile(const string file_name)
Definition: neural_net.cpp:210
virtual bool SetLearnParam(char *var_name, float val)
int in_cnt() const
Definition: neural_net.h:49
bool Insert(int class_id, int cost, void *tag=NULL)
bool FeedForward(const Type *inputs, Type *outputs)
Definition: neural_net.cpp:88
virtual CharAltList * Classify(CharSamp *char_samp)
static bool ReadFileToString(const string &file_name, string *str)
Definition: cube_utils.cpp:189
ConvNetCharClassifier(CharSet *char_set, TuningParams *params, FeatureBase *feat_extract)
static int Prob2Cost(double prob_val)
Definition: cube_utils.cpp:37
int ClassCount() const
Definition: char_set.h:111
static void SplitStringUsing(const string &str, const string &delims, vector< string > *str_vec)
Definition: cube_utils.cpp:220
int out_cnt() const
Definition: neural_net.h:50
#define MAX(x, y)
Definition: ndminx.h:24
virtual bool ComputeFeatures(CharSamp *char_samp, float *features)=0
virtual int CharCost(CharSamp *char_samp)
signed int char_32
Definition: string_32.h:40
basic_string< char_32 > string_32
Definition: string_32.h:41
virtual bool Train(CharSamp *char_samp, int ClassID)
static void UTF8ToUTF32(const char *utf8_str, string_32 *str32)
Definition: cube_utils.cpp:256
int ClassID(const char_32 *str) const
Definition: char_set.h:54
virtual int FeatureCnt()=0
const char_32 * ClassString(int class_id) const
Definition: char_set.h:104
void SetNet(tesseract::NeuralNet *net)