00001 /*M/////////////////////////////////////////////////////////////////////////////////////// 00002 // 00003 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. 00004 // 00005 // By downloading, copying, installing or using the software you agree to this license. 00006 // If you do not agree to this license, do not download, install, 00007 // copy or use the software. 00008 // 00009 // 00010 // Intel License Agreement 00011 // 00012 // Copyright (C) 2000, Intel Corporation, all rights reserved. 00013 // Third party copyrights are property of their respective owners. 00014 // 00015 // Redistribution and use in source and binary forms, with or without modification, 00016 // are permitted provided that the following conditions are met: 00017 // 00018 // * Redistribution's of source code must retain the above copyright notice, 00019 // this list of conditions and the following disclaimer. 00020 // 00021 // * Redistribution's in binary form must reproduce the above copyright notice, 00022 // this list of conditions and the following disclaimer in the documentation 00023 // and/or other materials provided with the distribution. 00024 // 00025 // * The name of Intel Corporation may not be used to endorse or promote products 00026 // derived from this software without specific prior written permission. 00027 // 00028 // This software is provided by the copyright holders and contributors "as is" and 00029 // any express or implied warranties, including, but not limited to, the implied 00030 // warranties of merchantability and fitness for a particular purpose are disclaimed. 00031 // In no event shall the Intel Corporation or contributors be liable for any direct, 00032 // indirect, incidental, special, exemplary, or consequential damages 00033 // (including, but not limited to, procurement of substitute goods or services; 00034 // loss of use, data, or profits; or business interruption) however caused 00035 // and on any theory of liability, whether in contract, strict liability, 00036 // or tort (including negligence or otherwise) arising in any way out of 00037 // the use of this software, even if advised of the possibility of such damage. 00038 // 00039 //M*/ 00040 00041 #ifndef __OPENCV_ML_HPP__ 00042 #define __OPENCV_ML_HPP__ 00043 00044 #include "opencv2/core/core.hpp" 00045 #include <limits.h> 00046 00047 #ifdef __cplusplus 00048 00049 // Apple defines a check() macro somewhere in the debug headers 00050 // that interferes with a method definiton in this header 00051 #undef check 00052 00053 /****************************************************************************************\ 00054 * Main struct definitions * 00055 \****************************************************************************************/ 00056 00057 /* log(2*PI) */ 00058 #define CV_LOG2PI (1.8378770664093454835606594728112) 00059 00060 /* columns of <trainData> matrix are training samples */ 00061 #define CV_COL_SAMPLE 0 00062 00063 /* rows of <trainData> matrix are training samples */ 00064 #define CV_ROW_SAMPLE 1 00065 00066 #define CV_IS_ROW_SAMPLE(flags) ((flags) & CV_ROW_SAMPLE) 00067 00068 struct CvVectors 00069 { 00070 int type; 00071 int dims, count; 00072 CvVectors* next; 00073 union 00074 { 00075 uchar** ptr; 00076 float** fl; 00077 double** db; 00078 } data; 00079 }; 00080 00081 #if 0 00082 /* A structure, representing the lattice range of statmodel parameters. 00083 It is used for optimizing statmodel parameters by cross-validation method. 00084 The lattice is logarithmic, so <step> must be greater then 1. */ 00085 typedef struct CvParamLattice 00086 { 00087 double min_val; 00088 double max_val; 00089 double step; 00090 } 00091 CvParamLattice; 00092 00093 CV_INLINE CvParamLattice cvParamLattice( double min_val, double max_val, 00094 double log_step ) 00095 { 00096 CvParamLattice pl; 00097 pl.min_val = MIN( min_val, max_val ); 00098 pl.max_val = MAX( min_val, max_val ); 00099 pl.step = MAX( log_step, 1. ); 00100 return pl; 00101 } 00102 00103 CV_INLINE CvParamLattice cvDefaultParamLattice( void ) 00104 { 00105 CvParamLattice pl = {0,0,0}; 00106 return pl; 00107 } 00108 #endif 00109 00110 /* Variable type */ 00111 #define CV_VAR_NUMERICAL 0 00112 #define CV_VAR_ORDERED 0 00113 #define CV_VAR_CATEGORICAL 1 00114 00115 #define CV_TYPE_NAME_ML_SVM "opencv-ml-svm" 00116 #define CV_TYPE_NAME_ML_KNN "opencv-ml-knn" 00117 #define CV_TYPE_NAME_ML_NBAYES "opencv-ml-bayesian" 00118 #define CV_TYPE_NAME_ML_EM "opencv-ml-em" 00119 #define CV_TYPE_NAME_ML_BOOSTING "opencv-ml-boost-tree" 00120 #define CV_TYPE_NAME_ML_TREE "opencv-ml-tree" 00121 #define CV_TYPE_NAME_ML_ANN_MLP "opencv-ml-ann-mlp" 00122 #define CV_TYPE_NAME_ML_CNN "opencv-ml-cnn" 00123 #define CV_TYPE_NAME_ML_RTREES "opencv-ml-random-trees" 00124 #define CV_TYPE_NAME_ML_GBT "opencv-ml-gradient-boosting-trees" 00125 00126 #define CV_TRAIN_ERROR 0 00127 #define CV_TEST_ERROR 1 00128 00129 class CV_EXPORTS_W CvStatModel 00130 { 00131 public: 00132 CvStatModel(); 00133 virtual ~CvStatModel(); 00134 00135 virtual void clear(); 00136 00137 CV_WRAP virtual void save( const char* filename, const char* name=0 ) const; 00138 CV_WRAP virtual void load( const char* filename, const char* name=0 ); 00139 00140 virtual void write( CvFileStorage* storage, const char* name ) const; 00141 virtual void read( CvFileStorage* storage, CvFileNode* node ); 00142 00143 protected: 00144 const char* default_model_name; 00145 }; 00146 00147 /****************************************************************************************\ 00148 * Normal Bayes Classifier * 00149 \****************************************************************************************/ 00150 00151 /* The structure, representing the grid range of statmodel parameters. 00152 It is used for optimizing statmodel accuracy by varying model parameters, 00153 the accuracy estimate being computed by cross-validation. 00154 The grid is logarithmic, so <step> must be greater then 1. */ 00155 00156 class CvMLData; 00157 00158 struct CV_EXPORTS_W_MAP CvParamGrid 00159 { 00160 // SVM params type 00161 enum { SVM_C=0, SVM_GAMMA=1, SVM_P=2, SVM_NU=3, SVM_COEF=4, SVM_DEGREE=5 }; 00162 00163 CvParamGrid() 00164 { 00165 min_val = max_val = step = 0; 00166 } 00167 00168 CvParamGrid( double _min_val, double _max_val, double log_step ) 00169 { 00170 min_val = _min_val; 00171 max_val = _max_val; 00172 step = log_step; 00173 } 00174 //CvParamGrid( int param_id ); 00175 bool check() const; 00176 00177 CV_PROP_RW double min_val; 00178 CV_PROP_RW double max_val; 00179 CV_PROP_RW double step; 00180 }; 00181 00182 class CV_EXPORTS_W CvNormalBayesClassifier : public CvStatModel 00183 { 00184 public: 00185 CV_WRAP CvNormalBayesClassifier(); 00186 virtual ~CvNormalBayesClassifier(); 00187 00188 CvNormalBayesClassifier( const CvMat* trainData, const CvMat* responses, 00189 const CvMat* varIdx=0, const CvMat* sampleIdx=0 ); 00190 00191 virtual bool train( const CvMat* trainData, const CvMat* responses, 00192 const CvMat* varIdx = 0, const CvMat* sampleIdx=0, bool update=false ); 00193 00194 virtual float predict( const CvMat* samples, CV_OUT CvMat* results=0 ) const; 00195 CV_WRAP virtual void clear(); 00196 00197 #ifndef SWIG 00198 CV_WRAP CvNormalBayesClassifier( const cv::Mat& trainData, const cv::Mat& responses, 00199 const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat() ); 00200 CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses, 00201 const cv::Mat& varIdx = cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(), 00202 bool update=false ); 00203 CV_WRAP virtual float predict( const cv::Mat& samples, CV_OUT cv::Mat* results=0 ) const; 00204 #endif 00205 00206 virtual void write( CvFileStorage* storage, const char* name ) const; 00207 virtual void read( CvFileStorage* storage, CvFileNode* node ); 00208 00209 protected: 00210 int var_count, var_all; 00211 CvMat* var_idx; 00212 CvMat* cls_labels; 00213 CvMat** count; 00214 CvMat** sum; 00215 CvMat** productsum; 00216 CvMat** avg; 00217 CvMat** inv_eigen_values; 00218 CvMat** cov_rotate_mats; 00219 CvMat* c; 00220 }; 00221 00222 00223 /****************************************************************************************\ 00224 * K-Nearest Neighbour Classifier * 00225 \****************************************************************************************/ 00226 00227 // k Nearest Neighbors 00228 class CV_EXPORTS_W CvKNearest : public CvStatModel 00229 { 00230 public: 00231 00232 CV_WRAP CvKNearest(); 00233 virtual ~CvKNearest(); 00234 00235 CvKNearest( const CvMat* trainData, const CvMat* responses, 00236 const CvMat* sampleIdx=0, bool isRegression=false, int max_k=32 ); 00237 00238 virtual bool train( const CvMat* trainData, const CvMat* responses, 00239 const CvMat* sampleIdx=0, bool is_regression=false, 00240 int maxK=32, bool updateBase=false ); 00241 00242 virtual float find_nearest( const CvMat* samples, int k, CV_OUT CvMat* results=0, 00243 const float** neighbors=0, CV_OUT CvMat* neighborResponses=0, CV_OUT CvMat* dist=0 ) const; 00244 00245 #ifndef SWIG 00246 CV_WRAP CvKNearest( const cv::Mat& trainData, const cv::Mat& responses, 00247 const cv::Mat& sampleIdx=cv::Mat(), bool isRegression=false, int max_k=32 ); 00248 00249 CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses, 00250 const cv::Mat& sampleIdx=cv::Mat(), bool isRegression=false, 00251 int maxK=32, bool updateBase=false ); 00252 00253 virtual float find_nearest( const cv::Mat& samples, int k, cv::Mat* results=0, 00254 const float** neighbors=0, cv::Mat* neighborResponses=0, 00255 cv::Mat* dist=0 ) const; 00256 CV_WRAP virtual float find_nearest( const cv::Mat& samples, int k, CV_OUT cv::Mat& results, 00257 CV_OUT cv::Mat& neighborResponses, CV_OUT cv::Mat& dists) const; 00258 #endif 00259 00260 virtual void clear(); 00261 int get_max_k() const; 00262 int get_var_count() const; 00263 int get_sample_count() const; 00264 bool is_regression() const; 00265 00266 virtual float write_results( int k, int k1, int start, int end, 00267 const float* neighbor_responses, const float* dist, CvMat* _results, 00268 CvMat* _neighbor_responses, CvMat* _dist, Cv32suf* sort_buf ) const; 00269 00270 virtual void find_neighbors_direct( const CvMat* _samples, int k, int start, int end, 00271 float* neighbor_responses, const float** neighbors, float* dist ) const; 00272 00273 protected: 00274 00275 int max_k, var_count; 00276 int total; 00277 bool regression; 00278 CvVectors* samples; 00279 }; 00280 00281 /****************************************************************************************\ 00282 * Support Vector Machines * 00283 \****************************************************************************************/ 00284 00285 // SVM training parameters 00286 struct CV_EXPORTS_W_MAP CvSVMParams 00287 { 00288 CvSVMParams(); 00289 CvSVMParams( int _svm_type, int _kernel_type, 00290 double _degree, double _gamma, double _coef0, 00291 double Cvalue, double _nu, double _p, 00292 CvMat* _class_weights, CvTermCriteria _term_crit ); 00293 00294 CV_PROP_RW int svm_type; 00295 CV_PROP_RW int kernel_type; 00296 CV_PROP_RW double degree; // for poly 00297 CV_PROP_RW double gamma; // for poly/rbf/sigmoid 00298 CV_PROP_RW double coef0; // for poly/sigmoid 00299 00300 CV_PROP_RW double C; // for CV_SVM_C_SVC, CV_SVM_EPS_SVR and CV_SVM_NU_SVR 00301 CV_PROP_RW double nu; // for CV_SVM_NU_SVC, CV_SVM_ONE_CLASS, and CV_SVM_NU_SVR 00302 CV_PROP_RW double p; // for CV_SVM_EPS_SVR 00303 CvMat* class_weights; // for CV_SVM_C_SVC 00304 CV_PROP_RW CvTermCriteria term_crit; // termination criteria 00305 }; 00306 00307 00308 struct CV_EXPORTS CvSVMKernel 00309 { 00310 typedef void (CvSVMKernel::*Calc)( int vec_count, int vec_size, const float** vecs, 00311 const float* another, float* results ); 00312 CvSVMKernel(); 00313 CvSVMKernel( const CvSVMParams* params, Calc _calc_func ); 00314 virtual bool create( const CvSVMParams* params, Calc _calc_func ); 00315 virtual ~CvSVMKernel(); 00316 00317 virtual void clear(); 00318 virtual void calc( int vcount, int n, const float** vecs, const float* another, float* results ); 00319 00320 const CvSVMParams* params; 00321 Calc calc_func; 00322 00323 virtual void calc_non_rbf_base( int vec_count, int vec_size, const float** vecs, 00324 const float* another, float* results, 00325 double alpha, double beta ); 00326 00327 virtual void calc_linear( int vec_count, int vec_size, const float** vecs, 00328 const float* another, float* results ); 00329 virtual void calc_rbf( int vec_count, int vec_size, const float** vecs, 00330 const float* another, float* results ); 00331 virtual void calc_poly( int vec_count, int vec_size, const float** vecs, 00332 const float* another, float* results ); 00333 virtual void calc_sigmoid( int vec_count, int vec_size, const float** vecs, 00334 const float* another, float* results ); 00335 }; 00336 00337 00338 struct CvSVMKernelRow 00339 { 00340 CvSVMKernelRow* prev; 00341 CvSVMKernelRow* next; 00342 float* data; 00343 }; 00344 00345 00346 struct CvSVMSolutionInfo 00347 { 00348 double obj; 00349 double rho; 00350 double upper_bound_p; 00351 double upper_bound_n; 00352 double r; // for Solver_NU 00353 }; 00354 00355 class CV_EXPORTS CvSVMSolver 00356 { 00357 public: 00358 typedef bool (CvSVMSolver::*SelectWorkingSet)( int& i, int& j ); 00359 typedef float* (CvSVMSolver::*GetRow)( int i, float* row, float* dst, bool existed ); 00360 typedef void (CvSVMSolver::*CalcRho)( double& rho, double& r ); 00361 00362 CvSVMSolver(); 00363 00364 CvSVMSolver( int count, int var_count, const float** samples, schar* y, 00365 int alpha_count, double* alpha, double Cp, double Cn, 00366 CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row, 00367 SelectWorkingSet select_working_set, CalcRho calc_rho ); 00368 virtual bool create( int count, int var_count, const float** samples, schar* y, 00369 int alpha_count, double* alpha, double Cp, double Cn, 00370 CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row, 00371 SelectWorkingSet select_working_set, CalcRho calc_rho ); 00372 virtual ~CvSVMSolver(); 00373 00374 virtual void clear(); 00375 virtual bool solve_generic( CvSVMSolutionInfo& si ); 00376 00377 virtual bool solve_c_svc( int count, int var_count, const float** samples, schar* y, 00378 double Cp, double Cn, CvMemStorage* storage, 00379 CvSVMKernel* kernel, double* alpha, CvSVMSolutionInfo& si ); 00380 virtual bool solve_nu_svc( int count, int var_count, const float** samples, schar* y, 00381 CvMemStorage* storage, CvSVMKernel* kernel, 00382 double* alpha, CvSVMSolutionInfo& si ); 00383 virtual bool solve_one_class( int count, int var_count, const float** samples, 00384 CvMemStorage* storage, CvSVMKernel* kernel, 00385 double* alpha, CvSVMSolutionInfo& si ); 00386 00387 virtual bool solve_eps_svr( int count, int var_count, const float** samples, const float* y, 00388 CvMemStorage* storage, CvSVMKernel* kernel, 00389 double* alpha, CvSVMSolutionInfo& si ); 00390 00391 virtual bool solve_nu_svr( int count, int var_count, const float** samples, const float* y, 00392 CvMemStorage* storage, CvSVMKernel* kernel, 00393 double* alpha, CvSVMSolutionInfo& si ); 00394 00395 virtual float* get_row_base( int i, bool* _existed ); 00396 virtual float* get_row( int i, float* dst ); 00397 00398 int sample_count; 00399 int var_count; 00400 int cache_size; 00401 int cache_line_size; 00402 const float** samples; 00403 const CvSVMParams* params; 00404 CvMemStorage* storage; 00405 CvSVMKernelRow lru_list; 00406 CvSVMKernelRow* rows; 00407 00408 int alpha_count; 00409 00410 double* G; 00411 double* alpha; 00412 00413 // -1 - lower bound, 0 - free, 1 - upper bound 00414 schar* alpha_status; 00415 00416 schar* y; 00417 double* b; 00418 float* buf[2]; 00419 double eps; 00420 int max_iter; 00421 double C[2]; // C[0] == Cn, C[1] == Cp 00422 CvSVMKernel* kernel; 00423 00424 SelectWorkingSet select_working_set_func; 00425 CalcRho calc_rho_func; 00426 GetRow get_row_func; 00427 00428 virtual bool select_working_set( int& i, int& j ); 00429 virtual bool select_working_set_nu_svm( int& i, int& j ); 00430 virtual void calc_rho( double& rho, double& r ); 00431 virtual void calc_rho_nu_svm( double& rho, double& r ); 00432 00433 virtual float* get_row_svc( int i, float* row, float* dst, bool existed ); 00434 virtual float* get_row_one_class( int i, float* row, float* dst, bool existed ); 00435 virtual float* get_row_svr( int i, float* row, float* dst, bool existed ); 00436 }; 00437 00438 00439 struct CvSVMDecisionFunc 00440 { 00441 double rho; 00442 int sv_count; 00443 double* alpha; 00444 int* sv_index; 00445 }; 00446 00447 00448 // SVM model 00449 class CV_EXPORTS_W CvSVM : public CvStatModel 00450 { 00451 public: 00452 // SVM type 00453 enum { C_SVC=100, NU_SVC=101, ONE_CLASS=102, EPS_SVR=103, NU_SVR=104 }; 00454 00455 // SVM kernel type 00456 enum { LINEAR=0, POLY=1, RBF=2, SIGMOID=3 }; 00457 00458 // SVM params type 00459 enum { C=0, GAMMA=1, P=2, NU=3, COEF=4, DEGREE=5 }; 00460 00461 CV_WRAP CvSVM(); 00462 virtual ~CvSVM(); 00463 00464 CvSVM( const CvMat* trainData, const CvMat* responses, 00465 const CvMat* varIdx=0, const CvMat* sampleIdx=0, 00466 CvSVMParams params=CvSVMParams() ); 00467 00468 virtual bool train( const CvMat* trainData, const CvMat* responses, 00469 const CvMat* varIdx=0, const CvMat* sampleIdx=0, 00470 CvSVMParams params=CvSVMParams() ); 00471 00472 virtual bool train_auto( const CvMat* trainData, const CvMat* responses, 00473 const CvMat* varIdx, const CvMat* sampleIdx, CvSVMParams params, 00474 int kfold = 10, 00475 CvParamGrid Cgrid = get_default_grid(CvSVM::C), 00476 CvParamGrid gammaGrid = get_default_grid(CvSVM::GAMMA), 00477 CvParamGrid pGrid = get_default_grid(CvSVM::P), 00478 CvParamGrid nuGrid = get_default_grid(CvSVM::NU), 00479 CvParamGrid coeffGrid = get_default_grid(CvSVM::COEF), 00480 CvParamGrid degreeGrid = get_default_grid(CvSVM::DEGREE), 00481 bool balanced=false ); 00482 00483 virtual float predict( const CvMat* sample, bool returnDFVal=false ) const; 00484 virtual float predict( const CvMat* samples, CvMat* results ) const; 00485 00486 #ifndef SWIG 00487 CV_WRAP CvSVM( const cv::Mat& trainData, const cv::Mat& responses, 00488 const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(), 00489 CvSVMParams params=CvSVMParams() ); 00490 00491 CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses, 00492 const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(), 00493 CvSVMParams params=CvSVMParams() ); 00494 00495 CV_WRAP virtual bool train_auto( const cv::Mat& trainData, const cv::Mat& responses, 00496 const cv::Mat& varIdx, const cv::Mat& sampleIdx, CvSVMParams params, 00497 int k_fold = 10, 00498 CvParamGrid Cgrid = CvSVM::get_default_grid(CvSVM::C), 00499 CvParamGrid gammaGrid = CvSVM::get_default_grid(CvSVM::GAMMA), 00500 CvParamGrid pGrid = CvSVM::get_default_grid(CvSVM::P), 00501 CvParamGrid nuGrid = CvSVM::get_default_grid(CvSVM::NU), 00502 CvParamGrid coeffGrid = CvSVM::get_default_grid(CvSVM::COEF), 00503 CvParamGrid degreeGrid = CvSVM::get_default_grid(CvSVM::DEGREE), 00504 bool balanced=false); 00505 CV_WRAP virtual float predict( const cv::Mat& sample, bool returnDFVal=false ) const; 00506 #endif 00507 00508 CV_WRAP virtual int get_support_vector_count() const; 00509 virtual const float* get_support_vector(int i) const; 00510 virtual CvSVMParams get_params() const { return params; }; 00511 CV_WRAP virtual void clear(); 00512 00513 static CvParamGrid get_default_grid( int param_id ); 00514 00515 virtual void write( CvFileStorage* storage, const char* name ) const; 00516 virtual void read( CvFileStorage* storage, CvFileNode* node ); 00517 CV_WRAP int get_var_count() const { return var_idx ? var_idx->cols : var_all; } 00518 00519 protected: 00520 00521 virtual bool set_params( const CvSVMParams& params ); 00522 virtual bool train1( int sample_count, int var_count, const float** samples, 00523 const void* responses, double Cp, double Cn, 00524 CvMemStorage* _storage, double* alpha, double& rho ); 00525 virtual bool do_train( int svm_type, int sample_count, int var_count, const float** samples, 00526 const CvMat* responses, CvMemStorage* _storage, double* alpha ); 00527 virtual void create_kernel(); 00528 virtual void create_solver(); 00529 00530 virtual float predict( const float* row_sample, int row_len, bool returnDFVal=false ) const; 00531 00532 virtual void write_params( CvFileStorage* fs ) const; 00533 virtual void read_params( CvFileStorage* fs, CvFileNode* node ); 00534 00535 CvSVMParams params; 00536 CvMat* class_labels; 00537 int var_all; 00538 float** sv; 00539 int sv_total; 00540 CvMat* var_idx; 00541 CvMat* class_weights; 00542 CvSVMDecisionFunc* decision_func; 00543 CvMemStorage* storage; 00544 00545 CvSVMSolver* solver; 00546 CvSVMKernel* kernel; 00547 }; 00548 00549 /****************************************************************************************\ 00550 * Expectation - Maximization * 00551 \****************************************************************************************/ 00552 00553 struct CV_EXPORTS_W_MAP CvEMParams 00554 { 00555 CvEMParams() : nclusters(10), cov_mat_type(1/*CvEM::COV_MAT_DIAGONAL*/), 00556 start_step(0/*CvEM::START_AUTO_STEP*/), probs(0), weights(0), means(0), covs(0) 00557 { 00558 term_crit=cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 100, FLT_EPSILON ); 00559 } 00560 00561 CvEMParams( int _nclusters, int _cov_mat_type=1/*CvEM::COV_MAT_DIAGONAL*/, 00562 int _start_step=0/*CvEM::START_AUTO_STEP*/, 00563 CvTermCriteria _term_crit=cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 100, FLT_EPSILON), 00564 const CvMat* _probs=0, const CvMat* _weights=0, const CvMat* _means=0, const CvMat** _covs=0 ) : 00565 nclusters(_nclusters), cov_mat_type(_cov_mat_type), start_step(_start_step), 00566 probs(_probs), weights(_weights), means(_means), covs(_covs), term_crit(_term_crit) 00567 {} 00568 00569 CV_PROP_RW int nclusters; 00570 CV_PROP_RW int cov_mat_type; 00571 CV_PROP_RW int start_step; 00572 const CvMat* probs; 00573 const CvMat* weights; 00574 const CvMat* means; 00575 const CvMat** covs; 00576 CV_PROP_RW CvTermCriteria term_crit; 00577 }; 00578 00579 00580 class CV_EXPORTS_W CvEM : public CvStatModel 00581 { 00582 public: 00583 // Type of covariation matrices 00584 enum { COV_MAT_SPHERICAL=0, COV_MAT_DIAGONAL=1, COV_MAT_GENERIC=2 }; 00585 00586 // The initial step 00587 enum { START_E_STEP=1, START_M_STEP=2, START_AUTO_STEP=0 }; 00588 00589 CV_WRAP CvEM(); 00590 CvEM( const CvMat* samples, const CvMat* sampleIdx=0, 00591 CvEMParams params=CvEMParams(), CvMat* labels=0 ); 00592 //CvEM (CvEMParams params, CvMat * means, CvMat ** covs, CvMat * weights, 00593 // CvMat * probs, CvMat * log_weight_div_det, CvMat * inv_eigen_values, CvMat** cov_rotate_mats); 00594 00595 virtual ~CvEM(); 00596 00597 virtual bool train( const CvMat* samples, const CvMat* sampleIdx=0, 00598 CvEMParams params=CvEMParams(), CvMat* labels=0 ); 00599 00600 virtual float predict( const CvMat* sample, CV_OUT CvMat* probs ) const; 00601 00602 #ifndef SWIG 00603 CV_WRAP CvEM( const cv::Mat& samples, const cv::Mat& sampleIdx=cv::Mat(), 00604 CvEMParams params=CvEMParams() ); 00605 00606 CV_WRAP virtual bool train( const cv::Mat& samples, 00607 const cv::Mat& sampleIdx=cv::Mat(), 00608 CvEMParams params=CvEMParams(), 00609 CV_OUT cv::Mat* labels=0 ); 00610 00611 CV_WRAP virtual float predict( const cv::Mat& sample, CV_OUT cv::Mat* probs=0 ) const; 00612 CV_WRAP virtual double calcLikelihood( const cv::Mat &sample ) const; 00613 00614 CV_WRAP int getNClusters() const; 00615 CV_WRAP cv::Mat getMeans() const; 00616 CV_WRAP void getCovs(CV_OUT std::vector<cv::Mat>& covs) const; 00617 CV_WRAP cv::Mat getWeights() const; 00618 CV_WRAP cv::Mat getProbs() const; 00619 00620 CV_WRAP inline double getLikelihood() const { return log_likelihood; } 00621 CV_WRAP inline double getLikelihoodDelta() const { return log_likelihood_delta; } 00622 #endif 00623 00624 CV_WRAP virtual void clear(); 00625 00626 int get_nclusters() const; 00627 const CvMat* get_means() const; 00628 const CvMat** get_covs() const; 00629 const CvMat* get_weights() const; 00630 const CvMat* get_probs() const; 00631 00632 inline double get_log_likelihood() const { return log_likelihood; } 00633 inline double get_log_likelihood_delta() const { return log_likelihood_delta; } 00634 00635 // inline const CvMat * get_log_weight_div_det () const { return log_weight_div_det; }; 00636 // inline const CvMat * get_inv_eigen_values () const { return inv_eigen_values; }; 00637 // inline const CvMat ** get_cov_rotate_mats () const { return cov_rotate_mats; }; 00638 00639 virtual void read( CvFileStorage* fs, CvFileNode* node ); 00640 virtual void write( CvFileStorage* fs, const char* name ) const; 00641 00642 virtual void write_params( CvFileStorage* fs ) const; 00643 virtual void read_params( CvFileStorage* fs, CvFileNode* node ); 00644 00645 protected: 00646 00647 virtual void set_params( const CvEMParams& params, 00648 const CvVectors& train_data ); 00649 virtual void init_em( const CvVectors& train_data ); 00650 virtual double run_em( const CvVectors& train_data ); 00651 virtual void init_auto( const CvVectors& samples ); 00652 virtual void kmeans( const CvVectors& train_data, int nclusters, 00653 CvMat* labels, CvTermCriteria criteria, 00654 const CvMat* means ); 00655 CvEMParams params; 00656 double log_likelihood; 00657 double log_likelihood_delta; 00658 00659 CvMat* means; 00660 CvMat** covs; 00661 CvMat* weights; 00662 CvMat* probs; 00663 00664 CvMat* log_weight_div_det; 00665 CvMat* inv_eigen_values; 00666 CvMat** cov_rotate_mats; 00667 }; 00668 00669 /****************************************************************************************\ 00670 * Decision Tree * 00671 \****************************************************************************************/\ 00672 struct CvPair16u32s 00673 { 00674 unsigned short* u; 00675 int* i; 00676 }; 00677 00678 00679 #define CV_DTREE_CAT_DIR(idx,subset) \ 00680 (2*((subset[(idx)>>5]&(1 << ((idx) & 31)))==0)-1) 00681 00682 struct CvDTreeSplit 00683 { 00684 int var_idx; 00685 int condensed_idx; 00686 int inversed; 00687 float quality; 00688 CvDTreeSplit* next; 00689 union 00690 { 00691 int subset[2]; 00692 struct 00693 { 00694 float c; 00695 int split_point; 00696 } 00697 ord; 00698 }; 00699 }; 00700 00701 struct CvDTreeNode 00702 { 00703 int class_idx; 00704 int Tn; 00705 double value; 00706 00707 CvDTreeNode* parent; 00708 CvDTreeNode* left; 00709 CvDTreeNode* right; 00710 00711 CvDTreeSplit* split; 00712 00713 int sample_count; 00714 int depth; 00715 int* num_valid; 00716 int offset; 00717 int buf_idx; 00718 double maxlr; 00719 00720 // global pruning data 00721 int complexity; 00722 double alpha; 00723 double node_risk, tree_risk, tree_error; 00724 00725 // cross-validation pruning data 00726 int* cv_Tn; 00727 double* cv_node_risk; 00728 double* cv_node_error; 00729 00730 int get_num_valid(int vi) { return num_valid ? num_valid[vi] : sample_count; } 00731 void set_num_valid(int vi, int n) { if( num_valid ) num_valid[vi] = n; } 00732 }; 00733 00734 00735 struct CV_EXPORTS_W_MAP CvDTreeParams 00736 { 00737 CV_PROP_RW int max_categories; 00738 CV_PROP_RW int max_depth; 00739 CV_PROP_RW int min_sample_count; 00740 CV_PROP_RW int cv_folds; 00741 CV_PROP_RW bool use_surrogates; 00742 CV_PROP_RW bool use_1se_rule; 00743 CV_PROP_RW bool truncate_pruned_tree; 00744 CV_PROP_RW float regression_accuracy; 00745 const float* priors; 00746 00747 CvDTreeParams() : max_categories(10), max_depth(INT_MAX), min_sample_count(10), 00748 cv_folds(10), use_surrogates(true), use_1se_rule(true), 00749 truncate_pruned_tree(true), regression_accuracy(0.01f), priors(0) 00750 {} 00751 00752 CvDTreeParams( int _max_depth, int _min_sample_count, 00753 float _regression_accuracy, bool _use_surrogates, 00754 int _max_categories, int _cv_folds, 00755 bool _use_1se_rule, bool _truncate_pruned_tree, 00756 const float* _priors ) : 00757 max_categories(_max_categories), max_depth(_max_depth), 00758 min_sample_count(_min_sample_count), cv_folds (_cv_folds), 00759 use_surrogates(_use_surrogates), use_1se_rule(_use_1se_rule), 00760 truncate_pruned_tree(_truncate_pruned_tree), 00761 regression_accuracy(_regression_accuracy), 00762 priors(_priors) 00763 {} 00764 }; 00765 00766 00767 struct CV_EXPORTS CvDTreeTrainData 00768 { 00769 CvDTreeTrainData(); 00770 CvDTreeTrainData( const CvMat* trainData, int tflag, 00771 const CvMat* responses, const CvMat* varIdx=0, 00772 const CvMat* sampleIdx=0, const CvMat* varType=0, 00773 const CvMat* missingDataMask=0, 00774 const CvDTreeParams& params=CvDTreeParams(), 00775 bool _shared=false, bool _add_labels=false ); 00776 virtual ~CvDTreeTrainData(); 00777 00778 virtual void set_data( const CvMat* trainData, int tflag, 00779 const CvMat* responses, const CvMat* varIdx=0, 00780 const CvMat* sampleIdx=0, const CvMat* varType=0, 00781 const CvMat* missingDataMask=0, 00782 const CvDTreeParams& params=CvDTreeParams(), 00783 bool _shared=false, bool _add_labels=false, 00784 bool _update_data=false ); 00785 virtual void do_responses_copy(); 00786 00787 virtual void get_vectors( const CvMat* _subsample_idx, 00788 float* values, uchar* missing, float* responses, bool get_class_idx=false ); 00789 00790 virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx ); 00791 00792 virtual void write_params( CvFileStorage* fs ) const; 00793 virtual void read_params( CvFileStorage* fs, CvFileNode* node ); 00794 00795 // release all the data 00796 virtual void clear(); 00797 00798 int get_num_classes() const; 00799 int get_var_type(int vi) const; 00800 int get_work_var_count() const {return work_var_count;} 00801 00802 virtual const float* get_ord_responses( CvDTreeNode* n, float* values_buf, int* sample_indices_buf ); 00803 virtual const int* get_class_labels( CvDTreeNode* n, int* labels_buf ); 00804 virtual const int* get_cv_labels( CvDTreeNode* n, int* labels_buf ); 00805 virtual const int* get_sample_indices( CvDTreeNode* n, int* indices_buf ); 00806 virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf ); 00807 virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* sorted_indices_buf, 00808 const float** ord_values, const int** sorted_indices, int* sample_indices_buf ); 00809 virtual int get_child_buf_idx( CvDTreeNode* n ); 00810 00812 00813 virtual bool set_params( const CvDTreeParams& params ); 00814 virtual CvDTreeNode* new_node( CvDTreeNode* parent, int count, 00815 int storage_idx, int offset ); 00816 00817 virtual CvDTreeSplit* new_split_ord( int vi, float cmp_val, 00818 int split_point, int inversed, float quality ); 00819 virtual CvDTreeSplit* new_split_cat( int vi, float quality ); 00820 virtual void free_node_data( CvDTreeNode* node ); 00821 virtual void free_train_data(); 00822 virtual void free_node( CvDTreeNode* node ); 00823 00824 int sample_count, var_all, var_count, max_c_count; 00825 int ord_var_count, cat_var_count, work_var_count; 00826 bool have_labels, have_priors; 00827 bool is_classifier; 00828 int tflag; 00829 00830 const CvMat* train_data; 00831 const CvMat* responses; 00832 CvMat* responses_copy; // used in Boosting 00833 00834 int buf_count, buf_size; 00835 bool shared; 00836 int is_buf_16u; 00837 00838 CvMat* cat_count; 00839 CvMat* cat_ofs; 00840 CvMat* cat_map; 00841 00842 CvMat* counts; 00843 CvMat* buf; 00844 CvMat* direction; 00845 CvMat* split_buf; 00846 00847 CvMat* var_idx; 00848 CvMat* var_type; // i-th element = 00849 // k<0 - ordered 00850 // k>=0 - categorical, see k-th element of cat_* arrays 00851 CvMat* priors; 00852 CvMat* priors_mult; 00853 00854 CvDTreeParams params; 00855 00856 CvMemStorage* tree_storage; 00857 CvMemStorage* temp_storage; 00858 00859 CvDTreeNode* data_root; 00860 00861 CvSet* node_heap; 00862 CvSet* split_heap; 00863 CvSet* cv_heap; 00864 CvSet* nv_heap; 00865 00866 cv::RNG* rng; 00867 }; 00868 00869 class CvDTree; 00870 class CvForestTree; 00871 00872 namespace cv 00873 { 00874 struct DTreeBestSplitFinder; 00875 struct ForestTreeBestSplitFinder; 00876 } 00877 00878 class CV_EXPORTS_W CvDTree : public CvStatModel 00879 { 00880 public: 00881 CV_WRAP CvDTree(); 00882 virtual ~CvDTree(); 00883 00884 virtual bool train( const CvMat* trainData, int tflag, 00885 const CvMat* responses, const CvMat* varIdx=0, 00886 const CvMat* sampleIdx=0, const CvMat* varType=0, 00887 const CvMat* missingDataMask=0, 00888 CvDTreeParams params=CvDTreeParams() ); 00889 00890 virtual bool train( CvMLData* trainData, CvDTreeParams params=CvDTreeParams() ); 00891 00892 // type in {CV_TRAIN_ERROR, CV_TEST_ERROR} 00893 virtual float calc_error( CvMLData* trainData, int type, std::vector<float> *resp = 0 ); 00894 00895 virtual bool train( CvDTreeTrainData* trainData, const CvMat* subsampleIdx ); 00896 00897 virtual CvDTreeNode* predict( const CvMat* sample, const CvMat* missingDataMask=0, 00898 bool preprocessedInput=false ) const; 00899 00900 #ifndef SWIG 00901 CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag, 00902 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(), 00903 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(), 00904 const cv::Mat& missingDataMask=cv::Mat(), 00905 CvDTreeParams params=CvDTreeParams() ); 00906 00907 CV_WRAP virtual CvDTreeNode* predict( const cv::Mat& sample, const cv::Mat& missingDataMask=cv::Mat(), 00908 bool preprocessedInput=false ) const; 00909 CV_WRAP virtual cv::Mat getVarImportance(); 00910 #endif 00911 00912 virtual const CvMat* get_var_importance(); 00913 CV_WRAP virtual void clear(); 00914 00915 virtual void read( CvFileStorage* fs, CvFileNode* node ); 00916 virtual void write( CvFileStorage* fs, const char* name ) const; 00917 00918 // special read & write methods for trees in the tree ensembles 00919 virtual void read( CvFileStorage* fs, CvFileNode* node, 00920 CvDTreeTrainData* data ); 00921 virtual void write( CvFileStorage* fs ) const; 00922 00923 const CvDTreeNode* get_root() const; 00924 int get_pruned_tree_idx() const; 00925 CvDTreeTrainData* get_data(); 00926 00927 protected: 00928 friend struct cv::DTreeBestSplitFinder; 00929 00930 virtual bool do_train( const CvMat* _subsample_idx ); 00931 00932 virtual void try_split_node( CvDTreeNode* n ); 00933 virtual void split_node_data( CvDTreeNode* n ); 00934 virtual CvDTreeSplit* find_best_split( CvDTreeNode* n ); 00935 virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi, 00936 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 00937 virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi, 00938 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 00939 virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi, 00940 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 00941 virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi, 00942 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 00943 virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi, uchar* ext_buf = 0 ); 00944 virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi, uchar* ext_buf = 0 ); 00945 virtual double calc_node_dir( CvDTreeNode* node ); 00946 virtual void complete_node_dir( CvDTreeNode* node ); 00947 virtual void cluster_categories( const int* vectors, int vector_count, 00948 int var_count, int* sums, int k, int* cluster_labels ); 00949 00950 virtual void calc_node_value( CvDTreeNode* node ); 00951 00952 virtual void prune_cv(); 00953 virtual double update_tree_rnc( int T, int fold ); 00954 virtual int cut_tree( int T, int fold, double min_alpha ); 00955 virtual void free_prune_data(bool cut_tree); 00956 virtual void free_tree(); 00957 00958 virtual void write_node( CvFileStorage* fs, CvDTreeNode* node ) const; 00959 virtual void write_split( CvFileStorage* fs, CvDTreeSplit* split ) const; 00960 virtual CvDTreeNode* read_node( CvFileStorage* fs, CvFileNode* node, CvDTreeNode* parent ); 00961 virtual CvDTreeSplit* read_split( CvFileStorage* fs, CvFileNode* node ); 00962 virtual void write_tree_nodes( CvFileStorage* fs ) const; 00963 virtual void read_tree_nodes( CvFileStorage* fs, CvFileNode* node ); 00964 00965 CvDTreeNode* root; 00966 CvMat* var_importance; 00967 CvDTreeTrainData* data; 00968 00969 public: 00970 int pruned_tree_idx; 00971 }; 00972 00973 00974 /****************************************************************************************\ 00975 * Random Trees Classifier * 00976 \****************************************************************************************/ 00977 00978 class CvRTrees; 00979 00980 class CV_EXPORTS CvForestTree: public CvDTree 00981 { 00982 public: 00983 CvForestTree(); 00984 virtual ~CvForestTree(); 00985 00986 virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx, CvRTrees* forest ); 00987 00988 virtual int get_var_count() const {return data ? data->var_count : 0;} 00989 virtual void read( CvFileStorage* fs, CvFileNode* node, CvRTrees* forest, CvDTreeTrainData* _data ); 00990 00991 /* dummy methods to avoid warnings: BEGIN */ 00992 virtual bool train( const CvMat* trainData, int tflag, 00993 const CvMat* responses, const CvMat* varIdx=0, 00994 const CvMat* sampleIdx=0, const CvMat* varType=0, 00995 const CvMat* missingDataMask=0, 00996 CvDTreeParams params=CvDTreeParams() ); 00997 00998 virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx ); 00999 virtual void read( CvFileStorage* fs, CvFileNode* node ); 01000 virtual void read( CvFileStorage* fs, CvFileNode* node, 01001 CvDTreeTrainData* data ); 01002 /* dummy methods to avoid warnings: END */ 01003 01004 protected: 01005 friend struct cv::ForestTreeBestSplitFinder; 01006 01007 virtual CvDTreeSplit* find_best_split( CvDTreeNode* n ); 01008 CvRTrees* forest; 01009 }; 01010 01011 01012 struct CV_EXPORTS_W_MAP CvRTParams : public CvDTreeParams 01013 { 01014 //Parameters for the forest 01015 CV_PROP_RW bool calc_var_importance; // true <=> RF processes variable importance 01016 CV_PROP_RW int nactive_vars; 01017 CV_PROP_RW CvTermCriteria term_crit; 01018 01019 CvRTParams() : CvDTreeParams( 5, 10, 0, false, 10, 0, false, false, 0 ), 01020 calc_var_importance(false), nactive_vars(0) 01021 { 01022 term_crit = cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 50, 0.1 ); 01023 } 01024 01025 CvRTParams( int _max_depth, int _min_sample_count, 01026 float _regression_accuracy, bool _use_surrogates, 01027 int _max_categories, const float* _priors, bool _calc_var_importance, 01028 int _nactive_vars, int max_num_of_trees_in_the_forest, 01029 float forest_accuracy, int termcrit_type ) : 01030 CvDTreeParams( _max_depth, _min_sample_count, _regression_accuracy, 01031 _use_surrogates, _max_categories, 0, 01032 false, false, _priors ), 01033 calc_var_importance(_calc_var_importance), 01034 nactive_vars(_nactive_vars) 01035 { 01036 term_crit = cvTermCriteria(termcrit_type, 01037 max_num_of_trees_in_the_forest, forest_accuracy); 01038 } 01039 }; 01040 01041 01042 class CV_EXPORTS_W CvRTrees : public CvStatModel 01043 { 01044 public: 01045 CV_WRAP CvRTrees(); 01046 virtual ~CvRTrees(); 01047 virtual bool train( const CvMat* trainData, int tflag, 01048 const CvMat* responses, const CvMat* varIdx=0, 01049 const CvMat* sampleIdx=0, const CvMat* varType=0, 01050 const CvMat* missingDataMask=0, 01051 CvRTParams params=CvRTParams() ); 01052 01053 virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() ); 01054 virtual float predict( const CvMat* sample, const CvMat* missing = 0 ) const; 01055 virtual float predict_prob( const CvMat* sample, const CvMat* missing = 0 ) const; 01056 01057 #ifndef SWIG 01058 CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag, 01059 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(), 01060 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(), 01061 const cv::Mat& missingDataMask=cv::Mat(), 01062 CvRTParams params=CvRTParams() ); 01063 CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const; 01064 CV_WRAP virtual float predict_prob( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const; 01065 CV_WRAP virtual cv::Mat getVarImportance(); 01066 #endif 01067 01068 CV_WRAP virtual void clear(); 01069 01070 virtual const CvMat* get_var_importance(); 01071 virtual float get_proximity( const CvMat* sample1, const CvMat* sample2, 01072 const CvMat* missing1 = 0, const CvMat* missing2 = 0 ) const; 01073 01074 virtual float calc_error( CvMLData* _data, int type , std::vector<float> *resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR} 01075 01076 virtual float get_train_error(); 01077 01078 virtual void read( CvFileStorage* fs, CvFileNode* node ); 01079 virtual void write( CvFileStorage* fs, const char* name ) const; 01080 01081 CvMat* get_active_var_mask(); 01082 CvRNG* get_rng(); 01083 01084 int get_tree_count() const; 01085 CvForestTree* get_tree(int i) const; 01086 01087 protected: 01088 01089 virtual bool grow_forest( const CvTermCriteria term_crit ); 01090 01091 // array of the trees of the forest 01092 CvForestTree** trees; 01093 CvDTreeTrainData* data; 01094 int ntrees; 01095 int nclasses; 01096 double oob_error; 01097 CvMat* var_importance; 01098 int nsamples; 01099 01100 cv::RNG* rng; 01101 CvMat* active_var_mask; 01102 }; 01103 01104 /****************************************************************************************\ 01105 * Extremely randomized trees Classifier * 01106 \****************************************************************************************/ 01107 struct CV_EXPORTS CvERTreeTrainData : public CvDTreeTrainData 01108 { 01109 virtual void set_data( const CvMat* trainData, int tflag, 01110 const CvMat* responses, const CvMat* varIdx=0, 01111 const CvMat* sampleIdx=0, const CvMat* varType=0, 01112 const CvMat* missingDataMask=0, 01113 const CvDTreeParams& params=CvDTreeParams(), 01114 bool _shared=false, bool _add_labels=false, 01115 bool _update_data=false ); 01116 virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* missing_buf, 01117 const float** ord_values, const int** missing, int* sample_buf = 0 ); 01118 virtual const int* get_sample_indices( CvDTreeNode* n, int* indices_buf ); 01119 virtual const int* get_cv_labels( CvDTreeNode* n, int* labels_buf ); 01120 virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf ); 01121 virtual void get_vectors( const CvMat* _subsample_idx, float* values, uchar* missing, 01122 float* responses, bool get_class_idx=false ); 01123 virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx ); 01124 const CvMat* missing_mask; 01125 }; 01126 01127 class CV_EXPORTS CvForestERTree : public CvForestTree 01128 { 01129 protected: 01130 virtual double calc_node_dir( CvDTreeNode* node ); 01131 virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi, 01132 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 01133 virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi, 01134 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 01135 virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi, 01136 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 01137 virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi, 01138 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 01139 virtual void split_node_data( CvDTreeNode* n ); 01140 }; 01141 01142 class CV_EXPORTS_W CvERTrees : public CvRTrees 01143 { 01144 public: 01145 CV_WRAP CvERTrees(); 01146 virtual ~CvERTrees(); 01147 virtual bool train( const CvMat* trainData, int tflag, 01148 const CvMat* responses, const CvMat* varIdx=0, 01149 const CvMat* sampleIdx=0, const CvMat* varType=0, 01150 const CvMat* missingDataMask=0, 01151 CvRTParams params=CvRTParams()); 01152 #ifndef SWIG 01153 CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag, 01154 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(), 01155 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(), 01156 const cv::Mat& missingDataMask=cv::Mat(), 01157 CvRTParams params=CvRTParams()); 01158 #endif 01159 virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() ); 01160 protected: 01161 virtual bool grow_forest( const CvTermCriteria term_crit ); 01162 }; 01163 01164 01165 /****************************************************************************************\ 01166 * Boosted tree classifier * 01167 \****************************************************************************************/ 01168 01169 struct CV_EXPORTS_W_MAP CvBoostParams : public CvDTreeParams 01170 { 01171 CV_PROP_RW int boost_type; 01172 CV_PROP_RW int weak_count; 01173 CV_PROP_RW int split_criteria; 01174 CV_PROP_RW double weight_trim_rate; 01175 01176 CvBoostParams(); 01177 CvBoostParams( int boost_type, int weak_count, double weight_trim_rate, 01178 int max_depth, bool use_surrogates, const float* priors ); 01179 }; 01180 01181 01182 class CvBoost; 01183 01184 class CV_EXPORTS CvBoostTree: public CvDTree 01185 { 01186 public: 01187 CvBoostTree(); 01188 virtual ~CvBoostTree(); 01189 01190 virtual bool train( CvDTreeTrainData* trainData, 01191 const CvMat* subsample_idx, CvBoost* ensemble ); 01192 01193 virtual void scale( double s ); 01194 virtual void read( CvFileStorage* fs, CvFileNode* node, 01195 CvBoost* ensemble, CvDTreeTrainData* _data ); 01196 virtual void clear(); 01197 01198 /* dummy methods to avoid warnings: BEGIN */ 01199 virtual bool train( const CvMat* trainData, int tflag, 01200 const CvMat* responses, const CvMat* varIdx=0, 01201 const CvMat* sampleIdx=0, const CvMat* varType=0, 01202 const CvMat* missingDataMask=0, 01203 CvDTreeParams params=CvDTreeParams() ); 01204 virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx ); 01205 01206 virtual void read( CvFileStorage* fs, CvFileNode* node ); 01207 virtual void read( CvFileStorage* fs, CvFileNode* node, 01208 CvDTreeTrainData* data ); 01209 /* dummy methods to avoid warnings: END */ 01210 01211 protected: 01212 01213 virtual void try_split_node( CvDTreeNode* n ); 01214 virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi, uchar* ext_buf = 0 ); 01215 virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi, uchar* ext_buf = 0 ); 01216 virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi, 01217 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 01218 virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi, 01219 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 01220 virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi, 01221 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 01222 virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi, 01223 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 01224 virtual void calc_node_value( CvDTreeNode* n ); 01225 virtual double calc_node_dir( CvDTreeNode* n ); 01226 01227 CvBoost* ensemble; 01228 }; 01229 01230 01231 class CV_EXPORTS_W CvBoost : public CvStatModel 01232 { 01233 public: 01234 // Boosting type 01235 enum { DISCRETE=0, REAL=1, LOGIT=2, GENTLE=3 }; 01236 01237 // Splitting criteria 01238 enum { DEFAULT=0, GINI=1, MISCLASS=3, SQERR=4 }; 01239 01240 CV_WRAP CvBoost(); 01241 virtual ~CvBoost(); 01242 01243 CvBoost( const CvMat* trainData, int tflag, 01244 const CvMat* responses, const CvMat* varIdx=0, 01245 const CvMat* sampleIdx=0, const CvMat* varType=0, 01246 const CvMat* missingDataMask=0, 01247 CvBoostParams params=CvBoostParams() ); 01248 01249 virtual bool train( const CvMat* trainData, int tflag, 01250 const CvMat* responses, const CvMat* varIdx=0, 01251 const CvMat* sampleIdx=0, const CvMat* varType=0, 01252 const CvMat* missingDataMask=0, 01253 CvBoostParams params=CvBoostParams(), 01254 bool update=false ); 01255 01256 virtual bool train( CvMLData* data, 01257 CvBoostParams params=CvBoostParams(), 01258 bool update=false ); 01259 01260 virtual float predict( const CvMat* sample, const CvMat* missing=0, 01261 CvMat* weak_responses=0, CvSlice slice=CV_WHOLE_SEQ, 01262 bool raw_mode=false, bool return_sum=false ) const; 01263 01264 #ifndef SWIG 01265 CV_WRAP CvBoost( const cv::Mat& trainData, int tflag, 01266 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(), 01267 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(), 01268 const cv::Mat& missingDataMask=cv::Mat(), 01269 CvBoostParams params=CvBoostParams() ); 01270 01271 CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag, 01272 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(), 01273 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(), 01274 const cv::Mat& missingDataMask=cv::Mat(), 01275 CvBoostParams params=CvBoostParams(), 01276 bool update=false ); 01277 01278 CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing=cv::Mat(), 01279 const cv::Range& slice=cv::Range::all(), bool rawMode=false, 01280 bool returnSum=false ) const; 01281 #endif 01282 01283 virtual float calc_error( CvMLData* _data, int type , std::vector<float> *resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR} 01284 01285 CV_WRAP virtual void prune( CvSlice slice ); 01286 01287 CV_WRAP virtual void clear(); 01288 01289 virtual void write( CvFileStorage* storage, const char* name ) const; 01290 virtual void read( CvFileStorage* storage, CvFileNode* node ); 01291 virtual const CvMat* get_active_vars(bool absolute_idx=true); 01292 01293 CvSeq* get_weak_predictors(); 01294 01295 CvMat* get_weights(); 01296 CvMat* get_subtree_weights(); 01297 CvMat* get_weak_response(); 01298 const CvBoostParams& get_params() const; 01299 const CvDTreeTrainData* get_data() const; 01300 01301 protected: 01302 01303 virtual bool set_params( const CvBoostParams& params ); 01304 virtual void update_weights( CvBoostTree* tree ); 01305 virtual void trim_weights(); 01306 virtual void write_params( CvFileStorage* fs ) const; 01307 virtual void read_params( CvFileStorage* fs, CvFileNode* node ); 01308 01309 CvDTreeTrainData* data; 01310 CvBoostParams params; 01311 CvSeq* weak; 01312 01313 CvMat* active_vars; 01314 CvMat* active_vars_abs; 01315 bool have_active_cat_vars; 01316 01317 CvMat* orig_response; 01318 CvMat* sum_response; 01319 CvMat* weak_eval; 01320 CvMat* subsample_mask; 01321 CvMat* weights; 01322 CvMat* subtree_weights; 01323 bool have_subsample; 01324 }; 01325 01326 01327 /****************************************************************************************\ 01328 * Gradient Boosted Trees * 01329 \****************************************************************************************/ 01330 01331 // DataType: STRUCT CvGBTreesParams 01332 // Parameters of GBT (Gradient Boosted trees model), including single 01333 // tree settings and ensemble parameters. 01334 // 01335 // weak_count - count of trees in the ensemble 01336 // loss_function_type - loss function used for ensemble training 01337 // subsample_portion - portion of whole training set used for 01338 // every single tree training. 01339 // subsample_portion value is in (0.0, 1.0]. 01340 // subsample_portion == 1.0 when whole dataset is 01341 // used on each step. Count of sample used on each 01342 // step is computed as 01343 // int(total_samples_count * subsample_portion). 01344 // shrinkage - regularization parameter. 01345 // Each tree prediction is multiplied on shrinkage value. 01346 01347 01348 struct CV_EXPORTS_W_MAP CvGBTreesParams : public CvDTreeParams 01349 { 01350 CV_PROP_RW int weak_count; 01351 CV_PROP_RW int loss_function_type; 01352 CV_PROP_RW float subsample_portion; 01353 CV_PROP_RW float shrinkage; 01354 01355 CvGBTreesParams(); 01356 CvGBTreesParams( int loss_function_type, int weak_count, float shrinkage, 01357 float subsample_portion, int max_depth, bool use_surrogates ); 01358 }; 01359 01360 // DataType: CLASS CvGBTrees 01361 // Gradient Boosting Trees (GBT) algorithm implementation. 01362 // 01363 // data - training dataset 01364 // params - parameters of the CvGBTrees 01365 // weak - array[0..(class_count-1)] of CvSeq 01366 // for storing tree ensembles 01367 // orig_response - original responses of the training set samples 01368 // sum_response - predicitons of the current model on the training dataset. 01369 // this matrix is updated on every iteration. 01370 // sum_response_tmp - predicitons of the model on the training set on the next 01371 // step. On every iteration values of sum_responses_tmp are 01372 // computed via sum_responses values. When the current 01373 // step is complete sum_response values become equal to 01374 // sum_responses_tmp. 01375 // sampleIdx - indices of samples used for training the ensemble. 01376 // CvGBTrees training procedure takes a set of samples 01377 // (train_data) and a set of responses (responses). 01378 // Only pairs (train_data[i], responses[i]), where i is 01379 // in sample_idx are used for training the ensemble. 01380 // subsample_train - indices of samples used for training a single decision 01381 // tree on the current step. This indices are countered 01382 // relatively to the sample_idx, so that pairs 01383 // (train_data[sample_idx[i]], responses[sample_idx[i]]) 01384 // are used for training a decision tree. 01385 // Training set is randomly splited 01386 // in two parts (subsample_train and subsample_test) 01387 // on every iteration accordingly to the portion parameter. 01388 // subsample_test - relative indices of samples from the training set, 01389 // which are not used for training a tree on the current 01390 // step. 01391 // missing - mask of the missing values in the training set. This 01392 // matrix has the same size as train_data. 1 - missing 01393 // value, 0 - not a missing value. 01394 // class_labels - output class labels map. 01395 // rng - random number generator. Used for spliting the 01396 // training set. 01397 // class_count - count of output classes. 01398 // class_count == 1 in the case of regression, 01399 // and > 1 in the case of classification. 01400 // delta - Huber loss function parameter. 01401 // base_value - start point of the gradient descent procedure. 01402 // model prediction is 01403 // f(x) = f_0 + sum_{i=1..weak_count-1}(f_i(x)), where 01404 // f_0 is the base value. 01405 01406 01407 01408 class CV_EXPORTS_W CvGBTrees : public CvStatModel 01409 { 01410 public: 01411 01412 /* 01413 // DataType: ENUM 01414 // Loss functions implemented in CvGBTrees. 01415 // 01416 // SQUARED_LOSS 01417 // problem: regression 01418 // loss = (x - x')^2 01419 // 01420 // ABSOLUTE_LOSS 01421 // problem: regression 01422 // loss = abs(x - x') 01423 // 01424 // HUBER_LOSS 01425 // problem: regression 01426 // loss = delta*( abs(x - x') - delta/2), if abs(x - x') > delta 01427 // 1/2*(x - x')^2, if abs(x - x') <= delta, 01428 // where delta is the alpha-quantile of pseudo responses from 01429 // the training set. 01430 // 01431 // DEVIANCE_LOSS 01432 // problem: classification 01433 // 01434 */ 01435 enum {SQUARED_LOSS=0, ABSOLUTE_LOSS, HUBER_LOSS=3, DEVIANCE_LOSS}; 01436 01437 01438 /* 01439 // Default constructor. Creates a model only (without training). 01440 // Should be followed by one form of the train(...) function. 01441 // 01442 // API 01443 // CvGBTrees(); 01444 01445 // INPUT 01446 // OUTPUT 01447 // RESULT 01448 */ 01449 CV_WRAP CvGBTrees(); 01450 01451 01452 /* 01453 // Full form constructor. Creates a gradient boosting model and does the 01454 // train. 01455 // 01456 // API 01457 // CvGBTrees( const CvMat* trainData, int tflag, 01458 const CvMat* responses, const CvMat* varIdx=0, 01459 const CvMat* sampleIdx=0, const CvMat* varType=0, 01460 const CvMat* missingDataMask=0, 01461 CvGBTreesParams params=CvGBTreesParams() ); 01462 01463 // INPUT 01464 // trainData - a set of input feature vectors. 01465 // size of matrix is 01466 // <count of samples> x <variables count> 01467 // or <variables count> x <count of samples> 01468 // depending on the tflag parameter. 01469 // matrix values are float. 01470 // tflag - a flag showing how do samples stored in the 01471 // trainData matrix row by row (tflag=CV_ROW_SAMPLE) 01472 // or column by column (tflag=CV_COL_SAMPLE). 01473 // responses - a vector of responses corresponding to the samples 01474 // in trainData. 01475 // varIdx - indices of used variables. zero value means that all 01476 // variables are active. 01477 // sampleIdx - indices of used samples. zero value means that all 01478 // samples from trainData are in the training set. 01479 // varType - vector of <variables count> length. gives every 01480 // variable type CV_VAR_CATEGORICAL or CV_VAR_ORDERED. 01481 // varType = 0 means all variables are numerical. 01482 // missingDataMask - a mask of misiing values in trainData. 01483 // missingDataMask = 0 means that there are no missing 01484 // values. 01485 // params - parameters of GTB algorithm. 01486 // OUTPUT 01487 // RESULT 01488 */ 01489 CvGBTrees( const CvMat* trainData, int tflag, 01490 const CvMat* responses, const CvMat* varIdx=0, 01491 const CvMat* sampleIdx=0, const CvMat* varType=0, 01492 const CvMat* missingDataMask=0, 01493 CvGBTreesParams params=CvGBTreesParams() ); 01494 01495 01496 /* 01497 // Destructor. 01498 */ 01499 virtual ~CvGBTrees(); 01500 01501 01502 /* 01503 // Gradient tree boosting model training 01504 // 01505 // API 01506 // virtual bool train( const CvMat* trainData, int tflag, 01507 const CvMat* responses, const CvMat* varIdx=0, 01508 const CvMat* sampleIdx=0, const CvMat* varType=0, 01509 const CvMat* missingDataMask=0, 01510 CvGBTreesParams params=CvGBTreesParams(), 01511 bool update=false ); 01512 01513 // INPUT 01514 // trainData - a set of input feature vectors. 01515 // size of matrix is 01516 // <count of samples> x <variables count> 01517 // or <variables count> x <count of samples> 01518 // depending on the tflag parameter. 01519 // matrix values are float. 01520 // tflag - a flag showing how do samples stored in the 01521 // trainData matrix row by row (tflag=CV_ROW_SAMPLE) 01522 // or column by column (tflag=CV_COL_SAMPLE). 01523 // responses - a vector of responses corresponding to the samples 01524 // in trainData. 01525 // varIdx - indices of used variables. zero value means that all 01526 // variables are active. 01527 // sampleIdx - indices of used samples. zero value means that all 01528 // samples from trainData are in the training set. 01529 // varType - vector of <variables count> length. gives every 01530 // variable type CV_VAR_CATEGORICAL or CV_VAR_ORDERED. 01531 // varType = 0 means all variables are numerical. 01532 // missingDataMask - a mask of misiing values in trainData. 01533 // missingDataMask = 0 means that there are no missing 01534 // values. 01535 // params - parameters of GTB algorithm. 01536 // update - is not supported now. (!) 01537 // OUTPUT 01538 // RESULT 01539 // Error state. 01540 */ 01541 virtual bool train( const CvMat* trainData, int tflag, 01542 const CvMat* responses, const CvMat* varIdx=0, 01543 const CvMat* sampleIdx=0, const CvMat* varType=0, 01544 const CvMat* missingDataMask=0, 01545 CvGBTreesParams params=CvGBTreesParams(), 01546 bool update=false ); 01547 01548 01549 /* 01550 // Gradient tree boosting model training 01551 // 01552 // API 01553 // virtual bool train( CvMLData* data, 01554 CvGBTreesParams params=CvGBTreesParams(), 01555 bool update=false ) {return false;}; 01556 01557 // INPUT 01558 // data - training set. 01559 // params - parameters of GTB algorithm. 01560 // update - is not supported now. (!) 01561 // OUTPUT 01562 // RESULT 01563 // Error state. 01564 */ 01565 virtual bool train( CvMLData* data, 01566 CvGBTreesParams params=CvGBTreesParams(), 01567 bool update=false ); 01568 01569 01570 /* 01571 // Response value prediction 01572 // 01573 // API 01574 // virtual float predict_serial( const CvMat* sample, const CvMat* missing=0, 01575 CvMat* weak_responses=0, CvSlice slice = CV_WHOLE_SEQ, 01576 int k=-1 ) const; 01577 01578 // INPUT 01579 // sample - input sample of the same type as in the training set. 01580 // missing - missing values mask. missing=0 if there are no 01581 // missing values in sample vector. 01582 // weak_responses - predictions of all of the trees. 01583 // not implemented (!) 01584 // slice - part of the ensemble used for prediction. 01585 // slice = CV_WHOLE_SEQ when all trees are used. 01586 // k - number of ensemble used. 01587 // k is in {-1,0,1,..,<count of output classes-1>}. 01588 // in the case of classification problem 01589 // <count of output classes-1> ensembles are built. 01590 // If k = -1 ordinary prediction is the result, 01591 // otherwise function gives the prediction of the 01592 // k-th ensemble only. 01593 // OUTPUT 01594 // RESULT 01595 // Predicted value. 01596 */ 01597 virtual float predict_serial( const CvMat* sample, const CvMat* missing=0, 01598 CvMat* weakResponses=0, CvSlice slice = CV_WHOLE_SEQ, 01599 int k=-1 ) const; 01600 01601 /* 01602 // Response value prediction. 01603 // Parallel version (in the case of TBB existence) 01604 // 01605 // API 01606 // virtual float predict( const CvMat* sample, const CvMat* missing=0, 01607 CvMat* weak_responses=0, CvSlice slice = CV_WHOLE_SEQ, 01608 int k=-1 ) const; 01609 01610 // INPUT 01611 // sample - input sample of the same type as in the training set. 01612 // missing - missing values mask. missing=0 if there are no 01613 // missing values in sample vector. 01614 // weak_responses - predictions of all of the trees. 01615 // not implemented (!) 01616 // slice - part of the ensemble used for prediction. 01617 // slice = CV_WHOLE_SEQ when all trees are used. 01618 // k - number of ensemble used. 01619 // k is in {-1,0,1,..,<count of output classes-1>}. 01620 // in the case of classification problem 01621 // <count of output classes-1> ensembles are built. 01622 // If k = -1 ordinary prediction is the result, 01623 // otherwise function gives the prediction of the 01624 // k-th ensemble only. 01625 // OUTPUT 01626 // RESULT 01627 // Predicted value. 01628 */ 01629 virtual float predict( const CvMat* sample, const CvMat* missing=0, 01630 CvMat* weakResponses=0, CvSlice slice = CV_WHOLE_SEQ, 01631 int k=-1 ) const; 01632 01633 /* 01634 // Deletes all the data. 01635 // 01636 // API 01637 // virtual void clear(); 01638 01639 // INPUT 01640 // OUTPUT 01641 // delete data, weak, orig_response, sum_response, 01642 // weak_eval, subsample_train, subsample_test, 01643 // sample_idx, missing, lass_labels 01644 // delta = 0.0 01645 // RESULT 01646 */ 01647 CV_WRAP virtual void clear(); 01648 01649 /* 01650 // Compute error on the train/test set. 01651 // 01652 // API 01653 // virtual float calc_error( CvMLData* _data, int type, 01654 // std::vector<float> *resp = 0 ); 01655 // 01656 // INPUT 01657 // data - dataset 01658 // type - defines which error is to compute: train (CV_TRAIN_ERROR) or 01659 // test (CV_TEST_ERROR). 01660 // OUTPUT 01661 // resp - vector of predicitons 01662 // RESULT 01663 // Error value. 01664 */ 01665 virtual float calc_error( CvMLData* _data, int type, 01666 std::vector<float> *resp = 0 ); 01667 01668 /* 01669 // 01670 // Write parameters of the gtb model and data. Write learned model. 01671 // 01672 // API 01673 // virtual void write( CvFileStorage* fs, const char* name ) const; 01674 // 01675 // INPUT 01676 // fs - file storage to read parameters from. 01677 // name - model name. 01678 // OUTPUT 01679 // RESULT 01680 */ 01681 virtual void write( CvFileStorage* fs, const char* name ) const; 01682 01683 01684 /* 01685 // 01686 // Read parameters of the gtb model and data. Read learned model. 01687 // 01688 // API 01689 // virtual void read( CvFileStorage* fs, CvFileNode* node ); 01690 // 01691 // INPUT 01692 // fs - file storage to read parameters from. 01693 // node - file node. 01694 // OUTPUT 01695 // RESULT 01696 */ 01697 virtual void read( CvFileStorage* fs, CvFileNode* node ); 01698 01699 01700 // new-style C++ interface 01701 CV_WRAP CvGBTrees( const cv::Mat& trainData, int tflag, 01702 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(), 01703 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(), 01704 const cv::Mat& missingDataMask=cv::Mat(), 01705 CvGBTreesParams params=CvGBTreesParams() ); 01706 01707 CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag, 01708 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(), 01709 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(), 01710 const cv::Mat& missingDataMask=cv::Mat(), 01711 CvGBTreesParams params=CvGBTreesParams(), 01712 bool update=false ); 01713 01714 CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing=cv::Mat(), 01715 const cv::Range& slice = cv::Range::all(), 01716 int k=-1 ) const; 01717 01718 protected: 01719 01720 /* 01721 // Compute the gradient vector components. 01722 // 01723 // API 01724 // virtual void find_gradient( const int k = 0); 01725 01726 // INPUT 01727 // k - used for classification problem, determining current 01728 // tree ensemble. 01729 // OUTPUT 01730 // changes components of data->responses 01731 // which correspond to samples used for training 01732 // on the current step. 01733 // RESULT 01734 */ 01735 virtual void find_gradient( const int k = 0); 01736 01737 01738 /* 01739 // 01740 // Change values in tree leaves according to the used loss function. 01741 // 01742 // API 01743 // virtual void change_values(CvDTree* tree, const int k = 0); 01744 // 01745 // INPUT 01746 // tree - decision tree to change. 01747 // k - used for classification problem, determining current 01748 // tree ensemble. 01749 // OUTPUT 01750 // changes 'value' fields of the trees' leaves. 01751 // changes sum_response_tmp. 01752 // RESULT 01753 */ 01754 virtual void change_values(CvDTree* tree, const int k = 0); 01755 01756 01757 /* 01758 // 01759 // Find optimal constant prediction value according to the used loss 01760 // function. 01761 // The goal is to find a constant which gives the minimal summary loss 01762 // on the _Idx samples. 01763 // 01764 // API 01765 // virtual float find_optimal_value( const CvMat* _Idx ); 01766 // 01767 // INPUT 01768 // _Idx - indices of the samples from the training set. 01769 // OUTPUT 01770 // RESULT 01771 // optimal constant value. 01772 */ 01773 virtual float find_optimal_value( const CvMat* _Idx ); 01774 01775 01776 /* 01777 // 01778 // Randomly split the whole training set in two parts according 01779 // to params.portion. 01780 // 01781 // API 01782 // virtual void do_subsample(); 01783 // 01784 // INPUT 01785 // OUTPUT 01786 // subsample_train - indices of samples used for training 01787 // subsample_test - indices of samples used for test 01788 // RESULT 01789 */ 01790 virtual void do_subsample(); 01791 01792 01793 /* 01794 // 01795 // Internal recursive function giving an array of subtree tree leaves. 01796 // 01797 // API 01798 // void leaves_get( CvDTreeNode** leaves, int& count, CvDTreeNode* node ); 01799 // 01800 // INPUT 01801 // node - current leaf. 01802 // OUTPUT 01803 // count - count of leaves in the subtree. 01804 // leaves - array of pointers to leaves. 01805 // RESULT 01806 */ 01807 void leaves_get( CvDTreeNode** leaves, int& count, CvDTreeNode* node ); 01808 01809 01810 /* 01811 // 01812 // Get leaves of the tree. 01813 // 01814 // API 01815 // CvDTreeNode** GetLeaves( const CvDTree* dtree, int& len ); 01816 // 01817 // INPUT 01818 // dtree - decision tree. 01819 // OUTPUT 01820 // len - count of the leaves. 01821 // RESULT 01822 // CvDTreeNode** - array of pointers to leaves. 01823 */ 01824 CvDTreeNode** GetLeaves( const CvDTree* dtree, int& len ); 01825 01826 01827 /* 01828 // 01829 // Is it a regression or a classification. 01830 // 01831 // API 01832 // bool problem_type(); 01833 // 01834 // INPUT 01835 // OUTPUT 01836 // RESULT 01837 // false if it is a classification problem, 01838 // true - if regression. 01839 */ 01840 virtual bool problem_type() const; 01841 01842 01843 /* 01844 // 01845 // Write parameters of the gtb model. 01846 // 01847 // API 01848 // virtual void write_params( CvFileStorage* fs ) const; 01849 // 01850 // INPUT 01851 // fs - file storage to write parameters to. 01852 // OUTPUT 01853 // RESULT 01854 */ 01855 virtual void write_params( CvFileStorage* fs ) const; 01856 01857 01858 /* 01859 // 01860 // Read parameters of the gtb model and data. 01861 // 01862 // API 01863 // virtual void read_params( CvFileStorage* fs ); 01864 // 01865 // INPUT 01866 // fs - file storage to read parameters from. 01867 // OUTPUT 01868 // params - parameters of the gtb model. 01869 // data - contains information about the structure 01870 // of the data set (count of variables, 01871 // their types, etc.). 01872 // class_labels - output class labels map. 01873 // RESULT 01874 */ 01875 virtual void read_params( CvFileStorage* fs, CvFileNode* fnode ); 01876 int get_len(const CvMat* mat) const; 01877 01878 01879 CvDTreeTrainData* data; 01880 CvGBTreesParams params; 01881 01882 CvSeq** weak; 01883 CvMat* orig_response; 01884 CvMat* sum_response; 01885 CvMat* sum_response_tmp; 01886 CvMat* sample_idx; 01887 CvMat* subsample_train; 01888 CvMat* subsample_test; 01889 CvMat* missing; 01890 CvMat* class_labels; 01891 01892 cv::RNG* rng; 01893 01894 int class_count; 01895 float delta; 01896 float base_value; 01897 01898 }; 01899 01900 01901 01902 /****************************************************************************************\ 01903 * Artificial Neural Networks (ANN) * 01904 \****************************************************************************************/ 01905 01907 01908 struct CV_EXPORTS_W_MAP CvANN_MLP_TrainParams 01909 { 01910 CvANN_MLP_TrainParams(); 01911 CvANN_MLP_TrainParams( CvTermCriteria term_crit, int train_method, 01912 double param1, double param2=0 ); 01913 ~CvANN_MLP_TrainParams(); 01914 01915 enum { BACKPROP=0, RPROP=1 }; 01916 01917 CV_PROP_RW CvTermCriteria term_crit; 01918 CV_PROP_RW int train_method; 01919 01920 // backpropagation parameters 01921 CV_PROP_RW double bp_dw_scale, bp_moment_scale; 01922 01923 // rprop parameters 01924 CV_PROP_RW double rp_dw0, rp_dw_plus, rp_dw_minus, rp_dw_min, rp_dw_max; 01925 }; 01926 01927 01928 class CV_EXPORTS_W CvANN_MLP : public CvStatModel 01929 { 01930 public: 01931 CV_WRAP CvANN_MLP(); 01932 CvANN_MLP( const CvMat* layerSizes, 01933 int activateFunc=CvANN_MLP::SIGMOID_SYM, 01934 double fparam1=0, double fparam2=0 ); 01935 01936 virtual ~CvANN_MLP(); 01937 01938 virtual void create( const CvMat* layerSizes, 01939 int activateFunc=CvANN_MLP::SIGMOID_SYM, 01940 double fparam1=0, double fparam2=0 ); 01941 01942 virtual int train( const CvMat* inputs, const CvMat* outputs, 01943 const CvMat* sampleWeights, const CvMat* sampleIdx=0, 01944 CvANN_MLP_TrainParams params = CvANN_MLP_TrainParams(), 01945 int flags=0 ); 01946 virtual float predict( const CvMat* inputs, CV_OUT CvMat* outputs ) const; 01947 01948 #ifndef SWIG 01949 CV_WRAP CvANN_MLP( const cv::Mat& layerSizes, 01950 int activateFunc=CvANN_MLP::SIGMOID_SYM, 01951 double fparam1=0, double fparam2=0 ); 01952 01953 CV_WRAP virtual void create( const cv::Mat& layerSizes, 01954 int activateFunc=CvANN_MLP::SIGMOID_SYM, 01955 double fparam1=0, double fparam2=0 ); 01956 01957 CV_WRAP virtual int train( const cv::Mat& inputs, const cv::Mat& outputs, 01958 const cv::Mat& sampleWeights, const cv::Mat& sampleIdx=cv::Mat(), 01959 CvANN_MLP_TrainParams params = CvANN_MLP_TrainParams(), 01960 int flags=0 ); 01961 01962 CV_WRAP virtual float predict( const cv::Mat& inputs, cv::Mat& outputs ) const; 01963 #endif 01964 01965 CV_WRAP virtual void clear(); 01966 01967 // possible activation functions 01968 enum { IDENTITY = 0, SIGMOID_SYM = 1, GAUSSIAN = 2 }; 01969 01970 // available training flags 01971 enum { UPDATE_WEIGHTS = 1, NO_INPUT_SCALE = 2, NO_OUTPUT_SCALE = 4 }; 01972 01973 virtual void read( CvFileStorage* fs, CvFileNode* node ); 01974 virtual void write( CvFileStorage* storage, const char* name ) const; 01975 01976 int get_layer_count() { return layer_sizes ? layer_sizes->cols : 0; } 01977 const CvMat* get_layer_sizes() { return layer_sizes; } 01978 double* get_weights(int layer) 01979 { 01980 return layer_sizes && weights && 01981 (unsigned)layer <= (unsigned)layer_sizes->cols ? weights[layer] : 0; 01982 } 01983 01984 virtual void calc_activ_func_deriv( CvMat* xf, CvMat* deriv, const double* bias ) const; 01985 01986 protected: 01987 01988 virtual bool prepare_to_train( const CvMat* _inputs, const CvMat* _outputs, 01989 const CvMat* _sample_weights, const CvMat* sampleIdx, 01990 CvVectors* _ivecs, CvVectors* _ovecs, double** _sw, int _flags ); 01991 01992 // sequential random backpropagation 01993 virtual int train_backprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw ); 01994 01995 // RPROP algorithm 01996 virtual int train_rprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw ); 01997 01998 virtual void calc_activ_func( CvMat* xf, const double* bias ) const; 01999 virtual void set_activ_func( int _activ_func=SIGMOID_SYM, 02000 double _f_param1=0, double _f_param2=0 ); 02001 virtual void init_weights(); 02002 virtual void scale_input( const CvMat* _src, CvMat* _dst ) const; 02003 virtual void scale_output( const CvMat* _src, CvMat* _dst ) const; 02004 virtual void calc_input_scale( const CvVectors* vecs, int flags ); 02005 virtual void calc_output_scale( const CvVectors* vecs, int flags ); 02006 02007 virtual void write_params( CvFileStorage* fs ) const; 02008 virtual void read_params( CvFileStorage* fs, CvFileNode* node ); 02009 02010 CvMat* layer_sizes; 02011 CvMat* wbuf; 02012 CvMat* sample_weights; 02013 double** weights; 02014 double f_param1, f_param2; 02015 double min_val, max_val, min_val1, max_val1; 02016 int activ_func; 02017 int max_count, max_buf_sz; 02018 CvANN_MLP_TrainParams params; 02019 cv::RNG* rng; 02020 }; 02021 02022 /****************************************************************************************\ 02023 * Auxilary functions declarations * 02024 \****************************************************************************************/ 02025 02026 /* Generates <sample> from multivariate normal distribution, where <mean> - is an 02027 average row vector, <cov> - symmetric covariation matrix */ 02028 CVAPI(void) cvRandMVNormal( CvMat* mean, CvMat* cov, CvMat* sample, 02029 CvRNG* rng CV_DEFAULT(0) ); 02030 02031 /* Generates sample from gaussian mixture distribution */ 02032 CVAPI(void) cvRandGaussMixture( CvMat* means[], 02033 CvMat* covs[], 02034 float weights[], 02035 int clsnum, 02036 CvMat* sample, 02037 CvMat* sampClasses CV_DEFAULT(0) ); 02038 02039 #define CV_TS_CONCENTRIC_SPHERES 0 02040 02041 /* creates test set */ 02042 CVAPI(void) cvCreateTestSet( int type, CvMat** samples, 02043 int num_samples, 02044 int num_features, 02045 CvMat** responses, 02046 int num_classes, ... ); 02047 02048 02049 #endif 02050 02051 /****************************************************************************************\ 02052 * Data * 02053 \****************************************************************************************/ 02054 02055 #include <map> 02056 #include <string> 02057 #include <iostream> 02058 02059 #define CV_COUNT 0 02060 #define CV_PORTION 1 02061 02062 struct CV_EXPORTS CvTrainTestSplit 02063 { 02064 CvTrainTestSplit(); 02065 CvTrainTestSplit( int train_sample_count, bool mix = true); 02066 CvTrainTestSplit( float train_sample_portion, bool mix = true); 02067 02068 union 02069 { 02070 int count; 02071 float portion; 02072 } train_sample_part; 02073 int train_sample_part_mode; 02074 02075 bool mix; 02076 }; 02077 02078 class CV_EXPORTS CvMLData 02079 { 02080 public: 02081 CvMLData(); 02082 virtual ~CvMLData(); 02083 02084 // returns: 02085 // 0 - OK 02086 // 1 - file can not be opened or is not correct 02087 int read_csv( const char* filename ); 02088 02089 const CvMat* get_values() const; 02090 const CvMat* get_responses(); 02091 const CvMat* get_missing() const; 02092 02093 void set_response_idx( int idx ); // old response become predictors, new response_idx = idx 02094 // if idx < 0 there will be no response 02095 int get_response_idx() const; 02096 02097 void set_train_test_split( const CvTrainTestSplit * spl ); 02098 const CvMat* get_train_sample_idx() const; 02099 const CvMat* get_test_sample_idx() const; 02100 void mix_train_and_test_idx(); 02101 02102 const CvMat* get_var_idx(); 02103 void chahge_var_idx( int vi, bool state ); // state == true to set vi-variable as predictor 02104 02105 const CvMat* get_var_types(); 02106 int get_var_type( int var_idx ) const; 02107 // following 2 methods enable to change vars type 02108 // use these methods to assign CV_VAR_CATEGORICAL type for categorical variable 02109 // with numerical labels; in the other cases var types are correctly determined automatically 02110 void set_var_types( const char* str ); // str examples: 02111 // "ord[0-17],cat[18]", "ord[0,2,4,10-12], cat[1,3,5-9,13,14]", 02112 // "cat", "ord" (all vars are categorical/ordered) 02113 void change_var_type( int var_idx, int type); // type in { CV_VAR_ORDERED, CV_VAR_CATEGORICAL } 02114 02115 void set_delimiter( char ch ); 02116 char get_delimiter() const; 02117 02118 void set_miss_ch( char ch ); 02119 char get_miss_ch() const; 02120 02121 const std::map<std::string, int>& get_class_labels_map() const; 02122 02123 protected: 02124 virtual void clear(); 02125 02126 void str_to_flt_elem( const char* token, float& flt_elem, int& type); 02127 void free_train_test_idx(); 02128 02129 char delimiter; 02130 char miss_ch; 02131 //char flt_separator; 02132 02133 CvMat* values; 02134 CvMat* missing; 02135 CvMat* var_types; 02136 CvMat* var_idx_mask; 02137 02138 CvMat* response_out; // header 02139 CvMat* var_idx_out; // mat 02140 CvMat* var_types_out; // mat 02141 02142 int response_idx; 02143 02144 int train_sample_count; 02145 bool mix; 02146 02147 int total_class_count; 02148 std::map<std::string, int> class_map; 02149 02150 CvMat* train_sample_idx; 02151 CvMat* test_sample_idx; 02152 int* sample_idx; // data of train_sample_idx and test_sample_idx 02153 02154 cv::RNG* rng; 02155 }; 02156 02157 02158 namespace cv 02159 { 02160 02161 typedef CvStatModel StatModel; 02162 typedef CvParamGrid ParamGrid; 02163 typedef CvNormalBayesClassifier NormalBayesClassifier; 02164 typedef CvKNearest KNearest; 02165 typedef CvSVMParams SVMParams; 02166 typedef CvSVMKernel SVMKernel; 02167 typedef CvSVMSolver SVMSolver; 02168 typedef CvSVM SVM; 02169 typedef CvEMParams EMParams; 02170 typedef CvEM ExpectationMaximization; 02171 typedef CvDTreeParams DTreeParams; 02172 typedef CvMLData TrainData; 02173 typedef CvDTree DecisionTree; 02174 typedef CvForestTree ForestTree; 02175 typedef CvRTParams RandomTreeParams; 02176 typedef CvRTrees RandomTrees; 02177 typedef CvERTreeTrainData ERTreeTRainData; 02178 typedef CvForestERTree ERTree; 02179 typedef CvERTrees ERTrees; 02180 typedef CvBoostParams BoostParams; 02181 typedef CvBoostTree BoostTree; 02182 typedef CvBoost Boost; 02183 typedef CvANN_MLP_TrainParams ANN_MLP_TrainParams; 02184 typedef CvANN_MLP NeuralNet_MLP; 02185 typedef CvGBTreesParams GradientBoostingTreeParams; 02186 typedef CvGBTrees GradientBoostingTrees; 02187 02188 template<> CV_EXPORTS void Ptr<CvDTreeSplit>::delete_obj(); 02189 02190 } 02191 02192 #endif 02193 /* End of file. */