include/opencv2/ml/ml.hpp
Go to the documentation of this file.
00001 /*M///////////////////////////////////////////////////////////////////////////////////////
00002 //
00003 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
00004 //
00005 //  By downloading, copying, installing or using the software you agree to this license.
00006 //  If you do not agree to this license, do not download, install,
00007 //  copy or use the software.
00008 //
00009 //
00010 //                        Intel License Agreement
00011 //
00012 // Copyright (C) 2000, Intel Corporation, all rights reserved.
00013 // Third party copyrights are property of their respective owners.
00014 //
00015 // Redistribution and use in source and binary forms, with or without modification,
00016 // are permitted provided that the following conditions are met:
00017 //
00018 //   * Redistribution's of source code must retain the above copyright notice,
00019 //     this list of conditions and the following disclaimer.
00020 //
00021 //   * Redistribution's in binary form must reproduce the above copyright notice,
00022 //     this list of conditions and the following disclaimer in the documentation
00023 //     and/or other materials provided with the distribution.
00024 //
00025 //   * The name of Intel Corporation may not be used to endorse or promote products
00026 //     derived from this software without specific prior written permission.
00027 //
00028 // This software is provided by the copyright holders and contributors "as is" and
00029 // any express or implied warranties, including, but not limited to, the implied
00030 // warranties of merchantability and fitness for a particular purpose are disclaimed.
00031 // In no event shall the Intel Corporation or contributors be liable for any direct,
00032 // indirect, incidental, special, exemplary, or consequential damages
00033 // (including, but not limited to, procurement of substitute goods or services;
00034 // loss of use, data, or profits; or business interruption) however caused
00035 // and on any theory of liability, whether in contract, strict liability,
00036 // or tort (including negligence or otherwise) arising in any way out of
00037 // the use of this software, even if advised of the possibility of such damage.
00038 //
00039 //M*/
00040 
00041 #ifndef __OPENCV_ML_HPP__
00042 #define __OPENCV_ML_HPP__
00043 
00044 #include "opencv2/core/core.hpp"
00045 #include <limits.h>
00046 
00047 #ifdef __cplusplus
00048 
00049 #include <map>
00050 #include <string>
00051 #include <iostream>
00052 
00053 // Apple defines a check() macro somewhere in the debug headers
00054 // that interferes with a method definiton in this header
00055 #undef check
00056 
00057 /****************************************************************************************\
00058 *                               Main struct definitions                                  *
00059 \****************************************************************************************/
00060 
00061 /* log(2*PI) */
00062 #define CV_LOG2PI (1.8378770664093454835606594728112)
00063 
00064 /* columns of <trainData> matrix are training samples */
00065 #define CV_COL_SAMPLE 0
00066 
00067 /* rows of <trainData> matrix are training samples */
00068 #define CV_ROW_SAMPLE 1
00069 
00070 #define CV_IS_ROW_SAMPLE(flags) ((flags) & CV_ROW_SAMPLE)
00071 
00072 struct CvVectors
00073 {
00074     int type;
00075     int dims, count;
00076     CvVectors* next;
00077     union
00078     {
00079         uchar** ptr;
00080         float** fl;
00081         double** db;
00082     } data;
00083 };
00084 
00085 #if 0
00086 /* A structure, representing the lattice range of statmodel parameters.
00087    It is used for optimizing statmodel parameters by cross-validation method.
00088    The lattice is logarithmic, so <step> must be greater then 1. */
00089 typedef struct CvParamLattice
00090 {
00091     double min_val;
00092     double max_val;
00093     double step;
00094 }
00095 CvParamLattice;
00096 
00097 CV_INLINE CvParamLattice cvParamLattice( double min_val, double max_val,
00098                                          double log_step )
00099 {
00100     CvParamLattice pl;
00101     pl.min_val = MIN( min_val, max_val );
00102     pl.max_val = MAX( min_val, max_val );
00103     pl.step = MAX( log_step, 1. );
00104     return pl;
00105 }
00106 
00107 CV_INLINE CvParamLattice cvDefaultParamLattice( void )
00108 {
00109     CvParamLattice pl = {0,0,0};
00110     return pl;
00111 }
00112 #endif
00113 
00114 /* Variable type */
00115 #define CV_VAR_NUMERICAL    0
00116 #define CV_VAR_ORDERED      0
00117 #define CV_VAR_CATEGORICAL  1
00118 
00119 #define CV_TYPE_NAME_ML_SVM         "opencv-ml-svm"
00120 #define CV_TYPE_NAME_ML_KNN         "opencv-ml-knn"
00121 #define CV_TYPE_NAME_ML_NBAYES      "opencv-ml-bayesian"
00122 #define CV_TYPE_NAME_ML_EM          "opencv-ml-em"
00123 #define CV_TYPE_NAME_ML_BOOSTING    "opencv-ml-boost-tree"
00124 #define CV_TYPE_NAME_ML_TREE        "opencv-ml-tree"
00125 #define CV_TYPE_NAME_ML_ANN_MLP     "opencv-ml-ann-mlp"
00126 #define CV_TYPE_NAME_ML_CNN         "opencv-ml-cnn"
00127 #define CV_TYPE_NAME_ML_RTREES      "opencv-ml-random-trees"
00128 #define CV_TYPE_NAME_ML_ERTREES     "opencv-ml-extremely-randomized-trees"
00129 #define CV_TYPE_NAME_ML_GBT         "opencv-ml-gradient-boosting-trees"
00130 
00131 #define CV_TRAIN_ERROR  0
00132 #define CV_TEST_ERROR   1
00133 
00134 class CV_EXPORTS_W CvStatModel
00135 {
00136 public:
00137     CvStatModel();
00138     virtual ~CvStatModel();
00139 
00140     virtual void clear();
00141 
00142     CV_WRAP virtual void save( const char* filename, const char* name=0 ) const;
00143     CV_WRAP virtual void load( const char* filename, const char* name=0 );
00144 
00145     virtual void write( CvFileStorage* storage, const char* name ) const;
00146     virtual void read( CvFileStorage* storage, CvFileNode* node );
00147 
00148 protected:
00149     const char* default_model_name;
00150 };
00151 
00152 /****************************************************************************************\
00153 *                                 Normal Bayes Classifier                                *
00154 \****************************************************************************************/
00155 
00156 /* The structure, representing the grid range of statmodel parameters.
00157    It is used for optimizing statmodel accuracy by varying model parameters,
00158    the accuracy estimate being computed by cross-validation.
00159    The grid is logarithmic, so <step> must be greater then 1. */
00160 
00161 class CvMLData;
00162 
00163 struct CV_EXPORTS_W_MAP CvParamGrid
00164 {
00165     // SVM params type
00166     enum { SVM_C=0, SVM_GAMMA=1, SVM_P=2, SVM_NU=3, SVM_COEF=4, SVM_DEGREE=5 };
00167 
00168     CvParamGrid()
00169     {
00170         min_val = max_val = step = 0;
00171     }
00172 
00173     CvParamGrid( double min_val, double max_val, double log_step );
00174     //CvParamGrid( int param_id );
00175     bool check() const;
00176 
00177     CV_PROP_RW double min_val;
00178     CV_PROP_RW double max_val;
00179     CV_PROP_RW double step;
00180 };
00181 
00182 inline CvParamGrid::CvParamGrid( double _min_val, double _max_val, double _log_step )
00183 {
00184     min_val = _min_val;
00185     max_val = _max_val;
00186     step = _log_step;
00187 }
00188 
00189 class CV_EXPORTS_W CvNormalBayesClassifier : public CvStatModel
00190 {
00191 public:
00192     CV_WRAP CvNormalBayesClassifier();
00193     virtual ~CvNormalBayesClassifier();
00194 
00195     CvNormalBayesClassifier( const CvMat* trainData, const CvMat* responses,
00196         const CvMat* varIdx=0, const CvMat* sampleIdx=0 );
00197 
00198     virtual bool train( const CvMat* trainData, const CvMat* responses,
00199         const CvMat* varIdx = 0, const CvMat* sampleIdx=0, bool update=false );
00200 
00201     virtual float predict( const CvMat* samples, CV_OUT CvMat* results=0 ) const;
00202     CV_WRAP virtual void clear();
00203 
00204     CV_WRAP CvNormalBayesClassifier( const cv::Mat& trainData, const cv::Mat& responses,
00205                             const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat() );
00206     CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
00207                        const cv::Mat& varIdx = cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
00208                        bool update=false );
00209     CV_WRAP virtual float predict( const cv::Mat& samples, CV_OUT cv::Mat* results=0 ) const;
00210 
00211     virtual void write( CvFileStorage* storage, const char* name ) const;
00212     virtual void read( CvFileStorage* storage, CvFileNode* node );
00213 
00214 protected:
00215     int     var_count, var_all;
00216     CvMat*  var_idx;
00217     CvMat*  cls_labels;
00218     CvMat** count;
00219     CvMat** sum;
00220     CvMat** productsum;
00221     CvMat** avg;
00222     CvMat** inv_eigen_values;
00223     CvMat** cov_rotate_mats;
00224     CvMat*  c;
00225 };
00226 
00227 
00228 /****************************************************************************************\
00229 *                          K-Nearest Neighbour Classifier                                *
00230 \****************************************************************************************/
00231 
00232 // k Nearest Neighbors
00233 class CV_EXPORTS_W CvKNearest : public CvStatModel
00234 {
00235 public:
00236 
00237     CV_WRAP CvKNearest();
00238     virtual ~CvKNearest();
00239 
00240     CvKNearest( const CvMat* trainData, const CvMat* responses,
00241                 const CvMat* sampleIdx=0, bool isRegression=false, int max_k=32 );
00242 
00243     virtual bool train( const CvMat* trainData, const CvMat* responses,
00244                         const CvMat* sampleIdx=0, bool is_regression=false,
00245                         int maxK=32, bool updateBase=false );
00246 
00247     virtual float find_nearest( const CvMat* samples, int k, CV_OUT CvMat* results=0,
00248         const float** neighbors=0, CV_OUT CvMat* neighborResponses=0, CV_OUT CvMat* dist=0 ) const;
00249 
00250     CV_WRAP CvKNearest( const cv::Mat& trainData, const cv::Mat& responses,
00251                const cv::Mat& sampleIdx=cv::Mat(), bool isRegression=false, int max_k=32 );
00252 
00253     CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
00254                        const cv::Mat& sampleIdx=cv::Mat(), bool isRegression=false,
00255                        int maxK=32, bool updateBase=false );
00256 
00257     virtual float find_nearest( const cv::Mat& samples, int k, cv::Mat* results=0,
00258                                 const float** neighbors=0, cv::Mat* neighborResponses=0,
00259                                 cv::Mat* dist=0 ) const;
00260     CV_WRAP virtual float find_nearest( const cv::Mat& samples, int k, CV_OUT cv::Mat& results,
00261                                         CV_OUT cv::Mat& neighborResponses, CV_OUT cv::Mat& dists) const;
00262 
00263     virtual void clear();
00264     int get_max_k() const;
00265     int get_var_count() const;
00266     int get_sample_count() const;
00267     bool is_regression() const;
00268 
00269     virtual float write_results( int k, int k1, int start, int end,
00270         const float* neighbor_responses, const float* dist, CvMat* _results,
00271         CvMat* _neighbor_responses, CvMat* _dist, Cv32suf* sort_buf ) const;
00272 
00273     virtual void find_neighbors_direct( const CvMat* _samples, int k, int start, int end,
00274         float* neighbor_responses, const float** neighbors, float* dist ) const;
00275 
00276 protected:
00277 
00278     int max_k, var_count;
00279     int total;
00280     bool regression;
00281     CvVectors* samples;
00282 };
00283 
00284 /****************************************************************************************\
00285 *                                   Support Vector Machines                              *
00286 \****************************************************************************************/
00287 
00288 // SVM training parameters
00289 struct CV_EXPORTS_W_MAP CvSVMParams
00290 {
00291     CvSVMParams();
00292     CvSVMParams( int svm_type, int kernel_type,
00293                  double degree, double gamma, double coef0,
00294                  double Cvalue, double nu, double p,
00295                  CvMat* class_weights, CvTermCriteria term_crit );
00296 
00297     CV_PROP_RW int         svm_type;
00298     CV_PROP_RW int         kernel_type;
00299     CV_PROP_RW double      degree; // for poly
00300     CV_PROP_RW double      gamma;  // for poly/rbf/sigmoid
00301     CV_PROP_RW double      coef0;  // for poly/sigmoid
00302 
00303     CV_PROP_RW double      C;  // for CV_SVM_C_SVC, CV_SVM_EPS_SVR and CV_SVM_NU_SVR
00304     CV_PROP_RW double      nu; // for CV_SVM_NU_SVC, CV_SVM_ONE_CLASS, and CV_SVM_NU_SVR
00305     CV_PROP_RW double      p; // for CV_SVM_EPS_SVR
00306     CvMat*      class_weights; // for CV_SVM_C_SVC
00307     CV_PROP_RW CvTermCriteria term_crit; // termination criteria
00308 };
00309 
00310 
00311 struct CV_EXPORTS CvSVMKernel
00312 {
00313     typedef void (CvSVMKernel::*Calc)( int vec_count, int vec_size, const float** vecs,
00314                                        const float* another, float* results );
00315     CvSVMKernel();
00316     CvSVMKernel( const CvSVMParams* params, Calc _calc_func );
00317     virtual bool create( const CvSVMParams* params, Calc _calc_func );
00318     virtual ~CvSVMKernel();
00319 
00320     virtual void clear();
00321     virtual void calc( int vcount, int n, const float** vecs, const float* another, float* results );
00322 
00323     const CvSVMParams* params;
00324     Calc calc_func;
00325 
00326     virtual void calc_non_rbf_base( int vec_count, int vec_size, const float** vecs,
00327                                     const float* another, float* results,
00328                                     double alpha, double beta );
00329 
00330     virtual void calc_linear( int vec_count, int vec_size, const float** vecs,
00331                               const float* another, float* results );
00332     virtual void calc_rbf( int vec_count, int vec_size, const float** vecs,
00333                            const float* another, float* results );
00334     virtual void calc_poly( int vec_count, int vec_size, const float** vecs,
00335                             const float* another, float* results );
00336     virtual void calc_sigmoid( int vec_count, int vec_size, const float** vecs,
00337                                const float* another, float* results );
00338 };
00339 
00340 
00341 struct CvSVMKernelRow
00342 {
00343     CvSVMKernelRow* prev;
00344     CvSVMKernelRow* next;
00345     float* data;
00346 };
00347 
00348 
00349 struct CvSVMSolutionInfo
00350 {
00351     double obj;
00352     double rho;
00353     double upper_bound_p;
00354     double upper_bound_n;
00355     double r;   // for Solver_NU
00356 };
00357 
00358 class CV_EXPORTS CvSVMSolver
00359 {
00360 public:
00361     typedef bool (CvSVMSolver::*SelectWorkingSet)( int& i, int& j );
00362     typedef float* (CvSVMSolver::*GetRow)( int i, float* row, float* dst, bool existed );
00363     typedef void (CvSVMSolver::*CalcRho)( double& rho, double& r );
00364 
00365     CvSVMSolver();
00366 
00367     CvSVMSolver( int count, int var_count, const float** samples, schar* y,
00368                  int alpha_count, double* alpha, double Cp, double Cn,
00369                  CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
00370                  SelectWorkingSet select_working_set, CalcRho calc_rho );
00371     virtual bool create( int count, int var_count, const float** samples, schar* y,
00372                  int alpha_count, double* alpha, double Cp, double Cn,
00373                  CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
00374                  SelectWorkingSet select_working_set, CalcRho calc_rho );
00375     virtual ~CvSVMSolver();
00376 
00377     virtual void clear();
00378     virtual bool solve_generic( CvSVMSolutionInfo& si );
00379 
00380     virtual bool solve_c_svc( int count, int var_count, const float** samples, schar* y,
00381                               double Cp, double Cn, CvMemStorage* storage,
00382                               CvSVMKernel* kernel, double* alpha, CvSVMSolutionInfo& si );
00383     virtual bool solve_nu_svc( int count, int var_count, const float** samples, schar* y,
00384                                CvMemStorage* storage, CvSVMKernel* kernel,
00385                                double* alpha, CvSVMSolutionInfo& si );
00386     virtual bool solve_one_class( int count, int var_count, const float** samples,
00387                                   CvMemStorage* storage, CvSVMKernel* kernel,
00388                                   double* alpha, CvSVMSolutionInfo& si );
00389 
00390     virtual bool solve_eps_svr( int count, int var_count, const float** samples, const float* y,
00391                                 CvMemStorage* storage, CvSVMKernel* kernel,
00392                                 double* alpha, CvSVMSolutionInfo& si );
00393 
00394     virtual bool solve_nu_svr( int count, int var_count, const float** samples, const float* y,
00395                                CvMemStorage* storage, CvSVMKernel* kernel,
00396                                double* alpha, CvSVMSolutionInfo& si );
00397 
00398     virtual float* get_row_base( int i, bool* _existed );
00399     virtual float* get_row( int i, float* dst );
00400 
00401     int sample_count;
00402     int var_count;
00403     int cache_size;
00404     int cache_line_size;
00405     const float** samples;
00406     const CvSVMParams* params;
00407     CvMemStorage* storage;
00408     CvSVMKernelRow lru_list;
00409     CvSVMKernelRow* rows;
00410 
00411     int alpha_count;
00412 
00413     double* G;
00414     double* alpha;
00415 
00416     // -1 - lower bound, 0 - free, 1 - upper bound
00417     schar* alpha_status;
00418 
00419     schar* y;
00420     double* b;
00421     float* buf[2];
00422     double eps;
00423     int max_iter;
00424     double C[2];  // C[0] == Cn, C[1] == Cp
00425     CvSVMKernel* kernel;
00426 
00427     SelectWorkingSet select_working_set_func;
00428     CalcRho calc_rho_func;
00429     GetRow get_row_func;
00430 
00431     virtual bool select_working_set( int& i, int& j );
00432     virtual bool select_working_set_nu_svm( int& i, int& j );
00433     virtual void calc_rho( double& rho, double& r );
00434     virtual void calc_rho_nu_svm( double& rho, double& r );
00435 
00436     virtual float* get_row_svc( int i, float* row, float* dst, bool existed );
00437     virtual float* get_row_one_class( int i, float* row, float* dst, bool existed );
00438     virtual float* get_row_svr( int i, float* row, float* dst, bool existed );
00439 };
00440 
00441 
00442 struct CvSVMDecisionFunc
00443 {
00444     double rho;
00445     int sv_count;
00446     double* alpha;
00447     int* sv_index;
00448 };
00449 
00450 
00451 // SVM model
00452 class CV_EXPORTS_W CvSVM : public CvStatModel
00453 {
00454 public:
00455     // SVM type
00456     enum { C_SVC=100, NU_SVC=101, ONE_CLASS=102, EPS_SVR=103, NU_SVR=104 };
00457 
00458     // SVM kernel type
00459     enum { LINEAR=0, POLY=1, RBF=2, SIGMOID=3 };
00460 
00461     // SVM params type
00462     enum { C=0, GAMMA=1, P=2, NU=3, COEF=4, DEGREE=5 };
00463 
00464     CV_WRAP CvSVM();
00465     virtual ~CvSVM();
00466 
00467     CvSVM( const CvMat* trainData, const CvMat* responses,
00468            const CvMat* varIdx=0, const CvMat* sampleIdx=0,
00469            CvSVMParams params=CvSVMParams() );
00470 
00471     virtual bool train( const CvMat* trainData, const CvMat* responses,
00472                         const CvMat* varIdx=0, const CvMat* sampleIdx=0,
00473                         CvSVMParams params=CvSVMParams() );
00474 
00475     virtual bool train_auto( const CvMat* trainData, const CvMat* responses,
00476         const CvMat* varIdx, const CvMat* sampleIdx, CvSVMParams params,
00477         int kfold = 10,
00478         CvParamGrid Cgrid      = get_default_grid(CvSVM::C),
00479         CvParamGrid gammaGrid  = get_default_grid(CvSVM::GAMMA),
00480         CvParamGrid pGrid      = get_default_grid(CvSVM::P),
00481         CvParamGrid nuGrid     = get_default_grid(CvSVM::NU),
00482         CvParamGrid coeffGrid  = get_default_grid(CvSVM::COEF),
00483         CvParamGrid degreeGrid = get_default_grid(CvSVM::DEGREE),
00484         bool balanced=false );
00485 
00486     virtual float predict( const CvMat* sample, bool returnDFVal=false ) const;
00487     virtual float predict( const CvMat* samples, CV_OUT CvMat* results ) const;
00488 
00489     CV_WRAP CvSVM( const cv::Mat& trainData, const cv::Mat& responses,
00490           const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
00491           CvSVMParams params=CvSVMParams() );
00492 
00493     CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
00494                        const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
00495                        CvSVMParams params=CvSVMParams() );
00496 
00497     CV_WRAP virtual bool train_auto( const cv::Mat& trainData, const cv::Mat& responses,
00498                             const cv::Mat& varIdx, const cv::Mat& sampleIdx, CvSVMParams params,
00499                             int k_fold = 10,
00500                             CvParamGrid Cgrid      = CvSVM::get_default_grid(CvSVM::C),
00501                             CvParamGrid gammaGrid  = CvSVM::get_default_grid(CvSVM::GAMMA),
00502                             CvParamGrid pGrid      = CvSVM::get_default_grid(CvSVM::P),
00503                             CvParamGrid nuGrid     = CvSVM::get_default_grid(CvSVM::NU),
00504                             CvParamGrid coeffGrid  = CvSVM::get_default_grid(CvSVM::COEF),
00505                             CvParamGrid degreeGrid = CvSVM::get_default_grid(CvSVM::DEGREE),
00506                             bool balanced=false);
00507     CV_WRAP virtual float predict( const cv::Mat& sample, bool returnDFVal=false ) const;
00508     CV_WRAP_AS(predict_all) void predict( cv::InputArray samples, cv::OutputArray results ) const;
00509 
00510     CV_WRAP virtual int get_support_vector_count() const;
00511     virtual const float* get_support_vector(int i) const;
00512     virtual CvSVMParams get_params() const { return params; };
00513     CV_WRAP virtual void clear();
00514 
00515     static CvParamGrid get_default_grid( int param_id );
00516 
00517     virtual void write( CvFileStorage* storage, const char* name ) const;
00518     virtual void read( CvFileStorage* storage, CvFileNode* node );
00519     CV_WRAP int get_var_count() const { return var_idx ? var_idx->cols : var_all; }
00520 
00521 protected:
00522 
00523     virtual bool set_params( const CvSVMParams& params );
00524     virtual bool train1( int sample_count, int var_count, const float** samples,
00525                     const void* responses, double Cp, double Cn,
00526                     CvMemStorage* _storage, double* alpha, double& rho );
00527     virtual bool do_train( int svm_type, int sample_count, int var_count, const float** samples,
00528                     const CvMat* responses, CvMemStorage* _storage, double* alpha );
00529     virtual void create_kernel();
00530     virtual void create_solver();
00531 
00532     virtual float predict( const float* row_sample, int row_len, bool returnDFVal=false ) const;
00533 
00534     virtual void write_params( CvFileStorage* fs ) const;
00535     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
00536 
00537     CvSVMParams params;
00538     CvMat* class_labels;
00539     int var_all;
00540     float** sv;
00541     int sv_total;
00542     CvMat* var_idx;
00543     CvMat* class_weights;
00544     CvSVMDecisionFunc* decision_func;
00545     CvMemStorage* storage;
00546 
00547     CvSVMSolver* solver;
00548     CvSVMKernel* kernel;
00549 };
00550 
00551 /****************************************************************************************\
00552 *                              Expectation - Maximization                                *
00553 \****************************************************************************************/
00554 namespace cv
00555 {
00556 class CV_EXPORTS_W EM : public Algorithm
00557 {
00558 public:
00559     // Type of covariation matrices
00560     enum {COV_MAT_SPHERICAL=0, COV_MAT_DIAGONAL=1, COV_MAT_GENERIC=2, COV_MAT_DEFAULT=COV_MAT_DIAGONAL};
00561 
00562     // Default parameters
00563     enum {DEFAULT_NCLUSTERS=5, DEFAULT_MAX_ITERS=100};
00564 
00565     // The initial step
00566     enum {START_E_STEP=1, START_M_STEP=2, START_AUTO_STEP=0};
00567 
00568     CV_WRAP EM(int nclusters=EM::DEFAULT_NCLUSTERS, int covMatType=EM::COV_MAT_DIAGONAL,
00569        const TermCriteria& termCrit=TermCriteria(TermCriteria::COUNT+TermCriteria::EPS,
00570                                                  EM::DEFAULT_MAX_ITERS, FLT_EPSILON));
00571 
00572     virtual ~EM();
00573     CV_WRAP virtual void clear();
00574 
00575     CV_WRAP virtual bool train(InputArray samples,
00576                        OutputArray logLikelihoods=noArray(),
00577                        OutputArray labels=noArray(),
00578                        OutputArray probs=noArray());
00579 
00580     CV_WRAP virtual bool trainE(InputArray samples,
00581                         InputArray means0,
00582                         InputArray covs0=noArray(),
00583                         InputArray weights0=noArray(),
00584                         OutputArray logLikelihoods=noArray(),
00585                         OutputArray labels=noArray(),
00586                         OutputArray probs=noArray());
00587 
00588     CV_WRAP virtual bool trainM(InputArray samples,
00589                         InputArray probs0,
00590                         OutputArray logLikelihoods=noArray(),
00591                         OutputArray labels=noArray(),
00592                         OutputArray probs=noArray());
00593 
00594     CV_WRAP Vec2d predict(InputArray sample,
00595                 OutputArray probs=noArray()) const;
00596 
00597     CV_WRAP bool isTrained() const;
00598 
00599     AlgorithmInfo* info() const;
00600     virtual void read(const FileNode& fn);
00601 
00602 protected:
00603 
00604     virtual void setTrainData(int startStep, const Mat& samples,
00605                               const Mat* probs0,
00606                               const Mat* means0,
00607                               const vector<Mat>* covs0,
00608                               const Mat* weights0);
00609 
00610     bool doTrain(int startStep,
00611                  OutputArray logLikelihoods,
00612                  OutputArray labels,
00613                  OutputArray probs);
00614     virtual void eStep();
00615     virtual void mStep();
00616 
00617     void clusterTrainSamples();
00618     void decomposeCovs();
00619     void computeLogWeightDivDet();
00620 
00621     Vec2d computeProbabilities(const Mat& sample, Mat* probs) const;
00622 
00623     // all inner matrices have type CV_64FC1
00624     CV_PROP_RW int nclusters;
00625     CV_PROP_RW int covMatType;
00626     CV_PROP_RW int maxIters;
00627     CV_PROP_RW double epsilon;
00628 
00629     Mat trainSamples;
00630     Mat trainProbs;
00631     Mat trainLogLikelihoods;
00632     Mat trainLabels;
00633 
00634     CV_PROP Mat weights;
00635     CV_PROP Mat means;
00636     CV_PROP vector<Mat> covs;
00637 
00638     vector<Mat> covsEigenValues;
00639     vector<Mat> covsRotateMats;
00640     vector<Mat> invCovsEigenValues;
00641     Mat logWeightDivDet;
00642 };
00643 } // namespace cv
00644 
00645 /****************************************************************************************\
00646 *                                      Decision Tree                                     *
00647 \****************************************************************************************/\
00648 struct CvPair16u32s
00649 {
00650     unsigned short* u;
00651     int* i;
00652 };
00653 
00654 
00655 #define CV_DTREE_CAT_DIR(idx,subset) \
00656     (2*((subset[(idx)>>5]&(1 << ((idx) & 31)))==0)-1)
00657 
00658 struct CvDTreeSplit
00659 {
00660     int var_idx;
00661     int condensed_idx;
00662     int inversed;
00663     float quality;
00664     CvDTreeSplit* next;
00665     union
00666     {
00667         int subset[2];
00668         struct
00669         {
00670             float c;
00671             int split_point;
00672         }
00673         ord;
00674     };
00675 };
00676 
00677 struct CvDTreeNode
00678 {
00679     int class_idx;
00680     int Tn;
00681     double value;
00682 
00683     CvDTreeNode* parent;
00684     CvDTreeNode* left;
00685     CvDTreeNode* right;
00686 
00687     CvDTreeSplit* split;
00688 
00689     int sample_count;
00690     int depth;
00691     int* num_valid;
00692     int offset;
00693     int buf_idx;
00694     double maxlr;
00695 
00696     // global pruning data
00697     int complexity;
00698     double alpha;
00699     double node_risk, tree_risk, tree_error;
00700 
00701     // cross-validation pruning data
00702     int* cv_Tn;
00703     double* cv_node_risk;
00704     double* cv_node_error;
00705 
00706     int get_num_valid(int vi) { return num_valid ? num_valid[vi] : sample_count; }
00707     void set_num_valid(int vi, int n) { if( num_valid ) num_valid[vi] = n; }
00708 };
00709 
00710 
00711 struct CV_EXPORTS_W_MAP CvDTreeParams
00712 {
00713     CV_PROP_RW int   max_categories;
00714     CV_PROP_RW int   max_depth;
00715     CV_PROP_RW int   min_sample_count;
00716     CV_PROP_RW int   cv_folds;
00717     CV_PROP_RW bool  use_surrogates;
00718     CV_PROP_RW bool  use_1se_rule;
00719     CV_PROP_RW bool  truncate_pruned_tree;
00720     CV_PROP_RW float regression_accuracy;
00721     const float* priors;
00722 
00723     CvDTreeParams();
00724     CvDTreeParams( int max_depth, int min_sample_count,
00725                    float regression_accuracy, bool use_surrogates,
00726                    int max_categories, int cv_folds,
00727                    bool use_1se_rule, bool truncate_pruned_tree,
00728                    const float* priors );
00729 };
00730 
00731 
00732 struct CV_EXPORTS CvDTreeTrainData
00733 {
00734     CvDTreeTrainData();
00735     CvDTreeTrainData( const CvMat* trainData, int tflag,
00736                       const CvMat* responses, const CvMat* varIdx=0,
00737                       const CvMat* sampleIdx=0, const CvMat* varType=0,
00738                       const CvMat* missingDataMask=0,
00739                       const CvDTreeParams& params=CvDTreeParams(),
00740                       bool _shared=false, bool _add_labels=false );
00741     virtual ~CvDTreeTrainData();
00742 
00743     virtual void set_data( const CvMat* trainData, int tflag,
00744                           const CvMat* responses, const CvMat* varIdx=0,
00745                           const CvMat* sampleIdx=0, const CvMat* varType=0,
00746                           const CvMat* missingDataMask=0,
00747                           const CvDTreeParams& params=CvDTreeParams(),
00748                           bool _shared=false, bool _add_labels=false,
00749                           bool _update_data=false );
00750     virtual void do_responses_copy();
00751 
00752     virtual void get_vectors( const CvMat* _subsample_idx,
00753          float* values, uchar* missing, float* responses, bool get_class_idx=false );
00754 
00755     virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
00756 
00757     virtual void write_params( CvFileStorage* fs ) const;
00758     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
00759 
00760     // release all the data
00761     virtual void clear();
00762 
00763     int get_num_classes() const;
00764     int get_var_type(int vi) const;
00765     int get_work_var_count() const {return work_var_count;}
00766 
00767     virtual const float* get_ord_responses( CvDTreeNode* n, float* values_buf, int* sample_indices_buf );
00768     virtual const int* get_class_labels( CvDTreeNode* n, int* labels_buf );
00769     virtual const int* get_cv_labels( CvDTreeNode* n, int* labels_buf );
00770     virtual const int* get_sample_indices( CvDTreeNode* n, int* indices_buf );
00771     virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf );
00772     virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* sorted_indices_buf,
00773                                    const float** ord_values, const int** sorted_indices, int* sample_indices_buf );
00774     virtual int get_child_buf_idx( CvDTreeNode* n );
00775 
00777 
00778     virtual bool set_params( const CvDTreeParams& params );
00779     virtual CvDTreeNode* new_node( CvDTreeNode* parent, int count,
00780                                    int storage_idx, int offset );
00781 
00782     virtual CvDTreeSplit* new_split_ord( int vi, float cmp_val,
00783                 int split_point, int inversed, float quality );
00784     virtual CvDTreeSplit* new_split_cat( int vi, float quality );
00785     virtual void free_node_data( CvDTreeNode* node );
00786     virtual void free_train_data();
00787     virtual void free_node( CvDTreeNode* node );
00788 
00789     int sample_count, var_all, var_count, max_c_count;
00790     int ord_var_count, cat_var_count, work_var_count;
00791     bool have_labels, have_priors;
00792     bool is_classifier;
00793     int tflag;
00794 
00795     const CvMat* train_data;
00796     const CvMat* responses;
00797     CvMat* responses_copy; // used in Boosting
00798 
00799     int buf_count, buf_size;
00800     bool shared;
00801     int is_buf_16u;
00802 
00803     CvMat* cat_count;
00804     CvMat* cat_ofs;
00805     CvMat* cat_map;
00806 
00807     CvMat* counts;
00808     CvMat* buf;
00809     CvMat* direction;
00810     CvMat* split_buf;
00811 
00812     CvMat* var_idx;
00813     CvMat* var_type; // i-th element =
00814                      //   k<0  - ordered
00815                      //   k>=0 - categorical, see k-th element of cat_* arrays
00816     CvMat* priors;
00817     CvMat* priors_mult;
00818 
00819     CvDTreeParams params;
00820 
00821     CvMemStorage* tree_storage;
00822     CvMemStorage* temp_storage;
00823 
00824     CvDTreeNode* data_root;
00825 
00826     CvSet* node_heap;
00827     CvSet* split_heap;
00828     CvSet* cv_heap;
00829     CvSet* nv_heap;
00830 
00831     cv::RNG* rng;
00832 };
00833 
00834 class CvDTree;
00835 class CvForestTree;
00836 
00837 namespace cv
00838 {
00839     struct DTreeBestSplitFinder;
00840     struct ForestTreeBestSplitFinder;
00841 }
00842 
00843 class CV_EXPORTS_W CvDTree : public CvStatModel
00844 {
00845 public:
00846     CV_WRAP CvDTree();
00847     virtual ~CvDTree();
00848 
00849     virtual bool train( const CvMat* trainData, int tflag,
00850                         const CvMat* responses, const CvMat* varIdx=0,
00851                         const CvMat* sampleIdx=0, const CvMat* varType=0,
00852                         const CvMat* missingDataMask=0,
00853                         CvDTreeParams params=CvDTreeParams() );
00854 
00855     virtual bool train( CvMLData* trainData, CvDTreeParams params=CvDTreeParams() );
00856 
00857     // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
00858     virtual float calc_error( CvMLData* trainData, int type, std::vector<float> *resp = 0 );
00859 
00860     virtual bool train( CvDTreeTrainData* trainData, const CvMat* subsampleIdx );
00861 
00862     virtual CvDTreeNode* predict( const CvMat* sample, const CvMat* missingDataMask=0,
00863                                   bool preprocessedInput=false ) const;
00864 
00865     CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
00866                        const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
00867                        const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
00868                        const cv::Mat& missingDataMask=cv::Mat(),
00869                        CvDTreeParams params=CvDTreeParams() );
00870 
00871     CV_WRAP virtual CvDTreeNode* predict( const cv::Mat& sample, const cv::Mat& missingDataMask=cv::Mat(),
00872                                   bool preprocessedInput=false ) const;
00873     CV_WRAP virtual cv::Mat getVarImportance();
00874 
00875     virtual const CvMat* get_var_importance();
00876     CV_WRAP virtual void clear();
00877 
00878     virtual void read( CvFileStorage* fs, CvFileNode* node );
00879     virtual void write( CvFileStorage* fs, const char* name ) const;
00880 
00881     // special read & write methods for trees in the tree ensembles
00882     virtual void read( CvFileStorage* fs, CvFileNode* node,
00883                        CvDTreeTrainData* data );
00884     virtual void write( CvFileStorage* fs ) const;
00885 
00886     const CvDTreeNode* get_root() const;
00887     int get_pruned_tree_idx() const;
00888     CvDTreeTrainData* get_data();
00889 
00890 protected:
00891     friend struct cv::DTreeBestSplitFinder;
00892 
00893     virtual bool do_train( const CvMat* _subsample_idx );
00894 
00895     virtual void try_split_node( CvDTreeNode* n );
00896     virtual void split_node_data( CvDTreeNode* n );
00897     virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
00898     virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi,
00899                             float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
00900     virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
00901                             float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
00902     virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi,
00903                             float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
00904     virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi,
00905                             float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
00906     virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
00907     virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
00908     virtual double calc_node_dir( CvDTreeNode* node );
00909     virtual void complete_node_dir( CvDTreeNode* node );
00910     virtual void cluster_categories( const int* vectors, int vector_count,
00911         int var_count, int* sums, int k, int* cluster_labels );
00912 
00913     virtual void calc_node_value( CvDTreeNode* node );
00914 
00915     virtual void prune_cv();
00916     virtual double update_tree_rnc( int T, int fold );
00917     virtual int cut_tree( int T, int fold, double min_alpha );
00918     virtual void free_prune_data(bool cut_tree);
00919     virtual void free_tree();
00920 
00921     virtual void write_node( CvFileStorage* fs, CvDTreeNode* node ) const;
00922     virtual void write_split( CvFileStorage* fs, CvDTreeSplit* split ) const;
00923     virtual CvDTreeNode* read_node( CvFileStorage* fs, CvFileNode* node, CvDTreeNode* parent );
00924     virtual CvDTreeSplit* read_split( CvFileStorage* fs, CvFileNode* node );
00925     virtual void write_tree_nodes( CvFileStorage* fs ) const;
00926     virtual void read_tree_nodes( CvFileStorage* fs, CvFileNode* node );
00927 
00928     CvDTreeNode* root;
00929     CvMat* var_importance;
00930     CvDTreeTrainData* data;
00931 
00932 public:
00933     int pruned_tree_idx;
00934 };
00935 
00936 
00937 /****************************************************************************************\
00938 *                                   Random Trees Classifier                              *
00939 \****************************************************************************************/
00940 
00941 class CvRTrees;
00942 
00943 class CV_EXPORTS CvForestTree: public CvDTree
00944 {
00945 public:
00946     CvForestTree();
00947     virtual ~CvForestTree();
00948 
00949     virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx, CvRTrees* forest );
00950 
00951     virtual int get_var_count() const {return data ? data->var_count : 0;}
00952     virtual void read( CvFileStorage* fs, CvFileNode* node, CvRTrees* forest, CvDTreeTrainData* _data );
00953 
00954     /* dummy methods to avoid warnings: BEGIN */
00955     virtual bool train( const CvMat* trainData, int tflag,
00956                         const CvMat* responses, const CvMat* varIdx=0,
00957                         const CvMat* sampleIdx=0, const CvMat* varType=0,
00958                         const CvMat* missingDataMask=0,
00959                         CvDTreeParams params=CvDTreeParams() );
00960 
00961     virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx );
00962     virtual void read( CvFileStorage* fs, CvFileNode* node );
00963     virtual void read( CvFileStorage* fs, CvFileNode* node,
00964                        CvDTreeTrainData* data );
00965     /* dummy methods to avoid warnings: END */
00966 
00967 protected:
00968     friend struct cv::ForestTreeBestSplitFinder;
00969 
00970     virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
00971     CvRTrees* forest;
00972 };
00973 
00974 
00975 struct CV_EXPORTS_W_MAP CvRTParams : public CvDTreeParams
00976 {
00977     //Parameters for the forest
00978     CV_PROP_RW bool calc_var_importance; // true <=> RF processes variable importance
00979     CV_PROP_RW int nactive_vars;
00980     CV_PROP_RW CvTermCriteria term_crit;
00981 
00982     CvRTParams();
00983     CvRTParams( int max_depth, int min_sample_count,
00984                 float regression_accuracy, bool use_surrogates,
00985                 int max_categories, const float* priors, bool calc_var_importance,
00986                 int nactive_vars, int max_num_of_trees_in_the_forest,
00987                 float forest_accuracy, int termcrit_type );
00988 };
00989 
00990 
00991 class CV_EXPORTS_W CvRTrees : public CvStatModel
00992 {
00993 public:
00994     CV_WRAP CvRTrees();
00995     virtual ~CvRTrees();
00996     virtual bool train( const CvMat* trainData, int tflag,
00997                         const CvMat* responses, const CvMat* varIdx=0,
00998                         const CvMat* sampleIdx=0, const CvMat* varType=0,
00999                         const CvMat* missingDataMask=0,
01000                         CvRTParams params=CvRTParams() );
01001 
01002     virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );
01003     virtual float predict( const CvMat* sample, const CvMat* missing = 0 ) const;
01004     virtual float predict_prob( const CvMat* sample, const CvMat* missing = 0 ) const;
01005 
01006     CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
01007                        const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
01008                        const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
01009                        const cv::Mat& missingDataMask=cv::Mat(),
01010                        CvRTParams params=CvRTParams() );
01011     CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;
01012     CV_WRAP virtual float predict_prob( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;
01013     CV_WRAP virtual cv::Mat getVarImportance();
01014 
01015     CV_WRAP virtual void clear();
01016 
01017     virtual const CvMat* get_var_importance();
01018     virtual float get_proximity( const CvMat* sample1, const CvMat* sample2,
01019         const CvMat* missing1 = 0, const CvMat* missing2 = 0 ) const;
01020 
01021     virtual float calc_error( CvMLData* data, int type , std::vector<float>* resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
01022 
01023     virtual float get_train_error();
01024 
01025     virtual void read( CvFileStorage* fs, CvFileNode* node );
01026     virtual void write( CvFileStorage* fs, const char* name ) const;
01027 
01028     CvMat* get_active_var_mask();
01029     CvRNG* get_rng();
01030 
01031     int get_tree_count() const;
01032     CvForestTree* get_tree(int i) const;
01033 
01034 protected:
01035     virtual std::string getName() const;
01036 
01037     virtual bool grow_forest( const CvTermCriteria term_crit );
01038 
01039     // array of the trees of the forest
01040     CvForestTree** trees;
01041     CvDTreeTrainData* data;
01042     int ntrees;
01043     int nclasses;
01044     double oob_error;
01045     CvMat* var_importance;
01046     int nsamples;
01047 
01048     cv::RNG* rng;
01049     CvMat* active_var_mask;
01050 };
01051 
01052 /****************************************************************************************\
01053 *                           Extremely randomized trees Classifier                        *
01054 \****************************************************************************************/
01055 struct CV_EXPORTS CvERTreeTrainData : public CvDTreeTrainData
01056 {
01057     virtual void set_data( const CvMat* trainData, int tflag,
01058                           const CvMat* responses, const CvMat* varIdx=0,
01059                           const CvMat* sampleIdx=0, const CvMat* varType=0,
01060                           const CvMat* missingDataMask=0,
01061                           const CvDTreeParams& params=CvDTreeParams(),
01062                           bool _shared=false, bool _add_labels=false,
01063                           bool _update_data=false );
01064     virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* missing_buf,
01065                                    const float** ord_values, const int** missing, int* sample_buf = 0 );
01066     virtual const int* get_sample_indices( CvDTreeNode* n, int* indices_buf );
01067     virtual const int* get_cv_labels( CvDTreeNode* n, int* labels_buf );
01068     virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf );
01069     virtual void get_vectors( const CvMat* _subsample_idx, float* values, uchar* missing,
01070                               float* responses, bool get_class_idx=false );
01071     virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
01072     const CvMat* missing_mask;
01073 };
01074 
01075 class CV_EXPORTS CvForestERTree : public CvForestTree
01076 {
01077 protected:
01078     virtual double calc_node_dir( CvDTreeNode* node );
01079     virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi,
01080         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01081     virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
01082         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01083     virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi,
01084         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01085     virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi,
01086         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01087     virtual void split_node_data( CvDTreeNode* n );
01088 };
01089 
01090 class CV_EXPORTS_W CvERTrees : public CvRTrees
01091 {
01092 public:
01093     CV_WRAP CvERTrees();
01094     virtual ~CvERTrees();
01095     virtual bool train( const CvMat* trainData, int tflag,
01096                         const CvMat* responses, const CvMat* varIdx=0,
01097                         const CvMat* sampleIdx=0, const CvMat* varType=0,
01098                         const CvMat* missingDataMask=0,
01099                         CvRTParams params=CvRTParams());
01100     CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
01101                        const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
01102                        const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
01103                        const cv::Mat& missingDataMask=cv::Mat(),
01104                        CvRTParams params=CvRTParams());
01105     virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );
01106 protected:
01107     virtual std::string getName() const;
01108     virtual bool grow_forest( const CvTermCriteria term_crit );
01109 };
01110 
01111 
01112 /****************************************************************************************\
01113 *                                   Boosted tree classifier                              *
01114 \****************************************************************************************/
01115 
01116 struct CV_EXPORTS_W_MAP CvBoostParams : public CvDTreeParams
01117 {
01118     CV_PROP_RW int boost_type;
01119     CV_PROP_RW int weak_count;
01120     CV_PROP_RW int split_criteria;
01121     CV_PROP_RW double weight_trim_rate;
01122 
01123     CvBoostParams();
01124     CvBoostParams( int boost_type, int weak_count, double weight_trim_rate,
01125                    int max_depth, bool use_surrogates, const float* priors );
01126 };
01127 
01128 
01129 class CvBoost;
01130 
01131 class CV_EXPORTS CvBoostTree: public CvDTree
01132 {
01133 public:
01134     CvBoostTree();
01135     virtual ~CvBoostTree();
01136 
01137     virtual bool train( CvDTreeTrainData* trainData,
01138                         const CvMat* subsample_idx, CvBoost* ensemble );
01139 
01140     virtual void scale( double s );
01141     virtual void read( CvFileStorage* fs, CvFileNode* node,
01142                        CvBoost* ensemble, CvDTreeTrainData* _data );
01143     virtual void clear();
01144 
01145     /* dummy methods to avoid warnings: BEGIN */
01146     virtual bool train( const CvMat* trainData, int tflag,
01147                         const CvMat* responses, const CvMat* varIdx=0,
01148                         const CvMat* sampleIdx=0, const CvMat* varType=0,
01149                         const CvMat* missingDataMask=0,
01150                         CvDTreeParams params=CvDTreeParams() );
01151     virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx );
01152 
01153     virtual void read( CvFileStorage* fs, CvFileNode* node );
01154     virtual void read( CvFileStorage* fs, CvFileNode* node,
01155                        CvDTreeTrainData* data );
01156     /* dummy methods to avoid warnings: END */
01157 
01158 protected:
01159 
01160     virtual void try_split_node( CvDTreeNode* n );
01161     virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
01162     virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
01163     virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi,
01164         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01165     virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
01166         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01167     virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi,
01168         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01169     virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi,
01170         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01171     virtual void calc_node_value( CvDTreeNode* n );
01172     virtual double calc_node_dir( CvDTreeNode* n );
01173 
01174     CvBoost* ensemble;
01175 };
01176 
01177 
01178 class CV_EXPORTS_W CvBoost : public CvStatModel
01179 {
01180 public:
01181     // Boosting type
01182     enum { DISCRETE=0, REAL=1, LOGIT=2, GENTLE=3 };
01183 
01184     // Splitting criteria
01185     enum { DEFAULT=0, GINI=1, MISCLASS=3, SQERR=4 };
01186 
01187     CV_WRAP CvBoost();
01188     virtual ~CvBoost();
01189 
01190     CvBoost( const CvMat* trainData, int tflag,
01191              const CvMat* responses, const CvMat* varIdx=0,
01192              const CvMat* sampleIdx=0, const CvMat* varType=0,
01193              const CvMat* missingDataMask=0,
01194              CvBoostParams params=CvBoostParams() );
01195 
01196     virtual bool train( const CvMat* trainData, int tflag,
01197              const CvMat* responses, const CvMat* varIdx=0,
01198              const CvMat* sampleIdx=0, const CvMat* varType=0,
01199              const CvMat* missingDataMask=0,
01200              CvBoostParams params=CvBoostParams(),
01201              bool update=false );
01202 
01203     virtual bool train( CvMLData* data,
01204              CvBoostParams params=CvBoostParams(),
01205              bool update=false );
01206 
01207     virtual float predict( const CvMat* sample, const CvMat* missing=0,
01208                            CvMat* weak_responses=0, CvSlice slice=CV_WHOLE_SEQ,
01209                            bool raw_mode=false, bool return_sum=false ) const;
01210 
01211     CV_WRAP CvBoost( const cv::Mat& trainData, int tflag,
01212             const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
01213             const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
01214             const cv::Mat& missingDataMask=cv::Mat(),
01215             CvBoostParams params=CvBoostParams() );
01216 
01217     CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
01218                        const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
01219                        const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
01220                        const cv::Mat& missingDataMask=cv::Mat(),
01221                        CvBoostParams params=CvBoostParams(),
01222                        bool update=false );
01223 
01224     CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing=cv::Mat(),
01225                                    const cv::Range& slice=cv::Range::all(), bool rawMode=false,
01226                                    bool returnSum=false ) const;
01227 
01228     virtual float calc_error( CvMLData* _data, int type , std::vector<float> *resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
01229 
01230     CV_WRAP virtual void prune( CvSlice slice );
01231 
01232     CV_WRAP virtual void clear();
01233 
01234     virtual void write( CvFileStorage* storage, const char* name ) const;
01235     virtual void read( CvFileStorage* storage, CvFileNode* node );
01236     virtual const CvMat* get_active_vars(bool absolute_idx=true);
01237 
01238     CvSeq* get_weak_predictors();
01239 
01240     CvMat* get_weights();
01241     CvMat* get_subtree_weights();
01242     CvMat* get_weak_response();
01243     const CvBoostParams& get_params() const;
01244     const CvDTreeTrainData* get_data() const;
01245 
01246 protected:
01247 
01248     virtual bool set_params( const CvBoostParams& params );
01249     virtual void update_weights( CvBoostTree* tree );
01250     virtual void trim_weights();
01251     virtual void write_params( CvFileStorage* fs ) const;
01252     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
01253 
01254     CvDTreeTrainData* data;
01255     CvBoostParams params;
01256     CvSeq* weak;
01257 
01258     CvMat* active_vars;
01259     CvMat* active_vars_abs;
01260     bool have_active_cat_vars;
01261 
01262     CvMat* orig_response;
01263     CvMat* sum_response;
01264     CvMat* weak_eval;
01265     CvMat* subsample_mask;
01266     CvMat* weights;
01267     CvMat* subtree_weights;
01268     bool have_subsample;
01269 };
01270 
01271 
01272 /****************************************************************************************\
01273 *                                   Gradient Boosted Trees                               *
01274 \****************************************************************************************/
01275 
01276 // DataType: STRUCT CvGBTreesParams
01277 // Parameters of GBT (Gradient Boosted trees model), including single
01278 // tree settings and ensemble parameters.
01279 //
01280 // weak_count          - count of trees in the ensemble
01281 // loss_function_type  - loss function used for ensemble training
01282 // subsample_portion   - portion of whole training set used for
01283 //                       every single tree training.
01284 //                       subsample_portion value is in (0.0, 1.0].
01285 //                       subsample_portion == 1.0 when whole dataset is
01286 //                       used on each step. Count of sample used on each
01287 //                       step is computed as
01288 //                       int(total_samples_count * subsample_portion).
01289 // shrinkage           - regularization parameter.
01290 //                       Each tree prediction is multiplied on shrinkage value.
01291 
01292 
01293 struct CV_EXPORTS_W_MAP CvGBTreesParams : public CvDTreeParams
01294 {
01295     CV_PROP_RW int weak_count;
01296     CV_PROP_RW int loss_function_type;
01297     CV_PROP_RW float subsample_portion;
01298     CV_PROP_RW float shrinkage;
01299 
01300     CvGBTreesParams();
01301     CvGBTreesParams( int loss_function_type, int weak_count, float shrinkage,
01302         float subsample_portion, int max_depth, bool use_surrogates );
01303 };
01304 
01305 // DataType: CLASS CvGBTrees
01306 // Gradient Boosting Trees (GBT) algorithm implementation.
01307 //
01308 // data             - training dataset
01309 // params           - parameters of the CvGBTrees
01310 // weak             - array[0..(class_count-1)] of CvSeq
01311 //                    for storing tree ensembles
01312 // orig_response    - original responses of the training set samples
01313 // sum_response     - predicitons of the current model on the training dataset.
01314 //                    this matrix is updated on every iteration.
01315 // sum_response_tmp - predicitons of the model on the training set on the next
01316 //                    step. On every iteration values of sum_responses_tmp are
01317 //                    computed via sum_responses values. When the current
01318 //                    step is complete sum_response values become equal to
01319 //                    sum_responses_tmp.
01320 // sampleIdx       - indices of samples used for training the ensemble.
01321 //                    CvGBTrees training procedure takes a set of samples
01322 //                    (train_data) and a set of responses (responses).
01323 //                    Only pairs (train_data[i], responses[i]), where i is
01324 //                    in sample_idx are used for training the ensemble.
01325 // subsample_train  - indices of samples used for training a single decision
01326 //                    tree on the current step. This indices are countered
01327 //                    relatively to the sample_idx, so that pairs
01328 //                    (train_data[sample_idx[i]], responses[sample_idx[i]])
01329 //                    are used for training a decision tree.
01330 //                    Training set is randomly splited
01331 //                    in two parts (subsample_train and subsample_test)
01332 //                    on every iteration accordingly to the portion parameter.
01333 // subsample_test   - relative indices of samples from the training set,
01334 //                    which are not used for training a tree on the current
01335 //                    step.
01336 // missing          - mask of the missing values in the training set. This
01337 //                    matrix has the same size as train_data. 1 - missing
01338 //                    value, 0 - not a missing value.
01339 // class_labels     - output class labels map.
01340 // rng              - random number generator. Used for spliting the
01341 //                    training set.
01342 // class_count      - count of output classes.
01343 //                    class_count == 1 in the case of regression,
01344 //                    and > 1 in the case of classification.
01345 // delta            - Huber loss function parameter.
01346 // base_value       - start point of the gradient descent procedure.
01347 //                    model prediction is
01348 //                    f(x) = f_0 + sum_{i=1..weak_count-1}(f_i(x)), where
01349 //                    f_0 is the base value.
01350 
01351 
01352 
01353 class CV_EXPORTS_W CvGBTrees : public CvStatModel
01354 {
01355 public:
01356 
01357     /*
01358     // DataType: ENUM
01359     // Loss functions implemented in CvGBTrees.
01360     //
01361     // SQUARED_LOSS
01362     // problem: regression
01363     // loss = (x - x')^2
01364     //
01365     // ABSOLUTE_LOSS
01366     // problem: regression
01367     // loss = abs(x - x')
01368     //
01369     // HUBER_LOSS
01370     // problem: regression
01371     // loss = delta*( abs(x - x') - delta/2), if abs(x - x') > delta
01372     //           1/2*(x - x')^2, if abs(x - x') <= delta,
01373     //           where delta is the alpha-quantile of pseudo responses from
01374     //           the training set.
01375     //
01376     // DEVIANCE_LOSS
01377     // problem: classification
01378     //
01379     */
01380     enum {SQUARED_LOSS=0, ABSOLUTE_LOSS, HUBER_LOSS=3, DEVIANCE_LOSS};
01381 
01382 
01383     /*
01384     // Default constructor. Creates a model only (without training).
01385     // Should be followed by one form of the train(...) function.
01386     //
01387     // API
01388     // CvGBTrees();
01389 
01390     // INPUT
01391     // OUTPUT
01392     // RESULT
01393     */
01394     CV_WRAP CvGBTrees();
01395 
01396 
01397     /*
01398     // Full form constructor. Creates a gradient boosting model and does the
01399     // train.
01400     //
01401     // API
01402     // CvGBTrees( const CvMat* trainData, int tflag,
01403              const CvMat* responses, const CvMat* varIdx=0,
01404              const CvMat* sampleIdx=0, const CvMat* varType=0,
01405              const CvMat* missingDataMask=0,
01406              CvGBTreesParams params=CvGBTreesParams() );
01407 
01408     // INPUT
01409     // trainData    - a set of input feature vectors.
01410     //                  size of matrix is
01411     //                  <count of samples> x <variables count>
01412     //                  or <variables count> x <count of samples>
01413     //                  depending on the tflag parameter.
01414     //                  matrix values are float.
01415     // tflag         - a flag showing how do samples stored in the
01416     //                  trainData matrix row by row (tflag=CV_ROW_SAMPLE)
01417     //                  or column by column (tflag=CV_COL_SAMPLE).
01418     // responses     - a vector of responses corresponding to the samples
01419     //                  in trainData.
01420     // varIdx       - indices of used variables. zero value means that all
01421     //                  variables are active.
01422     // sampleIdx    - indices of used samples. zero value means that all
01423     //                  samples from trainData are in the training set.
01424     // varType      - vector of <variables count> length. gives every
01425     //                  variable type CV_VAR_CATEGORICAL or CV_VAR_ORDERED.
01426     //                  varType = 0 means all variables are numerical.
01427     // missingDataMask  - a mask of misiing values in trainData.
01428     //                  missingDataMask = 0 means that there are no missing
01429     //                  values.
01430     // params         - parameters of GTB algorithm.
01431     // OUTPUT
01432     // RESULT
01433     */
01434     CvGBTrees( const CvMat* trainData, int tflag,
01435              const CvMat* responses, const CvMat* varIdx=0,
01436              const CvMat* sampleIdx=0, const CvMat* varType=0,
01437              const CvMat* missingDataMask=0,
01438              CvGBTreesParams params=CvGBTreesParams() );
01439 
01440 
01441     /*
01442     // Destructor.
01443     */
01444     virtual ~CvGBTrees();
01445 
01446 
01447     /*
01448     // Gradient tree boosting model training
01449     //
01450     // API
01451     // virtual bool train( const CvMat* trainData, int tflag,
01452              const CvMat* responses, const CvMat* varIdx=0,
01453              const CvMat* sampleIdx=0, const CvMat* varType=0,
01454              const CvMat* missingDataMask=0,
01455              CvGBTreesParams params=CvGBTreesParams(),
01456              bool update=false );
01457 
01458     // INPUT
01459     // trainData    - a set of input feature vectors.
01460     //                  size of matrix is
01461     //                  <count of samples> x <variables count>
01462     //                  or <variables count> x <count of samples>
01463     //                  depending on the tflag parameter.
01464     //                  matrix values are float.
01465     // tflag         - a flag showing how do samples stored in the
01466     //                  trainData matrix row by row (tflag=CV_ROW_SAMPLE)
01467     //                  or column by column (tflag=CV_COL_SAMPLE).
01468     // responses     - a vector of responses corresponding to the samples
01469     //                  in trainData.
01470     // varIdx       - indices of used variables. zero value means that all
01471     //                  variables are active.
01472     // sampleIdx    - indices of used samples. zero value means that all
01473     //                  samples from trainData are in the training set.
01474     // varType      - vector of <variables count> length. gives every
01475     //                  variable type CV_VAR_CATEGORICAL or CV_VAR_ORDERED.
01476     //                  varType = 0 means all variables are numerical.
01477     // missingDataMask  - a mask of misiing values in trainData.
01478     //                  missingDataMask = 0 means that there are no missing
01479     //                  values.
01480     // params         - parameters of GTB algorithm.
01481     // update         - is not supported now. (!)
01482     // OUTPUT
01483     // RESULT
01484     // Error state.
01485     */
01486     virtual bool train( const CvMat* trainData, int tflag,
01487              const CvMat* responses, const CvMat* varIdx=0,
01488              const CvMat* sampleIdx=0, const CvMat* varType=0,
01489              const CvMat* missingDataMask=0,
01490              CvGBTreesParams params=CvGBTreesParams(),
01491              bool update=false );
01492 
01493 
01494     /*
01495     // Gradient tree boosting model training
01496     //
01497     // API
01498     // virtual bool train( CvMLData* data,
01499              CvGBTreesParams params=CvGBTreesParams(),
01500              bool update=false ) {return false;};
01501 
01502     // INPUT
01503     // data          - training set.
01504     // params        - parameters of GTB algorithm.
01505     // update        - is not supported now. (!)
01506     // OUTPUT
01507     // RESULT
01508     // Error state.
01509     */
01510     virtual bool train( CvMLData* data,
01511              CvGBTreesParams params=CvGBTreesParams(),
01512              bool update=false );
01513 
01514 
01515     /*
01516     // Response value prediction
01517     //
01518     // API
01519     // virtual float predict_serial( const CvMat* sample, const CvMat* missing=0,
01520              CvMat* weak_responses=0, CvSlice slice = CV_WHOLE_SEQ,
01521              int k=-1 ) const;
01522 
01523     // INPUT
01524     // sample         - input sample of the same type as in the training set.
01525     // missing        - missing values mask. missing=0 if there are no
01526     //                   missing values in sample vector.
01527     // weak_responses  - predictions of all of the trees.
01528     //                   not implemented (!)
01529     // slice           - part of the ensemble used for prediction.
01530     //                   slice = CV_WHOLE_SEQ when all trees are used.
01531     // k               - number of ensemble used.
01532     //                   k is in {-1,0,1,..,<count of output classes-1>}.
01533     //                   in the case of classification problem
01534     //                   <count of output classes-1> ensembles are built.
01535     //                   If k = -1 ordinary prediction is the result,
01536     //                   otherwise function gives the prediction of the
01537     //                   k-th ensemble only.
01538     // OUTPUT
01539     // RESULT
01540     // Predicted value.
01541     */
01542     virtual float predict_serial( const CvMat* sample, const CvMat* missing=0,
01543             CvMat* weakResponses=0, CvSlice slice = CV_WHOLE_SEQ,
01544             int k=-1 ) const;
01545 
01546     /*
01547     // Response value prediction.
01548     // Parallel version (in the case of TBB existence)
01549     //
01550     // API
01551     // virtual float predict( const CvMat* sample, const CvMat* missing=0,
01552              CvMat* weak_responses=0, CvSlice slice = CV_WHOLE_SEQ,
01553              int k=-1 ) const;
01554 
01555     // INPUT
01556     // sample         - input sample of the same type as in the training set.
01557     // missing        - missing values mask. missing=0 if there are no
01558     //                   missing values in sample vector.
01559     // weak_responses  - predictions of all of the trees.
01560     //                   not implemented (!)
01561     // slice           - part of the ensemble used for prediction.
01562     //                   slice = CV_WHOLE_SEQ when all trees are used.
01563     // k               - number of ensemble used.
01564     //                   k is in {-1,0,1,..,<count of output classes-1>}.
01565     //                   in the case of classification problem
01566     //                   <count of output classes-1> ensembles are built.
01567     //                   If k = -1 ordinary prediction is the result,
01568     //                   otherwise function gives the prediction of the
01569     //                   k-th ensemble only.
01570     // OUTPUT
01571     // RESULT
01572     // Predicted value.
01573     */
01574     virtual float predict( const CvMat* sample, const CvMat* missing=0,
01575             CvMat* weakResponses=0, CvSlice slice = CV_WHOLE_SEQ,
01576             int k=-1 ) const;
01577 
01578     /*
01579     // Deletes all the data.
01580     //
01581     // API
01582     // virtual void clear();
01583 
01584     // INPUT
01585     // OUTPUT
01586     // delete data, weak, orig_response, sum_response,
01587     //        weak_eval, subsample_train, subsample_test,
01588     //        sample_idx, missing, lass_labels
01589     // delta = 0.0
01590     // RESULT
01591     */
01592     CV_WRAP virtual void clear();
01593 
01594     /*
01595     // Compute error on the train/test set.
01596     //
01597     // API
01598     // virtual float calc_error( CvMLData* _data, int type,
01599     //        std::vector<float> *resp = 0 );
01600     //
01601     // INPUT
01602     // data  - dataset
01603     // type  - defines which error is to compute: train (CV_TRAIN_ERROR) or
01604     //         test (CV_TEST_ERROR).
01605     // OUTPUT
01606     // resp  - vector of predicitons
01607     // RESULT
01608     // Error value.
01609     */
01610     virtual float calc_error( CvMLData* _data, int type,
01611             std::vector<float> *resp = 0 );
01612 
01613     /*
01614     //
01615     // Write parameters of the gtb model and data. Write learned model.
01616     //
01617     // API
01618     // virtual void write( CvFileStorage* fs, const char* name ) const;
01619     //
01620     // INPUT
01621     // fs     - file storage to read parameters from.
01622     // name   - model name.
01623     // OUTPUT
01624     // RESULT
01625     */
01626     virtual void write( CvFileStorage* fs, const char* name ) const;
01627 
01628 
01629     /*
01630     //
01631     // Read parameters of the gtb model and data. Read learned model.
01632     //
01633     // API
01634     // virtual void read( CvFileStorage* fs, CvFileNode* node );
01635     //
01636     // INPUT
01637     // fs     - file storage to read parameters from.
01638     // node   - file node.
01639     // OUTPUT
01640     // RESULT
01641     */
01642     virtual void read( CvFileStorage* fs, CvFileNode* node );
01643 
01644 
01645     // new-style C++ interface
01646     CV_WRAP CvGBTrees( const cv::Mat& trainData, int tflag,
01647               const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
01648               const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
01649               const cv::Mat& missingDataMask=cv::Mat(),
01650               CvGBTreesParams params=CvGBTreesParams() );
01651 
01652     CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
01653                        const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
01654                        const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
01655                        const cv::Mat& missingDataMask=cv::Mat(),
01656                        CvGBTreesParams params=CvGBTreesParams(),
01657                        bool update=false );
01658 
01659     CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing=cv::Mat(),
01660                            const cv::Range& slice = cv::Range::all(),
01661                            int k=-1 ) const;
01662 
01663 protected:
01664 
01665     /*
01666     // Compute the gradient vector components.
01667     //
01668     // API
01669     // virtual void find_gradient( const int k = 0);
01670 
01671     // INPUT
01672     // k        - used for classification problem, determining current
01673     //            tree ensemble.
01674     // OUTPUT
01675     // changes components of data->responses
01676     // which correspond to samples used for training
01677     // on the current step.
01678     // RESULT
01679     */
01680     virtual void find_gradient( const int k = 0);
01681 
01682 
01683     /*
01684     //
01685     // Change values in tree leaves according to the used loss function.
01686     //
01687     // API
01688     // virtual void change_values(CvDTree* tree, const int k = 0);
01689     //
01690     // INPUT
01691     // tree      - decision tree to change.
01692     // k         - used for classification problem, determining current
01693     //             tree ensemble.
01694     // OUTPUT
01695     // changes 'value' fields of the trees' leaves.
01696     // changes sum_response_tmp.
01697     // RESULT
01698     */
01699     virtual void change_values(CvDTree* tree, const int k = 0);
01700 
01701 
01702     /*
01703     //
01704     // Find optimal constant prediction value according to the used loss
01705     // function.
01706     // The goal is to find a constant which gives the minimal summary loss
01707     // on the _Idx samples.
01708     //
01709     // API
01710     // virtual float find_optimal_value( const CvMat* _Idx );
01711     //
01712     // INPUT
01713     // _Idx        - indices of the samples from the training set.
01714     // OUTPUT
01715     // RESULT
01716     // optimal constant value.
01717     */
01718     virtual float find_optimal_value( const CvMat* _Idx );
01719 
01720 
01721     /*
01722     //
01723     // Randomly split the whole training set in two parts according
01724     // to params.portion.
01725     //
01726     // API
01727     // virtual void do_subsample();
01728     //
01729     // INPUT
01730     // OUTPUT
01731     // subsample_train - indices of samples used for training
01732     // subsample_test  - indices of samples used for test
01733     // RESULT
01734     */
01735     virtual void do_subsample();
01736 
01737 
01738     /*
01739     //
01740     // Internal recursive function giving an array of subtree tree leaves.
01741     //
01742     // API
01743     // void leaves_get( CvDTreeNode** leaves, int& count, CvDTreeNode* node );
01744     //
01745     // INPUT
01746     // node         - current leaf.
01747     // OUTPUT
01748     // count        - count of leaves in the subtree.
01749     // leaves       - array of pointers to leaves.
01750     // RESULT
01751     */
01752     void leaves_get( CvDTreeNode** leaves, int& count, CvDTreeNode* node );
01753 
01754 
01755     /*
01756     //
01757     // Get leaves of the tree.
01758     //
01759     // API
01760     // CvDTreeNode** GetLeaves( const CvDTree* dtree, int& len );
01761     //
01762     // INPUT
01763     // dtree            - decision tree.
01764     // OUTPUT
01765     // len              - count of the leaves.
01766     // RESULT
01767     // CvDTreeNode**    - array of pointers to leaves.
01768     */
01769     CvDTreeNode** GetLeaves( const CvDTree* dtree, int& len );
01770 
01771 
01772     /*
01773     //
01774     // Is it a regression or a classification.
01775     //
01776     // API
01777     // bool problem_type();
01778     //
01779     // INPUT
01780     // OUTPUT
01781     // RESULT
01782     // false if it is a classification problem,
01783     // true - if regression.
01784     */
01785     virtual bool problem_type() const;
01786 
01787 
01788     /*
01789     //
01790     // Write parameters of the gtb model.
01791     //
01792     // API
01793     // virtual void write_params( CvFileStorage* fs ) const;
01794     //
01795     // INPUT
01796     // fs           - file storage to write parameters to.
01797     // OUTPUT
01798     // RESULT
01799     */
01800     virtual void write_params( CvFileStorage* fs ) const;
01801 
01802 
01803     /*
01804     //
01805     // Read parameters of the gtb model and data.
01806     //
01807     // API
01808     // virtual void read_params( CvFileStorage* fs );
01809     //
01810     // INPUT
01811     // fs           - file storage to read parameters from.
01812     // OUTPUT
01813     // params       - parameters of the gtb model.
01814     // data         - contains information about the structure
01815     //                of the data set (count of variables,
01816     //                their types, etc.).
01817     // class_labels - output class labels map.
01818     // RESULT
01819     */
01820     virtual void read_params( CvFileStorage* fs, CvFileNode* fnode );
01821     int get_len(const CvMat* mat) const;
01822 
01823 
01824     CvDTreeTrainData* data;
01825     CvGBTreesParams params;
01826 
01827     CvSeq** weak;
01828     CvMat* orig_response;
01829     CvMat* sum_response;
01830     CvMat* sum_response_tmp;
01831     CvMat* sample_idx;
01832     CvMat* subsample_train;
01833     CvMat* subsample_test;
01834     CvMat* missing;
01835     CvMat* class_labels;
01836 
01837     cv::RNG* rng;
01838 
01839     int class_count;
01840     float delta;
01841     float base_value;
01842 
01843 };
01844 
01845 
01846 
01847 /****************************************************************************************\
01848 *                              Artificial Neural Networks (ANN)                          *
01849 \****************************************************************************************/
01850 
01852 
01853 struct CV_EXPORTS_W_MAP CvANN_MLP_TrainParams
01854 {
01855     CvANN_MLP_TrainParams();
01856     CvANN_MLP_TrainParams( CvTermCriteria term_crit, int train_method,
01857                            double param1, double param2=0 );
01858     ~CvANN_MLP_TrainParams();
01859 
01860     enum { BACKPROP=0, RPROP=1 };
01861 
01862     CV_PROP_RW CvTermCriteria term_crit;
01863     CV_PROP_RW int train_method;
01864 
01865     // backpropagation parameters
01866     CV_PROP_RW double bp_dw_scale, bp_moment_scale;
01867 
01868     // rprop parameters
01869     CV_PROP_RW double rp_dw0, rp_dw_plus, rp_dw_minus, rp_dw_min, rp_dw_max;
01870 };
01871 
01872 
01873 class CV_EXPORTS_W CvANN_MLP : public CvStatModel
01874 {
01875 public:
01876     CV_WRAP CvANN_MLP();
01877     CvANN_MLP( const CvMat* layerSizes,
01878                int activateFunc=CvANN_MLP::SIGMOID_SYM,
01879                double fparam1=0, double fparam2=0 );
01880 
01881     virtual ~CvANN_MLP();
01882 
01883     virtual void create( const CvMat* layerSizes,
01884                          int activateFunc=CvANN_MLP::SIGMOID_SYM,
01885                          double fparam1=0, double fparam2=0 );
01886 
01887     virtual int train( const CvMat* inputs, const CvMat* outputs,
01888                        const CvMat* sampleWeights, const CvMat* sampleIdx=0,
01889                        CvANN_MLP_TrainParams params = CvANN_MLP_TrainParams(),
01890                        int flags=0 );
01891     virtual float predict( const CvMat* inputs, CV_OUT CvMat* outputs ) const;
01892 
01893     CV_WRAP CvANN_MLP( const cv::Mat& layerSizes,
01894               int activateFunc=CvANN_MLP::SIGMOID_SYM,
01895               double fparam1=0, double fparam2=0 );
01896 
01897     CV_WRAP virtual void create( const cv::Mat& layerSizes,
01898                         int activateFunc=CvANN_MLP::SIGMOID_SYM,
01899                         double fparam1=0, double fparam2=0 );
01900 
01901     CV_WRAP virtual int train( const cv::Mat& inputs, const cv::Mat& outputs,
01902                       const cv::Mat& sampleWeights, const cv::Mat& sampleIdx=cv::Mat(),
01903                       CvANN_MLP_TrainParams params = CvANN_MLP_TrainParams(),
01904                       int flags=0 );
01905 
01906     CV_WRAP virtual float predict( const cv::Mat& inputs, CV_OUT cv::Mat& outputs ) const;
01907 
01908     CV_WRAP virtual void clear();
01909 
01910     // possible activation functions
01911     enum { IDENTITY = 0, SIGMOID_SYM = 1, GAUSSIAN = 2 };
01912 
01913     // available training flags
01914     enum { UPDATE_WEIGHTS = 1, NO_INPUT_SCALE = 2, NO_OUTPUT_SCALE = 4 };
01915 
01916     virtual void read( CvFileStorage* fs, CvFileNode* node );
01917     virtual void write( CvFileStorage* storage, const char* name ) const;
01918 
01919     int get_layer_count() { return layer_sizes ? layer_sizes->cols : 0; }
01920     const CvMat* get_layer_sizes() { return layer_sizes; }
01921     double* get_weights(int layer)
01922     {
01923         return layer_sizes && weights &&
01924             (unsigned)layer <= (unsigned)layer_sizes->cols ? weights[layer] : 0;
01925     }
01926 
01927     virtual void calc_activ_func_deriv( CvMat* xf, CvMat* deriv, const double* bias ) const;
01928 
01929 protected:
01930 
01931     virtual bool prepare_to_train( const CvMat* _inputs, const CvMat* _outputs,
01932             const CvMat* _sample_weights, const CvMat* sampleIdx,
01933             CvVectors* _ivecs, CvVectors* _ovecs, double** _sw, int _flags );
01934 
01935     // sequential random backpropagation
01936     virtual int train_backprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
01937 
01938     // RPROP algorithm
01939     virtual int train_rprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
01940 
01941     virtual void calc_activ_func( CvMat* xf, const double* bias ) const;
01942     virtual void set_activ_func( int _activ_func=SIGMOID_SYM,
01943                                  double _f_param1=0, double _f_param2=0 );
01944     virtual void init_weights();
01945     virtual void scale_input( const CvMat* _src, CvMat* _dst ) const;
01946     virtual void scale_output( const CvMat* _src, CvMat* _dst ) const;
01947     virtual void calc_input_scale( const CvVectors* vecs, int flags );
01948     virtual void calc_output_scale( const CvVectors* vecs, int flags );
01949 
01950     virtual void write_params( CvFileStorage* fs ) const;
01951     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
01952 
01953     CvMat* layer_sizes;
01954     CvMat* wbuf;
01955     CvMat* sample_weights;
01956     double** weights;
01957     double f_param1, f_param2;
01958     double min_val, max_val, min_val1, max_val1;
01959     int activ_func;
01960     int max_count, max_buf_sz;
01961     CvANN_MLP_TrainParams params;
01962     cv::RNG* rng;
01963 };
01964 
01965 /****************************************************************************************\
01966 *                           Auxilary functions declarations                              *
01967 \****************************************************************************************/
01968 
01969 /* Generates <sample> from multivariate normal distribution, where <mean> - is an
01970    average row vector, <cov> - symmetric covariation matrix */
01971 CVAPI(void) cvRandMVNormal( CvMat* mean, CvMat* cov, CvMat* sample,
01972                            CvRNG* rng CV_DEFAULT(0) );
01973 
01974 /* Generates sample from gaussian mixture distribution */
01975 CVAPI(void) cvRandGaussMixture( CvMat* means[],
01976                                CvMat* covs[],
01977                                float weights[],
01978                                int clsnum,
01979                                CvMat* sample,
01980                                CvMat* sampClasses CV_DEFAULT(0) );
01981 
01982 #define CV_TS_CONCENTRIC_SPHERES 0
01983 
01984 /* creates test set */
01985 CVAPI(void) cvCreateTestSet( int type, CvMat** samples,
01986                  int num_samples,
01987                  int num_features,
01988                  CvMat** responses,
01989                  int num_classes, ... );
01990 
01991 /****************************************************************************************\
01992 *                                      Data                                             *
01993 \****************************************************************************************/
01994 
01995 #define CV_COUNT     0
01996 #define CV_PORTION   1
01997 
01998 struct CV_EXPORTS CvTrainTestSplit
01999 {
02000     CvTrainTestSplit();
02001     CvTrainTestSplit( int train_sample_count, bool mix = true);
02002     CvTrainTestSplit( float train_sample_portion, bool mix = true);
02003 
02004     union
02005     {
02006         int count;
02007         float portion;
02008     } train_sample_part;
02009     int train_sample_part_mode;
02010 
02011     bool mix;
02012 };
02013 
02014 class CV_EXPORTS CvMLData
02015 {
02016 public:
02017     CvMLData();
02018     virtual ~CvMLData();
02019 
02020     // returns:
02021     // 0 - OK
02022     // -1 - file can not be opened or is not correct
02023     int read_csv( const char* filename );
02024 
02025     const CvMat* get_values() const;
02026     const CvMat* get_responses();
02027     const CvMat* get_missing() const;
02028 
02029     void set_response_idx( int idx ); // old response become predictors, new response_idx = idx
02030                                       // if idx < 0 there will be no response
02031     int get_response_idx() const;
02032 
02033     void set_train_test_split( const CvTrainTestSplit * spl );
02034     const CvMat* get_train_sample_idx() const;
02035     const CvMat* get_test_sample_idx() const;
02036     void mix_train_and_test_idx();
02037 
02038     const CvMat* get_var_idx();
02039     void chahge_var_idx( int vi, bool state ); // misspelled (saved for back compitability),
02040                                                // use change_var_idx
02041     void change_var_idx( int vi, bool state ); // state == true to set vi-variable as predictor
02042 
02043     const CvMat* get_var_types();
02044     int get_var_type( int var_idx ) const;
02045     // following 2 methods enable to change vars type
02046     // use these methods to assign CV_VAR_CATEGORICAL type for categorical variable
02047     // with numerical labels; in the other cases var types are correctly determined automatically
02048     void set_var_types( const char* str );  // str examples:
02049                                             // "ord[0-17],cat[18]", "ord[0,2,4,10-12], cat[1,3,5-9,13,14]",
02050                                             // "cat", "ord" (all vars are categorical/ordered)
02051     void change_var_type( int var_idx, int type); // type in { CV_VAR_ORDERED, CV_VAR_CATEGORICAL }
02052 
02053     void set_delimiter( char ch );
02054     char get_delimiter() const;
02055 
02056     void set_miss_ch( char ch );
02057     char get_miss_ch() const;
02058 
02059     const std::map<std::string, int>& get_class_labels_map() const;
02060 
02061 protected:
02062     virtual void clear();
02063 
02064     void str_to_flt_elem( const char* token, float& flt_elem, int& type);
02065     void free_train_test_idx();
02066 
02067     char delimiter;
02068     char miss_ch;
02069     //char flt_separator;
02070 
02071     CvMat* values;
02072     CvMat* missing;
02073     CvMat* var_types;
02074     CvMat* var_idx_mask;
02075 
02076     CvMat* response_out; // header
02077     CvMat* var_idx_out; // mat
02078     CvMat* var_types_out; // mat
02079 
02080     int response_idx;
02081 
02082     int train_sample_count;
02083     bool mix;
02084 
02085     int total_class_count;
02086     std::map<std::string, int> class_map;
02087 
02088     CvMat* train_sample_idx;
02089     CvMat* test_sample_idx;
02090     int* sample_idx; // data of train_sample_idx and test_sample_idx
02091 
02092     cv::RNG* rng;
02093 };
02094 
02095 
02096 namespace cv
02097 {
02098 
02099 typedef CvStatModel StatModel;
02100 typedef CvParamGrid ParamGrid;
02101 typedef CvNormalBayesClassifier NormalBayesClassifier;
02102 typedef CvKNearest KNearest;
02103 typedef CvSVMParams SVMParams;
02104 typedef CvSVMKernel SVMKernel;
02105 typedef CvSVMSolver SVMSolver;
02106 typedef CvSVM SVM;
02107 typedef CvDTreeParams DTreeParams;
02108 typedef CvMLData TrainData;
02109 typedef CvDTree DecisionTree;
02110 typedef CvForestTree ForestTree;
02111 typedef CvRTParams RandomTreeParams;
02112 typedef CvRTrees RandomTrees;
02113 typedef CvERTreeTrainData ERTreeTRainData;
02114 typedef CvForestERTree ERTree;
02115 typedef CvERTrees ERTrees;
02116 typedef CvBoostParams BoostParams;
02117 typedef CvBoostTree BoostTree;
02118 typedef CvBoost Boost;
02119 typedef CvANN_MLP_TrainParams ANN_MLP_TrainParams;
02120 typedef CvANN_MLP NeuralNet_MLP;
02121 typedef CvGBTreesParams GradientBoostingTreeParams;
02122 typedef CvGBTrees GradientBoostingTrees;
02123 
02124 template<> CV_EXPORTS void Ptr<CvDTreeSplit>::delete_obj();
02125 
02126 CV_EXPORTS bool initModule_ml(void);
02127 
02128 }
02129 
02130 #endif // __cplusplus
02131 #endif // __OPENCV_ML_HPP__
02132 
02133 /* End of file. */