00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041 #ifndef __OPENCV_ML_HPP__
00042 #define __OPENCV_ML_HPP__
00043
00044 #include "opencv2/core/core.hpp"
00045 #include <limits.h>
00046
00047 #ifdef __cplusplus
00048
00049
00050
00051 #undef check
00052
00053
00054
00055
00056
00057
00058 #define CV_LOG2PI (1.8378770664093454835606594728112)
00059
00060
00061 #define CV_COL_SAMPLE 0
00062
00063
00064 #define CV_ROW_SAMPLE 1
00065
00066 #define CV_IS_ROW_SAMPLE(flags) ((flags) & CV_ROW_SAMPLE)
00067
00068 struct CvVectors
00069 {
00070 int type;
00071 int dims, count;
00072 CvVectors* next;
00073 union
00074 {
00075 uchar** ptr;
00076 float** fl;
00077 double** db;
00078 } data;
00079 };
00080
00081 #if 0
00082
00083
00084
00085 typedef struct CvParamLattice
00086 {
00087 double min_val;
00088 double max_val;
00089 double step;
00090 }
00091 CvParamLattice;
00092
00093 CV_INLINE CvParamLattice cvParamLattice( double min_val, double max_val,
00094 double log_step )
00095 {
00096 CvParamLattice pl;
00097 pl.min_val = MIN( min_val, max_val );
00098 pl.max_val = MAX( min_val, max_val );
00099 pl.step = MAX( log_step, 1. );
00100 return pl;
00101 }
00102
00103 CV_INLINE CvParamLattice cvDefaultParamLattice( void )
00104 {
00105 CvParamLattice pl = {0,0,0};
00106 return pl;
00107 }
00108 #endif
00109
00110
00111 #define CV_VAR_NUMERICAL 0
00112 #define CV_VAR_ORDERED 0
00113 #define CV_VAR_CATEGORICAL 1
00114
00115 #define CV_TYPE_NAME_ML_SVM "opencv-ml-svm"
00116 #define CV_TYPE_NAME_ML_KNN "opencv-ml-knn"
00117 #define CV_TYPE_NAME_ML_NBAYES "opencv-ml-bayesian"
00118 #define CV_TYPE_NAME_ML_EM "opencv-ml-em"
00119 #define CV_TYPE_NAME_ML_BOOSTING "opencv-ml-boost-tree"
00120 #define CV_TYPE_NAME_ML_TREE "opencv-ml-tree"
00121 #define CV_TYPE_NAME_ML_ANN_MLP "opencv-ml-ann-mlp"
00122 #define CV_TYPE_NAME_ML_CNN "opencv-ml-cnn"
00123 #define CV_TYPE_NAME_ML_RTREES "opencv-ml-random-trees"
00124 #define CV_TYPE_NAME_ML_GBT "opencv-ml-gradient-boosting-trees"
00125
00126 #define CV_TRAIN_ERROR 0
00127 #define CV_TEST_ERROR 1
00128
00129 class CV_EXPORTS_W CvStatModel
00130 {
00131 public:
00132 CvStatModel();
00133 virtual ~CvStatModel();
00134
00135 virtual void clear();
00136
00137 CV_WRAP virtual void save( const char* filename, const char* name=0 ) const;
00138 CV_WRAP virtual void load( const char* filename, const char* name=0 );
00139
00140 virtual void write( CvFileStorage* storage, const char* name ) const;
00141 virtual void read( CvFileStorage* storage, CvFileNode* node );
00142
00143 protected:
00144 const char* default_model_name;
00145 };
00146
00147
00148
00149
00150
00151
00152
00153
00154
00155
00156 class CvMLData;
00157
00158 struct CV_EXPORTS_W_MAP CvParamGrid
00159 {
00160
00161 enum { SVM_C=0, SVM_GAMMA=1, SVM_P=2, SVM_NU=3, SVM_COEF=4, SVM_DEGREE=5 };
00162
00163 CvParamGrid()
00164 {
00165 min_val = max_val = step = 0;
00166 }
00167
00168 CvParamGrid( double _min_val, double _max_val, double log_step )
00169 {
00170 min_val = _min_val;
00171 max_val = _max_val;
00172 step = log_step;
00173 }
00174
00175 bool check() const;
00176
00177 CV_PROP_RW double min_val;
00178 CV_PROP_RW double max_val;
00179 CV_PROP_RW double step;
00180 };
00181
00182 class CV_EXPORTS_W CvNormalBayesClassifier : public CvStatModel
00183 {
00184 public:
00185 CV_WRAP CvNormalBayesClassifier();
00186 virtual ~CvNormalBayesClassifier();
00187
00188 CvNormalBayesClassifier( const CvMat* trainData, const CvMat* responses,
00189 const CvMat* varIdx=0, const CvMat* sampleIdx=0 );
00190
00191 virtual bool train( const CvMat* trainData, const CvMat* responses,
00192 const CvMat* varIdx = 0, const CvMat* sampleIdx=0, bool update=false );
00193
00194 virtual float predict( const CvMat* samples, CV_OUT CvMat* results=0 ) const;
00195 CV_WRAP virtual void clear();
00196
00197 #ifndef SWIG
00198 CV_WRAP CvNormalBayesClassifier( const cv::Mat& trainData, const cv::Mat& responses,
00199 const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat() );
00200 CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
00201 const cv::Mat& varIdx = cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
00202 bool update=false );
00203 CV_WRAP virtual float predict( const cv::Mat& samples, CV_OUT cv::Mat* results=0 ) const;
00204 #endif
00205
00206 virtual void write( CvFileStorage* storage, const char* name ) const;
00207 virtual void read( CvFileStorage* storage, CvFileNode* node );
00208
00209 protected:
00210 int var_count, var_all;
00211 CvMat* var_idx;
00212 CvMat* cls_labels;
00213 CvMat** count;
00214 CvMat** sum;
00215 CvMat** productsum;
00216 CvMat** avg;
00217 CvMat** inv_eigen_values;
00218 CvMat** cov_rotate_mats;
00219 CvMat* c;
00220 };
00221
00222
00223
00224
00225
00226
00227
00228 class CV_EXPORTS_W CvKNearest : public CvStatModel
00229 {
00230 public:
00231
00232 CV_WRAP CvKNearest();
00233 virtual ~CvKNearest();
00234
00235 CvKNearest( const CvMat* trainData, const CvMat* responses,
00236 const CvMat* sampleIdx=0, bool isRegression=false, int max_k=32 );
00237
00238 virtual bool train( const CvMat* trainData, const CvMat* responses,
00239 const CvMat* sampleIdx=0, bool is_regression=false,
00240 int maxK=32, bool updateBase=false );
00241
00242 virtual float find_nearest( const CvMat* samples, int k, CV_OUT CvMat* results=0,
00243 const float** neighbors=0, CV_OUT CvMat* neighborResponses=0, CV_OUT CvMat* dist=0 ) const;
00244
00245 #ifndef SWIG
00246 CV_WRAP CvKNearest( const cv::Mat& trainData, const cv::Mat& responses,
00247 const cv::Mat& sampleIdx=cv::Mat(), bool isRegression=false, int max_k=32 );
00248
00249 CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
00250 const cv::Mat& sampleIdx=cv::Mat(), bool isRegression=false,
00251 int maxK=32, bool updateBase=false );
00252
00253 virtual float find_nearest( const cv::Mat& samples, int k, cv::Mat* results=0,
00254 const float** neighbors=0, cv::Mat* neighborResponses=0,
00255 cv::Mat* dist=0 ) const;
00256 CV_WRAP virtual float find_nearest( const cv::Mat& samples, int k, CV_OUT cv::Mat& results,
00257 CV_OUT cv::Mat& neighborResponses, CV_OUT cv::Mat& dists) const;
00258 #endif
00259
00260 virtual void clear();
00261 int get_max_k() const;
00262 int get_var_count() const;
00263 int get_sample_count() const;
00264 bool is_regression() const;
00265
00266 virtual float write_results( int k, int k1, int start, int end,
00267 const float* neighbor_responses, const float* dist, CvMat* _results,
00268 CvMat* _neighbor_responses, CvMat* _dist, Cv32suf* sort_buf ) const;
00269
00270 virtual void find_neighbors_direct( const CvMat* _samples, int k, int start, int end,
00271 float* neighbor_responses, const float** neighbors, float* dist ) const;
00272
00273 protected:
00274
00275 int max_k, var_count;
00276 int total;
00277 bool regression;
00278 CvVectors* samples;
00279 };
00280
00281
00282
00283
00284
00285
00286 struct CV_EXPORTS_W_MAP CvSVMParams
00287 {
00288 CvSVMParams();
00289 CvSVMParams( int _svm_type, int _kernel_type,
00290 double _degree, double _gamma, double _coef0,
00291 double Cvalue, double _nu, double _p,
00292 CvMat* _class_weights, CvTermCriteria _term_crit );
00293
00294 CV_PROP_RW int svm_type;
00295 CV_PROP_RW int kernel_type;
00296 CV_PROP_RW double degree;
00297 CV_PROP_RW double gamma;
00298 CV_PROP_RW double coef0;
00299
00300 CV_PROP_RW double C;
00301 CV_PROP_RW double nu;
00302 CV_PROP_RW double p;
00303 CvMat* class_weights;
00304 CV_PROP_RW CvTermCriteria term_crit;
00305 };
00306
00307
00308 struct CV_EXPORTS CvSVMKernel
00309 {
00310 typedef void (CvSVMKernel::*Calc)( int vec_count, int vec_size, const float** vecs,
00311 const float* another, float* results );
00312 CvSVMKernel();
00313 CvSVMKernel( const CvSVMParams* params, Calc _calc_func );
00314 virtual bool create( const CvSVMParams* params, Calc _calc_func );
00315 virtual ~CvSVMKernel();
00316
00317 virtual void clear();
00318 virtual void calc( int vcount, int n, const float** vecs, const float* another, float* results );
00319
00320 const CvSVMParams* params;
00321 Calc calc_func;
00322
00323 virtual void calc_non_rbf_base( int vec_count, int vec_size, const float** vecs,
00324 const float* another, float* results,
00325 double alpha, double beta );
00326
00327 virtual void calc_linear( int vec_count, int vec_size, const float** vecs,
00328 const float* another, float* results );
00329 virtual void calc_rbf( int vec_count, int vec_size, const float** vecs,
00330 const float* another, float* results );
00331 virtual void calc_poly( int vec_count, int vec_size, const float** vecs,
00332 const float* another, float* results );
00333 virtual void calc_sigmoid( int vec_count, int vec_size, const float** vecs,
00334 const float* another, float* results );
00335 };
00336
00337
00338 struct CvSVMKernelRow
00339 {
00340 CvSVMKernelRow* prev;
00341 CvSVMKernelRow* next;
00342 float* data;
00343 };
00344
00345
00346 struct CvSVMSolutionInfo
00347 {
00348 double obj;
00349 double rho;
00350 double upper_bound_p;
00351 double upper_bound_n;
00352 double r;
00353 };
00354
00355 class CV_EXPORTS CvSVMSolver
00356 {
00357 public:
00358 typedef bool (CvSVMSolver::*SelectWorkingSet)( int& i, int& j );
00359 typedef float* (CvSVMSolver::*GetRow)( int i, float* row, float* dst, bool existed );
00360 typedef void (CvSVMSolver::*CalcRho)( double& rho, double& r );
00361
00362 CvSVMSolver();
00363
00364 CvSVMSolver( int count, int var_count, const float** samples, schar* y,
00365 int alpha_count, double* alpha, double Cp, double Cn,
00366 CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
00367 SelectWorkingSet select_working_set, CalcRho calc_rho );
00368 virtual bool create( int count, int var_count, const float** samples, schar* y,
00369 int alpha_count, double* alpha, double Cp, double Cn,
00370 CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
00371 SelectWorkingSet select_working_set, CalcRho calc_rho );
00372 virtual ~CvSVMSolver();
00373
00374 virtual void clear();
00375 virtual bool solve_generic( CvSVMSolutionInfo& si );
00376
00377 virtual bool solve_c_svc( int count, int var_count, const float** samples, schar* y,
00378 double Cp, double Cn, CvMemStorage* storage,
00379 CvSVMKernel* kernel, double* alpha, CvSVMSolutionInfo& si );
00380 virtual bool solve_nu_svc( int count, int var_count, const float** samples, schar* y,
00381 CvMemStorage* storage, CvSVMKernel* kernel,
00382 double* alpha, CvSVMSolutionInfo& si );
00383 virtual bool solve_one_class( int count, int var_count, const float** samples,
00384 CvMemStorage* storage, CvSVMKernel* kernel,
00385 double* alpha, CvSVMSolutionInfo& si );
00386
00387 virtual bool solve_eps_svr( int count, int var_count, const float** samples, const float* y,
00388 CvMemStorage* storage, CvSVMKernel* kernel,
00389 double* alpha, CvSVMSolutionInfo& si );
00390
00391 virtual bool solve_nu_svr( int count, int var_count, const float** samples, const float* y,
00392 CvMemStorage* storage, CvSVMKernel* kernel,
00393 double* alpha, CvSVMSolutionInfo& si );
00394
00395 virtual float* get_row_base( int i, bool* _existed );
00396 virtual float* get_row( int i, float* dst );
00397
00398 int sample_count;
00399 int var_count;
00400 int cache_size;
00401 int cache_line_size;
00402 const float** samples;
00403 const CvSVMParams* params;
00404 CvMemStorage* storage;
00405 CvSVMKernelRow lru_list;
00406 CvSVMKernelRow* rows;
00407
00408 int alpha_count;
00409
00410 double* G;
00411 double* alpha;
00412
00413
00414 schar* alpha_status;
00415
00416 schar* y;
00417 double* b;
00418 float* buf[2];
00419 double eps;
00420 int max_iter;
00421 double C[2];
00422 CvSVMKernel* kernel;
00423
00424 SelectWorkingSet select_working_set_func;
00425 CalcRho calc_rho_func;
00426 GetRow get_row_func;
00427
00428 virtual bool select_working_set( int& i, int& j );
00429 virtual bool select_working_set_nu_svm( int& i, int& j );
00430 virtual void calc_rho( double& rho, double& r );
00431 virtual void calc_rho_nu_svm( double& rho, double& r );
00432
00433 virtual float* get_row_svc( int i, float* row, float* dst, bool existed );
00434 virtual float* get_row_one_class( int i, float* row, float* dst, bool existed );
00435 virtual float* get_row_svr( int i, float* row, float* dst, bool existed );
00436 };
00437
00438
00439 struct CvSVMDecisionFunc
00440 {
00441 double rho;
00442 int sv_count;
00443 double* alpha;
00444 int* sv_index;
00445 };
00446
00447
00448
00449 class CV_EXPORTS_W CvSVM : public CvStatModel
00450 {
00451 public:
00452
00453 enum { C_SVC=100, NU_SVC=101, ONE_CLASS=102, EPS_SVR=103, NU_SVR=104 };
00454
00455
00456 enum { LINEAR=0, POLY=1, RBF=2, SIGMOID=3 };
00457
00458
00459 enum { C=0, GAMMA=1, P=2, NU=3, COEF=4, DEGREE=5 };
00460
00461 CV_WRAP CvSVM();
00462 virtual ~CvSVM();
00463
00464 CvSVM( const CvMat* trainData, const CvMat* responses,
00465 const CvMat* varIdx=0, const CvMat* sampleIdx=0,
00466 CvSVMParams params=CvSVMParams() );
00467
00468 virtual bool train( const CvMat* trainData, const CvMat* responses,
00469 const CvMat* varIdx=0, const CvMat* sampleIdx=0,
00470 CvSVMParams params=CvSVMParams() );
00471
00472 virtual bool train_auto( const CvMat* trainData, const CvMat* responses,
00473 const CvMat* varIdx, const CvMat* sampleIdx, CvSVMParams params,
00474 int kfold = 10,
00475 CvParamGrid Cgrid = get_default_grid(CvSVM::C),
00476 CvParamGrid gammaGrid = get_default_grid(CvSVM::GAMMA),
00477 CvParamGrid pGrid = get_default_grid(CvSVM::P),
00478 CvParamGrid nuGrid = get_default_grid(CvSVM::NU),
00479 CvParamGrid coeffGrid = get_default_grid(CvSVM::COEF),
00480 CvParamGrid degreeGrid = get_default_grid(CvSVM::DEGREE),
00481 bool balanced=false );
00482
00483 virtual float predict( const CvMat* sample, bool returnDFVal=false ) const;
00484 virtual float predict( const CvMat* samples, CvMat* results ) const;
00485
00486 #ifndef SWIG
00487 CV_WRAP CvSVM( const cv::Mat& trainData, const cv::Mat& responses,
00488 const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
00489 CvSVMParams params=CvSVMParams() );
00490
00491 CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
00492 const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
00493 CvSVMParams params=CvSVMParams() );
00494
00495 CV_WRAP virtual bool train_auto( const cv::Mat& trainData, const cv::Mat& responses,
00496 const cv::Mat& varIdx, const cv::Mat& sampleIdx, CvSVMParams params,
00497 int k_fold = 10,
00498 CvParamGrid Cgrid = CvSVM::get_default_grid(CvSVM::C),
00499 CvParamGrid gammaGrid = CvSVM::get_default_grid(CvSVM::GAMMA),
00500 CvParamGrid pGrid = CvSVM::get_default_grid(CvSVM::P),
00501 CvParamGrid nuGrid = CvSVM::get_default_grid(CvSVM::NU),
00502 CvParamGrid coeffGrid = CvSVM::get_default_grid(CvSVM::COEF),
00503 CvParamGrid degreeGrid = CvSVM::get_default_grid(CvSVM::DEGREE),
00504 bool balanced=false);
00505 CV_WRAP virtual float predict( const cv::Mat& sample, bool returnDFVal=false ) const;
00506 #endif
00507
00508 CV_WRAP virtual int get_support_vector_count() const;
00509 virtual const float* get_support_vector(int i) const;
00510 virtual CvSVMParams get_params() const { return params; };
00511 CV_WRAP virtual void clear();
00512
00513 static CvParamGrid get_default_grid( int param_id );
00514
00515 virtual void write( CvFileStorage* storage, const char* name ) const;
00516 virtual void read( CvFileStorage* storage, CvFileNode* node );
00517 CV_WRAP int get_var_count() const { return var_idx ? var_idx->cols : var_all; }
00518
00519 protected:
00520
00521 virtual bool set_params( const CvSVMParams& params );
00522 virtual bool train1( int sample_count, int var_count, const float** samples,
00523 const void* responses, double Cp, double Cn,
00524 CvMemStorage* _storage, double* alpha, double& rho );
00525 virtual bool do_train( int svm_type, int sample_count, int var_count, const float** samples,
00526 const CvMat* responses, CvMemStorage* _storage, double* alpha );
00527 virtual void create_kernel();
00528 virtual void create_solver();
00529
00530 virtual float predict( const float* row_sample, int row_len, bool returnDFVal=false ) const;
00531
00532 virtual void write_params( CvFileStorage* fs ) const;
00533 virtual void read_params( CvFileStorage* fs, CvFileNode* node );
00534
00535 CvSVMParams params;
00536 CvMat* class_labels;
00537 int var_all;
00538 float** sv;
00539 int sv_total;
00540 CvMat* var_idx;
00541 CvMat* class_weights;
00542 CvSVMDecisionFunc* decision_func;
00543 CvMemStorage* storage;
00544
00545 CvSVMSolver* solver;
00546 CvSVMKernel* kernel;
00547 };
00548
00549
00550
00551
00552
00553 struct CV_EXPORTS_W_MAP CvEMParams
00554 {
00555 CvEMParams() : nclusters(10), cov_mat_type(1),
00556 start_step(0), probs(0), weights(0), means(0), covs(0)
00557 {
00558 term_crit=cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 100, FLT_EPSILON );
00559 }
00560
00561 CvEMParams( int _nclusters, int _cov_mat_type=1,
00562 int _start_step=0,
00563 CvTermCriteria _term_crit=cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 100, FLT_EPSILON),
00564 const CvMat* _probs=0, const CvMat* _weights=0, const CvMat* _means=0, const CvMat** _covs=0 ) :
00565 nclusters(_nclusters), cov_mat_type(_cov_mat_type), start_step(_start_step),
00566 probs(_probs), weights(_weights), means(_means), covs(_covs), term_crit(_term_crit)
00567 {}
00568
00569 CV_PROP_RW int nclusters;
00570 CV_PROP_RW int cov_mat_type;
00571 CV_PROP_RW int start_step;
00572 const CvMat* probs;
00573 const CvMat* weights;
00574 const CvMat* means;
00575 const CvMat** covs;
00576 CV_PROP_RW CvTermCriteria term_crit;
00577 };
00578
00579
00580 class CV_EXPORTS_W CvEM : public CvStatModel
00581 {
00582 public:
00583
00584 enum { COV_MAT_SPHERICAL=0, COV_MAT_DIAGONAL=1, COV_MAT_GENERIC=2 };
00585
00586
00587 enum { START_E_STEP=1, START_M_STEP=2, START_AUTO_STEP=0 };
00588
00589 CV_WRAP CvEM();
00590 CvEM( const CvMat* samples, const CvMat* sampleIdx=0,
00591 CvEMParams params=CvEMParams(), CvMat* labels=0 );
00592
00593
00594
00595 virtual ~CvEM();
00596
00597 virtual bool train( const CvMat* samples, const CvMat* sampleIdx=0,
00598 CvEMParams params=CvEMParams(), CvMat* labels=0 );
00599
00600 virtual float predict( const CvMat* sample, CV_OUT CvMat* probs ) const;
00601
00602 #ifndef SWIG
00603 CV_WRAP CvEM( const cv::Mat& samples, const cv::Mat& sampleIdx=cv::Mat(),
00604 CvEMParams params=CvEMParams() );
00605
00606 CV_WRAP virtual bool train( const cv::Mat& samples,
00607 const cv::Mat& sampleIdx=cv::Mat(),
00608 CvEMParams params=CvEMParams(),
00609 CV_OUT cv::Mat* labels=0 );
00610
00611 CV_WRAP virtual float predict( const cv::Mat& sample, CV_OUT cv::Mat* probs=0 ) const;
00612 CV_WRAP virtual double calcLikelihood( const cv::Mat &sample ) const;
00613
00614 CV_WRAP int getNClusters() const;
00615 CV_WRAP cv::Mat getMeans() const;
00616 CV_WRAP void getCovs(CV_OUT std::vector<cv::Mat>& covs) const;
00617 CV_WRAP cv::Mat getWeights() const;
00618 CV_WRAP cv::Mat getProbs() const;
00619
00620 CV_WRAP inline double getLikelihood() const { return log_likelihood; }
00621 CV_WRAP inline double getLikelihoodDelta() const { return log_likelihood_delta; }
00622 #endif
00623
00624 CV_WRAP virtual void clear();
00625
00626 int get_nclusters() const;
00627 const CvMat* get_means() const;
00628 const CvMat** get_covs() const;
00629 const CvMat* get_weights() const;
00630 const CvMat* get_probs() const;
00631
00632 inline double get_log_likelihood() const { return log_likelihood; }
00633 inline double get_log_likelihood_delta() const { return log_likelihood_delta; }
00634
00635
00636
00637
00638
00639 virtual void read( CvFileStorage* fs, CvFileNode* node );
00640 virtual void write( CvFileStorage* fs, const char* name ) const;
00641
00642 virtual void write_params( CvFileStorage* fs ) const;
00643 virtual void read_params( CvFileStorage* fs, CvFileNode* node );
00644
00645 protected:
00646
00647 virtual void set_params( const CvEMParams& params,
00648 const CvVectors& train_data );
00649 virtual void init_em( const CvVectors& train_data );
00650 virtual double run_em( const CvVectors& train_data );
00651 virtual void init_auto( const CvVectors& samples );
00652 virtual void kmeans( const CvVectors& train_data, int nclusters,
00653 CvMat* labels, CvTermCriteria criteria,
00654 const CvMat* means );
00655 CvEMParams params;
00656 double log_likelihood;
00657 double log_likelihood_delta;
00658
00659 CvMat* means;
00660 CvMat** covs;
00661 CvMat* weights;
00662 CvMat* probs;
00663
00664 CvMat* log_weight_div_det;
00665 CvMat* inv_eigen_values;
00666 CvMat** cov_rotate_mats;
00667 };
00668
00669
00670
00671 \
00672 struct CvPair16u32s
00673 {
00674 unsigned short* u;
00675 int* i;
00676 };
00677
00678
00679 #define CV_DTREE_CAT_DIR(idx,subset) \
00680 (2*((subset[(idx)>>5]&(1 << ((idx) & 31)))==0)-1)
00681
00682 struct CvDTreeSplit
00683 {
00684 int var_idx;
00685 int condensed_idx;
00686 int inversed;
00687 float quality;
00688 CvDTreeSplit* next;
00689 union
00690 {
00691 int subset[2];
00692 struct
00693 {
00694 float c;
00695 int split_point;
00696 }
00697 ord;
00698 };
00699 };
00700
00701 struct CvDTreeNode
00702 {
00703 int class_idx;
00704 int Tn;
00705 double value;
00706
00707 CvDTreeNode* parent;
00708 CvDTreeNode* left;
00709 CvDTreeNode* right;
00710
00711 CvDTreeSplit* split;
00712
00713 int sample_count;
00714 int depth;
00715 int* num_valid;
00716 int offset;
00717 int buf_idx;
00718 double maxlr;
00719
00720
00721 int complexity;
00722 double alpha;
00723 double node_risk, tree_risk, tree_error;
00724
00725
00726 int* cv_Tn;
00727 double* cv_node_risk;
00728 double* cv_node_error;
00729
00730 int get_num_valid(int vi) { return num_valid ? num_valid[vi] : sample_count; }
00731 void set_num_valid(int vi, int n) { if( num_valid ) num_valid[vi] = n; }
00732 };
00733
00734
00735 struct CV_EXPORTS_W_MAP CvDTreeParams
00736 {
00737 CV_PROP_RW int max_categories;
00738 CV_PROP_RW int max_depth;
00739 CV_PROP_RW int min_sample_count;
00740 CV_PROP_RW int cv_folds;
00741 CV_PROP_RW bool use_surrogates;
00742 CV_PROP_RW bool use_1se_rule;
00743 CV_PROP_RW bool truncate_pruned_tree;
00744 CV_PROP_RW float regression_accuracy;
00745 const float* priors;
00746
00747 CvDTreeParams() : max_categories(10), max_depth(INT_MAX), min_sample_count(10),
00748 cv_folds(10), use_surrogates(true), use_1se_rule(true),
00749 truncate_pruned_tree(true), regression_accuracy(0.01f), priors(0)
00750 {}
00751
00752 CvDTreeParams( int _max_depth, int _min_sample_count,
00753 float _regression_accuracy, bool _use_surrogates,
00754 int _max_categories, int _cv_folds,
00755 bool _use_1se_rule, bool _truncate_pruned_tree,
00756 const float* _priors ) :
00757 max_categories(_max_categories), max_depth(_max_depth),
00758 min_sample_count(_min_sample_count), cv_folds (_cv_folds),
00759 use_surrogates(_use_surrogates), use_1se_rule(_use_1se_rule),
00760 truncate_pruned_tree(_truncate_pruned_tree),
00761 regression_accuracy(_regression_accuracy),
00762 priors(_priors)
00763 {}
00764 };
00765
00766
00767 struct CV_EXPORTS CvDTreeTrainData
00768 {
00769 CvDTreeTrainData();
00770 CvDTreeTrainData( const CvMat* trainData, int tflag,
00771 const CvMat* responses, const CvMat* varIdx=0,
00772 const CvMat* sampleIdx=0, const CvMat* varType=0,
00773 const CvMat* missingDataMask=0,
00774 const CvDTreeParams& params=CvDTreeParams(),
00775 bool _shared=false, bool _add_labels=false );
00776 virtual ~CvDTreeTrainData();
00777
00778 virtual void set_data( const CvMat* trainData, int tflag,
00779 const CvMat* responses, const CvMat* varIdx=0,
00780 const CvMat* sampleIdx=0, const CvMat* varType=0,
00781 const CvMat* missingDataMask=0,
00782 const CvDTreeParams& params=CvDTreeParams(),
00783 bool _shared=false, bool _add_labels=false,
00784 bool _update_data=false );
00785 virtual void do_responses_copy();
00786
00787 virtual void get_vectors( const CvMat* _subsample_idx,
00788 float* values, uchar* missing, float* responses, bool get_class_idx=false );
00789
00790 virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
00791
00792 virtual void write_params( CvFileStorage* fs ) const;
00793 virtual void read_params( CvFileStorage* fs, CvFileNode* node );
00794
00795
00796 virtual void clear();
00797
00798 int get_num_classes() const;
00799 int get_var_type(int vi) const;
00800 int get_work_var_count() const {return work_var_count;}
00801
00802 virtual const float* get_ord_responses( CvDTreeNode* n, float* values_buf, int* sample_indices_buf );
00803 virtual const int* get_class_labels( CvDTreeNode* n, int* labels_buf );
00804 virtual const int* get_cv_labels( CvDTreeNode* n, int* labels_buf );
00805 virtual const int* get_sample_indices( CvDTreeNode* n, int* indices_buf );
00806 virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf );
00807 virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* sorted_indices_buf,
00808 const float** ord_values, const int** sorted_indices, int* sample_indices_buf );
00809 virtual int get_child_buf_idx( CvDTreeNode* n );
00810
00812
00813 virtual bool set_params( const CvDTreeParams& params );
00814 virtual CvDTreeNode* new_node( CvDTreeNode* parent, int count,
00815 int storage_idx, int offset );
00816
00817 virtual CvDTreeSplit* new_split_ord( int vi, float cmp_val,
00818 int split_point, int inversed, float quality );
00819 virtual CvDTreeSplit* new_split_cat( int vi, float quality );
00820 virtual void free_node_data( CvDTreeNode* node );
00821 virtual void free_train_data();
00822 virtual void free_node( CvDTreeNode* node );
00823
00824 int sample_count, var_all, var_count, max_c_count;
00825 int ord_var_count, cat_var_count, work_var_count;
00826 bool have_labels, have_priors;
00827 bool is_classifier;
00828 int tflag;
00829
00830 const CvMat* train_data;
00831 const CvMat* responses;
00832 CvMat* responses_copy;
00833
00834 int buf_count, buf_size;
00835 bool shared;
00836 int is_buf_16u;
00837
00838 CvMat* cat_count;
00839 CvMat* cat_ofs;
00840 CvMat* cat_map;
00841
00842 CvMat* counts;
00843 CvMat* buf;
00844 CvMat* direction;
00845 CvMat* split_buf;
00846
00847 CvMat* var_idx;
00848 CvMat* var_type;
00849
00850
00851 CvMat* priors;
00852 CvMat* priors_mult;
00853
00854 CvDTreeParams params;
00855
00856 CvMemStorage* tree_storage;
00857 CvMemStorage* temp_storage;
00858
00859 CvDTreeNode* data_root;
00860
00861 CvSet* node_heap;
00862 CvSet* split_heap;
00863 CvSet* cv_heap;
00864 CvSet* nv_heap;
00865
00866 cv::RNG* rng;
00867 };
00868
00869 class CvDTree;
00870 class CvForestTree;
00871
00872 namespace cv
00873 {
00874 struct DTreeBestSplitFinder;
00875 struct ForestTreeBestSplitFinder;
00876 }
00877
00878 class CV_EXPORTS_W CvDTree : public CvStatModel
00879 {
00880 public:
00881 CV_WRAP CvDTree();
00882 virtual ~CvDTree();
00883
00884 virtual bool train( const CvMat* trainData, int tflag,
00885 const CvMat* responses, const CvMat* varIdx=0,
00886 const CvMat* sampleIdx=0, const CvMat* varType=0,
00887 const CvMat* missingDataMask=0,
00888 CvDTreeParams params=CvDTreeParams() );
00889
00890 virtual bool train( CvMLData* trainData, CvDTreeParams params=CvDTreeParams() );
00891
00892
00893 virtual float calc_error( CvMLData* trainData, int type, std::vector<float> *resp = 0 );
00894
00895 virtual bool train( CvDTreeTrainData* trainData, const CvMat* subsampleIdx );
00896
00897 virtual CvDTreeNode* predict( const CvMat* sample, const CvMat* missingDataMask=0,
00898 bool preprocessedInput=false ) const;
00899
00900 #ifndef SWIG
00901 CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
00902 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
00903 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
00904 const cv::Mat& missingDataMask=cv::Mat(),
00905 CvDTreeParams params=CvDTreeParams() );
00906
00907 CV_WRAP virtual CvDTreeNode* predict( const cv::Mat& sample, const cv::Mat& missingDataMask=cv::Mat(),
00908 bool preprocessedInput=false ) const;
00909 CV_WRAP virtual cv::Mat getVarImportance();
00910 #endif
00911
00912 virtual const CvMat* get_var_importance();
00913 CV_WRAP virtual void clear();
00914
00915 virtual void read( CvFileStorage* fs, CvFileNode* node );
00916 virtual void write( CvFileStorage* fs, const char* name ) const;
00917
00918
00919 virtual void read( CvFileStorage* fs, CvFileNode* node,
00920 CvDTreeTrainData* data );
00921 virtual void write( CvFileStorage* fs ) const;
00922
00923 const CvDTreeNode* get_root() const;
00924 int get_pruned_tree_idx() const;
00925 CvDTreeTrainData* get_data();
00926
00927 protected:
00928 friend struct cv::DTreeBestSplitFinder;
00929
00930 virtual bool do_train( const CvMat* _subsample_idx );
00931
00932 virtual void try_split_node( CvDTreeNode* n );
00933 virtual void split_node_data( CvDTreeNode* n );
00934 virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
00935 virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi,
00936 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
00937 virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
00938 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
00939 virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi,
00940 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
00941 virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi,
00942 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
00943 virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
00944 virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
00945 virtual double calc_node_dir( CvDTreeNode* node );
00946 virtual void complete_node_dir( CvDTreeNode* node );
00947 virtual void cluster_categories( const int* vectors, int vector_count,
00948 int var_count, int* sums, int k, int* cluster_labels );
00949
00950 virtual void calc_node_value( CvDTreeNode* node );
00951
00952 virtual void prune_cv();
00953 virtual double update_tree_rnc( int T, int fold );
00954 virtual int cut_tree( int T, int fold, double min_alpha );
00955 virtual void free_prune_data(bool cut_tree);
00956 virtual void free_tree();
00957
00958 virtual void write_node( CvFileStorage* fs, CvDTreeNode* node ) const;
00959 virtual void write_split( CvFileStorage* fs, CvDTreeSplit* split ) const;
00960 virtual CvDTreeNode* read_node( CvFileStorage* fs, CvFileNode* node, CvDTreeNode* parent );
00961 virtual CvDTreeSplit* read_split( CvFileStorage* fs, CvFileNode* node );
00962 virtual void write_tree_nodes( CvFileStorage* fs ) const;
00963 virtual void read_tree_nodes( CvFileStorage* fs, CvFileNode* node );
00964
00965 CvDTreeNode* root;
00966 CvMat* var_importance;
00967 CvDTreeTrainData* data;
00968
00969 public:
00970 int pruned_tree_idx;
00971 };
00972
00973
00974
00975
00976
00977
00978 class CvRTrees;
00979
00980 class CV_EXPORTS CvForestTree: public CvDTree
00981 {
00982 public:
00983 CvForestTree();
00984 virtual ~CvForestTree();
00985
00986 virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx, CvRTrees* forest );
00987
00988 virtual int get_var_count() const {return data ? data->var_count : 0;}
00989 virtual void read( CvFileStorage* fs, CvFileNode* node, CvRTrees* forest, CvDTreeTrainData* _data );
00990
00991
00992 virtual bool train( const CvMat* trainData, int tflag,
00993 const CvMat* responses, const CvMat* varIdx=0,
00994 const CvMat* sampleIdx=0, const CvMat* varType=0,
00995 const CvMat* missingDataMask=0,
00996 CvDTreeParams params=CvDTreeParams() );
00997
00998 virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx );
00999 virtual void read( CvFileStorage* fs, CvFileNode* node );
01000 virtual void read( CvFileStorage* fs, CvFileNode* node,
01001 CvDTreeTrainData* data );
01002
01003
01004 protected:
01005 friend struct cv::ForestTreeBestSplitFinder;
01006
01007 virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
01008 CvRTrees* forest;
01009 };
01010
01011
01012 struct CV_EXPORTS_W_MAP CvRTParams : public CvDTreeParams
01013 {
01014
01015 CV_PROP_RW bool calc_var_importance;
01016 CV_PROP_RW int nactive_vars;
01017 CV_PROP_RW CvTermCriteria term_crit;
01018
01019 CvRTParams() : CvDTreeParams( 5, 10, 0, false, 10, 0, false, false, 0 ),
01020 calc_var_importance(false), nactive_vars(0)
01021 {
01022 term_crit = cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 50, 0.1 );
01023 }
01024
01025 CvRTParams( int _max_depth, int _min_sample_count,
01026 float _regression_accuracy, bool _use_surrogates,
01027 int _max_categories, const float* _priors, bool _calc_var_importance,
01028 int _nactive_vars, int max_num_of_trees_in_the_forest,
01029 float forest_accuracy, int termcrit_type ) :
01030 CvDTreeParams( _max_depth, _min_sample_count, _regression_accuracy,
01031 _use_surrogates, _max_categories, 0,
01032 false, false, _priors ),
01033 calc_var_importance(_calc_var_importance),
01034 nactive_vars(_nactive_vars)
01035 {
01036 term_crit = cvTermCriteria(termcrit_type,
01037 max_num_of_trees_in_the_forest, forest_accuracy);
01038 }
01039 };
01040
01041
01042 class CV_EXPORTS_W CvRTrees : public CvStatModel
01043 {
01044 public:
01045 CV_WRAP CvRTrees();
01046 virtual ~CvRTrees();
01047 virtual bool train( const CvMat* trainData, int tflag,
01048 const CvMat* responses, const CvMat* varIdx=0,
01049 const CvMat* sampleIdx=0, const CvMat* varType=0,
01050 const CvMat* missingDataMask=0,
01051 CvRTParams params=CvRTParams() );
01052
01053 virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );
01054 virtual float predict( const CvMat* sample, const CvMat* missing = 0 ) const;
01055 virtual float predict_prob( const CvMat* sample, const CvMat* missing = 0 ) const;
01056
01057 #ifndef SWIG
01058 CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
01059 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
01060 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
01061 const cv::Mat& missingDataMask=cv::Mat(),
01062 CvRTParams params=CvRTParams() );
01063 CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;
01064 CV_WRAP virtual float predict_prob( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;
01065 CV_WRAP virtual cv::Mat getVarImportance();
01066 #endif
01067
01068 CV_WRAP virtual void clear();
01069
01070 virtual const CvMat* get_var_importance();
01071 virtual float get_proximity( const CvMat* sample1, const CvMat* sample2,
01072 const CvMat* missing1 = 0, const CvMat* missing2 = 0 ) const;
01073
01074 virtual float calc_error( CvMLData* _data, int type , std::vector<float> *resp = 0 );
01075
01076 virtual float get_train_error();
01077
01078 virtual void read( CvFileStorage* fs, CvFileNode* node );
01079 virtual void write( CvFileStorage* fs, const char* name ) const;
01080
01081 CvMat* get_active_var_mask();
01082 CvRNG* get_rng();
01083
01084 int get_tree_count() const;
01085 CvForestTree* get_tree(int i) const;
01086
01087 protected:
01088
01089 virtual bool grow_forest( const CvTermCriteria term_crit );
01090
01091
01092 CvForestTree** trees;
01093 CvDTreeTrainData* data;
01094 int ntrees;
01095 int nclasses;
01096 double oob_error;
01097 CvMat* var_importance;
01098 int nsamples;
01099
01100 cv::RNG* rng;
01101 CvMat* active_var_mask;
01102 };
01103
01104
01105
01106
01107 struct CV_EXPORTS CvERTreeTrainData : public CvDTreeTrainData
01108 {
01109 virtual void set_data( const CvMat* trainData, int tflag,
01110 const CvMat* responses, const CvMat* varIdx=0,
01111 const CvMat* sampleIdx=0, const CvMat* varType=0,
01112 const CvMat* missingDataMask=0,
01113 const CvDTreeParams& params=CvDTreeParams(),
01114 bool _shared=false, bool _add_labels=false,
01115 bool _update_data=false );
01116 virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* missing_buf,
01117 const float** ord_values, const int** missing, int* sample_buf = 0 );
01118 virtual const int* get_sample_indices( CvDTreeNode* n, int* indices_buf );
01119 virtual const int* get_cv_labels( CvDTreeNode* n, int* labels_buf );
01120 virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf );
01121 virtual void get_vectors( const CvMat* _subsample_idx, float* values, uchar* missing,
01122 float* responses, bool get_class_idx=false );
01123 virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
01124 const CvMat* missing_mask;
01125 };
01126
01127 class CV_EXPORTS CvForestERTree : public CvForestTree
01128 {
01129 protected:
01130 virtual double calc_node_dir( CvDTreeNode* node );
01131 virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi,
01132 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01133 virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
01134 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01135 virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi,
01136 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01137 virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi,
01138 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01139 virtual void split_node_data( CvDTreeNode* n );
01140 };
01141
01142 class CV_EXPORTS_W CvERTrees : public CvRTrees
01143 {
01144 public:
01145 CV_WRAP CvERTrees();
01146 virtual ~CvERTrees();
01147 virtual bool train( const CvMat* trainData, int tflag,
01148 const CvMat* responses, const CvMat* varIdx=0,
01149 const CvMat* sampleIdx=0, const CvMat* varType=0,
01150 const CvMat* missingDataMask=0,
01151 CvRTParams params=CvRTParams());
01152 #ifndef SWIG
01153 CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
01154 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
01155 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
01156 const cv::Mat& missingDataMask=cv::Mat(),
01157 CvRTParams params=CvRTParams());
01158 #endif
01159 virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );
01160 protected:
01161 virtual bool grow_forest( const CvTermCriteria term_crit );
01162 };
01163
01164
01165
01166
01167
01168
01169 struct CV_EXPORTS_W_MAP CvBoostParams : public CvDTreeParams
01170 {
01171 CV_PROP_RW int boost_type;
01172 CV_PROP_RW int weak_count;
01173 CV_PROP_RW int split_criteria;
01174 CV_PROP_RW double weight_trim_rate;
01175
01176 CvBoostParams();
01177 CvBoostParams( int boost_type, int weak_count, double weight_trim_rate,
01178 int max_depth, bool use_surrogates, const float* priors );
01179 };
01180
01181
01182 class CvBoost;
01183
01184 class CV_EXPORTS CvBoostTree: public CvDTree
01185 {
01186 public:
01187 CvBoostTree();
01188 virtual ~CvBoostTree();
01189
01190 virtual bool train( CvDTreeTrainData* trainData,
01191 const CvMat* subsample_idx, CvBoost* ensemble );
01192
01193 virtual void scale( double s );
01194 virtual void read( CvFileStorage* fs, CvFileNode* node,
01195 CvBoost* ensemble, CvDTreeTrainData* _data );
01196 virtual void clear();
01197
01198
01199 virtual bool train( const CvMat* trainData, int tflag,
01200 const CvMat* responses, const CvMat* varIdx=0,
01201 const CvMat* sampleIdx=0, const CvMat* varType=0,
01202 const CvMat* missingDataMask=0,
01203 CvDTreeParams params=CvDTreeParams() );
01204 virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx );
01205
01206 virtual void read( CvFileStorage* fs, CvFileNode* node );
01207 virtual void read( CvFileStorage* fs, CvFileNode* node,
01208 CvDTreeTrainData* data );
01209
01210
01211 protected:
01212
01213 virtual void try_split_node( CvDTreeNode* n );
01214 virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
01215 virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
01216 virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi,
01217 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01218 virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
01219 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01220 virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi,
01221 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01222 virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi,
01223 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01224 virtual void calc_node_value( CvDTreeNode* n );
01225 virtual double calc_node_dir( CvDTreeNode* n );
01226
01227 CvBoost* ensemble;
01228 };
01229
01230
01231 class CV_EXPORTS_W CvBoost : public CvStatModel
01232 {
01233 public:
01234
01235 enum { DISCRETE=0, REAL=1, LOGIT=2, GENTLE=3 };
01236
01237
01238 enum { DEFAULT=0, GINI=1, MISCLASS=3, SQERR=4 };
01239
01240 CV_WRAP CvBoost();
01241 virtual ~CvBoost();
01242
01243 CvBoost( const CvMat* trainData, int tflag,
01244 const CvMat* responses, const CvMat* varIdx=0,
01245 const CvMat* sampleIdx=0, const CvMat* varType=0,
01246 const CvMat* missingDataMask=0,
01247 CvBoostParams params=CvBoostParams() );
01248
01249 virtual bool train( const CvMat* trainData, int tflag,
01250 const CvMat* responses, const CvMat* varIdx=0,
01251 const CvMat* sampleIdx=0, const CvMat* varType=0,
01252 const CvMat* missingDataMask=0,
01253 CvBoostParams params=CvBoostParams(),
01254 bool update=false );
01255
01256 virtual bool train( CvMLData* data,
01257 CvBoostParams params=CvBoostParams(),
01258 bool update=false );
01259
01260 virtual float predict( const CvMat* sample, const CvMat* missing=0,
01261 CvMat* weak_responses=0, CvSlice slice=CV_WHOLE_SEQ,
01262 bool raw_mode=false, bool return_sum=false ) const;
01263
01264 #ifndef SWIG
01265 CV_WRAP CvBoost( const cv::Mat& trainData, int tflag,
01266 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
01267 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
01268 const cv::Mat& missingDataMask=cv::Mat(),
01269 CvBoostParams params=CvBoostParams() );
01270
01271 CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
01272 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
01273 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
01274 const cv::Mat& missingDataMask=cv::Mat(),
01275 CvBoostParams params=CvBoostParams(),
01276 bool update=false );
01277
01278 CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing=cv::Mat(),
01279 const cv::Range& slice=cv::Range::all(), bool rawMode=false,
01280 bool returnSum=false ) const;
01281 #endif
01282
01283 virtual float calc_error( CvMLData* _data, int type , std::vector<float> *resp = 0 );
01284
01285 CV_WRAP virtual void prune( CvSlice slice );
01286
01287 CV_WRAP virtual void clear();
01288
01289 virtual void write( CvFileStorage* storage, const char* name ) const;
01290 virtual void read( CvFileStorage* storage, CvFileNode* node );
01291 virtual const CvMat* get_active_vars(bool absolute_idx=true);
01292
01293 CvSeq* get_weak_predictors();
01294
01295 CvMat* get_weights();
01296 CvMat* get_subtree_weights();
01297 CvMat* get_weak_response();
01298 const CvBoostParams& get_params() const;
01299 const CvDTreeTrainData* get_data() const;
01300
01301 protected:
01302
01303 virtual bool set_params( const CvBoostParams& params );
01304 virtual void update_weights( CvBoostTree* tree );
01305 virtual void trim_weights();
01306 virtual void write_params( CvFileStorage* fs ) const;
01307 virtual void read_params( CvFileStorage* fs, CvFileNode* node );
01308
01309 CvDTreeTrainData* data;
01310 CvBoostParams params;
01311 CvSeq* weak;
01312
01313 CvMat* active_vars;
01314 CvMat* active_vars_abs;
01315 bool have_active_cat_vars;
01316
01317 CvMat* orig_response;
01318 CvMat* sum_response;
01319 CvMat* weak_eval;
01320 CvMat* subsample_mask;
01321 CvMat* weights;
01322 CvMat* subtree_weights;
01323 bool have_subsample;
01324 };
01325
01326
01327
01328
01329
01330
01331
01332
01333
01334
01335
01336
01337
01338
01339
01340
01341
01342
01343
01344
01345
01346
01347
01348 struct CV_EXPORTS_W_MAP CvGBTreesParams : public CvDTreeParams
01349 {
01350 CV_PROP_RW int weak_count;
01351 CV_PROP_RW int loss_function_type;
01352 CV_PROP_RW float subsample_portion;
01353 CV_PROP_RW float shrinkage;
01354
01355 CvGBTreesParams();
01356 CvGBTreesParams( int loss_function_type, int weak_count, float shrinkage,
01357 float subsample_portion, int max_depth, bool use_surrogates );
01358 };
01359
01360
01361
01362
01363
01364
01365
01366
01367
01368
01369
01370
01371
01372
01373
01374
01375
01376
01377
01378
01379
01380
01381
01382
01383
01384
01385
01386
01387
01388
01389
01390
01391
01392
01393
01394
01395
01396
01397
01398
01399
01400
01401
01402
01403
01404
01405
01406
01407
01408 class CV_EXPORTS_W CvGBTrees : public CvStatModel
01409 {
01410 public:
01411
01412
01413
01414
01415
01416
01417
01418
01419
01420
01421
01422
01423
01424
01425
01426
01427
01428
01429
01430
01431
01432
01433
01434
01435 enum {SQUARED_LOSS=0, ABSOLUTE_LOSS, HUBER_LOSS=3, DEVIANCE_LOSS};
01436
01437
01438
01439
01440
01441
01442
01443
01444
01445
01446
01447
01448
01449 CV_WRAP CvGBTrees();
01450
01451
01452
01453
01454
01455
01456
01457
01458
01459
01460
01461
01462
01463
01464
01465
01466
01467
01468
01469
01470
01471
01472
01473
01474
01475
01476
01477
01478
01479
01480
01481
01482
01483
01484
01485
01486
01487
01488
01489 CvGBTrees( const CvMat* trainData, int tflag,
01490 const CvMat* responses, const CvMat* varIdx=0,
01491 const CvMat* sampleIdx=0, const CvMat* varType=0,
01492 const CvMat* missingDataMask=0,
01493 CvGBTreesParams params=CvGBTreesParams() );
01494
01495
01496
01497
01498
01499 virtual ~CvGBTrees();
01500
01501
01502
01503
01504
01505
01506
01507
01508
01509
01510
01511
01512
01513
01514
01515
01516
01517
01518
01519
01520
01521
01522
01523
01524
01525
01526
01527
01528
01529
01530
01531
01532
01533
01534
01535
01536
01537
01538
01539
01540
01541 virtual bool train( const CvMat* trainData, int tflag,
01542 const CvMat* responses, const CvMat* varIdx=0,
01543 const CvMat* sampleIdx=0, const CvMat* varType=0,
01544 const CvMat* missingDataMask=0,
01545 CvGBTreesParams params=CvGBTreesParams(),
01546 bool update=false );
01547
01548
01549
01550
01551
01552
01553
01554
01555
01556
01557
01558
01559
01560
01561
01562
01563
01564
01565 virtual bool train( CvMLData* data,
01566 CvGBTreesParams params=CvGBTreesParams(),
01567 bool update=false );
01568
01569
01570
01571
01572
01573
01574
01575
01576
01577
01578
01579
01580
01581
01582
01583
01584
01585
01586
01587
01588
01589
01590
01591
01592
01593
01594
01595
01596
01597 virtual float predict_serial( const CvMat* sample, const CvMat* missing=0,
01598 CvMat* weakResponses=0, CvSlice slice = CV_WHOLE_SEQ,
01599 int k=-1 ) const;
01600
01601
01602
01603
01604
01605
01606
01607
01608
01609
01610
01611
01612
01613
01614
01615
01616
01617
01618
01619
01620
01621
01622
01623
01624
01625
01626
01627
01628
01629 virtual float predict( const CvMat* sample, const CvMat* missing=0,
01630 CvMat* weakResponses=0, CvSlice slice = CV_WHOLE_SEQ,
01631 int k=-1 ) const;
01632
01633
01634
01635
01636
01637
01638
01639
01640
01641
01642
01643
01644
01645
01646
01647 CV_WRAP virtual void clear();
01648
01649
01650
01651
01652
01653
01654
01655
01656
01657
01658
01659
01660
01661
01662
01663
01664
01665 virtual float calc_error( CvMLData* _data, int type,
01666 std::vector<float> *resp = 0 );
01667
01668
01669
01670
01671
01672
01673
01674
01675
01676
01677
01678
01679
01680
01681 virtual void write( CvFileStorage* fs, const char* name ) const;
01682
01683
01684
01685
01686
01687
01688
01689
01690
01691
01692
01693
01694
01695
01696
01697 virtual void read( CvFileStorage* fs, CvFileNode* node );
01698
01699
01700
01701 CV_WRAP CvGBTrees( const cv::Mat& trainData, int tflag,
01702 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
01703 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
01704 const cv::Mat& missingDataMask=cv::Mat(),
01705 CvGBTreesParams params=CvGBTreesParams() );
01706
01707 CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
01708 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
01709 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
01710 const cv::Mat& missingDataMask=cv::Mat(),
01711 CvGBTreesParams params=CvGBTreesParams(),
01712 bool update=false );
01713
01714 CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing=cv::Mat(),
01715 const cv::Range& slice = cv::Range::all(),
01716 int k=-1 ) const;
01717
01718 protected:
01719
01720
01721
01722
01723
01724
01725
01726
01727
01728
01729
01730
01731
01732
01733
01734
01735 virtual void find_gradient( const int k = 0);
01736
01737
01738
01739
01740
01741
01742
01743
01744
01745
01746
01747
01748
01749
01750
01751
01752
01753
01754 virtual void change_values(CvDTree* tree, const int k = 0);
01755
01756
01757
01758
01759
01760
01761
01762
01763
01764
01765
01766
01767
01768
01769
01770
01771
01772
01773 virtual float find_optimal_value( const CvMat* _Idx );
01774
01775
01776
01777
01778
01779
01780
01781
01782
01783
01784
01785
01786
01787
01788
01789
01790 virtual void do_subsample();
01791
01792
01793
01794
01795
01796
01797
01798
01799
01800
01801
01802
01803
01804
01805
01806
01807 void leaves_get( CvDTreeNode** leaves, int& count, CvDTreeNode* node );
01808
01809
01810
01811
01812
01813
01814
01815
01816
01817
01818
01819
01820
01821
01822
01823
01824 CvDTreeNode** GetLeaves( const CvDTree* dtree, int& len );
01825
01826
01827
01828
01829
01830
01831
01832
01833
01834
01835
01836
01837
01838
01839
01840 virtual bool problem_type() const;
01841
01842
01843
01844
01845
01846
01847
01848
01849
01850
01851
01852
01853
01854
01855 virtual void write_params( CvFileStorage* fs ) const;
01856
01857
01858
01859
01860
01861
01862
01863
01864
01865
01866
01867
01868
01869
01870
01871
01872
01873
01874
01875 virtual void read_params( CvFileStorage* fs, CvFileNode* fnode );
01876 int get_len(const CvMat* mat) const;
01877
01878
01879 CvDTreeTrainData* data;
01880 CvGBTreesParams params;
01881
01882 CvSeq** weak;
01883 CvMat* orig_response;
01884 CvMat* sum_response;
01885 CvMat* sum_response_tmp;
01886 CvMat* sample_idx;
01887 CvMat* subsample_train;
01888 CvMat* subsample_test;
01889 CvMat* missing;
01890 CvMat* class_labels;
01891
01892 cv::RNG* rng;
01893
01894 int class_count;
01895 float delta;
01896 float base_value;
01897
01898 };
01899
01900
01901
01902
01903
01904
01905
01907
01908 struct CV_EXPORTS_W_MAP CvANN_MLP_TrainParams
01909 {
01910 CvANN_MLP_TrainParams();
01911 CvANN_MLP_TrainParams( CvTermCriteria term_crit, int train_method,
01912 double param1, double param2=0 );
01913 ~CvANN_MLP_TrainParams();
01914
01915 enum { BACKPROP=0, RPROP=1 };
01916
01917 CV_PROP_RW CvTermCriteria term_crit;
01918 CV_PROP_RW int train_method;
01919
01920
01921 CV_PROP_RW double bp_dw_scale, bp_moment_scale;
01922
01923
01924 CV_PROP_RW double rp_dw0, rp_dw_plus, rp_dw_minus, rp_dw_min, rp_dw_max;
01925 };
01926
01927
01928 class CV_EXPORTS_W CvANN_MLP : public CvStatModel
01929 {
01930 public:
01931 CV_WRAP CvANN_MLP();
01932 CvANN_MLP( const CvMat* layerSizes,
01933 int activateFunc=CvANN_MLP::SIGMOID_SYM,
01934 double fparam1=0, double fparam2=0 );
01935
01936 virtual ~CvANN_MLP();
01937
01938 virtual void create( const CvMat* layerSizes,
01939 int activateFunc=CvANN_MLP::SIGMOID_SYM,
01940 double fparam1=0, double fparam2=0 );
01941
01942 virtual int train( const CvMat* inputs, const CvMat* outputs,
01943 const CvMat* sampleWeights, const CvMat* sampleIdx=0,
01944 CvANN_MLP_TrainParams params = CvANN_MLP_TrainParams(),
01945 int flags=0 );
01946 virtual float predict( const CvMat* inputs, CV_OUT CvMat* outputs ) const;
01947
01948 #ifndef SWIG
01949 CV_WRAP CvANN_MLP( const cv::Mat& layerSizes,
01950 int activateFunc=CvANN_MLP::SIGMOID_SYM,
01951 double fparam1=0, double fparam2=0 );
01952
01953 CV_WRAP virtual void create( const cv::Mat& layerSizes,
01954 int activateFunc=CvANN_MLP::SIGMOID_SYM,
01955 double fparam1=0, double fparam2=0 );
01956
01957 CV_WRAP virtual int train( const cv::Mat& inputs, const cv::Mat& outputs,
01958 const cv::Mat& sampleWeights, const cv::Mat& sampleIdx=cv::Mat(),
01959 CvANN_MLP_TrainParams params = CvANN_MLP_TrainParams(),
01960 int flags=0 );
01961
01962 CV_WRAP virtual float predict( const cv::Mat& inputs, cv::Mat& outputs ) const;
01963 #endif
01964
01965 CV_WRAP virtual void clear();
01966
01967
01968 enum { IDENTITY = 0, SIGMOID_SYM = 1, GAUSSIAN = 2 };
01969
01970
01971 enum { UPDATE_WEIGHTS = 1, NO_INPUT_SCALE = 2, NO_OUTPUT_SCALE = 4 };
01972
01973 virtual void read( CvFileStorage* fs, CvFileNode* node );
01974 virtual void write( CvFileStorage* storage, const char* name ) const;
01975
01976 int get_layer_count() { return layer_sizes ? layer_sizes->cols : 0; }
01977 const CvMat* get_layer_sizes() { return layer_sizes; }
01978 double* get_weights(int layer)
01979 {
01980 return layer_sizes && weights &&
01981 (unsigned)layer <= (unsigned)layer_sizes->cols ? weights[layer] : 0;
01982 }
01983
01984 virtual void calc_activ_func_deriv( CvMat* xf, CvMat* deriv, const double* bias ) const;
01985
01986 protected:
01987
01988 virtual bool prepare_to_train( const CvMat* _inputs, const CvMat* _outputs,
01989 const CvMat* _sample_weights, const CvMat* sampleIdx,
01990 CvVectors* _ivecs, CvVectors* _ovecs, double** _sw, int _flags );
01991
01992
01993 virtual int train_backprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
01994
01995
01996 virtual int train_rprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
01997
01998 virtual void calc_activ_func( CvMat* xf, const double* bias ) const;
01999 virtual void set_activ_func( int _activ_func=SIGMOID_SYM,
02000 double _f_param1=0, double _f_param2=0 );
02001 virtual void init_weights();
02002 virtual void scale_input( const CvMat* _src, CvMat* _dst ) const;
02003 virtual void scale_output( const CvMat* _src, CvMat* _dst ) const;
02004 virtual void calc_input_scale( const CvVectors* vecs, int flags );
02005 virtual void calc_output_scale( const CvVectors* vecs, int flags );
02006
02007 virtual void write_params( CvFileStorage* fs ) const;
02008 virtual void read_params( CvFileStorage* fs, CvFileNode* node );
02009
02010 CvMat* layer_sizes;
02011 CvMat* wbuf;
02012 CvMat* sample_weights;
02013 double** weights;
02014 double f_param1, f_param2;
02015 double min_val, max_val, min_val1, max_val1;
02016 int activ_func;
02017 int max_count, max_buf_sz;
02018 CvANN_MLP_TrainParams params;
02019 cv::RNG* rng;
02020 };
02021
02022
02023
02024
02025
02026
02027
02028 CVAPI(void) cvRandMVNormal( CvMat* mean, CvMat* cov, CvMat* sample,
02029 CvRNG* rng CV_DEFAULT(0) );
02030
02031
02032 CVAPI(void) cvRandGaussMixture( CvMat* means[],
02033 CvMat* covs[],
02034 float weights[],
02035 int clsnum,
02036 CvMat* sample,
02037 CvMat* sampClasses CV_DEFAULT(0) );
02038
02039 #define CV_TS_CONCENTRIC_SPHERES 0
02040
02041
02042 CVAPI(void) cvCreateTestSet( int type, CvMat** samples,
02043 int num_samples,
02044 int num_features,
02045 CvMat** responses,
02046 int num_classes, ... );
02047
02048
02049 #endif
02050
02051
02052
02053
02054
02055 #include <map>
02056 #include <string>
02057 #include <iostream>
02058
02059 #define CV_COUNT 0
02060 #define CV_PORTION 1
02061
02062 struct CV_EXPORTS CvTrainTestSplit
02063 {
02064 CvTrainTestSplit();
02065 CvTrainTestSplit( int train_sample_count, bool mix = true);
02066 CvTrainTestSplit( float train_sample_portion, bool mix = true);
02067
02068 union
02069 {
02070 int count;
02071 float portion;
02072 } train_sample_part;
02073 int train_sample_part_mode;
02074
02075 bool mix;
02076 };
02077
02078 class CV_EXPORTS CvMLData
02079 {
02080 public:
02081 CvMLData();
02082 virtual ~CvMLData();
02083
02084
02085
02086
02087 int read_csv( const char* filename );
02088
02089 const CvMat* get_values() const;
02090 const CvMat* get_responses();
02091 const CvMat* get_missing() const;
02092
02093 void set_response_idx( int idx );
02094
02095 int get_response_idx() const;
02096
02097 void set_train_test_split( const CvTrainTestSplit * spl );
02098 const CvMat* get_train_sample_idx() const;
02099 const CvMat* get_test_sample_idx() const;
02100 void mix_train_and_test_idx();
02101
02102 const CvMat* get_var_idx();
02103 void chahge_var_idx( int vi, bool state );
02104
02105 const CvMat* get_var_types();
02106 int get_var_type( int var_idx ) const;
02107
02108
02109
02110 void set_var_types( const char* str );
02111
02112
02113 void change_var_type( int var_idx, int type);
02114
02115 void set_delimiter( char ch );
02116 char get_delimiter() const;
02117
02118 void set_miss_ch( char ch );
02119 char get_miss_ch() const;
02120
02121 const std::map<std::string, int>& get_class_labels_map() const;
02122
02123 protected:
02124 virtual void clear();
02125
02126 void str_to_flt_elem( const char* token, float& flt_elem, int& type);
02127 void free_train_test_idx();
02128
02129 char delimiter;
02130 char miss_ch;
02131
02132
02133 CvMat* values;
02134 CvMat* missing;
02135 CvMat* var_types;
02136 CvMat* var_idx_mask;
02137
02138 CvMat* response_out;
02139 CvMat* var_idx_out;
02140 CvMat* var_types_out;
02141
02142 int response_idx;
02143
02144 int train_sample_count;
02145 bool mix;
02146
02147 int total_class_count;
02148 std::map<std::string, int> class_map;
02149
02150 CvMat* train_sample_idx;
02151 CvMat* test_sample_idx;
02152 int* sample_idx;
02153
02154 cv::RNG* rng;
02155 };
02156
02157
02158 namespace cv
02159 {
02160
02161 typedef CvStatModel StatModel;
02162 typedef CvParamGrid ParamGrid;
02163 typedef CvNormalBayesClassifier NormalBayesClassifier;
02164 typedef CvKNearest KNearest;
02165 typedef CvSVMParams SVMParams;
02166 typedef CvSVMKernel SVMKernel;
02167 typedef CvSVMSolver SVMSolver;
02168 typedef CvSVM SVM;
02169 typedef CvEMParams EMParams;
02170 typedef CvEM ExpectationMaximization;
02171 typedef CvDTreeParams DTreeParams;
02172 typedef CvMLData TrainData;
02173 typedef CvDTree DecisionTree;
02174 typedef CvForestTree ForestTree;
02175 typedef CvRTParams RandomTreeParams;
02176 typedef CvRTrees RandomTrees;
02177 typedef CvERTreeTrainData ERTreeTRainData;
02178 typedef CvForestERTree ERTree;
02179 typedef CvERTrees ERTrees;
02180 typedef CvBoostParams BoostParams;
02181 typedef CvBoostTree BoostTree;
02182 typedef CvBoost Boost;
02183 typedef CvANN_MLP_TrainParams ANN_MLP_TrainParams;
02184 typedef CvANN_MLP NeuralNet_MLP;
02185 typedef CvGBTreesParams GradientBoostingTreeParams;
02186 typedef CvGBTrees GradientBoostingTrees;
02187
02188 template<> CV_EXPORTS void Ptr<CvDTreeSplit>::delete_obj();
02189
02190 }
02191
02192 #endif
02193