48 if (char_net_ != NULL) {
53 if (net_input_ != NULL) {
58 if (net_output_ != NULL) {
86 void ConvNetCharClassifier::Fold() {
91 for (
int class_id = 0; class_id < class_cnt; class_id++) {
96 for (
int ch = 0; ch < upper_form32.length(); ch++) {
97 if (iswalpha(static_cast<int>(upper_form32[ch])) != 0) {
98 upper_form32[ch] = towupper(upper_form32[ch]);
105 upper_form32.c_str()));
106 if (upper_class_id != -1 && class_id != upper_class_id) {
107 float max_out =
MAX(net_output_[class_id], net_output_[upper_class_id]);
108 net_output_[class_id] = max_out;
109 net_output_[upper_class_id] = max_out;
118 for (
int fold_set = 0; fold_set <
fold_set_cnt_; fold_set++) {
121 float max_prob = net_output_[
fold_sets_[fold_set][0]];
123 if (net_output_[
fold_sets_[fold_set][ch]] > max_prob) {
124 max_prob = net_output_[
fold_sets_[fold_set][ch]];
128 net_output_[
fold_sets_[fold_set][ch]] =
MAX(max_prob * kFoldingRatio,
138 bool ConvNetCharClassifier::RunNets(CharSamp *char_samp) {
139 if (char_net_ == NULL) {
140 fprintf(stderr,
"Cube ERROR (ConvNetCharClassifier::RunNets): " 141 "NeuralNet is NULL\n");
144 int feat_cnt = char_net_->
in_cnt();
148 if (net_input_ == NULL) {
149 net_input_ =
new float[feat_cnt];
150 net_output_ =
new float[class_cnt];
155 fprintf(stderr,
"Cube ERROR (ConvNetCharClassifier::RunNets): " 156 "unable to compute features\n");
160 if (char_net_ != NULL) {
161 if (char_net_->
FeedForward(net_input_, net_output_) ==
false) {
162 fprintf(stderr,
"Cube ERROR (ConvNetCharClassifier::RunNets): " 163 "unable to run feed-forward\n");
177 if (RunNets(char_samp) ==
false) {
189 if (RunNets(char_samp) ==
false) {
198 for (
int out = 1; out < class_cnt; out++) {
200 alt_list->
Insert(out, cost);
210 if (char_net_ != NULL) {
214 char_net_ = char_net;
221 bool ConvNetCharClassifier::LoadFoldingSets(
const string &data_file_path,
225 string fold_file_name;
226 fold_file_name = data_file_path +
lang;
227 fold_file_name +=
".cube.fold";
230 FILE *fp = fopen(fold_file_name.c_str(),
"rb");
236 string fold_sets_str;
243 vector<string> str_vec;
250 for (
int fold_set = 0; fold_set <
fold_set_cnt_; fold_set++) {
251 reinterpret_cast<TessLangModel *
>(lang_mod)->RemoveInvalidCharacters(
255 if (str_vec[fold_set].length() <= 1) {
256 fprintf(stderr,
"Cube WARNING (ConvNetCharClassifier::LoadFoldingSets): " 257 "invalidating folding set %d\n", fold_set);
277 bool ConvNetCharClassifier::Init(
const string &data_file_path,
279 LangModel *lang_mod) {
286 if (!LoadNets(data_file_path,
lang)) {
292 if (!LoadFoldingSets(data_file_path,
lang, lang_mod)) {
305 bool ConvNetCharClassifier::LoadNets(
const string &data_file_path,
306 const string &
lang) {
307 string char_net_file;
310 char_net_file = data_file_path +
lang;
311 char_net_file +=
".cube.nn";
314 FILE *fp = fopen(char_net_file.c_str(),
"rb");
322 if (char_net_ == NULL) {
323 fprintf(stderr,
"Cube ERROR (ConvNetCharClassifier::LoadNets): " 324 "could not load %s\n", char_net_file.c_str());
330 fprintf(stderr,
"Cube ERROR (ConvNetCharClassifier::LoadNets): " 331 "could not validate net %s\n", char_net_file.c_str());
336 int feat_cnt = char_net_->
in_cnt();
339 if (char_net_->
out_cnt() != class_cnt) {
340 fprintf(stderr,
"Cube ERROR (ConvNetCharClassifier::LoadNets): " 341 "output count (%d) and class count (%d) are not equal\n",
342 char_net_->
out_cnt(), class_cnt);
347 if (net_input_ == NULL) {
348 net_input_ =
new float[feat_cnt];
349 net_output_ =
new float[class_cnt];
static NeuralNet * FromFile(const string file_name)
virtual bool SetLearnParam(char *var_name, float val)
bool Insert(int class_id, int cost, void *tag=NULL)
bool FeedForward(const Type *inputs, Type *outputs)
virtual CharAltList * Classify(CharSamp *char_samp)
static bool ReadFileToString(const string &file_name, string *str)
ConvNetCharClassifier(CharSet *char_set, TuningParams *params, FeatureBase *feat_extract)
static int Prob2Cost(double prob_val)
FeatureBase * feat_extract_
static void SplitStringUsing(const string &str, const string &delims, vector< string > *str_vec)
virtual bool ComputeFeatures(CharSamp *char_samp, float *features)=0
virtual int CharCost(CharSamp *char_samp)
virtual ~ConvNetCharClassifier()
basic_string< char_32 > string_32
virtual bool Train(CharSamp *char_samp, int ClassID)
static void UTF8ToUTF32(const char *utf8_str, string_32 *str32)
int ClassID(const char_32 *str) const
virtual int FeatureCnt()=0
const char_32 * ClassString(int class_id) const
void SetNet(tesseract::NeuralNet *net)