iRoCS Toolbox  1.1.0
lRandomTree.hh
Go to the documentation of this file.
1 /**************************************************************************
2  *
3  * Copyright (C) 2015 Kun Liu, Thorsten Falk
4  *
5  * Image Analysis Lab, University of Freiburg, Germany
6  *
7  * This program is free software; you can redistribute it and/or modify
8  * it under the terms of the GNU General Public License as published by
9  * the Free Software Foundation; either version 3 of the License, or
10  * (at your option) any later version.
11  *
12  * This program is distributed in the hope that it will be useful,
13  * but WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15  * GNU General Public License for more details.
16  *
17  * You should have received a copy of the GNU General Public License
18  * along with this program; if not, write to the Free Software Foundation,
19  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
20  *
21  **************************************************************************/
22 
23 /*
24  * lRandomTree.h
25  *
26  * Created on: Jun 11, 2011
27  * Author: liu
28  */
29 
30 #ifndef LRANDOMTREE_H_
31 #define LRANDOMTREE_H_
32 
33 #ifdef HAVE_CONFIG_H
34 #include <config.hh>
35 #endif
36 
37 #include <sstream>
38 #include <vector>
39 #include <map>
40 
41 #include <opencv/cxcore.h>
42 //#include <opencv/cv.h>
43 
44 // Auxilary structure
45 struct IntIndex
46 {
47  float val;
48  unsigned int index;
49  bool
50  operator<(const IntIndex& a) const;
51 };
52 
53 // Structure for the leafs
54 struct LeafNode
55 {
56  int label;
59 };
60 
61 // Structure for the normal node
62 struct Node
63 {
64  Node();
65  bool test(const float* f) const;
66  float eval(const float* f) const;
67  int fIdx;
68  float t;
69  int left; // Pointer to the left subtree. Negative for leaf node
70  int right; // Pointer to the right subtree. Negative for leaf node
71 };
72 
74 {
75 
76 public:
77 
78  lRandomTree(const char* filename);
79  lRandomTree(std::stringstream & in);
80  lRandomTree();
81  ~lRandomTree();
82 
83  // Set/Get functions
84  unsigned int GetMaxDepth() const;
85 
86  // Classification
87  int predict(const float *f) const;
88 
89  // Proximity
90  void proximity(
91  const float *f, int cl, std::map<int, double>& proximityCounter) const;
92 
93  // IO functions
94  bool saveTree(const char *filename) const;
95  bool saveTree(std::stringstream &out) const;
96 
97  // Training
98  void growTree(
99  float **TrainX, int *TrainL, std::vector<int>& TrainSet, int m,
100  int n, int maxLabel, float *classWeight, CvRNG *pRNG, int m_try,
101  int max_depth, int min_samples = 1, int num_grid = 20);
102 
103  int getMaxLabel() const;
104  void setMaxLabel(int _maxLabel);
105 
106 private:
107 
108  // Private functions for training
109  void grow(const std::vector<int>& TrainSet, int& node, unsigned int depth);
110 
111  bool findSplit(
112  const std::vector<int>& TrainSet, Node& test,
113  std::vector<int> & SetA, std::vector<int>& SetB);
114 
115  void generateTest(Node& test);
116 
117  void evaluateTest(
118  std::vector<IntIndex> & valSet, const Node& test,
119  const std::vector<int>& TrainSet);
120 
121  void split(
122  std::vector<int>& SetA, std::vector<int>& SetB,
123  const std::vector<int>& TrainSet, const std::vector<IntIndex>& valSet,
124  float t);
125 
126  float measureGini(const std::vector<int>& Set, float& pt);
127 
128  void makeLeaf(const std::vector<int>& TrainSet, int& node);
129 
130  // tree table Data structure
131  // 2^(max_depth+1)-1 x colTreeTable matrix as vector
132  // column: leafindex f1 f2 b c
133  // the test f1 + f2 * b > c
134  // if node is not a leaf, leaf=-1
135  std::vector<Node*> _treetable;
136  int _root;
137  std::vector<LeafNode*> _leaves;
138 
139  // number of features used for splitting
140  unsigned int _m_try;
141 
142  // stop growing when number of patches is less than min_samples
143  unsigned int _min_samples;
144 
145  // depth of the tree: 0-max_depth
146  unsigned int _max_depth;
147 
148  // number of nodes: 2^(max_depth+1)-1
149  unsigned int _num_nodes;
150 
151  // number of leafs
152  unsigned int _num_leaf;
153 
154  // number of iterations for optimizing splitting
155  unsigned int _num_grid;
156 
157  CvRNG *_cvRNG;
158 
159  int _featureDim;
160 
161  int _maxLabel;
162 
163  // for training
164  float** _TrainX;
165  int* _TrainL;
166  int _nSample;
167  float* _classWeight;
168 };
169 
170 #endif /* LRANDOMTREE_H_ */
int fIdx
Definition: lRandomTree.hh:67
int left
Definition: lRandomTree.hh:69
int right
Definition: lRandomTree.hh:70
int * instanceIndex
Definition: lRandomTree.hh:58
float val
Definition: lRandomTree.hh:47
unsigned int index
Definition: lRandomTree.hh:48
float t
Definition: lRandomTree.hh:68
int numOfInstance
Definition: lRandomTree.hh:57
bool operator<(const IntIndex &a) const