tesseract  3.05.02
beam_search.cpp
Go to the documentation of this file.
1 /**********************************************************************
2  * File: beam_search.cpp
3  * Description: Class to implement Beam Word Search Algorithm
4  * Author: Ahmad Abdulkader
5  * Created: 2007
6  *
7  * (C) Copyright 2008, Google Inc.
8  ** Licensed under the Apache License, Version 2.0 (the "License");
9  ** you may not use this file except in compliance with the License.
10  ** You may obtain a copy of the License at
11  ** http://www.apache.org/licenses/LICENSE-2.0
12  ** Unless required by applicable law or agreed to in writing, software
13  ** distributed under the License is distributed on an "AS IS" BASIS,
14  ** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  ** See the License for the specific language governing permissions and
16  ** limitations under the License.
17  *
18  **********************************************************************/
19 
20 #include <algorithm>
21 
22 #include "beam_search.h"
23 #include "tesseractclass.h"
24 
25 namespace tesseract {
26 
27 BeamSearch::BeamSearch(CubeRecoContext *cntxt, bool word_mode) {
28  cntxt_ = cntxt;
29  seg_pt_cnt_ = 0;
30  col_cnt_ = 1;
31  col_ = NULL;
32  word_mode_ = word_mode;
33 }
34 
35 // Cleanup the lattice corresponding to the last search
36 void BeamSearch::Cleanup() {
37  if (col_ != NULL) {
38  for (int col = 0; col < col_cnt_; col++) {
39  delete col_[col];
40  }
41  delete []col_;
42  }
43  col_ = NULL;
44 }
45 
47  Cleanup();
48 }
49 
50 // Creates a set of children nodes emerging from a parent node based on
51 // the character alternate list and the language model.
52 void BeamSearch::CreateChildren(SearchColumn *out_col, LangModel *lang_mod,
53  SearchNode *parent_node,
54  LangModEdge *lm_parent_edge,
55  CharAltList *char_alt_list, int extra_cost) {
56  // get all the edges from this parent
57  int edge_cnt;
58  LangModEdge **lm_edges = lang_mod->GetEdges(char_alt_list,
59  lm_parent_edge, &edge_cnt);
60  if (lm_edges) {
61  // add them to the ending column with the appropriate parent
62  for (int edge = 0; edge < edge_cnt; edge++) {
63  // add a node to the column if the current column is not the
64  // last one, or if the lang model edge indicates it is valid EOW
65  if (!cntxt_->NoisyInput() && out_col->ColIdx() >= seg_pt_cnt_ &&
66  !lm_edges[edge]->IsEOW()) {
67  // free edge since no object is going to own it
68  delete lm_edges[edge];
69  continue;
70  }
71 
72  // compute the recognition cost of this node
73  int recognition_cost = MIN_PROB_COST;
74  if (char_alt_list && char_alt_list->AltCount() > 0) {
75  recognition_cost = MAX(0, char_alt_list->ClassCost(
76  lm_edges[edge]->ClassID()));
77  // Add the no space cost. This should zero in word mode
78  recognition_cost += extra_cost;
79  }
80 
81  // Note that the edge will be freed inside the column if
82  // AddNode is called
83  if (recognition_cost >= 0) {
84  out_col->AddNode(lm_edges[edge], recognition_cost, parent_node,
85  cntxt_);
86  } else {
87  delete lm_edges[edge];
88  }
89  } // edge
90  // free edge array
91  delete []lm_edges;
92  } // lm_edges
93 }
94 
95 // Performs a beam search in the specified search using the specified
96 // language model; returns an alternate list of possible words as a result.
98  // verifications
99  if (!lang_mod)
100  lang_mod = cntxt_->LangMod();
101  if (!lang_mod) {
102  fprintf(stderr, "Cube ERROR (BeamSearch::Search): could not construct "
103  "LangModel\n");
104  return NULL;
105  }
106 
107  // free existing state
108  Cleanup();
109 
110  // get seg pt count
111  seg_pt_cnt_ = srch_obj->SegPtCnt();
112  if (seg_pt_cnt_ < 0) {
113  return NULL;
114  }
115  col_cnt_ = seg_pt_cnt_ + 1;
116 
117  // disregard suspicious cases
118  if (seg_pt_cnt_ > 128) {
119  fprintf(stderr, "Cube ERROR (BeamSearch::Search): segment point count is "
120  "suspiciously high; bailing out\n");
121  return NULL;
122  }
123 
124  // alloc memory for columns
125  col_ = new SearchColumn *[col_cnt_];
126  memset(col_, 0, col_cnt_ * sizeof(*col_));
127 
128  // for all possible segments
129  for (int end_seg = 1; end_seg <= (seg_pt_cnt_ + 1); end_seg++) {
130  // create a search column
131  col_[end_seg - 1] = new SearchColumn(end_seg - 1,
132  cntxt_->Params()->BeamWidth());
133 
134  // for all possible start segments
135  int init_seg = MAX(0, end_seg - cntxt_->Params()->MaxSegPerChar());
136  for (int strt_seg = init_seg; strt_seg < end_seg; strt_seg++) {
137  int parent_nodes_cnt;
138  SearchNode **parent_nodes;
139 
140  // for the root segment, we do not have a parent
141  if (strt_seg == 0) {
142  parent_nodes_cnt = 1;
143  parent_nodes = NULL;
144  } else {
145  // for all the existing nodes in the starting column
146  parent_nodes_cnt = col_[strt_seg - 1]->NodeCount();
147  parent_nodes = col_[strt_seg - 1]->Nodes();
148  }
149 
150  // run the shape recognizer
151  CharAltList *char_alt_list = srch_obj->RecognizeSegment(strt_seg - 1,
152  end_seg - 1);
153  // for all the possible parents
154  for (int parent_idx = 0; parent_idx < parent_nodes_cnt; parent_idx++) {
155  // point to the parent node
156  SearchNode *parent_node = !parent_nodes ? NULL
157  : parent_nodes[parent_idx];
158  LangModEdge *lm_parent_edge = !parent_node ? lang_mod->Root()
159  : parent_node->LangModelEdge();
160 
161  // compute the cost of not having spaces within the segment range
162  int contig_cost = srch_obj->NoSpaceCost(strt_seg - 1, end_seg - 1);
163 
164  // In phrase mode, compute the cost of not having a space before
165  // this character
166  int no_space_cost = 0;
167  if (!word_mode_ && strt_seg > 0) {
168  no_space_cost = srch_obj->NoSpaceCost(strt_seg - 1);
169  }
170 
171  // if the no space cost is low enough
172  if ((contig_cost + no_space_cost) < MIN_PROB_COST) {
173  // Add the children nodes
174  CreateChildren(col_[end_seg - 1], lang_mod, parent_node,
175  lm_parent_edge, char_alt_list,
176  contig_cost + no_space_cost);
177  }
178 
179  // In phrase mode and if not starting at the root
180  if (!word_mode_ && strt_seg > 0) { // parent_node must be non-NULL
181  // consider starting a new word for nodes that are valid EOW
182  if (parent_node->LangModelEdge()->IsEOW()) {
183  // get the space cost
184  int space_cost = srch_obj->SpaceCost(strt_seg - 1);
185  // if the space cost is low enough
186  if ((contig_cost + space_cost) < MIN_PROB_COST) {
187  // Restart the language model and add nodes as children to the
188  // space node.
189  CreateChildren(col_[end_seg - 1], lang_mod, parent_node, NULL,
190  char_alt_list, contig_cost + space_cost);
191  }
192  }
193  }
194  } // parent
195  } // strt_seg
196 
197  // prune the column nodes
198  col_[end_seg - 1]->Prune();
199 
200  // Free the column hash table. No longer needed
201  col_[end_seg - 1]->FreeHashTable();
202  } // end_seg
203 
204  WordAltList *alt_list = CreateWordAltList(srch_obj);
205  return alt_list;
206 }
207 
208 // Creates a Word alternate list from the results in the lattice.
209 WordAltList *BeamSearch::CreateWordAltList(SearchObject *srch_obj) {
210  // create an alternate list of all the nodes in the last column
211  int node_cnt = col_[col_cnt_ - 1]->NodeCount();
212  SearchNode **srch_nodes = col_[col_cnt_ - 1]->Nodes();
213  CharBigrams *bigrams = cntxt_->Bigrams();
214  WordUnigrams *word_unigrams = cntxt_->WordUnigramsObj();
215 
216  // Save the index of the best-cost node before the alt list is
217  // sorted, so that we can retrieve it from the node list when backtracking.
218  best_presorted_node_idx_ = 0;
219  int best_cost = -1;
220 
221  if (node_cnt <= 0)
222  return NULL;
223 
224  // start creating the word alternate list
225  WordAltList *alt_list = new WordAltList(node_cnt + 1);
226  for (int node_idx = 0; node_idx < node_cnt; node_idx++) {
227  // recognition cost
228  int recognition_cost = srch_nodes[node_idx]->BestCost();
229  // compute the size cost of the alternate
230  char_32 *ch_buff = NULL;
231  int size_cost = SizeCost(srch_obj, srch_nodes[node_idx], &ch_buff);
232  // accumulate other costs
233  if (ch_buff) {
234  int cost = 0;
235  // char bigram cost
236  int bigram_cost = !bigrams ? 0 :
237  bigrams->Cost(ch_buff, cntxt_->CharacterSet());
238  // word unigram cost
239  int unigram_cost = !word_unigrams ? 0 :
240  word_unigrams->Cost(ch_buff, cntxt_->LangMod(),
241  cntxt_->CharacterSet());
242  // overall cost
243  cost = static_cast<int>(
244  (size_cost * cntxt_->Params()->SizeWgt()) +
245  (bigram_cost * cntxt_->Params()->CharBigramWgt()) +
246  (unigram_cost * cntxt_->Params()->WordUnigramWgt()) +
247  (recognition_cost * cntxt_->Params()->RecoWgt()));
248 
249  // insert into word alt list
250  alt_list->Insert(ch_buff, cost,
251  static_cast<void *>(srch_nodes[node_idx]));
252  // Note that strict < is necessary because WordAltList::Sort()
253  // uses it in a bubble sort to swap entries.
254  if (best_cost < 0 || cost < best_cost) {
255  best_presorted_node_idx_ = node_idx;
256  best_cost = cost;
257  }
258  delete []ch_buff;
259  }
260  }
261 
262  // sort the alternates based on cost
263  alt_list->Sort();
264  return alt_list;
265 }
266 
267 // Returns the lattice column corresponding to the specified column index.
269  if (col < 0 || col >= col_cnt_ || !col_)
270  return NULL;
271  return col_[col];
272 }
273 
274 // Returns the best node in the last column of last performed search.
276  if (col_cnt_ < 1 || !col_ || !col_[col_cnt_ - 1])
277  return NULL;
278 
279  int node_cnt = col_[col_cnt_ - 1]->NodeCount();
280  SearchNode **srch_nodes = col_[col_cnt_ - 1]->Nodes();
281  if (node_cnt < 1 || !srch_nodes || !srch_nodes[0])
282  return NULL;
283  return srch_nodes[0];
284 }
285 
286 // Returns the string corresponding to the specified alt.
287 char_32 *BeamSearch::Alt(int alt) const {
288  // get the last column of the lattice
289  if (col_cnt_ <= 0)
290  return NULL;
291 
292  SearchColumn *srch_col = col_[col_cnt_ - 1];
293  if (!srch_col)
294  return NULL;
295 
296  // point to the last node in the selected path
297  if (alt >= srch_col->NodeCount() || srch_col->Nodes() == NULL) {
298  return NULL;
299  }
300 
301  SearchNode *srch_node = srch_col->Nodes()[alt];
302  if (!srch_node)
303  return NULL;
304 
305  // get string
306  char_32 *str32 = srch_node->PathString();
307  if (!str32)
308  return NULL;
309 
310  return str32;
311 }
312 
313 // Backtracks from the specified node index and returns the corresponding
314 // character mapped segments and character count. Optional return
315 // arguments are the char_32 result string and character bounding
316 // boxes, if non-NULL values are passed in.
317 CharSamp **BeamSearch::BackTrack(SearchObject *srch_obj, int node_index,
318  int *char_cnt, char_32 **str32,
319  Boxa **char_boxes) const {
320  // get the last column of the lattice
321  if (col_cnt_ <= 0)
322  return NULL;
323  SearchColumn *srch_col = col_[col_cnt_ - 1];
324  if (!srch_col)
325  return NULL;
326 
327  // point to the last node in the selected path
328  if (node_index >= srch_col->NodeCount() || !srch_col->Nodes())
329  return NULL;
330 
331  SearchNode *srch_node = srch_col->Nodes()[node_index];
332  if (!srch_node)
333  return NULL;
334  return BackTrack(srch_obj, srch_node, char_cnt, str32, char_boxes);
335 }
336 
337 // Backtracks from the specified node index and returns the corresponding
338 // character mapped segments and character count. Optional return
339 // arguments are the char_32 result string and character bounding
340 // boxes, if non-NULL values are passed in.
342  int *char_cnt, char_32 **str32,
343  Boxa **char_boxes) const {
344  if (!srch_node)
345  return NULL;
346 
347  if (str32) {
348  delete [](*str32); // clear existing value
349  *str32 = srch_node->PathString();
350  if (!*str32)
351  return NULL;
352  }
353 
354  if (char_boxes && *char_boxes) {
355  boxaDestroy(char_boxes); // clear existing value
356  }
357 
358  CharSamp **chars;
359  chars = SplitByNode(srch_obj, srch_node, char_cnt, char_boxes);
360  if (!chars && str32)
361  delete []*str32;
362  return chars;
363 }
364 
365 // Backtracks from the given lattice node and return the corresponding
366 // char mapped segments and character count. The character bounding
367 // boxes are optional return arguments, if non-NULL values are passed in.
368 CharSamp **BeamSearch::SplitByNode(SearchObject *srch_obj,
369  SearchNode *srch_node,
370  int *char_cnt,
371  Boxa **char_boxes) const {
372  // Count the characters (could be less than the path length when in
373  // phrase mode)
374  *char_cnt = 0;
375  SearchNode *node = srch_node;
376  while (node) {
377  node = node->ParentNode();
378  (*char_cnt)++;
379  }
380 
381  if (*char_cnt == 0)
382  return NULL;
383 
384  // Allocate box array
385  if (char_boxes) {
386  if (*char_boxes)
387  boxaDestroy(char_boxes); // clear existing value
388  *char_boxes = boxaCreate(*char_cnt);
389  if (*char_boxes == NULL)
390  return NULL;
391  }
392 
393  // Allocate memory for CharSamp array.
394  CharSamp **chars = new CharSamp *[*char_cnt];
395 
396  int ch_idx = *char_cnt - 1;
397  int seg_pt_cnt = srch_obj->SegPtCnt();
398  bool success=true;
399  while (srch_node && ch_idx >= 0) {
400  // Parent node (could be null)
401  SearchNode *parent_node = srch_node->ParentNode();
402 
403  // Get the seg pts corresponding to the search node
404  int st_col = !parent_node ? 0 : parent_node->ColIdx() + 1;
405  int st_seg_pt = st_col <= 0 ? -1 : st_col - 1;
406  int end_col = srch_node->ColIdx();
407  int end_seg_pt = end_col >= seg_pt_cnt ? seg_pt_cnt : end_col;
408 
409  // Get a char sample corresponding to the segmentation points
410  CharSamp *samp = srch_obj->CharSample(st_seg_pt, end_seg_pt);
411  if (!samp) {
412  success = false;
413  break;
414  }
415  samp->SetLabel(srch_node->NodeString());
416  chars[ch_idx] = samp;
417  if (char_boxes) {
418  // Create the corresponding character bounding box
419  Box *char_box = boxCreate(samp->Left(), samp->Top(),
420  samp->Width(), samp->Height());
421  if (!char_box) {
422  success = false;
423  break;
424  }
425  boxaAddBox(*char_boxes, char_box, L_INSERT);
426  }
427  srch_node = parent_node;
428  ch_idx--;
429  }
430  if (!success) {
431  delete []chars;
432  if (char_boxes)
433  boxaDestroy(char_boxes);
434  return NULL;
435  }
436 
437  // Reverse the order of boxes.
438  if (char_boxes) {
439  int char_boxa_size = boxaGetCount(*char_boxes);
440  int limit = char_boxa_size / 2;
441  for (int i = 0; i < limit; ++i) {
442  int box1_idx = i;
443  int box2_idx = char_boxa_size - 1 - i;
444  Box *box1 = boxaGetBox(*char_boxes, box1_idx, L_CLONE);
445  Box *box2 = boxaGetBox(*char_boxes, box2_idx, L_CLONE);
446  boxaReplaceBox(*char_boxes, box2_idx, box1);
447  boxaReplaceBox(*char_boxes, box1_idx, box2);
448  }
449  }
450  return chars;
451 }
452 
453 // Returns the size cost of a string for a lattice path that
454 // ends at the specified lattice node.
456  char_32 **str32) const {
457  CharSamp **chars = NULL;
458  int char_cnt = 0;
459  if (!node)
460  return 0;
461  // Backtrack to get string and character segmentation
462  chars = BackTrack(srch_obj, node, &char_cnt, str32, NULL);
463  if (!chars)
464  return WORST_COST;
465  int size_cost = (cntxt_->SizeModel() == NULL) ? 0 :
466  cntxt_->SizeModel()->Cost(chars, char_cnt);
467  delete []chars;
468  return size_cost;
469 }
470 } // namespace tesesract
WordAltList * Search(SearchObject *srch_obj, LangModel *lang_mod=NULL)
Definition: beam_search.cpp:97
virtual LangModEdge * Root()=0
virtual CharSamp * CharSample(int start_pt, int end_pt)=0
SearchNode * BestNode() const
SearchNode * ParentNode()
Definition: search_node.h:69
CharSamp ** BackTrack(SearchObject *srch_obj, int node_index, int *char_cnt, char_32 **str32, Boxa **char_boxes) const
double RecoWgt() const
Definition: tuning_params.h:48
SearchNode ** Nodes() const
Definition: search_column.h:44
SearchColumn * Column(int col_idx) const
virtual int SpaceCost(int seg_pt)=0
virtual LangModEdge ** GetEdges(CharAltList *alt_list, LangModEdge *parent_edge, int *edge_cnt)=0
WordUnigrams * WordUnigramsObj() const
virtual int NoSpaceCost(int seg_pt)=0
LangModel * LangMod() const
virtual int SegPtCnt()=0
void SetLabel(char_32 label)
Definition: char_samp.h:68
#define WORST_COST
Definition: cube_const.h:30
LangModEdge * LangModelEdge()
Definition: search_node.h:70
TuningParams * Params() const
int SizeCost(SearchObject *srch_obj, SearchNode *node, char_32 **str32=NULL) const
int Cost(CharSamp **samp_array, int samp_cnt) const
double SizeWgt() const
Definition: tuning_params.h:49
int Cost(const char_32 *str, CharSet *char_set) const
CharSet * CharacterSet() const
#define MAX(x, y)
Definition: ndminx.h:24
signed int char_32
Definition: string_32.h:40
#define MIN_PROB_COST
Definition: cube_const.h:26
virtual bool IsEOW() const =0
int ClassCost(int class_id) const
Definition: char_altlist.h:42
WordSizeModel * SizeModel() const
BeamSearch(CubeRecoContext *cntxt, bool word_mode=true)
Definition: beam_search.cpp:27
int AltCount() const
Definition: altlist.h:39
bool Insert(char_32 *char_ptr, int cost, void *tag=NULL)
SearchNode * AddNode(LangModEdge *edge, int score, SearchNode *parent, CubeRecoContext *cntxt)
const char_32 * NodeString()
Definition: search_node.h:51
int Cost(const char_32 *str32, LangModel *lang_mod, CharSet *char_set) const
char_32 * Alt(int alt) const
CharBigrams * Bigrams() const
double WordUnigramWgt() const
Definition: tuning_params.h:51
double CharBigramWgt() const
Definition: tuning_params.h:50
virtual CharAltList * RecognizeSegment(int start_pt, int end_pt)=0