iRoCS Toolbox  1.1.0
CrossValidator.hh
Go to the documentation of this file.
1 /**************************************************************************
2  *
3  * Copyright (C) 2004-2015 Olaf Ronneberger, Florian Pigorsch, Jörg Mechnich,
4  * Thorsten Falk
5  *
6  * Image Analysis Lab, University of Freiburg, Germany
7  *
8  * This program is free software; you can redistribute it and/or modify
9  * it under the terms of the GNU General Public License as published by
10  * the Free Software Foundation; either version 3 of the License, or
11  * (at your option) any later version.
12  *
13  * This program is distributed in the hope that it will be useful,
14  * but WITHOUT ANY WARRANTY; without even the implied warranty of
15  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16  * GNU General Public License for more details.
17  *
18  * You should have received a copy of the GNU General Public License
19  * along with this program; if not, write to the Free Software Foundation,
20  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
21  *
22  **************************************************************************/
23 
24 /**************************************************************************
25 ** Title:
26 ** $RCSfile$
27 ** $Revision: 2824 $$Name$
28 ** $Date: 2009-09-14 09:30:46 +0200 (Mon, 14 Sep 2009) $
29 ** Copyright: GPL $Author: ronneber $
30 ** Description:
31 **
32 **
33 **
34 **-------------------------------------------------------------------------
35 **
36 ** $Log$
37 ** Revision 1.3 2005/10/26 07:00:50 ronneber
38 ** - added set/getStoreClassificationDetailsFlag()
39 ** - now collects classification details and can return them via
40 ** classificationDetailsByUID()
41 **
42 ** Revision 1.2 2005/03/29 17:51:41 ronneber
43 ** - added updateKernelCache() and clearKernelCache()
44 **
45 ** Revision 1.1 2004/08/26 08:36:58 ronneber
46 ** initital import
47 **
48 **
49 **
50 **************************************************************************/
51 
52 #ifndef CROSSVALIDATOR_HH
53 #define CROSSVALIDATOR_HH
54 
55 #ifdef HAVE_CONFIG_H
56 #include <config.hh>
57 #endif
58 
59 #include <vector>
60 #include <cmath>
61 #include "GroupedTrainingData.hh"
62 #include "ProgressReporter.hh"
63 #include "StDataASCII.hh"
64 
65 // requirements of template parameters
70 
71 
72 namespace svt
73 {
74 
75  void generateSortedSubsets( int nFeatureVectors, int nSubsets,
76  std::vector<int>& subsetIndizesForFVs);
77 
78  void generateShuffledSubsets( int nFeatureVectors, int nSubsets,
79  std::vector<int>& subsetIndizesForFVs);
80 
81 
82 /*======================================================================*/
113 /*======================================================================*/
114  template< typename FV, typename SVMTYPE, typename PROBLEM = GroupedTrainingData<FV> >
116  {
119  // macros with 3 parameters seem not to work...
121 
122 
123  public:
124  typedef typename SVMTYPE::template Traits<FV>::ModelType ModelType;
125  typedef FV FV_TYPE;
126  typedef PROBLEM PROBLEM_TYPE;
127 
128 
129  CrossValidator( SVMTYPE* svm = 0)
130  : _svm(svm),
131  _owningSVM(false),
132  _problem( 0),
133  _classificationDelta( 0.01),
134  _pr(0),
135  _sum_nSV(0),
136  _sum_nFV(0),
137  _sum_nBSV(0),
138  _storeClassificationDetailsFlag( false)
139 
140  {
141  if( _svm == 0)
142  {
143  _svm = new SVMTYPE;
144  _owningSVM = true;
145  }
146  }
147 
149  {
150  if( _owningSVM)
151  {
152  delete _svm;
153  }
154  }
155 
156  private:
157  // forbid copying
159  void operator=( const CrossValidator<FV, SVMTYPE, PROBLEM>& orig) {}
160  public:
161 
162 
163 
164  SVMTYPE* svm()
165  {
166  return _svm;
167  }
168 
170  {
171  _pr = pr;
172  _svm->setProgressReporter( pr);
173  }
174 
175 
176  /*======================================================================*/
186  /*======================================================================*/
187  void setTrainingData( const PROBLEM* problem)
188  {
189  _problem = problem;
190  }
191 
192  /*======================================================================*/
199  /*======================================================================*/
200  const PROBLEM* trainingData() const
201  {
202  return _problem;
203  }
204 
205 
206  /*======================================================================*/
220  /*======================================================================*/
222  {
223  SVM_ASSERT( _problem != 0);
224  _svm->updateKernelCache( _problem->FV_begin(),
225  _problem->FV_end(),
227  }
228 
229  /*======================================================================*/
239  /*======================================================================*/
241  {
242  _svm->clearKernelCache();
243  }
244 
245 
246 
247  /*======================================================================*/
255  /*======================================================================*/
257  {
258  SVM_ASSERT( _problem != 0);
259 
260  if( _pr != 0)
261  {
263  "preprocess training data", 0, "");
264  }
265 
266  _svm->train( *_problem, _fullModel);
267 
268  if( _pr != 0)
269  {
271  "preprocess training data", 1.0, "");
272  }
273  }
274 
275  /*======================================================================*/
298  /*======================================================================*/
299  int doFullCV( const std::vector<int>& subsetIndexByUID,
300  std::vector<double>& predictedClassLabelByUID);
301 
302 
303  /*======================================================================*/
345  /*======================================================================*/
346  int doPartialCV( int subsetIndex,
347  const std::vector<int>& subsetIndexByUID,
348  std::vector<double>& predictedClassLabelByUID,
349  ModelType* partialModel = 0);
350 
351 
352  const ModelType& fullModel() const
353  {
354  return _fullModel;
355  }
356 
357  void setClassificationDelta( double d)
358  {
359  _classificationDelta = d;
360  }
361 
362  double classificationDelta() const
363  {
364  return _classificationDelta;
365  }
366 
367 
369  {
370  _storeClassificationDetailsFlag = f;
371  }
372 
374  {
375  return _storeClassificationDetailsFlag;
376  }
377 
378 
379  /*======================================================================*/
389  /*======================================================================*/
390  template< typename STDATA>
391  void saveStatistics( STDATA& statistics,
392  int detailLevel = 1)
393  {
394  if( detailLevel >= 1)
395  {
396  statistics.setValue( "sum_nFV", _sum_nFV);
397  statistics.setValue( "sum_nSV", _sum_nSV);
398  statistics.setValue( "sum_nBSV", _sum_nBSV);
399  statistics.setValue( "nSV_per_nFV", double(_sum_nSV) / _sum_nFV);
400  statistics.setValue( "nBSV_per_nSV", double( _sum_nBSV) / _sum_nSV);
401  }
402  }
403 
404  /*======================================================================*/
414  /*======================================================================*/
415  const std::vector< StDataASCII>& classificationDetailsByUID() const
416  {
417  return _classificationDetailsByUID;
418  }
419 
420 
421 
422  template<typename STDATA>
423  void loadParameters( STDATA& stData)
424  {
426 
427  _svm->loadParameters(stData);
428 
429  }
430 
431 
432  template<typename STDATA>
433  void saveParameters( STDATA& stData)
434  {
436 
437  _svm->saveParameters(stData);
438 
439  }
440 
441 
442 
443 
444 
445  private:
446  SVMTYPE* _svm;
447  bool _owningSVM;
448  const PROBLEM* _problem;
449  ModelType _fullModel;
450  double _classificationDelta;
451  ProgressReporter* _pr;
452  unsigned int _sum_nSV;
453  unsigned int _sum_nFV;
454  unsigned int _sum_nBSV;
455  bool _storeClassificationDetailsFlag;
456  std::vector< StDataASCII > _classificationDetailsByUID;
457 
458  };
459 }
460 
461 
462 #include "CrossValidator.icc"
463 
464 #endif
void preprocessTrainingData()
trains all two-class SVM&#39;s with the whole data set.
void saveParameters(STDATA &stData)
virtual void reportProgress(int taskLevel, const std::string &taskName, float completenessPercent, const std::string &completenessPlainText)
This method is called if some progress was made.
#define CHECK_MEMBER_TEMPLATE(c)
void generateShuffledSubsets(int nFeatureVectors, int nSubsets, std::vector< int > &subsetIndizesForFVs)
CrossValidator(SVMTYPE *svm=0)
The CrossValidator class provides a highly optimized cross validation algorithm.
void generateSortedSubsets(int nFeatureVectors, int nSubsets, std::vector< int > &subsetIndizesForFVs)
#define SVM_ASSERT(condition)
Definition: SVMError.hh:176
const ModelType & fullModel() const
#define CHECK_CLASS_TEMPLATE1(c)
void setClassificationDelta(double d)
#define CHECK_CLASS_TEMPLATE2(c)
double classificationDelta() const
void clearKernelCache()
call clearKernelCache() of selected svm.
bool getStoreClassificationDetailsFlag() const
int doFullCV(const std::vector< int > &subsetIndexByUID, std::vector< double > &predictedClassLabelByUID)
do a full cross validation.
SVMTYPE::template Traits< FV >::ModelType ModelType
int doPartialCV(int subsetIndex, const std::vector< int > &subsetIndexByUID, std::vector< double > &predictedClassLabelByUID, ModelType *partialModel=0)
Do one part of a cross validation.
void setTrainingData(const PROBLEM *problem)
set the training data.
const PROBLEM * trainingData() const
get pointer to training data, that was set with setTrainingData().
const int TASK_LEVEL_CROSS_VAL
void updateKernelCache()
call updateKernelCache() of selected svm with given problem.
Ensure that TESTCLASS provides a setProgressReporter() method.
const std::vector< StDataASCII > & classificationDetailsByUID() const
get classification details for each uid from last full CV or parital CV.
void saveStatistics(STDATA &statistics, int detailLevel=1)
save cross validation statistics.
void setProgressReporter(ProgressReporter *pr)
void setStoreClassificationDetailsFlag(bool f)
void loadParameters(STDATA &stData)