include/opencv2/ml/ml.hpp
Go to the documentation of this file.
00001 /*M///////////////////////////////////////////////////////////////////////////////////////
00002 //
00003 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
00004 //
00005 //  By downloading, copying, installing or using the software you agree to this license.
00006 //  If you do not agree to this license, do not download, install,
00007 //  copy or use the software.
00008 //
00009 //
00010 //                        Intel License Agreement
00011 //
00012 // Copyright (C) 2000, Intel Corporation, all rights reserved.
00013 // Third party copyrights are property of their respective owners.
00014 //
00015 // Redistribution and use in source and binary forms, with or without modification,
00016 // are permitted provided that the following conditions are met:
00017 //
00018 //   * Redistribution's of source code must retain the above copyright notice,
00019 //     this list of conditions and the following disclaimer.
00020 //
00021 //   * Redistribution's in binary form must reproduce the above copyright notice,
00022 //     this list of conditions and the following disclaimer in the documentation
00023 //     and/or other materials provided with the distribution.
00024 //
00025 //   * The name of Intel Corporation may not be used to endorse or promote products
00026 //     derived from this software without specific prior written permission.
00027 //
00028 // This software is provided by the copyright holders and contributors "as is" and
00029 // any express or implied warranties, including, but not limited to, the implied
00030 // warranties of merchantability and fitness for a particular purpose are disclaimed.
00031 // In no event shall the Intel Corporation or contributors be liable for any direct,
00032 // indirect, incidental, special, exemplary, or consequential damages
00033 // (including, but not limited to, procurement of substitute goods or services;
00034 // loss of use, data, or profits; or business interruption) however caused
00035 // and on any theory of liability, whether in contract, strict liability,
00036 // or tort (including negligence or otherwise) arising in any way out of
00037 // the use of this software, even if advised of the possibility of such damage.
00038 //
00039 //M*/
00040 
00041 #ifndef __OPENCV_ML_HPP__
00042 #define __OPENCV_ML_HPP__
00043 
00044 #include "opencv2/core/core.hpp"
00045 #include <limits.h>
00046 
00047 #ifdef __cplusplus
00048 
00049 // Apple defines a check() macro somewhere in the debug headers
00050 // that interferes with a method definiton in this header
00051 #undef check
00052 
00053 /****************************************************************************************\
00054 *                               Main struct definitions                                  *
00055 \****************************************************************************************/
00056 
00057 /* log(2*PI) */
00058 #define CV_LOG2PI (1.8378770664093454835606594728112)
00059 
00060 /* columns of <trainData> matrix are training samples */
00061 #define CV_COL_SAMPLE 0
00062 
00063 /* rows of <trainData> matrix are training samples */
00064 #define CV_ROW_SAMPLE 1
00065 
00066 #define CV_IS_ROW_SAMPLE(flags) ((flags) & CV_ROW_SAMPLE)
00067 
00068 struct CvVectors
00069 {
00070     int type;
00071     int dims, count;
00072     CvVectors* next;
00073     union
00074     {
00075         uchar** ptr;
00076         float** fl;
00077         double** db;
00078     } data;
00079 };
00080 
00081 #if 0
00082 /* A structure, representing the lattice range of statmodel parameters.
00083    It is used for optimizing statmodel parameters by cross-validation method.
00084    The lattice is logarithmic, so <step> must be greater then 1. */
00085 typedef struct CvParamLattice
00086 {
00087     double min_val;
00088     double max_val;
00089     double step;
00090 }
00091 CvParamLattice;
00092 
00093 CV_INLINE CvParamLattice cvParamLattice( double min_val, double max_val,
00094                                          double log_step )
00095 {
00096     CvParamLattice pl;
00097     pl.min_val = MIN( min_val, max_val );
00098     pl.max_val = MAX( min_val, max_val );
00099     pl.step = MAX( log_step, 1. );
00100     return pl;
00101 }
00102 
00103 CV_INLINE CvParamLattice cvDefaultParamLattice( void )
00104 {
00105     CvParamLattice pl = {0,0,0};
00106     return pl;
00107 }
00108 #endif
00109 
00110 /* Variable type */
00111 #define CV_VAR_NUMERICAL    0
00112 #define CV_VAR_ORDERED      0
00113 #define CV_VAR_CATEGORICAL  1
00114 
00115 #define CV_TYPE_NAME_ML_SVM         "opencv-ml-svm"
00116 #define CV_TYPE_NAME_ML_KNN         "opencv-ml-knn"
00117 #define CV_TYPE_NAME_ML_NBAYES      "opencv-ml-bayesian"
00118 #define CV_TYPE_NAME_ML_EM          "opencv-ml-em"
00119 #define CV_TYPE_NAME_ML_BOOSTING    "opencv-ml-boost-tree"
00120 #define CV_TYPE_NAME_ML_TREE        "opencv-ml-tree"
00121 #define CV_TYPE_NAME_ML_ANN_MLP     "opencv-ml-ann-mlp"
00122 #define CV_TYPE_NAME_ML_CNN         "opencv-ml-cnn"
00123 #define CV_TYPE_NAME_ML_RTREES      "opencv-ml-random-trees"
00124 #define CV_TYPE_NAME_ML_GBT         "opencv-ml-gradient-boosting-trees"
00125 
00126 #define CV_TRAIN_ERROR  0
00127 #define CV_TEST_ERROR   1
00128 
00129 class CV_EXPORTS_W CvStatModel
00130 {
00131 public:
00132     CvStatModel();
00133     virtual ~CvStatModel();
00134 
00135     virtual void clear();
00136 
00137     CV_WRAP virtual void save( const char* filename, const char* name=0 ) const;
00138     CV_WRAP virtual void load( const char* filename, const char* name=0 );
00139 
00140     virtual void write( CvFileStorage* storage, const char* name ) const;
00141     virtual void read( CvFileStorage* storage, CvFileNode* node );
00142 
00143 protected:
00144     const char* default_model_name;
00145 };
00146 
00147 /****************************************************************************************\
00148 *                                 Normal Bayes Classifier                                *
00149 \****************************************************************************************/
00150 
00151 /* The structure, representing the grid range of statmodel parameters.
00152    It is used for optimizing statmodel accuracy by varying model parameters,
00153    the accuracy estimate being computed by cross-validation.
00154    The grid is logarithmic, so <step> must be greater then 1. */
00155 
00156 class CvMLData;
00157 
00158 struct CV_EXPORTS_W_MAP CvParamGrid
00159 {
00160     // SVM params type
00161     enum { SVM_C=0, SVM_GAMMA=1, SVM_P=2, SVM_NU=3, SVM_COEF=4, SVM_DEGREE=5 };
00162 
00163     CvParamGrid()
00164     {
00165         min_val = max_val = step = 0;
00166     }
00167 
00168     CvParamGrid( double _min_val, double _max_val, double log_step )
00169     {
00170         min_val = _min_val;
00171         max_val = _max_val;
00172         step = log_step;
00173     }
00174     //CvParamGrid( int param_id );
00175     bool check() const;
00176 
00177     CV_PROP_RW double min_val;
00178     CV_PROP_RW double max_val;
00179     CV_PROP_RW double step;
00180 };
00181 
00182 class CV_EXPORTS_W CvNormalBayesClassifier : public CvStatModel
00183 {
00184 public:
00185     CV_WRAP CvNormalBayesClassifier();
00186     virtual ~CvNormalBayesClassifier();
00187 
00188     CvNormalBayesClassifier( const CvMat* trainData, const CvMat* responses,
00189         const CvMat* varIdx=0, const CvMat* sampleIdx=0 );
00190     
00191     virtual bool train( const CvMat* trainData, const CvMat* responses,
00192         const CvMat* varIdx = 0, const CvMat* sampleIdx=0, bool update=false );
00193    
00194     virtual float predict( const CvMat* samples, CV_OUT CvMat* results=0 ) const;
00195     CV_WRAP virtual void clear();
00196 
00197 #ifndef SWIG
00198     CV_WRAP CvNormalBayesClassifier( const cv::Mat& trainData, const cv::Mat& responses,
00199                             const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat() );
00200     CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
00201                        const cv::Mat& varIdx = cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
00202                        bool update=false );
00203     CV_WRAP virtual float predict( const cv::Mat& samples, CV_OUT cv::Mat* results=0 ) const;
00204 #endif
00205     
00206     virtual void write( CvFileStorage* storage, const char* name ) const;
00207     virtual void read( CvFileStorage* storage, CvFileNode* node );
00208 
00209 protected:
00210     int     var_count, var_all;
00211     CvMat*  var_idx;
00212     CvMat*  cls_labels;
00213     CvMat** count;
00214     CvMat** sum;
00215     CvMat** productsum;
00216     CvMat** avg;
00217     CvMat** inv_eigen_values;
00218     CvMat** cov_rotate_mats;
00219     CvMat*  c;
00220 };
00221 
00222 
00223 /****************************************************************************************\
00224 *                          K-Nearest Neighbour Classifier                                *
00225 \****************************************************************************************/
00226 
00227 // k Nearest Neighbors
00228 class CV_EXPORTS_W CvKNearest : public CvStatModel
00229 {
00230 public:
00231 
00232     CV_WRAP CvKNearest();
00233     virtual ~CvKNearest();
00234 
00235     CvKNearest( const CvMat* trainData, const CvMat* responses,
00236                 const CvMat* sampleIdx=0, bool isRegression=false, int max_k=32 );
00237 
00238     virtual bool train( const CvMat* trainData, const CvMat* responses,
00239                         const CvMat* sampleIdx=0, bool is_regression=false,
00240                         int maxK=32, bool updateBase=false );
00241     
00242     virtual float find_nearest( const CvMat* samples, int k, CV_OUT CvMat* results=0,
00243         const float** neighbors=0, CV_OUT CvMat* neighborResponses=0, CV_OUT CvMat* dist=0 ) const;
00244     
00245 #ifndef SWIG
00246     CV_WRAP CvKNearest( const cv::Mat& trainData, const cv::Mat& responses,
00247                const cv::Mat& sampleIdx=cv::Mat(), bool isRegression=false, int max_k=32 );
00248     
00249     CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
00250                        const cv::Mat& sampleIdx=cv::Mat(), bool isRegression=false,
00251                        int maxK=32, bool updateBase=false );    
00252     
00253     virtual float find_nearest( const cv::Mat& samples, int k, cv::Mat* results=0,
00254                                 const float** neighbors=0, cv::Mat* neighborResponses=0,
00255                                 cv::Mat* dist=0 ) const;
00256     CV_WRAP virtual float find_nearest( const cv::Mat& samples, int k, CV_OUT cv::Mat& results,
00257                                         CV_OUT cv::Mat& neighborResponses, CV_OUT cv::Mat& dists) const;
00258 #endif
00259     
00260     virtual void clear();
00261     int get_max_k() const;
00262     int get_var_count() const;
00263     int get_sample_count() const;
00264     bool is_regression() const;
00265     
00266     virtual float write_results( int k, int k1, int start, int end,
00267         const float* neighbor_responses, const float* dist, CvMat* _results,
00268         CvMat* _neighbor_responses, CvMat* _dist, Cv32suf* sort_buf ) const;
00269 
00270     virtual void find_neighbors_direct( const CvMat* _samples, int k, int start, int end,
00271         float* neighbor_responses, const float** neighbors, float* dist ) const;
00272 
00273 protected:
00274 
00275     int max_k, var_count;
00276     int total;
00277     bool regression;
00278     CvVectors* samples;
00279 };
00280 
00281 /****************************************************************************************\
00282 *                                   Support Vector Machines                              *
00283 \****************************************************************************************/
00284 
00285 // SVM training parameters
00286 struct CV_EXPORTS_W_MAP CvSVMParams
00287 {
00288     CvSVMParams();
00289     CvSVMParams( int _svm_type, int _kernel_type,
00290                  double _degree, double _gamma, double _coef0,
00291                  double Cvalue, double _nu, double _p,
00292                  CvMat* _class_weights, CvTermCriteria _term_crit );
00293 
00294     CV_PROP_RW int         svm_type;
00295     CV_PROP_RW int         kernel_type;
00296     CV_PROP_RW double      degree; // for poly
00297     CV_PROP_RW double      gamma;  // for poly/rbf/sigmoid
00298     CV_PROP_RW double      coef0;  // for poly/sigmoid
00299 
00300     CV_PROP_RW double      C;  // for CV_SVM_C_SVC, CV_SVM_EPS_SVR and CV_SVM_NU_SVR
00301     CV_PROP_RW double      nu; // for CV_SVM_NU_SVC, CV_SVM_ONE_CLASS, and CV_SVM_NU_SVR
00302     CV_PROP_RW double      p; // for CV_SVM_EPS_SVR
00303     CvMat*      class_weights; // for CV_SVM_C_SVC
00304     CV_PROP_RW CvTermCriteria term_crit; // termination criteria
00305 };
00306 
00307 
00308 struct CV_EXPORTS CvSVMKernel
00309 {
00310     typedef void (CvSVMKernel::*Calc)( int vec_count, int vec_size, const float** vecs,
00311                                        const float* another, float* results );
00312     CvSVMKernel();
00313     CvSVMKernel( const CvSVMParams* params, Calc _calc_func );
00314     virtual bool create( const CvSVMParams* params, Calc _calc_func );
00315     virtual ~CvSVMKernel();
00316 
00317     virtual void clear();
00318     virtual void calc( int vcount, int n, const float** vecs, const float* another, float* results );
00319 
00320     const CvSVMParams* params;
00321     Calc calc_func;
00322 
00323     virtual void calc_non_rbf_base( int vec_count, int vec_size, const float** vecs,
00324                                     const float* another, float* results,
00325                                     double alpha, double beta );
00326 
00327     virtual void calc_linear( int vec_count, int vec_size, const float** vecs,
00328                               const float* another, float* results );
00329     virtual void calc_rbf( int vec_count, int vec_size, const float** vecs,
00330                            const float* another, float* results );
00331     virtual void calc_poly( int vec_count, int vec_size, const float** vecs,
00332                             const float* another, float* results );
00333     virtual void calc_sigmoid( int vec_count, int vec_size, const float** vecs,
00334                                const float* another, float* results );
00335 };
00336 
00337 
00338 struct CvSVMKernelRow
00339 {
00340     CvSVMKernelRow* prev;
00341     CvSVMKernelRow* next;
00342     float* data;
00343 };
00344 
00345 
00346 struct CvSVMSolutionInfo
00347 {
00348     double obj;
00349     double rho;
00350     double upper_bound_p;
00351     double upper_bound_n;
00352     double r;   // for Solver_NU
00353 };
00354 
00355 class CV_EXPORTS CvSVMSolver
00356 {
00357 public:
00358     typedef bool (CvSVMSolver::*SelectWorkingSet)( int& i, int& j );
00359     typedef float* (CvSVMSolver::*GetRow)( int i, float* row, float* dst, bool existed );
00360     typedef void (CvSVMSolver::*CalcRho)( double& rho, double& r );
00361 
00362     CvSVMSolver();
00363 
00364     CvSVMSolver( int count, int var_count, const float** samples, schar* y,
00365                  int alpha_count, double* alpha, double Cp, double Cn,
00366                  CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
00367                  SelectWorkingSet select_working_set, CalcRho calc_rho );
00368     virtual bool create( int count, int var_count, const float** samples, schar* y,
00369                  int alpha_count, double* alpha, double Cp, double Cn,
00370                  CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
00371                  SelectWorkingSet select_working_set, CalcRho calc_rho );
00372     virtual ~CvSVMSolver();
00373 
00374     virtual void clear();
00375     virtual bool solve_generic( CvSVMSolutionInfo& si );
00376 
00377     virtual bool solve_c_svc( int count, int var_count, const float** samples, schar* y,
00378                               double Cp, double Cn, CvMemStorage* storage,
00379                               CvSVMKernel* kernel, double* alpha, CvSVMSolutionInfo& si );
00380     virtual bool solve_nu_svc( int count, int var_count, const float** samples, schar* y,
00381                                CvMemStorage* storage, CvSVMKernel* kernel,
00382                                double* alpha, CvSVMSolutionInfo& si );
00383     virtual bool solve_one_class( int count, int var_count, const float** samples,
00384                                   CvMemStorage* storage, CvSVMKernel* kernel,
00385                                   double* alpha, CvSVMSolutionInfo& si );
00386 
00387     virtual bool solve_eps_svr( int count, int var_count, const float** samples, const float* y,
00388                                 CvMemStorage* storage, CvSVMKernel* kernel,
00389                                 double* alpha, CvSVMSolutionInfo& si );
00390 
00391     virtual bool solve_nu_svr( int count, int var_count, const float** samples, const float* y,
00392                                CvMemStorage* storage, CvSVMKernel* kernel,
00393                                double* alpha, CvSVMSolutionInfo& si );
00394 
00395     virtual float* get_row_base( int i, bool* _existed );
00396     virtual float* get_row( int i, float* dst );
00397 
00398     int sample_count;
00399     int var_count;
00400     int cache_size;
00401     int cache_line_size;
00402     const float** samples;
00403     const CvSVMParams* params;
00404     CvMemStorage* storage;
00405     CvSVMKernelRow lru_list;
00406     CvSVMKernelRow* rows;
00407 
00408     int alpha_count;
00409 
00410     double* G;
00411     double* alpha;
00412 
00413     // -1 - lower bound, 0 - free, 1 - upper bound
00414     schar* alpha_status;
00415 
00416     schar* y;
00417     double* b;
00418     float* buf[2];
00419     double eps;
00420     int max_iter;
00421     double C[2];  // C[0] == Cn, C[1] == Cp
00422     CvSVMKernel* kernel;
00423 
00424     SelectWorkingSet select_working_set_func;
00425     CalcRho calc_rho_func;
00426     GetRow get_row_func;
00427 
00428     virtual bool select_working_set( int& i, int& j );
00429     virtual bool select_working_set_nu_svm( int& i, int& j );
00430     virtual void calc_rho( double& rho, double& r );
00431     virtual void calc_rho_nu_svm( double& rho, double& r );
00432 
00433     virtual float* get_row_svc( int i, float* row, float* dst, bool existed );
00434     virtual float* get_row_one_class( int i, float* row, float* dst, bool existed );
00435     virtual float* get_row_svr( int i, float* row, float* dst, bool existed );
00436 };
00437 
00438 
00439 struct CvSVMDecisionFunc
00440 {
00441     double rho;
00442     int sv_count;
00443     double* alpha;
00444     int* sv_index;
00445 };
00446 
00447 
00448 // SVM model
00449 class CV_EXPORTS_W CvSVM : public CvStatModel
00450 {
00451 public:
00452     // SVM type
00453     enum { C_SVC=100, NU_SVC=101, ONE_CLASS=102, EPS_SVR=103, NU_SVR=104 };
00454 
00455     // SVM kernel type
00456     enum { LINEAR=0, POLY=1, RBF=2, SIGMOID=3 };
00457 
00458     // SVM params type
00459     enum { C=0, GAMMA=1, P=2, NU=3, COEF=4, DEGREE=5 };
00460 
00461     CV_WRAP CvSVM();
00462     virtual ~CvSVM();
00463 
00464     CvSVM( const CvMat* trainData, const CvMat* responses,
00465            const CvMat* varIdx=0, const CvMat* sampleIdx=0,
00466            CvSVMParams params=CvSVMParams() );
00467 
00468     virtual bool train( const CvMat* trainData, const CvMat* responses,
00469                         const CvMat* varIdx=0, const CvMat* sampleIdx=0,
00470                         CvSVMParams params=CvSVMParams() );
00471     
00472     virtual bool train_auto( const CvMat* trainData, const CvMat* responses,
00473         const CvMat* varIdx, const CvMat* sampleIdx, CvSVMParams params,
00474         int kfold = 10,
00475         CvParamGrid Cgrid      = get_default_grid(CvSVM::C),
00476         CvParamGrid gammaGrid  = get_default_grid(CvSVM::GAMMA),
00477         CvParamGrid pGrid      = get_default_grid(CvSVM::P),
00478         CvParamGrid nuGrid     = get_default_grid(CvSVM::NU),
00479         CvParamGrid coeffGrid  = get_default_grid(CvSVM::COEF),
00480         CvParamGrid degreeGrid = get_default_grid(CvSVM::DEGREE),
00481         bool balanced=false );
00482 
00483     virtual float predict( const CvMat* sample, bool returnDFVal=false ) const;
00484     virtual float predict( const CvMat* samples, CvMat* results ) const;
00485     
00486 #ifndef SWIG
00487     CV_WRAP CvSVM( const cv::Mat& trainData, const cv::Mat& responses,
00488           const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
00489           CvSVMParams params=CvSVMParams() );
00490     
00491     CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
00492                        const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
00493                        CvSVMParams params=CvSVMParams() );
00494     
00495     CV_WRAP virtual bool train_auto( const cv::Mat& trainData, const cv::Mat& responses,
00496                             const cv::Mat& varIdx, const cv::Mat& sampleIdx, CvSVMParams params,
00497                             int k_fold = 10,
00498                             CvParamGrid Cgrid      = CvSVM::get_default_grid(CvSVM::C),
00499                             CvParamGrid gammaGrid  = CvSVM::get_default_grid(CvSVM::GAMMA),
00500                             CvParamGrid pGrid      = CvSVM::get_default_grid(CvSVM::P),
00501                             CvParamGrid nuGrid     = CvSVM::get_default_grid(CvSVM::NU),
00502                             CvParamGrid coeffGrid  = CvSVM::get_default_grid(CvSVM::COEF),
00503                             CvParamGrid degreeGrid = CvSVM::get_default_grid(CvSVM::DEGREE),
00504                             bool balanced=false);
00505     CV_WRAP virtual float predict( const cv::Mat& sample, bool returnDFVal=false ) const;
00506 #endif
00507     
00508     CV_WRAP virtual int get_support_vector_count() const;
00509     virtual const float* get_support_vector(int i) const;
00510     virtual CvSVMParams get_params() const { return params; };
00511     CV_WRAP virtual void clear();
00512 
00513     static CvParamGrid get_default_grid( int param_id );
00514 
00515     virtual void write( CvFileStorage* storage, const char* name ) const;
00516     virtual void read( CvFileStorage* storage, CvFileNode* node );
00517     CV_WRAP int get_var_count() const { return var_idx ? var_idx->cols : var_all; }
00518 
00519 protected:
00520 
00521     virtual bool set_params( const CvSVMParams& params );
00522     virtual bool train1( int sample_count, int var_count, const float** samples,
00523                     const void* responses, double Cp, double Cn,
00524                     CvMemStorage* _storage, double* alpha, double& rho );
00525     virtual bool do_train( int svm_type, int sample_count, int var_count, const float** samples,
00526                     const CvMat* responses, CvMemStorage* _storage, double* alpha );
00527     virtual void create_kernel();
00528     virtual void create_solver();
00529 
00530     virtual float predict( const float* row_sample, int row_len, bool returnDFVal=false ) const;
00531 
00532     virtual void write_params( CvFileStorage* fs ) const;
00533     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
00534 
00535     CvSVMParams params;
00536     CvMat* class_labels;
00537     int var_all;
00538     float** sv;
00539     int sv_total;
00540     CvMat* var_idx;
00541     CvMat* class_weights;
00542     CvSVMDecisionFunc* decision_func;
00543     CvMemStorage* storage;
00544 
00545     CvSVMSolver* solver;
00546     CvSVMKernel* kernel;
00547 };
00548 
00549 /****************************************************************************************\
00550 *                              Expectation - Maximization                                *
00551 \****************************************************************************************/
00552 
00553 struct CV_EXPORTS_W_MAP CvEMParams
00554 {
00555     CvEMParams() : nclusters(10), cov_mat_type(1/*CvEM::COV_MAT_DIAGONAL*/),
00556         start_step(0/*CvEM::START_AUTO_STEP*/), probs(0), weights(0), means(0), covs(0)
00557     {
00558         term_crit=cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 100, FLT_EPSILON );
00559     }
00560 
00561     CvEMParams( int _nclusters, int _cov_mat_type=1/*CvEM::COV_MAT_DIAGONAL*/,
00562                 int _start_step=0/*CvEM::START_AUTO_STEP*/,
00563                 CvTermCriteria _term_crit=cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 100, FLT_EPSILON),
00564                 const CvMat* _probs=0, const CvMat* _weights=0, const CvMat* _means=0, const CvMat** _covs=0 ) :
00565                 nclusters(_nclusters), cov_mat_type(_cov_mat_type), start_step(_start_step),
00566                 probs(_probs), weights(_weights), means(_means), covs(_covs), term_crit(_term_crit)
00567     {}
00568 
00569     CV_PROP_RW int nclusters;
00570     CV_PROP_RW int cov_mat_type;
00571     CV_PROP_RW int start_step;
00572     const CvMat* probs;
00573     const CvMat* weights;
00574     const CvMat* means;
00575     const CvMat** covs;
00576     CV_PROP_RW CvTermCriteria term_crit;
00577 };
00578 
00579 
00580 class CV_EXPORTS_W CvEM : public CvStatModel
00581 {
00582 public:
00583     // Type of covariation matrices
00584     enum { COV_MAT_SPHERICAL=0, COV_MAT_DIAGONAL=1, COV_MAT_GENERIC=2 };
00585 
00586     // The initial step
00587     enum { START_E_STEP=1, START_M_STEP=2, START_AUTO_STEP=0 };
00588 
00589     CV_WRAP CvEM();
00590     CvEM( const CvMat* samples, const CvMat* sampleIdx=0,
00591           CvEMParams params=CvEMParams(), CvMat* labels=0 );
00592     //CvEM (CvEMParams params, CvMat * means, CvMat ** covs, CvMat * weights,
00593     // CvMat * probs, CvMat * log_weight_div_det, CvMat * inv_eigen_values, CvMat** cov_rotate_mats);
00594 
00595     virtual ~CvEM();
00596 
00597     virtual bool train( const CvMat* samples, const CvMat* sampleIdx=0,
00598                         CvEMParams params=CvEMParams(), CvMat* labels=0 );
00599 
00600     virtual float predict( const CvMat* sample, CV_OUT CvMat* probs ) const;
00601 
00602 #ifndef SWIG
00603     CV_WRAP CvEM( const cv::Mat& samples, const cv::Mat& sampleIdx=cv::Mat(),
00604                   CvEMParams params=CvEMParams() );
00605     
00606     CV_WRAP virtual bool train( const cv::Mat& samples,
00607                                 const cv::Mat& sampleIdx=cv::Mat(),
00608                                 CvEMParams params=CvEMParams(),
00609                                 CV_OUT cv::Mat* labels=0 );
00610     
00611     CV_WRAP virtual float predict( const cv::Mat& sample, CV_OUT cv::Mat* probs=0 ) const;
00612     CV_WRAP virtual double calcLikelihood( const cv::Mat &sample ) const;
00613     
00614     CV_WRAP int  getNClusters() const;
00615     CV_WRAP cv::Mat  getMeans()     const;
00616     CV_WRAP void getCovs(CV_OUT std::vector<cv::Mat>& covs)      const;
00617     CV_WRAP cv::Mat  getWeights()   const;
00618     CV_WRAP cv::Mat  getProbs()     const;
00619     
00620     CV_WRAP inline double getLikelihood() const { return log_likelihood; }
00621     CV_WRAP inline double getLikelihoodDelta() const { return log_likelihood_delta; }
00622 #endif
00623     
00624     CV_WRAP virtual void clear();
00625 
00626     int           get_nclusters() const;
00627     const CvMat*  get_means()     const;
00628     const CvMat** get_covs()      const;
00629     const CvMat*  get_weights()   const;
00630     const CvMat*  get_probs()     const;
00631 
00632     inline double get_log_likelihood() const { return log_likelihood; }
00633     inline double get_log_likelihood_delta() const { return log_likelihood_delta; }
00634     
00635 //    inline const CvMat *  get_log_weight_div_det () const { return log_weight_div_det; };
00636 //    inline const CvMat *  get_inv_eigen_values   () const { return inv_eigen_values;   };
00637 //    inline const CvMat ** get_cov_rotate_mats    () const { return cov_rotate_mats;    };
00638 
00639     virtual void read( CvFileStorage* fs, CvFileNode* node );
00640     virtual void write( CvFileStorage* fs, const char* name ) const;
00641 
00642     virtual void write_params( CvFileStorage* fs ) const;
00643     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
00644 
00645 protected:
00646 
00647     virtual void set_params( const CvEMParams& params,
00648                              const CvVectors& train_data );
00649     virtual void init_em( const CvVectors& train_data );
00650     virtual double run_em( const CvVectors& train_data );
00651     virtual void init_auto( const CvVectors& samples );
00652     virtual void kmeans( const CvVectors& train_data, int nclusters,
00653                          CvMat* labels, CvTermCriteria criteria,
00654                          const CvMat* means );
00655     CvEMParams params;
00656     double log_likelihood;
00657     double log_likelihood_delta;
00658 
00659     CvMat* means;
00660     CvMat** covs;
00661     CvMat* weights;
00662     CvMat* probs;
00663 
00664     CvMat* log_weight_div_det;
00665     CvMat* inv_eigen_values;
00666     CvMat** cov_rotate_mats;
00667 };
00668 
00669 /****************************************************************************************\
00670 *                                      Decision Tree                                     *
00671 \****************************************************************************************/\
00672 struct CvPair16u32s
00673 {
00674     unsigned short* u;
00675     int* i;
00676 };
00677 
00678 
00679 #define CV_DTREE_CAT_DIR(idx,subset) \
00680     (2*((subset[(idx)>>5]&(1 << ((idx) & 31)))==0)-1)
00681 
00682 struct CvDTreeSplit
00683 {
00684     int var_idx;
00685     int condensed_idx;
00686     int inversed;
00687     float quality;
00688     CvDTreeSplit* next;
00689     union
00690     {
00691         int subset[2];
00692         struct
00693         {
00694             float c;
00695             int split_point;
00696         }
00697         ord;
00698     };
00699 };
00700 
00701 struct CvDTreeNode
00702 {
00703     int class_idx;
00704     int Tn;
00705     double value;
00706 
00707     CvDTreeNode* parent;
00708     CvDTreeNode* left;
00709     CvDTreeNode* right;
00710 
00711     CvDTreeSplit* split;
00712 
00713     int sample_count;
00714     int depth;
00715     int* num_valid;
00716     int offset;
00717     int buf_idx;
00718     double maxlr;
00719 
00720     // global pruning data
00721     int complexity;
00722     double alpha;
00723     double node_risk, tree_risk, tree_error;
00724 
00725     // cross-validation pruning data
00726     int* cv_Tn;
00727     double* cv_node_risk;
00728     double* cv_node_error;
00729 
00730     int get_num_valid(int vi) { return num_valid ? num_valid[vi] : sample_count; }
00731     void set_num_valid(int vi, int n) { if( num_valid ) num_valid[vi] = n; }
00732 };
00733 
00734 
00735 struct CV_EXPORTS_W_MAP CvDTreeParams
00736 {
00737     CV_PROP_RW int   max_categories;
00738     CV_PROP_RW int   max_depth;
00739     CV_PROP_RW int   min_sample_count;
00740     CV_PROP_RW int   cv_folds;
00741     CV_PROP_RW bool  use_surrogates;
00742     CV_PROP_RW bool  use_1se_rule;
00743     CV_PROP_RW bool  truncate_pruned_tree;
00744     CV_PROP_RW float regression_accuracy;
00745     const float* priors;
00746 
00747     CvDTreeParams() : max_categories(10), max_depth(INT_MAX), min_sample_count(10),
00748         cv_folds(10), use_surrogates(true), use_1se_rule(true),
00749         truncate_pruned_tree(true), regression_accuracy(0.01f), priors(0)
00750     {}
00751 
00752     CvDTreeParams( int _max_depth, int _min_sample_count,
00753                    float _regression_accuracy, bool _use_surrogates,
00754                    int _max_categories, int _cv_folds,
00755                    bool _use_1se_rule, bool _truncate_pruned_tree,
00756                    const float* _priors ) :
00757         max_categories(_max_categories), max_depth(_max_depth),
00758         min_sample_count(_min_sample_count), cv_folds (_cv_folds),
00759         use_surrogates(_use_surrogates), use_1se_rule(_use_1se_rule),
00760         truncate_pruned_tree(_truncate_pruned_tree),
00761         regression_accuracy(_regression_accuracy),
00762         priors(_priors)
00763     {}
00764 };
00765 
00766 
00767 struct CV_EXPORTS CvDTreeTrainData
00768 {
00769     CvDTreeTrainData();
00770     CvDTreeTrainData( const CvMat* trainData, int tflag,
00771                       const CvMat* responses, const CvMat* varIdx=0,
00772                       const CvMat* sampleIdx=0, const CvMat* varType=0,
00773                       const CvMat* missingDataMask=0,
00774                       const CvDTreeParams& params=CvDTreeParams(),
00775                       bool _shared=false, bool _add_labels=false );
00776     virtual ~CvDTreeTrainData();
00777 
00778     virtual void set_data( const CvMat* trainData, int tflag,
00779                           const CvMat* responses, const CvMat* varIdx=0,
00780                           const CvMat* sampleIdx=0, const CvMat* varType=0,
00781                           const CvMat* missingDataMask=0,
00782                           const CvDTreeParams& params=CvDTreeParams(),
00783                           bool _shared=false, bool _add_labels=false,
00784                           bool _update_data=false );
00785     virtual void do_responses_copy();
00786 
00787     virtual void get_vectors( const CvMat* _subsample_idx,
00788          float* values, uchar* missing, float* responses, bool get_class_idx=false );
00789 
00790     virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
00791 
00792     virtual void write_params( CvFileStorage* fs ) const;
00793     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
00794 
00795     // release all the data
00796     virtual void clear();
00797 
00798     int get_num_classes() const;
00799     int get_var_type(int vi) const;
00800     int get_work_var_count() const {return work_var_count;}
00801 
00802     virtual const float* get_ord_responses( CvDTreeNode* n, float* values_buf, int* sample_indices_buf );
00803     virtual const int* get_class_labels( CvDTreeNode* n, int* labels_buf );
00804     virtual const int* get_cv_labels( CvDTreeNode* n, int* labels_buf );
00805     virtual const int* get_sample_indices( CvDTreeNode* n, int* indices_buf );
00806     virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf );
00807     virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* sorted_indices_buf,
00808                                    const float** ord_values, const int** sorted_indices, int* sample_indices_buf );
00809     virtual int get_child_buf_idx( CvDTreeNode* n );
00810 
00812 
00813     virtual bool set_params( const CvDTreeParams& params );
00814     virtual CvDTreeNode* new_node( CvDTreeNode* parent, int count,
00815                                    int storage_idx, int offset );
00816 
00817     virtual CvDTreeSplit* new_split_ord( int vi, float cmp_val,
00818                 int split_point, int inversed, float quality );
00819     virtual CvDTreeSplit* new_split_cat( int vi, float quality );
00820     virtual void free_node_data( CvDTreeNode* node );
00821     virtual void free_train_data();
00822     virtual void free_node( CvDTreeNode* node );
00823 
00824     int sample_count, var_all, var_count, max_c_count;
00825     int ord_var_count, cat_var_count, work_var_count;
00826     bool have_labels, have_priors;
00827     bool is_classifier;
00828     int tflag;
00829 
00830     const CvMat* train_data;
00831     const CvMat* responses;
00832     CvMat* responses_copy; // used in Boosting
00833 
00834     int buf_count, buf_size;
00835     bool shared;
00836     int is_buf_16u;
00837     
00838     CvMat* cat_count;
00839     CvMat* cat_ofs;
00840     CvMat* cat_map;
00841 
00842     CvMat* counts;
00843     CvMat* buf;
00844     CvMat* direction;
00845     CvMat* split_buf;
00846 
00847     CvMat* var_idx;
00848     CvMat* var_type; // i-th element =
00849                      //   k<0  - ordered
00850                      //   k>=0 - categorical, see k-th element of cat_* arrays
00851     CvMat* priors;
00852     CvMat* priors_mult;
00853 
00854     CvDTreeParams params;
00855 
00856     CvMemStorage* tree_storage;
00857     CvMemStorage* temp_storage;
00858 
00859     CvDTreeNode* data_root;
00860 
00861     CvSet* node_heap;
00862     CvSet* split_heap;
00863     CvSet* cv_heap;
00864     CvSet* nv_heap;
00865 
00866     cv::RNG* rng;
00867 };
00868 
00869 class CvDTree;
00870 class CvForestTree;
00871 
00872 namespace cv
00873 {
00874     struct DTreeBestSplitFinder;
00875     struct ForestTreeBestSplitFinder;
00876 }
00877 
00878 class CV_EXPORTS_W CvDTree : public CvStatModel
00879 {
00880 public:
00881     CV_WRAP CvDTree();
00882     virtual ~CvDTree();
00883 
00884     virtual bool train( const CvMat* trainData, int tflag,
00885                         const CvMat* responses, const CvMat* varIdx=0,
00886                         const CvMat* sampleIdx=0, const CvMat* varType=0,
00887                         const CvMat* missingDataMask=0,
00888                         CvDTreeParams params=CvDTreeParams() );
00889 
00890     virtual bool train( CvMLData* trainData, CvDTreeParams params=CvDTreeParams() );
00891 
00892     // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
00893     virtual float calc_error( CvMLData* trainData, int type, std::vector<float> *resp = 0 );
00894 
00895     virtual bool train( CvDTreeTrainData* trainData, const CvMat* subsampleIdx );
00896 
00897     virtual CvDTreeNode* predict( const CvMat* sample, const CvMat* missingDataMask=0,
00898                                   bool preprocessedInput=false ) const;
00899 
00900 #ifndef SWIG
00901     CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
00902                        const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
00903                        const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
00904                        const cv::Mat& missingDataMask=cv::Mat(),
00905                        CvDTreeParams params=CvDTreeParams() );
00906     
00907     CV_WRAP virtual CvDTreeNode* predict( const cv::Mat& sample, const cv::Mat& missingDataMask=cv::Mat(),
00908                                   bool preprocessedInput=false ) const;
00909     CV_WRAP virtual cv::Mat getVarImportance();
00910 #endif
00911     
00912     virtual const CvMat* get_var_importance();
00913     CV_WRAP virtual void clear();
00914 
00915     virtual void read( CvFileStorage* fs, CvFileNode* node );
00916     virtual void write( CvFileStorage* fs, const char* name ) const;
00917 
00918     // special read & write methods for trees in the tree ensembles
00919     virtual void read( CvFileStorage* fs, CvFileNode* node,
00920                        CvDTreeTrainData* data );
00921     virtual void write( CvFileStorage* fs ) const;
00922 
00923     const CvDTreeNode* get_root() const;
00924     int get_pruned_tree_idx() const;
00925     CvDTreeTrainData* get_data();
00926 
00927 protected:
00928     friend struct cv::DTreeBestSplitFinder;
00929 
00930     virtual bool do_train( const CvMat* _subsample_idx );
00931 
00932     virtual void try_split_node( CvDTreeNode* n );
00933     virtual void split_node_data( CvDTreeNode* n );
00934     virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
00935     virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi, 
00936                             float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
00937     virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
00938                             float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
00939     virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi, 
00940                             float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
00941     virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi, 
00942                             float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
00943     virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
00944     virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
00945     virtual double calc_node_dir( CvDTreeNode* node );
00946     virtual void complete_node_dir( CvDTreeNode* node );
00947     virtual void cluster_categories( const int* vectors, int vector_count,
00948         int var_count, int* sums, int k, int* cluster_labels );
00949 
00950     virtual void calc_node_value( CvDTreeNode* node );
00951 
00952     virtual void prune_cv();
00953     virtual double update_tree_rnc( int T, int fold );
00954     virtual int cut_tree( int T, int fold, double min_alpha );
00955     virtual void free_prune_data(bool cut_tree);
00956     virtual void free_tree();
00957 
00958     virtual void write_node( CvFileStorage* fs, CvDTreeNode* node ) const;
00959     virtual void write_split( CvFileStorage* fs, CvDTreeSplit* split ) const;
00960     virtual CvDTreeNode* read_node( CvFileStorage* fs, CvFileNode* node, CvDTreeNode* parent );
00961     virtual CvDTreeSplit* read_split( CvFileStorage* fs, CvFileNode* node );
00962     virtual void write_tree_nodes( CvFileStorage* fs ) const;
00963     virtual void read_tree_nodes( CvFileStorage* fs, CvFileNode* node );
00964 
00965     CvDTreeNode* root;
00966     CvMat* var_importance;
00967     CvDTreeTrainData* data;
00968 
00969 public:
00970     int pruned_tree_idx;
00971 };
00972 
00973 
00974 /****************************************************************************************\
00975 *                                   Random Trees Classifier                              *
00976 \****************************************************************************************/
00977 
00978 class CvRTrees;
00979 
00980 class CV_EXPORTS CvForestTree: public CvDTree
00981 {
00982 public:
00983     CvForestTree();
00984     virtual ~CvForestTree();
00985 
00986     virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx, CvRTrees* forest );
00987 
00988     virtual int get_var_count() const {return data ? data->var_count : 0;}
00989     virtual void read( CvFileStorage* fs, CvFileNode* node, CvRTrees* forest, CvDTreeTrainData* _data );
00990 
00991     /* dummy methods to avoid warnings: BEGIN */
00992     virtual bool train( const CvMat* trainData, int tflag,
00993                         const CvMat* responses, const CvMat* varIdx=0,
00994                         const CvMat* sampleIdx=0, const CvMat* varType=0,
00995                         const CvMat* missingDataMask=0,
00996                         CvDTreeParams params=CvDTreeParams() );
00997 
00998     virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx );
00999     virtual void read( CvFileStorage* fs, CvFileNode* node );
01000     virtual void read( CvFileStorage* fs, CvFileNode* node,
01001                        CvDTreeTrainData* data );
01002     /* dummy methods to avoid warnings: END */
01003 
01004 protected:
01005     friend struct cv::ForestTreeBestSplitFinder;
01006 
01007     virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
01008     CvRTrees* forest;
01009 };
01010 
01011 
01012 struct CV_EXPORTS_W_MAP CvRTParams : public CvDTreeParams
01013 {
01014     //Parameters for the forest
01015     CV_PROP_RW bool calc_var_importance; // true <=> RF processes variable importance
01016     CV_PROP_RW int nactive_vars;
01017     CV_PROP_RW CvTermCriteria term_crit;
01018 
01019     CvRTParams() : CvDTreeParams( 5, 10, 0, false, 10, 0, false, false, 0 ),
01020         calc_var_importance(false), nactive_vars(0)
01021     {
01022         term_crit = cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 50, 0.1 );
01023     }
01024 
01025     CvRTParams( int _max_depth, int _min_sample_count,
01026                 float _regression_accuracy, bool _use_surrogates,
01027                 int _max_categories, const float* _priors, bool _calc_var_importance,
01028                 int _nactive_vars, int max_num_of_trees_in_the_forest,
01029                 float forest_accuracy, int termcrit_type ) :
01030         CvDTreeParams( _max_depth, _min_sample_count, _regression_accuracy,
01031                        _use_surrogates, _max_categories, 0,
01032                        false, false, _priors ),
01033         calc_var_importance(_calc_var_importance),
01034         nactive_vars(_nactive_vars)
01035     {
01036         term_crit = cvTermCriteria(termcrit_type,
01037             max_num_of_trees_in_the_forest, forest_accuracy);
01038     }
01039 };
01040 
01041 
01042 class CV_EXPORTS_W CvRTrees : public CvStatModel
01043 {
01044 public:
01045     CV_WRAP CvRTrees();
01046     virtual ~CvRTrees();
01047     virtual bool train( const CvMat* trainData, int tflag,
01048                         const CvMat* responses, const CvMat* varIdx=0,
01049                         const CvMat* sampleIdx=0, const CvMat* varType=0,
01050                         const CvMat* missingDataMask=0,
01051                         CvRTParams params=CvRTParams() );
01052     
01053     virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );
01054     virtual float predict( const CvMat* sample, const CvMat* missing = 0 ) const;
01055     virtual float predict_prob( const CvMat* sample, const CvMat* missing = 0 ) const;
01056 
01057 #ifndef SWIG
01058     CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
01059                        const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
01060                        const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
01061                        const cv::Mat& missingDataMask=cv::Mat(),
01062                        CvRTParams params=CvRTParams() );
01063     CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;
01064     CV_WRAP virtual float predict_prob( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;
01065     CV_WRAP virtual cv::Mat getVarImportance();
01066 #endif
01067     
01068     CV_WRAP virtual void clear();
01069 
01070     virtual const CvMat* get_var_importance();
01071     virtual float get_proximity( const CvMat* sample1, const CvMat* sample2,
01072         const CvMat* missing1 = 0, const CvMat* missing2 = 0 ) const;
01073     
01074     virtual float calc_error( CvMLData* _data, int type , std::vector<float> *resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
01075 
01076     virtual float get_train_error();    
01077 
01078     virtual void read( CvFileStorage* fs, CvFileNode* node );
01079     virtual void write( CvFileStorage* fs, const char* name ) const;
01080 
01081     CvMat* get_active_var_mask();
01082     CvRNG* get_rng();
01083 
01084     int get_tree_count() const;
01085     CvForestTree* get_tree(int i) const;
01086 
01087 protected:
01088 
01089     virtual bool grow_forest( const CvTermCriteria term_crit );
01090 
01091     // array of the trees of the forest
01092     CvForestTree** trees;
01093     CvDTreeTrainData* data;
01094     int ntrees;
01095     int nclasses;
01096     double oob_error;
01097     CvMat* var_importance;
01098     int nsamples;
01099 
01100     cv::RNG* rng;
01101     CvMat* active_var_mask;
01102 };
01103 
01104 /****************************************************************************************\
01105 *                           Extremely randomized trees Classifier                        *
01106 \****************************************************************************************/
01107 struct CV_EXPORTS CvERTreeTrainData : public CvDTreeTrainData
01108 {
01109     virtual void set_data( const CvMat* trainData, int tflag,
01110                           const CvMat* responses, const CvMat* varIdx=0,
01111                           const CvMat* sampleIdx=0, const CvMat* varType=0,
01112                           const CvMat* missingDataMask=0,
01113                           const CvDTreeParams& params=CvDTreeParams(),
01114                           bool _shared=false, bool _add_labels=false,
01115                           bool _update_data=false );
01116     virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* missing_buf,
01117                                    const float** ord_values, const int** missing, int* sample_buf = 0 );
01118     virtual const int* get_sample_indices( CvDTreeNode* n, int* indices_buf );
01119     virtual const int* get_cv_labels( CvDTreeNode* n, int* labels_buf );
01120     virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf );
01121     virtual void get_vectors( const CvMat* _subsample_idx, float* values, uchar* missing,
01122                               float* responses, bool get_class_idx=false );
01123     virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
01124     const CvMat* missing_mask;
01125 };
01126 
01127 class CV_EXPORTS CvForestERTree : public CvForestTree
01128 {
01129 protected:
01130     virtual double calc_node_dir( CvDTreeNode* node );
01131     virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi, 
01132         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01133     virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
01134         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01135     virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi, 
01136         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01137     virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi, 
01138         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01139     virtual void split_node_data( CvDTreeNode* n );
01140 };
01141 
01142 class CV_EXPORTS_W CvERTrees : public CvRTrees
01143 {
01144 public:
01145     CV_WRAP CvERTrees();
01146     virtual ~CvERTrees();
01147     virtual bool train( const CvMat* trainData, int tflag,
01148                         const CvMat* responses, const CvMat* varIdx=0,
01149                         const CvMat* sampleIdx=0, const CvMat* varType=0,
01150                         const CvMat* missingDataMask=0,
01151                         CvRTParams params=CvRTParams());
01152 #ifndef SWIG
01153     CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
01154                        const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
01155                        const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
01156                        const cv::Mat& missingDataMask=cv::Mat(),
01157                        CvRTParams params=CvRTParams());
01158 #endif
01159     virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );
01160 protected:
01161     virtual bool grow_forest( const CvTermCriteria term_crit );
01162 };
01163 
01164 
01165 /****************************************************************************************\
01166 *                                   Boosted tree classifier                              *
01167 \****************************************************************************************/
01168 
01169 struct CV_EXPORTS_W_MAP CvBoostParams : public CvDTreeParams
01170 {
01171     CV_PROP_RW int boost_type;
01172     CV_PROP_RW int weak_count;
01173     CV_PROP_RW int split_criteria;
01174     CV_PROP_RW double weight_trim_rate;
01175 
01176     CvBoostParams();
01177     CvBoostParams( int boost_type, int weak_count, double weight_trim_rate,
01178                    int max_depth, bool use_surrogates, const float* priors );
01179 };
01180 
01181 
01182 class CvBoost;
01183 
01184 class CV_EXPORTS CvBoostTree: public CvDTree
01185 {
01186 public:
01187     CvBoostTree();
01188     virtual ~CvBoostTree();
01189 
01190     virtual bool train( CvDTreeTrainData* trainData,
01191                         const CvMat* subsample_idx, CvBoost* ensemble );
01192 
01193     virtual void scale( double s );
01194     virtual void read( CvFileStorage* fs, CvFileNode* node,
01195                        CvBoost* ensemble, CvDTreeTrainData* _data );
01196     virtual void clear();
01197 
01198     /* dummy methods to avoid warnings: BEGIN */
01199     virtual bool train( const CvMat* trainData, int tflag,
01200                         const CvMat* responses, const CvMat* varIdx=0,
01201                         const CvMat* sampleIdx=0, const CvMat* varType=0,
01202                         const CvMat* missingDataMask=0,
01203                         CvDTreeParams params=CvDTreeParams() );
01204     virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx );
01205 
01206     virtual void read( CvFileStorage* fs, CvFileNode* node );
01207     virtual void read( CvFileStorage* fs, CvFileNode* node,
01208                        CvDTreeTrainData* data );
01209     /* dummy methods to avoid warnings: END */
01210 
01211 protected:
01212 
01213     virtual void try_split_node( CvDTreeNode* n );
01214     virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
01215     virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
01216     virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi, 
01217         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01218     virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
01219         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01220     virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi, 
01221         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01222     virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi, 
01223         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01224     virtual void calc_node_value( CvDTreeNode* n );
01225     virtual double calc_node_dir( CvDTreeNode* n );
01226 
01227     CvBoost* ensemble;
01228 };
01229 
01230 
01231 class CV_EXPORTS_W CvBoost : public CvStatModel
01232 {
01233 public:
01234     // Boosting type
01235     enum { DISCRETE=0, REAL=1, LOGIT=2, GENTLE=3 };
01236 
01237     // Splitting criteria
01238     enum { DEFAULT=0, GINI=1, MISCLASS=3, SQERR=4 };
01239 
01240     CV_WRAP CvBoost();
01241     virtual ~CvBoost();
01242 
01243     CvBoost( const CvMat* trainData, int tflag,
01244              const CvMat* responses, const CvMat* varIdx=0,
01245              const CvMat* sampleIdx=0, const CvMat* varType=0,
01246              const CvMat* missingDataMask=0,
01247              CvBoostParams params=CvBoostParams() );
01248     
01249     virtual bool train( const CvMat* trainData, int tflag,
01250              const CvMat* responses, const CvMat* varIdx=0,
01251              const CvMat* sampleIdx=0, const CvMat* varType=0,
01252              const CvMat* missingDataMask=0,
01253              CvBoostParams params=CvBoostParams(),
01254              bool update=false );
01255     
01256     virtual bool train( CvMLData* data,
01257              CvBoostParams params=CvBoostParams(),
01258              bool update=false );
01259 
01260     virtual float predict( const CvMat* sample, const CvMat* missing=0,
01261                            CvMat* weak_responses=0, CvSlice slice=CV_WHOLE_SEQ,
01262                            bool raw_mode=false, bool return_sum=false ) const;
01263 
01264 #ifndef SWIG
01265     CV_WRAP CvBoost( const cv::Mat& trainData, int tflag,
01266             const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
01267             const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
01268             const cv::Mat& missingDataMask=cv::Mat(),
01269             CvBoostParams params=CvBoostParams() );
01270     
01271     CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
01272                        const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
01273                        const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
01274                        const cv::Mat& missingDataMask=cv::Mat(),
01275                        CvBoostParams params=CvBoostParams(),
01276                        bool update=false );
01277     
01278     CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing=cv::Mat(),
01279                                    const cv::Range& slice=cv::Range::all(), bool rawMode=false,
01280                                    bool returnSum=false ) const;
01281 #endif
01282     
01283     virtual float calc_error( CvMLData* _data, int type , std::vector<float> *resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
01284 
01285     CV_WRAP virtual void prune( CvSlice slice );
01286 
01287     CV_WRAP virtual void clear();
01288 
01289     virtual void write( CvFileStorage* storage, const char* name ) const;
01290     virtual void read( CvFileStorage* storage, CvFileNode* node );
01291     virtual const CvMat* get_active_vars(bool absolute_idx=true);
01292 
01293     CvSeq* get_weak_predictors();
01294 
01295     CvMat* get_weights();
01296     CvMat* get_subtree_weights();
01297     CvMat* get_weak_response();
01298     const CvBoostParams& get_params() const;
01299     const CvDTreeTrainData* get_data() const;
01300 
01301 protected:
01302 
01303     virtual bool set_params( const CvBoostParams& params );
01304     virtual void update_weights( CvBoostTree* tree );
01305     virtual void trim_weights();
01306     virtual void write_params( CvFileStorage* fs ) const;
01307     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
01308 
01309     CvDTreeTrainData* data;
01310     CvBoostParams params;
01311     CvSeq* weak;
01312 
01313     CvMat* active_vars;
01314     CvMat* active_vars_abs;
01315     bool have_active_cat_vars;
01316 
01317     CvMat* orig_response;
01318     CvMat* sum_response;
01319     CvMat* weak_eval;
01320     CvMat* subsample_mask;
01321     CvMat* weights;
01322     CvMat* subtree_weights;
01323     bool have_subsample;
01324 };
01325 
01326 
01327 /****************************************************************************************\
01328 *                                   Gradient Boosted Trees                               *
01329 \****************************************************************************************/
01330 
01331 // DataType: STRUCT CvGBTreesParams
01332 // Parameters of GBT (Gradient Boosted trees model), including single
01333 // tree settings and ensemble parameters.
01334 //
01335 // weak_count          - count of trees in the ensemble
01336 // loss_function_type  - loss function used for ensemble training
01337 // subsample_portion   - portion of whole training set used for
01338 //                       every single tree training.
01339 //                       subsample_portion value is in (0.0, 1.0].
01340 //                       subsample_portion == 1.0 when whole dataset is
01341 //                       used on each step. Count of sample used on each
01342 //                       step is computed as
01343 //                       int(total_samples_count * subsample_portion).
01344 // shrinkage           - regularization parameter.
01345 //                       Each tree prediction is multiplied on shrinkage value.
01346 
01347 
01348 struct CV_EXPORTS_W_MAP CvGBTreesParams : public CvDTreeParams
01349 {
01350     CV_PROP_RW int weak_count;
01351     CV_PROP_RW int loss_function_type;
01352     CV_PROP_RW float subsample_portion;
01353     CV_PROP_RW float shrinkage;
01354 
01355     CvGBTreesParams();
01356     CvGBTreesParams( int loss_function_type, int weak_count, float shrinkage,
01357         float subsample_portion, int max_depth, bool use_surrogates );
01358 };
01359 
01360 // DataType: CLASS CvGBTrees
01361 // Gradient Boosting Trees (GBT) algorithm implementation.
01362 // 
01363 // data             - training dataset
01364 // params           - parameters of the CvGBTrees
01365 // weak             - array[0..(class_count-1)] of CvSeq
01366 //                    for storing tree ensembles
01367 // orig_response    - original responses of the training set samples
01368 // sum_response     - predicitons of the current model on the training dataset.
01369 //                    this matrix is updated on every iteration.
01370 // sum_response_tmp - predicitons of the model on the training set on the next
01371 //                    step. On every iteration values of sum_responses_tmp are
01372 //                    computed via sum_responses values. When the current
01373 //                    step is complete sum_response values become equal to
01374 //                    sum_responses_tmp.
01375 // sampleIdx       - indices of samples used for training the ensemble.
01376 //                    CvGBTrees training procedure takes a set of samples
01377 //                    (train_data) and a set of responses (responses).
01378 //                    Only pairs (train_data[i], responses[i]), where i is
01379 //                    in sample_idx are used for training the ensemble.
01380 // subsample_train  - indices of samples used for training a single decision
01381 //                    tree on the current step. This indices are countered
01382 //                    relatively to the sample_idx, so that pairs
01383 //                    (train_data[sample_idx[i]], responses[sample_idx[i]])
01384 //                    are used for training a decision tree.
01385 //                    Training set is randomly splited
01386 //                    in two parts (subsample_train and subsample_test)
01387 //                    on every iteration accordingly to the portion parameter.
01388 // subsample_test   - relative indices of samples from the training set,
01389 //                    which are not used for training a tree on the current
01390 //                    step.
01391 // missing          - mask of the missing values in the training set. This
01392 //                    matrix has the same size as train_data. 1 - missing
01393 //                    value, 0 - not a missing value.
01394 // class_labels     - output class labels map. 
01395 // rng              - random number generator. Used for spliting the
01396 //                    training set.
01397 // class_count      - count of output classes.
01398 //                    class_count == 1 in the case of regression,
01399 //                    and > 1 in the case of classification.
01400 // delta            - Huber loss function parameter.
01401 // base_value       - start point of the gradient descent procedure.
01402 //                    model prediction is
01403 //                    f(x) = f_0 + sum_{i=1..weak_count-1}(f_i(x)), where
01404 //                    f_0 is the base value.
01405 
01406 
01407 
01408 class CV_EXPORTS_W CvGBTrees : public CvStatModel
01409 {
01410 public:
01411 
01412     /*
01413     // DataType: ENUM
01414     // Loss functions implemented in CvGBTrees.
01415     //  
01416     // SQUARED_LOSS
01417     // problem: regression
01418     // loss = (x - x')^2
01419     // 
01420     // ABSOLUTE_LOSS
01421     // problem: regression
01422     // loss = abs(x - x')
01423     // 
01424     // HUBER_LOSS
01425     // problem: regression
01426     // loss = delta*( abs(x - x') - delta/2), if abs(x - x') > delta
01427     //           1/2*(x - x')^2, if abs(x - x') <= delta,
01428     //           where delta is the alpha-quantile of pseudo responses from
01429     //           the training set.
01430     //
01431     // DEVIANCE_LOSS
01432     // problem: classification
01433     // 
01434     */ 
01435     enum {SQUARED_LOSS=0, ABSOLUTE_LOSS, HUBER_LOSS=3, DEVIANCE_LOSS};
01436  
01437     
01438     /*
01439     // Default constructor. Creates a model only (without training).
01440     // Should be followed by one form of the train(...) function.
01441     //
01442     // API
01443     // CvGBTrees();
01444     
01445     // INPUT
01446     // OUTPUT
01447     // RESULT
01448     */
01449     CV_WRAP CvGBTrees();
01450 
01451 
01452     /*
01453     // Full form constructor. Creates a gradient boosting model and does the
01454     // train.
01455     //
01456     // API
01457     // CvGBTrees( const CvMat* trainData, int tflag,
01458              const CvMat* responses, const CvMat* varIdx=0,
01459              const CvMat* sampleIdx=0, const CvMat* varType=0,
01460              const CvMat* missingDataMask=0,
01461              CvGBTreesParams params=CvGBTreesParams() );
01462     
01463     // INPUT
01464     // trainData    - a set of input feature vectors.
01465     //                  size of matrix is
01466     //                  <count of samples> x <variables count>
01467     //                  or <variables count> x <count of samples>
01468     //                  depending on the tflag parameter.
01469     //                  matrix values are float.
01470     // tflag         - a flag showing how do samples stored in the
01471     //                  trainData matrix row by row (tflag=CV_ROW_SAMPLE)
01472     //                  or column by column (tflag=CV_COL_SAMPLE).
01473     // responses     - a vector of responses corresponding to the samples
01474     //                  in trainData.
01475     // varIdx       - indices of used variables. zero value means that all
01476     //                  variables are active.
01477     // sampleIdx    - indices of used samples. zero value means that all
01478     //                  samples from trainData are in the training set.
01479     // varType      - vector of <variables count> length. gives every
01480     //                  variable type CV_VAR_CATEGORICAL or CV_VAR_ORDERED.
01481     //                  varType = 0 means all variables are numerical.
01482     // missingDataMask  - a mask of misiing values in trainData.
01483     //                  missingDataMask = 0 means that there are no missing
01484     //                  values.
01485     // params         - parameters of GTB algorithm.
01486     // OUTPUT
01487     // RESULT
01488     */
01489     CvGBTrees( const CvMat* trainData, int tflag,
01490              const CvMat* responses, const CvMat* varIdx=0,
01491              const CvMat* sampleIdx=0, const CvMat* varType=0,
01492              const CvMat* missingDataMask=0,
01493              CvGBTreesParams params=CvGBTreesParams() );
01494 
01495              
01496     /*
01497     // Destructor.
01498     */
01499     virtual ~CvGBTrees();
01500     
01501     
01502     /*
01503     // Gradient tree boosting model training
01504     //
01505     // API
01506     // virtual bool train( const CvMat* trainData, int tflag,
01507              const CvMat* responses, const CvMat* varIdx=0,
01508              const CvMat* sampleIdx=0, const CvMat* varType=0,
01509              const CvMat* missingDataMask=0,
01510              CvGBTreesParams params=CvGBTreesParams(),
01511              bool update=false );
01512     
01513     // INPUT
01514     // trainData    - a set of input feature vectors.
01515     //                  size of matrix is
01516     //                  <count of samples> x <variables count>
01517     //                  or <variables count> x <count of samples>
01518     //                  depending on the tflag parameter.
01519     //                  matrix values are float.
01520     // tflag         - a flag showing how do samples stored in the
01521     //                  trainData matrix row by row (tflag=CV_ROW_SAMPLE)
01522     //                  or column by column (tflag=CV_COL_SAMPLE).
01523     // responses     - a vector of responses corresponding to the samples
01524     //                  in trainData.
01525     // varIdx       - indices of used variables. zero value means that all
01526     //                  variables are active.
01527     // sampleIdx    - indices of used samples. zero value means that all
01528     //                  samples from trainData are in the training set.
01529     // varType      - vector of <variables count> length. gives every
01530     //                  variable type CV_VAR_CATEGORICAL or CV_VAR_ORDERED.
01531     //                  varType = 0 means all variables are numerical.
01532     // missingDataMask  - a mask of misiing values in trainData.
01533     //                  missingDataMask = 0 means that there are no missing
01534     //                  values.
01535     // params         - parameters of GTB algorithm.
01536     // update         - is not supported now. (!)
01537     // OUTPUT
01538     // RESULT
01539     // Error state.
01540     */
01541     virtual bool train( const CvMat* trainData, int tflag,
01542              const CvMat* responses, const CvMat* varIdx=0,
01543              const CvMat* sampleIdx=0, const CvMat* varType=0,
01544              const CvMat* missingDataMask=0,
01545              CvGBTreesParams params=CvGBTreesParams(),
01546              bool update=false );
01547     
01548     
01549     /*
01550     // Gradient tree boosting model training
01551     //
01552     // API
01553     // virtual bool train( CvMLData* data,
01554              CvGBTreesParams params=CvGBTreesParams(),
01555              bool update=false ) {return false;};
01556     
01557     // INPUT
01558     // data          - training set.
01559     // params        - parameters of GTB algorithm.
01560     // update        - is not supported now. (!)
01561     // OUTPUT
01562     // RESULT
01563     // Error state.
01564     */
01565     virtual bool train( CvMLData* data,
01566              CvGBTreesParams params=CvGBTreesParams(),
01567              bool update=false );
01568 
01569     
01570     /*
01571     // Response value prediction
01572     //
01573     // API
01574     // virtual float predict_serial( const CvMat* sample, const CvMat* missing=0,
01575              CvMat* weak_responses=0, CvSlice slice = CV_WHOLE_SEQ,
01576              int k=-1 ) const;
01577     
01578     // INPUT
01579     // sample         - input sample of the same type as in the training set.
01580     // missing        - missing values mask. missing=0 if there are no
01581     //                   missing values in sample vector.
01582     // weak_responses  - predictions of all of the trees.
01583     //                   not implemented (!)
01584     // slice           - part of the ensemble used for prediction.
01585     //                   slice = CV_WHOLE_SEQ when all trees are used.
01586     // k               - number of ensemble used.
01587     //                   k is in {-1,0,1,..,<count of output classes-1>}.
01588     //                   in the case of classification problem 
01589     //                   <count of output classes-1> ensembles are built.
01590     //                   If k = -1 ordinary prediction is the result,
01591     //                   otherwise function gives the prediction of the
01592     //                   k-th ensemble only.
01593     // OUTPUT
01594     // RESULT
01595     // Predicted value.
01596     */
01597     virtual float predict_serial( const CvMat* sample, const CvMat* missing=0,
01598             CvMat* weakResponses=0, CvSlice slice = CV_WHOLE_SEQ,
01599             int k=-1 ) const;
01600             
01601     /*
01602     // Response value prediction.
01603     // Parallel version (in the case of TBB existence)
01604     //
01605     // API
01606     // virtual float predict( const CvMat* sample, const CvMat* missing=0,
01607              CvMat* weak_responses=0, CvSlice slice = CV_WHOLE_SEQ,
01608              int k=-1 ) const;
01609     
01610     // INPUT
01611     // sample         - input sample of the same type as in the training set.
01612     // missing        - missing values mask. missing=0 if there are no
01613     //                   missing values in sample vector.
01614     // weak_responses  - predictions of all of the trees.
01615     //                   not implemented (!)
01616     // slice           - part of the ensemble used for prediction.
01617     //                   slice = CV_WHOLE_SEQ when all trees are used.
01618     // k               - number of ensemble used.
01619     //                   k is in {-1,0,1,..,<count of output classes-1>}.
01620     //                   in the case of classification problem 
01621     //                   <count of output classes-1> ensembles are built.
01622     //                   If k = -1 ordinary prediction is the result,
01623     //                   otherwise function gives the prediction of the
01624     //                   k-th ensemble only.
01625     // OUTPUT
01626     // RESULT
01627     // Predicted value.
01628     */        
01629     virtual float predict( const CvMat* sample, const CvMat* missing=0,
01630             CvMat* weakResponses=0, CvSlice slice = CV_WHOLE_SEQ,
01631             int k=-1 ) const;
01632 
01633     /*
01634     // Deletes all the data.
01635     //
01636     // API
01637     // virtual void clear();
01638     
01639     // INPUT
01640     // OUTPUT
01641     // delete data, weak, orig_response, sum_response,
01642     //        weak_eval, subsample_train, subsample_test,
01643     //        sample_idx, missing, lass_labels
01644     // delta = 0.0
01645     // RESULT
01646     */
01647     CV_WRAP virtual void clear();
01648 
01649     /*
01650     // Compute error on the train/test set.
01651     //
01652     // API
01653     // virtual float calc_error( CvMLData* _data, int type,
01654     //        std::vector<float> *resp = 0 );
01655     //
01656     // INPUT
01657     // data  - dataset
01658     // type  - defines which error is to compute: train (CV_TRAIN_ERROR) or
01659     //         test (CV_TEST_ERROR).
01660     // OUTPUT
01661     // resp  - vector of predicitons
01662     // RESULT
01663     // Error value.
01664     */
01665     virtual float calc_error( CvMLData* _data, int type,
01666             std::vector<float> *resp = 0 );
01667 
01668     /*
01669     // 
01670     // Write parameters of the gtb model and data. Write learned model.
01671     //
01672     // API
01673     // virtual void write( CvFileStorage* fs, const char* name ) const;
01674     //
01675     // INPUT
01676     // fs     - file storage to read parameters from.
01677     // name   - model name.
01678     // OUTPUT
01679     // RESULT
01680     */
01681     virtual void write( CvFileStorage* fs, const char* name ) const;
01682 
01683 
01684     /*
01685     // 
01686     // Read parameters of the gtb model and data. Read learned model.
01687     //
01688     // API
01689     // virtual void read( CvFileStorage* fs, CvFileNode* node );
01690     //
01691     // INPUT
01692     // fs     - file storage to read parameters from.
01693     // node   - file node.
01694     // OUTPUT
01695     // RESULT
01696     */
01697     virtual void read( CvFileStorage* fs, CvFileNode* node );
01698 
01699     
01700     // new-style C++ interface
01701     CV_WRAP CvGBTrees( const cv::Mat& trainData, int tflag,
01702               const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
01703               const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
01704               const cv::Mat& missingDataMask=cv::Mat(),
01705               CvGBTreesParams params=CvGBTreesParams() );
01706     
01707     CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
01708                        const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
01709                        const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
01710                        const cv::Mat& missingDataMask=cv::Mat(),
01711                        CvGBTreesParams params=CvGBTreesParams(),
01712                        bool update=false );
01713 
01714     CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing=cv::Mat(),
01715                            const cv::Range& slice = cv::Range::all(),
01716                            int k=-1 ) const;
01717     
01718 protected:
01719 
01720     /*
01721     // Compute the gradient vector components.
01722     //
01723     // API
01724     // virtual void find_gradient( const int k = 0);
01725     
01726     // INPUT
01727     // k        - used for classification problem, determining current
01728     //            tree ensemble.
01729     // OUTPUT
01730     // changes components of data->responses
01731     // which correspond to samples used for training
01732     // on the current step.
01733     // RESULT
01734     */
01735     virtual void find_gradient( const int k = 0);
01736 
01737     
01738     /*
01739     // 
01740     // Change values in tree leaves according to the used loss function.
01741     //
01742     // API
01743     // virtual void change_values(CvDTree* tree, const int k = 0);
01744     //
01745     // INPUT
01746     // tree      - decision tree to change.
01747     // k         - used for classification problem, determining current
01748     //             tree ensemble.
01749     // OUTPUT
01750     // changes 'value' fields of the trees' leaves.
01751     // changes sum_response_tmp.
01752     // RESULT
01753     */
01754     virtual void change_values(CvDTree* tree, const int k = 0);
01755 
01756 
01757     /*
01758     // 
01759     // Find optimal constant prediction value according to the used loss
01760     // function.
01761     // The goal is to find a constant which gives the minimal summary loss
01762     // on the _Idx samples.
01763     //
01764     // API
01765     // virtual float find_optimal_value( const CvMat* _Idx );
01766     //
01767     // INPUT
01768     // _Idx        - indices of the samples from the training set.
01769     // OUTPUT
01770     // RESULT
01771     // optimal constant value.
01772     */
01773     virtual float find_optimal_value( const CvMat* _Idx );
01774 
01775     
01776     /*
01777     // 
01778     // Randomly split the whole training set in two parts according
01779     // to params.portion.
01780     //
01781     // API
01782     // virtual void do_subsample();
01783     //
01784     // INPUT
01785     // OUTPUT
01786     // subsample_train - indices of samples used for training
01787     // subsample_test  - indices of samples used for test
01788     // RESULT
01789     */
01790     virtual void do_subsample();
01791 
01792 
01793     /*
01794     // 
01795     // Internal recursive function giving an array of subtree tree leaves.
01796     //
01797     // API
01798     // void leaves_get( CvDTreeNode** leaves, int& count, CvDTreeNode* node );
01799     //
01800     // INPUT
01801     // node         - current leaf.
01802     // OUTPUT
01803     // count        - count of leaves in the subtree.
01804     // leaves       - array of pointers to leaves.
01805     // RESULT
01806     */
01807     void leaves_get( CvDTreeNode** leaves, int& count, CvDTreeNode* node );
01808     
01809     
01810     /*
01811     // 
01812     // Get leaves of the tree.
01813     //
01814     // API
01815     // CvDTreeNode** GetLeaves( const CvDTree* dtree, int& len );
01816     //
01817     // INPUT
01818     // dtree            - decision tree.
01819     // OUTPUT
01820     // len              - count of the leaves.
01821     // RESULT
01822     // CvDTreeNode**    - array of pointers to leaves.
01823     */
01824     CvDTreeNode** GetLeaves( const CvDTree* dtree, int& len );
01825 
01826     
01827     /*
01828     // 
01829     // Is it a regression or a classification.
01830     //
01831     // API
01832     // bool problem_type();
01833     //
01834     // INPUT
01835     // OUTPUT
01836     // RESULT
01837     // false if it is a classification problem,
01838     // true - if regression.
01839     */
01840     virtual bool problem_type() const;
01841 
01842 
01843     /*
01844     // 
01845     // Write parameters of the gtb model.
01846     //
01847     // API
01848     // virtual void write_params( CvFileStorage* fs ) const;
01849     //
01850     // INPUT
01851     // fs           - file storage to write parameters to.
01852     // OUTPUT
01853     // RESULT
01854     */
01855     virtual void write_params( CvFileStorage* fs ) const;
01856 
01857 
01858     /*
01859     // 
01860     // Read parameters of the gtb model and data.
01861     //
01862     // API
01863     // virtual void read_params( CvFileStorage* fs );
01864     //
01865     // INPUT
01866     // fs           - file storage to read parameters from.
01867     // OUTPUT
01868     // params       - parameters of the gtb model.
01869     // data         - contains information about the structure
01870     //                of the data set (count of variables,
01871     //                their types, etc.).
01872     // class_labels - output class labels map.
01873     // RESULT
01874     */
01875     virtual void read_params( CvFileStorage* fs, CvFileNode* fnode );
01876     int get_len(const CvMat* mat) const;
01877 
01878     
01879     CvDTreeTrainData* data;
01880     CvGBTreesParams params;
01881 
01882     CvSeq** weak;
01883     CvMat* orig_response;
01884     CvMat* sum_response;
01885     CvMat* sum_response_tmp;
01886     CvMat* sample_idx;
01887     CvMat* subsample_train;
01888     CvMat* subsample_test;
01889     CvMat* missing;
01890     CvMat* class_labels;
01891 
01892     cv::RNG* rng;
01893 
01894     int class_count;
01895     float delta;
01896     float base_value;
01897 
01898 };
01899 
01900 
01901 
01902 /****************************************************************************************\
01903 *                              Artificial Neural Networks (ANN)                          *
01904 \****************************************************************************************/
01905 
01907 
01908 struct CV_EXPORTS_W_MAP CvANN_MLP_TrainParams
01909 {
01910     CvANN_MLP_TrainParams();
01911     CvANN_MLP_TrainParams( CvTermCriteria term_crit, int train_method,
01912                            double param1, double param2=0 );
01913     ~CvANN_MLP_TrainParams();
01914 
01915     enum { BACKPROP=0, RPROP=1 };
01916 
01917     CV_PROP_RW CvTermCriteria term_crit;
01918     CV_PROP_RW int train_method;
01919 
01920     // backpropagation parameters
01921     CV_PROP_RW double bp_dw_scale, bp_moment_scale;
01922 
01923     // rprop parameters
01924     CV_PROP_RW double rp_dw0, rp_dw_plus, rp_dw_minus, rp_dw_min, rp_dw_max;
01925 };
01926 
01927 
01928 class CV_EXPORTS_W CvANN_MLP : public CvStatModel
01929 {
01930 public:
01931     CV_WRAP CvANN_MLP();
01932     CvANN_MLP( const CvMat* layerSizes,
01933                int activateFunc=CvANN_MLP::SIGMOID_SYM,
01934                double fparam1=0, double fparam2=0 );
01935 
01936     virtual ~CvANN_MLP();
01937 
01938     virtual void create( const CvMat* layerSizes,
01939                          int activateFunc=CvANN_MLP::SIGMOID_SYM,
01940                          double fparam1=0, double fparam2=0 );
01941     
01942     virtual int train( const CvMat* inputs, const CvMat* outputs,
01943                        const CvMat* sampleWeights, const CvMat* sampleIdx=0,
01944                        CvANN_MLP_TrainParams params = CvANN_MLP_TrainParams(),
01945                        int flags=0 );
01946     virtual float predict( const CvMat* inputs, CV_OUT CvMat* outputs ) const;
01947     
01948 #ifndef SWIG
01949     CV_WRAP CvANN_MLP( const cv::Mat& layerSizes,
01950               int activateFunc=CvANN_MLP::SIGMOID_SYM,
01951               double fparam1=0, double fparam2=0 );
01952     
01953     CV_WRAP virtual void create( const cv::Mat& layerSizes,
01954                         int activateFunc=CvANN_MLP::SIGMOID_SYM,
01955                         double fparam1=0, double fparam2=0 );    
01956     
01957     CV_WRAP virtual int train( const cv::Mat& inputs, const cv::Mat& outputs,
01958                       const cv::Mat& sampleWeights, const cv::Mat& sampleIdx=cv::Mat(),
01959                       CvANN_MLP_TrainParams params = CvANN_MLP_TrainParams(),
01960                       int flags=0 );    
01961     
01962     CV_WRAP virtual float predict( const cv::Mat& inputs, cv::Mat& outputs ) const;
01963 #endif
01964     
01965     CV_WRAP virtual void clear();
01966 
01967     // possible activation functions
01968     enum { IDENTITY = 0, SIGMOID_SYM = 1, GAUSSIAN = 2 };
01969 
01970     // available training flags
01971     enum { UPDATE_WEIGHTS = 1, NO_INPUT_SCALE = 2, NO_OUTPUT_SCALE = 4 };
01972 
01973     virtual void read( CvFileStorage* fs, CvFileNode* node );
01974     virtual void write( CvFileStorage* storage, const char* name ) const;
01975 
01976     int get_layer_count() { return layer_sizes ? layer_sizes->cols : 0; }
01977     const CvMat* get_layer_sizes() { return layer_sizes; }
01978     double* get_weights(int layer)
01979     {
01980         return layer_sizes && weights &&
01981             (unsigned)layer <= (unsigned)layer_sizes->cols ? weights[layer] : 0;
01982     }
01983 
01984     virtual void calc_activ_func_deriv( CvMat* xf, CvMat* deriv, const double* bias ) const;
01985 
01986 protected:
01987 
01988     virtual bool prepare_to_train( const CvMat* _inputs, const CvMat* _outputs,
01989             const CvMat* _sample_weights, const CvMat* sampleIdx,
01990             CvVectors* _ivecs, CvVectors* _ovecs, double** _sw, int _flags );
01991 
01992     // sequential random backpropagation
01993     virtual int train_backprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
01994 
01995     // RPROP algorithm
01996     virtual int train_rprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
01997 
01998     virtual void calc_activ_func( CvMat* xf, const double* bias ) const;
01999     virtual void set_activ_func( int _activ_func=SIGMOID_SYM,
02000                                  double _f_param1=0, double _f_param2=0 );
02001     virtual void init_weights();
02002     virtual void scale_input( const CvMat* _src, CvMat* _dst ) const;
02003     virtual void scale_output( const CvMat* _src, CvMat* _dst ) const;
02004     virtual void calc_input_scale( const CvVectors* vecs, int flags );
02005     virtual void calc_output_scale( const CvVectors* vecs, int flags );
02006 
02007     virtual void write_params( CvFileStorage* fs ) const;
02008     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
02009 
02010     CvMat* layer_sizes;
02011     CvMat* wbuf;
02012     CvMat* sample_weights;
02013     double** weights;
02014     double f_param1, f_param2;
02015     double min_val, max_val, min_val1, max_val1;
02016     int activ_func;
02017     int max_count, max_buf_sz;
02018     CvANN_MLP_TrainParams params;
02019     cv::RNG* rng;
02020 };
02021 
02022 /****************************************************************************************\
02023 *                           Auxilary functions declarations                              *
02024 \****************************************************************************************/
02025 
02026 /* Generates <sample> from multivariate normal distribution, where <mean> - is an
02027    average row vector, <cov> - symmetric covariation matrix */
02028 CVAPI(void) cvRandMVNormal( CvMat* mean, CvMat* cov, CvMat* sample,
02029                            CvRNG* rng CV_DEFAULT(0) );
02030 
02031 /* Generates sample from gaussian mixture distribution */
02032 CVAPI(void) cvRandGaussMixture( CvMat* means[],
02033                                CvMat* covs[],
02034                                float weights[],
02035                                int clsnum,
02036                                CvMat* sample,
02037                                CvMat* sampClasses CV_DEFAULT(0) );
02038 
02039 #define CV_TS_CONCENTRIC_SPHERES 0
02040 
02041 /* creates test set */
02042 CVAPI(void) cvCreateTestSet( int type, CvMat** samples,
02043                  int num_samples,
02044                  int num_features,
02045                  CvMat** responses,
02046                  int num_classes, ... );
02047 
02048 
02049 #endif
02050 
02051 /****************************************************************************************\
02052 *                                      Data                                             *
02053 \****************************************************************************************/
02054 
02055 #include <map>
02056 #include <string>
02057 #include <iostream>
02058 
02059 #define CV_COUNT     0
02060 #define CV_PORTION   1
02061 
02062 struct CV_EXPORTS CvTrainTestSplit
02063 {
02064     CvTrainTestSplit();
02065     CvTrainTestSplit( int train_sample_count, bool mix = true);
02066     CvTrainTestSplit( float train_sample_portion, bool mix = true);
02067 
02068     union
02069     {
02070         int count;
02071         float portion;
02072     } train_sample_part;
02073     int train_sample_part_mode;
02074 
02075     bool mix;
02076 };
02077 
02078 class CV_EXPORTS CvMLData
02079 {
02080 public:
02081     CvMLData();
02082     virtual ~CvMLData();
02083 
02084     // returns:
02085     // 0 - OK  
02086     // 1 - file can not be opened or is not correct
02087     int read_csv( const char* filename );
02088 
02089     const CvMat* get_values() const;
02090     const CvMat* get_responses();
02091     const CvMat* get_missing() const;
02092 
02093     void set_response_idx( int idx ); // old response become predictors, new response_idx = idx
02094                                       // if idx < 0 there will be no response
02095     int get_response_idx() const;
02096 
02097     void set_train_test_split( const CvTrainTestSplit * spl );
02098     const CvMat* get_train_sample_idx() const;
02099     const CvMat* get_test_sample_idx() const;
02100     void mix_train_and_test_idx();
02101     
02102     const CvMat* get_var_idx();
02103     void chahge_var_idx( int vi, bool state ); // state == true to set vi-variable as predictor
02104 
02105     const CvMat* get_var_types();
02106     int get_var_type( int var_idx ) const;
02107     // following 2 methods enable to change vars type
02108     // use these methods to assign CV_VAR_CATEGORICAL type for categorical variable
02109     // with numerical labels; in the other cases var types are correctly determined automatically
02110     void set_var_types( const char* str );  // str examples:
02111                                             // "ord[0-17],cat[18]", "ord[0,2,4,10-12], cat[1,3,5-9,13,14]",
02112                                             // "cat", "ord" (all vars are categorical/ordered)
02113     void change_var_type( int var_idx, int type); // type in { CV_VAR_ORDERED, CV_VAR_CATEGORICAL }    
02114  
02115     void set_delimiter( char ch );
02116     char get_delimiter() const;
02117 
02118     void set_miss_ch( char ch );
02119     char get_miss_ch() const;
02120     
02121     const std::map<std::string, int>& get_class_labels_map() const;
02122 
02123 protected:
02124     virtual void clear();
02125 
02126     void str_to_flt_elem( const char* token, float& flt_elem, int& type);
02127     void free_train_test_idx();
02128     
02129     char delimiter;
02130     char miss_ch;
02131     //char flt_separator;
02132 
02133     CvMat* values;
02134     CvMat* missing;
02135     CvMat* var_types;
02136     CvMat* var_idx_mask;
02137 
02138     CvMat* response_out; // header
02139     CvMat* var_idx_out; // mat
02140     CvMat* var_types_out; // mat
02141 
02142     int response_idx;
02143 
02144     int train_sample_count;
02145     bool mix;
02146    
02147     int total_class_count;
02148     std::map<std::string, int> class_map;
02149 
02150     CvMat* train_sample_idx;
02151     CvMat* test_sample_idx;
02152     int* sample_idx; // data of train_sample_idx and test_sample_idx
02153 
02154     cv::RNG* rng;
02155 };
02156 
02157 
02158 namespace cv
02159 {
02160     
02161 typedef CvStatModel StatModel;
02162 typedef CvParamGrid ParamGrid;
02163 typedef CvNormalBayesClassifier NormalBayesClassifier;
02164 typedef CvKNearest KNearest;
02165 typedef CvSVMParams SVMParams;
02166 typedef CvSVMKernel SVMKernel;
02167 typedef CvSVMSolver SVMSolver;
02168 typedef CvSVM SVM;
02169 typedef CvEMParams EMParams;
02170 typedef CvEM ExpectationMaximization;
02171 typedef CvDTreeParams DTreeParams;
02172 typedef CvMLData TrainData;
02173 typedef CvDTree DecisionTree;
02174 typedef CvForestTree ForestTree;
02175 typedef CvRTParams RandomTreeParams;
02176 typedef CvRTrees RandomTrees;
02177 typedef CvERTreeTrainData ERTreeTRainData;
02178 typedef CvForestERTree ERTree;
02179 typedef CvERTrees ERTrees;
02180 typedef CvBoostParams BoostParams;
02181 typedef CvBoostTree BoostTree;
02182 typedef CvBoost Boost;
02183 typedef CvANN_MLP_TrainParams ANN_MLP_TrainParams;
02184 typedef CvANN_MLP NeuralNet_MLP;
02185 typedef CvGBTreesParams GradientBoostingTreeParams;
02186 typedef CvGBTrees GradientBoostingTrees;
02187 
02188 template<> CV_EXPORTS void Ptr<CvDTreeSplit>::delete_obj();
02189     
02190 }
02191 
02192 #endif
02193 /* End of file. */