Point Cloud Library (PCL) 1.15.0
Loading...
Searching...
No Matches
decision_forest_trainer.h
1/*
2 * Software License Agreement (BSD License)
3 *
4 * Point Cloud Library (PCL) - www.pointclouds.org
5 * Copyright (c) 2010-2011, Willow Garage, Inc.
6 *
7 * All rights reserved.
8 *
9 * Redistribution and use in source and binary forms, with or without
10 * modification, are permitted provided that the following conditions
11 * are met:
12 *
13 * * Redistributions of source code must retain the above copyright
14 * notice, this list of conditions and the following disclaimer.
15 * * Redistributions in binary form must reproduce the above
16 * copyright notice, this list of conditions and the following
17 * disclaimer in the documentation and/or other materials provided
18 * with the distribution.
19 * * Neither the name of Willow Garage, Inc. nor the names of its
20 * contributors may be used to endorse or promote products derived
21 * from this software without specific prior written permission.
22 *
23 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
24 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
25 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
26 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
27 * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
28 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
29 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
30 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
31 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
32 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
33 * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
34 * POSSIBILITY OF SUCH DAMAGE.
35 *
36 */
37
38#pragma once
39
40#include <pcl/common/common.h>
41#include <pcl/ml/dt/decision_forest.h>
42#include <pcl/ml/dt/decision_tree.h>
43#include <pcl/ml/dt/decision_tree_trainer.h>
44#include <pcl/ml/feature_handler.h>
45#include <pcl/ml/stats_estimator.h>
46
47#include <vector>
48
49namespace pcl {
50
51/** Trainer for decision trees. */
52template <class FeatureType,
53 class DataSet,
54 class LabelType,
55 class ExampleIndex,
56 class NodeType>
57class PCL_EXPORTS DecisionForestTrainer {
58
59public:
60 /** Constructor. */
62
63 /** Destructor. */
65
66 /** Sets the number of trees to train.
67 *
68 * \param[in] num_of_trees the number of trees
69 */
70 inline void
71 setNumberOfTreesToTrain(const std::size_t num_of_trees)
72 {
73 num_of_trees_to_train_ = num_of_trees;
74 }
75
76 /** Sets the feature handler used to create and evaluate features.
77 *
78 * \param[in] feature_handler the feature handler
79 */
80 inline void
83 {
84 decision_tree_trainer_.setFeatureHandler(feature_handler);
85 }
86
87 /** Sets the object for estimating the statistics for tree nodes.
88 *
89 * \param[in] stats_estimator the statistics estimator
90 */
91 inline void
94 {
95 decision_tree_trainer_.setStatsEstimator(stats_estimator);
96 }
97
98 /** Sets the maximum depth of the learned tree.
99 *
100 * \param[in] max_tree_depth maximum depth of the learned tree
101 */
102 inline void
103 setMaxTreeDepth(const std::size_t max_tree_depth)
104 {
105 decision_tree_trainer_.setMaxTreeDepth(max_tree_depth);
106 }
107
108 /** Sets the number of features used to find optimal decision features.
109 *
110 * \param[in] num_of_features the number of features
111 */
112 inline void
113 setNumOfFeatures(const std::size_t num_of_features)
114 {
115 decision_tree_trainer_.setNumOfFeatures(num_of_features);
116 }
117
118 /** Sets the number of thresholds tested for finding the optimal decision threshold on
119 * the feature responses.
120 *
121 * \param[in] num_of_threshold the number of thresholds
122 */
123 inline void
124 setNumOfThresholds(const std::size_t num_of_threshold)
125 {
126 decision_tree_trainer_.setNumOfThresholds(num_of_threshold);
127 }
128
129 /** Sets the input data set used for training.
130 *
131 * \param[in] data_set the data set used for training
132 */
133 inline void
134 setTrainingDataSet(DataSet& data_set)
135 {
136 decision_tree_trainer_.setTrainingDataSet(data_set);
137 }
138
139 /** Example indices that specify the data used for training.
140 *
141 * \param[in] examples the examples
142 */
143 inline void
144 setExamples(std::vector<ExampleIndex>& examples)
145 {
146 decision_tree_trainer_.setExamples(examples);
147 }
148
149 /** Sets the label data corresponding to the example data.
150 *
151 * \param[in] label_data the label data
152 */
153 inline void
154 setLabelData(std::vector<LabelType>& label_data)
155 {
156 decision_tree_trainer_.setLabelData(label_data);
157 }
158
159 /** Sets the minimum number of examples to continue growing a tree.
160 *
161 * \param[in] n number of examples
162 */
163 inline void
165 {
166 decision_tree_trainer_.setMinExamplesForSplit(n);
167 }
168
169 /** Specify the thresholds to be used when evaluating features.
170 *
171 * \param[in] thres the threshold values
172 */
173 void
174 setThresholds(std::vector<float>& thres)
175 {
176 decision_tree_trainer_.setThresholds(thres);
177 }
178
179 /** Specify the data provider.
180 *
181 * \param[in] dtdp the data provider that should implement getDatasetAndLabels()
182 * function
183 */
184 void
186 typename pcl::DecisionTreeTrainerDataProvider<FeatureType,
187 DataSet,
188 LabelType,
189 ExampleIndex,
190 NodeType>::Ptr& dtdp)
191 {
192 decision_tree_trainer_.setDecisionTreeDataProvider(dtdp);
193 }
194
195 /** Specify if the features are randomly generated at each split node.
196 *
197 * \param[in] b do it or not
198 */
199 void
201 {
202 decision_tree_trainer_.setRandomFeaturesAtSplitNode(b);
203 }
204
205 /** Trains a decision forest using the set training data and settings.
206 *
207 * \param[out] forest destination for the trained forest
208 */
209 void
210 train(DecisionForest<NodeType>& forest);
211
212private:
213 /** The number of trees to train. */
214 std::size_t num_of_trees_to_train_{1};
215
216 /** The trainer for the decision trees of the forest. */
218 decision_tree_trainer_;
219};
220
221} // namespace pcl
222
223#include <pcl/ml/impl/dt/decision_forest_trainer.hpp>
Class representing a decision forest.
Trainer for decision trees.
void setMinExamplesForSplit(std::size_t n)
Sets the minimum number of examples to continue growing a tree.
void setStatsEstimator(pcl::StatsEstimator< LabelType, NodeType, DataSet, ExampleIndex > &stats_estimator)
Sets the object for estimating the statistics for tree nodes.
void setNumberOfTreesToTrain(const std::size_t num_of_trees)
Sets the number of trees to train.
void setTrainingDataSet(DataSet &data_set)
Sets the input data set used for training.
void setThresholds(std::vector< float > &thres)
Specify the thresholds to be used when evaluating features.
virtual ~DecisionForestTrainer()
Destructor.
void setNumOfFeatures(const std::size_t num_of_features)
Sets the number of features used to find optimal decision features.
void setRandomFeaturesAtSplitNode(bool b)
Specify if the features are randomly generated at each split node.
void setExamples(std::vector< ExampleIndex > &examples)
Example indices that specify the data used for training.
void setMaxTreeDepth(const std::size_t max_tree_depth)
Sets the maximum depth of the learned tree.
void setDecisionTreeDataProvider(typename pcl::DecisionTreeTrainerDataProvider< FeatureType, DataSet, LabelType, ExampleIndex, NodeType >::Ptr &dtdp)
Specify the data provider.
void setNumOfThresholds(const std::size_t num_of_threshold)
Sets the number of thresholds tested for finding the optimal decision threshold on the feature respon...
void setLabelData(std::vector< LabelType > &label_data)
Sets the label data corresponding to the example data.
void setFeatureHandler(pcl::FeatureHandler< FeatureType, DataSet, ExampleIndex > &feature_handler)
Sets the feature handler used to create and evaluate features.
Trainer for decision trees.
Utility class interface which is used for creating and evaluating features.
Class interface for gathering statistics for decision tree learning.
Define standard C methods and C++ classes that are common to all methods.