00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041 #ifndef __OPENCV_ML_HPP__
00042 #define __OPENCV_ML_HPP__
00043
00044 #include "opencv2/core/core.hpp"
00045 #include <limits.h>
00046
00047 #ifdef __cplusplus
00048
00049 #include <map>
00050 #include <string>
00051 #include <iostream>
00052
00053
00054
00055 #undef check
00056
00057
00058
00059
00060
00061
00062 #define CV_LOG2PI (1.8378770664093454835606594728112)
00063
00064
00065 #define CV_COL_SAMPLE 0
00066
00067
00068 #define CV_ROW_SAMPLE 1
00069
00070 #define CV_IS_ROW_SAMPLE(flags) ((flags) & CV_ROW_SAMPLE)
00071
00072 struct CvVectors
00073 {
00074 int type;
00075 int dims, count;
00076 CvVectors* next;
00077 union
00078 {
00079 uchar** ptr;
00080 float** fl;
00081 double** db;
00082 } data;
00083 };
00084
00085 #if 0
00086
00087
00088
00089 typedef struct CvParamLattice
00090 {
00091 double min_val;
00092 double max_val;
00093 double step;
00094 }
00095 CvParamLattice;
00096
00097 CV_INLINE CvParamLattice cvParamLattice( double min_val, double max_val,
00098 double log_step )
00099 {
00100 CvParamLattice pl;
00101 pl.min_val = MIN( min_val, max_val );
00102 pl.max_val = MAX( min_val, max_val );
00103 pl.step = MAX( log_step, 1. );
00104 return pl;
00105 }
00106
00107 CV_INLINE CvParamLattice cvDefaultParamLattice( void )
00108 {
00109 CvParamLattice pl = {0,0,0};
00110 return pl;
00111 }
00112 #endif
00113
00114
00115 #define CV_VAR_NUMERICAL 0
00116 #define CV_VAR_ORDERED 0
00117 #define CV_VAR_CATEGORICAL 1
00118
00119 #define CV_TYPE_NAME_ML_SVM "opencv-ml-svm"
00120 #define CV_TYPE_NAME_ML_KNN "opencv-ml-knn"
00121 #define CV_TYPE_NAME_ML_NBAYES "opencv-ml-bayesian"
00122 #define CV_TYPE_NAME_ML_EM "opencv-ml-em"
00123 #define CV_TYPE_NAME_ML_BOOSTING "opencv-ml-boost-tree"
00124 #define CV_TYPE_NAME_ML_TREE "opencv-ml-tree"
00125 #define CV_TYPE_NAME_ML_ANN_MLP "opencv-ml-ann-mlp"
00126 #define CV_TYPE_NAME_ML_CNN "opencv-ml-cnn"
00127 #define CV_TYPE_NAME_ML_RTREES "opencv-ml-random-trees"
00128 #define CV_TYPE_NAME_ML_ERTREES "opencv-ml-extremely-randomized-trees"
00129 #define CV_TYPE_NAME_ML_GBT "opencv-ml-gradient-boosting-trees"
00130
00131 #define CV_TRAIN_ERROR 0
00132 #define CV_TEST_ERROR 1
00133
00134 class CV_EXPORTS_W CvStatModel
00135 {
00136 public:
00137 CvStatModel();
00138 virtual ~CvStatModel();
00139
00140 virtual void clear();
00141
00142 CV_WRAP virtual void save( const char* filename, const char* name=0 ) const;
00143 CV_WRAP virtual void load( const char* filename, const char* name=0 );
00144
00145 virtual void write( CvFileStorage* storage, const char* name ) const;
00146 virtual void read( CvFileStorage* storage, CvFileNode* node );
00147
00148 protected:
00149 const char* default_model_name;
00150 };
00151
00152
00153
00154
00155
00156
00157
00158
00159
00160
00161 class CvMLData;
00162
00163 struct CV_EXPORTS_W_MAP CvParamGrid
00164 {
00165
00166 enum { SVM_C=0, SVM_GAMMA=1, SVM_P=2, SVM_NU=3, SVM_COEF=4, SVM_DEGREE=5 };
00167
00168 CvParamGrid()
00169 {
00170 min_val = max_val = step = 0;
00171 }
00172
00173 CvParamGrid( double min_val, double max_val, double log_step );
00174
00175 bool check() const;
00176
00177 CV_PROP_RW double min_val;
00178 CV_PROP_RW double max_val;
00179 CV_PROP_RW double step;
00180 };
00181
00182 inline CvParamGrid::CvParamGrid( double _min_val, double _max_val, double _log_step )
00183 {
00184 min_val = _min_val;
00185 max_val = _max_val;
00186 step = _log_step;
00187 }
00188
00189 class CV_EXPORTS_W CvNormalBayesClassifier : public CvStatModel
00190 {
00191 public:
00192 CV_WRAP CvNormalBayesClassifier();
00193 virtual ~CvNormalBayesClassifier();
00194
00195 CvNormalBayesClassifier( const CvMat* trainData, const CvMat* responses,
00196 const CvMat* varIdx=0, const CvMat* sampleIdx=0 );
00197
00198 virtual bool train( const CvMat* trainData, const CvMat* responses,
00199 const CvMat* varIdx = 0, const CvMat* sampleIdx=0, bool update=false );
00200
00201 virtual float predict( const CvMat* samples, CV_OUT CvMat* results=0 ) const;
00202 CV_WRAP virtual void clear();
00203
00204 CV_WRAP CvNormalBayesClassifier( const cv::Mat& trainData, const cv::Mat& responses,
00205 const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat() );
00206 CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
00207 const cv::Mat& varIdx = cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
00208 bool update=false );
00209 CV_WRAP virtual float predict( const cv::Mat& samples, CV_OUT cv::Mat* results=0 ) const;
00210
00211 virtual void write( CvFileStorage* storage, const char* name ) const;
00212 virtual void read( CvFileStorage* storage, CvFileNode* node );
00213
00214 protected:
00215 int var_count, var_all;
00216 CvMat* var_idx;
00217 CvMat* cls_labels;
00218 CvMat** count;
00219 CvMat** sum;
00220 CvMat** productsum;
00221 CvMat** avg;
00222 CvMat** inv_eigen_values;
00223 CvMat** cov_rotate_mats;
00224 CvMat* c;
00225 };
00226
00227
00228
00229
00230
00231
00232
00233 class CV_EXPORTS_W CvKNearest : public CvStatModel
00234 {
00235 public:
00236
00237 CV_WRAP CvKNearest();
00238 virtual ~CvKNearest();
00239
00240 CvKNearest( const CvMat* trainData, const CvMat* responses,
00241 const CvMat* sampleIdx=0, bool isRegression=false, int max_k=32 );
00242
00243 virtual bool train( const CvMat* trainData, const CvMat* responses,
00244 const CvMat* sampleIdx=0, bool is_regression=false,
00245 int maxK=32, bool updateBase=false );
00246
00247 virtual float find_nearest( const CvMat* samples, int k, CV_OUT CvMat* results=0,
00248 const float** neighbors=0, CV_OUT CvMat* neighborResponses=0, CV_OUT CvMat* dist=0 ) const;
00249
00250 CV_WRAP CvKNearest( const cv::Mat& trainData, const cv::Mat& responses,
00251 const cv::Mat& sampleIdx=cv::Mat(), bool isRegression=false, int max_k=32 );
00252
00253 CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
00254 const cv::Mat& sampleIdx=cv::Mat(), bool isRegression=false,
00255 int maxK=32, bool updateBase=false );
00256
00257 virtual float find_nearest( const cv::Mat& samples, int k, cv::Mat* results=0,
00258 const float** neighbors=0, cv::Mat* neighborResponses=0,
00259 cv::Mat* dist=0 ) const;
00260 CV_WRAP virtual float find_nearest( const cv::Mat& samples, int k, CV_OUT cv::Mat& results,
00261 CV_OUT cv::Mat& neighborResponses, CV_OUT cv::Mat& dists) const;
00262
00263 virtual void clear();
00264 int get_max_k() const;
00265 int get_var_count() const;
00266 int get_sample_count() const;
00267 bool is_regression() const;
00268
00269 virtual float write_results( int k, int k1, int start, int end,
00270 const float* neighbor_responses, const float* dist, CvMat* _results,
00271 CvMat* _neighbor_responses, CvMat* _dist, Cv32suf* sort_buf ) const;
00272
00273 virtual void find_neighbors_direct( const CvMat* _samples, int k, int start, int end,
00274 float* neighbor_responses, const float** neighbors, float* dist ) const;
00275
00276 protected:
00277
00278 int max_k, var_count;
00279 int total;
00280 bool regression;
00281 CvVectors* samples;
00282 };
00283
00284
00285
00286
00287
00288
00289 struct CV_EXPORTS_W_MAP CvSVMParams
00290 {
00291 CvSVMParams();
00292 CvSVMParams( int svm_type, int kernel_type,
00293 double degree, double gamma, double coef0,
00294 double Cvalue, double nu, double p,
00295 CvMat* class_weights, CvTermCriteria term_crit );
00296
00297 CV_PROP_RW int svm_type;
00298 CV_PROP_RW int kernel_type;
00299 CV_PROP_RW double degree;
00300 CV_PROP_RW double gamma;
00301 CV_PROP_RW double coef0;
00302
00303 CV_PROP_RW double C;
00304 CV_PROP_RW double nu;
00305 CV_PROP_RW double p;
00306 CvMat* class_weights;
00307 CV_PROP_RW CvTermCriteria term_crit;
00308 };
00309
00310
00311 struct CV_EXPORTS CvSVMKernel
00312 {
00313 typedef void (CvSVMKernel::*Calc)( int vec_count, int vec_size, const float** vecs,
00314 const float* another, float* results );
00315 CvSVMKernel();
00316 CvSVMKernel( const CvSVMParams* params, Calc _calc_func );
00317 virtual bool create( const CvSVMParams* params, Calc _calc_func );
00318 virtual ~CvSVMKernel();
00319
00320 virtual void clear();
00321 virtual void calc( int vcount, int n, const float** vecs, const float* another, float* results );
00322
00323 const CvSVMParams* params;
00324 Calc calc_func;
00325
00326 virtual void calc_non_rbf_base( int vec_count, int vec_size, const float** vecs,
00327 const float* another, float* results,
00328 double alpha, double beta );
00329
00330 virtual void calc_linear( int vec_count, int vec_size, const float** vecs,
00331 const float* another, float* results );
00332 virtual void calc_rbf( int vec_count, int vec_size, const float** vecs,
00333 const float* another, float* results );
00334 virtual void calc_poly( int vec_count, int vec_size, const float** vecs,
00335 const float* another, float* results );
00336 virtual void calc_sigmoid( int vec_count, int vec_size, const float** vecs,
00337 const float* another, float* results );
00338 };
00339
00340
00341 struct CvSVMKernelRow
00342 {
00343 CvSVMKernelRow* prev;
00344 CvSVMKernelRow* next;
00345 float* data;
00346 };
00347
00348
00349 struct CvSVMSolutionInfo
00350 {
00351 double obj;
00352 double rho;
00353 double upper_bound_p;
00354 double upper_bound_n;
00355 double r;
00356 };
00357
00358 class CV_EXPORTS CvSVMSolver
00359 {
00360 public:
00361 typedef bool (CvSVMSolver::*SelectWorkingSet)( int& i, int& j );
00362 typedef float* (CvSVMSolver::*GetRow)( int i, float* row, float* dst, bool existed );
00363 typedef void (CvSVMSolver::*CalcRho)( double& rho, double& r );
00364
00365 CvSVMSolver();
00366
00367 CvSVMSolver( int count, int var_count, const float** samples, schar* y,
00368 int alpha_count, double* alpha, double Cp, double Cn,
00369 CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
00370 SelectWorkingSet select_working_set, CalcRho calc_rho );
00371 virtual bool create( int count, int var_count, const float** samples, schar* y,
00372 int alpha_count, double* alpha, double Cp, double Cn,
00373 CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
00374 SelectWorkingSet select_working_set, CalcRho calc_rho );
00375 virtual ~CvSVMSolver();
00376
00377 virtual void clear();
00378 virtual bool solve_generic( CvSVMSolutionInfo& si );
00379
00380 virtual bool solve_c_svc( int count, int var_count, const float** samples, schar* y,
00381 double Cp, double Cn, CvMemStorage* storage,
00382 CvSVMKernel* kernel, double* alpha, CvSVMSolutionInfo& si );
00383 virtual bool solve_nu_svc( int count, int var_count, const float** samples, schar* y,
00384 CvMemStorage* storage, CvSVMKernel* kernel,
00385 double* alpha, CvSVMSolutionInfo& si );
00386 virtual bool solve_one_class( int count, int var_count, const float** samples,
00387 CvMemStorage* storage, CvSVMKernel* kernel,
00388 double* alpha, CvSVMSolutionInfo& si );
00389
00390 virtual bool solve_eps_svr( int count, int var_count, const float** samples, const float* y,
00391 CvMemStorage* storage, CvSVMKernel* kernel,
00392 double* alpha, CvSVMSolutionInfo& si );
00393
00394 virtual bool solve_nu_svr( int count, int var_count, const float** samples, const float* y,
00395 CvMemStorage* storage, CvSVMKernel* kernel,
00396 double* alpha, CvSVMSolutionInfo& si );
00397
00398 virtual float* get_row_base( int i, bool* _existed );
00399 virtual float* get_row( int i, float* dst );
00400
00401 int sample_count;
00402 int var_count;
00403 int cache_size;
00404 int cache_line_size;
00405 const float** samples;
00406 const CvSVMParams* params;
00407 CvMemStorage* storage;
00408 CvSVMKernelRow lru_list;
00409 CvSVMKernelRow* rows;
00410
00411 int alpha_count;
00412
00413 double* G;
00414 double* alpha;
00415
00416
00417 schar* alpha_status;
00418
00419 schar* y;
00420 double* b;
00421 float* buf[2];
00422 double eps;
00423 int max_iter;
00424 double C[2];
00425 CvSVMKernel* kernel;
00426
00427 SelectWorkingSet select_working_set_func;
00428 CalcRho calc_rho_func;
00429 GetRow get_row_func;
00430
00431 virtual bool select_working_set( int& i, int& j );
00432 virtual bool select_working_set_nu_svm( int& i, int& j );
00433 virtual void calc_rho( double& rho, double& r );
00434 virtual void calc_rho_nu_svm( double& rho, double& r );
00435
00436 virtual float* get_row_svc( int i, float* row, float* dst, bool existed );
00437 virtual float* get_row_one_class( int i, float* row, float* dst, bool existed );
00438 virtual float* get_row_svr( int i, float* row, float* dst, bool existed );
00439 };
00440
00441
00442 struct CvSVMDecisionFunc
00443 {
00444 double rho;
00445 int sv_count;
00446 double* alpha;
00447 int* sv_index;
00448 };
00449
00450
00451
00452 class CV_EXPORTS_W CvSVM : public CvStatModel
00453 {
00454 public:
00455
00456 enum { C_SVC=100, NU_SVC=101, ONE_CLASS=102, EPS_SVR=103, NU_SVR=104 };
00457
00458
00459 enum { LINEAR=0, POLY=1, RBF=2, SIGMOID=3 };
00460
00461
00462 enum { C=0, GAMMA=1, P=2, NU=3, COEF=4, DEGREE=5 };
00463
00464 CV_WRAP CvSVM();
00465 virtual ~CvSVM();
00466
00467 CvSVM( const CvMat* trainData, const CvMat* responses,
00468 const CvMat* varIdx=0, const CvMat* sampleIdx=0,
00469 CvSVMParams params=CvSVMParams() );
00470
00471 virtual bool train( const CvMat* trainData, const CvMat* responses,
00472 const CvMat* varIdx=0, const CvMat* sampleIdx=0,
00473 CvSVMParams params=CvSVMParams() );
00474
00475 virtual bool train_auto( const CvMat* trainData, const CvMat* responses,
00476 const CvMat* varIdx, const CvMat* sampleIdx, CvSVMParams params,
00477 int kfold = 10,
00478 CvParamGrid Cgrid = get_default_grid(CvSVM::C),
00479 CvParamGrid gammaGrid = get_default_grid(CvSVM::GAMMA),
00480 CvParamGrid pGrid = get_default_grid(CvSVM::P),
00481 CvParamGrid nuGrid = get_default_grid(CvSVM::NU),
00482 CvParamGrid coeffGrid = get_default_grid(CvSVM::COEF),
00483 CvParamGrid degreeGrid = get_default_grid(CvSVM::DEGREE),
00484 bool balanced=false );
00485
00486 virtual float predict( const CvMat* sample, bool returnDFVal=false ) const;
00487 virtual float predict( const CvMat* samples, CV_OUT CvMat* results ) const;
00488
00489 CV_WRAP CvSVM( const cv::Mat& trainData, const cv::Mat& responses,
00490 const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
00491 CvSVMParams params=CvSVMParams() );
00492
00493 CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
00494 const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
00495 CvSVMParams params=CvSVMParams() );
00496
00497 CV_WRAP virtual bool train_auto( const cv::Mat& trainData, const cv::Mat& responses,
00498 const cv::Mat& varIdx, const cv::Mat& sampleIdx, CvSVMParams params,
00499 int k_fold = 10,
00500 CvParamGrid Cgrid = CvSVM::get_default_grid(CvSVM::C),
00501 CvParamGrid gammaGrid = CvSVM::get_default_grid(CvSVM::GAMMA),
00502 CvParamGrid pGrid = CvSVM::get_default_grid(CvSVM::P),
00503 CvParamGrid nuGrid = CvSVM::get_default_grid(CvSVM::NU),
00504 CvParamGrid coeffGrid = CvSVM::get_default_grid(CvSVM::COEF),
00505 CvParamGrid degreeGrid = CvSVM::get_default_grid(CvSVM::DEGREE),
00506 bool balanced=false);
00507 CV_WRAP virtual float predict( const cv::Mat& sample, bool returnDFVal=false ) const;
00508 CV_WRAP_AS(predict_all) void predict( cv::InputArray samples, cv::OutputArray results ) const;
00509
00510 CV_WRAP virtual int get_support_vector_count() const;
00511 virtual const float* get_support_vector(int i) const;
00512 virtual CvSVMParams get_params() const { return params; };
00513 CV_WRAP virtual void clear();
00514
00515 static CvParamGrid get_default_grid( int param_id );
00516
00517 virtual void write( CvFileStorage* storage, const char* name ) const;
00518 virtual void read( CvFileStorage* storage, CvFileNode* node );
00519 CV_WRAP int get_var_count() const { return var_idx ? var_idx->cols : var_all; }
00520
00521 protected:
00522
00523 virtual bool set_params( const CvSVMParams& params );
00524 virtual bool train1( int sample_count, int var_count, const float** samples,
00525 const void* responses, double Cp, double Cn,
00526 CvMemStorage* _storage, double* alpha, double& rho );
00527 virtual bool do_train( int svm_type, int sample_count, int var_count, const float** samples,
00528 const CvMat* responses, CvMemStorage* _storage, double* alpha );
00529 virtual void create_kernel();
00530 virtual void create_solver();
00531
00532 virtual float predict( const float* row_sample, int row_len, bool returnDFVal=false ) const;
00533
00534 virtual void write_params( CvFileStorage* fs ) const;
00535 virtual void read_params( CvFileStorage* fs, CvFileNode* node );
00536
00537 CvSVMParams params;
00538 CvMat* class_labels;
00539 int var_all;
00540 float** sv;
00541 int sv_total;
00542 CvMat* var_idx;
00543 CvMat* class_weights;
00544 CvSVMDecisionFunc* decision_func;
00545 CvMemStorage* storage;
00546
00547 CvSVMSolver* solver;
00548 CvSVMKernel* kernel;
00549 };
00550
00551
00552
00553
00554 namespace cv
00555 {
00556 class CV_EXPORTS_W EM : public Algorithm
00557 {
00558 public:
00559
00560 enum {COV_MAT_SPHERICAL=0, COV_MAT_DIAGONAL=1, COV_MAT_GENERIC=2, COV_MAT_DEFAULT=COV_MAT_DIAGONAL};
00561
00562
00563 enum {DEFAULT_NCLUSTERS=5, DEFAULT_MAX_ITERS=100};
00564
00565
00566 enum {START_E_STEP=1, START_M_STEP=2, START_AUTO_STEP=0};
00567
00568 CV_WRAP EM(int nclusters=EM::DEFAULT_NCLUSTERS, int covMatType=EM::COV_MAT_DIAGONAL,
00569 const TermCriteria& termCrit=TermCriteria(TermCriteria::COUNT+TermCriteria::EPS,
00570 EM::DEFAULT_MAX_ITERS, FLT_EPSILON));
00571
00572 virtual ~EM();
00573 CV_WRAP virtual void clear();
00574
00575 CV_WRAP virtual bool train(InputArray samples,
00576 OutputArray logLikelihoods=noArray(),
00577 OutputArray labels=noArray(),
00578 OutputArray probs=noArray());
00579
00580 CV_WRAP virtual bool trainE(InputArray samples,
00581 InputArray means0,
00582 InputArray covs0=noArray(),
00583 InputArray weights0=noArray(),
00584 OutputArray logLikelihoods=noArray(),
00585 OutputArray labels=noArray(),
00586 OutputArray probs=noArray());
00587
00588 CV_WRAP virtual bool trainM(InputArray samples,
00589 InputArray probs0,
00590 OutputArray logLikelihoods=noArray(),
00591 OutputArray labels=noArray(),
00592 OutputArray probs=noArray());
00593
00594 CV_WRAP Vec2d predict(InputArray sample,
00595 OutputArray probs=noArray()) const;
00596
00597 CV_WRAP bool isTrained() const;
00598
00599 AlgorithmInfo* info() const;
00600 virtual void read(const FileNode& fn);
00601
00602 protected:
00603
00604 virtual void setTrainData(int startStep, const Mat& samples,
00605 const Mat* probs0,
00606 const Mat* means0,
00607 const vector<Mat>* covs0,
00608 const Mat* weights0);
00609
00610 bool doTrain(int startStep,
00611 OutputArray logLikelihoods,
00612 OutputArray labels,
00613 OutputArray probs);
00614 virtual void eStep();
00615 virtual void mStep();
00616
00617 void clusterTrainSamples();
00618 void decomposeCovs();
00619 void computeLogWeightDivDet();
00620
00621 Vec2d computeProbabilities(const Mat& sample, Mat* probs) const;
00622
00623
00624 CV_PROP_RW int nclusters;
00625 CV_PROP_RW int covMatType;
00626 CV_PROP_RW int maxIters;
00627 CV_PROP_RW double epsilon;
00628
00629 Mat trainSamples;
00630 Mat trainProbs;
00631 Mat trainLogLikelihoods;
00632 Mat trainLabels;
00633
00634 CV_PROP Mat weights;
00635 CV_PROP Mat means;
00636 CV_PROP vector<Mat> covs;
00637
00638 vector<Mat> covsEigenValues;
00639 vector<Mat> covsRotateMats;
00640 vector<Mat> invCovsEigenValues;
00641 Mat logWeightDivDet;
00642 };
00643 }
00644
00645
00646
00647 \
00648 struct CvPair16u32s
00649 {
00650 unsigned short* u;
00651 int* i;
00652 };
00653
00654
00655 #define CV_DTREE_CAT_DIR(idx,subset) \
00656 (2*((subset[(idx)>>5]&(1 << ((idx) & 31)))==0)-1)
00657
00658 struct CvDTreeSplit
00659 {
00660 int var_idx;
00661 int condensed_idx;
00662 int inversed;
00663 float quality;
00664 CvDTreeSplit* next;
00665 union
00666 {
00667 int subset[2];
00668 struct
00669 {
00670 float c;
00671 int split_point;
00672 }
00673 ord;
00674 };
00675 };
00676
00677 struct CvDTreeNode
00678 {
00679 int class_idx;
00680 int Tn;
00681 double value;
00682
00683 CvDTreeNode* parent;
00684 CvDTreeNode* left;
00685 CvDTreeNode* right;
00686
00687 CvDTreeSplit* split;
00688
00689 int sample_count;
00690 int depth;
00691 int* num_valid;
00692 int offset;
00693 int buf_idx;
00694 double maxlr;
00695
00696
00697 int complexity;
00698 double alpha;
00699 double node_risk, tree_risk, tree_error;
00700
00701
00702 int* cv_Tn;
00703 double* cv_node_risk;
00704 double* cv_node_error;
00705
00706 int get_num_valid(int vi) { return num_valid ? num_valid[vi] : sample_count; }
00707 void set_num_valid(int vi, int n) { if( num_valid ) num_valid[vi] = n; }
00708 };
00709
00710
00711 struct CV_EXPORTS_W_MAP CvDTreeParams
00712 {
00713 CV_PROP_RW int max_categories;
00714 CV_PROP_RW int max_depth;
00715 CV_PROP_RW int min_sample_count;
00716 CV_PROP_RW int cv_folds;
00717 CV_PROP_RW bool use_surrogates;
00718 CV_PROP_RW bool use_1se_rule;
00719 CV_PROP_RW bool truncate_pruned_tree;
00720 CV_PROP_RW float regression_accuracy;
00721 const float* priors;
00722
00723 CvDTreeParams();
00724 CvDTreeParams( int max_depth, int min_sample_count,
00725 float regression_accuracy, bool use_surrogates,
00726 int max_categories, int cv_folds,
00727 bool use_1se_rule, bool truncate_pruned_tree,
00728 const float* priors );
00729 };
00730
00731
00732 struct CV_EXPORTS CvDTreeTrainData
00733 {
00734 CvDTreeTrainData();
00735 CvDTreeTrainData( const CvMat* trainData, int tflag,
00736 const CvMat* responses, const CvMat* varIdx=0,
00737 const CvMat* sampleIdx=0, const CvMat* varType=0,
00738 const CvMat* missingDataMask=0,
00739 const CvDTreeParams& params=CvDTreeParams(),
00740 bool _shared=false, bool _add_labels=false );
00741 virtual ~CvDTreeTrainData();
00742
00743 virtual void set_data( const CvMat* trainData, int tflag,
00744 const CvMat* responses, const CvMat* varIdx=0,
00745 const CvMat* sampleIdx=0, const CvMat* varType=0,
00746 const CvMat* missingDataMask=0,
00747 const CvDTreeParams& params=CvDTreeParams(),
00748 bool _shared=false, bool _add_labels=false,
00749 bool _update_data=false );
00750 virtual void do_responses_copy();
00751
00752 virtual void get_vectors( const CvMat* _subsample_idx,
00753 float* values, uchar* missing, float* responses, bool get_class_idx=false );
00754
00755 virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
00756
00757 virtual void write_params( CvFileStorage* fs ) const;
00758 virtual void read_params( CvFileStorage* fs, CvFileNode* node );
00759
00760
00761 virtual void clear();
00762
00763 int get_num_classes() const;
00764 int get_var_type(int vi) const;
00765 int get_work_var_count() const {return work_var_count;}
00766
00767 virtual const float* get_ord_responses( CvDTreeNode* n, float* values_buf, int* sample_indices_buf );
00768 virtual const int* get_class_labels( CvDTreeNode* n, int* labels_buf );
00769 virtual const int* get_cv_labels( CvDTreeNode* n, int* labels_buf );
00770 virtual const int* get_sample_indices( CvDTreeNode* n, int* indices_buf );
00771 virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf );
00772 virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* sorted_indices_buf,
00773 const float** ord_values, const int** sorted_indices, int* sample_indices_buf );
00774 virtual int get_child_buf_idx( CvDTreeNode* n );
00775
00777
00778 virtual bool set_params( const CvDTreeParams& params );
00779 virtual CvDTreeNode* new_node( CvDTreeNode* parent, int count,
00780 int storage_idx, int offset );
00781
00782 virtual CvDTreeSplit* new_split_ord( int vi, float cmp_val,
00783 int split_point, int inversed, float quality );
00784 virtual CvDTreeSplit* new_split_cat( int vi, float quality );
00785 virtual void free_node_data( CvDTreeNode* node );
00786 virtual void free_train_data();
00787 virtual void free_node( CvDTreeNode* node );
00788
00789 int sample_count, var_all, var_count, max_c_count;
00790 int ord_var_count, cat_var_count, work_var_count;
00791 bool have_labels, have_priors;
00792 bool is_classifier;
00793 int tflag;
00794
00795 const CvMat* train_data;
00796 const CvMat* responses;
00797 CvMat* responses_copy;
00798
00799 int buf_count, buf_size;
00800 bool shared;
00801 int is_buf_16u;
00802
00803 CvMat* cat_count;
00804 CvMat* cat_ofs;
00805 CvMat* cat_map;
00806
00807 CvMat* counts;
00808 CvMat* buf;
00809 CvMat* direction;
00810 CvMat* split_buf;
00811
00812 CvMat* var_idx;
00813 CvMat* var_type;
00814
00815
00816 CvMat* priors;
00817 CvMat* priors_mult;
00818
00819 CvDTreeParams params;
00820
00821 CvMemStorage* tree_storage;
00822 CvMemStorage* temp_storage;
00823
00824 CvDTreeNode* data_root;
00825
00826 CvSet* node_heap;
00827 CvSet* split_heap;
00828 CvSet* cv_heap;
00829 CvSet* nv_heap;
00830
00831 cv::RNG* rng;
00832 };
00833
00834 class CvDTree;
00835 class CvForestTree;
00836
00837 namespace cv
00838 {
00839 struct DTreeBestSplitFinder;
00840 struct ForestTreeBestSplitFinder;
00841 }
00842
00843 class CV_EXPORTS_W CvDTree : public CvStatModel
00844 {
00845 public:
00846 CV_WRAP CvDTree();
00847 virtual ~CvDTree();
00848
00849 virtual bool train( const CvMat* trainData, int tflag,
00850 const CvMat* responses, const CvMat* varIdx=0,
00851 const CvMat* sampleIdx=0, const CvMat* varType=0,
00852 const CvMat* missingDataMask=0,
00853 CvDTreeParams params=CvDTreeParams() );
00854
00855 virtual bool train( CvMLData* trainData, CvDTreeParams params=CvDTreeParams() );
00856
00857
00858 virtual float calc_error( CvMLData* trainData, int type, std::vector<float> *resp = 0 );
00859
00860 virtual bool train( CvDTreeTrainData* trainData, const CvMat* subsampleIdx );
00861
00862 virtual CvDTreeNode* predict( const CvMat* sample, const CvMat* missingDataMask=0,
00863 bool preprocessedInput=false ) const;
00864
00865 CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
00866 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
00867 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
00868 const cv::Mat& missingDataMask=cv::Mat(),
00869 CvDTreeParams params=CvDTreeParams() );
00870
00871 CV_WRAP virtual CvDTreeNode* predict( const cv::Mat& sample, const cv::Mat& missingDataMask=cv::Mat(),
00872 bool preprocessedInput=false ) const;
00873 CV_WRAP virtual cv::Mat getVarImportance();
00874
00875 virtual const CvMat* get_var_importance();
00876 CV_WRAP virtual void clear();
00877
00878 virtual void read( CvFileStorage* fs, CvFileNode* node );
00879 virtual void write( CvFileStorage* fs, const char* name ) const;
00880
00881
00882 virtual void read( CvFileStorage* fs, CvFileNode* node,
00883 CvDTreeTrainData* data );
00884 virtual void write( CvFileStorage* fs ) const;
00885
00886 const CvDTreeNode* get_root() const;
00887 int get_pruned_tree_idx() const;
00888 CvDTreeTrainData* get_data();
00889
00890 protected:
00891 friend struct cv::DTreeBestSplitFinder;
00892
00893 virtual bool do_train( const CvMat* _subsample_idx );
00894
00895 virtual void try_split_node( CvDTreeNode* n );
00896 virtual void split_node_data( CvDTreeNode* n );
00897 virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
00898 virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi,
00899 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
00900 virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
00901 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
00902 virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi,
00903 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
00904 virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi,
00905 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
00906 virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
00907 virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
00908 virtual double calc_node_dir( CvDTreeNode* node );
00909 virtual void complete_node_dir( CvDTreeNode* node );
00910 virtual void cluster_categories( const int* vectors, int vector_count,
00911 int var_count, int* sums, int k, int* cluster_labels );
00912
00913 virtual void calc_node_value( CvDTreeNode* node );
00914
00915 virtual void prune_cv();
00916 virtual double update_tree_rnc( int T, int fold );
00917 virtual int cut_tree( int T, int fold, double min_alpha );
00918 virtual void free_prune_data(bool cut_tree);
00919 virtual void free_tree();
00920
00921 virtual void write_node( CvFileStorage* fs, CvDTreeNode* node ) const;
00922 virtual void write_split( CvFileStorage* fs, CvDTreeSplit* split ) const;
00923 virtual CvDTreeNode* read_node( CvFileStorage* fs, CvFileNode* node, CvDTreeNode* parent );
00924 virtual CvDTreeSplit* read_split( CvFileStorage* fs, CvFileNode* node );
00925 virtual void write_tree_nodes( CvFileStorage* fs ) const;
00926 virtual void read_tree_nodes( CvFileStorage* fs, CvFileNode* node );
00927
00928 CvDTreeNode* root;
00929 CvMat* var_importance;
00930 CvDTreeTrainData* data;
00931
00932 public:
00933 int pruned_tree_idx;
00934 };
00935
00936
00937
00938
00939
00940
00941 class CvRTrees;
00942
00943 class CV_EXPORTS CvForestTree: public CvDTree
00944 {
00945 public:
00946 CvForestTree();
00947 virtual ~CvForestTree();
00948
00949 virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx, CvRTrees* forest );
00950
00951 virtual int get_var_count() const {return data ? data->var_count : 0;}
00952 virtual void read( CvFileStorage* fs, CvFileNode* node, CvRTrees* forest, CvDTreeTrainData* _data );
00953
00954
00955 virtual bool train( const CvMat* trainData, int tflag,
00956 const CvMat* responses, const CvMat* varIdx=0,
00957 const CvMat* sampleIdx=0, const CvMat* varType=0,
00958 const CvMat* missingDataMask=0,
00959 CvDTreeParams params=CvDTreeParams() );
00960
00961 virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx );
00962 virtual void read( CvFileStorage* fs, CvFileNode* node );
00963 virtual void read( CvFileStorage* fs, CvFileNode* node,
00964 CvDTreeTrainData* data );
00965
00966
00967 protected:
00968 friend struct cv::ForestTreeBestSplitFinder;
00969
00970 virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
00971 CvRTrees* forest;
00972 };
00973
00974
00975 struct CV_EXPORTS_W_MAP CvRTParams : public CvDTreeParams
00976 {
00977
00978 CV_PROP_RW bool calc_var_importance;
00979 CV_PROP_RW int nactive_vars;
00980 CV_PROP_RW CvTermCriteria term_crit;
00981
00982 CvRTParams();
00983 CvRTParams( int max_depth, int min_sample_count,
00984 float regression_accuracy, bool use_surrogates,
00985 int max_categories, const float* priors, bool calc_var_importance,
00986 int nactive_vars, int max_num_of_trees_in_the_forest,
00987 float forest_accuracy, int termcrit_type );
00988 };
00989
00990
00991 class CV_EXPORTS_W CvRTrees : public CvStatModel
00992 {
00993 public:
00994 CV_WRAP CvRTrees();
00995 virtual ~CvRTrees();
00996 virtual bool train( const CvMat* trainData, int tflag,
00997 const CvMat* responses, const CvMat* varIdx=0,
00998 const CvMat* sampleIdx=0, const CvMat* varType=0,
00999 const CvMat* missingDataMask=0,
01000 CvRTParams params=CvRTParams() );
01001
01002 virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );
01003 virtual float predict( const CvMat* sample, const CvMat* missing = 0 ) const;
01004 virtual float predict_prob( const CvMat* sample, const CvMat* missing = 0 ) const;
01005
01006 CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
01007 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
01008 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
01009 const cv::Mat& missingDataMask=cv::Mat(),
01010 CvRTParams params=CvRTParams() );
01011 CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;
01012 CV_WRAP virtual float predict_prob( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;
01013 CV_WRAP virtual cv::Mat getVarImportance();
01014
01015 CV_WRAP virtual void clear();
01016
01017 virtual const CvMat* get_var_importance();
01018 virtual float get_proximity( const CvMat* sample1, const CvMat* sample2,
01019 const CvMat* missing1 = 0, const CvMat* missing2 = 0 ) const;
01020
01021 virtual float calc_error( CvMLData* data, int type , std::vector<float>* resp = 0 );
01022
01023 virtual float get_train_error();
01024
01025 virtual void read( CvFileStorage* fs, CvFileNode* node );
01026 virtual void write( CvFileStorage* fs, const char* name ) const;
01027
01028 CvMat* get_active_var_mask();
01029 CvRNG* get_rng();
01030
01031 int get_tree_count() const;
01032 CvForestTree* get_tree(int i) const;
01033
01034 protected:
01035 virtual std::string getName() const;
01036
01037 virtual bool grow_forest( const CvTermCriteria term_crit );
01038
01039
01040 CvForestTree** trees;
01041 CvDTreeTrainData* data;
01042 int ntrees;
01043 int nclasses;
01044 double oob_error;
01045 CvMat* var_importance;
01046 int nsamples;
01047
01048 cv::RNG* rng;
01049 CvMat* active_var_mask;
01050 };
01051
01052
01053
01054
01055 struct CV_EXPORTS CvERTreeTrainData : public CvDTreeTrainData
01056 {
01057 virtual void set_data( const CvMat* trainData, int tflag,
01058 const CvMat* responses, const CvMat* varIdx=0,
01059 const CvMat* sampleIdx=0, const CvMat* varType=0,
01060 const CvMat* missingDataMask=0,
01061 const CvDTreeParams& params=CvDTreeParams(),
01062 bool _shared=false, bool _add_labels=false,
01063 bool _update_data=false );
01064 virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* missing_buf,
01065 const float** ord_values, const int** missing, int* sample_buf = 0 );
01066 virtual const int* get_sample_indices( CvDTreeNode* n, int* indices_buf );
01067 virtual const int* get_cv_labels( CvDTreeNode* n, int* labels_buf );
01068 virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf );
01069 virtual void get_vectors( const CvMat* _subsample_idx, float* values, uchar* missing,
01070 float* responses, bool get_class_idx=false );
01071 virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
01072 const CvMat* missing_mask;
01073 };
01074
01075 class CV_EXPORTS CvForestERTree : public CvForestTree
01076 {
01077 protected:
01078 virtual double calc_node_dir( CvDTreeNode* node );
01079 virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi,
01080 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01081 virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
01082 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01083 virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi,
01084 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01085 virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi,
01086 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01087 virtual void split_node_data( CvDTreeNode* n );
01088 };
01089
01090 class CV_EXPORTS_W CvERTrees : public CvRTrees
01091 {
01092 public:
01093 CV_WRAP CvERTrees();
01094 virtual ~CvERTrees();
01095 virtual bool train( const CvMat* trainData, int tflag,
01096 const CvMat* responses, const CvMat* varIdx=0,
01097 const CvMat* sampleIdx=0, const CvMat* varType=0,
01098 const CvMat* missingDataMask=0,
01099 CvRTParams params=CvRTParams());
01100 CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
01101 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
01102 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
01103 const cv::Mat& missingDataMask=cv::Mat(),
01104 CvRTParams params=CvRTParams());
01105 virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );
01106 protected:
01107 virtual std::string getName() const;
01108 virtual bool grow_forest( const CvTermCriteria term_crit );
01109 };
01110
01111
01112
01113
01114
01115
01116 struct CV_EXPORTS_W_MAP CvBoostParams : public CvDTreeParams
01117 {
01118 CV_PROP_RW int boost_type;
01119 CV_PROP_RW int weak_count;
01120 CV_PROP_RW int split_criteria;
01121 CV_PROP_RW double weight_trim_rate;
01122
01123 CvBoostParams();
01124 CvBoostParams( int boost_type, int weak_count, double weight_trim_rate,
01125 int max_depth, bool use_surrogates, const float* priors );
01126 };
01127
01128
01129 class CvBoost;
01130
01131 class CV_EXPORTS CvBoostTree: public CvDTree
01132 {
01133 public:
01134 CvBoostTree();
01135 virtual ~CvBoostTree();
01136
01137 virtual bool train( CvDTreeTrainData* trainData,
01138 const CvMat* subsample_idx, CvBoost* ensemble );
01139
01140 virtual void scale( double s );
01141 virtual void read( CvFileStorage* fs, CvFileNode* node,
01142 CvBoost* ensemble, CvDTreeTrainData* _data );
01143 virtual void clear();
01144
01145
01146 virtual bool train( const CvMat* trainData, int tflag,
01147 const CvMat* responses, const CvMat* varIdx=0,
01148 const CvMat* sampleIdx=0, const CvMat* varType=0,
01149 const CvMat* missingDataMask=0,
01150 CvDTreeParams params=CvDTreeParams() );
01151 virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx );
01152
01153 virtual void read( CvFileStorage* fs, CvFileNode* node );
01154 virtual void read( CvFileStorage* fs, CvFileNode* node,
01155 CvDTreeTrainData* data );
01156
01157
01158 protected:
01159
01160 virtual void try_split_node( CvDTreeNode* n );
01161 virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
01162 virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
01163 virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi,
01164 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01165 virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
01166 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01167 virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi,
01168 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01169 virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi,
01170 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01171 virtual void calc_node_value( CvDTreeNode* n );
01172 virtual double calc_node_dir( CvDTreeNode* n );
01173
01174 CvBoost* ensemble;
01175 };
01176
01177
01178 class CV_EXPORTS_W CvBoost : public CvStatModel
01179 {
01180 public:
01181
01182 enum { DISCRETE=0, REAL=1, LOGIT=2, GENTLE=3 };
01183
01184
01185 enum { DEFAULT=0, GINI=1, MISCLASS=3, SQERR=4 };
01186
01187 CV_WRAP CvBoost();
01188 virtual ~CvBoost();
01189
01190 CvBoost( const CvMat* trainData, int tflag,
01191 const CvMat* responses, const CvMat* varIdx=0,
01192 const CvMat* sampleIdx=0, const CvMat* varType=0,
01193 const CvMat* missingDataMask=0,
01194 CvBoostParams params=CvBoostParams() );
01195
01196 virtual bool train( const CvMat* trainData, int tflag,
01197 const CvMat* responses, const CvMat* varIdx=0,
01198 const CvMat* sampleIdx=0, const CvMat* varType=0,
01199 const CvMat* missingDataMask=0,
01200 CvBoostParams params=CvBoostParams(),
01201 bool update=false );
01202
01203 virtual bool train( CvMLData* data,
01204 CvBoostParams params=CvBoostParams(),
01205 bool update=false );
01206
01207 virtual float predict( const CvMat* sample, const CvMat* missing=0,
01208 CvMat* weak_responses=0, CvSlice slice=CV_WHOLE_SEQ,
01209 bool raw_mode=false, bool return_sum=false ) const;
01210
01211 CV_WRAP CvBoost( const cv::Mat& trainData, int tflag,
01212 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
01213 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
01214 const cv::Mat& missingDataMask=cv::Mat(),
01215 CvBoostParams params=CvBoostParams() );
01216
01217 CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
01218 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
01219 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
01220 const cv::Mat& missingDataMask=cv::Mat(),
01221 CvBoostParams params=CvBoostParams(),
01222 bool update=false );
01223
01224 CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing=cv::Mat(),
01225 const cv::Range& slice=cv::Range::all(), bool rawMode=false,
01226 bool returnSum=false ) const;
01227
01228 virtual float calc_error( CvMLData* _data, int type , std::vector<float> *resp = 0 );
01229
01230 CV_WRAP virtual void prune( CvSlice slice );
01231
01232 CV_WRAP virtual void clear();
01233
01234 virtual void write( CvFileStorage* storage, const char* name ) const;
01235 virtual void read( CvFileStorage* storage, CvFileNode* node );
01236 virtual const CvMat* get_active_vars(bool absolute_idx=true);
01237
01238 CvSeq* get_weak_predictors();
01239
01240 CvMat* get_weights();
01241 CvMat* get_subtree_weights();
01242 CvMat* get_weak_response();
01243 const CvBoostParams& get_params() const;
01244 const CvDTreeTrainData* get_data() const;
01245
01246 protected:
01247
01248 virtual bool set_params( const CvBoostParams& params );
01249 virtual void update_weights( CvBoostTree* tree );
01250 virtual void trim_weights();
01251 virtual void write_params( CvFileStorage* fs ) const;
01252 virtual void read_params( CvFileStorage* fs, CvFileNode* node );
01253
01254 CvDTreeTrainData* data;
01255 CvBoostParams params;
01256 CvSeq* weak;
01257
01258 CvMat* active_vars;
01259 CvMat* active_vars_abs;
01260 bool have_active_cat_vars;
01261
01262 CvMat* orig_response;
01263 CvMat* sum_response;
01264 CvMat* weak_eval;
01265 CvMat* subsample_mask;
01266 CvMat* weights;
01267 CvMat* subtree_weights;
01268 bool have_subsample;
01269 };
01270
01271
01272
01273
01274
01275
01276
01277
01278
01279
01280
01281
01282
01283
01284
01285
01286
01287
01288
01289
01290
01291
01292
01293 struct CV_EXPORTS_W_MAP CvGBTreesParams : public CvDTreeParams
01294 {
01295 CV_PROP_RW int weak_count;
01296 CV_PROP_RW int loss_function_type;
01297 CV_PROP_RW float subsample_portion;
01298 CV_PROP_RW float shrinkage;
01299
01300 CvGBTreesParams();
01301 CvGBTreesParams( int loss_function_type, int weak_count, float shrinkage,
01302 float subsample_portion, int max_depth, bool use_surrogates );
01303 };
01304
01305
01306
01307
01308
01309
01310
01311
01312
01313
01314
01315
01316
01317
01318
01319
01320
01321
01322
01323
01324
01325
01326
01327
01328
01329
01330
01331
01332
01333
01334
01335
01336
01337
01338
01339
01340
01341
01342
01343
01344
01345
01346
01347
01348
01349
01350
01351
01352
01353 class CV_EXPORTS_W CvGBTrees : public CvStatModel
01354 {
01355 public:
01356
01357
01358
01359
01360
01361
01362
01363
01364
01365
01366
01367
01368
01369
01370
01371
01372
01373
01374
01375
01376
01377
01378
01379
01380 enum {SQUARED_LOSS=0, ABSOLUTE_LOSS, HUBER_LOSS=3, DEVIANCE_LOSS};
01381
01382
01383
01384
01385
01386
01387
01388
01389
01390
01391
01392
01393
01394 CV_WRAP CvGBTrees();
01395
01396
01397
01398
01399
01400
01401
01402
01403
01404
01405
01406
01407
01408
01409
01410
01411
01412
01413
01414
01415
01416
01417
01418
01419
01420
01421
01422
01423
01424
01425
01426
01427
01428
01429
01430
01431
01432
01433
01434 CvGBTrees( const CvMat* trainData, int tflag,
01435 const CvMat* responses, const CvMat* varIdx=0,
01436 const CvMat* sampleIdx=0, const CvMat* varType=0,
01437 const CvMat* missingDataMask=0,
01438 CvGBTreesParams params=CvGBTreesParams() );
01439
01440
01441
01442
01443
01444 virtual ~CvGBTrees();
01445
01446
01447
01448
01449
01450
01451
01452
01453
01454
01455
01456
01457
01458
01459
01460
01461
01462
01463
01464
01465
01466
01467
01468
01469
01470
01471
01472
01473
01474
01475
01476
01477
01478
01479
01480
01481
01482
01483
01484
01485
01486 virtual bool train( const CvMat* trainData, int tflag,
01487 const CvMat* responses, const CvMat* varIdx=0,
01488 const CvMat* sampleIdx=0, const CvMat* varType=0,
01489 const CvMat* missingDataMask=0,
01490 CvGBTreesParams params=CvGBTreesParams(),
01491 bool update=false );
01492
01493
01494
01495
01496
01497
01498
01499
01500
01501
01502
01503
01504
01505
01506
01507
01508
01509
01510 virtual bool train( CvMLData* data,
01511 CvGBTreesParams params=CvGBTreesParams(),
01512 bool update=false );
01513
01514
01515
01516
01517
01518
01519
01520
01521
01522
01523
01524
01525
01526
01527
01528
01529
01530
01531
01532
01533
01534
01535
01536
01537
01538
01539
01540
01541
01542 virtual float predict_serial( const CvMat* sample, const CvMat* missing=0,
01543 CvMat* weakResponses=0, CvSlice slice = CV_WHOLE_SEQ,
01544 int k=-1 ) const;
01545
01546
01547
01548
01549
01550
01551
01552
01553
01554
01555
01556
01557
01558
01559
01560
01561
01562
01563
01564
01565
01566
01567
01568
01569
01570
01571
01572
01573
01574 virtual float predict( const CvMat* sample, const CvMat* missing=0,
01575 CvMat* weakResponses=0, CvSlice slice = CV_WHOLE_SEQ,
01576 int k=-1 ) const;
01577
01578
01579
01580
01581
01582
01583
01584
01585
01586
01587
01588
01589
01590
01591
01592 CV_WRAP virtual void clear();
01593
01594
01595
01596
01597
01598
01599
01600
01601
01602
01603
01604
01605
01606
01607
01608
01609
01610 virtual float calc_error( CvMLData* _data, int type,
01611 std::vector<float> *resp = 0 );
01612
01613
01614
01615
01616
01617
01618
01619
01620
01621
01622
01623
01624
01625
01626 virtual void write( CvFileStorage* fs, const char* name ) const;
01627
01628
01629
01630
01631
01632
01633
01634
01635
01636
01637
01638
01639
01640
01641
01642 virtual void read( CvFileStorage* fs, CvFileNode* node );
01643
01644
01645
01646 CV_WRAP CvGBTrees( const cv::Mat& trainData, int tflag,
01647 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
01648 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
01649 const cv::Mat& missingDataMask=cv::Mat(),
01650 CvGBTreesParams params=CvGBTreesParams() );
01651
01652 CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
01653 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
01654 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
01655 const cv::Mat& missingDataMask=cv::Mat(),
01656 CvGBTreesParams params=CvGBTreesParams(),
01657 bool update=false );
01658
01659 CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing=cv::Mat(),
01660 const cv::Range& slice = cv::Range::all(),
01661 int k=-1 ) const;
01662
01663 protected:
01664
01665
01666
01667
01668
01669
01670
01671
01672
01673
01674
01675
01676
01677
01678
01679
01680 virtual void find_gradient( const int k = 0);
01681
01682
01683
01684
01685
01686
01687
01688
01689
01690
01691
01692
01693
01694
01695
01696
01697
01698
01699 virtual void change_values(CvDTree* tree, const int k = 0);
01700
01701
01702
01703
01704
01705
01706
01707
01708
01709
01710
01711
01712
01713
01714
01715
01716
01717
01718 virtual float find_optimal_value( const CvMat* _Idx );
01719
01720
01721
01722
01723
01724
01725
01726
01727
01728
01729
01730
01731
01732
01733
01734
01735 virtual void do_subsample();
01736
01737
01738
01739
01740
01741
01742
01743
01744
01745
01746
01747
01748
01749
01750
01751
01752 void leaves_get( CvDTreeNode** leaves, int& count, CvDTreeNode* node );
01753
01754
01755
01756
01757
01758
01759
01760
01761
01762
01763
01764
01765
01766
01767
01768
01769 CvDTreeNode** GetLeaves( const CvDTree* dtree, int& len );
01770
01771
01772
01773
01774
01775
01776
01777
01778
01779
01780
01781
01782
01783
01784
01785 virtual bool problem_type() const;
01786
01787
01788
01789
01790
01791
01792
01793
01794
01795
01796
01797
01798
01799
01800 virtual void write_params( CvFileStorage* fs ) const;
01801
01802
01803
01804
01805
01806
01807
01808
01809
01810
01811
01812
01813
01814
01815
01816
01817
01818
01819
01820 virtual void read_params( CvFileStorage* fs, CvFileNode* fnode );
01821 int get_len(const CvMat* mat) const;
01822
01823
01824 CvDTreeTrainData* data;
01825 CvGBTreesParams params;
01826
01827 CvSeq** weak;
01828 CvMat* orig_response;
01829 CvMat* sum_response;
01830 CvMat* sum_response_tmp;
01831 CvMat* sample_idx;
01832 CvMat* subsample_train;
01833 CvMat* subsample_test;
01834 CvMat* missing;
01835 CvMat* class_labels;
01836
01837 cv::RNG* rng;
01838
01839 int class_count;
01840 float delta;
01841 float base_value;
01842
01843 };
01844
01845
01846
01847
01848
01849
01850
01852
01853 struct CV_EXPORTS_W_MAP CvANN_MLP_TrainParams
01854 {
01855 CvANN_MLP_TrainParams();
01856 CvANN_MLP_TrainParams( CvTermCriteria term_crit, int train_method,
01857 double param1, double param2=0 );
01858 ~CvANN_MLP_TrainParams();
01859
01860 enum { BACKPROP=0, RPROP=1 };
01861
01862 CV_PROP_RW CvTermCriteria term_crit;
01863 CV_PROP_RW int train_method;
01864
01865
01866 CV_PROP_RW double bp_dw_scale, bp_moment_scale;
01867
01868
01869 CV_PROP_RW double rp_dw0, rp_dw_plus, rp_dw_minus, rp_dw_min, rp_dw_max;
01870 };
01871
01872
01873 class CV_EXPORTS_W CvANN_MLP : public CvStatModel
01874 {
01875 public:
01876 CV_WRAP CvANN_MLP();
01877 CvANN_MLP( const CvMat* layerSizes,
01878 int activateFunc=CvANN_MLP::SIGMOID_SYM,
01879 double fparam1=0, double fparam2=0 );
01880
01881 virtual ~CvANN_MLP();
01882
01883 virtual void create( const CvMat* layerSizes,
01884 int activateFunc=CvANN_MLP::SIGMOID_SYM,
01885 double fparam1=0, double fparam2=0 );
01886
01887 virtual int train( const CvMat* inputs, const CvMat* outputs,
01888 const CvMat* sampleWeights, const CvMat* sampleIdx=0,
01889 CvANN_MLP_TrainParams params = CvANN_MLP_TrainParams(),
01890 int flags=0 );
01891 virtual float predict( const CvMat* inputs, CV_OUT CvMat* outputs ) const;
01892
01893 CV_WRAP CvANN_MLP( const cv::Mat& layerSizes,
01894 int activateFunc=CvANN_MLP::SIGMOID_SYM,
01895 double fparam1=0, double fparam2=0 );
01896
01897 CV_WRAP virtual void create( const cv::Mat& layerSizes,
01898 int activateFunc=CvANN_MLP::SIGMOID_SYM,
01899 double fparam1=0, double fparam2=0 );
01900
01901 CV_WRAP virtual int train( const cv::Mat& inputs, const cv::Mat& outputs,
01902 const cv::Mat& sampleWeights, const cv::Mat& sampleIdx=cv::Mat(),
01903 CvANN_MLP_TrainParams params = CvANN_MLP_TrainParams(),
01904 int flags=0 );
01905
01906 CV_WRAP virtual float predict( const cv::Mat& inputs, CV_OUT cv::Mat& outputs ) const;
01907
01908 CV_WRAP virtual void clear();
01909
01910
01911 enum { IDENTITY = 0, SIGMOID_SYM = 1, GAUSSIAN = 2 };
01912
01913
01914 enum { UPDATE_WEIGHTS = 1, NO_INPUT_SCALE = 2, NO_OUTPUT_SCALE = 4 };
01915
01916 virtual void read( CvFileStorage* fs, CvFileNode* node );
01917 virtual void write( CvFileStorage* storage, const char* name ) const;
01918
01919 int get_layer_count() { return layer_sizes ? layer_sizes->cols : 0; }
01920 const CvMat* get_layer_sizes() { return layer_sizes; }
01921 double* get_weights(int layer)
01922 {
01923 return layer_sizes && weights &&
01924 (unsigned)layer <= (unsigned)layer_sizes->cols ? weights[layer] : 0;
01925 }
01926
01927 virtual void calc_activ_func_deriv( CvMat* xf, CvMat* deriv, const double* bias ) const;
01928
01929 protected:
01930
01931 virtual bool prepare_to_train( const CvMat* _inputs, const CvMat* _outputs,
01932 const CvMat* _sample_weights, const CvMat* sampleIdx,
01933 CvVectors* _ivecs, CvVectors* _ovecs, double** _sw, int _flags );
01934
01935
01936 virtual int train_backprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
01937
01938
01939 virtual int train_rprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
01940
01941 virtual void calc_activ_func( CvMat* xf, const double* bias ) const;
01942 virtual void set_activ_func( int _activ_func=SIGMOID_SYM,
01943 double _f_param1=0, double _f_param2=0 );
01944 virtual void init_weights();
01945 virtual void scale_input( const CvMat* _src, CvMat* _dst ) const;
01946 virtual void scale_output( const CvMat* _src, CvMat* _dst ) const;
01947 virtual void calc_input_scale( const CvVectors* vecs, int flags );
01948 virtual void calc_output_scale( const CvVectors* vecs, int flags );
01949
01950 virtual void write_params( CvFileStorage* fs ) const;
01951 virtual void read_params( CvFileStorage* fs, CvFileNode* node );
01952
01953 CvMat* layer_sizes;
01954 CvMat* wbuf;
01955 CvMat* sample_weights;
01956 double** weights;
01957 double f_param1, f_param2;
01958 double min_val, max_val, min_val1, max_val1;
01959 int activ_func;
01960 int max_count, max_buf_sz;
01961 CvANN_MLP_TrainParams params;
01962 cv::RNG* rng;
01963 };
01964
01965
01966
01967
01968
01969
01970
01971 CVAPI(void) cvRandMVNormal( CvMat* mean, CvMat* cov, CvMat* sample,
01972 CvRNG* rng CV_DEFAULT(0) );
01973
01974
01975 CVAPI(void) cvRandGaussMixture( CvMat* means[],
01976 CvMat* covs[],
01977 float weights[],
01978 int clsnum,
01979 CvMat* sample,
01980 CvMat* sampClasses CV_DEFAULT(0) );
01981
01982 #define CV_TS_CONCENTRIC_SPHERES 0
01983
01984
01985 CVAPI(void) cvCreateTestSet( int type, CvMat** samples,
01986 int num_samples,
01987 int num_features,
01988 CvMat** responses,
01989 int num_classes, ... );
01990
01991
01992
01993
01994
01995 #define CV_COUNT 0
01996 #define CV_PORTION 1
01997
01998 struct CV_EXPORTS CvTrainTestSplit
01999 {
02000 CvTrainTestSplit();
02001 CvTrainTestSplit( int train_sample_count, bool mix = true);
02002 CvTrainTestSplit( float train_sample_portion, bool mix = true);
02003
02004 union
02005 {
02006 int count;
02007 float portion;
02008 } train_sample_part;
02009 int train_sample_part_mode;
02010
02011 bool mix;
02012 };
02013
02014 class CV_EXPORTS CvMLData
02015 {
02016 public:
02017 CvMLData();
02018 virtual ~CvMLData();
02019
02020
02021
02022
02023 int read_csv( const char* filename );
02024
02025 const CvMat* get_values() const;
02026 const CvMat* get_responses();
02027 const CvMat* get_missing() const;
02028
02029 void set_response_idx( int idx );
02030
02031 int get_response_idx() const;
02032
02033 void set_train_test_split( const CvTrainTestSplit * spl );
02034 const CvMat* get_train_sample_idx() const;
02035 const CvMat* get_test_sample_idx() const;
02036 void mix_train_and_test_idx();
02037
02038 const CvMat* get_var_idx();
02039 void chahge_var_idx( int vi, bool state );
02040
02041 void change_var_idx( int vi, bool state );
02042
02043 const CvMat* get_var_types();
02044 int get_var_type( int var_idx ) const;
02045
02046
02047
02048 void set_var_types( const char* str );
02049
02050
02051 void change_var_type( int var_idx, int type);
02052
02053 void set_delimiter( char ch );
02054 char get_delimiter() const;
02055
02056 void set_miss_ch( char ch );
02057 char get_miss_ch() const;
02058
02059 const std::map<std::string, int>& get_class_labels_map() const;
02060
02061 protected:
02062 virtual void clear();
02063
02064 void str_to_flt_elem( const char* token, float& flt_elem, int& type);
02065 void free_train_test_idx();
02066
02067 char delimiter;
02068 char miss_ch;
02069
02070
02071 CvMat* values;
02072 CvMat* missing;
02073 CvMat* var_types;
02074 CvMat* var_idx_mask;
02075
02076 CvMat* response_out;
02077 CvMat* var_idx_out;
02078 CvMat* var_types_out;
02079
02080 int response_idx;
02081
02082 int train_sample_count;
02083 bool mix;
02084
02085 int total_class_count;
02086 std::map<std::string, int> class_map;
02087
02088 CvMat* train_sample_idx;
02089 CvMat* test_sample_idx;
02090 int* sample_idx;
02091
02092 cv::RNG* rng;
02093 };
02094
02095
02096 namespace cv
02097 {
02098
02099 typedef CvStatModel StatModel;
02100 typedef CvParamGrid ParamGrid;
02101 typedef CvNormalBayesClassifier NormalBayesClassifier;
02102 typedef CvKNearest KNearest;
02103 typedef CvSVMParams SVMParams;
02104 typedef CvSVMKernel SVMKernel;
02105 typedef CvSVMSolver SVMSolver;
02106 typedef CvSVM SVM;
02107 typedef CvDTreeParams DTreeParams;
02108 typedef CvMLData TrainData;
02109 typedef CvDTree DecisionTree;
02110 typedef CvForestTree ForestTree;
02111 typedef CvRTParams RandomTreeParams;
02112 typedef CvRTrees RandomTrees;
02113 typedef CvERTreeTrainData ERTreeTRainData;
02114 typedef CvForestERTree ERTree;
02115 typedef CvERTrees ERTrees;
02116 typedef CvBoostParams BoostParams;
02117 typedef CvBoostTree BoostTree;
02118 typedef CvBoost Boost;
02119 typedef CvANN_MLP_TrainParams ANN_MLP_TrainParams;
02120 typedef CvANN_MLP NeuralNet_MLP;
02121 typedef CvGBTreesParams GradientBoostingTreeParams;
02122 typedef CvGBTrees GradientBoostingTrees;
02123
02124 template<> CV_EXPORTS void Ptr<CvDTreeSplit>::delete_obj();
02125
02126 CV_EXPORTS bool initModule_ml(void);
02127
02128 }
02129
02130 #endif // __cplusplus
02131 #endif // __OPENCV_ML_HPP__
02132
02133