00001 /*M/////////////////////////////////////////////////////////////////////////////////////// 00002 // 00003 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. 00004 // 00005 // By downloading, copying, installing or using the software you agree to this license. 00006 // If you do not agree to this license, do not download, install, 00007 // copy or use the software. 00008 // 00009 // 00010 // Intel License Agreement 00011 // 00012 // Copyright (C) 2000, Intel Corporation, all rights reserved. 00013 // Third party copyrights are property of their respective owners. 00014 // 00015 // Redistribution and use in source and binary forms, with or without modification, 00016 // are permitted provided that the following conditions are met: 00017 // 00018 // * Redistribution's of source code must retain the above copyright notice, 00019 // this list of conditions and the following disclaimer. 00020 // 00021 // * Redistribution's in binary form must reproduce the above copyright notice, 00022 // this list of conditions and the following disclaimer in the documentation 00023 // and/or other materials provided with the distribution. 00024 // 00025 // * The name of Intel Corporation may not be used to endorse or promote products 00026 // derived from this software without specific prior written permission. 00027 // 00028 // This software is provided by the copyright holders and contributors "as is" and 00029 // any express or implied warranties, including, but not limited to, the implied 00030 // warranties of merchantability and fitness for a particular purpose are disclaimed. 00031 // In no event shall the Intel Corporation or contributors be liable for any direct, 00032 // indirect, incidental, special, exemplary, or consequential damages 00033 // (including, but not limited to, procurement of substitute goods or services; 00034 // loss of use, data, or profits; or business interruption) however caused 00035 // and on any theory of liability, whether in contract, strict liability, 00036 // or tort (including negligence or otherwise) arising in any way out of 00037 // the use of this software, even if advised of the possibility of such damage. 00038 // 00039 //M*/ 00040 00041 #ifndef __OPENCV_ML_H__ 00042 #define __OPENCV_ML_H__ 00043 00044 // disable deprecation warning which appears in VisualStudio 8.0 00045 #if _MSC_VER >= 1400 00046 #pragma warning( disable : 4996 ) 00047 #endif 00048 00049 #ifndef SKIP_INCLUDES 00050 00051 #include "cxcore.h" 00052 #include <limits.h> 00053 00054 #if defined WIN32 || defined _WIN32 || defined WIN64 || defined _WIN64 00055 #include <windows.h> 00056 #endif 00057 00058 #else // SKIP_INCLUDES 00059 00060 #if defined WIN32 || defined _WIN32 || defined WIN64 || defined _WIN64 00061 #define CV_CDECL __cdecl 00062 #define CV_STDCALL __stdcall 00063 #else 00064 #define CV_CDECL 00065 #define CV_STDCALL 00066 #endif 00067 00068 #ifndef CV_EXTERN_C 00069 #ifdef __cplusplus 00070 #define CV_EXTERN_C extern "C" 00071 #define CV_DEFAULT(val) = val 00072 #else 00073 #define CV_EXTERN_C 00074 #define CV_DEFAULT(val) 00075 #endif 00076 #endif 00077 00078 #ifndef CV_EXTERN_C_FUNCPTR 00079 #ifdef __cplusplus 00080 #define CV_EXTERN_C_FUNCPTR(x) extern "C" { typedef x; } 00081 #else 00082 #define CV_EXTERN_C_FUNCPTR(x) typedef x 00083 #endif 00084 #endif 00085 00086 #ifndef CV_INLINE 00087 #if defined __cplusplus 00088 #define CV_INLINE inline 00089 #elif (defined WIN32 || defined _WIN32 || defined WIN64 || defined _WIN64) && !defined __GNUC__ 00090 #define CV_INLINE __inline 00091 #else 00092 #define CV_INLINE static 00093 #endif 00094 #endif /* CV_INLINE */ 00095 00096 #if (defined WIN32 || defined _WIN32 || defined WIN64 || defined _WIN64) && defined CVAPI_EXPORTS 00097 #define CV_EXPORTS __declspec(dllexport) 00098 #else 00099 #define CV_EXPORTS 00100 #endif 00101 00102 #ifndef CVAPI 00103 #define CVAPI(rettype) CV_EXTERN_C CV_EXPORTS rettype CV_CDECL 00104 #endif 00105 00106 #endif // SKIP_INCLUDES 00107 00108 00109 #ifdef __cplusplus 00110 00111 // Apple defines a check() macro somewhere in the debug headers 00112 // that interferes with a method definiton in this header 00113 #undef check 00114 00115 #include "cvinternal.h" 00116 00117 /****************************************************************************************\ 00118 * Main struct definitions * 00119 \****************************************************************************************/ 00120 00121 /* log(2*PI) */ 00122 #define CV_LOG2PI (1.8378770664093454835606594728112) 00123 00124 /* columns of <trainData> matrix are training samples */ 00125 #define CV_COL_SAMPLE 0 00126 00127 /* rows of <trainData> matrix are training samples */ 00128 #define CV_ROW_SAMPLE 1 00129 00130 #define CV_IS_ROW_SAMPLE(flags) ((flags) & CV_ROW_SAMPLE) 00131 00132 struct CvVectors 00133 { 00134 int type; 00135 int dims, count; 00136 CvVectors* next; 00137 union 00138 { 00139 uchar** ptr; 00140 float** fl; 00141 double** db; 00142 } data; 00143 }; 00144 00145 #if 0 00146 /* A structure, representing the lattice range of statmodel parameters. 00147 It is used for optimizing statmodel parameters by cross-validation method. 00148 The lattice is logarithmic, so <step> must be greater then 1. */ 00149 typedef struct CvParamLattice 00150 { 00151 double min_val; 00152 double max_val; 00153 double step; 00154 } 00155 CvParamLattice; 00156 00157 CV_INLINE CvParamLattice cvParamLattice( double min_val, double max_val, 00158 double log_step ) 00159 { 00160 CvParamLattice pl; 00161 pl.min_val = MIN( min_val, max_val ); 00162 pl.max_val = MAX( min_val, max_val ); 00163 pl.step = MAX( log_step, 1. ); 00164 return pl; 00165 } 00166 00167 CV_INLINE CvParamLattice cvDefaultParamLattice( void ) 00168 { 00169 CvParamLattice pl = {0,0,0}; 00170 return pl; 00171 } 00172 #endif 00173 00174 /* Variable type */ 00175 #define CV_VAR_NUMERICAL 0 00176 #define CV_VAR_ORDERED 0 00177 #define CV_VAR_CATEGORICAL 1 00178 00179 #define CV_TYPE_NAME_ML_SVM "opencv-ml-svm" 00180 #define CV_TYPE_NAME_ML_KNN "opencv-ml-knn" 00181 #define CV_TYPE_NAME_ML_NBAYES "opencv-ml-bayesian" 00182 #define CV_TYPE_NAME_ML_EM "opencv-ml-em" 00183 #define CV_TYPE_NAME_ML_BOOSTING "opencv-ml-boost-tree" 00184 #define CV_TYPE_NAME_ML_TREE "opencv-ml-tree" 00185 #define CV_TYPE_NAME_ML_ANN_MLP "opencv-ml-ann-mlp" 00186 #define CV_TYPE_NAME_ML_CNN "opencv-ml-cnn" 00187 #define CV_TYPE_NAME_ML_RTREES "opencv-ml-random-trees" 00188 00189 #define CV_TRAIN_ERROR 0 00190 #define CV_TEST_ERROR 1 00191 00192 class CV_EXPORTS CvStatModel 00193 { 00194 public: 00195 CvStatModel(); 00196 virtual ~CvStatModel(); 00197 00198 virtual void clear(); 00199 00200 virtual void save( const char* filename, const char* name=0 ) const; 00201 virtual void load( const char* filename, const char* name=0 ); 00202 00203 virtual void write( CvFileStorage* storage, const char* name ) const; 00204 virtual void read( CvFileStorage* storage, CvFileNode* node ); 00205 00206 protected: 00207 const char* default_model_name; 00208 }; 00209 00210 /****************************************************************************************\ 00211 * Normal Bayes Classifier * 00212 \****************************************************************************************/ 00213 00214 /* The structure, representing the grid range of statmodel parameters. 00215 It is used for optimizing statmodel accuracy by varying model parameters, 00216 the accuracy estimate being computed by cross-validation. 00217 The grid is logarithmic, so <step> must be greater then 1. */ 00218 00219 class CvMLData; 00220 00221 struct CV_EXPORTS CvParamGrid 00222 { 00223 // SVM params type 00224 enum { SVM_C=0, SVM_GAMMA=1, SVM_P=2, SVM_NU=3, SVM_COEF=4, SVM_DEGREE=5 }; 00225 00226 CvParamGrid() 00227 { 00228 min_val = max_val = step = 0; 00229 } 00230 00231 CvParamGrid( double _min_val, double _max_val, double log_step ) 00232 { 00233 min_val = _min_val; 00234 max_val = _max_val; 00235 step = log_step; 00236 } 00237 //CvParamGrid( int param_id ); 00238 bool check() const; 00239 00240 double min_val; 00241 double max_val; 00242 double step; 00243 }; 00244 00245 class CV_EXPORTS CvNormalBayesClassifier : public CvStatModel 00246 { 00247 public: 00248 CvNormalBayesClassifier(); 00249 virtual ~CvNormalBayesClassifier(); 00250 00251 CvNormalBayesClassifier( const CvMat* _train_data, const CvMat* _responses, 00252 const CvMat* _var_idx=0, const CvMat* _sample_idx=0 ); 00253 00254 virtual bool train( const CvMat* _train_data, const CvMat* _responses, 00255 const CvMat* _var_idx = 0, const CvMat* _sample_idx=0, bool update=false ); 00256 00257 virtual float predict( const CvMat* _samples, CvMat* results=0 ) const; 00258 virtual void clear(); 00259 00260 #ifndef SWIG 00261 CvNormalBayesClassifier( const cv::Mat& _train_data, const cv::Mat& _responses, 00262 const cv::Mat& _var_idx=cv::Mat(), const cv::Mat& _sample_idx=cv::Mat() ); 00263 virtual bool train( const cv::Mat& _train_data, const cv::Mat& _responses, 00264 const cv::Mat& _var_idx = cv::Mat(), const cv::Mat& _sample_idx=cv::Mat(), 00265 bool update=false ); 00266 virtual float predict( const cv::Mat& _samples, cv::Mat* results=0 ) const; 00267 #endif 00268 00269 virtual void write( CvFileStorage* storage, const char* name ) const; 00270 virtual void read( CvFileStorage* storage, CvFileNode* node ); 00271 00272 protected: 00273 int var_count, var_all; 00274 CvMat* var_idx; 00275 CvMat* cls_labels; 00276 CvMat** count; 00277 CvMat** sum; 00278 CvMat** productsum; 00279 CvMat** avg; 00280 CvMat** inv_eigen_values; 00281 CvMat** cov_rotate_mats; 00282 CvMat* c; 00283 }; 00284 00285 00286 /****************************************************************************************\ 00287 * K-Nearest Neighbour Classifier * 00288 \****************************************************************************************/ 00289 00290 // k Nearest Neighbors 00291 class CV_EXPORTS CvKNearest : public CvStatModel 00292 { 00293 public: 00294 00295 CvKNearest(); 00296 virtual ~CvKNearest(); 00297 00298 CvKNearest( const CvMat* _train_data, const CvMat* _responses, 00299 const CvMat* _sample_idx=0, bool _is_regression=false, int max_k=32 ); 00300 00301 virtual bool train( const CvMat* _train_data, const CvMat* _responses, 00302 const CvMat* _sample_idx=0, bool is_regression=false, 00303 int _max_k=32, bool _update_base=false ); 00304 00305 virtual float find_nearest( const CvMat* _samples, int k, CvMat* results=0, 00306 const float** neighbors=0, CvMat* neighbor_responses=0, CvMat* dist=0 ) const; 00307 00308 #ifndef SWIG 00309 CvKNearest( const cv::Mat& _train_data, const cv::Mat& _responses, 00310 const cv::Mat& _sample_idx=cv::Mat(), bool _is_regression=false, int max_k=32 ); 00311 00312 virtual bool train( const cv::Mat& _train_data, const cv::Mat& _responses, 00313 const cv::Mat& _sample_idx=cv::Mat(), bool is_regression=false, 00314 int _max_k=32, bool _update_base=false ); 00315 00316 virtual float find_nearest( const cv::Mat& _samples, int k, cv::Mat* results=0, 00317 const float** neighbors=0, 00318 cv::Mat* neighbor_responses=0, 00319 cv::Mat* dist=0 ) const; 00320 #endif 00321 00322 virtual void clear(); 00323 int get_max_k() const; 00324 int get_var_count() const; 00325 int get_sample_count() const; 00326 bool is_regression() const; 00327 00328 protected: 00329 00330 virtual float write_results( int k, int k1, int start, int end, 00331 const float* neighbor_responses, const float* dist, CvMat* _results, 00332 CvMat* _neighbor_responses, CvMat* _dist, Cv32suf* sort_buf ) const; 00333 00334 virtual void find_neighbors_direct( const CvMat* _samples, int k, int start, int end, 00335 float* neighbor_responses, const float** neighbors, float* dist ) const; 00336 00337 00338 int max_k, var_count; 00339 int total; 00340 bool regression; 00341 CvVectors* samples; 00342 }; 00343 00344 /****************************************************************************************\ 00345 * Support Vector Machines * 00346 \****************************************************************************************/ 00347 00348 // SVM training parameters 00349 struct CV_EXPORTS CvSVMParams 00350 { 00351 CvSVMParams(); 00352 CvSVMParams( int _svm_type, int _kernel_type, 00353 double _degree, double _gamma, double _coef0, 00354 double Cvalue, double _nu, double _p, 00355 CvMat* _class_weights, CvTermCriteria _term_crit ); 00356 00357 int svm_type; 00358 int kernel_type; 00359 double degree; // for poly 00360 double gamma; // for poly/rbf/sigmoid 00361 double coef0; // for poly/sigmoid 00362 00363 double C; // for CV_SVM_C_SVC, CV_SVM_EPS_SVR and CV_SVM_NU_SVR 00364 double nu; // for CV_SVM_NU_SVC, CV_SVM_ONE_CLASS, and CV_SVM_NU_SVR 00365 double p; // for CV_SVM_EPS_SVR 00366 CvMat* class_weights; // for CV_SVM_C_SVC 00367 CvTermCriteria term_crit; // termination criteria 00368 }; 00369 00370 00371 struct CV_EXPORTS CvSVMKernel 00372 { 00373 typedef void (CvSVMKernel::*Calc)( int vec_count, int vec_size, const float** vecs, 00374 const float* another, float* results ); 00375 CvSVMKernel(); 00376 CvSVMKernel( const CvSVMParams* _params, Calc _calc_func ); 00377 virtual bool create( const CvSVMParams* _params, Calc _calc_func ); 00378 virtual ~CvSVMKernel(); 00379 00380 virtual void clear(); 00381 virtual void calc( int vcount, int n, const float** vecs, const float* another, float* results ); 00382 00383 const CvSVMParams* params; 00384 Calc calc_func; 00385 00386 virtual void calc_non_rbf_base( int vec_count, int vec_size, const float** vecs, 00387 const float* another, float* results, 00388 double alpha, double beta ); 00389 00390 virtual void calc_linear( int vec_count, int vec_size, const float** vecs, 00391 const float* another, float* results ); 00392 virtual void calc_rbf( int vec_count, int vec_size, const float** vecs, 00393 const float* another, float* results ); 00394 virtual void calc_poly( int vec_count, int vec_size, const float** vecs, 00395 const float* another, float* results ); 00396 virtual void calc_sigmoid( int vec_count, int vec_size, const float** vecs, 00397 const float* another, float* results ); 00398 }; 00399 00400 00401 struct CvSVMKernelRow 00402 { 00403 CvSVMKernelRow* prev; 00404 CvSVMKernelRow* next; 00405 float* data; 00406 }; 00407 00408 00409 struct CvSVMSolutionInfo 00410 { 00411 double obj; 00412 double rho; 00413 double upper_bound_p; 00414 double upper_bound_n; 00415 double r; // for Solver_NU 00416 }; 00417 00418 class CV_EXPORTS CvSVMSolver 00419 { 00420 public: 00421 typedef bool (CvSVMSolver::*SelectWorkingSet)( int& i, int& j ); 00422 typedef float* (CvSVMSolver::*GetRow)( int i, float* row, float* dst, bool existed ); 00423 typedef void (CvSVMSolver::*CalcRho)( double& rho, double& r ); 00424 00425 CvSVMSolver(); 00426 00427 CvSVMSolver( int count, int var_count, const float** samples, schar* y, 00428 int alpha_count, double* alpha, double Cp, double Cn, 00429 CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row, 00430 SelectWorkingSet select_working_set, CalcRho calc_rho ); 00431 virtual bool create( int count, int var_count, const float** samples, schar* y, 00432 int alpha_count, double* alpha, double Cp, double Cn, 00433 CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row, 00434 SelectWorkingSet select_working_set, CalcRho calc_rho ); 00435 virtual ~CvSVMSolver(); 00436 00437 virtual void clear(); 00438 virtual bool solve_generic( CvSVMSolutionInfo& si ); 00439 00440 virtual bool solve_c_svc( int count, int var_count, const float** samples, schar* y, 00441 double Cp, double Cn, CvMemStorage* storage, 00442 CvSVMKernel* kernel, double* alpha, CvSVMSolutionInfo& si ); 00443 virtual bool solve_nu_svc( int count, int var_count, const float** samples, schar* y, 00444 CvMemStorage* storage, CvSVMKernel* kernel, 00445 double* alpha, CvSVMSolutionInfo& si ); 00446 virtual bool solve_one_class( int count, int var_count, const float** samples, 00447 CvMemStorage* storage, CvSVMKernel* kernel, 00448 double* alpha, CvSVMSolutionInfo& si ); 00449 00450 virtual bool solve_eps_svr( int count, int var_count, const float** samples, const float* y, 00451 CvMemStorage* storage, CvSVMKernel* kernel, 00452 double* alpha, CvSVMSolutionInfo& si ); 00453 00454 virtual bool solve_nu_svr( int count, int var_count, const float** samples, const float* y, 00455 CvMemStorage* storage, CvSVMKernel* kernel, 00456 double* alpha, CvSVMSolutionInfo& si ); 00457 00458 virtual float* get_row_base( int i, bool* _existed ); 00459 virtual float* get_row( int i, float* dst ); 00460 00461 int sample_count; 00462 int var_count; 00463 int cache_size; 00464 int cache_line_size; 00465 const float** samples; 00466 const CvSVMParams* params; 00467 CvMemStorage* storage; 00468 CvSVMKernelRow lru_list; 00469 CvSVMKernelRow* rows; 00470 00471 int alpha_count; 00472 00473 double* G; 00474 double* alpha; 00475 00476 // -1 - lower bound, 0 - free, 1 - upper bound 00477 schar* alpha_status; 00478 00479 schar* y; 00480 double* b; 00481 float* buf[2]; 00482 double eps; 00483 int max_iter; 00484 double C[2]; // C[0] == Cn, C[1] == Cp 00485 CvSVMKernel* kernel; 00486 00487 SelectWorkingSet select_working_set_func; 00488 CalcRho calc_rho_func; 00489 GetRow get_row_func; 00490 00491 virtual bool select_working_set( int& i, int& j ); 00492 virtual bool select_working_set_nu_svm( int& i, int& j ); 00493 virtual void calc_rho( double& rho, double& r ); 00494 virtual void calc_rho_nu_svm( double& rho, double& r ); 00495 00496 virtual float* get_row_svc( int i, float* row, float* dst, bool existed ); 00497 virtual float* get_row_one_class( int i, float* row, float* dst, bool existed ); 00498 virtual float* get_row_svr( int i, float* row, float* dst, bool existed ); 00499 }; 00500 00501 00502 struct CvSVMDecisionFunc 00503 { 00504 double rho; 00505 int sv_count; 00506 double* alpha; 00507 int* sv_index; 00508 }; 00509 00510 00511 // SVM model 00512 class CV_EXPORTS CvSVM : public CvStatModel 00513 { 00514 public: 00515 // SVM type 00516 enum { C_SVC=100, NU_SVC=101, ONE_CLASS=102, EPS_SVR=103, NU_SVR=104 }; 00517 00518 // SVM kernel type 00519 enum { LINEAR=0, POLY=1, RBF=2, SIGMOID=3 }; 00520 00521 // SVM params type 00522 enum { C=0, GAMMA=1, P=2, NU=3, COEF=4, DEGREE=5 }; 00523 00524 CvSVM(); 00525 virtual ~CvSVM(); 00526 00527 CvSVM( const CvMat* _train_data, const CvMat* _responses, 00528 const CvMat* _var_idx=0, const CvMat* _sample_idx=0, 00529 CvSVMParams _params=CvSVMParams() ); 00530 00531 virtual bool train( const CvMat* _train_data, const CvMat* _responses, 00532 const CvMat* _var_idx=0, const CvMat* _sample_idx=0, 00533 CvSVMParams _params=CvSVMParams() ); 00534 00535 virtual bool train_auto( const CvMat* _train_data, const CvMat* _responses, 00536 const CvMat* _var_idx, const CvMat* _sample_idx, CvSVMParams _params, 00537 int k_fold = 10, 00538 CvParamGrid C_grid = get_default_grid(CvSVM::C), 00539 CvParamGrid gamma_grid = get_default_grid(CvSVM::GAMMA), 00540 CvParamGrid p_grid = get_default_grid(CvSVM::P), 00541 CvParamGrid nu_grid = get_default_grid(CvSVM::NU), 00542 CvParamGrid coef_grid = get_default_grid(CvSVM::COEF), 00543 CvParamGrid degree_grid = get_default_grid(CvSVM::DEGREE) ); 00544 00545 virtual float predict( const CvMat* _sample, bool returnDFVal=false ) const; 00546 00547 #ifndef SWIG 00548 CvSVM( const cv::Mat& _train_data, const cv::Mat& _responses, 00549 const cv::Mat& _var_idx=cv::Mat(), const cv::Mat& _sample_idx=cv::Mat(), 00550 CvSVMParams _params=CvSVMParams() ); 00551 00552 virtual bool train( const cv::Mat& _train_data, const cv::Mat& _responses, 00553 const cv::Mat& _var_idx=cv::Mat(), const cv::Mat& _sample_idx=cv::Mat(), 00554 CvSVMParams _params=CvSVMParams() ); 00555 00556 virtual bool train_auto( const cv::Mat& _train_data, const cv::Mat& _responses, 00557 const cv::Mat& _var_idx, const cv::Mat& _sample_idx, CvSVMParams _params, 00558 int k_fold = 10, 00559 CvParamGrid C_grid = get_default_grid(CvSVM::C), 00560 CvParamGrid gamma_grid = get_default_grid(CvSVM::GAMMA), 00561 CvParamGrid p_grid = get_default_grid(CvSVM::P), 00562 CvParamGrid nu_grid = get_default_grid(CvSVM::NU), 00563 CvParamGrid coef_grid = get_default_grid(CvSVM::COEF), 00564 CvParamGrid degree_grid = get_default_grid(CvSVM::DEGREE) ); 00565 virtual float predict( const cv::Mat& _sample, bool returnDFVal=false ) const; 00566 #endif 00567 00568 virtual int get_support_vector_count() const; 00569 virtual const float* get_support_vector(int i) const; 00570 virtual CvSVMParams get_params() const { return params; }; 00571 virtual void clear(); 00572 00573 static CvParamGrid get_default_grid( int param_id ); 00574 00575 virtual void write( CvFileStorage* storage, const char* name ) const; 00576 virtual void read( CvFileStorage* storage, CvFileNode* node ); 00577 int get_var_count() const { return var_idx ? var_idx->cols : var_all; } 00578 00579 protected: 00580 00581 virtual bool set_params( const CvSVMParams& _params ); 00582 virtual bool train1( int sample_count, int var_count, const float** samples, 00583 const void* _responses, double Cp, double Cn, 00584 CvMemStorage* _storage, double* alpha, double& rho ); 00585 virtual bool do_train( int svm_type, int sample_count, int var_count, const float** samples, 00586 const CvMat* _responses, CvMemStorage* _storage, double* alpha ); 00587 virtual void create_kernel(); 00588 virtual void create_solver(); 00589 00590 virtual float predict( const float* row_sample, int row_len, bool returnDFVal=false ) const; 00591 00592 virtual void write_params( CvFileStorage* fs ) const; 00593 virtual void read_params( CvFileStorage* fs, CvFileNode* node ); 00594 00595 CvSVMParams params; 00596 CvMat* class_labels; 00597 int var_all; 00598 float** sv; 00599 int sv_total; 00600 CvMat* var_idx; 00601 CvMat* class_weights; 00602 CvSVMDecisionFunc* decision_func; 00603 CvMemStorage* storage; 00604 00605 CvSVMSolver* solver; 00606 CvSVMKernel* kernel; 00607 }; 00608 00609 /****************************************************************************************\ 00610 * Expectation - Maximization * 00611 \****************************************************************************************/ 00612 00613 struct CV_EXPORTS CvEMParams 00614 { 00615 CvEMParams() : nclusters(10), cov_mat_type(1/*CvEM::COV_MAT_DIAGONAL*/), 00616 start_step(0/*CvEM::START_AUTO_STEP*/), probs(0), weights(0), means(0), covs(0) 00617 { 00618 term_crit=cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 100, FLT_EPSILON ); 00619 } 00620 00621 CvEMParams( int _nclusters, int _cov_mat_type=1/*CvEM::COV_MAT_DIAGONAL*/, 00622 int _start_step=0/*CvEM::START_AUTO_STEP*/, 00623 CvTermCriteria _term_crit=cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 100, FLT_EPSILON), 00624 const CvMat* _probs=0, const CvMat* _weights=0, const CvMat* _means=0, const CvMat** _covs=0 ) : 00625 nclusters(_nclusters), cov_mat_type(_cov_mat_type), start_step(_start_step), 00626 probs(_probs), weights(_weights), means(_means), covs(_covs), term_crit(_term_crit) 00627 {} 00628 00629 int nclusters; 00630 int cov_mat_type; 00631 int start_step; 00632 const CvMat* probs; 00633 const CvMat* weights; 00634 const CvMat* means; 00635 const CvMat** covs; 00636 CvTermCriteria term_crit; 00637 }; 00638 00639 00640 class CV_EXPORTS CvEM : public CvStatModel 00641 { 00642 public: 00643 // Type of covariation matrices 00644 enum { COV_MAT_SPHERICAL=0, COV_MAT_DIAGONAL=1, COV_MAT_GENERIC=2 }; 00645 00646 // The initial step 00647 enum { START_E_STEP=1, START_M_STEP=2, START_AUTO_STEP=0 }; 00648 00649 CvEM(); 00650 CvEM( const CvMat* samples, const CvMat* sample_idx=0, 00651 CvEMParams params=CvEMParams(), CvMat* labels=0 ); 00652 //CvEM (CvEMParams params, CvMat * means, CvMat ** covs, CvMat * weights, CvMat * probs, CvMat * log_weight_div_det, CvMat * inv_eigen_values, CvMat** cov_rotate_mats); 00653 00654 virtual ~CvEM(); 00655 00656 virtual bool train( const CvMat* samples, const CvMat* sample_idx=0, 00657 CvEMParams params=CvEMParams(), CvMat* labels=0 ); 00658 00659 virtual float predict( const CvMat* sample, CvMat* probs ) const; 00660 00661 #ifndef SWIG 00662 CvEM( const cv::Mat& samples, const cv::Mat& sample_idx=cv::Mat(), 00663 CvEMParams params=CvEMParams(), cv::Mat* labels=0 ); 00664 00665 virtual bool train( const cv::Mat& samples, const cv::Mat& sample_idx=cv::Mat(), 00666 CvEMParams params=CvEMParams(), cv::Mat* labels=0 ); 00667 00668 virtual float predict( const cv::Mat& sample, cv::Mat* probs ) const; 00669 #endif 00670 00671 virtual void clear(); 00672 00673 int get_nclusters() const; 00674 const CvMat* get_means() const; 00675 const CvMat** get_covs() const; 00676 const CvMat* get_weights() const; 00677 const CvMat* get_probs() const; 00678 00679 inline double get_log_likelihood () const { return log_likelihood; }; 00680 00681 // inline const CvMat * get_log_weight_div_det () const { return log_weight_div_det; }; 00682 // inline const CvMat * get_inv_eigen_values () const { return inv_eigen_values; }; 00683 // inline const CvMat ** get_cov_rotate_mats () const { return cov_rotate_mats; }; 00684 00685 protected: 00686 00687 virtual void set_params( const CvEMParams& params, 00688 const CvVectors& train_data ); 00689 virtual void init_em( const CvVectors& train_data ); 00690 virtual double run_em( const CvVectors& train_data ); 00691 virtual void init_auto( const CvVectors& samples ); 00692 virtual void kmeans( const CvVectors& train_data, int nclusters, 00693 CvMat* labels, CvTermCriteria criteria, 00694 const CvMat* means ); 00695 CvEMParams params; 00696 double log_likelihood; 00697 00698 CvMat* means; 00699 CvMat** covs; 00700 CvMat* weights; 00701 CvMat* probs; 00702 00703 CvMat* log_weight_div_det; 00704 CvMat* inv_eigen_values; 00705 CvMat** cov_rotate_mats; 00706 }; 00707 00708 /****************************************************************************************\ 00709 * Decision Tree * 00710 \****************************************************************************************/\ 00711 struct CvPair16u32s 00712 { 00713 unsigned short* u; 00714 int* i; 00715 }; 00716 00717 00718 #define CV_DTREE_CAT_DIR(idx,subset) \ 00719 (2*((subset[(idx)>>5]&(1 << ((idx) & 31)))==0)-1) 00720 00721 struct CvDTreeSplit 00722 { 00723 int var_idx; 00724 int condensed_idx; 00725 int inversed; 00726 float quality; 00727 CvDTreeSplit* next; 00728 union 00729 { 00730 int subset[2]; 00731 struct 00732 { 00733 float c; 00734 int split_point; 00735 } 00736 ord; 00737 }; 00738 }; 00739 00740 00741 struct CvDTreeNode 00742 { 00743 int class_idx; 00744 int Tn; 00745 double value; 00746 00747 CvDTreeNode* parent; 00748 CvDTreeNode* left; 00749 CvDTreeNode* right; 00750 00751 CvDTreeSplit* split; 00752 00753 int sample_count; 00754 int depth; 00755 int* num_valid; 00756 int offset; 00757 int buf_idx; 00758 double maxlr; 00759 00760 // global pruning data 00761 int complexity; 00762 double alpha; 00763 double node_risk, tree_risk, tree_error; 00764 00765 // cross-validation pruning data 00766 int* cv_Tn; 00767 double* cv_node_risk; 00768 double* cv_node_error; 00769 00770 int get_num_valid(int vi) { return num_valid ? num_valid[vi] : sample_count; } 00771 void set_num_valid(int vi, int n) { if( num_valid ) num_valid[vi] = n; } 00772 }; 00773 00774 00775 struct CV_EXPORTS CvDTreeParams 00776 { 00777 int max_categories; 00778 int max_depth; 00779 int min_sample_count; 00780 int cv_folds; 00781 bool use_surrogates; 00782 bool use_1se_rule; 00783 bool truncate_pruned_tree; 00784 float regression_accuracy; 00785 const float* priors; 00786 00787 CvDTreeParams() : max_categories(10), max_depth(INT_MAX), min_sample_count(10), 00788 cv_folds(10), use_surrogates(true), use_1se_rule(true), 00789 truncate_pruned_tree(true), regression_accuracy(0.01f), priors(0) 00790 {} 00791 00792 CvDTreeParams( int _max_depth, int _min_sample_count, 00793 float _regression_accuracy, bool _use_surrogates, 00794 int _max_categories, int _cv_folds, 00795 bool _use_1se_rule, bool _truncate_pruned_tree, 00796 const float* _priors ) : 00797 max_categories(_max_categories), max_depth(_max_depth), 00798 min_sample_count(_min_sample_count), cv_folds (_cv_folds), 00799 use_surrogates(_use_surrogates), use_1se_rule(_use_1se_rule), 00800 truncate_pruned_tree(_truncate_pruned_tree), 00801 regression_accuracy(_regression_accuracy), 00802 priors(_priors) 00803 {} 00804 }; 00805 00806 00807 struct CV_EXPORTS CvDTreeTrainData 00808 { 00809 CvDTreeTrainData(); 00810 CvDTreeTrainData( const CvMat* _train_data, int _tflag, 00811 const CvMat* _responses, const CvMat* _var_idx=0, 00812 const CvMat* _sample_idx=0, const CvMat* _var_type=0, 00813 const CvMat* _missing_mask=0, 00814 const CvDTreeParams& _params=CvDTreeParams(), 00815 bool _shared=false, bool _add_labels=false ); 00816 virtual ~CvDTreeTrainData(); 00817 00818 virtual void set_data( const CvMat* _train_data, int _tflag, 00819 const CvMat* _responses, const CvMat* _var_idx=0, 00820 const CvMat* _sample_idx=0, const CvMat* _var_type=0, 00821 const CvMat* _missing_mask=0, 00822 const CvDTreeParams& _params=CvDTreeParams(), 00823 bool _shared=false, bool _add_labels=false, 00824 bool _update_data=false ); 00825 virtual void do_responses_copy(); 00826 00827 virtual void get_vectors( const CvMat* _subsample_idx, 00828 float* values, uchar* missing, float* responses, bool get_class_idx=false ); 00829 00830 virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx ); 00831 00832 virtual void write_params( CvFileStorage* fs ) const; 00833 virtual void read_params( CvFileStorage* fs, CvFileNode* node ); 00834 00835 // release all the data 00836 virtual void clear(); 00837 00838 int get_num_classes() const; 00839 int get_var_type(int vi) const; 00840 int get_work_var_count() const {return work_var_count;} 00841 00842 virtual const float* get_ord_responses( CvDTreeNode* n, float* values_buf, int* sample_indices_buf ); 00843 virtual const int* get_class_labels( CvDTreeNode* n, int* labels_buf ); 00844 virtual const int* get_cv_labels( CvDTreeNode* n, int* labels_buf ); 00845 virtual const int* get_sample_indices( CvDTreeNode* n, int* indices_buf ); 00846 virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf ); 00847 virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* sorted_indices_buf, 00848 const float** ord_values, const int** sorted_indices, int* sample_indices_buf ); 00849 virtual int get_child_buf_idx( CvDTreeNode* n ); 00850 00852 00853 virtual bool set_params( const CvDTreeParams& params ); 00854 virtual CvDTreeNode* new_node( CvDTreeNode* parent, int count, 00855 int storage_idx, int offset ); 00856 00857 virtual CvDTreeSplit* new_split_ord( int vi, float cmp_val, 00858 int split_point, int inversed, float quality ); 00859 virtual CvDTreeSplit* new_split_cat( int vi, float quality ); 00860 virtual void free_node_data( CvDTreeNode* node ); 00861 virtual void free_train_data(); 00862 virtual void free_node( CvDTreeNode* node ); 00863 00864 int sample_count, var_all, var_count, max_c_count; 00865 int ord_var_count, cat_var_count, work_var_count; 00866 bool have_labels, have_priors; 00867 bool is_classifier; 00868 int tflag; 00869 00870 const CvMat* train_data; 00871 const CvMat* responses; 00872 CvMat* responses_copy; // used in Boosting 00873 00874 int buf_count, buf_size; 00875 bool shared; 00876 int is_buf_16u; 00877 00878 CvMat* cat_count; 00879 CvMat* cat_ofs; 00880 CvMat* cat_map; 00881 00882 CvMat* counts; 00883 CvMat* buf; 00884 CvMat* direction; 00885 CvMat* split_buf; 00886 00887 CvMat* var_idx; 00888 CvMat* var_type; // i-th element = 00889 // k<0 - ordered 00890 // k>=0 - categorical, see k-th element of cat_* arrays 00891 CvMat* priors; 00892 CvMat* priors_mult; 00893 00894 CvDTreeParams params; 00895 00896 CvMemStorage* tree_storage; 00897 CvMemStorage* temp_storage; 00898 00899 CvDTreeNode* data_root; 00900 00901 CvSet* node_heap; 00902 CvSet* split_heap; 00903 CvSet* cv_heap; 00904 CvSet* nv_heap; 00905 00906 CvRNG rng; 00907 }; 00908 00909 class CvDTree; 00910 class CvForestTree; 00911 00912 namespace cv 00913 { 00914 struct DTreeBestSplitFinder; 00915 struct ForestTreeBestSplitFinder; 00916 } 00917 00918 class CV_EXPORTS CvDTree : public CvStatModel 00919 { 00920 public: 00921 CvDTree(); 00922 virtual ~CvDTree(); 00923 00924 virtual bool train( const CvMat* _train_data, int _tflag, 00925 const CvMat* _responses, const CvMat* _var_idx=0, 00926 const CvMat* _sample_idx=0, const CvMat* _var_type=0, 00927 const CvMat* _missing_mask=0, 00928 CvDTreeParams params=CvDTreeParams() ); 00929 00930 virtual bool train( CvMLData* _data, CvDTreeParams _params=CvDTreeParams() ); 00931 00932 virtual float calc_error( CvMLData* _data, int type , std::vector<float> *resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR} 00933 00934 virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx ); 00935 00936 virtual CvDTreeNode* predict( const CvMat* _sample, const CvMat* _missing_data_mask=0, 00937 bool preprocessed_input=false ) const; 00938 00939 #ifndef SWIG 00940 virtual bool train( const cv::Mat& _train_data, int _tflag, 00941 const cv::Mat& _responses, const cv::Mat& _var_idx=cv::Mat(), 00942 const cv::Mat& _sample_idx=cv::Mat(), const cv::Mat& _var_type=cv::Mat(), 00943 const cv::Mat& _missing_mask=cv::Mat(), 00944 CvDTreeParams params=CvDTreeParams() ); 00945 00946 virtual CvDTreeNode* predict( const cv::Mat& _sample, const cv::Mat& _missing_data_mask=cv::Mat(), 00947 bool preprocessed_input=false ) const; 00948 #endif 00949 00950 virtual const CvMat* get_var_importance(); 00951 virtual void clear(); 00952 00953 virtual void read( CvFileStorage* fs, CvFileNode* node ); 00954 virtual void write( CvFileStorage* fs, const char* name ) const; 00955 00956 // special read & write methods for trees in the tree ensembles 00957 virtual void read( CvFileStorage* fs, CvFileNode* node, 00958 CvDTreeTrainData* data ); 00959 virtual void write( CvFileStorage* fs ) const; 00960 00961 const CvDTreeNode* get_root() const; 00962 int get_pruned_tree_idx() const; 00963 CvDTreeTrainData* get_data(); 00964 00965 protected: 00966 friend struct cv::DTreeBestSplitFinder; 00967 00968 virtual bool do_train( const CvMat* _subsample_idx ); 00969 00970 virtual void try_split_node( CvDTreeNode* n ); 00971 virtual void split_node_data( CvDTreeNode* n ); 00972 virtual CvDTreeSplit* find_best_split( CvDTreeNode* n ); 00973 virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi, 00974 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 00975 virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi, 00976 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 00977 virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi, 00978 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 00979 virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi, 00980 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 00981 virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi, uchar* ext_buf = 0 ); 00982 virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi, uchar* ext_buf = 0 ); 00983 virtual double calc_node_dir( CvDTreeNode* node ); 00984 virtual void complete_node_dir( CvDTreeNode* node ); 00985 virtual void cluster_categories( const int* vectors, int vector_count, 00986 int var_count, int* sums, int k, int* cluster_labels ); 00987 00988 virtual void calc_node_value( CvDTreeNode* node ); 00989 00990 virtual void prune_cv(); 00991 virtual double update_tree_rnc( int T, int fold ); 00992 virtual int cut_tree( int T, int fold, double min_alpha ); 00993 virtual void free_prune_data(bool cut_tree); 00994 virtual void free_tree(); 00995 00996 virtual void write_node( CvFileStorage* fs, CvDTreeNode* node ) const; 00997 virtual void write_split( CvFileStorage* fs, CvDTreeSplit* split ) const; 00998 virtual CvDTreeNode* read_node( CvFileStorage* fs, CvFileNode* node, CvDTreeNode* parent ); 00999 virtual CvDTreeSplit* read_split( CvFileStorage* fs, CvFileNode* node ); 01000 virtual void write_tree_nodes( CvFileStorage* fs ) const; 01001 virtual void read_tree_nodes( CvFileStorage* fs, CvFileNode* node ); 01002 01003 CvDTreeNode* root; 01004 CvMat* var_importance; 01005 CvDTreeTrainData* data; 01006 01007 public: 01008 int pruned_tree_idx; 01009 }; 01010 01011 01012 /****************************************************************************************\ 01013 * Random Trees Classifier * 01014 \****************************************************************************************/ 01015 01016 class CvRTrees; 01017 01018 class CV_EXPORTS CvForestTree: public CvDTree 01019 { 01020 public: 01021 CvForestTree(); 01022 virtual ~CvForestTree(); 01023 01024 virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx, CvRTrees* forest ); 01025 01026 virtual int get_var_count() const {return data ? data->var_count : 0;} 01027 virtual void read( CvFileStorage* fs, CvFileNode* node, CvRTrees* forest, CvDTreeTrainData* _data ); 01028 01029 /* dummy methods to avoid warnings: BEGIN */ 01030 virtual bool train( const CvMat* _train_data, int _tflag, 01031 const CvMat* _responses, const CvMat* _var_idx=0, 01032 const CvMat* _sample_idx=0, const CvMat* _var_type=0, 01033 const CvMat* _missing_mask=0, 01034 CvDTreeParams params=CvDTreeParams() ); 01035 01036 virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx ); 01037 virtual void read( CvFileStorage* fs, CvFileNode* node ); 01038 virtual void read( CvFileStorage* fs, CvFileNode* node, 01039 CvDTreeTrainData* data ); 01040 /* dummy methods to avoid warnings: END */ 01041 01042 protected: 01043 friend struct cv::ForestTreeBestSplitFinder; 01044 01045 virtual CvDTreeSplit* find_best_split( CvDTreeNode* n ); 01046 CvRTrees* forest; 01047 }; 01048 01049 01050 struct CV_EXPORTS CvRTParams : public CvDTreeParams 01051 { 01052 //Parameters for the forest 01053 bool calc_var_importance; // true <=> RF processes variable importance 01054 int nactive_vars; 01055 CvTermCriteria term_crit; 01056 01057 CvRTParams() : CvDTreeParams( 5, 10, 0, false, 10, 0, false, false, 0 ), 01058 calc_var_importance(false), nactive_vars(0) 01059 { 01060 term_crit = cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 50, 0.1 ); 01061 } 01062 01063 CvRTParams( int _max_depth, int _min_sample_count, 01064 float _regression_accuracy, bool _use_surrogates, 01065 int _max_categories, const float* _priors, bool _calc_var_importance, 01066 int _nactive_vars, int max_num_of_trees_in_the_forest, 01067 float forest_accuracy, int termcrit_type ) : 01068 CvDTreeParams( _max_depth, _min_sample_count, _regression_accuracy, 01069 _use_surrogates, _max_categories, 0, 01070 false, false, _priors ), 01071 calc_var_importance(_calc_var_importance), 01072 nactive_vars(_nactive_vars) 01073 { 01074 term_crit = cvTermCriteria(termcrit_type, 01075 max_num_of_trees_in_the_forest, forest_accuracy); 01076 } 01077 }; 01078 01079 01080 class CV_EXPORTS CvRTrees : public CvStatModel 01081 { 01082 public: 01083 CvRTrees(); 01084 virtual ~CvRTrees(); 01085 virtual bool train( const CvMat* _train_data, int _tflag, 01086 const CvMat* _responses, const CvMat* _var_idx=0, 01087 const CvMat* _sample_idx=0, const CvMat* _var_type=0, 01088 const CvMat* _missing_mask=0, 01089 CvRTParams params=CvRTParams() ); 01090 01091 virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() ); 01092 virtual float predict( const CvMat* sample, const CvMat* missing = 0 ) const; 01093 virtual float predict_prob( const CvMat* sample, const CvMat* missing = 0 ) const; 01094 01095 #ifndef SWIG 01096 virtual bool train( const cv::Mat& _train_data, int _tflag, 01097 const cv::Mat& _responses, const cv::Mat& _var_idx=cv::Mat(), 01098 const cv::Mat& _sample_idx=cv::Mat(), const cv::Mat& _var_type=cv::Mat(), 01099 const cv::Mat& _missing_mask=cv::Mat(), 01100 CvRTParams params=CvRTParams() ); 01101 virtual float predict( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const; 01102 virtual float predict_prob( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const; 01103 #endif 01104 01105 virtual void clear(); 01106 01107 virtual const CvMat* get_var_importance(); 01108 virtual float get_proximity( const CvMat* sample1, const CvMat* sample2, 01109 const CvMat* missing1 = 0, const CvMat* missing2 = 0 ) const; 01110 01111 virtual float calc_error( CvMLData* _data, int type , std::vector<float> *resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR} 01112 01113 virtual float get_train_error(); 01114 01115 virtual void read( CvFileStorage* fs, CvFileNode* node ); 01116 virtual void write( CvFileStorage* fs, const char* name ) const; 01117 01118 CvMat* get_active_var_mask(); 01119 CvRNG* get_rng(); 01120 01121 int get_tree_count() const; 01122 CvForestTree* get_tree(int i) const; 01123 01124 protected: 01125 01126 virtual bool grow_forest( const CvTermCriteria term_crit ); 01127 01128 // array of the trees of the forest 01129 CvForestTree** trees; 01130 CvDTreeTrainData* data; 01131 int ntrees; 01132 int nclasses; 01133 double oob_error; 01134 CvMat* var_importance; 01135 int nsamples; 01136 01137 CvRNG rng; 01138 CvMat* active_var_mask; 01139 }; 01140 01141 /****************************************************************************************\ 01142 * Extremely randomized trees Classifier * 01143 \****************************************************************************************/ 01144 struct CV_EXPORTS CvERTreeTrainData : public CvDTreeTrainData 01145 { 01146 virtual void set_data( const CvMat* _train_data, int _tflag, 01147 const CvMat* _responses, const CvMat* _var_idx=0, 01148 const CvMat* _sample_idx=0, const CvMat* _var_type=0, 01149 const CvMat* _missing_mask=0, 01150 const CvDTreeParams& _params=CvDTreeParams(), 01151 bool _shared=false, bool _add_labels=false, 01152 bool _update_data=false ); 01153 virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* missing_buf, 01154 const float** ord_values, const int** missing, int* sample_buf = 0 ); 01155 virtual const int* get_sample_indices( CvDTreeNode* n, int* indices_buf ); 01156 virtual const int* get_cv_labels( CvDTreeNode* n, int* labels_buf ); 01157 virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf ); 01158 virtual void get_vectors( const CvMat* _subsample_idx, float* values, uchar* missing, 01159 float* responses, bool get_class_idx=false ); 01160 virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx ); 01161 const CvMat* missing_mask; 01162 }; 01163 01164 class CV_EXPORTS CvForestERTree : public CvForestTree 01165 { 01166 protected: 01167 virtual double calc_node_dir( CvDTreeNode* node ); 01168 virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi, 01169 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 01170 virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi, 01171 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 01172 virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi, 01173 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 01174 virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi, 01175 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 01176 virtual void split_node_data( CvDTreeNode* n ); 01177 }; 01178 01179 class CV_EXPORTS CvERTrees : public CvRTrees 01180 { 01181 public: 01182 CvERTrees(); 01183 virtual ~CvERTrees(); 01184 virtual bool train( const CvMat* _train_data, int _tflag, 01185 const CvMat* _responses, const CvMat* _var_idx=0, 01186 const CvMat* _sample_idx=0, const CvMat* _var_type=0, 01187 const CvMat* _missing_mask=0, 01188 CvRTParams params=CvRTParams()); 01189 #ifndef SWIG 01190 virtual bool train( const cv::Mat& _train_data, int _tflag, 01191 const cv::Mat& _responses, const cv::Mat& _var_idx=cv::Mat(), 01192 const cv::Mat& _sample_idx=cv::Mat(), const cv::Mat& _var_type=cv::Mat(), 01193 const cv::Mat& _missing_mask=cv::Mat(), 01194 CvRTParams params=CvRTParams()); 01195 #endif 01196 virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() ); 01197 protected: 01198 virtual bool grow_forest( const CvTermCriteria term_crit ); 01199 }; 01200 01201 01202 /****************************************************************************************\ 01203 * Boosted tree classifier * 01204 \****************************************************************************************/ 01205 01206 struct CV_EXPORTS CvBoostParams : public CvDTreeParams 01207 { 01208 int boost_type; 01209 int weak_count; 01210 int split_criteria; 01211 double weight_trim_rate; 01212 01213 CvBoostParams(); 01214 CvBoostParams( int boost_type, int weak_count, double weight_trim_rate, 01215 int max_depth, bool use_surrogates, const float* priors ); 01216 }; 01217 01218 01219 class CvBoost; 01220 01221 class CV_EXPORTS CvBoostTree: public CvDTree 01222 { 01223 public: 01224 CvBoostTree(); 01225 virtual ~CvBoostTree(); 01226 01227 virtual bool train( CvDTreeTrainData* _train_data, 01228 const CvMat* subsample_idx, CvBoost* ensemble ); 01229 01230 virtual void scale( double s ); 01231 virtual void read( CvFileStorage* fs, CvFileNode* node, 01232 CvBoost* ensemble, CvDTreeTrainData* _data ); 01233 virtual void clear(); 01234 01235 /* dummy methods to avoid warnings: BEGIN */ 01236 virtual bool train( const CvMat* _train_data, int _tflag, 01237 const CvMat* _responses, const CvMat* _var_idx=0, 01238 const CvMat* _sample_idx=0, const CvMat* _var_type=0, 01239 const CvMat* _missing_mask=0, 01240 CvDTreeParams params=CvDTreeParams() ); 01241 virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx ); 01242 01243 virtual void read( CvFileStorage* fs, CvFileNode* node ); 01244 virtual void read( CvFileStorage* fs, CvFileNode* node, 01245 CvDTreeTrainData* data ); 01246 /* dummy methods to avoid warnings: END */ 01247 01248 protected: 01249 01250 virtual void try_split_node( CvDTreeNode* n ); 01251 virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi, uchar* ext_buf = 0 ); 01252 virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi, uchar* ext_buf = 0 ); 01253 virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi, 01254 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 01255 virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi, 01256 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 01257 virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi, 01258 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 01259 virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi, 01260 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 01261 virtual void calc_node_value( CvDTreeNode* n ); 01262 virtual double calc_node_dir( CvDTreeNode* n ); 01263 01264 CvBoost* ensemble; 01265 }; 01266 01267 01268 class CV_EXPORTS CvBoost : public CvStatModel 01269 { 01270 public: 01271 // Boosting type 01272 enum { DISCRETE=0, REAL=1, LOGIT=2, GENTLE=3 }; 01273 01274 // Splitting criteria 01275 enum { DEFAULT=0, GINI=1, MISCLASS=3, SQERR=4 }; 01276 01277 CvBoost(); 01278 virtual ~CvBoost(); 01279 01280 CvBoost( const CvMat* _train_data, int _tflag, 01281 const CvMat* _responses, const CvMat* _var_idx=0, 01282 const CvMat* _sample_idx=0, const CvMat* _var_type=0, 01283 const CvMat* _missing_mask=0, 01284 CvBoostParams params=CvBoostParams() ); 01285 01286 virtual bool train( const CvMat* _train_data, int _tflag, 01287 const CvMat* _responses, const CvMat* _var_idx=0, 01288 const CvMat* _sample_idx=0, const CvMat* _var_type=0, 01289 const CvMat* _missing_mask=0, 01290 CvBoostParams params=CvBoostParams(), 01291 bool update=false ); 01292 01293 virtual bool train( CvMLData* data, 01294 CvBoostParams params=CvBoostParams(), 01295 bool update=false ); 01296 01297 virtual float predict( const CvMat* _sample, const CvMat* _missing=0, 01298 CvMat* weak_responses=0, CvSlice slice=CV_WHOLE_SEQ, 01299 bool raw_mode=false, bool return_sum=false ) const; 01300 01301 #ifndef SWIG 01302 CvBoost( const cv::Mat& _train_data, int _tflag, 01303 const cv::Mat& _responses, const cv::Mat& _var_idx=cv::Mat(), 01304 const cv::Mat& _sample_idx=cv::Mat(), const cv::Mat& _var_type=cv::Mat(), 01305 const cv::Mat& _missing_mask=cv::Mat(), 01306 CvBoostParams params=CvBoostParams() ); 01307 01308 virtual bool train( const cv::Mat& _train_data, int _tflag, 01309 const cv::Mat& _responses, const cv::Mat& _var_idx=cv::Mat(), 01310 const cv::Mat& _sample_idx=cv::Mat(), const cv::Mat& _var_type=cv::Mat(), 01311 const cv::Mat& _missing_mask=cv::Mat(), 01312 CvBoostParams params=CvBoostParams(), 01313 bool update=false ); 01314 01315 virtual float predict( const cv::Mat& _sample, const cv::Mat& _missing=cv::Mat(), 01316 cv::Mat* weak_responses=0, CvSlice slice=CV_WHOLE_SEQ, 01317 bool raw_mode=false, bool return_sum=false ) const; 01318 #endif 01319 01320 virtual float calc_error( CvMLData* _data, int type , std::vector<float> *resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR} 01321 01322 virtual void prune( CvSlice slice ); 01323 01324 virtual void clear(); 01325 01326 virtual void write( CvFileStorage* storage, const char* name ) const; 01327 virtual void read( CvFileStorage* storage, CvFileNode* node ); 01328 virtual const CvMat* get_active_vars(bool absolute_idx=true); 01329 01330 CvSeq* get_weak_predictors(); 01331 01332 CvMat* get_weights(); 01333 CvMat* get_subtree_weights(); 01334 CvMat* get_weak_response(); 01335 const CvBoostParams& get_params() const; 01336 const CvDTreeTrainData* get_data() const; 01337 01338 protected: 01339 01340 virtual bool set_params( const CvBoostParams& _params ); 01341 virtual void update_weights( CvBoostTree* tree ); 01342 virtual void trim_weights(); 01343 virtual void write_params( CvFileStorage* fs ) const; 01344 virtual void read_params( CvFileStorage* fs, CvFileNode* node ); 01345 01346 CvDTreeTrainData* data; 01347 CvBoostParams params; 01348 CvSeq* weak; 01349 01350 CvMat* active_vars; 01351 CvMat* active_vars_abs; 01352 bool have_active_cat_vars; 01353 01354 CvMat* orig_response; 01355 CvMat* sum_response; 01356 CvMat* weak_eval; 01357 CvMat* subsample_mask; 01358 CvMat* weights; 01359 CvMat* subtree_weights; 01360 bool have_subsample; 01361 }; 01362 01363 01364 /****************************************************************************************\ 01365 * Artificial Neural Networks (ANN) * 01366 \****************************************************************************************/ 01367 01369 01370 struct CV_EXPORTS CvANN_MLP_TrainParams 01371 { 01372 CvANN_MLP_TrainParams(); 01373 CvANN_MLP_TrainParams( CvTermCriteria term_crit, int train_method, 01374 double param1, double param2=0 ); 01375 ~CvANN_MLP_TrainParams(); 01376 01377 enum { BACKPROP=0, RPROP=1 }; 01378 01379 CvTermCriteria term_crit; 01380 int train_method; 01381 01382 // backpropagation parameters 01383 double bp_dw_scale, bp_moment_scale; 01384 01385 // rprop parameters 01386 double rp_dw0, rp_dw_plus, rp_dw_minus, rp_dw_min, rp_dw_max; 01387 }; 01388 01389 01390 class CV_EXPORTS CvANN_MLP : public CvStatModel 01391 { 01392 public: 01393 CvANN_MLP(); 01394 CvANN_MLP( const CvMat* _layer_sizes, 01395 int _activ_func=SIGMOID_SYM, 01396 double _f_param1=0, double _f_param2=0 ); 01397 01398 virtual ~CvANN_MLP(); 01399 01400 virtual void create( const CvMat* _layer_sizes, 01401 int _activ_func=SIGMOID_SYM, 01402 double _f_param1=0, double _f_param2=0 ); 01403 01404 virtual int train( const CvMat* _inputs, const CvMat* _outputs, 01405 const CvMat* _sample_weights, const CvMat* _sample_idx=0, 01406 CvANN_MLP_TrainParams _params = CvANN_MLP_TrainParams(), 01407 int flags=0 ); 01408 virtual float predict( const CvMat* _inputs, CvMat* _outputs ) const; 01409 01410 #ifndef SWIG 01411 CvANN_MLP( const cv::Mat& _layer_sizes, 01412 int _activ_func=SIGMOID_SYM, 01413 double _f_param1=0, double _f_param2=0 ); 01414 01415 virtual void create( const cv::Mat& _layer_sizes, 01416 int _activ_func=SIGMOID_SYM, 01417 double _f_param1=0, double _f_param2=0 ); 01418 01419 virtual int train( const cv::Mat& _inputs, const cv::Mat& _outputs, 01420 const cv::Mat& _sample_weights, const cv::Mat& _sample_idx=cv::Mat(), 01421 CvANN_MLP_TrainParams _params = CvANN_MLP_TrainParams(), 01422 int flags=0 ); 01423 01424 virtual float predict( const cv::Mat& _inputs, cv::Mat& _outputs ) const; 01425 #endif 01426 01427 virtual void clear(); 01428 01429 // possible activation functions 01430 enum { IDENTITY = 0, SIGMOID_SYM = 1, GAUSSIAN = 2 }; 01431 01432 // available training flags 01433 enum { UPDATE_WEIGHTS = 1, NO_INPUT_SCALE = 2, NO_OUTPUT_SCALE = 4 }; 01434 01435 virtual void read( CvFileStorage* fs, CvFileNode* node ); 01436 virtual void write( CvFileStorage* storage, const char* name ) const; 01437 01438 int get_layer_count() { return layer_sizes ? layer_sizes->cols : 0; } 01439 const CvMat* get_layer_sizes() { return layer_sizes; } 01440 double* get_weights(int layer) 01441 { 01442 return layer_sizes && weights && 01443 (unsigned)layer <= (unsigned)layer_sizes->cols ? weights[layer] : 0; 01444 } 01445 01446 protected: 01447 01448 virtual bool prepare_to_train( const CvMat* _inputs, const CvMat* _outputs, 01449 const CvMat* _sample_weights, const CvMat* _sample_idx, 01450 CvVectors* _ivecs, CvVectors* _ovecs, double** _sw, int _flags ); 01451 01452 // sequential random backpropagation 01453 virtual int train_backprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw ); 01454 01455 // RPROP algorithm 01456 virtual int train_rprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw ); 01457 01458 virtual void calc_activ_func( CvMat* xf, const double* bias ) const; 01459 virtual void calc_activ_func_deriv( CvMat* xf, CvMat* deriv, const double* bias ) const; 01460 virtual void set_activ_func( int _activ_func=SIGMOID_SYM, 01461 double _f_param1=0, double _f_param2=0 ); 01462 virtual void init_weights(); 01463 virtual void scale_input( const CvMat* _src, CvMat* _dst ) const; 01464 virtual void scale_output( const CvMat* _src, CvMat* _dst ) const; 01465 virtual void calc_input_scale( const CvVectors* vecs, int flags ); 01466 virtual void calc_output_scale( const CvVectors* vecs, int flags ); 01467 01468 virtual void write_params( CvFileStorage* fs ) const; 01469 virtual void read_params( CvFileStorage* fs, CvFileNode* node ); 01470 01471 CvMat* layer_sizes; 01472 CvMat* wbuf; 01473 CvMat* sample_weights; 01474 double** weights; 01475 double f_param1, f_param2; 01476 double min_val, max_val, min_val1, max_val1; 01477 int activ_func; 01478 int max_count, max_buf_sz; 01479 CvANN_MLP_TrainParams params; 01480 CvRNG rng; 01481 }; 01482 01483 #if 0 01484 /****************************************************************************************\ 01485 * Convolutional Neural Network * 01486 \****************************************************************************************/ 01487 typedef struct CvCNNLayer CvCNNLayer; 01488 typedef struct CvCNNetwork CvCNNetwork; 01489 01490 #define CV_CNN_LEARN_RATE_DECREASE_HYPERBOLICALLY 1 01491 #define CV_CNN_LEARN_RATE_DECREASE_SQRT_INV 2 01492 #define CV_CNN_LEARN_RATE_DECREASE_LOG_INV 3 01493 01494 #define CV_CNN_GRAD_ESTIM_RANDOM 0 01495 #define CV_CNN_GRAD_ESTIM_BY_WORST_IMG 1 01496 01497 #define ICV_CNN_LAYER 0x55550000 01498 #define ICV_CNN_CONVOLUTION_LAYER 0x00001111 01499 #define ICV_CNN_SUBSAMPLING_LAYER 0x00002222 01500 #define ICV_CNN_FULLCONNECT_LAYER 0x00003333 01501 01502 #define ICV_IS_CNN_LAYER( layer ) \ 01503 ( ((layer) != NULL) && ((((CvCNNLayer*)(layer))->flags & CV_MAGIC_MASK)\ 01504 == ICV_CNN_LAYER )) 01505 01506 #define ICV_IS_CNN_CONVOLUTION_LAYER( layer ) \ 01507 ( (ICV_IS_CNN_LAYER( layer )) && (((CvCNNLayer*) (layer))->flags \ 01508 & ~CV_MAGIC_MASK) == ICV_CNN_CONVOLUTION_LAYER ) 01509 01510 #define ICV_IS_CNN_SUBSAMPLING_LAYER( layer ) \ 01511 ( (ICV_IS_CNN_LAYER( layer )) && (((CvCNNLayer*) (layer))->flags \ 01512 & ~CV_MAGIC_MASK) == ICV_CNN_SUBSAMPLING_LAYER ) 01513 01514 #define ICV_IS_CNN_FULLCONNECT_LAYER( layer ) \ 01515 ( (ICV_IS_CNN_LAYER( layer )) && (((CvCNNLayer*) (layer))->flags \ 01516 & ~CV_MAGIC_MASK) == ICV_CNN_FULLCONNECT_LAYER ) 01517 01518 typedef void (CV_CDECL *CvCNNLayerForward) 01519 ( CvCNNLayer* layer, const CvMat* input, CvMat* output ); 01520 01521 typedef void (CV_CDECL *CvCNNLayerBackward) 01522 ( CvCNNLayer* layer, int t, const CvMat* X, const CvMat* dE_dY, CvMat* dE_dX ); 01523 01524 typedef void (CV_CDECL *CvCNNLayerRelease) 01525 (CvCNNLayer** layer); 01526 01527 typedef void (CV_CDECL *CvCNNetworkAddLayer) 01528 (CvCNNetwork* network, CvCNNLayer* layer); 01529 01530 typedef void (CV_CDECL *CvCNNetworkRelease) 01531 (CvCNNetwork** network); 01532 01533 #define CV_CNN_LAYER_FIELDS() \ 01534 /* Indicator of the layer's type */ \ 01535 int flags; \ 01536 \ 01537 /* Number of input images */ \ 01538 int n_input_planes; \ 01539 /* Height of each input image */ \ 01540 int input_height; \ 01541 /* Width of each input image */ \ 01542 int input_width; \ 01543 \ 01544 /* Number of output images */ \ 01545 int n_output_planes; \ 01546 /* Height of each output image */ \ 01547 int output_height; \ 01548 /* Width of each output image */ \ 01549 int output_width; \ 01550 \ 01551 /* Learning rate at the first iteration */ \ 01552 float init_learn_rate; \ 01553 /* Dynamics of learning rate decreasing */ \ 01554 int learn_rate_decrease_type; \ 01555 /* Trainable weights of the layer (including bias) */ \ 01556 /* i-th row is a set of weights of the i-th output plane */ \ 01557 CvMat* weights; \ 01558 \ 01559 CvCNNLayerForward forward; \ 01560 CvCNNLayerBackward backward; \ 01561 CvCNNLayerRelease release; \ 01562 /* Pointers to the previous and next layers in the network */ \ 01563 CvCNNLayer* prev_layer; \ 01564 CvCNNLayer* next_layer 01565 01566 typedef struct CvCNNLayer 01567 { 01568 CV_CNN_LAYER_FIELDS(); 01569 }CvCNNLayer; 01570 01571 typedef struct CvCNNConvolutionLayer 01572 { 01573 CV_CNN_LAYER_FIELDS(); 01574 // Kernel size (height and width) for convolution. 01575 int K; 01576 // connections matrix, (i,j)-th element is 1 iff there is a connection between 01577 // i-th plane of the current layer and j-th plane of the previous layer; 01578 // (i,j)-th element is equal to 0 otherwise 01579 CvMat *connect_mask; 01580 // value of the learning rate for updating weights at the first iteration 01581 }CvCNNConvolutionLayer; 01582 01583 typedef struct CvCNNSubSamplingLayer 01584 { 01585 CV_CNN_LAYER_FIELDS(); 01586 // ratio between the heights (or widths - ratios are supposed to be equal) 01587 // of the input and output planes 01588 int sub_samp_scale; 01589 // amplitude of sigmoid activation function 01590 float a; 01591 // scale parameter of sigmoid activation function 01592 float s; 01593 // exp2ssumWX = exp(2<s>*(bias+w*(x1+...+x4))), where x1,...x4 are some elements of X 01594 // - is the vector used in computing of the activation function in backward 01595 CvMat* exp2ssumWX; 01596 // (x1+x2+x3+x4), where x1,...x4 are some elements of X 01597 // - is the vector used in computing of the activation function in backward 01598 CvMat* sumX; 01599 }CvCNNSubSamplingLayer; 01600 01601 // Structure of the last layer. 01602 typedef struct CvCNNFullConnectLayer 01603 { 01604 CV_CNN_LAYER_FIELDS(); 01605 // amplitude of sigmoid activation function 01606 float a; 01607 // scale parameter of sigmoid activation function 01608 float s; 01609 // exp2ssumWX = exp(2*<s>*(W*X)) - is the vector used in computing of the 01610 // activation function and it's derivative by the formulae 01611 // activ.func. = <a>(exp(2<s>WX)-1)/(exp(2<s>WX)+1) == <a> - 2<a>/(<exp2ssumWX> + 1) 01612 // (activ.func.)' = 4<a><s>exp(2<s>WX)/(exp(2<s>WX)+1)^2 01613 CvMat* exp2ssumWX; 01614 }CvCNNFullConnectLayer; 01615 01616 typedef struct CvCNNetwork 01617 { 01618 int n_layers; 01619 CvCNNLayer* layers; 01620 CvCNNetworkAddLayer add_layer; 01621 CvCNNetworkRelease release; 01622 }CvCNNetwork; 01623 01624 typedef struct CvCNNStatModel 01625 { 01626 CV_STAT_MODEL_FIELDS(); 01627 CvCNNetwork* network; 01628 // etalons are allocated as rows, the i-th etalon has label cls_labeles[i] 01629 CvMat* etalons; 01630 // classes labels 01631 CvMat* cls_labels; 01632 }CvCNNStatModel; 01633 01634 typedef struct CvCNNStatModelParams 01635 { 01636 CV_STAT_MODEL_PARAM_FIELDS(); 01637 // network must be created by the functions cvCreateCNNetwork and <add_layer> 01638 CvCNNetwork* network; 01639 CvMat* etalons; 01640 // termination criteria 01641 int max_iter; 01642 int start_iter; 01643 int grad_estim_type; 01644 }CvCNNStatModelParams; 01645 01646 CVAPI(CvCNNLayer*) cvCreateCNNConvolutionLayer( 01647 int n_input_planes, int input_height, int input_width, 01648 int n_output_planes, int K, 01649 float init_learn_rate, int learn_rate_decrease_type, 01650 CvMat* connect_mask CV_DEFAULT(0), CvMat* weights CV_DEFAULT(0) ); 01651 01652 CVAPI(CvCNNLayer*) cvCreateCNNSubSamplingLayer( 01653 int n_input_planes, int input_height, int input_width, 01654 int sub_samp_scale, float a, float s, 01655 float init_learn_rate, int learn_rate_decrease_type, CvMat* weights CV_DEFAULT(0) ); 01656 01657 CVAPI(CvCNNLayer*) cvCreateCNNFullConnectLayer( 01658 int n_inputs, int n_outputs, float a, float s, 01659 float init_learn_rate, int learning_type, CvMat* weights CV_DEFAULT(0) ); 01660 01661 CVAPI(CvCNNetwork*) cvCreateCNNetwork( CvCNNLayer* first_layer ); 01662 01663 CVAPI(CvStatModel*) cvTrainCNNClassifier( 01664 const CvMat* train_data, int tflag, 01665 const CvMat* responses, 01666 const CvStatModelParams* params, 01667 const CvMat* CV_DEFAULT(0), 01668 const CvMat* sample_idx CV_DEFAULT(0), 01669 const CvMat* CV_DEFAULT(0), const CvMat* CV_DEFAULT(0) ); 01670 01671 /****************************************************************************************\ 01672 * Estimate classifiers algorithms * 01673 \****************************************************************************************/ 01674 typedef const CvMat* (CV_CDECL *CvStatModelEstimateGetMat) 01675 ( const CvStatModel* estimateModel ); 01676 01677 typedef int (CV_CDECL *CvStatModelEstimateNextStep) 01678 ( CvStatModel* estimateModel ); 01679 01680 typedef void (CV_CDECL *CvStatModelEstimateCheckClassifier) 01681 ( CvStatModel* estimateModel, 01682 const CvStatModel* model, 01683 const CvMat* features, 01684 int sample_t_flag, 01685 const CvMat* responses ); 01686 01687 typedef void (CV_CDECL *CvStatModelEstimateCheckClassifierEasy) 01688 ( CvStatModel* estimateModel, 01689 const CvStatModel* model ); 01690 01691 typedef float (CV_CDECL *CvStatModelEstimateGetCurrentResult) 01692 ( const CvStatModel* estimateModel, 01693 float* correlation ); 01694 01695 typedef void (CV_CDECL *CvStatModelEstimateReset) 01696 ( CvStatModel* estimateModel ); 01697 01698 //-------------------------------- Cross-validation -------------------------------------- 01699 #define CV_CROSS_VALIDATION_ESTIMATE_CLASSIFIER_PARAM_FIELDS() \ 01700 CV_STAT_MODEL_PARAM_FIELDS(); \ 01701 int k_fold; \ 01702 int is_regression; \ 01703 CvRNG* rng 01704 01705 typedef struct CvCrossValidationParams 01706 { 01707 CV_CROSS_VALIDATION_ESTIMATE_CLASSIFIER_PARAM_FIELDS(); 01708 } CvCrossValidationParams; 01709 01710 #define CV_CROSS_VALIDATION_ESTIMATE_CLASSIFIER_FIELDS() \ 01711 CvStatModelEstimateGetMat getTrainIdxMat; \ 01712 CvStatModelEstimateGetMat getCheckIdxMat; \ 01713 CvStatModelEstimateNextStep nextStep; \ 01714 CvStatModelEstimateCheckClassifier check; \ 01715 CvStatModelEstimateGetCurrentResult getResult; \ 01716 CvStatModelEstimateReset reset; \ 01717 int is_regression; \ 01718 int folds_all; \ 01719 int samples_all; \ 01720 int* sampleIdxAll; \ 01721 int* folds; \ 01722 int max_fold_size; \ 01723 int current_fold; \ 01724 int is_checked; \ 01725 CvMat* sampleIdxTrain; \ 01726 CvMat* sampleIdxEval; \ 01727 CvMat* predict_results; \ 01728 int correct_results; \ 01729 int all_results; \ 01730 double sq_error; \ 01731 double sum_correct; \ 01732 double sum_predict; \ 01733 double sum_cc; \ 01734 double sum_pp; \ 01735 double sum_cp 01736 01737 typedef struct CvCrossValidationModel 01738 { 01739 CV_STAT_MODEL_FIELDS(); 01740 CV_CROSS_VALIDATION_ESTIMATE_CLASSIFIER_FIELDS(); 01741 } CvCrossValidationModel; 01742 01743 CVAPI(CvStatModel*) 01744 cvCreateCrossValidationEstimateModel 01745 ( int samples_all, 01746 const CvStatModelParams* estimateParams CV_DEFAULT(0), 01747 const CvMat* sampleIdx CV_DEFAULT(0) ); 01748 01749 CVAPI(float) 01750 cvCrossValidation( const CvMat* trueData, 01751 int tflag, 01752 const CvMat* trueClasses, 01753 CvStatModel* (*createClassifier)( const CvMat*, 01754 int, 01755 const CvMat*, 01756 const CvStatModelParams*, 01757 const CvMat*, 01758 const CvMat*, 01759 const CvMat*, 01760 const CvMat* ), 01761 const CvStatModelParams* estimateParams CV_DEFAULT(0), 01762 const CvStatModelParams* trainParams CV_DEFAULT(0), 01763 const CvMat* compIdx CV_DEFAULT(0), 01764 const CvMat* sampleIdx CV_DEFAULT(0), 01765 CvStatModel** pCrValModel CV_DEFAULT(0), 01766 const CvMat* typeMask CV_DEFAULT(0), 01767 const CvMat* missedMeasurementMask CV_DEFAULT(0) ); 01768 #endif 01769 01770 /****************************************************************************************\ 01771 * Auxilary functions declarations * 01772 \****************************************************************************************/ 01773 01774 /* Generates <sample> from multivariate normal distribution, where <mean> - is an 01775 average row vector, <cov> - symmetric covariation matrix */ 01776 CVAPI(void) cvRandMVNormal( CvMat* mean, CvMat* cov, CvMat* sample, 01777 CvRNG* rng CV_DEFAULT(0) ); 01778 01779 /* Generates sample from gaussian mixture distribution */ 01780 CVAPI(void) cvRandGaussMixture( CvMat* means[], 01781 CvMat* covs[], 01782 float weights[], 01783 int clsnum, 01784 CvMat* sample, 01785 CvMat* sampClasses CV_DEFAULT(0) ); 01786 01787 #define CV_TS_CONCENTRIC_SPHERES 0 01788 01789 /* creates test set */ 01790 CVAPI(void) cvCreateTestSet( int type, CvMat** samples, 01791 int num_samples, 01792 int num_features, 01793 CvMat** responses, 01794 int num_classes, ... ); 01795 01796 01797 #endif 01798 01799 /****************************************************************************************\ 01800 * Data * 01801 \****************************************************************************************/ 01802 01803 #include <map> 01804 #include <string> 01805 #include <iostream> 01806 01807 #define CV_COUNT 0 01808 #define CV_PORTION 1 01809 01810 struct CV_EXPORTS CvTrainTestSplit 01811 { 01812 public: 01813 CvTrainTestSplit(); 01814 CvTrainTestSplit( int _train_sample_count, bool _mix = true); 01815 CvTrainTestSplit( float _train_sample_portion, bool _mix = true); 01816 01817 union 01818 { 01819 int count; 01820 float portion; 01821 } train_sample_part; 01822 int train_sample_part_mode; 01823 01824 union 01825 { 01826 int *count; 01827 float *portion; 01828 } *class_part; 01829 int class_part_mode; 01830 01831 bool mix; 01832 }; 01833 01834 class CV_EXPORTS CvMLData 01835 { 01836 public: 01837 CvMLData(); 01838 virtual ~CvMLData(); 01839 01840 // returns: 01841 // 0 - OK 01842 // 1 - file can not be opened or is not correct 01843 int read_csv(const char* filename); 01844 01845 const CvMat* get_values(){ return values; }; 01846 01847 const CvMat* get_responses(); 01848 01849 const CvMat* get_missing(){ return missing; }; 01850 01851 void set_response_idx( int idx ); // old response become predictors, new response_idx = idx 01852 // if idx < 0 there will be no response 01853 int get_response_idx() { return response_idx; } 01854 01855 const CvMat* get_train_sample_idx() { return train_sample_idx; }; 01856 const CvMat* get_test_sample_idx() { return test_sample_idx; }; 01857 void mix_train_and_test_idx(); 01858 void set_train_test_split( const CvTrainTestSplit * spl); 01859 01860 const CvMat* get_var_idx(); 01861 void chahge_var_idx( int vi, bool state ); // state == true to set vi-variable as predictor 01862 01863 const CvMat* get_var_types(); 01864 int get_var_type( int var_idx ) { return var_types->data.ptr[var_idx]; }; 01865 // following 2 methods enable to change vars type 01866 // use these methods to assign CV_VAR_CATEGORICAL type for categorical variable 01867 // with numerical labels; in the other cases var types are correctly determined automatically 01868 void set_var_types( const char* str ); // str examples: 01869 // "ord[0-17],cat[18]", "ord[0,2,4,10-12], cat[1,3,5-9,13,14]", 01870 // "cat", "ord" (all vars are categorical/ordered) 01871 void change_var_type( int var_idx, int type); // type in { CV_VAR_ORDERED, CV_VAR_CATEGORICAL } 01872 01873 void set_delimiter( char ch ); 01874 char get_delimiter() { return delimiter; }; 01875 01876 void set_miss_ch( char ch ); 01877 char get_miss_ch() { return miss_ch; }; 01878 01879 protected: 01880 virtual void clear(); 01881 01882 void str_to_flt_elem( const char* token, float& flt_elem, int& type); 01883 void free_train_test_idx(); 01884 01885 char delimiter; 01886 char miss_ch; 01887 //char flt_separator; 01888 01889 CvMat* values; 01890 CvMat* missing; 01891 CvMat* var_types; 01892 CvMat* var_idx_mask; 01893 01894 CvMat* response_out; // header 01895 CvMat* var_idx_out; // mat 01896 CvMat* var_types_out; // mat 01897 01898 int response_idx; 01899 01900 int train_sample_count; 01901 bool mix; 01902 01903 int total_class_count; 01904 std::map<std::string, int> *class_map; 01905 01906 CvMat* train_sample_idx; 01907 CvMat* test_sample_idx; 01908 int* sample_idx; // data of train_sample_idx and test_sample_idx 01909 01910 CvRNG rng; 01911 }; 01912 01913 01914 namespace cv 01915 { 01916 01917 typedef CvStatModel StatModel; 01918 typedef CvParamGrid ParamGrid; 01919 typedef CvNormalBayesClassifier NormalBayesClassifier; 01920 typedef CvKNearest KNearest; 01921 typedef CvSVMParams SVMParams; 01922 typedef CvSVMKernel SVMKernel; 01923 typedef CvSVMSolver SVMSolver; 01924 typedef CvSVM SVM; 01925 typedef CvEMParams EMParams; 01926 typedef CvEM ExpectationMaximization; 01927 typedef CvDTreeParams DTreeParams; 01928 typedef CvMLData TrainData; 01929 typedef CvDTree DecisionTree; 01930 typedef CvForestTree ForestTree; 01931 typedef CvRTParams RandomTreeParams; 01932 typedef CvRTrees RandomTrees; 01933 typedef CvERTreeTrainData ERTreeTRainData; 01934 typedef CvForestERTree ERTree; 01935 typedef CvERTrees ERTrees; 01936 typedef CvBoostParams BoostParams; 01937 typedef CvBoostTree BoostTree; 01938 typedef CvBoost Boost; 01939 typedef CvANN_MLP_TrainParams ANN_MLP_TrainParams; 01940 typedef CvANN_MLP NeuralNet_MLP; 01941 01942 } 01943 01944 #endif 01945 /* End of file. */