Point Cloud Library (PCL) 1.15.0
Loading...
Searching...
No Matches
decision_tree_trainer.hpp
1/*
2 * Software License Agreement (BSD License)
3 *
4 * Point Cloud Library (PCL) - www.pointclouds.org
5 * Copyright (c) 2010-2011, Willow Garage, Inc.
6 *
7 * All rights reserved.
8 *
9 * Redistribution and use in source and binary forms, with or without
10 * modification, are permitted provided that the following conditions
11 * are met:
12 *
13 * * Redistributions of source code must retain the above copyright
14 * notice, this list of conditions and the following disclaimer.
15 * * Redistributions in binary form must reproduce the above
16 * copyright notice, this list of conditions and the following
17 * disclaimer in the documentation and/or other materials provided
18 * with the distribution.
19 * * Neither the name of Willow Garage, Inc. nor the names of its
20 * contributors may be used to endorse or promote products derived
21 * from this software without specific prior written permission.
22 *
23 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
24 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
25 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
26 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
27 * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
28 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
29 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
30 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
31 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
32 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
33 * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
34 * POSSIBILITY OF SUCH DAMAGE.
35 *
36 */
37
38#pragma once
39
40namespace pcl {
41
42template <class FeatureType,
43 class DataSet,
44 class LabelType,
45 class ExampleIndex,
46 class NodeType>
48 DecisionTreeTrainer() = default;
49
50template <class FeatureType,
51 class DataSet,
52 class LabelType,
53 class ExampleIndex,
54 class NodeType>
56 ~DecisionTreeTrainer() = default;
57
58template <class FeatureType,
59 class DataSet,
60 class LabelType,
61 class ExampleIndex,
62 class NodeType>
63void
66{
67 // create random features
68 std::vector<FeatureType> features;
69
70 if (!random_features_at_split_node_)
71 feature_handler_->createRandomFeatures(num_of_features_, features);
72
73 // recursively build decision tree
74 NodeType root_node;
75 tree.setRoot(root_node);
76
77 if (decision_tree_trainer_data_provider_) {
78 std::cerr << "use decision_tree_trainer_data_provider_" << std::endl;
79
80 decision_tree_trainer_data_provider_->getDatasetAndLabels(
81 data_set_, label_data_, examples_);
82 trainDecisionTreeNode(
83 features, examples_, label_data_, max_tree_depth_, tree.getRoot());
84 label_data_.clear();
85 data_set_.clear();
86 examples_.clear();
87 }
88 else {
89 trainDecisionTreeNode(
90 features, examples_, label_data_, max_tree_depth_, tree.getRoot());
91 }
92}
93
94template <class FeatureType,
95 class DataSet,
96 class LabelType,
97 class ExampleIndex,
98 class NodeType>
99void
101 trainDecisionTreeNode(std::vector<FeatureType>& features,
102 std::vector<ExampleIndex>& examples,
103 std::vector<LabelType>& label_data,
104 const std::size_t max_depth,
105 NodeType& node)
106{
107 const std::size_t num_of_examples = examples.size();
108 if (num_of_examples == 0) {
109 PCL_ERROR(
110 "Reached invalid point in decision tree training: Number of examples is 0!\n");
111 return;
112 };
113
114 if (max_depth == 0) {
115 stats_estimator_->computeAndSetNodeStats(data_set_, examples, label_data, node);
116 return;
117 };
118
119 if (examples.size() < min_examples_for_split_) {
120 stats_estimator_->computeAndSetNodeStats(data_set_, examples, label_data, node);
121 return;
122 }
123
124 if (random_features_at_split_node_) {
125 features.clear();
126 feature_handler_->createRandomFeatures(num_of_features_, features);
127 }
128
129 std::vector<float> feature_results;
130 std::vector<unsigned char> flags;
131
132 feature_results.reserve(num_of_examples);
133 flags.reserve(num_of_examples);
134
135 // find best feature for split
136 int best_feature_index = -1;
137 float best_feature_threshold = 0.0f;
138 float best_feature_information_gain = 0.0f;
139
140 const std::size_t num_of_features = features.size();
141 for (std::size_t feature_index = 0; feature_index < num_of_features;
142 ++feature_index) {
143 // evaluate features
144 feature_handler_->evaluateFeature(
145 features[feature_index], data_set_, examples, feature_results, flags);
146
147 // get list of thresholds
148 if (!thresholds_.empty()) {
149 // compute information gain for each threshold and store threshold with highest
150 // information gain
151 for (const float& threshold : thresholds_) {
152
153 const float information_gain = stats_estimator_->computeInformationGain(
154 data_set_, examples, label_data, feature_results, flags, threshold);
155
156 if (information_gain > best_feature_information_gain) {
157 best_feature_information_gain = information_gain;
158 best_feature_index = static_cast<int>(feature_index);
159 best_feature_threshold = threshold;
160 }
161 }
162 }
163 else {
164 std::vector<float> thresholds;
165 thresholds.reserve(num_of_thresholds_);
166 createThresholdsUniform(num_of_thresholds_, feature_results, thresholds);
167
168 // compute information gain for each threshold and store threshold with highest
169 // information gain
170 for (std::size_t threshold_index = 0; threshold_index < num_of_thresholds_;
171 ++threshold_index) {
172 const float threshold = thresholds[threshold_index];
173
174 // compute information gain
175 const float information_gain = stats_estimator_->computeInformationGain(
176 data_set_, examples, label_data, feature_results, flags, threshold);
177
178 if (information_gain > best_feature_information_gain) {
179 best_feature_information_gain = information_gain;
180 best_feature_index = static_cast<int>(feature_index);
181 best_feature_threshold = threshold;
182 }
183 }
184 }
185 }
186
187 if (best_feature_index == -1) {
188 stats_estimator_->computeAndSetNodeStats(data_set_, examples, label_data, node);
189 return;
190 }
191
192 // get branch indices for best feature and best threshold
193 std::vector<unsigned char> branch_indices;
194 branch_indices.reserve(num_of_examples);
195 {
196 feature_handler_->evaluateFeature(
197 features[best_feature_index], data_set_, examples, feature_results, flags);
198
199 stats_estimator_->computeBranchIndices(
200 feature_results, flags, best_feature_threshold, branch_indices);
201 }
202
203 stats_estimator_->computeAndSetNodeStats(data_set_, examples, label_data, node);
204
205 // separate data
206 {
207 const std::size_t num_of_branches = stats_estimator_->getNumOfBranches();
208
209 std::vector<std::size_t> branch_counts(num_of_branches, 0);
210 for (std::size_t example_index = 0; example_index < num_of_examples;
211 ++example_index) {
212 ++branch_counts[branch_indices[example_index]];
213 }
214
215 node.feature = features[best_feature_index];
216 node.threshold = best_feature_threshold;
217 node.sub_nodes.resize(num_of_branches);
218
219 for (std::size_t branch_index = 0; branch_index < num_of_branches; ++branch_index) {
220 if (branch_counts[branch_index] == 0) {
221 NodeType branch_node;
222 stats_estimator_->computeAndSetNodeStats(
223 data_set_, examples, label_data, branch_node);
224 // branch_node->num_of_sub_nodes = 0;
225
226 node.sub_nodes[branch_index] = branch_node;
227
228 continue;
229 }
230
231 std::vector<LabelType> branch_labels;
232 std::vector<ExampleIndex> branch_examples;
233 branch_labels.reserve(branch_counts[branch_index]);
234 branch_examples.reserve(branch_counts[branch_index]);
235
236 for (std::size_t example_index = 0; example_index < num_of_examples;
237 ++example_index) {
238 if (branch_indices[example_index] == branch_index) {
239 branch_examples.push_back(examples[example_index]);
240 branch_labels.push_back(label_data[example_index]);
241 }
242 }
243
244 trainDecisionTreeNode(features,
245 branch_examples,
246 branch_labels,
247 max_depth - 1,
248 node.sub_nodes[branch_index]);
249 }
250 }
251}
252
253template <class FeatureType,
254 class DataSet,
255 class LabelType,
256 class ExampleIndex,
257 class NodeType>
258void
260 createThresholdsUniform(const std::size_t num_of_thresholds,
261 std::vector<float>& values,
262 std::vector<float>& thresholds)
263{
264 // estimate range of values
265 float min_value = ::std::numeric_limits<float>::max();
266 float max_value = -::std::numeric_limits<float>::max();
267
268 const std::size_t num_of_values = values.size();
269 for (std::size_t value_index = 0; value_index < num_of_values; ++value_index) {
270 const float value = values[value_index];
271
272 if (value < min_value)
273 min_value = value;
274 if (value > max_value)
275 max_value = value;
276 }
277
278 const float range = max_value - min_value;
279 const float step = range / static_cast<float>(num_of_thresholds + 2);
280
281 // compute thresholds
282 thresholds.resize(num_of_thresholds);
283
284 for (std::size_t threshold_index = 0; threshold_index < num_of_thresholds;
285 ++threshold_index) {
286 thresholds[threshold_index] =
287 min_value + step * (static_cast<float>(threshold_index + 1));
288 }
289}
290
291} // namespace pcl
Class representing a decision tree.
NodeType & getRoot()
Returns the root node of the tree.
void setRoot(const NodeType &root)
Sets the root node of the tree.
static void createThresholdsUniform(const std::size_t num_of_thresholds, std::vector< float > &values, std::vector< float > &thresholds)
Creates uniformly distributed thresholds over the range of the supplied values.
void trainDecisionTreeNode(std::vector< FeatureType > &features, std::vector< ExampleIndex > &examples, std::vector< LabelType > &label_data, std::size_t max_depth, NodeType &node)
Trains a decision tree node from the specified features, label data, and examples.
void train(DecisionTree< NodeType > &tree)
Trains a decision tree using the set training data and settings.
virtual ~DecisionTreeTrainer()
Destructor.
DecisionTreeTrainer()
Constructor.