Point Cloud Library (PCL)  1.11.1-dev
decision_tree_trainer.hpp
1 /*
2  * Software License Agreement (BSD License)
3  *
4  * Point Cloud Library (PCL) - www.pointclouds.org
5  * Copyright (c) 2010-2011, Willow Garage, Inc.
6  *
7  * All rights reserved.
8  *
9  * Redistribution and use in source and binary forms, with or without
10  * modification, are permitted provided that the following conditions
11  * are met:
12  *
13  * * Redistributions of source code must retain the above copyright
14  * notice, this list of conditions and the following disclaimer.
15  * * Redistributions in binary form must reproduce the above
16  * copyright notice, this list of conditions and the following
17  * disclaimer in the documentation and/or other materials provided
18  * with the distribution.
19  * * Neither the name of Willow Garage, Inc. nor the names of its
20  * contributors may be used to endorse or promote products derived
21  * from this software without specific prior written permission.
22  *
23  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
24  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
25  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
26  * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
27  * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
28  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
29  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
30  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
31  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
32  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
33  * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
34  * POSSIBILITY OF SUCH DAMAGE.
35  *
36  */
37 
38 #pragma once
39 
40 namespace pcl {
41 
42 template <class FeatureType,
43  class DataSet,
44  class LabelType,
45  class ExampleIndex,
46  class NodeType>
49 : max_tree_depth_(15)
50 , num_of_features_(1000)
51 , num_of_thresholds_(10)
52 , feature_handler_(nullptr)
53 , stats_estimator_(nullptr)
54 , data_set_()
55 , label_data_()
56 , examples_()
57 , decision_tree_trainer_data_provider_()
58 , random_features_at_split_node_(false)
59 {}
60 
61 template <class FeatureType,
62  class DataSet,
63  class LabelType,
64  class ExampleIndex,
65  class NodeType>
68 {}
69 
70 template <class FeatureType,
71  class DataSet,
72  class LabelType,
73  class ExampleIndex,
74  class NodeType>
75 void
78 {
79  // create random features
80  std::vector<FeatureType> features;
81 
82  if (!random_features_at_split_node_)
83  feature_handler_->createRandomFeatures(num_of_features_, features);
84 
85  // recursively build decision tree
86  NodeType root_node;
87  tree.setRoot(root_node);
88 
89  if (decision_tree_trainer_data_provider_) {
90  std::cerr << "use decision_tree_trainer_data_provider_" << std::endl;
91 
92  decision_tree_trainer_data_provider_->getDatasetAndLabels(
93  data_set_, label_data_, examples_);
94  trainDecisionTreeNode(
95  features, examples_, label_data_, max_tree_depth_, tree.getRoot());
96  label_data_.clear();
97  data_set_.clear();
98  examples_.clear();
99  }
100  else {
101  trainDecisionTreeNode(
102  features, examples_, label_data_, max_tree_depth_, tree.getRoot());
103  }
104 }
105 
106 template <class FeatureType,
107  class DataSet,
108  class LabelType,
109  class ExampleIndex,
110  class NodeType>
111 void
113  trainDecisionTreeNode(std::vector<FeatureType>& features,
114  std::vector<ExampleIndex>& examples,
115  std::vector<LabelType>& label_data,
116  const std::size_t max_depth,
117  NodeType& node)
118 {
119  const std::size_t num_of_examples = examples.size();
120  if (num_of_examples == 0) {
121  PCL_ERROR(
122  "Reached invalid point in decision tree training: Number of examples is 0!\n");
123  return;
124  };
125 
126  if (max_depth == 0) {
127  stats_estimator_->computeAndSetNodeStats(data_set_, examples, label_data, node);
128  return;
129  };
130 
131  if (examples.size() < min_examples_for_split_) {
132  stats_estimator_->computeAndSetNodeStats(data_set_, examples, label_data, node);
133  return;
134  }
135 
136  if (random_features_at_split_node_) {
137  features.clear();
138  feature_handler_->createRandomFeatures(num_of_features_, features);
139  }
140 
141  std::vector<float> feature_results;
142  std::vector<unsigned char> flags;
143 
144  feature_results.reserve(num_of_examples);
145  flags.reserve(num_of_examples);
146 
147  // find best feature for split
148  int best_feature_index = -1;
149  float best_feature_threshold = 0.0f;
150  float best_feature_information_gain = 0.0f;
151 
152  const std::size_t num_of_features = features.size();
153  for (std::size_t feature_index = 0; feature_index < num_of_features;
154  ++feature_index) {
155  // evaluate features
156  feature_handler_->evaluateFeature(
157  features[feature_index], data_set_, examples, feature_results, flags);
158 
159  // get list of thresholds
160  if (!thresholds_.empty()) {
161  // compute information gain for each threshold and store threshold with highest
162  // information gain
163  for (std::size_t threshold_index = 0; threshold_index < thresholds_.size();
164  ++threshold_index) {
165 
166  const float information_gain =
167  stats_estimator_->computeInformationGain(data_set_,
168  examples,
169  label_data,
170  feature_results,
171  flags,
172  thresholds_[threshold_index]);
173 
174  if (information_gain > best_feature_information_gain) {
175  best_feature_information_gain = information_gain;
176  best_feature_index = static_cast<int>(feature_index);
177  best_feature_threshold = thresholds_[threshold_index];
178  }
179  }
180  }
181  else {
182  std::vector<float> thresholds;
183  thresholds.reserve(num_of_thresholds_);
184  createThresholdsUniform(num_of_thresholds_, feature_results, thresholds);
185 
186  // compute information gain for each threshold and store threshold with highest
187  // information gain
188  for (std::size_t threshold_index = 0; threshold_index < num_of_thresholds_;
189  ++threshold_index) {
190  const float threshold = thresholds[threshold_index];
191 
192  // compute information gain
193  const float information_gain = stats_estimator_->computeInformationGain(
194  data_set_, examples, label_data, feature_results, flags, threshold);
195 
196  if (information_gain > best_feature_information_gain) {
197  best_feature_information_gain = information_gain;
198  best_feature_index = static_cast<int>(feature_index);
199  best_feature_threshold = threshold;
200  }
201  }
202  }
203  }
204 
205  if (best_feature_index == -1) {
206  stats_estimator_->computeAndSetNodeStats(data_set_, examples, label_data, node);
207  return;
208  }
209 
210  // get branch indices for best feature and best threshold
211  std::vector<unsigned char> branch_indices;
212  branch_indices.reserve(num_of_examples);
213  {
214  feature_handler_->evaluateFeature(
215  features[best_feature_index], data_set_, examples, feature_results, flags);
216 
217  stats_estimator_->computeBranchIndices(
218  feature_results, flags, best_feature_threshold, branch_indices);
219  }
220 
221  stats_estimator_->computeAndSetNodeStats(data_set_, examples, label_data, node);
222 
223  // separate data
224  {
225  const std::size_t num_of_branches = stats_estimator_->getNumOfBranches();
226 
227  std::vector<std::size_t> branch_counts(num_of_branches, 0);
228  for (std::size_t example_index = 0; example_index < num_of_examples;
229  ++example_index) {
230  ++branch_counts[branch_indices[example_index]];
231  }
232 
233  node.feature = features[best_feature_index];
234  node.threshold = best_feature_threshold;
235  node.sub_nodes.resize(num_of_branches);
236 
237  for (std::size_t branch_index = 0; branch_index < num_of_branches; ++branch_index) {
238  if (branch_counts[branch_index] == 0) {
239  NodeType branch_node;
240  stats_estimator_->computeAndSetNodeStats(
241  data_set_, examples, label_data, branch_node);
242  // branch_node->num_of_sub_nodes = 0;
243 
244  node.sub_nodes[branch_index] = branch_node;
245 
246  continue;
247  }
248 
249  std::vector<LabelType> branch_labels;
250  std::vector<ExampleIndex> branch_examples;
251  branch_labels.reserve(branch_counts[branch_index]);
252  branch_examples.reserve(branch_counts[branch_index]);
253 
254  for (std::size_t example_index = 0; example_index < num_of_examples;
255  ++example_index) {
256  if (branch_indices[example_index] == branch_index) {
257  branch_examples.push_back(examples[example_index]);
258  branch_labels.push_back(label_data[example_index]);
259  }
260  }
261 
262  trainDecisionTreeNode(features,
263  branch_examples,
264  branch_labels,
265  max_depth - 1,
266  node.sub_nodes[branch_index]);
267  }
268  }
269 }
270 
271 template <class FeatureType,
272  class DataSet,
273  class LabelType,
274  class ExampleIndex,
275  class NodeType>
276 void
278  createThresholdsUniform(const std::size_t num_of_thresholds,
279  std::vector<float>& values,
280  std::vector<float>& thresholds)
281 {
282  // estimate range of values
283  float min_value = ::std::numeric_limits<float>::max();
284  float max_value = -::std::numeric_limits<float>::max();
285 
286  const std::size_t num_of_values = values.size();
287  for (std::size_t value_index = 0; value_index < num_of_values; ++value_index) {
288  const float value = values[value_index];
289 
290  if (value < min_value)
291  min_value = value;
292  if (value > max_value)
293  max_value = value;
294  }
295 
296  const float range = max_value - min_value;
297  const float step = range / static_cast<float>(num_of_thresholds + 2);
298 
299  // compute thresholds
300  thresholds.resize(num_of_thresholds);
301 
302  for (std::size_t threshold_index = 0; threshold_index < num_of_thresholds;
303  ++threshold_index) {
304  thresholds[threshold_index] =
305  min_value + step * (static_cast<float>(threshold_index + 1));
306  }
307 }
308 
309 } // namespace pcl
pcl
Definition: convolution.h:46
pcl::DecisionTree::getRoot
NodeType & getRoot()
Returns the root node of the tree.
Definition: decision_tree.h:69
pcl::DecisionTreeTrainer::DecisionTreeTrainer
DecisionTreeTrainer()
Constructor.
Definition: decision_tree_trainer.hpp:48
pcl::DecisionTreeTrainer::train
void train(DecisionTree< NodeType > &tree)
Trains a decision tree using the set training data and settings.
Definition: decision_tree_trainer.hpp:76
pcl::DecisionTreeTrainer::createThresholdsUniform
static void createThresholdsUniform(const std::size_t num_of_thresholds, std::vector< float > &values, std::vector< float > &thresholds)
Creates uniformely distrebuted thresholds over the range of the supplied values.
Definition: decision_tree_trainer.hpp:278
pcl::DecisionTree
Class representing a decision tree.
Definition: decision_tree.h:49
pcl::DecisionTreeTrainer::trainDecisionTreeNode
void trainDecisionTreeNode(std::vector< FeatureType > &features, std::vector< ExampleIndex > &examples, std::vector< LabelType > &label_data, std::size_t max_depth, NodeType &node)
Trains a decision tree node from the specified features, label data, and examples.
Definition: decision_tree_trainer.hpp:113
pcl::DecisionTreeTrainer::~DecisionTreeTrainer
virtual ~DecisionTreeTrainer()
Destructor.
Definition: decision_tree_trainer.hpp:67
pcl::DecisionTree::setRoot
void setRoot(const NodeType &root)
Sets the root node of the tree.
Definition: decision_tree.h:62