RapidLib  v2.1.0
A simple library for interactive machine learning
svmClassification.h
Go to the documentation of this file.
1 
10 #ifndef svm_h
11 #define svm_h
12 
13 #include <vector>
14 #include "baseModel.h"
15 #include "../dependencies/libsvm/libsvm.h"
16 
17 template<typename T>
18 class svmClassification : public baseModel<T> {
19 
20 public:
23 
44  KernelType kernelType = LINEAR_KERNEL,
45  SVMType svmType = C_SVC,
46  bool useScaling = true,
47  bool useNullRejection = false,
48  bool useAutoGamma = true,
49  float gamma = 0.1,
50  unsigned int degree = 3,
51  float coef0 = 0,
52  float nu = 0.5,
53  float C = 1,
54  bool useCrossValidation = false,
55  unsigned int kFoldValue = 10
56  );
57 
58  svmClassification(int numInputs);
59 
62 
68  void train(const std::vector<trainingExampleTemplate<T> > &trainingSet);
69 
74  T run(const std::vector<T> &inputVector);
75 
76  void reset();
77 
95  bool init(KernelType kernelType,SVMType svmType,bool useScaling,bool useNullRejection,bool useAutoGamma,
96  float gamma,
97  unsigned int degree,
98  float coef0,
99  float nu,
100  float C,
101  bool useCrossValidation,
102  unsigned int kFoldValue
103  );
104 
105  int getNumInputs() const;
106  std::vector<int> getWhichInputs() const;
107 
108 
109 #ifndef EMSCRIPTEN
110  void getJSONDescription(Json::Value &currentModel);
111 #endif
112 
113 private:
114  bool problemSet;
115  struct LIBSVM::svm_model *model;
116  struct LIBSVM::svm_parameter param;
117  struct LIBSVM::svm_problem problem;
118 
119  int numInputs;
120 
122  std::vector<double> inRanges;
123  std::vector<double> inBases;
124 
125  bool trained;
126 };
127 
128 #endif
129 
130 
131 
132 
133 
Definition: svmClassification.h:21
Definition: trainingExample.h:18
std::vector< int > getWhichInputs() const
Definition: svmClassification.cpp:255
void train(const std::vector< trainingExampleTemplate< T > > &trainingSet)
Definition: svmClassification.cpp:172
Definition: svmClassification.h:22
Definition: svmClassification.h:22
Definition: svmClassification.h:18
Definition: svmClassification.h:22
Definition: svmClassification.h:21
T run(const std::vector< T > &inputVector)
Definition: svmClassification.cpp:227
void getJSONDescription(Json::Value &currentModel)
Definition: svmClassification.cpp:262
Definition: svmClassification.h:22
~svmClassification()
Definition: svmClassification.cpp:103
Definition: baseModel.h:23
KernelType
Definition: svmClassification.h:22
Definition: svmClassification.h:21
Definition: svmClassification.h:22
Definition: svmClassification.h:21
Definition: svmClassification.h:21
int getNumInputs() const
Definition: svmClassification.cpp:250
void reset()
Definition: svmClassification.cpp:108
SVMType
Definition: svmClassification.h:21
bool init(KernelType kernelType, SVMType svmType, bool useScaling, bool useNullRejection, bool useAutoGamma, float gamma, unsigned int degree, float coef0, float nu, float C, bool useCrossValidation, unsigned int kFoldValue)
Definition: svmClassification.cpp:113
svmClassification(KernelType kernelType=LINEAR_KERNEL, SVMType svmType=C_SVC, bool useScaling=true, bool useNullRejection=false, bool useAutoGamma=true, float gamma=0.1, unsigned int degree=3, float coef0=0, float nu=0.5, float C=1, bool useCrossValidation=false, unsigned int kFoldValue=10)
Definition: svmClassification.cpp:17