49 for (
int net_idx = 0; net_idx < nets_.size(); net_idx++) {
50 if (nets_[net_idx] != NULL) {
51 delete nets_[net_idx];
56 if (net_input_ != NULL) {
61 if (net_output_ != NULL) {
83 void HybridNeuralNetCharClassifier::Fold() {
88 for (
int class_id = 0; class_id < class_cnt; class_id++) {
93 for (
int ch = 0; ch < upper_form32.length(); ch++) {
94 if (iswalpha(static_cast<int>(upper_form32[ch])) != 0) {
95 upper_form32[ch] = towupper(upper_form32[ch]);
102 upper_form32.c_str()));
103 if (upper_class_id != -1 && class_id != upper_class_id) {
104 float max_out =
MAX(net_output_[class_id], net_output_[upper_class_id]);
105 net_output_[class_id] = max_out;
106 net_output_[upper_class_id] = max_out;
115 for (
int fold_set = 0; fold_set <
fold_set_cnt_; fold_set++) {
116 float max_prob = net_output_[
fold_sets_[fold_set][0]];
119 if (net_output_[
fold_sets_[fold_set][ch]] > max_prob) {
120 max_prob = net_output_[
fold_sets_[fold_set][ch]];
124 net_output_[
fold_sets_[fold_set][ch]] =
MAX(max_prob * kFoldingRatio,
132 bool HybridNeuralNetCharClassifier::RunNets(CharSamp *char_samp) {
137 if (net_input_ == NULL) {
138 net_input_ =
new float[feat_cnt];
139 net_output_ =
new float[class_cnt];
148 memset(net_output_, 0, class_cnt *
sizeof(*net_output_));
149 float *inputs = net_input_;
150 for (
int net_idx = 0; net_idx < nets_.size(); net_idx++) {
152 vector<float> net_out(class_cnt, 0.0);
153 if (!nets_[net_idx]->FeedForward(inputs, &net_out[0])) {
157 for (
int class_idx = 0; class_idx < class_cnt; class_idx++) {
158 net_output_[class_idx] += (net_out[class_idx] * net_wgts_[net_idx]);
161 inputs += nets_[net_idx]->in_cnt();
173 if (RunNets(char_samp) ==
false) {
184 if (RunNets(char_samp) ==
false) {
193 for (
int out = 1; out < class_cnt; out++) {
195 alt_list->
Insert(out, cost);
208 bool HybridNeuralNetCharClassifier::LoadFoldingSets(
209 const string &data_file_path,
const string &
lang,
LangModel *lang_mod) {
211 string fold_file_name;
212 fold_file_name = data_file_path +
lang;
213 fold_file_name +=
".cube.fold";
216 FILE *fp = fopen(fold_file_name.c_str(),
"rb");
222 string fold_sets_str;
229 vector<string> str_vec;
235 for (
int fold_set = 0; fold_set <
fold_set_cnt_; fold_set++) {
236 reinterpret_cast<TessLangModel *
>(lang_mod)->RemoveInvalidCharacters(
240 if (str_vec[fold_set].length() <= 1) {
241 fprintf(stderr,
"Cube WARNING (ConvNetCharClassifier::LoadFoldingSets): " 242 "invalidating folding set %d\n", fold_set);
260 bool HybridNeuralNetCharClassifier::Init(
const string &data_file_path,
262 LangModel *lang_mod) {
269 if (!LoadNets(data_file_path,
lang)) {
275 if (!LoadFoldingSets(data_file_path,
lang, lang_mod)) {
286 bool HybridNeuralNetCharClassifier::LoadNets(
const string &data_file_path,
287 const string &
lang) {
288 string hybrid_net_file;
289 string junk_net_file;
292 hybrid_net_file = data_file_path +
lang;
293 hybrid_net_file +=
".cube.hybrid";
296 FILE *fp = fopen(hybrid_net_file.c_str(),
"rb");
308 vector<string> str_vec;
310 if (str_vec.empty()) {
315 nets_.resize(str_vec.size(), NULL);
316 net_wgts_.resize(str_vec.size(), 0);
317 int total_input_size = 0;
318 for (
int net_idx = 0; net_idx < str_vec.size(); net_idx++) {
320 vector<string> tokens_vec;
323 if (tokens_vec.size() != 2) {
327 string net_file_name = data_file_path + tokens_vec[0];
329 if (nets_[net_idx] == NULL) {
333 net_wgts_[net_idx] = atof(tokens_vec[1].c_str());
334 if (net_wgts_[net_idx] < 0.0) {
337 total_input_size += nets_[net_idx]->in_cnt();
virtual int CharCost(CharSamp *char_samp)
static NeuralNet * FromFile(const string file_name)
bool Insert(int class_id, int cost, void *tag=NULL)
virtual bool Train(CharSamp *char_samp, int ClassID)
static bool ReadFileToString(const string &file_name, string *str)
virtual CharAltList * Classify(CharSamp *char_samp)
HybridNeuralNetCharClassifier(CharSet *char_set, TuningParams *params, FeatureBase *feat_extract)
static int Prob2Cost(double prob_val)
FeatureBase * feat_extract_
void SetNet(tesseract::NeuralNet *net)
static void SplitStringUsing(const string &str, const string &delims, vector< string > *str_vec)
virtual ~HybridNeuralNetCharClassifier()
virtual bool ComputeFeatures(CharSamp *char_samp, float *features)=0
basic_string< char_32 > string_32
static void UTF8ToUTF32(const char *utf8_str, string_32 *str32)
virtual bool SetLearnParam(char *var_name, float val)
int ClassID(const char_32 *str) const
virtual int FeatureCnt()=0
const char_32 * ClassString(int class_id) const