37#ifndef VIGRA_RANDOM_FOREST_HXX
38#define VIGRA_RANDOM_FOREST_HXX
46#include "mathutil.hxx"
47#include "array_vector.hxx"
48#include "sized_int.hxx"
50#include "metaprogramming.hxx"
52#include "functorexpression.hxx"
53#include "random_forest/rf_common.hxx"
54#include "random_forest/rf_nodeproxy.hxx"
55#include "random_forest/rf_split.hxx"
56#include "random_forest/rf_decisionTree.hxx"
57#include "random_forest/rf_visitors.hxx"
58#include "random_forest/rf_region.hxx"
59#include "sampling.hxx"
60#include "random_forest/rf_preprocessing.hxx"
61#include "random_forest/rf_online_prediction_set.hxx"
62#include "random_forest/rf_earlystopping.hxx"
63#include "random_forest/rf_ridge_split.hxx"
83inline SamplerOptions make_sampler_opt ( RandomForestOptions & RF_opt)
85 SamplerOptions return_opt;
87 return_opt.
stratified(RF_opt.stratification_method_ == RF_EQUAL);
146template <
class LabelType =
double ,
class PreprocessorTag = ClassificationTag >
153 typedef detail::DecisionTree DecisionTree_t;
160 typedef LabelType LabelT;
227 template<
class TopologyIterator,
class ParameterIterator>
243 trees_[
k].topology_ = *topology_begin;
262 vigra_precondition(ext_param_.used() ==
true,
263 "RandomForest::ext_param(): "
264 "Random forest has not been trained yet.");
281 vigra_precondition(ext_param_.used() ==
false,
282 "RandomForest::set_ext_param():"
283 "Random forest has been trained! Call reset()"
284 "before specifying new extrinsic parameters.");
308 DecisionTree_t
const &
tree(
int index)
const
310 return trees_[index];
315 DecisionTree_t &
tree(
int index)
317 return trees_[index];
325 return ext_param_.column_count_;
336 return ext_param_.column_count_;
344 return ext_param_.class_count_;
351 return options_.tree_count_;
392 template <
class U,
class C1,
403 Random_t
const & random);
405 template <
class U,
class C1,
426 template <
class U,
class C1,
class U2,
class C2,
class Visitor_t>
427 void learn( MultiArrayView<2, U, C1>
const & features,
428 MultiArrayView<2, U2,C2>
const & labels,
438 template <
class U,
class C1,
class U2,
class C2,
439 class Visitor_t,
class Split_t>
440 void learn( MultiArrayView<2, U, C1>
const & features,
441 MultiArrayView<2, U2,C2>
const & labels,
470 template <
class U,
class C1,
class U2,
class C2>
482 template<
class U,
class C1,
495 bool adjust_thresholds=
false);
497 template <
class U,
class C1,
class U2,
class C2>
502 onlineLearn(features,
512 template<
class U,
class C1,
518 void reLearnTree(MultiArrayView<2,U,C1>
const & features,
519 MultiArrayView<2,U2,C2>
const & response,
526 template<
class U,
class C1,
class U2,
class C2>
527 void reLearnTree(MultiArrayView<2, U, C1>
const & features,
528 MultiArrayView<2, U2, C2>
const & labels,
531 RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed);
561 template <
class U,
class C,
class Stop>
562 LabelType
predictLabel(MultiArrayView<2, U, C>
const & features, Stop & stop)
const;
564 template <
class U,
class C>
565 LabelType
predictLabel(MultiArrayView<2, U, C>
const & features)
575 template <
class U,
class C>
576 LabelType
predictLabel(MultiArrayView<2, U, C>
const & features,
577 ArrayVectorView<double> prior)
const;
589 template <
class U,
class C1,
class T,
class C2>
593 vigra_precondition(features.shape(0) == labels.shape(0),
594 "RandomForest::predictLabels(): Label array has wrong size.");
595 for(
int k=0;
k<features.shape(0); ++
k)
597 vigra_precondition(!detail::contains_nan(rowVector(features,
k)),
598 "RandomForest::predictLabels(): NaN in feature matrix.");
613 template <
class U,
class C1,
class T,
class C2>
618 vigra_precondition(features.shape(0) == labels.shape(0),
619 "RandomForest::predictLabels(): Label array has wrong size.");
620 for(
int k=0;
k<features.shape(0); ++
k)
622 if(detail::contains_nan(rowVector(features,
k)))
638 template <
class U,
class C1,
class T,
class C2,
class Stop>
643 vigra_precondition(features.shape(0) == labels.shape(0),
644 "RandomForest::predictLabels(): Label array has wrong size.");
645 for(
int k=0;
k<features.shape(0); ++
k)
646 labels(
k,0) = detail::RequiresExplicitCast<T>::cast(
predictLabel(rowVector(features,
k), stop));
660 template <
class U,
class C1,
class T,
class C2,
class Stop>
664 template <
class T1,
class T2,
class C>
674 template <
class U,
class C1,
class T,
class C2>
681 template <
class U,
class C1,
class T,
class C2>
691template <
class LabelType,
class PreprocessorTag>
692template<
class U,
class C1,
698void RandomForest<LabelType, PreprocessorTag>::onlineLearn(MultiArrayView<2,U,C1>
const & features,
699 MultiArrayView<2,U2,C2>
const & response,
705 bool adjust_thresholds)
707 online_visitor_.activate();
708 online_visitor_.adjust_thresholds=adjust_thresholds;
720 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
722 typename RF_CHOOSER(
Stop_t)::type stop
725 typename RF_CHOOSER(
Split_t)::type split
728 typedef rf::visitors::detail::VisitorNode
729 <rf::visitors::OnlineLearnVisitor,
735 vigra_precondition(options_.prepare_online_learning_,
"onlineLearn: online learning must be enabled on RandomForest construction");
741 ext_param_.class_count_=0;
743 options_, ext_param_);
749 split.set_external_parameters(ext_param_);
750 stop.set_external_parameters(ext_param_);
762 online_visitor_.tree_id=
ii;
770 online_visitor_.current_label=
preprocessor.response()(sample,0);
771 online_visitor_.last_node_id=StackEntry_t::DecisionTreeNoParent;
772 int leaf=trees_[
ii].getToLeaf(rowVector(features,sample),online_visitor_);
776 online_visitor_.add_to_index_list(
ii,
leaf,sample);
779 if(Node<e_ConstProbNode>(trees_[
ii].topology_,trees_[
ii].parameters_,
leaf).prob_begin()[
preprocessor.response()(sample,0)]!=1.0)
791 int lin_index=online_visitor_.trees_online_information[
ii].exterior_to_index[
leaf];
797 ext_param_.class_count_);
802 if(NodeBase(trees_[
ii].topology_,trees_[
ii].parameters_,parent).child(0)==
leaf)
808 vigra_assert(NodeBase(trees_[
ii].topology_,trees_[
ii].parameters_,parent).child(1)==
leaf,
"last_node_id seems to be wrong");
815 online_visitor_.move_exterior_node(
ii,trees_[
ii].topology_.size(),
ii,
leaf);
828 online_visitor_.deactivate();
831template<
class LabelType,
class PreprocessorTag>
832template<
class U,
class C1,
853 ext_param_.class_count_=0;
861 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
863 typename RF_CHOOSER(
Stop_t)::type stop
866 typename RF_CHOOSER(
Split_t)::type split
875 vigra_precondition(options_.prepare_online_learning_,
"reLearnTree: Re learning trees only makes sense, if online learning is enabled");
876 online_visitor_.activate();
886 options_, ext_param_);
889 split.set_external_parameters(ext_param_);
890 stop.set_external_parameters(ext_param_);
898 detail::make_sampler_opt(options_)
899 .sampleSize(ext_param().actual_msample_),
908 ext_param_.class_count_);
913 online_visitor_.tree_id=
treeId;
924 .visit_after_tree( *
this,
930 online_visitor_.deactivate();
933template <
class LabelType,
class PreprocessorTag>
934template <
class U,
class C1,
946 Random_t
const & random)
957 vigra_precondition(features.shape(0) == response.shape(0),
958 "RandomForest::learn(): shape mismatch between features and response.");
965 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
967 typename RF_CHOOSER(
Stop_t)::type stop
970 typename RF_CHOOSER(
Split_t)::type split
979 if(options_.prepare_online_learning_)
980 online_visitor_.activate();
982 online_visitor_.deactivate();
994 options_, ext_param_);
997 split.set_external_parameters(ext_param_);
998 stop.set_external_parameters(ext_param_);
1002 trees_.resize(options_.tree_count_ , DecisionTree_t(ext_param_));
1006 detail::make_sampler_opt(options_)
1007 .sampleSize(ext_param().actual_msample_),
1021 ext_param_.class_count_);
1034 .visit_after_tree( *
this,
1043 online_visitor_.deactivate();
1049template <
class LabelType,
class Tag>
1050template <
class U,
class C,
class Stop>
1054 vigra_precondition(columnCount(features) >= ext_param_.column_count_,
1055 "RandomForestn::predictLabel():"
1056 " Too few columns in feature matrix.");
1057 vigra_precondition(rowCount(features) == 1,
1058 "RandomForestn::predictLabel():"
1059 " Feature matrix must have a singlerow.");
1069template <
class LabelType,
class PreprocessorTag>
1070template <
class U,
class C>
1075 using namespace functor;
1076 vigra_precondition(columnCount(features) >= ext_param_.column_count_,
1077 "RandomForestn::predictLabel(): Too few columns in feature matrix.");
1078 vigra_precondition(rowCount(features) == 1,
1079 "RandomForestn::predictLabel():"
1080 " Feature matrix must have a single row.");
1082 predictProbabilities(features,
prob);
1091template<
class LabelType,
class PreprocessorTag>
1092template <
class T1,
class T2,
class C>
1101 "RandomFroest::predictProbabilities():"
1102 " Feature matrix and probability matrix size mismatch.");
1105 vigra_precondition( columnCount(
predictionSet.features) >= ext_param_.column_count_,
1106 "RandomForestn::predictProbabilities():"
1107 " Too few columns in feature matrix.");
1108 vigra_precondition( columnCount(
prob)
1110 "RandomForestn::predictProbabilities():"
1111 " Probability matrix must have as many columns as there are classes.");
1117 for(
int k=0;
k<options_.tree_count_; ++
k)
1120 typedef std::set<SampleRange<T1> >
my_set;
1121 typedef typename my_set::iterator
set_it;
1124 std::vector<std::pair<int,set_it> >
stack;
1128 stack.push_back(std::pair<int,set_it>(2,
i));
1131 while(!
stack.empty())
1134 int index=
stack.back().first;
1138 if(trees_[
k].isLeafNode(trees_[
k].topology_[index]))
1141 trees_[
k].parameters_,
1142 index).prob_begin();
1143 for(
int i=range->start;
i!=range->
end;++
i)
1146 for(
int l=0;
l<ext_param_.class_count_; ++
l)
1157 if(trees_[
k].topology_[index]!=i_ThresholdNode)
1159 throw std::runtime_error(
"predicting with online prediction sets is only supported for RFs with threshold nodes");
1161 Node<i_ThresholdNode> node(trees_[
k].topology_,trees_[
k].parameters_,index);
1162 if(range->min_boundaries[node.column()]>=node.threshold())
1165 stack.push_back(std::pair<int,set_it>(node.child(1),range));
1168 if(range->max_boundaries[node.column()]<node.threshold())
1171 stack.push_back(std::pair<int,set_it>(node.child(0),range));
1177 range->max_boundaries[node.column()]=-
FLT_MAX;
1180 while(
i!=range->
end)
1185 new_range.min_boundaries[node.column()]=std::min(
new_range.min_boundaries[node.column()],
1194 range->max_boundaries[node.column()]=std::max(range->max_boundaries[node.column()],
1200 if(range->start==range->
end)
1206 stack.push_back(std::pair<int,set_it>(node.child(0),range));
1212 stack.push_back(std::pair<int,set_it>(node.child(1),
new_it.first));
1222 for(
int l=0;
l<ext_param_.class_count_; ++
l)
1232template <
class LabelType,
class PreprocessorTag>
1233template <
class U,
class C1,
class T,
class C2,
class Stop_t>
1234void RandomForest<LabelType, PreprocessorTag>
1235 ::predictProbabilities(MultiArrayView<2, U, C1>
const & features,
1236 MultiArrayView<2, T, C2> & prob,
1237 Stop_t & stop_)
const
1243 "RandomForestn::predictProbabilities():"
1244 " Feature matrix and probability matrix size mismatch.");
1248 vigra_precondition(
columnCount(features) >= ext_param_.column_count_,
1249 "RandomForestn::predictProbabilities():"
1250 " Too few columns in feature matrix.");
1253 "RandomForestn::predictProbabilities():"
1254 " Probability matrix must have as many columns as there are classes.");
1256 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
1258 typename RF_CHOOSER(
Stop_t)::type & stop
1261 stop.set_external_parameters(ext_param_, tree_count());
1290 for(
int k=0;
k<options_.tree_count_; ++
k)
1296 int weighted = options_.predict_weighted_;
1297 for(
int l=0;
l<ext_param_.class_count_; ++
l)
1305 if(stop.after_prediction(weights,
1315 for(
int l=0;
l< ext_param_.class_count_; ++
l)
1323template <
class LabelType,
class PreprocessorTag>
1324template <
class U,
class C1,
class T,
class C2>
1325void RandomForest<LabelType, PreprocessorTag>
1326 ::predictRaw(MultiArrayView<2, U, C1>
const & features,
1327 MultiArrayView<2, T, C2> & prob)
const
1333 "RandomForestn::predictProbabilities():"
1334 " Feature matrix and probability matrix size mismatch.");
1338 vigra_precondition(
columnCount(features) >= ext_param_.column_count_,
1339 "RandomForestn::predictProbabilities():"
1340 " Too few columns in feature matrix.");
1343 "RandomForestn::predictProbabilities():"
1344 " Probability matrix must have as many columns as there are classes.");
1346 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
1362 for(
int k=0;
k<options_.tree_count_; ++
k)
1368 int weighted = options_.predict_weighted_;
1369 for(
int l=0;
l<ext_param_.class_count_; ++
l)
1377 prob/= options_.tree_count_;
1383#include "random_forest/rf_algorithm.hxx"
Standard early stopping criterion.
Definition rf_common.hxx:886
Class for a single RGB value.
Definition rgbvalue.hxx:128
RGBValue()
Definition rgbvalue.hxx:209
Options object for the random forest.
Definition rf_common.hxx:171
Random forest version 2 (see also vigra::rf3::RandomForest for version 3)
Definition random_forest.hxx:148
RandomForest(int treeCount, TopologyIterator topology_begin, ParameterIterator parameter_begin, ProblemSpec_t const &problem_spec, Options_t const &options=Options_t())
Create RF from external source.
Definition random_forest.hxx:228
Options_t const & options() const
access const random forest options
Definition random_forest.hxx:301
void learn(MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &response, Visitor_t visitor, Split_t split, Stop_t stop, Random_t const &random)
learn on data with custom config and random number generator
Definition random_forest.hxx:941
DecisionTree_t const & tree(int index) const
access const trees
Definition random_forest.hxx:308
void set_ext_param(ProblemSpec_t const &in)
set external parameters
Definition random_forest.hxx:278
DecisionTree_t & tree(int index)
access trees
Definition random_forest.hxx:315
int tree_count() const
return number of trees
Definition random_forest.hxx:349
RandomForest(Options_t const &options=Options_t(), ProblemSpec_t const &ext_param=ProblemSpec_t())
default constructor
Definition random_forest.hxx:197
void predictProbabilities(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &prob, Stop &stop) const
predict the class probabilities for multiple labels
int column_count() const
return number of features used while training.
Definition random_forest.hxx:334
void predictLabels(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &labels, LabelType nanLabel) const
predict multiple labels with given features
Definition random_forest.hxx:614
int feature_count() const
return number of features used while training.
Definition random_forest.hxx:323
void predictLabels(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &labels) const
predict multiple labels with given features
Definition random_forest.hxx:590
int class_count() const
return number of classes used while training.
Definition random_forest.hxx:342
LabelType predictLabel(MultiArrayView< 2, U, C >const &features, Stop &stop) const
predict a label given a feature.
Definition random_forest.hxx:1052
Options_t & set_options()
access random forest options
Definition random_forest.hxx:291
ProblemSpec_t const & ext_param() const
return external parameters for viewing
Definition random_forest.hxx:260
void reLearnTree(MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &response, int treeId, Visitor_t visitor_, Split_t split_, Stop_t stop_, Random_t &random)
Definition random_forest.hxx:838
void learn(MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &labels)
learn on data with default configuration
Definition random_forest.hxx:471
void predictLabels(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &labels, Stop &stop) const
predict multiple labels with given features
Definition random_forest.hxx:639
void predictProbabilities(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &prob) const
predict the class probabilities for multiple labels
Definition random_forest.hxx:675
SamplerOptions & withReplacement(bool in=true)
Sample from training population with replacement.
Definition sampling.hxx:83
SamplerOptions & stratified(bool in=true)
Draw equally many samples from each "stratum". A stratum is a group of like entities,...
Definition sampling.hxx:141
void init(Iterator i, Iterator end)
Definition tinyvector.hxx:708
size_type size() const
Definition tinyvector.hxx:913
iterator end()
Definition tinyvector.hxx:864
iterator begin()
Definition tinyvector.hxx:861
Definition rf_visitors.hxx:585
void reset_tree(int tree_id)
Definition rf_visitors.hxx:636
Definition rf_visitors.hxx:236
Definition rf_visitors.hxx:256
MultiArrayIndex columnCount(const MultiArrayView< 2, T, C > &x)
Definition matrix.hxx:684
MultiArrayView< 2, T, C > rowVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition matrix.hxx:697
MultiArrayIndex rowCount(const MultiArrayView< 2, T, C > &x)
Definition matrix.hxx:671
Iterator argMax(Iterator first, Iterator last)
Find the maximum element in a sequence.
Definition algorithm.hxx:96
std::ptrdiff_t MultiArrayIndex
Definition multi_fwd.hxx:60
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int
Definition sized_int.hxx:175
detail::RF_DEFAULT & rf_default()
factory function to return a RF_DEFAULT tag
Definition rf_common.hxx:131