[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

rf_visitors.hxx
1/************************************************************************/
2/* */
3/* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */
4/* */
5/* This file is part of the VIGRA computer vision library. */
6/* The VIGRA Website is */
7/* http://hci.iwr.uni-heidelberg.de/vigra/ */
8/* Please direct questions, bug reports, and contributions to */
9/* ullrich.koethe@iwr.uni-heidelberg.de or */
10/* vigra@informatik.uni-hamburg.de */
11/* */
12/* Permission is hereby granted, free of charge, to any person */
13/* obtaining a copy of this software and associated documentation */
14/* files (the "Software"), to deal in the Software without */
15/* restriction, including without limitation the rights to use, */
16/* copy, modify, merge, publish, distribute, sublicense, and/or */
17/* sell copies of the Software, and to permit persons to whom the */
18/* Software is furnished to do so, subject to the following */
19/* conditions: */
20/* */
21/* The above copyright notice and this permission notice shall be */
22/* included in all copies or substantial portions of the */
23/* Software. */
24/* */
25/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */
26/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */
27/* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */
28/* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */
29/* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */
30/* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */
31/* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */
32/* OTHER DEALINGS IN THE SOFTWARE. */
33/* */
34/************************************************************************/
35#ifndef RF_VISITORS_HXX
36#define RF_VISITORS_HXX
37
38#ifdef HasHDF5
39# include "vigra/hdf5impex.hxx"
40#endif // HasHDF5
41#include <vigra/windows.h>
42#include <iostream>
43#include <iomanip>
44#include <random>
45
46#include <vigra/metaprogramming.hxx>
47#include <vigra/multi_pointoperators.hxx>
48#include <vigra/timing.hxx>
49
50namespace vigra
51{
52namespace rf
53{
54/** \brief Visitors to extract information during training of \ref vigra::RandomForest version 2.
55
56 \ingroup MachineLearning
57
58 This namespace contains all classes and methods related to extracting information during
59 learning of the random forest. All Visitors share the same interface defined in
60 visitors::VisitorBase. The member methods are invoked at certain points of the main code in
61 the order they were supplied.
62
63 For the Random Forest the Visitor concept is implemented as a statically linked list
64 (Using templates). Each Visitor object is encapsulated in a detail::VisitorNode object. The
65 VisitorNode object calls the Next Visitor after one of its visit() methods have terminated.
66
67 To simplify usage create_visitor() factory methods are supplied.
68 Use the create_visitor() method to supply visitor objects to the RandomForest::learn() method.
69 It is possible to supply more than one visitor. They will then be invoked in serial order.
70
71 The calculated information are stored as public data members of the class. - see documentation
72 of the individual visitors
73
74 While creating a new visitor the new class should therefore publicly inherit from this class
75 (i.e.: see visitors::OOB_Error).
76
77 \code
78
79 typedef xxx feature_t \\ replace xxx with whichever type
80 typedef yyy label_t \\ meme chose.
81 MultiArrayView<2, feature_t> f = get_some_features();
82 MultiArrayView<2, label_t> l = get_some_labels();
83 RandomForest<> rf()
84
85 //calculate OOB Error
86 visitors::OOB_Error oob_v;
87 //calculate Variable Importance
88 visitors::VariableImportanceVisitor varimp_v;
89
90 double oob_error = rf.learn(f, l, visitors::create_visitor(oob_v, varimp_v);
91 //the data can be found in the attributes of oob_v and varimp_v now
92
93 \endcode
94*/
95namespace visitors
96{
97
98
99/** Base Class from which all Visitors derive. Can be used as a template to create new
100 * Visitors.
101 */
103{
104 public:
105 bool active_;
106 bool is_active()
107 {
108 return active_;
109 }
110
111 bool has_value()
112 {
113 return false;
114 }
115
117 : active_(true)
118 {}
119
120 void deactivate()
121 {
122 active_ = false;
123 }
124 void activate()
125 {
126 active_ = true;
127 }
128
129 /** do something after the the Split has decided how to process the Region
130 * (Stack entry)
131 *
132 * \param tree reference to the tree that is currently being learned
133 * \param split reference to the split object
134 * \param parent current stack entry which was used to decide the split
135 * \param leftChild left stack entry that will be pushed
136 * \param rightChild
137 * right stack entry that will be pushed.
138 * \param features features matrix
139 * \param labels label matrix
140 * \sa RF_Traits::StackEntry_t
141 */
142 template<class Tree, class Split, class Region, class Feature_t, class Label_t>
144 Split & split,
145 Region & parent,
148 Feature_t & features,
149 Label_t & labels)
150 {
151 ignore_argument(tree,split,parent,leftChild,rightChild,features,labels);
152 }
153
154 /** do something after each tree has been learned
155 *
156 * \param rf reference to the random forest object that called this
157 * visitor
158 * \param pr reference to the preprocessor that processed the input
159 * \param sm reference to the sampler object
160 * \param st reference to the first stack entry
161 * \param index index of current tree
162 */
163 template<class RF, class PR, class SM, class ST>
164 void visit_after_tree(RF & rf, PR & pr, SM & sm, ST & st, int index)
165 {
166 ignore_argument(rf,pr,sm,st,index);
167 }
168
169 /** do something after all trees have been learned
170 *
171 * \param rf reference to the random forest object that called this
172 * visitor
173 * \param pr reference to the preprocessor that processed the input
174 */
175 template<class RF, class PR>
176 void visit_at_end(RF const & rf, PR const & pr)
177 {
178 ignore_argument(rf,pr);
179 }
180
181 /** do something before learning starts
182 *
183 * \param rf reference to the random forest object that called this
184 * visitor
185 * \param pr reference to the Processor class used.
186 */
187 template<class RF, class PR>
188 void visit_at_beginning(RF const & rf, PR const & pr)
189 {
190 ignore_argument(rf,pr);
191 }
192 /** do some thing while traversing tree after it has been learned
193 * (external nodes)
194 *
195 * \param tr reference to the tree object that called this visitor
196 * \param index index in the topology_ array we currently are at
197 * \param node_t type of node we have (will be e_.... - )
198 * \param features feature matrix
199 * \sa NodeTags;
200 *
201 * you can create the node by using a switch on node_tag and using the
202 * corresponding Node objects. Or - if you do not care about the type
203 * use the NodeBase class.
204 */
205 template<class TR, class IntT, class TopT,class Feat>
206 void visit_external_node(TR & tr, IntT index, TopT node_t, Feat & features)
207 {
208 ignore_argument(tr,index,node_t,features);
209 }
210
211 /** do something when visiting a internal node after it has been learned
212 *
213 * \sa visit_external_node
214 */
215 template<class TR, class IntT, class TopT,class Feat>
216 void visit_internal_node(TR & /* tr */, IntT /* index */, TopT /* node_t */, Feat & /* features */)
217 {}
218
219 /** return a double value. The value of the first
220 * visitor encountered that has a return value is returned with the
221 * RandomForest::learn() method - or -1.0 if no return value visitor
222 * existed. This functionality basically only exists so that the
223 * OOB - visitor can return the oob error rate like in the old version
224 * of the random forest.
225 */
226 double return_val()
227 {
228 return -1.0;
229 }
230};
231
232
233/** Last Visitor that should be called to stop the recursion.
234 */
236{
237 public:
238 bool has_value()
239 {
240 return true;
241 }
242 double return_val()
243 {
244 return -1.0;
245 }
246};
247namespace detail
248{
249/** Container elements of the statically linked Visitor list.
250 *
251 * use the create_visitor() factory functions to create visitors up to size 10;
252 *
253 */
254template <class Visitor, class Next = StopVisiting>
256{
257 public:
258
259 StopVisiting stop_;
260 Next next_;
261 Visitor & visitor_;
262 VisitorNode(Visitor & visitor, Next & next)
263 :
264 next_(next), visitor_(visitor)
265 {}
266
267 VisitorNode(Visitor & visitor)
268 :
269 next_(stop_), visitor_(visitor)
270 {}
271
272 template<class Tree, class Split, class Region, class Feature_t, class Label_t>
273 void visit_after_split( Tree & tree,
274 Split & split,
275 Region & parent,
278 Feature_t & features,
279 Label_t & labels)
280 {
281 if(visitor_.is_active())
282 visitor_.visit_after_split(tree, split,
283 parent, leftChild, rightChild,
284 features, labels);
285 next_.visit_after_split(tree, split, parent, leftChild, rightChild,
286 features, labels);
287 }
288
289 template<class RF, class PR, class SM, class ST>
290 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index)
291 {
292 if(visitor_.is_active())
293 visitor_.visit_after_tree(rf, pr, sm, st, index);
294 next_.visit_after_tree(rf, pr, sm, st, index);
295 }
296
297 template<class RF, class PR>
298 void visit_at_beginning(RF & rf, PR & pr)
299 {
300 if(visitor_.is_active())
301 visitor_.visit_at_beginning(rf, pr);
302 next_.visit_at_beginning(rf, pr);
303 }
304 template<class RF, class PR>
305 void visit_at_end(RF & rf, PR & pr)
306 {
307 if(visitor_.is_active())
308 visitor_.visit_at_end(rf, pr);
309 next_.visit_at_end(rf, pr);
310 }
311
312 template<class TR, class IntT, class TopT,class Feat>
313 void visit_external_node(TR & tr, IntT & index, TopT & node_t,Feat & features)
314 {
315 if(visitor_.is_active())
316 visitor_.visit_external_node(tr, index, node_t,features);
317 next_.visit_external_node(tr, index, node_t,features);
318 }
319 template<class TR, class IntT, class TopT,class Feat>
320 void visit_internal_node(TR & tr, IntT & index, TopT & node_t,Feat & features)
321 {
322 if(visitor_.is_active())
323 visitor_.visit_internal_node(tr, index, node_t,features);
324 next_.visit_internal_node(tr, index, node_t,features);
325 }
326
327 double return_val()
328 {
329 if(visitor_.is_active() && visitor_.has_value())
330 return visitor_.return_val();
331 return next_.return_val();
332 }
333};
334
335} //namespace detail
336
337//////////////////////////////////////////////////////////////////////////////
338// Visitor Factory function up to 10 visitors //
339//////////////////////////////////////////////////////////////////////////////
340
341/** factory method to to be used with RandomForest::learn()
342 */
343template<class A>
346{
348 _0_t _0(a);
349 return _0;
350}
351
352
353/** factory method to to be used with RandomForest::learn()
354 */
355template<class A, class B>
356detail::VisitorNode<A, detail::VisitorNode<B> >
357create_visitor(A & a, B & b)
358{
360 _1_t _1(b);
362 _0_t _0(a, _1);
363 return _0;
364}
365
366
367/** factory method to to be used with RandomForest::learn()
368 */
369template<class A, class B, class C>
370detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C> > >
371create_visitor(A & a, B & b, C & c)
372{
374 _2_t _2(c);
376 _1_t _1(b, _2);
378 _0_t _0(a, _1);
379 return _0;
380}
381
382
383/** factory method to to be used with RandomForest::learn()
384 */
385template<class A, class B, class C, class D>
386detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
387 detail::VisitorNode<D> > > >
388create_visitor(A & a, B & b, C & c, D & d)
389{
391 _3_t _3(d);
393 _2_t _2(c, _3);
395 _1_t _1(b, _2);
397 _0_t _0(a, _1);
398 return _0;
399}
400
401
402/** factory method to to be used with RandomForest::learn()
403 */
404template<class A, class B, class C, class D, class E>
405detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
406 detail::VisitorNode<D, detail::VisitorNode<E> > > > >
407create_visitor(A & a, B & b, C & c,
408 D & d, E & e)
409{
411 _4_t _4(e);
413 _3_t _3(d, _4);
415 _2_t _2(c, _3);
417 _1_t _1(b, _2);
419 _0_t _0(a, _1);
420 return _0;
421}
422
423
424/** factory method to to be used with RandomForest::learn()
425 */
426template<class A, class B, class C, class D, class E,
427 class F>
428detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
429 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F> > > > > >
430create_visitor(A & a, B & b, C & c,
431 D & d, E & e, F & f)
432{
434 _5_t _5(f);
436 _4_t _4(e, _5);
438 _3_t _3(d, _4);
440 _2_t _2(c, _3);
442 _1_t _1(b, _2);
444 _0_t _0(a, _1);
445 return _0;
446}
447
448
449/** factory method to to be used with RandomForest::learn()
450 */
451template<class A, class B, class C, class D, class E,
452 class F, class G>
453detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
454 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
455 detail::VisitorNode<G> > > > > > >
456create_visitor(A & a, B & b, C & c,
457 D & d, E & e, F & f, G & g)
458{
460 _6_t _6(g);
462 _5_t _5(f, _6);
464 _4_t _4(e, _5);
466 _3_t _3(d, _4);
468 _2_t _2(c, _3);
470 _1_t _1(b, _2);
472 _0_t _0(a, _1);
473 return _0;
474}
475
476
477/** factory method to to be used with RandomForest::learn()
478 */
479template<class A, class B, class C, class D, class E,
480 class F, class G, class H>
481detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
482 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
483 detail::VisitorNode<G, detail::VisitorNode<H> > > > > > > >
484create_visitor(A & a, B & b, C & c,
485 D & d, E & e, F & f,
486 G & g, H & h)
487{
489 _7_t _7(h);
491 _6_t _6(g, _7);
493 _5_t _5(f, _6);
495 _4_t _4(e, _5);
497 _3_t _3(d, _4);
499 _2_t _2(c, _3);
501 _1_t _1(b, _2);
503 _0_t _0(a, _1);
504 return _0;
505}
506
507
508/** factory method to to be used with RandomForest::learn()
509 */
510template<class A, class B, class C, class D, class E,
511 class F, class G, class H, class I>
512detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
513 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
514 detail::VisitorNode<G, detail::VisitorNode<H, detail::VisitorNode<I> > > > > > > > >
515create_visitor(A & a, B & b, C & c,
516 D & d, E & e, F & f,
517 G & g, H & h, I & i)
518{
520 _8_t _8(i);
522 _7_t _7(h, _8);
524 _6_t _6(g, _7);
526 _5_t _5(f, _6);
528 _4_t _4(e, _5);
530 _3_t _3(d, _4);
532 _2_t _2(c, _3);
534 _1_t _1(b, _2);
536 _0_t _0(a, _1);
537 return _0;
538}
539
540/** factory method to to be used with RandomForest::learn()
541 */
542template<class A, class B, class C, class D, class E,
543 class F, class G, class H, class I, class J>
544detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
545 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
546 detail::VisitorNode<G, detail::VisitorNode<H, detail::VisitorNode<I,
547 detail::VisitorNode<J> > > > > > > > > >
548create_visitor(A & a, B & b, C & c,
549 D & d, E & e, F & f,
550 G & g, H & h, I & i,
551 J & j)
552{
554 _9_t _9(j);
556 _8_t _8(i, _9);
558 _7_t _7(h, _8);
560 _6_t _6(g, _7);
562 _5_t _5(f, _6);
564 _4_t _4(e, _5);
566 _3_t _3(d, _4);
568 _2_t _2(c, _3);
570 _1_t _1(b, _2);
572 _0_t _0(a, _1);
573 return _0;
574}
575
576//////////////////////////////////////////////////////////////////////////////
577// Visitors of communal interest. //
578//////////////////////////////////////////////////////////////////////////////
579
580
581/** Visitor to gain information, later needed for online learning.
582 */
583
585{
586public:
587 //Set if we adjust thresholds
588 bool adjust_thresholds;
589 //Current tree id
590 int tree_id;
591 //Last node id for finding parent
592 int last_node_id;
593 //Need to now the label for interior node visiting
594 vigra::Int32 current_label;
595 //marginal distribution for interior nodes
596 //
598 adjust_thresholds(false), tree_id(0), last_node_id(0), current_label(0)
599 {}
600 struct MarginalDistribution
601 {
602 ArrayVector<Int32> leftCounts;
603 Int32 leftTotalCounts;
604 ArrayVector<Int32> rightCounts;
605 Int32 rightTotalCounts;
606 double gap_left;
607 double gap_right;
608 };
610
611 //All information for one tree
612 struct TreeOnlineInformation
613 {
614 std::vector<MarginalDistribution> mag_distributions;
615 std::vector<IndexList> index_lists;
616 //map for linear index of mag_distributions
617 std::map<int,int> interior_to_index;
618 //map for linear index of index_lists
619 std::map<int,int> exterior_to_index;
620 };
621
622 //All trees
623 std::vector<TreeOnlineInformation> trees_online_information;
624
625 /** Initialize, set the number of trees
626 */
627 template<class RF,class PR>
628 void visit_at_beginning(RF & rf,const PR & /* pr */)
629 {
630 tree_id=0;
631 trees_online_information.resize(rf.options_.tree_count_);
632 }
633
634 /** Reset a tree
635 */
636 void reset_tree(int tree_id)
637 {
638 trees_online_information[tree_id].mag_distributions.clear();
639 trees_online_information[tree_id].index_lists.clear();
640 trees_online_information[tree_id].interior_to_index.clear();
641 trees_online_information[tree_id].exterior_to_index.clear();
642 }
643
644 /** simply increase the tree count
645 */
646 template<class RF, class PR, class SM, class ST>
647 void visit_after_tree(RF & /* rf */, PR & /* pr */, SM & /* sm */, ST & /* st */, int /* index */)
648 {
649 tree_id++;
650 }
651
652 template<class Tree, class Split, class Region, class Feature_t, class Label_t>
653 void visit_after_split( Tree & tree,
654 Split & split,
655 Region & parent,
658 Feature_t & features,
659 Label_t & /* labels */)
660 {
661 int linear_index;
662 int addr=tree.topology_.size();
663 if(split.createNode().typeID() == i_ThresholdNode)
664 {
665 if(adjust_thresholds)
666 {
667 //Store marginal distribution
668 linear_index=trees_online_information[tree_id].mag_distributions.size();
669 trees_online_information[tree_id].interior_to_index[addr]=linear_index;
670 trees_online_information[tree_id].mag_distributions.push_back(MarginalDistribution());
671
672 trees_online_information[tree_id].mag_distributions.back().leftCounts=leftChild.classCounts_;
673 trees_online_information[tree_id].mag_distributions.back().rightCounts=rightChild.classCounts_;
674
675 trees_online_information[tree_id].mag_distributions.back().leftTotalCounts=leftChild.size_;
676 trees_online_information[tree_id].mag_distributions.back().rightTotalCounts=rightChild.size_;
677 //Store the gap
678 double gap_left,gap_right;
679 int i;
680 gap_left=features(leftChild[0],split.bestSplitColumn());
681 for(i=1;i<leftChild.size();++i)
682 if(features(leftChild[i],split.bestSplitColumn())>gap_left)
683 gap_left=features(leftChild[i],split.bestSplitColumn());
684 gap_right=features(rightChild[0],split.bestSplitColumn());
685 for(i=1;i<rightChild.size();++i)
686 if(features(rightChild[i],split.bestSplitColumn())<gap_right)
687 gap_right=features(rightChild[i],split.bestSplitColumn());
688 trees_online_information[tree_id].mag_distributions.back().gap_left=gap_left;
689 trees_online_information[tree_id].mag_distributions.back().gap_right=gap_right;
690 }
691 }
692 else
693 {
694 //Store index list
695 linear_index=trees_online_information[tree_id].index_lists.size();
696 trees_online_information[tree_id].exterior_to_index[addr]=linear_index;
697
698 trees_online_information[tree_id].index_lists.push_back(IndexList());
699
700 trees_online_information[tree_id].index_lists.back().resize(parent.size_,0);
701 std::copy(parent.begin_,parent.end_,trees_online_information[tree_id].index_lists.back().begin());
702 }
703 }
704 void add_to_index_list(int tree,int node,int index)
705 {
706 if(!this->active_)
707 return;
708 TreeOnlineInformation &ti=trees_online_information[tree];
709 ti.index_lists[ti.exterior_to_index[node]].push_back(index);
710 }
711 void move_exterior_node(int src_tree,int src_index,int dst_tree,int dst_index)
712 {
713 if(!this->active_)
714 return;
715 trees_online_information[dst_tree].exterior_to_index[dst_index]=trees_online_information[src_tree].exterior_to_index[src_index];
716 trees_online_information[src_tree].exterior_to_index.erase(src_index);
717 }
718 /** do something when visiting a internal node during getToLeaf
719 *
720 * remember as last node id, for finding the parent of the last external node
721 * also: adjust class counts and borders
722 */
723 template<class TR, class IntT, class TopT,class Feat>
724 void visit_internal_node(TR & tr, IntT index, [[maybe_unused]] TopT node_t, Feat & features)
725 {
726 last_node_id=index;
727 if(adjust_thresholds)
728 {
729 vigra_assert(node_t==i_ThresholdNode,"We can only visit threshold nodes");
730 //Check if we are in the gap
731 double value=features(0, Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).column());
732 TreeOnlineInformation &ti=trees_online_information[tree_id];
733 MarginalDistribution &m=ti.mag_distributions[ti.interior_to_index[index]];
734 if(value>m.gap_left && value<m.gap_right)
735 {
736 //Check which site we want to go
737 if(m.leftCounts[current_label]/double(m.leftTotalCounts)>m.rightCounts[current_label]/double(m.rightTotalCounts))
738 {
739 //We want to go left
740 m.gap_left=value;
741 }
742 else
743 {
744 //We want to go right
745 m.gap_right=value;
746 }
747 Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold()=(m.gap_right+m.gap_left)/2.0;
748 }
749 //Adjust class counts
750 if(value>Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold())
751 {
752 ++m.rightTotalCounts;
753 ++m.rightCounts[current_label];
754 }
755 else
756 {
757 ++m.leftTotalCounts;
758 ++m.rightCounts[current_label];
759 }
760 }
761 }
762 /** do something when visiting a extern node during getToLeaf
763 *
764 * Store the new index!
765 */
766};
767
768//////////////////////////////////////////////////////////////////////////////
769// Out of Bag Error estimates //
770//////////////////////////////////////////////////////////////////////////////
771
772
773/** Visitor that calculates the oob error of each individual randomized
774 * decision tree.
775 *
776 * After training a tree, all those samples that are OOB for this particular tree
777 * are put down the tree and the error estimated.
778 * the per tree oob error is the average of the individual error estimates.
779 * (oobError = average error of one randomized tree)
780 * Note: This is Not the OOB - Error estimate suggested by Breiman (See OOB_Error
781 * visitor)
782 */
784{
785public:
786 /** Average error of one randomized decision tree
787 */
788 double oobError;
789
790 int totalOobCount;
791 ArrayVector<int> oobCount,oobErrorCount;
792
794 : oobError(0.0),
795 totalOobCount(0)
796 {}
797
798
799 bool has_value()
800 {
801 return true;
802 }
803
804
805 /** does the basic calculation per tree*/
806 template<class RF, class PR, class SM, class ST>
807 void visit_after_tree(RF & rf, PR & pr, SM & sm, ST &, int index)
808 {
809 //do the first time called.
810 if(int(oobCount.size()) != rf.ext_param_.row_count_)
811 {
812 oobCount.resize(rf.ext_param_.row_count_, 0);
813 oobErrorCount.resize(rf.ext_param_.row_count_, 0);
814 }
815 // go through the samples
816 for(int l = 0; l < rf.ext_param_.row_count_; ++l)
817 {
818 // if the lth sample is oob...
819 if(!sm.is_used()[l])
820 {
821 ++oobCount[l];
822 if( rf.tree(index)
823 .predictLabel(rowVector(pr.features(), l))
824 != pr.response()(l,0))
825 {
826 ++oobErrorCount[l];
827 }
828 }
829
830 }
831 }
832
833 /** Does the normalisation
834 */
835 template<class RF, class PR>
836 void visit_at_end(RF & rf, PR &)
837 {
838 // do some normalisation
839 for(int l=0; l < static_cast<int>(rf.ext_param_.row_count_); ++l)
840 {
841 if(oobCount[l])
842 {
843 oobError += double(oobErrorCount[l]) / oobCount[l];
844 ++totalOobCount;
845 }
846 }
847 oobError/=totalOobCount;
848 }
849
850};
851
852/** Visitor that calculates the oob error of the ensemble
853 *
854 * This rate serves as a quick estimate for the crossvalidation
855 * error rate.
856 * Here, each sample is put down the trees for which this sample
857 * is OOB, i.e., if sample #1 is OOB for trees 1, 3 and 5, we calculate
858 * the output using the ensemble consisting only of trees 1 3 and 5.
859 *
860 * Using normal bagged sampling each sample is OOB for approx. 33% of trees.
861 * The error rate obtained as such therefore corresponds to a crossvalidation
862 * rate obtained using a ensemble containing 33% of the trees.
863 */
864class OOB_Error : public VisitorBase
865{
867 int class_count;
868 bool is_weighted;
869 MultiArray<2,double> tmp_prob;
870 public:
871
872 MultiArray<2, double> prob_oob;
873 /** Ensemble oob error rate
874 */
876
877 MultiArray<2, double> oobCount;
878 ArrayVector< int> indices;
879 OOB_Error() : VisitorBase(), oob_breiman(0.0) {}
880#ifdef HasHDF5
881 void save(std::string filen, std::string pathn)
882 {
883 if(*(pathn.end()-1) != '/')
884 pathn += "/";
885 const char* filename = filen.c_str();
886 MultiArray<2, double> temp(Shp(1,1), 0.0);
887 temp[0] = oob_breiman;
888 writeHDF5(filename, (pathn + "breiman_error").c_str(), temp);
889 }
890#endif
891 // negative value if sample was ib, number indicates how often.
892 // value >=0 if sample was oob, 0 means fail 1, correct
893
894 template<class RF, class PR>
895 void visit_at_beginning(RF & rf, PR &)
896 {
897 class_count = rf.class_count();
898 tmp_prob.reshape(Shp(1, class_count), 0);
899 prob_oob.reshape(Shp(rf.ext_param().row_count_,class_count), 0);
900 is_weighted = rf.options().predict_weighted_;
901 indices.resize(rf.ext_param().row_count_);
902 if(int(oobCount.size()) != rf.ext_param_.row_count_)
903 {
904 oobCount.reshape(Shp(rf.ext_param_.row_count_, 1), 0);
905 }
906 for(int ii = 0; ii < rf.ext_param().row_count_; ++ii)
907 {
908 indices[ii] = ii;
909 }
910 }
911
912 template<class RF, class PR, class SM, class ST>
913 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST &, int index)
914 {
915 // go through the samples
916 // FIXME: magic number 10000: invoke special treatment when when msample << sample_count
917 // (i.e. the OOB sample ist very large)
918 // 40000: use at most 40000 OOB samples per class for OOB error estimate
919 if(rf.ext_param_.actual_msample_ < pr.features().shape(0) - 10000)
920 {
921 ArrayVector<int> oob_indices;
922 ArrayVector<int> cts(class_count, 0);
923 std::random_device rd;
924 std::mt19937 g(rd());
925 std::shuffle(indices.begin(), indices.end(), g);
926 for(int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
927 {
928 if(!sm.is_used()[indices[ii]] && cts[pr.response()(indices[ii], 0)] < 40000)
929 {
930 oob_indices.push_back(indices[ii]);
931 ++cts[pr.response()(indices[ii], 0)];
932 }
933 }
934 for(unsigned int ll = 0; ll < oob_indices.size(); ++ll)
935 {
936 // update number of trees in which current sample is oob
937 ++oobCount[oob_indices[ll]];
938
939 // get the predicted votes ---> tmp_prob;
940 int pos = rf.tree(index).getToLeaf(rowVector(pr.features(),oob_indices[ll]));
941 Node<e_ConstProbNode> node ( rf.tree(index).topology_,
942 rf.tree(index).parameters_,
943 pos);
944 tmp_prob.init(0);
945 for(int ii = 0; ii < class_count; ++ii)
946 {
947 tmp_prob[ii] = node.prob_begin()[ii];
948 }
949 if(is_weighted)
950 {
951 for(int ii = 0; ii < class_count; ++ii)
952 tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
953 }
954 rowVector(prob_oob, oob_indices[ll]) += tmp_prob;
955
956 }
957 }else
958 {
959 for(int ll = 0; ll < rf.ext_param_.row_count_; ++ll)
960 {
961 // if the lth sample is oob...
962 if(!sm.is_used()[ll])
963 {
964 // update number of trees in which current sample is oob
965 ++oobCount[ll];
966
967 // get the predicted votes ---> tmp_prob;
968 int pos = rf.tree(index).getToLeaf(rowVector(pr.features(),ll));
969 Node<e_ConstProbNode> node ( rf.tree(index).topology_,
970 rf.tree(index).parameters_,
971 pos);
972 tmp_prob.init(0);
973 for(int ii = 0; ii < class_count; ++ii)
974 {
975 tmp_prob[ii] = node.prob_begin()[ii];
976 }
977 if(is_weighted)
978 {
979 for(int ii = 0; ii < class_count; ++ii)
980 tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
981 }
982 rowVector(prob_oob, ll) += tmp_prob;
983 }
984 }
985 }
986 // go through the ib samples;
987 }
988
989 /** Normalise variable importance after the number of trees is known.
990 */
991 template<class RF, class PR>
992 void visit_at_end(RF & rf, PR & pr)
993 {
994 // ullis original metric and breiman style stuff
995 int totalOobCount =0;
996 int breimanstyle = 0;
997 for(int ll=0; ll < static_cast<int>(rf.ext_param_.row_count_); ++ll)
998 {
999 if(oobCount[ll])
1000 {
1001 if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0))
1002 ++breimanstyle;
1003 ++totalOobCount;
1004 }
1005 }
1006 oob_breiman = double(breimanstyle)/totalOobCount;
1007 }
1008};
1009
1010
1011/** Visitor that calculates different OOB error statistics
1012 */
1014{
1016 int class_count;
1017 bool is_weighted;
1018 MultiArray<2,double> tmp_prob;
1019 public:
1020
1021 /** OOB Error rate of each individual tree
1022 */
1024 /** Mean of oob_per_tree
1025 */
1026 double oob_mean;
1027 /**Standard deviation of oob_per_tree
1028 */
1029 double oob_std;
1030
1031 MultiArray<2, double> prob_oob;
1032 /** Ensemble OOB error
1033 *
1034 * \sa OOB_Error
1035 */
1037
1038 MultiArray<2, double> oobCount;
1039 MultiArray<2, double> oobErrorCount;
1040 /** Per Tree OOB error calculated as in OOB_PerTreeError
1041 * (Ulli's version)
1042 */
1044
1045 /**Column containing the development of the Ensemble
1046 * error rate with increasing number of trees
1047 */
1049 /** 4 dimensional array containing the development of confusion matrices
1050 * with number of trees - can be used to estimate ROC curves etc.
1051 *
1052 * oobroc_per_tree(ii,jj,kk,ll)
1053 * corresponds true label = ii
1054 * predicted label = jj
1055 * confusion matrix after ll trees
1056 *
1057 * explanation of third index:
1058 *
1059 * Two class case:
1060 * kk = 0 - (treeCount-1)
1061 * Threshold is on Probability for class 0 is kk/(treeCount-1);
1062 * More classes:
1063 * kk = 0. Threshold on probability set by argMax of the probability array.
1064 */
1066
1068
1069#ifdef HasHDF5
1070 /** save to HDF5 file
1071 */
1072 void save(std::string filen, std::string pathn)
1073 {
1074 if(*(pathn.end()-1) != '/')
1075 pathn += "/";
1076 const char* filename = filen.c_str();
1077 MultiArray<2, double> temp(Shp(1,1), 0.0);
1078 writeHDF5(filename, (pathn + "oob_per_tree").c_str(), oob_per_tree);
1079 writeHDF5(filename, (pathn + "oobroc_per_tree").c_str(), oobroc_per_tree);
1080 writeHDF5(filename, (pathn + "breiman_per_tree").c_str(), breiman_per_tree);
1081 temp[0] = oob_mean;
1082 writeHDF5(filename, (pathn + "per_tree_error").c_str(), temp);
1083 temp[0] = oob_std;
1084 writeHDF5(filename, (pathn + "per_tree_error_std").c_str(), temp);
1085 temp[0] = oob_breiman;
1086 writeHDF5(filename, (pathn + "breiman_error").c_str(), temp);
1087 temp[0] = oob_per_tree2;
1088 writeHDF5(filename, (pathn + "ulli_error").c_str(), temp);
1089 }
1090#endif
1091 // negative value if sample was ib, number indicates how often.
1092 // value >=0 if sample was oob, 0 means fail 1, correct
1093
1094 template<class RF, class PR>
1095 void visit_at_beginning(RF & rf, PR &)
1096 {
1097 class_count = rf.class_count();
1098 if(class_count == 2)
1099 oobroc_per_tree.reshape(MultiArrayShape<4>::type(2,2,rf.tree_count(), rf.tree_count()));
1100 else
1101 oobroc_per_tree.reshape(MultiArrayShape<4>::type(rf.class_count(),rf.class_count(),1, rf.tree_count()));
1102 tmp_prob.reshape(Shp(1, class_count), 0);
1103 prob_oob.reshape(Shp(rf.ext_param().row_count_,class_count), 0);
1104 is_weighted = rf.options().predict_weighted_;
1105 oob_per_tree.reshape(Shp(1, rf.tree_count()), 0);
1106 breiman_per_tree.reshape(Shp(1, rf.tree_count()), 0);
1107 //do the first time called.
1108 if(int(oobCount.size()) != rf.ext_param_.row_count_)
1109 {
1110 oobCount.reshape(Shp(rf.ext_param_.row_count_, 1), 0);
1111 oobErrorCount.reshape(Shp(rf.ext_param_.row_count_,1), 0);
1112 }
1113 }
1114
1115 template<class RF, class PR, class SM, class ST>
1116 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST &, int index)
1117 {
1118 // go through the samples
1119 int total_oob =0;
1120 int wrong_oob =0;
1121 for(int ll = 0; ll < rf.ext_param_.row_count_; ++ll)
1122 {
1123 // if the lth sample is oob...
1124 if(!sm.is_used()[ll])
1125 {
1126 // update number of trees in which current sample is oob
1127 ++oobCount[ll];
1128
1129 // update number of oob samples in this tree.
1130 ++total_oob;
1131 // get the predicted votes ---> tmp_prob;
1132 int pos = rf.tree(index).getToLeaf(rowVector(pr.features(),ll));
1133 Node<e_ConstProbNode> node ( rf.tree(index).topology_,
1134 rf.tree(index).parameters_,
1135 pos);
1136 tmp_prob.init(0);
1137 for(int ii = 0; ii < class_count; ++ii)
1138 {
1139 tmp_prob[ii] = node.prob_begin()[ii];
1140 }
1141 if(is_weighted)
1142 {
1143 for(int ii = 0; ii < class_count; ++ii)
1144 tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
1145 }
1146 rowVector(prob_oob, ll) += tmp_prob;
1147 int label = argMax(tmp_prob);
1148
1149 if(label != pr.response()(ll, 0))
1150 {
1151 // update number of wrong oob samples in this tree.
1152 ++wrong_oob;
1153 // update number of trees in which current sample is wrong oob
1154 ++oobErrorCount[ll];
1155 }
1156 }
1157 }
1158 int breimanstyle = 0;
1159 int totalOobCount = 0;
1160 for(int ll=0; ll < static_cast<int>(rf.ext_param_.row_count_); ++ll)
1161 {
1162 if(oobCount[ll])
1163 {
1164 if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0))
1165 ++breimanstyle;
1166 ++totalOobCount;
1167 if(oobroc_per_tree.shape(2) == 1)
1168 {
1169 oobroc_per_tree(pr.response()(ll,0), argMax(rowVector(prob_oob, ll)),0 ,index)++;
1170 }
1171 }
1172 }
1173 if(oobroc_per_tree.shape(2) == 1)
1174 oobroc_per_tree.bindOuter(index)/=totalOobCount;
1175 if(oobroc_per_tree.shape(2) > 1)
1176 {
1177 MultiArrayView<3, double> current_roc
1178 = oobroc_per_tree.bindOuter(index);
1179 for(int gg = 0; gg < current_roc.shape(2); ++gg)
1180 {
1181 for(int ll=0; ll < static_cast<int>(rf.ext_param_.row_count_); ++ll)
1182 {
1183 if(oobCount[ll])
1184 {
1185 int pred = prob_oob(ll, 1) > (double(gg)/double(current_roc.shape(2)))?
1186 1 : 0;
1187 current_roc(pr.response()(ll, 0), pred, gg)+= 1;
1188 }
1189 }
1190 current_roc.bindOuter(gg)/= totalOobCount;
1191 }
1192 }
1193 breiman_per_tree[index] = double(breimanstyle)/double(totalOobCount);
1194 oob_per_tree[index] = double(wrong_oob)/double(total_oob);
1195 // go through the ib samples;
1196 }
1197
1198 /** Normalise variable importance after the number of trees is known.
1199 */
1200 template<class RF, class PR>
1201 void visit_at_end(RF & rf, PR & pr)
1202 {
1203 // ullis original metric and breiman style stuff
1204 oob_per_tree2 = 0;
1205 int totalOobCount =0;
1206 int breimanstyle = 0;
1207 for(int ll=0; ll < static_cast<int>(rf.ext_param_.row_count_); ++ll)
1208 {
1209 if(oobCount[ll])
1210 {
1211 if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0))
1212 ++breimanstyle;
1213 oob_per_tree2 += double(oobErrorCount[ll]) / oobCount[ll];
1214 ++totalOobCount;
1215 }
1216 }
1217 oob_per_tree2 /= totalOobCount;
1218 oob_breiman = double(breimanstyle)/totalOobCount;
1219 // mean error of each tree
1221 MultiArrayView<2, double> stdDev(Shp(1,1), &oob_std);
1222 rowStatistics(oob_per_tree, mean, stdDev);
1223 }
1224};
1225
1226/** calculate variable importance while learning.
1227 */
1229{
1230 public:
1231
1232 /** This Array has the same entries as the R - random forest variable
1233 * importance.
1234 * Matrix is featureCount by (classCount +2)
1235 * variable_importance_(ii,jj) is the variable importance measure of
1236 * the ii-th variable according to:
1237 * jj = 0 - (classCount-1)
1238 * classwise permutation importance
1239 * jj = rowCount(variable_importance_) -2
1240 * permutation importance
1241 * jj = rowCount(variable_importance_) -1
1242 * gini decrease importance.
1243 *
1244 * permutation importance:
1245 * The difference between the fraction of OOB samples classified correctly
1246 * before and after permuting (randomizing) the ii-th column is calculated.
1247 * The ii-th column is permuted rep_cnt times.
1248 *
1249 * class wise permutation importance:
1250 * same as permutation importance. We only look at those OOB samples whose
1251 * response corresponds to class jj.
1252 *
1253 * gini decrease importance:
1254 * row ii corresponds to the sum of all gini decreases induced by variable ii
1255 * in each node of the random forest.
1256 */
1258 int repetition_count_;
1259 bool in_place_;
1260
1261#ifdef HasHDF5
1262 void save(std::string filename, std::string prefix)
1263 {
1264 prefix = "variable_importance_" + prefix;
1265 writeHDF5(filename.c_str(),
1266 prefix.c_str(),
1268 }
1269#endif
1270
1271 /* Constructor
1272 * \param rep_cnt (defautl: 10) how often should
1273 * the permutation take place. Set to 1 to make calculation faster (but
1274 * possibly more instable)
1275 */
1277 : repetition_count_(rep_cnt)
1278
1279 {}
1280
1281 /** calculates impurity decrease based variable importance after every
1282 * split.
1283 */
1284 template<class Tree, class Split, class Region, class Feature_t, class Label_t>
1286 Split & split,
1287 Region & /* parent */,
1288 Region & /* leftChild */,
1289 Region & /* rightChild */,
1290 Feature_t & /* features */,
1291 Label_t & /* labels */)
1292 {
1293 //resize to right size when called the first time
1294
1295 Int32 const class_count = tree.ext_param_.class_count_;
1296 Int32 const column_count = tree.ext_param_.column_count_;
1297 if(variable_importance_.size() == 0)
1298 {
1299
1301 .reshape(MultiArrayShape<2>::type(column_count,
1302 class_count+2));
1303 }
1304
1305 if(split.createNode().typeID() == i_ThresholdNode)
1306 {
1307 Node<i_ThresholdNode> node(split.createNode());
1308 variable_importance_(node.column(),class_count+1)
1309 += split.region_gini_ - split.minGini();
1310 }
1311 }
1312
1313 /**compute permutation based var imp.
1314 * (Only an Array of size oob_sample_count x 1 is created.
1315 * - apposed to oob_sample_count x feature_count in the other method.
1316 *
1317 * \sa FieldProxy
1318 */
1319 template<class RF, class PR, class SM, class ST>
1320 void after_tree_ip_impl(RF& rf, PR & pr, SM & sm, ST & /* st */, int index)
1321 {
1323 Int32 column_count = rf.ext_param_.column_count_;
1324 Int32 class_count = rf.ext_param_.class_count_;
1325
1326 /* This solution saves memory uptake but not multithreading
1327 * compatible
1328 */
1329 // remove the const cast on the features (yep , I know what I am
1330 // doing here.) data is not destroyed.
1331 //typename PR::Feature_t & features
1332 // = const_cast<typename PR::Feature_t &>(pr.features());
1333
1334 typedef typename PR::FeatureWithMemory_t FeatureArray;
1335 typedef typename FeatureArray::value_type FeatureValue;
1336
1337 FeatureArray features = pr.features();
1338
1339 //find the oob indices of current tree.
1341 ArrayVector<Int32>::iterator
1342 iter;
1343 for(int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
1344 if(!sm.is_used()[ii])
1345 oob_indices.push_back(ii);
1346
1347 //create space to back up a column
1349
1350 // Random foo
1351#ifdef CLASSIFIER_TEST
1352 RandomMT19937 random(1);
1353#else
1354 RandomMT19937 random(RandomSeed);
1355#endif
1357 randint(random);
1358
1359
1360 //make some space for the results
1362 oob_right(Shp_t(1, class_count + 1));
1364 perm_oob_right (Shp_t(1, class_count + 1));
1365
1366
1367 // get the oob success rate with the original samples
1368 for(iter = oob_indices.begin();
1369 iter != oob_indices.end();
1370 ++iter)
1371 {
1372 if(rf.tree(index)
1373 .predictLabel(rowVector(features, *iter))
1374 == pr.response()(*iter, 0))
1375 {
1376 //per class
1377 ++oob_right[pr.response()(*iter,0)];
1378 //total
1379 ++oob_right[class_count];
1380 }
1381 }
1382 //get the oob rate after permuting the ii'th dimension.
1383 for(int ii = 0; ii < column_count; ++ii)
1384 {
1385 perm_oob_right.init(0.0);
1386 //make backup of original column
1387 backup_column.clear();
1388 for(iter = oob_indices.begin();
1389 iter != oob_indices.end();
1390 ++iter)
1391 {
1392 backup_column.push_back(features(*iter,ii));
1393 }
1394
1395 //get the oob rate after permuting the ii'th dimension.
1396 for(int rr = 0; rr < repetition_count_; ++rr)
1397 {
1398 //permute dimension.
1399 int n = oob_indices.size();
1400 for(int jj = n-1; jj >= 1; --jj)
1401 std::swap(features(oob_indices[jj], ii),
1402 features(oob_indices[randint(jj+1)], ii));
1403
1404 //get the oob success rate after permuting
1405 for(iter = oob_indices.begin();
1406 iter != oob_indices.end();
1407 ++iter)
1408 {
1409 if(rf.tree(index)
1410 .predictLabel(rowVector(features, *iter))
1411 == pr.response()(*iter, 0))
1412 {
1413 //per class
1414 ++perm_oob_right[pr.response()(*iter, 0)];
1415 //total
1416 ++perm_oob_right[class_count];
1417 }
1418 }
1419 }
1420
1421
1422 //normalise and add to the variable_importance array.
1423 perm_oob_right /= repetition_count_;
1425 perm_oob_right *= -1;
1428 .subarray(Shp_t(ii,0),
1429 Shp_t(ii+1,class_count+1)) += perm_oob_right;
1430 //copy back permuted dimension
1431 for(int jj = 0; jj < int(oob_indices.size()); ++jj)
1432 features(oob_indices[jj], ii) = backup_column[jj];
1433 }
1434 }
1435
1436 /** calculate permutation based impurity after every tree has been
1437 * learned default behaviour is that this happens out of place.
1438 * If you have very big data sets and want to avoid copying of data
1439 * set the in_place_ flag to true.
1440 */
1441 template<class RF, class PR, class SM, class ST>
1442 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index)
1443 {
1444 after_tree_ip_impl(rf, pr, sm, st, index);
1445 }
1446
1447 /** Normalise variable importance after the number of trees is known.
1448 */
1449 template<class RF, class PR>
1450 void visit_at_end(RF & rf, PR & /* pr */)
1451 {
1452 variable_importance_ /= rf.trees_.size();
1453 }
1454};
1455
1456/** Verbose output
1457 */
1459 public:
1461
1462 template<class RF, class PR, class SM, class ST>
1463 void visit_after_tree(RF& rf, PR &, SM &, ST &, int index){
1464 if(index != rf.options().tree_count_-1) {
1465 std::cout << "\r[" << std::setw(10) << (index+1)/static_cast<double>(rf.options().tree_count_)*100 << "%]"
1466 << " (" << index+1 << " of " << rf.options().tree_count_ << ") done" << std::flush;
1467 }
1468 else {
1469 std::cout << "\r[" << std::setw(10) << 100.0 << "%]" << std::endl;
1470 }
1471 }
1472
1473 template<class RF, class PR>
1474 void visit_at_end(RF const & rf, PR const &) {
1475 std::string a = TOCS;
1476 std::cout << "all " << rf.options().tree_count_ << " trees have been learned in " << a << std::endl;
1477 }
1478
1479 template<class RF, class PR>
1480 void visit_at_beginning(RF const & rf, PR const &) {
1481 TIC;
1482 std::cout << "growing random forest, which will have " << rf.options().tree_count_ << " trees" << std::endl;
1483 }
1484
1485 private:
1486 USETICTOC;
1487};
1488
1489
1490/** Computes Correlation/Similarity Matrix of features while learning
1491 * random forest.
1492 */
1494{
1495 public:
1496 /** gini_missc(ii, jj) describes how well variable jj can describe a partition
1497 * created on variable ii(when variable ii was chosen)
1498 */
1500 MultiArray<2, int> tmp_labels;
1501 /** additional noise features.
1502 */
1504 MultiArray<2, double> noise_l;
1505 /** how well can a noise column describe a partition created on variable ii.
1506 */
1508 MultiArray<2, double> corr_l;
1509
1510 /** Similarity Matrix
1511 *
1512 * (numberOfFeatures + 1) by (number Of Features + 1) Matrix
1513 * gini_missc
1514 * - row normalized by the number of times the column was chosen
1515 * - mean of corr_noise subtracted
1516 * - and symmetrised.
1517 *
1518 */
1520 /** Distance Matrix 1-similarity
1521 */
1523 ArrayVector<int> tmp_cc;
1524
1525 /** How often was variable ii chosen
1526 */
1530 void save(std::string, std::string)
1531 {
1532 /*
1533 std::string tmp;
1534#define VAR_WRITE(NAME) \
1535 tmp = #NAME;\
1536 tmp += "_";\
1537 tmp += prefix;\
1538 vigra::writeToHDF5File(file.c_str(), tmp.c_str(), NAME);
1539 VAR_WRITE(gini_missc);
1540 VAR_WRITE(corr_noise);
1541 VAR_WRITE(distance);
1542 VAR_WRITE(similarity);
1543 vigra::writeToHDF5File(file.c_str(), "nChoices", MultiArrayView<2, int>(MultiArrayShape<2>::type(numChoices.size(),1), numChoices.data()));
1544#undef VAR_WRITE
1545*/
1546 }
1547
1548 template<class RF, class PR>
1549 void visit_at_beginning(RF const & rf, PR & pr)
1550 {
1551 typedef MultiArrayShape<2>::type Shp;
1552 int n = rf.ext_param_.column_count_;
1553 gini_missc.reshape(Shp(n +1,n+ 1));
1554 corr_noise.reshape(Shp(n + 1, 10));
1555 corr_l.reshape(Shp(n +1, 10));
1556
1557 noise.reshape(Shp(pr.features().shape(0), 10));
1558 noise_l.reshape(Shp(pr.features().shape(0), 10));
1559 RandomMT19937 random(RandomSeed);
1560 for(int ii = 0; ii < noise.size(); ++ii)
1561 {
1562 noise[ii] = random.uniform53();
1563 noise_l[ii] = random.uniform53() > 0.5;
1564 }
1565 bgfunc = ColumnDecisionFunctor( rf.ext_param_);
1566 tmp_labels.reshape(pr.response().shape());
1567 tmp_cc.resize(2);
1568 numChoices.resize(n+1);
1569 // look at all axes
1570 }
1571 template<class RF, class PR>
1572 void visit_at_end(RF const &, PR const &)
1573 {
1574 typedef MultiArrayShape<2>::type Shp;
1578 rowStatistics(corr_noise, mean_noise);
1580 int rC = similarity.shape(0);
1581 for(int jj = 0; jj < rC-1; ++jj)
1582 {
1583 rowVector(similarity, jj) /= numChoices[jj];
1584 rowVector(similarity, jj) -= mean_noise(jj, 0);
1585 }
1586 for(int jj = 0; jj < rC; ++jj)
1587 {
1588 similarity(rC -1, jj) /= numChoices[jj];
1589 }
1590 rowVector(similarity, rC - 1) -= mean_noise(rC-1, 0);
1591 similarity = abs(similarity);
1592 FindMinMax<double> minmax;
1593 inspectMultiArray(srcMultiArrayRange(similarity), minmax);
1594
1595 for(int jj = 0; jj < rC; ++jj)
1596 similarity(jj, jj) = minmax.max;
1597
1598 similarity.subarray(Shp(0,0), Shp(rC-1, rC-1))
1599 += similarity.subarray(Shp(0,0), Shp(rC-1, rC-1)).transpose();
1600 similarity.subarray(Shp(0,0), Shp(rC-1, rC-1))/= 2;
1601 columnVector(similarity, rC-1) = rowVector(similarity, rC-1).transpose();
1602 for(int jj = 0; jj < rC; ++jj)
1603 similarity(jj, jj) = 0;
1604
1605 FindMinMax<double> minmax2;
1606 inspectMultiArray(srcMultiArrayRange(similarity), minmax2);
1607 for(int jj = 0; jj < rC; ++jj)
1608 similarity(jj, jj) = minmax2.max;
1609 distance.reshape(gini_missc.shape(), minmax2.max);
1611 }
1612
1613 template<class Tree, class Split, class Region, class Feature_t, class Label_t>
1614 void visit_after_split( Tree &,
1615 Split & split,
1616 Region & parent,
1617 Region &,
1618 Region &,
1619 Feature_t & features,
1620 Label_t & labels)
1621 {
1622 if(split.createNode().typeID() == i_ThresholdNode)
1623 {
1624 double wgini;
1625 tmp_cc.init(0);
1626 for(int ii = 0; ii < parent.size(); ++ii)
1627 {
1628 tmp_labels[parent[ii]]
1629 = (features(parent[ii], split.bestSplitColumn()) < split.bestSplitThreshold());
1630 ++tmp_cc[tmp_labels[parent[ii]]];
1631 }
1632 double region_gini = bgfunc.loss_of_region(tmp_labels,
1633 parent.begin(),
1634 parent.end(),
1635 tmp_cc);
1636
1637 int n = split.bestSplitColumn();
1638 ++numChoices[n];
1639 ++(*(numChoices.end()-1));
1640 //this functor does all the work
1641 for(int k = 0; k < features.shape(1); ++k)
1642 {
1643 bgfunc(columnVector(features, k),
1644 tmp_labels,
1645 parent.begin(), parent.end(),
1646 tmp_cc);
1647 wgini = (region_gini - bgfunc.min_gini_);
1648 gini_missc(n, k)
1649 += wgini;
1650 }
1651 for(int k = 0; k < 10; ++k)
1652 {
1653 bgfunc(columnVector(noise, k),
1654 tmp_labels,
1655 parent.begin(), parent.end(),
1656 tmp_cc);
1657 wgini = (region_gini - bgfunc.min_gini_);
1658 corr_noise(n, k)
1659 += wgini;
1660 }
1661
1662 for(int k = 0; k < 10; ++k)
1663 {
1664 bgfunc(columnVector(noise_l, k),
1665 tmp_labels,
1666 parent.begin(), parent.end(),
1667 tmp_cc);
1668 wgini = (region_gini - bgfunc.min_gini_);
1669 corr_l(n, k)
1670 += wgini;
1671 }
1672 bgfunc(labels, tmp_labels, parent.begin(), parent.end(),tmp_cc);
1673 wgini = (region_gini - bgfunc.min_gini_);
1675 += wgini;
1676
1677 region_gini = split.region_gini_;
1678#if 1
1679 Node<i_ThresholdNode> node(split.createNode());
1681 node.column())
1682 +=split.region_gini_ - split.minGini();
1683#endif
1684 for(int k = 0; k < 10; ++k)
1685 {
1686 split.bgfunc(columnVector(noise, k),
1687 labels,
1688 parent.begin(), parent.end(),
1689 parent.classCounts());
1691 k)
1692 += wgini;
1693 }
1694#if 0
1695 for(int k = 0; k < tree.ext_param_.actual_mtry_; ++k)
1696 {
1697 wgini = region_gini - split.min_gini_[k];
1698
1700 split.splitColumns[k])
1701 += wgini;
1702 }
1703
1704 for(int k=tree.ext_param_.actual_mtry_; k<features.shape(1); ++k)
1705 {
1706 split.bgfunc(columnVector(features, split.splitColumns[k]),
1707 labels,
1708 parent.begin(), parent.end(),
1709 parent.classCounts());
1710 wgini = region_gini - split.bgfunc.min_gini_;
1712 split.splitColumns[k]) += wgini;
1713 }
1714#endif
1715 // remember to partition the data according to the best.
1718 += region_gini;
1719 SortSamplesByDimensions<Feature_t>
1720 sorter(features, split.bestSplitColumn(), split.bestSplitThreshold());
1721 std::partition(parent.begin(), parent.end(), sorter);
1722 }
1723 }
1724};
1725
1726
1727} // namespace visitors
1728} // namespace rf
1729} // namespace vigra
1730
1731#endif // RF_VISITORS_HXX
const_pointer data() const
Definition array_vector.hxx:209
const_iterator end() const
Definition array_vector.hxx:237
MultiArrayView subarray(difference_type p, difference_type q) const
Definition multi_array.hxx:1530
const difference_type & shape() const
Definition multi_array.hxx:1650
MultiArrayView< N-M, T, StrideTag > bindOuter(const TinyVector< Index, M > &d) const
Definition multi_array.hxx:2186
difference_type_1 size() const
Definition multi_array.hxx:1643
MultiArrayView< N, T, StridedArrayTag > transpose() const
Definition multi_array.hxx:1569
void reshape(const difference_type &shape)
Definition multi_array.hxx:2863
Class for a single RGB value.
Definition rgbvalue.hxx:128
void init(Iterator i, Iterator end)
Definition tinyvector.hxx:708
size_type size() const
Definition tinyvector.hxx:913
iterator end()
Definition tinyvector.hxx:864
iterator begin()
Definition tinyvector.hxx:861
Class for fixed size vectors.
Definition tinyvector.hxx:1008
Definition rf_visitors.hxx:1014
double oob_per_tree2
Definition rf_visitors.hxx:1043
MultiArray< 2, double > breiman_per_tree
Definition rf_visitors.hxx:1048
double oob_mean
Definition rf_visitors.hxx:1026
double oob_breiman
Definition rf_visitors.hxx:1036
MultiArray< 2, double > oob_per_tree
Definition rf_visitors.hxx:1023
void visit_at_end(RF &rf, PR &pr)
Definition rf_visitors.hxx:1201
MultiArray< 4, double > oobroc_per_tree
Definition rf_visitors.hxx:1065
double oob_std
Definition rf_visitors.hxx:1029
Definition rf_visitors.hxx:1494
MultiArray< 2, double > distance
Definition rf_visitors.hxx:1522
MultiArray< 2, double > corr_noise
Definition rf_visitors.hxx:1507
MultiArray< 2, double > gini_missc
Definition rf_visitors.hxx:1499
MultiArray< 2, double > similarity
Definition rf_visitors.hxx:1519
ArrayVector< int > numChoices
Definition rf_visitors.hxx:1527
MultiArray< 2, double > noise
Definition rf_visitors.hxx:1503
Definition rf_visitors.hxx:865
double oob_breiman
Definition rf_visitors.hxx:875
void visit_at_end(RF &rf, PR &pr)
Definition rf_visitors.hxx:992
Definition rf_visitors.hxx:784
void visit_after_tree(RF &rf, PR &pr, SM &sm, ST &, int index)
Definition rf_visitors.hxx:807
double oobError
Definition rf_visitors.hxx:788
void visit_at_end(RF &rf, PR &)
Definition rf_visitors.hxx:836
Definition rf_visitors.hxx:585
void visit_internal_node(TR &tr, IntT index, TopT node_t, Feat &features)
Definition rf_visitors.hxx:724
void reset_tree(int tree_id)
Definition rf_visitors.hxx:636
void visit_after_tree(RF &, PR &, SM &, ST &, int)
Definition rf_visitors.hxx:647
void visit_at_beginning(RF &rf, const PR &)
Definition rf_visitors.hxx:628
Definition rf_visitors.hxx:236
Definition rf_visitors.hxx:1229
void visit_after_split(Tree &tree, Split &split, Region &, Region &, Region &, Feature_t &, Label_t &)
Definition rf_visitors.hxx:1285
void visit_at_end(RF &rf, PR &)
Definition rf_visitors.hxx:1450
void visit_after_tree(RF &rf, PR &pr, SM &sm, ST &st, int index)
Definition rf_visitors.hxx:1442
void after_tree_ip_impl(RF &rf, PR &pr, SM &sm, ST &, int index)
Definition rf_visitors.hxx:1320
MultiArray< 2, double > variable_importance_
Definition rf_visitors.hxx:1257
Definition rf_visitors.hxx:103
void visit_at_beginning(RF const &rf, PR const &pr)
Definition rf_visitors.hxx:188
void visit_external_node(TR &tr, IntT index, TopT node_t, Feat &features)
Definition rf_visitors.hxx:206
void visit_after_split(Tree &tree, Split &split, Region &parent, Region &leftChild, Region &rightChild, Feature_t &features, Label_t &labels)
Definition rf_visitors.hxx:143
void visit_internal_node(TR &, IntT, TopT, Feat &)
Definition rf_visitors.hxx:216
void visit_after_tree(RF &rf, PR &pr, SM &sm, ST &st, int index)
Definition rf_visitors.hxx:164
void visit_at_end(RF const &rf, PR const &pr)
Definition rf_visitors.hxx:176
double return_val()
Definition rf_visitors.hxx:226
Definition rf_visitors.hxx:256
MultiArrayIndex columnCount(const MultiArrayView< 2, T, C > &x)
Definition matrix.hxx:684
MultiArrayView< 2, T, C > rowVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition matrix.hxx:697
MultiArrayIndex rowCount(const MultiArrayView< 2, T, C > &x)
Definition matrix.hxx:671
MultiArrayView< 2, T, C > columnVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition matrix.hxx:727
detail::VisitorNode< A > create_visitor(A &a)
Definition rf_visitors.hxx:345
void writeHDF5(...)
Store array data in an HDF5 file.
Iterator argMax(Iterator first, Iterator last)
Find the maximum element in a sequence.
Definition algorithm.hxx:96
void inspectMultiArray(...)
Call an analyzing functor at every element of a multi-dimensional array.
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int
Definition sized_int.hxx:175
#define TIC
Definition timing.hxx:321
#define TOCS
Definition timing.hxx:324

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.12.2