Point Cloud Library (PCL) 1.15.0
Loading...
Searching...
No Matches
fern_evaluator.hpp
1/*
2 * Software License Agreement (BSD License)
3 *
4 * Point Cloud Library (PCL) - www.pointclouds.org
5 * Copyright (c) 2010-2011, Willow Garage, Inc.
6 *
7 * All rights reserved.
8 *
9 * Redistribution and use in source and binary forms, with or without
10 * modification, are permitted provided that the following conditions
11 * are met:
12 *
13 * * Redistributions of source code must retain the above copyright
14 * notice, this list of conditions and the following disclaimer.
15 * * Redistributions in binary form must reproduce the above
16 * copyright notice, this list of conditions and the following
17 * disclaimer in the documentation and/or other materials provided
18 * with the distribution.
19 * * Neither the name of Willow Garage, Inc. nor the names of its
20 * contributors may be used to endorse or promote products derived
21 * from this software without specific prior written permission.
22 *
23 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
24 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
25 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
26 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
27 * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
28 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
29 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
30 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
31 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
32 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
33 * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
34 * POSSIBILITY OF SUCH DAMAGE.
35 *
36 */
37
38#pragma once
39
40#include <pcl/common/common.h>
41#include <pcl/ml/feature_handler.h>
42#include <pcl/ml/ferns/fern.h>
43#include <pcl/ml/stats_estimator.h>
44
45#include <vector>
46
47namespace pcl {
48
49template <class FeatureType,
50 class DataSet,
51 class LabelType,
52 class ExampleIndex,
53 class NodeType>
56
57template <class FeatureType,
58 class DataSet,
59 class LabelType,
60 class ExampleIndex,
61 class NodeType>
62void
67 DataSet& data_set,
68 std::vector<ExampleIndex>& examples,
69 std::vector<LabelType>& label_data)
70{
71 const std::size_t num_of_examples = examples.size();
72 const std::size_t num_of_branches = stats_estimator.getNumOfBranches();
73 const std::size_t num_of_features = fern.getNumOfFeatures();
74
75 label_data.resize(num_of_examples);
76
77 std::vector<std::vector<float>> results(num_of_features);
78 std::vector<std::vector<unsigned char>> flags(num_of_features);
79 std::vector<std::vector<unsigned char>> branch_indices(num_of_features);
80
81 for (std::size_t feature_index = 0; feature_index < num_of_features;
82 ++feature_index) {
83 results[feature_index].reserve(num_of_examples);
84 flags[feature_index].reserve(num_of_examples);
85 branch_indices[feature_index].reserve(num_of_examples);
86
87 feature_handler.evaluateFeature(fern.accessFeature(feature_index),
88 data_set,
89 examples,
90 results[feature_index],
91 flags[feature_index]);
92 stats_estimator.computeBranchIndices(results[feature_index],
93 flags[feature_index],
94 fern.accessThreshold(feature_index),
95 branch_indices[feature_index]);
96 }
97
98 for (std::size_t example_index = 0; example_index < num_of_examples;
99 ++example_index) {
100 std::size_t node_index = 0;
101 for (std::size_t feature_index = 0; feature_index < num_of_features;
102 ++feature_index) {
103 node_index *= num_of_branches;
104 node_index += branch_indices[feature_index][example_index];
105 }
106
107 label_data[example_index] = stats_estimator.getLabelOfNode(fern[node_index]);
108 }
109}
110
111template <class FeatureType,
112 class DataSet,
113 class LabelType,
114 class ExampleIndex,
115 class NodeType>
116void
121 DataSet& data_set,
122 std::vector<ExampleIndex>& examples,
123 std::vector<LabelType>& label_data)
124{
125 const std::size_t num_of_examples = examples.size();
126 const std::size_t num_of_branches = stats_estimator.getNumOfBranches();
127 const std::size_t num_of_features = fern.getNumOfFeatures();
128
129 std::vector<std::vector<float>> results(num_of_features);
130 std::vector<std::vector<unsigned char>> flags(num_of_features);
131 std::vector<std::vector<unsigned char>> branch_indices(num_of_features);
132
133 for (std::size_t feature_index = 0; feature_index < num_of_features;
134 ++feature_index) {
135 results[feature_index].reserve(num_of_examples);
136 flags[feature_index].reserve(num_of_examples);
137 branch_indices[feature_index].reserve(num_of_examples);
138
139 feature_handler.evaluateFeature(fern.accessFeature(feature_index),
140 data_set,
141 examples,
142 results[feature_index],
143 flags[feature_index]);
144 stats_estimator.computeBranchIndices(results[feature_index],
145 flags[feature_index],
146 fern.accessThreshold(feature_index),
147 branch_indices[feature_index]);
148 }
149
150 for (std::size_t example_index = 0; example_index < num_of_examples;
151 ++example_index) {
152 std::size_t node_index = 0;
153 for (std::size_t feature_index = 0; feature_index < num_of_features;
154 ++feature_index) {
155 node_index *= num_of_branches;
156 node_index += branch_indices[feature_index][example_index];
157 }
158
159 label_data[example_index] = stats_estimator.getLabelOfNode(fern[node_index]);
160 }
161}
162
163template <class FeatureType,
164 class DataSet,
165 class LabelType,
166 class ExampleIndex,
167 class NodeType>
168void
173 DataSet& data_set,
174 std::vector<ExampleIndex>& examples,
175 std::vector<NodeType*>& nodes)
176{
177 const std::size_t num_of_examples = examples.size();
178 const std::size_t num_of_branches = stats_estimator.getNumOfBranches();
179 const std::size_t num_of_features = fern.getNumOfFeatures();
180
181 nodes.reserve(num_of_examples);
182
183 std::vector<std::vector<float>> results(num_of_features);
184 std::vector<std::vector<unsigned char>> flags(num_of_features);
185 std::vector<std::vector<unsigned char>> branch_indices(num_of_features);
186
187 for (std::size_t feature_index = 0; feature_index < num_of_features;
188 ++feature_index) {
189 results[feature_index].reserve(num_of_examples);
190 flags[feature_index].reserve(num_of_examples);
191 branch_indices[feature_index].reserve(num_of_examples);
192
193 feature_handler.evaluateFeature(fern.accessFeature(feature_index),
194 data_set,
195 examples,
196 results[feature_index],
197 flags[feature_index]);
198 stats_estimator.computeBranchIndices(results[feature_index],
199 flags[feature_index],
200 fern.accessThreshold(feature_index),
201 branch_indices[feature_index]);
202 }
203
204 for (std::size_t example_index = 0; example_index < num_of_examples;
205 ++example_index) {
206 std::size_t node_index = 0;
207 for (std::size_t feature_index = 0; feature_index < num_of_features;
208 ++feature_index) {
209 node_index *= num_of_branches;
210 node_index += branch_indices[feature_index][example_index];
211 }
212
213 nodes.push_back(&(fern[node_index]));
214 }
215}
216
217} // namespace pcl
Utility class interface which is used for creating and evaluating features.
virtual void evaluateFeature(const FeatureType &feature, DataSet &data_set, std::vector< ExampleIndex > &examples, std::vector< float > &results, std::vector< unsigned char > &flags) const =0
Evaluates a feature on the specified data.
void evaluate(pcl::Fern< FeatureType, NodeType > &fern, pcl::FeatureHandler< FeatureType, DataSet, ExampleIndex > &feature_handler, pcl::StatsEstimator< LabelType, NodeType, DataSet, ExampleIndex > &stats_estimator, DataSet &data_set, std::vector< ExampleIndex > &examples, std::vector< LabelType > &label_data)
Evaluates the specified examples using the supplied tree.
void evaluateAndAdd(pcl::Fern< FeatureType, NodeType > &fern, pcl::FeatureHandler< FeatureType, DataSet, ExampleIndex > &feature_handler, pcl::StatsEstimator< LabelType, NodeType, DataSet, ExampleIndex > &stats_estimator, DataSet &data_set, std::vector< ExampleIndex > &examples, std::vector< LabelType > &label_data)
Evaluates the specified examples using the supplied tree and adds the results to the supplied results...
FernEvaluator()
Constructor.
void getNodes(pcl::Fern< FeatureType, NodeType > &fern, pcl::FeatureHandler< FeatureType, DataSet, ExampleIndex > &feature_handler, pcl::StatsEstimator< LabelType, NodeType, DataSet, ExampleIndex > &stats_estimator, DataSet &data_set, std::vector< ExampleIndex > &examples, std::vector< NodeType * > &nodes)
Evaluates the specified examples using the supplied tree.
Class representing a Fern.
Definition fern.h:49
float & accessThreshold(const std::size_t threshold_index)
Access operator for thresholds.
Definition fern.h:177
std::size_t getNumOfFeatures()
Returns the number of features the Fern has.
Definition fern.h:76
FeatureType & accessFeature(const std::size_t feature_index)
Access operator for features.
Definition fern.h:157
Class interface for gathering statistics for decision tree learning.
virtual std::size_t getNumOfBranches() const =0
Returns the number of branches a node can have (e.g.
virtual LabelDataType getLabelOfNode(NodeType &node) const =0
Returns the label of the specified node.
virtual void computeBranchIndices(std::vector< float > &results, std::vector< unsigned char > &flags, const float threshold, std::vector< unsigned char > &branch_indices) const =0
Computes the branch indices obtained by the specified threshold on the supplied feature evaluation re...
Define standard C methods and C++ classes that are common to all methods.