tesseract  3.05.02
tess_lang_model.cpp
Go to the documentation of this file.
1 /**********************************************************************
2  * File: tess_lang_model.cpp
3  * Description: Implementation of the Tesseract Language Model Class
4  * Author: Ahmad Abdulkader
5  * Created: 2008
6  *
7  * (C) Copyright 2008, Google Inc.
8  ** Licensed under the Apache License, Version 2.0 (the "License");
9  ** you may not use this file except in compliance with the License.
10  ** You may obtain a copy of the License at
11  ** http://www.apache.org/licenses/LICENSE-2.0
12  ** Unless required by applicable law or agreed to in writing, software
13  ** distributed under the License is distributed on an "AS IS" BASIS,
14  ** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  ** See the License for the specific language governing permissions and
16  ** limitations under the License.
17  *
18  **********************************************************************/
19 
20 // The TessLangModel class abstracts the Tesseract language model. It inherits
21 // from the LangModel class. The Tesseract language model encompasses several
22 // Dawgs (words from training data, punctuation, numbers, document words).
23 // On top of this Cube adds an OOD state machine
24 // The class provides methods to traverse the language model in a generative
25 // fashion. Given any node in the DAWG, the language model can generate a list
26 // of children (or fan-out) edges
27 
28 #include <string>
29 #include <vector>
30 
31 #include "char_samp.h"
32 #include "cube_utils.h"
33 #include "dict.h"
34 #include "tesseractclass.h"
35 #include "tess_lang_model.h"
36 #include "tessdatamanager.h"
37 #include "unicharset.h"
38 
39 namespace tesseract {
40 // max fan-out (used for preallocation). Initialized here, but modified by
41 // constructor
42 int TessLangModel::max_edge_ = 4096;
43 
44 // Language model extra State machines
45 const Dawg *TessLangModel::ood_dawg_ = reinterpret_cast<Dawg *>(DAWG_OOD);
46 const Dawg *TessLangModel::number_dawg_ = reinterpret_cast<Dawg *>(DAWG_NUMBER);
47 
48 // number state machine
49 const int TessLangModel::num_state_machine_[kStateCnt][kNumLiteralCnt] = {
50  {0, 1, 1, NUM_TRM, NUM_TRM},
51  {NUM_TRM, 1, 1, 3, 2},
52  {NUM_TRM, NUM_TRM, 1, NUM_TRM, 2},
53  {NUM_TRM, NUM_TRM, 3, NUM_TRM, 2},
54 };
55 const int TessLangModel::num_max_repeat_[kStateCnt] = {3, 32, 8, 3};
56 
57 // thresholds and penalties
58 int TessLangModel::max_ood_shape_cost_ = CubeUtils::Prob2Cost(1e-4);
59 
60 TessLangModel::TessLangModel(const string &lm_params,
61  const string &data_file_path,
62  bool load_system_dawg,
63  TessdataManager *tessdata_manager,
64  CubeRecoContext *cntxt) {
65  cntxt_ = cntxt;
66  has_case_ = cntxt_->HasCase();
67  // Load the rest of the language model elements from file
68  LoadLangModelElements(lm_params);
69  // Load word_dawgs_ if needed.
70  if (tessdata_manager->SeekToStart(TESSDATA_CUBE_UNICHARSET)) {
71  word_dawgs_ = new DawgVector();
72  if (load_system_dawg &&
73  tessdata_manager->SeekToStart(TESSDATA_CUBE_SYSTEM_DAWG)) {
74  // The last parameter to the Dawg constructor (the debug level) is set to
75  // false, until Cube has a way to express its preferred debug level.
76  *word_dawgs_ += new SquishedDawg(tessdata_manager->GetDataFilePtr(),
78  cntxt_->Lang().c_str(),
79  SYSTEM_DAWG_PERM, false);
80  }
81  } else {
82  word_dawgs_ = NULL;
83  }
84 }
85 
86 // Cleanup an edge array
87 void TessLangModel::FreeEdges(int edge_cnt, LangModEdge **edge_array) {
88  if (edge_array != NULL) {
89  for (int edge_idx = 0; edge_idx < edge_cnt; edge_idx++) {
90  if (edge_array[edge_idx] != NULL) {
91  delete edge_array[edge_idx];
92  }
93  }
94  delete []edge_array;
95  }
96 }
97 
98 // Determines if a sequence of 32-bit chars is valid in this language model
99 // starting from the specified edge. If the eow_flag is ON, also checks for
100 // a valid EndOfWord. If final_edge is not NULL, returns a pointer to the last
101 // edge
102 bool TessLangModel::IsValidSequence(LangModEdge *edge,
103  const char_32 *sequence,
104  bool eow_flag,
105  LangModEdge **final_edge) {
106  // get the edges emerging from this edge
107  int edge_cnt = 0;
108  LangModEdge **edge_array = GetEdges(NULL, edge, &edge_cnt);
109 
110  // find the 1st char in the sequence in the children
111  for (int edge_idx = 0; edge_idx < edge_cnt; edge_idx++) {
112  // found a match
113  if (sequence[0] == edge_array[edge_idx]->EdgeString()[0]) {
114  // if this is the last char
115  if (sequence[1] == 0) {
116  // succeed if we are in prefix mode or this is a terminal edge
117  if (eow_flag == false || edge_array[edge_idx]->IsEOW()) {
118  if (final_edge != NULL) {
119  (*final_edge) = edge_array[edge_idx];
120  edge_array[edge_idx] = NULL;
121  }
122 
123  FreeEdges(edge_cnt, edge_array);
124  return true;
125  }
126  } else {
127  // not the last char continue checking
128  if (IsValidSequence(edge_array[edge_idx], sequence + 1, eow_flag,
129  final_edge) == true) {
130  FreeEdges(edge_cnt, edge_array);
131  return true;
132  }
133  }
134  }
135  }
136 
137  FreeEdges(edge_cnt, edge_array);
138  return false;
139 }
140 
141 // Determines if a sequence of 32-bit chars is valid in this language model
142 // starting from the root. If the eow_flag is ON, also checks for
143 // a valid EndOfWord. If final_edge is not NULL, returns a pointer to the last
144 // edge
145 bool TessLangModel::IsValidSequence(const char_32 *sequence, bool eow_flag,
146  LangModEdge **final_edge) {
147  if (final_edge != NULL) {
148  (*final_edge) = NULL;
149  }
150 
151  return IsValidSequence(NULL, sequence, eow_flag, final_edge);
152 }
153 
155  return lead_punc_.find(ch) != string::npos;
156 }
157 
159  return trail_punc_.find(ch) != string::npos;
160 }
161 
163  return digits_.find(ch) != string::npos;
164 }
165 
166 // The general fan-out generation function. Returns the list of edges
167 // fanning-out of the specified edge and their count. If an AltList is
168 // specified, only the class-ids with a minimum cost are considered
170  LangModEdge *lang_mod_edge,
171  int *edge_cnt) {
172  TessLangModEdge *tess_lm_edge =
173  reinterpret_cast<TessLangModEdge *>(lang_mod_edge);
174  LangModEdge **edge_array = NULL;
175  (*edge_cnt) = 0;
176 
177  // if we are starting from the root, we'll instantiate every DAWG
178  // and get the all the edges that emerge from the root
179  if (tess_lm_edge == NULL) {
180  // get DAWG count from Tesseract
181  int dawg_cnt = NumDawgs();
182  // preallocate the edge buffer
183  (*edge_cnt) = dawg_cnt * max_edge_;
184  edge_array = new LangModEdge *[(*edge_cnt)];
185 
186  for (int dawg_idx = (*edge_cnt) = 0; dawg_idx < dawg_cnt; dawg_idx++) {
187  const Dawg *curr_dawg = GetDawg(dawg_idx);
188  // Only look through word Dawgs (since there is a special way of
189  // handling numbers and punctuation).
190  if (curr_dawg->type() == DAWG_TYPE_WORD) {
191  (*edge_cnt) += FanOut(alt_list, curr_dawg, 0, 0, NULL, true,
192  edge_array + (*edge_cnt));
193  }
194  } // dawg
195 
196  (*edge_cnt) += FanOut(alt_list, number_dawg_, 0, 0, NULL, true,
197  edge_array + (*edge_cnt));
198 
199  // OOD: it is intentionally not added to the list to make sure it comes
200  // at the end
201  (*edge_cnt) += FanOut(alt_list, ood_dawg_, 0, 0, NULL, true,
202  edge_array + (*edge_cnt));
203 
204  // set the root flag for all root edges
205  for (int edge_idx = 0; edge_idx < (*edge_cnt); edge_idx++) {
206  edge_array[edge_idx]->SetRoot(true);
207  }
208  } else { // not starting at the root
209  // preallocate the edge buffer
210  (*edge_cnt) = max_edge_;
211  // allocate memory for edges
212  edge_array = new LangModEdge *[(*edge_cnt)];
213 
214  // get the FanOut edges from the root of each dawg
215  (*edge_cnt) = FanOut(alt_list,
216  tess_lm_edge->GetDawg(),
217  tess_lm_edge->EndEdge(), tess_lm_edge->EdgeMask(),
218  tess_lm_edge->EdgeString(), false, edge_array);
219  }
220  return edge_array;
221 }
222 
223 // generate edges from an NULL terminated string
224 // (used for punctuation, operators and digits)
225 int TessLangModel::Edges(const char *strng, const Dawg *dawg,
226  EDGE_REF edge_ref, EDGE_REF edge_mask,
227  LangModEdge **edge_array) {
228  int edge_idx,
229  edge_cnt = 0;
230 
231  for (edge_idx = 0; strng[edge_idx] != 0; edge_idx++) {
232  int class_id = cntxt_->CharacterSet()->ClassID((char_32)strng[edge_idx]);
233  if (class_id != INVALID_UNICHAR_ID) {
234  // create an edge object
235  edge_array[edge_cnt] = new TessLangModEdge(cntxt_, dawg, edge_ref,
236  class_id);
237 
238  reinterpret_cast<TessLangModEdge *>(edge_array[edge_cnt])->
239  SetEdgeMask(edge_mask);
240  edge_cnt++;
241  }
242  }
243 
244  return edge_cnt;
245 }
246 
247 // generate OOD edges
248 int TessLangModel::OODEdges(CharAltList *alt_list, EDGE_REF edge_ref,
249  EDGE_REF edge_ref_mask, LangModEdge **edge_array) {
250  int class_cnt = cntxt_->CharacterSet()->ClassCount();
251  int edge_cnt = 0;
252  for (int class_id = 0; class_id < class_cnt; class_id++) {
253  // produce an OOD edge only if the cost of the char is low enough
254  if ((alt_list == NULL ||
255  alt_list->ClassCost(class_id) <= max_ood_shape_cost_)) {
256  // create an edge object
257  edge_array[edge_cnt] = new TessLangModEdge(cntxt_, class_id);
258  edge_cnt++;
259  }
260  }
261 
262  return edge_cnt;
263 }
264 
265 // computes and returns the edges that fan out of an edge ref
266 int TessLangModel::FanOut(CharAltList *alt_list, const Dawg *dawg,
267  EDGE_REF edge_ref, EDGE_REF edge_mask,
268  const char_32 *str, bool root_flag,
269  LangModEdge **edge_array) {
270  int edge_cnt = 0;
271  NODE_REF next_node = NO_EDGE;
272 
273  // OOD
274  if (dawg == reinterpret_cast<Dawg *>(DAWG_OOD)) {
275  if (ood_enabled_ == true) {
276  return OODEdges(alt_list, edge_ref, edge_mask, edge_array);
277  } else {
278  return 0;
279  }
280  } else if (dawg == reinterpret_cast<Dawg *>(DAWG_NUMBER)) {
281  // Number
282  if (numeric_enabled_ == true) {
283  return NumberEdges(edge_ref, edge_array);
284  } else {
285  return 0;
286  }
287  } else if (IsTrailingPuncEdge(edge_mask)) {
288  // a TRAILING PUNC MASK, generate more trailing punctuation and return
289  if (punc_enabled_ == true) {
290  EDGE_REF trail_cnt = TrailingPuncCount(edge_mask);
291  return Edges(trail_punc_.c_str(), dawg, edge_ref,
292  TrailingPuncEdgeMask(trail_cnt + 1), edge_array);
293  } else {
294  return 0;
295  }
296  } else if (root_flag == true || edge_ref == 0) {
297  // Root, generate leading punctuation and continue
298  if (root_flag) {
299  if (punc_enabled_ == true) {
300  edge_cnt += Edges(lead_punc_.c_str(), dawg, 0, LEAD_PUNC_EDGE_REF_MASK,
301  edge_array);
302  }
303  }
304  next_node = 0;
305  } else {
306  // a node in the main trie
307  bool eow_flag = (dawg->end_of_word(edge_ref) != 0);
308 
309  // for EOW
310  if (eow_flag == true) {
311  // generate trailing punctuation
312  if (punc_enabled_ == true) {
313  edge_cnt += Edges(trail_punc_.c_str(), dawg, edge_ref,
314  TrailingPuncEdgeMask((EDGE_REF)1), edge_array);
315  // generate a hyphen and go back to the root
316  edge_cnt += Edges("-/", dawg, 0, 0, edge_array + edge_cnt);
317  }
318  }
319 
320  // advance node
321  next_node = dawg->next_node(edge_ref);
322  if (next_node == 0 || next_node == NO_EDGE) {
323  return edge_cnt;
324  }
325  }
326 
327  // now get all the emerging edges if word list is enabled
328  if (word_list_enabled_ == true && next_node != NO_EDGE) {
329  // create child edges
330  int child_edge_cnt =
331  TessLangModEdge::CreateChildren(cntxt_, dawg, next_node,
332  edge_array + edge_cnt);
333  int strt_cnt = edge_cnt;
334 
335  // set the edge mask
336  for (int child = 0; child < child_edge_cnt; child++) {
337  reinterpret_cast<TessLangModEdge *>(edge_array[edge_cnt++])->
338  SetEdgeMask(edge_mask);
339  }
340 
341  // if we are at the root, create upper case forms of these edges if possible
342  if (root_flag == true) {
343  for (int child = 0; child < child_edge_cnt; child++) {
344  TessLangModEdge *child_edge =
345  reinterpret_cast<TessLangModEdge *>(edge_array[strt_cnt + child]);
346 
347  if (has_case_ == true) {
348  const char_32 *edge_str = child_edge->EdgeString();
349  if (edge_str != NULL && islower(edge_str[0]) != 0 &&
350  edge_str[1] == 0) {
351  int class_id =
352  cntxt_->CharacterSet()->ClassID(toupper(edge_str[0]));
353  if (class_id != INVALID_UNICHAR_ID) {
354  // generate an upper case edge for lower case chars
355  edge_array[edge_cnt] = new TessLangModEdge(cntxt_, dawg,
356  child_edge->StartEdge(), child_edge->EndEdge(), class_id);
357 
358  reinterpret_cast<TessLangModEdge *>(edge_array[edge_cnt])->
359  SetEdgeMask(edge_mask);
360  edge_cnt++;
361  }
362  }
363  }
364  }
365  }
366  }
367  return edge_cnt;
368 }
369 
370 // Generate the edges fanning-out from an edge in the number state machine
371 int TessLangModel::NumberEdges(EDGE_REF edge_ref, LangModEdge **edge_array) {
372  EDGE_REF new_state,
373  state;
374 
375  inT64 repeat_cnt,
376  new_repeat_cnt;
377 
378  state = ((edge_ref & NUMBER_STATE_MASK) >> NUMBER_STATE_SHIFT);
379  repeat_cnt = ((edge_ref & NUMBER_REPEAT_MASK) >> NUMBER_REPEAT_SHIFT);
380 
381  if (state < 0 || state >= kStateCnt) {
382  return 0;
383  }
384 
385  // go through all valid transitions from the state
386  int edge_cnt = 0;
387 
388  EDGE_REF new_edge_ref;
389 
390  for (int lit = 0; lit < kNumLiteralCnt; lit++) {
391  // move to the new state
392  new_state = num_state_machine_[state][lit];
393  if (new_state == NUM_TRM) {
394  continue;
395  }
396 
397  if (new_state == state) {
398  new_repeat_cnt = repeat_cnt + 1;
399  } else {
400  new_repeat_cnt = 1;
401  }
402 
403  // not allowed to repeat beyond this
404  if (new_repeat_cnt > num_max_repeat_[state]) {
405  continue;
406  }
407 
408  new_edge_ref = (new_state << NUMBER_STATE_SHIFT) |
409  (lit << NUMBER_LITERAL_SHIFT) |
410  (new_repeat_cnt << NUMBER_REPEAT_SHIFT);
411 
412  edge_cnt += Edges(literal_str_[lit]->c_str(), number_dawg_,
413  new_edge_ref, 0, edge_array + edge_cnt);
414  }
415 
416  return edge_cnt;
417 }
418 
419 // Loads Language model elements from contents of the <lang>.cube.lm file
420 bool TessLangModel::LoadLangModelElements(const string &lm_params) {
421  bool success = true;
422  // split into lines, each corresponding to a token type below
423  vector<string> str_vec;
424  CubeUtils::SplitStringUsing(lm_params, "\r\n", &str_vec);
425  for (int entry = 0; entry < str_vec.size(); entry++) {
426  vector<string> tokens;
427  // should be only two tokens: type and value
428  CubeUtils::SplitStringUsing(str_vec[entry], "=", &tokens);
429  if (tokens.size() != 2)
430  success = false;
431  if (tokens[0] == "LeadPunc") {
432  lead_punc_ = tokens[1];
433  } else if (tokens[0] == "TrailPunc") {
434  trail_punc_ = tokens[1];
435  } else if (tokens[0] == "NumLeadPunc") {
436  num_lead_punc_ = tokens[1];
437  } else if (tokens[0] == "NumTrailPunc") {
438  num_trail_punc_ = tokens[1];
439  } else if (tokens[0] == "Operators") {
440  operators_ = tokens[1];
441  } else if (tokens[0] == "Digits") {
442  digits_ = tokens[1];
443  } else if (tokens[0] == "Alphas") {
444  alphas_ = tokens[1];
445  } else {
446  success = false;
447  }
448  }
449 
450  RemoveInvalidCharacters(&num_lead_punc_);
451  RemoveInvalidCharacters(&num_trail_punc_);
452  RemoveInvalidCharacters(&digits_);
453  RemoveInvalidCharacters(&operators_);
454  RemoveInvalidCharacters(&alphas_);
455 
456  // form the array of literal strings needed for number state machine
457  // It is essential that the literal strings go in the order below
458  literal_str_[0] = &num_lead_punc_;
459  literal_str_[1] = &num_trail_punc_;
460  literal_str_[2] = &digits_;
461  literal_str_[3] = &operators_;
462  literal_str_[4] = &alphas_;
463 
464  return success;
465 }
466 
468  CharSet *char_set = cntxt_->CharacterSet();
469  tesseract::string_32 lm_str32;
470  CubeUtils::UTF8ToUTF32(lm_str->c_str(), &lm_str32);
471 
472  int len = CubeUtils::StrLen(lm_str32.c_str());
473  char_32 *clean_str32 = new char_32[len + 1];
474  int clean_len = 0;
475  for (int i = 0; i < len; ++i) {
476  int class_id = char_set->ClassID((char_32)lm_str32[i]);
477  if (class_id != INVALID_UNICHAR_ID) {
478  clean_str32[clean_len] = lm_str32[i];
479  ++clean_len;
480  }
481  }
482  clean_str32[clean_len] = 0;
483  if (clean_len < len) {
484  lm_str->clear();
485  CubeUtils::UTF32ToUTF8(clean_str32, lm_str);
486  }
487  delete [] clean_str32;
488 }
489 
490 int TessLangModel::NumDawgs() const {
491  return (word_dawgs_ != NULL) ?
492  word_dawgs_->size() : cntxt_->TesseractObject()->getDict().NumDawgs();
493 }
494 
495 // Returns the dawgs with the given index from either the dawgs
496 // stored by the Tesseract object, or the word_dawgs_.
497 const Dawg *TessLangModel::GetDawg(int index) const {
498  if (word_dawgs_ != NULL) {
499  ASSERT_HOST(index < word_dawgs_->size());
500  return (*word_dawgs_)[index];
501  } else {
502  ASSERT_HOST(index < cntxt_->TesseractObject()->getDict().NumDawgs());
503  return cntxt_->TesseractObject()->getDict().GetDawg(index);
504  }
505 }
506 }
bool IsTrailingPunc(char_32 ch)
#define NUMBER_REPEAT_MASK
#define NUMBER_REPEAT_SHIFT
#define TrailingPuncEdgeMask(Cnt)
inT64 EDGE_REF
Definition: dawg.h:54
#define NUMBER_STATE_MASK
inT64 NODE_REF
Definition: dawg.h:55
#define TrailingPuncCount(edge_mask)
#define IsTrailingPuncEdge(edge_mask)
const Dawg * GetDawg(int index) const
Return i-th dawg pointer recorded in the dawgs_ vector.
Definition: dict.h:412
#define LEAD_PUNC_EDGE_REF_MASK
void RemoveInvalidCharacters(string *lm_str)
static int CreateChildren(CubeRecoContext *cntxt, const Dawg *edges, NODE_REF edge_reg, LangModEdge **lm_edges)
DawgType type() const
Definition: dawg.h:127
LangModEdge ** GetEdges(CharAltList *alt_list, LangModEdge *edge, int *edge_cnt)
int NumDawgs() const
Return the number of dawgs in the dawgs_ vector.
Definition: dict.h:410
static int Prob2Cost(double prob_val)
Definition: cube_utils.cpp:37
int ClassCount() const
Definition: char_set.h:111
static void UTF32ToUTF8(const char_32 *utf32_str, string *str)
Definition: cube_utils.cpp:272
virtual void SetRoot(bool flag)=0
#define NUM_TRM
tesseract::Tesseract * TesseractObject() const
static int StrLen(const char_32 *str)
Definition: cube_utils.cpp:54
const Dawg * GetDawg() const
bool IsValidSequence(const char_32 *sequence, bool eow_flag, LangModEdge **final_edge=NULL)
static void SplitStringUsing(const string &str, const string &delims, vector< string > *str_vec)
Definition: cube_utils.cpp:220
long long int inT64
Definition: host.h:41
const string & Lang() const
bool IsLeadingPunc(char_32 ch)
TessLangModel(const string &lm_params, const string &data_file_path, bool load_system_dawg, TessdataManager *tessdata_manager, CubeRecoContext *cntxt)
Dict & getDict()
Definition: classify.h:65
CharSet * CharacterSet() const
const int kNumLiteralCnt
signed int char_32
Definition: string_32.h:40
int size() const
Definition: genericvector.h:72
const char_32 * EdgeString() const
GenericVector< Dawg * > DawgVector
Definition: dict.h:49
basic_string< char_32 > string_32
Definition: string_32.h:41
#define NUMBER_LITERAL_SHIFT
#define DAWG_OOD
bool SeekToStart(TessdataType tessdata_type)
#define DAWG_NUMBER
const int kStateCnt
static void UTF8ToUTF32(const char *utf8_str, string_32 *str32)
Definition: cube_utils.cpp:256
int ClassID(const char_32 *str) const
Definition: char_set.h:54
#define NUMBER_STATE_SHIFT
#define ASSERT_HOST(x)
Definition: errcode.h:84