iRoCS Toolbox  1.1.0
BasicSVMAdapterTempl.hh
Go to the documentation of this file.
1 /**************************************************************************
2  *
3  * Copyright (C) 2004-2015 Olaf Ronneberger, Florian Pigorsch, Jörg Mechnich,
4  * Thorsten Falk
5  *
6  * Image Analysis Lab, University of Freiburg, Germany
7  *
8  * This program is free software; you can redistribute it and/or modify
9  * it under the terms of the GNU General Public License as published by
10  * the Free Software Foundation; either version 3 of the License, or
11  * (at your option) any later version.
12  *
13  * This program is distributed in the hope that it will be useful,
14  * but WITHOUT ANY WARRANTY; without even the implied warranty of
15  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16  * GNU General Public License for more details.
17  *
18  * You should have received a copy of the GNU General Public License
19  * along with this program; if not, write to the Free Software Foundation,
20  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
21  *
22  **************************************************************************/
23 
24 /**************************************************************************
25 ** Title:
26 ** $RCSfile$
27 ** $Revision: 4820 $$Name$
28 ** $Date: 2011-11-08 10:57:01 +0100 (Tue, 08 Nov 2011) $
29 ** Copyright: GPL $Author: tschmidt $
30 ** Description:
31 **
32 ** Adapter for any Combination of multi-class, two-class and kernel
33 ** algorithms
34 **
35 **-------------------------------------------------------------------------
36 **
37 ** $Log$
38 ** Revision 1.15 2006/10/06 13:50:05 fehr
39 ** linear model optimizer added
40 **
41 ** Revision 1.14 2006/01/25 08:53:51 fehr
42 ** added HistIntersect to Kmatrix Kernels and new BasicSVMAdapter interface to TwoClassModels
43 **
44 ** Revision 1.13 2005/11/07 22:18:37 mechnich
45 ** changes for g++ 4
46 **
47 ** Revision 1.12 2005/10/26 06:58:06 ronneber
48 ** - added set/getStoreClassificationDetailsFlag()
49 ** - instead of classify() that returns a double detailedResultsVec, now
50 ** the details can be requested after classify() with the more general
51 ** saveClassificationDetailsASCII()
52 **
53 ** Revision 1.11 2005/07/19 13:03:59 haasdonk
54 ** removed redundant cout-messages, erroneous filename and added a new function
55 ** for computing a training-kernel matrix in BasicSVMAdapter*
56 **
57 ** Revision 1.10 2005/06/06 21:23:31 haasdonk
58 ** added updateCache() with two FV-lists, required for classification with precomputed kernel-matrices
59 **
60 ** Revision 1.9 2005/03/29 17:50:58 ronneber
61 ** - added updateKernelCache() and clearKernelCache()
62 **
63 ** Revision 1.8 2005/02/24 16:58:17 fehr
64 ** sometimes... it's better to go home: FINAL bugfix
65 **
66 ** Revision 1.7 2005/02/24 16:33:58 fehr
67 ** bugfix
68 **
69 ** Revision 1.6 2005/02/24 15:23:05 fehr
70 ** bugfix
71 **
72 ** Revision 1.5 2005/02/24 13:45:49 fehr
73 ** enable SV access through SVMAdapter
74 **
75 ** Revision 1.4 2005/02/24 12:57:50 fehr
76 ** some more 64-bit netcdf bugfixing
77 **
78 ** Revision 1.3 2005/02/23 16:04:05 fehr
79 ** added SV access for SVMAdapter
80 **
81 ** Revision 1.2 2004/09/08 14:15:36 ronneber
82 ** - added saveOnlyKernelParameters()
83 **
84 ** Revision 1.1 2004/08/26 08:36:58 ronneber
85 ** initital import
86 **
87 **
88 **
89 **************************************************************************/
90 
91 
92 #ifndef BASICSVMADAPTERTEMPL_HH
93 #define BASICSVMADAPTERTEMPL_HH
94 
95 #ifdef HAVE_CONFIG_H
96 #include <config.hh>
97 #endif
98 
99 #include "BasicSVMAdapter.hh"
100 #include "CrossValidator.hh"
102 
103 namespace svt
104 {
105 
106  template< typename FV,
107  typename STDATA,
108  typename MCSVMTYPE>
109  class BasicSVMAdapterTempl : public BasicSVMAdapter<FV,STDATA>
110  {
111  public:
113  :_pr(0)
114  {}
115 
117  {
118  _pr = pr;
119  _svm.setProgressReporter( pr);
120  }
121 
122 
123 
124  virtual void loadParameters( STDATA& stData)
125  {
126  _svm.loadParameters( stData);
127  }
128 
129 
130  virtual void loadParameters( StDataASCII& stData)
131  {
132  _svm.loadParameters( stData);
133  }
134 
135 
136  virtual void loadParameters( StDataCmdLine& stData)
137  {
138  _svm.loadParameters( stData);
139  }
140 
141 
142 
143  virtual void loadModel( STDATA& stData)
144  {
145  _model.loadParameters( stData);
146  }
147 
148 
149  virtual void saveParameters( STDATA& stData) const
150  {
151  _svm.saveParameters( stData);
152  }
153 
154  virtual void saveParameters( StDataASCII& stData) const
155  {
156  _svm.saveParameters( stData);
157  }
158 
159  virtual void saveOnlyKernelParameters( StDataASCII& stData) const
160  {
161  _svm.twoClassSVM().kernel().saveParameters( stData);
162  }
163 
164 
165 
166  virtual void saveModel( STDATA& stData) const
167  {
168  _model.saveParameters( stData);
169  }
170 
171  virtual void saveTrainingInfos( STDATA& stData,
172  int detailLevel = 1)
173  {
174  saveTrainingInfosTempl( stData, detailLevel);
175  }
176 
177  virtual void saveTrainingInfos( StDataASCII& stData,
178  int detailLevel = 1)
179  {
180  saveTrainingInfosTempl( stData, detailLevel);
181  }
182 
183  virtual void updateKernelCache( const GroupedTrainingData<FV>& trainData)
184  {
185  _svm.updateKernelCache(
186  trainData.FV_begin(),
187  trainData.FV_end(),
189  }
190 
191  virtual void updateKernelCache( const SVM_Problem<FV>& problem)
192  {
193  _svm.updateKernelCache(
194  problem.FV_begin(),
195  problem.FV_end(),
197  }
198 
199  virtual void updateTestKernelCache(typename
200  std::vector<FV>::iterator FV_begin,
201  const typename
202  std::vector<FV>::iterator& FV_end)
203  {
204  typename std::vector<FV*>::const_iterator svbegin=
205  _model.collectedSupportVectors().begin();
206  typename std::vector<FV*>::const_iterator svend=
207  _model.collectedSupportVectors().end();
208  _svm.twoClassSVM().kernel().updateCache(FV_begin,FV_end,
209  DirectAccessor(),
210  svbegin,svend,
212  }
213 
214 
215  virtual void train( const GroupedTrainingData<FV>& trainData)
216  {
217 
218  _svm.train( trainData, _model);
219  }
220 
221 
222 
223  virtual void train( const SVM_Problem<FV>& problem)
224  {
225  _svm.train( problem, _model);
226  }
227 
228 
229  virtual void clearKernelCache()
230  {
231  _svm.twoClassSVM().kernel().clearCache();
232  }
233 
234 
235  virtual double classify( const FV& testObject)
236  {
238  {
239  return _svm.classify( testObject, _model, _detailedResults);
240  }
241  else
242  {
243  return _svm.classify( testObject, _model);
244  }
245  }
246 
247 // virtual double classify( const FV& testObject,
248 // std::vector<double>& detailedResultsVec)
249 // {
250 // typename MCSVMTYPE::DetailedResultType detailedResults;
251 // double result = _svm.classify( testObject, _model,
252 // detailedResults);
253 // detailedResultsVec.resize( detailedResults.size());
254 // std::copy( detailedResults.begin(), detailedResults.end(),
255 // detailedResultsVec.begin());
256 // return result;
257 // }
258 
259 
260  virtual void saveClassificationDetailsASCII( StDataASCII& stData) const
261  {
262  _svm.saveClassificationDetails( _model, _detailedResults, stData);
263  }
264 
265 
266  virtual void computeTrainKernelMatrix(float** kmatrix,
267  const std::vector<FV>& featureVectors)
268  {
269 // return _svm.twoClassSVM().kernel().k_function(fv1,fv2);
270  size_t nfv = featureVectors.size();
271 
272  typename std::vector<FV>::const_iterator it =
273  featureVectors.begin();
274  for (size_t i = 0; i < nfv; i++, it++)
275  {
276  typename std::vector<FV>::const_iterator jt =
277  featureVectors.begin();
278  for (size_t j = 0; j < nfv; j++, jt++)
279  kmatrix[i][j] = static_cast<float>(
280  _svm.twoClassSVM().kernel().k_function(*it,*jt));
281  }
282  }
283 
284  virtual int nClasses() const
285  {
286  return _model.nClasses();
287  }
288 
289  double classIndexToLabel( unsigned int classIndex) const
290  {
291  return _model.classIndexToLabel( classIndex);
292  }
293 
294  size_t nSupportVectors() const
295  {
296  return _model.getSupportVectors().size();
297  }
298 
299 
300  FV* getSupportVector(int i) const
301  {
302  return _model.getSupportVectors().operator[](i);
303  }
304 
305  std::vector<const typename svt::Model<FV>* > getTwoClassModels() const
306  {
307  std::vector<const typename svt::Model<FV>* > models;
308  for (unsigned int i = 0; i<_model.nTwoClassModels();i++)
309  {
310  models.push_back(&(_model.twoClassModel(i)));
311  }
312  return models;
313  }
314 
315  std::vector<typename svt::Model<FV>* > getTwoClassModels()
316  {
317  std::vector<typename svt::Model<FV>* > models;
318  for (unsigned int i = 0; i<_model.nTwoClassModels();i++)
319  {
320  models.push_back(&(_model.twoClassModel(i)));
321  }
322  return models;
323  }
324 
326  {
327  /*---------------------------------------------------------------------
328  * Optimize Linear Model: precalc normals for fast classification
329  *---------------------------------------------------------------------*/
331  optimizer.optimizeModel(_model);
332 
333  }
334 
335  protected:
336  template<typename STDATA2>
337  void saveTrainingInfosTempl(STDATA2& stData, int detailLevel)
338  {
339  if( detailLevel == 2)
340  {
341  _model.saveTCTrainingInfos( stData);
342  }
343  _model.saveTrainingInfoStatistics( stData);
344  }
345 
346  private:
347  MCSVMTYPE _svm;
348  typename MCSVMTYPE::template Traits<FV>::ModelType _model;
349  typename MCSVMTYPE::DetailedResultType _detailedResults;
350  ProgressReporter* _pr;
351 
352  };
353 
354 }
355 
356 
357 #endif
std::vector< typename svt::Model< FV > *> getTwoClassModels()
The GroupedTrainingData class is a container for feature vectors.
virtual void computeTrainKernelMatrix(float **kmatrix, const std::vector< FV > &featureVectors)
Evaluate kernel matrix of the svm.
FV ** FV_end() const
Definition: SVM_Problem.hh:130
virtual void saveTrainingInfos(STDATA &stData, int detailLevel=1)
save additional trainnig infos to given structured data
virtual void saveParameters(StDataASCII &stData) const
same as saveParameters(), but with fixed StDataASCII class independent from given STDATA template par...
virtual void updateKernelCache(const GroupedTrainingData< FV > &trainData)
call updateCache() of selected kernel, e.g., for Kernel_MATRIX this will setup the internal cache mat...
virtual void saveOnlyKernelParameters(StDataASCII &stData) const
save only Kernel Parameters (this is used for user information and to detect in grid search...
virtual void loadParameters(StDataCmdLine &stData)
Same as loadParameters, but for Parameters from command line.
std::vector< FV * >::const_iterator FV_begin() const
virtual void saveModel(STDATA &stData) const
save resulting model from training process into given map
virtual void saveParameters(STDATA &stData) const
store all parameters of the SVM to given map
virtual void updateKernelCache(const SVM_Problem< FV > &problem)
same as previous updateKernelCache() method, just for feature vectors that are stored within an SVM_P...
virtual void clearKernelCache()
call clearCache() method of selected kernel, e.g., for Kernel_MATRIX this will clear the internal cac...
void optimizeLinearModel()
optimize model for faster classification
virtual void saveClassificationDetailsASCII(StDataASCII &stData) const
classify the given Feature Vector.
virtual int nClasses() const
get number of Classes (loadModel() or train() must have been called before)
virtual void setProgressReporter(ProgressReporter *pr)
set progress reporter object.
FV ** FV_begin() const
Definition: SVM_Problem.hh:125
virtual void updateTestKernelCache(typename std::vector< FV >::iterator FV_begin, const typename std::vector< FV >::iterator &FV_end)
Load Test Kernel Cache.
virtual void saveTrainingInfos(StDataASCII &stData, int detailLevel=1)
same as saveTrainingInfos(), but with fixed StDataASCII class independent from given STDATA template ...
FV * getSupportVector(int i) const
returns ith SV in model
virtual void train(const GroupedTrainingData< FV > &trainData)
train SVM with given training data.
virtual void loadParameters(StDataASCII &stData)
double classIndexToLabel(unsigned int classIndex) const
map classIndex to label (loadModel() or train() must have been called before)
virtual void loadModel(STDATA &stData)
load model data from stData into SVM&#39;s.
virtual void loadParameters(STDATA &stData)
Read all parameters (e.g.
size_t nSupportVectors() const
returns number of SV in model
virtual void train(const SVM_Problem< FV > &problem)
train SVM with given Feature Vectors.
std::vector< const typename svt::Model< FV > *> getTwoClassModels() const
returns vector containing pointers to all TwoClassModels
std::vector< FV * >::const_iterator FV_end() const
The StDataASCII class is a container for "structured data", that is kept completly in memory...
Definition: StDataASCII.hh:83
virtual double classify(const FV &testObject)
classify the given Feature Vector.
void saveTrainingInfosTempl(STDATA2 &stData, int detailLevel)