iRoCS Toolbox  1.1.0
MultiClassSVMOneVsRest.hh
Go to the documentation of this file.
1 /**************************************************************************
2  *
3  * Copyright (C) 2004-2015 Olaf Ronneberger, Florian Pigorsch, Jörg Mechnich,
4  * Thorsten Falk
5  *
6  * Image Analysis Lab, University of Freiburg, Germany
7  *
8  * This program is free software; you can redistribute it and/or modify
9  * it under the terms of the GNU General Public License as published by
10  * the Free Software Foundation; either version 3 of the License, or
11  * (at your option) any later version.
12  *
13  * This program is distributed in the hope that it will be useful,
14  * but WITHOUT ANY WARRANTY; without even the implied warranty of
15  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16  * GNU General Public License for more details.
17  *
18  * You should have received a copy of the GNU General Public License
19  * along with this program; if not, write to the Free Software Foundation,
20  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
21  *
22  **************************************************************************/
23 
24 /**************************************************************************
25 ** Title: multi class svm using "one versus rest" technique
26 ** $RCSfile$
27 ** $Revision: 4820 $$Name$
28 ** $Date: 2011-11-08 10:57:01 +0100 (Tue, 08 Nov 2011) $
29 ** Copyright: GPL $Author: tschmidt $
30 ** Description:
31 **
32 **
33 **
34 **-------------------------------------------------------------------------
35 **
36 ** $Log$
37 ** Revision 1.7 2005/10/26 07:27:58 ronneber
38 ** - added saveClassificationDetails()
39 ** - corrected some comments where "decision value" was called "alpha"
40 **
41 ** Revision 1.6 2005/03/29 18:01:29 ronneber
42 ** - replaced updateCacheFlag, etc with updateKernelCache() and
43 ** clearKernelCache() methods
44 **
45 ** Revision 1.5 2004/09/13 10:04:04 ronneber
46 ** - documentation update
47 **
48 ** Revision 1.4 2004/09/08 14:32:53 ronneber
49 ** - adapted to new ParamInfo class
50 **
51 ** Revision 1.3 2004/09/03 07:13:36 ronneber
52 ** - adapted to new updateCache() interface of Kernels
53 **
54 ** Revision 1.2 2004/09/01 14:43:36 ronneber
55 ** changed IterToPointerTraits stuff to
56 ** DirectAccessor and DereferencingAccessor, to make code more
57 ** intuitive understandable
58 **
59 ** Revision 1.1 2004/08/26 08:36:59 ronneber
60 ** initital import
61 **
62 ** Revision 1.8 2003/10/01 09:26:49 ronneber
63 ** - added some missing #include's
64 **
65 ** Revision 1.7 2003/05/19 11:04:38 ronneber
66 ** - converted from MapTools to ParamMapWrapper
67 ** - added new train() method where access to the feature vectors
68 ** container is done via a custom functor. This means, you now can give
69 ** your training vectors, e.g., as std::vector<FV> or std::vector<FV*>
70 ** or somethin completely different. You just have to pass an
71 ** appropriate Accessor to make an FV* from an iterator
72 **
73 ** Revision 1.6 2002/09/05 13:08:27 pigorsch
74 ** -modified to use new MapTools
75 **
76 ** Revision 1.5 2002/07/09 06:41:41 ronneber
77 ** added progress reporter for OneVsRest
78 **
79 ** Revision 1.4 2002/07/08 13:46:50 ronneber
80 ** - added copySVCoef... methods for OneVsRest SVM
81 **
82 ** Revision 1.3 2002/05/10 11:33:40 ronneber
83 ** - added default Constructor and twoClassSVM() method
84 **
85 ** Revision 1.2 2002/05/10 11:07:03 ronneber
86 ** - removed FV template for all public classes, because Feature Vector type
87 ** can be extracted automatically from passed iterators or other
88 ** parameters -- this makes the public interface much more intuitive
89 **
90 ** Revision 1.1 2002/05/08 10:37:31 ronneber
91 ** intitial revision
92 **
93 **
94 **
95 **
96 **************************************************************************/
97 
98 #ifndef MULTICLASSSVMONEVSREST_HH
99 #define MULTICLASSSVMONEVSREST_HH
100 
101 #ifdef HAVE_CONFIG_H
102 #include <config.hh>
103 #endif
104 
105 // std includes
106 #include <map>
107 #include <string>
108 
109 // libsvmtl includes
110 
111 #include "ProgressReporter.hh"
112 #include "Model_MC_OneVsRest.hh"
113 #include "GroupedTrainingData.hh"
114 #include "DirectAccessor.hh"
115 #include "DereferencingAccessor.hh"
116 
117 // requirements of template parameters
124 
125 
126 namespace svt
127 {
128 
129  /*======================================================================*/
138  /*======================================================================*/
139  template< typename SVM>
141  {
144  public:
145 
146  template< typename FV>
147  struct Traits
148  {
150  };
151 
152  typedef std::vector<double> DetailedResultType;
153 
154 
155  /*---------------------------------------------------------------------
156  * Dummy result vector, if no detailed results are requested
157  *---------------------------------------------------------------------*/
159  {
160  double dummy;
161  public:
162  void resize( unsigned int)
163  {}
164 
165  double& operator[]( int)
166  {
167  return dummy;
168  }
169  };
170 
171  /*---------------------------------------------------------------------
172  * struct to create sorted results list
173  *---------------------------------------------------------------------*/
175  {
177  double label;
178  bool operator<( const DecisionValueAndLabel& rhs) const
179  {
180  return (decisionValue > rhs.decisionValue);
181  }
182  };
183 
184 
185 
186 
187 
188 
189  /*======================================================================*/
205  /*======================================================================*/
206  MultiClassSVMOneVsRest( const SVM& svm)
207  :_twoClassSVM( svm),
208  _pr(0)
209  {
210  }
211 
212 
214  :_pr(0)
215  {
216  }
217 
218  /*======================================================================*/
231  /*======================================================================*/
233  {
234  _pr = pr;
235  _twoClassSVM.setProgressReporter( pr);
236  }
237 
238 
239 
240  /*======================================================================*/
250  /*======================================================================*/
251  const SVM& twoClassSVM() const
252  {
253  return _twoClassSVM;
254  }
255 
256  SVM& twoClassSVM()
257  {
258  return _twoClassSVM;
259  }
260 
261 
262  /*======================================================================*/
272  /*======================================================================*/
273  template< typename ForwardIter, typename Accessor>
274  void updateKernelCache( const ForwardIter& fvBegin,
275  const ForwardIter& fvEnd,
276  Accessor accessor) const
277  {
278  _twoClassSVM.updateKernelCache( fvBegin, fvEnd, accessor);
279  }
280 
281 
282  /*======================================================================*/
292  /*======================================================================*/
293  void clearKernelCache() const
294  {
295  _twoClassSVM.clearKernelCache();
296  }
297 
298  /*====================================================================*/
310  /*====================================================================*/
311  template<typename FV>
312  void train( const GroupedTrainingData<FV>& trainData,
313  typename Traits<FV>::ModelType& model) const;
314 
315 
316  /*====================================================================*/
326  /*====================================================================*/
327  template<typename FV>
328  void train( const SVM_Problem<FV>& problem,
329  typename Traits<FV>::ModelType& model) const
330  {
331  GroupedTrainingData<FV> trainData( problem);
332  train( trainData, model);
333  }
334 
335  /*======================================================================*/
353  /*======================================================================*/
354  template< typename ForwardIter>
355  void train( ForwardIter FV_begin,
356  const ForwardIter& FV_end,
357  typename Traits<typename std::iterator_traits< ForwardIter>::value_type>::ModelType& model) const
358  {
359  train( FV_begin, FV_end, model, DirectAccessor());
360  }
361 
362  /*======================================================================*/
384  /*======================================================================*/
385  template< typename ForwardIter, typename Accessor>
386  void train( ForwardIter FV_begin,
387  const ForwardIter& FV_end,
388  typename Traits<typename Accessor::template Traits<ForwardIter>::value_type>::ModelType& model,
389  Accessor accessor) const
390  {
391  typedef typename Accessor::template Traits<ForwardIter>::value_type FV;
392  GroupedTrainingData<FV> trainData( FV_begin, FV_end,
393  accessor);
394  train( trainData, model);
395  }
396 
397  /*======================================================================*/
417  /*======================================================================*/
418  template<typename FV>
420  const GroupedTrainingData<FV>& trainData,
421  const typename Traits<FV>::ModelType& fullModel,
422  const std::vector<char>& leaveOutFlagsByUID,
423  typename Traits<FV>::ModelType& resultingModel) const;
424 
425  /*======================================================================*/
439  /*======================================================================*/
440  template< typename FV, typename ResultVector>
441  unsigned int predictClassIndex(
442  const FV& testObject,
443  const typename Traits<FV>::ModelType& model,
444  ResultVector& resultVector) const;
445 
446 
447  /*======================================================================*/
456  /*======================================================================*/
457  template< typename FV>
458  unsigned int predictClassIndex(
459  const FV& testObject,
460  const typename Traits<FV>::ModelType& model) const
461  {
462  DummyResultVector dummy;
463  return predictClassIndex( testObject, model, dummy);
464  }
465 
466  /*======================================================================*/
480  /*======================================================================*/
481  template< typename FV, typename ResultVector>
482  double classify( const FV& testObject,
483  const typename Traits<FV>::ModelType& model,
484  ResultVector& resultVector) const
485  {
486  return model.classIndexToLabel(
487  predictClassIndex( testObject, model, resultVector));
488  }
489 
490  /*======================================================================*/
499  /*======================================================================*/
500  template< typename FV>
501  double classify( const FV& testObject,
502  const typename Traits<FV>::ModelType& model) const
503  {
504  DummyResultVector dummy;
505  return classify( testObject, model, dummy);
506  }
507 
508 
509  template<typename STDATA>
510  void loadParameters( STDATA& stData)
511  {
513  _twoClassSVM.loadParameters( stData);
514  }
515 
516  template<typename STDATA>
517  void saveParameters( STDATA& stData) const
518  {
520  stData.setValue( "multi_class_type", name());
521  _twoClassSVM.saveParameters( stData);
522  }
523 
524 
525  /*======================================================================*/
534  /*======================================================================*/
535  template< typename ModelType, typename STDATA>
537  const DetailedResultType& resultVector,
538  STDATA& stData) const
539  {
540  // create list of decision values with labels
541  std::vector< DecisionValueAndLabel >
542  decValueLabelList( resultVector.size());
543 
544  // copy decision values and labels to that list
545  for( size_t i = 0; i < resultVector.size(); ++i)
546  {
547  decValueLabelList[i].decisionValue = resultVector[i];
548  decValueLabelList[i].label = model.classIndexToLabel(
549  static_cast<int>(i));
550  }
551 
552  // sort the list
553  std::sort( decValueLabelList.begin(), decValueLabelList.end());
554 
555  // store it to structured data
556  std::vector<double> decValueList( decValueLabelList.size());
557  std::vector<double> labelList( decValueLabelList.size());
558  for( size_t i = 0; i < decValueLabelList.size(); ++i)
559  {
560  decValueList[i] = decValueLabelList[i].decisionValue;
561  labelList[i] = decValueLabelList[i].label;
562  }
563 
564  stData.setArray( "labels", labelList.begin(), labelList.size());
565  stData.setArray( "dec_values",
566  decValueList.begin(), decValueList.size());
567  }
568 
569 
570  /*======================================================================*/
579  /*======================================================================*/
580  static void getParamInfos( std::vector<ParamInfo>&)
581  {
582  }
583 
584  static const char* name()
585  {
586  return "one_vs_rest";
587  }
588 
589  static const char* description()
590  {
591  return "multi-class SVM by using the One versus Rest approach";
592  }
593 
594  private:
595  SVM _twoClassSVM;
596  ProgressReporter* _pr;
597 
598  };
599 
600 #include "MultiClassSVMOneVsRest.icc"
601 }
602 
603 #endif
The GroupedTrainingData class is a container for feature vectors.
double classify(const FV &testObject, const typename Traits< FV >::ModelType &model, ResultVector &resultVector) const
classify the given testObject using the model
#define CHECK_MEMBER_TEMPLATE(c)
static void getParamInfos(std::vector< ParamInfo > &)
get information about the parameters, that are used in loadParameters() and saveParameters().
MultiClassSVMOneVsRest(const SVM &svm)
Create a multi class SVM basing on the given "two-class-SVM", using the One versus Rest algorithm...
void train(const GroupedTrainingData< FV > &trainData, typename Traits< FV >::ModelType &model) const
train SVM with given Feature Vectors.
double classIndexToLabel(unsigned int classIndex) const
Definition: Model_MC.hh:145
unsigned int predictClassIndex(const FV &testObject, const typename Traits< FV >::ModelType &model) const
classify the given testObject using the model
#define CHECK_CLASS_TEMPLATE1(c)
void saveClassificationDetails(const ModelType &model, const DetailedResultType &resultVector, STDATA &stData) const
save classification details.
std::vector< double > DetailedResultType
#define CHECK_CLASS_TEMPLATE2(c)
void train(ForwardIter FV_begin, const ForwardIter &FV_end, typename Traits< typename Accessor::template Traits< ForwardIter >::value_type >::ModelType &model, Accessor accessor) const
same as train(), but you can specify an Accessor for accessing the elements of your container...
Ensure that TESTCLASS provides a loadParameters() and saveParamters() method.
double classify(const FV &testObject, const typename Traits< FV >::ModelType &model) const
classify the given testObject using the model
const SVM & twoClassSVM() const
(description)
unsigned int predictClassIndex(const FV &testObject, const typename Traits< FV >::ModelType &model, ResultVector &resultVector) const
classify the given testObject using the model
The MultiClassSVMOneVsRest class provides a one vs.
void retrainWithLeftOutVectors(const GroupedTrainingData< FV > &trainData, const typename Traits< FV >::ModelType &fullModel, const std::vector< char > &leaveOutFlagsByUID, typename Traits< FV >::ModelType &resultingModel) const
calls the retrainWithLeftOutVectors() for each two-class model, only if the model is affected by the ...
void saveParameters(STDATA &stData) const
void train(const SVM_Problem< FV > &problem, typename Traits< FV >::ModelType &model) const
train SVM with given Feature Vectors.
Model_MC_OneVsRest< typename SVM::template Traits< FV >::ModelType > ModelType
void updateKernelCache(const ForwardIter &fvBegin, const ForwardIter &fvEnd, Accessor accessor) const
call the updateKernelCache() method of selected two-class svm
bool operator<(const DecisionValueAndLabel &rhs) const
void train(ForwardIter FV_begin, const ForwardIter &FV_end, typename Traits< typename std::iterator_traits< ForwardIter >::value_type >::ModelType &model) const
train the Multi Class SVM with the given feature vectors.
Ensure that TESTCLASS provides a setProgressReporter() method.
void setProgressReporter(ProgressReporter *pr)
set a progress reporter object.
void clearKernelCache() const
call the clearKernelCache() method of selected two-class svm