tesseract  3.05.02
hybrid_neural_net_classifier.h
Go to the documentation of this file.
1 /**********************************************************************
2  * File: conv_net_classifier.h
3  * Description: Declaration of Convolutional-NeuralNet Character Classifier
4  * Author: Ahmad Abdulkader
5  * Created: 2007
6  *
7  * (C) Copyright 2008, Google Inc.
8  ** Licensed under the Apache License, Version 2.0 (the "License");
9  ** you may not use this file except in compliance with the License.
10  ** You may obtain a copy of the License at
11  ** http://www.apache.org/licenses/LICENSE-2.0
12  ** Unless required by applicable law or agreed to in writing, software
13  ** distributed under the License is distributed on an "AS IS" BASIS,
14  ** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  ** See the License for the specific language governing permissions and
16  ** limitations under the License.
17  *
18  **********************************************************************/
19 
20 #ifndef HYBRID_NEURAL_NET_CLASSIFIER_H
21 #define HYBRID_NEURAL_NET_CLASSIFIER_H
22 
23 #include <string>
24 #include <vector>
25 
26 #include "char_samp.h"
27 #include "char_altlist.h"
28 #include "char_set.h"
29 #include "classifier_base.h"
30 #include "feature_base.h"
31 #include "lang_model.h"
32 #include "neural_net.h"
33 #include "tuning_params.h"
34 
35 namespace tesseract {
36 
37 // Folding Ratio is the ratio of the max-activation of members of a folding
38 // set that is used to compute the min-activation of the rest of the set
39 // static const float kFoldingRatio = 0.75; // see conv_net_classifier.h
40 
42  public:
44  FeatureBase *feat_extract);
46  // The main training function. Given a sample and a class ID the classifier
47  // updates its parameters according to its learning algorithm. This function
48  // is currently not implemented. TODO(ahmadab): implement end-2-end training
49  virtual bool Train(CharSamp *char_samp, int ClassID);
50  // A secondary function needed for training. Allows the trainer to set the
51  // value of any train-time parameter. This function is currently not
52  // implemented. TODO(ahmadab): implement end-2-end training
53  virtual bool SetLearnParam(char *var_name, float val);
54  // Externally sets the Neural Net used by the classifier. Used for training
55  void SetNet(tesseract::NeuralNet *net);
56 
57  // Classifies an input charsamp and return a CharAltList object containing
58  // the possible candidates and corresponding scores
59  virtual CharAltList *Classify(CharSamp *char_samp);
60  // Computes the cost of a specific charsamp being a character (versus a
61  // non-character: part-of-a-character OR more-than-one-character)
62  virtual int CharCost(CharSamp *char_samp);
63 
64  private:
65  // Neural Net object used for classification
66  vector<tesseract::NeuralNet *> nets_;
67  vector<float> net_wgts_;
68 
69  // data buffers used to hold Neural Net inputs and outputs
70  float *net_input_;
71  float *net_output_;
72 
73  // Init the classifier provided a data-path and a language string
74  virtual bool Init(const string &data_file_path, const string &lang,
75  LangModel *lang_mod);
76  // Loads the NeuralNets needed for the classifier
77  bool LoadNets(const string &data_file_path, const string &lang);
78  // Load folding sets
79  // This function returns true on success or if the file can't be read,
80  // returns false if an error is encountered.
81  virtual bool LoadFoldingSets(const string &data_file_path,
82  const string &lang,
83  LangModel *lang_mod);
84  // Folds the output of the NeuralNet using the loaded folding sets
85  virtual void Fold();
86  // Scales the input char_samp and feeds it to the NeuralNet as input
87  bool RunNets(CharSamp *char_samp);
88 };
89 }
90 #endif // HYBRID_NEURAL_NET_CLASSIFIER_H
virtual bool Train(CharSamp *char_samp, int ClassID)
virtual CharAltList * Classify(CharSamp *char_samp)
HybridNeuralNetCharClassifier(CharSet *char_set, TuningParams *params, FeatureBase *feat_extract)
virtual bool SetLearnParam(char *var_name, float val)