Cinder

  • Main Page
  • Related Pages
  • Namespaces
  • Classes
  • Files
  • File List
  • File Members

include/OpenCV/ml.h

Go to the documentation of this file.
00001 /*M///////////////////////////////////////////////////////////////////////////////////////
00002 //
00003 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
00004 //
00005 //  By downloading, copying, installing or using the software you agree to this license.
00006 //  If you do not agree to this license, do not download, install,
00007 //  copy or use the software.
00008 //
00009 //
00010 //                        Intel License Agreement
00011 //
00012 // Copyright (C) 2000, Intel Corporation, all rights reserved.
00013 // Third party copyrights are property of their respective owners.
00014 //
00015 // Redistribution and use in source and binary forms, with or without modification,
00016 // are permitted provided that the following conditions are met:
00017 //
00018 //   * Redistribution's of source code must retain the above copyright notice,
00019 //     this list of conditions and the following disclaimer.
00020 //
00021 //   * Redistribution's in binary form must reproduce the above copyright notice,
00022 //     this list of conditions and the following disclaimer in the documentation
00023 //     and/or other materials provided with the distribution.
00024 //
00025 //   * The name of Intel Corporation may not be used to endorse or promote products
00026 //     derived from this software without specific prior written permission.
00027 //
00028 // This software is provided by the copyright holders and contributors "as is" and
00029 // any express or implied warranties, including, but not limited to, the implied
00030 // warranties of merchantability and fitness for a particular purpose are disclaimed.
00031 // In no event shall the Intel Corporation or contributors be liable for any direct,
00032 // indirect, incidental, special, exemplary, or consequential damages
00033 // (including, but not limited to, procurement of substitute goods or services;
00034 // loss of use, data, or profits; or business interruption) however caused
00035 // and on any theory of liability, whether in contract, strict liability,
00036 // or tort (including negligence or otherwise) arising in any way out of
00037 // the use of this software, even if advised of the possibility of such damage.
00038 //
00039 //M*/
00040 
00041 #ifndef __OPENCV_ML_H__
00042 #define __OPENCV_ML_H__
00043 
00044 // disable deprecation warning which appears in VisualStudio 8.0
00045 #if _MSC_VER >= 1400
00046 #pragma warning( disable : 4996 )
00047 #endif
00048 
00049 #ifndef SKIP_INCLUDES
00050 
00051   #include "cxcore.h"
00052   #include <limits.h>
00053 
00054   #if defined WIN32 || defined _WIN32 || defined WIN64 || defined _WIN64
00055     #include <windows.h>
00056   #endif
00057 
00058 #else // SKIP_INCLUDES
00059 
00060   #if defined WIN32 || defined _WIN32 || defined WIN64 || defined _WIN64
00061     #define CV_CDECL __cdecl
00062     #define CV_STDCALL __stdcall
00063   #else
00064     #define CV_CDECL
00065     #define CV_STDCALL
00066   #endif
00067 
00068   #ifndef CV_EXTERN_C
00069     #ifdef __cplusplus
00070       #define CV_EXTERN_C extern "C"
00071       #define CV_DEFAULT(val) = val
00072     #else
00073       #define CV_EXTERN_C
00074       #define CV_DEFAULT(val)
00075     #endif
00076   #endif
00077 
00078   #ifndef CV_EXTERN_C_FUNCPTR
00079     #ifdef __cplusplus
00080       #define CV_EXTERN_C_FUNCPTR(x) extern "C" { typedef x; }
00081     #else
00082       #define CV_EXTERN_C_FUNCPTR(x) typedef x
00083     #endif
00084   #endif
00085 
00086   #ifndef CV_INLINE
00087     #if defined __cplusplus
00088       #define CV_INLINE inline
00089     #elif (defined WIN32 || defined _WIN32 || defined WIN64 || defined _WIN64) && !defined __GNUC__
00090       #define CV_INLINE __inline
00091     #else
00092       #define CV_INLINE static
00093     #endif
00094   #endif /* CV_INLINE */
00095 
00096   #if (defined WIN32 || defined _WIN32 || defined WIN64 || defined _WIN64) && defined CVAPI_EXPORTS
00097     #define CV_EXPORTS __declspec(dllexport)
00098   #else
00099     #define CV_EXPORTS
00100   #endif
00101 
00102   #ifndef CVAPI
00103     #define CVAPI(rettype) CV_EXTERN_C CV_EXPORTS rettype CV_CDECL
00104   #endif
00105 
00106 #endif // SKIP_INCLUDES
00107 
00108 
00109 #ifdef __cplusplus
00110 
00111 // Apple defines a check() macro somewhere in the debug headers
00112 // that interferes with a method definiton in this header
00113 #undef check
00114 
00115 #include "cvinternal.h"
00116 
00117 /****************************************************************************************\
00118 *                               Main struct definitions                                  *
00119 \****************************************************************************************/
00120 
00121 /* log(2*PI) */
00122 #define CV_LOG2PI (1.8378770664093454835606594728112)
00123 
00124 /* columns of <trainData> matrix are training samples */
00125 #define CV_COL_SAMPLE 0
00126 
00127 /* rows of <trainData> matrix are training samples */
00128 #define CV_ROW_SAMPLE 1
00129 
00130 #define CV_IS_ROW_SAMPLE(flags) ((flags) & CV_ROW_SAMPLE)
00131 
00132 struct CvVectors
00133 {
00134     int type;
00135     int dims, count;
00136     CvVectors* next;
00137     union
00138     {
00139         uchar** ptr;
00140         float** fl;
00141         double** db;
00142     } data;
00143 };
00144 
00145 #if 0
00146 /* A structure, representing the lattice range of statmodel parameters.
00147    It is used for optimizing statmodel parameters by cross-validation method.
00148    The lattice is logarithmic, so <step> must be greater then 1. */
00149 typedef struct CvParamLattice
00150 {
00151     double min_val;
00152     double max_val;
00153     double step;
00154 }
00155 CvParamLattice;
00156 
00157 CV_INLINE CvParamLattice cvParamLattice( double min_val, double max_val,
00158                                          double log_step )
00159 {
00160     CvParamLattice pl;
00161     pl.min_val = MIN( min_val, max_val );
00162     pl.max_val = MAX( min_val, max_val );
00163     pl.step = MAX( log_step, 1. );
00164     return pl;
00165 }
00166 
00167 CV_INLINE CvParamLattice cvDefaultParamLattice( void )
00168 {
00169     CvParamLattice pl = {0,0,0};
00170     return pl;
00171 }
00172 #endif
00173 
00174 /* Variable type */
00175 #define CV_VAR_NUMERICAL    0
00176 #define CV_VAR_ORDERED      0
00177 #define CV_VAR_CATEGORICAL  1
00178 
00179 #define CV_TYPE_NAME_ML_SVM         "opencv-ml-svm"
00180 #define CV_TYPE_NAME_ML_KNN         "opencv-ml-knn"
00181 #define CV_TYPE_NAME_ML_NBAYES      "opencv-ml-bayesian"
00182 #define CV_TYPE_NAME_ML_EM          "opencv-ml-em"
00183 #define CV_TYPE_NAME_ML_BOOSTING    "opencv-ml-boost-tree"
00184 #define CV_TYPE_NAME_ML_TREE        "opencv-ml-tree"
00185 #define CV_TYPE_NAME_ML_ANN_MLP     "opencv-ml-ann-mlp"
00186 #define CV_TYPE_NAME_ML_CNN         "opencv-ml-cnn"
00187 #define CV_TYPE_NAME_ML_RTREES      "opencv-ml-random-trees"
00188 
00189 #define CV_TRAIN_ERROR  0
00190 #define CV_TEST_ERROR   1
00191 
00192 class CV_EXPORTS CvStatModel
00193 {
00194 public:
00195     CvStatModel();
00196     virtual ~CvStatModel();
00197 
00198     virtual void clear();
00199 
00200     virtual void save( const char* filename, const char* name=0 ) const;
00201     virtual void load( const char* filename, const char* name=0 );
00202 
00203     virtual void write( CvFileStorage* storage, const char* name ) const;
00204     virtual void read( CvFileStorage* storage, CvFileNode* node );
00205 
00206 protected:
00207     const char* default_model_name;
00208 };
00209 
00210 /****************************************************************************************\
00211 *                                 Normal Bayes Classifier                                *
00212 \****************************************************************************************/
00213 
00214 /* The structure, representing the grid range of statmodel parameters.
00215    It is used for optimizing statmodel accuracy by varying model parameters,
00216    the accuracy estimate being computed by cross-validation.
00217    The grid is logarithmic, so <step> must be greater then 1. */
00218 
00219 class CvMLData;
00220 
00221 struct CV_EXPORTS CvParamGrid
00222 {
00223     // SVM params type
00224     enum { SVM_C=0, SVM_GAMMA=1, SVM_P=2, SVM_NU=3, SVM_COEF=4, SVM_DEGREE=5 };
00225 
00226     CvParamGrid()
00227     {
00228         min_val = max_val = step = 0;
00229     }
00230 
00231     CvParamGrid( double _min_val, double _max_val, double log_step )
00232     {
00233         min_val = _min_val;
00234         max_val = _max_val;
00235         step = log_step;
00236     }
00237     //CvParamGrid( int param_id );
00238     bool check() const;
00239 
00240     double min_val;
00241     double max_val;
00242     double step;
00243 };
00244 
00245 class CV_EXPORTS CvNormalBayesClassifier : public CvStatModel
00246 {
00247 public:
00248     CvNormalBayesClassifier();
00249     virtual ~CvNormalBayesClassifier();
00250 
00251     CvNormalBayesClassifier( const CvMat* _train_data, const CvMat* _responses,
00252         const CvMat* _var_idx=0, const CvMat* _sample_idx=0 );
00253     
00254     virtual bool train( const CvMat* _train_data, const CvMat* _responses,
00255         const CvMat* _var_idx = 0, const CvMat* _sample_idx=0, bool update=false );
00256    
00257     virtual float predict( const CvMat* _samples, CvMat* results=0 ) const;
00258     virtual void clear();
00259 
00260 #ifndef SWIG
00261     CvNormalBayesClassifier( const cv::Mat& _train_data, const cv::Mat& _responses,
00262                             const cv::Mat& _var_idx=cv::Mat(), const cv::Mat& _sample_idx=cv::Mat() );
00263     virtual bool train( const cv::Mat& _train_data, const cv::Mat& _responses,
00264                        const cv::Mat& _var_idx = cv::Mat(), const cv::Mat& _sample_idx=cv::Mat(),
00265                        bool update=false );
00266     virtual float predict( const cv::Mat& _samples, cv::Mat* results=0 ) const;
00267 #endif
00268     
00269     virtual void write( CvFileStorage* storage, const char* name ) const;
00270     virtual void read( CvFileStorage* storage, CvFileNode* node );
00271 
00272 protected:
00273     int     var_count, var_all;
00274     CvMat*  var_idx;
00275     CvMat*  cls_labels;
00276     CvMat** count;
00277     CvMat** sum;
00278     CvMat** productsum;
00279     CvMat** avg;
00280     CvMat** inv_eigen_values;
00281     CvMat** cov_rotate_mats;
00282     CvMat*  c;
00283 };
00284 
00285 
00286 /****************************************************************************************\
00287 *                          K-Nearest Neighbour Classifier                                *
00288 \****************************************************************************************/
00289 
00290 // k Nearest Neighbors
00291 class CV_EXPORTS CvKNearest : public CvStatModel
00292 {
00293 public:
00294 
00295     CvKNearest();
00296     virtual ~CvKNearest();
00297 
00298     CvKNearest( const CvMat* _train_data, const CvMat* _responses,
00299                 const CvMat* _sample_idx=0, bool _is_regression=false, int max_k=32 );
00300 
00301     virtual bool train( const CvMat* _train_data, const CvMat* _responses,
00302                         const CvMat* _sample_idx=0, bool is_regression=false,
00303                         int _max_k=32, bool _update_base=false );
00304     
00305     virtual float find_nearest( const CvMat* _samples, int k, CvMat* results=0,
00306         const float** neighbors=0, CvMat* neighbor_responses=0, CvMat* dist=0 ) const;
00307     
00308 #ifndef SWIG
00309     CvKNearest( const cv::Mat& _train_data, const cv::Mat& _responses,
00310                const cv::Mat& _sample_idx=cv::Mat(), bool _is_regression=false, int max_k=32 );
00311     
00312     virtual bool train( const cv::Mat& _train_data, const cv::Mat& _responses,
00313                        const cv::Mat& _sample_idx=cv::Mat(), bool is_regression=false,
00314                        int _max_k=32, bool _update_base=false );    
00315     
00316     virtual float find_nearest( const cv::Mat& _samples, int k, cv::Mat* results=0,
00317                                 const float** neighbors=0,
00318                                 cv::Mat* neighbor_responses=0,
00319                                 cv::Mat* dist=0 ) const;
00320 #endif
00321     
00322     virtual void clear();
00323     int get_max_k() const;
00324     int get_var_count() const;
00325     int get_sample_count() const;
00326     bool is_regression() const;
00327 
00328 protected:
00329 
00330     virtual float write_results( int k, int k1, int start, int end,
00331         const float* neighbor_responses, const float* dist, CvMat* _results,
00332         CvMat* _neighbor_responses, CvMat* _dist, Cv32suf* sort_buf ) const;
00333 
00334     virtual void find_neighbors_direct( const CvMat* _samples, int k, int start, int end,
00335         float* neighbor_responses, const float** neighbors, float* dist ) const;
00336 
00337 
00338     int max_k, var_count;
00339     int total;
00340     bool regression;
00341     CvVectors* samples;
00342 };
00343 
00344 /****************************************************************************************\
00345 *                                   Support Vector Machines                              *
00346 \****************************************************************************************/
00347 
00348 // SVM training parameters
00349 struct CV_EXPORTS CvSVMParams
00350 {
00351     CvSVMParams();
00352     CvSVMParams( int _svm_type, int _kernel_type,
00353                  double _degree, double _gamma, double _coef0,
00354                  double Cvalue, double _nu, double _p,
00355                  CvMat* _class_weights, CvTermCriteria _term_crit );
00356 
00357     int         svm_type;
00358     int         kernel_type;
00359     double      degree; // for poly
00360     double      gamma;  // for poly/rbf/sigmoid
00361     double      coef0;  // for poly/sigmoid
00362 
00363     double      C;  // for CV_SVM_C_SVC, CV_SVM_EPS_SVR and CV_SVM_NU_SVR
00364     double      nu; // for CV_SVM_NU_SVC, CV_SVM_ONE_CLASS, and CV_SVM_NU_SVR
00365     double      p; // for CV_SVM_EPS_SVR
00366     CvMat*      class_weights; // for CV_SVM_C_SVC
00367     CvTermCriteria term_crit; // termination criteria
00368 };
00369 
00370 
00371 struct CV_EXPORTS CvSVMKernel
00372 {
00373     typedef void (CvSVMKernel::*Calc)( int vec_count, int vec_size, const float** vecs,
00374                                        const float* another, float* results );
00375     CvSVMKernel();
00376     CvSVMKernel( const CvSVMParams* _params, Calc _calc_func );
00377     virtual bool create( const CvSVMParams* _params, Calc _calc_func );
00378     virtual ~CvSVMKernel();
00379 
00380     virtual void clear();
00381     virtual void calc( int vcount, int n, const float** vecs, const float* another, float* results );
00382 
00383     const CvSVMParams* params;
00384     Calc calc_func;
00385 
00386     virtual void calc_non_rbf_base( int vec_count, int vec_size, const float** vecs,
00387                                     const float* another, float* results,
00388                                     double alpha, double beta );
00389 
00390     virtual void calc_linear( int vec_count, int vec_size, const float** vecs,
00391                               const float* another, float* results );
00392     virtual void calc_rbf( int vec_count, int vec_size, const float** vecs,
00393                            const float* another, float* results );
00394     virtual void calc_poly( int vec_count, int vec_size, const float** vecs,
00395                             const float* another, float* results );
00396     virtual void calc_sigmoid( int vec_count, int vec_size, const float** vecs,
00397                                const float* another, float* results );
00398 };
00399 
00400 
00401 struct CvSVMKernelRow
00402 {
00403     CvSVMKernelRow* prev;
00404     CvSVMKernelRow* next;
00405     float* data;
00406 };
00407 
00408 
00409 struct CvSVMSolutionInfo
00410 {
00411     double obj;
00412     double rho;
00413     double upper_bound_p;
00414     double upper_bound_n;
00415     double r;   // for Solver_NU
00416 };
00417 
00418 class CV_EXPORTS CvSVMSolver
00419 {
00420 public:
00421     typedef bool (CvSVMSolver::*SelectWorkingSet)( int& i, int& j );
00422     typedef float* (CvSVMSolver::*GetRow)( int i, float* row, float* dst, bool existed );
00423     typedef void (CvSVMSolver::*CalcRho)( double& rho, double& r );
00424 
00425     CvSVMSolver();
00426 
00427     CvSVMSolver( int count, int var_count, const float** samples, schar* y,
00428                  int alpha_count, double* alpha, double Cp, double Cn,
00429                  CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
00430                  SelectWorkingSet select_working_set, CalcRho calc_rho );
00431     virtual bool create( int count, int var_count, const float** samples, schar* y,
00432                  int alpha_count, double* alpha, double Cp, double Cn,
00433                  CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
00434                  SelectWorkingSet select_working_set, CalcRho calc_rho );
00435     virtual ~CvSVMSolver();
00436 
00437     virtual void clear();
00438     virtual bool solve_generic( CvSVMSolutionInfo& si );
00439 
00440     virtual bool solve_c_svc( int count, int var_count, const float** samples, schar* y,
00441                               double Cp, double Cn, CvMemStorage* storage,
00442                               CvSVMKernel* kernel, double* alpha, CvSVMSolutionInfo& si );
00443     virtual bool solve_nu_svc( int count, int var_count, const float** samples, schar* y,
00444                                CvMemStorage* storage, CvSVMKernel* kernel,
00445                                double* alpha, CvSVMSolutionInfo& si );
00446     virtual bool solve_one_class( int count, int var_count, const float** samples,
00447                                   CvMemStorage* storage, CvSVMKernel* kernel,
00448                                   double* alpha, CvSVMSolutionInfo& si );
00449 
00450     virtual bool solve_eps_svr( int count, int var_count, const float** samples, const float* y,
00451                                 CvMemStorage* storage, CvSVMKernel* kernel,
00452                                 double* alpha, CvSVMSolutionInfo& si );
00453 
00454     virtual bool solve_nu_svr( int count, int var_count, const float** samples, const float* y,
00455                                CvMemStorage* storage, CvSVMKernel* kernel,
00456                                double* alpha, CvSVMSolutionInfo& si );
00457 
00458     virtual float* get_row_base( int i, bool* _existed );
00459     virtual float* get_row( int i, float* dst );
00460 
00461     int sample_count;
00462     int var_count;
00463     int cache_size;
00464     int cache_line_size;
00465     const float** samples;
00466     const CvSVMParams* params;
00467     CvMemStorage* storage;
00468     CvSVMKernelRow lru_list;
00469     CvSVMKernelRow* rows;
00470 
00471     int alpha_count;
00472 
00473     double* G;
00474     double* alpha;
00475 
00476     // -1 - lower bound, 0 - free, 1 - upper bound
00477     schar* alpha_status;
00478 
00479     schar* y;
00480     double* b;
00481     float* buf[2];
00482     double eps;
00483     int max_iter;
00484     double C[2];  // C[0] == Cn, C[1] == Cp
00485     CvSVMKernel* kernel;
00486 
00487     SelectWorkingSet select_working_set_func;
00488     CalcRho calc_rho_func;
00489     GetRow get_row_func;
00490 
00491     virtual bool select_working_set( int& i, int& j );
00492     virtual bool select_working_set_nu_svm( int& i, int& j );
00493     virtual void calc_rho( double& rho, double& r );
00494     virtual void calc_rho_nu_svm( double& rho, double& r );
00495 
00496     virtual float* get_row_svc( int i, float* row, float* dst, bool existed );
00497     virtual float* get_row_one_class( int i, float* row, float* dst, bool existed );
00498     virtual float* get_row_svr( int i, float* row, float* dst, bool existed );
00499 };
00500 
00501 
00502 struct CvSVMDecisionFunc
00503 {
00504     double rho;
00505     int sv_count;
00506     double* alpha;
00507     int* sv_index;
00508 };
00509 
00510 
00511 // SVM model
00512 class CV_EXPORTS CvSVM : public CvStatModel
00513 {
00514 public:
00515     // SVM type
00516     enum { C_SVC=100, NU_SVC=101, ONE_CLASS=102, EPS_SVR=103, NU_SVR=104 };
00517 
00518     // SVM kernel type
00519     enum { LINEAR=0, POLY=1, RBF=2, SIGMOID=3 };
00520 
00521     // SVM params type
00522     enum { C=0, GAMMA=1, P=2, NU=3, COEF=4, DEGREE=5 };
00523 
00524     CvSVM();
00525     virtual ~CvSVM();
00526 
00527     CvSVM( const CvMat* _train_data, const CvMat* _responses,
00528            const CvMat* _var_idx=0, const CvMat* _sample_idx=0,
00529            CvSVMParams _params=CvSVMParams() );
00530 
00531     virtual bool train( const CvMat* _train_data, const CvMat* _responses,
00532                         const CvMat* _var_idx=0, const CvMat* _sample_idx=0,
00533                         CvSVMParams _params=CvSVMParams() );
00534     
00535     virtual bool train_auto( const CvMat* _train_data, const CvMat* _responses,
00536         const CvMat* _var_idx, const CvMat* _sample_idx, CvSVMParams _params,
00537         int k_fold = 10,
00538         CvParamGrid C_grid      = get_default_grid(CvSVM::C),
00539         CvParamGrid gamma_grid  = get_default_grid(CvSVM::GAMMA),
00540         CvParamGrid p_grid      = get_default_grid(CvSVM::P),
00541         CvParamGrid nu_grid     = get_default_grid(CvSVM::NU),
00542         CvParamGrid coef_grid   = get_default_grid(CvSVM::COEF),
00543         CvParamGrid degree_grid = get_default_grid(CvSVM::DEGREE) );
00544 
00545     virtual float predict( const CvMat* _sample, bool returnDFVal=false ) const;
00546 
00547 #ifndef SWIG
00548     CvSVM( const cv::Mat& _train_data, const cv::Mat& _responses,
00549           const cv::Mat& _var_idx=cv::Mat(), const cv::Mat& _sample_idx=cv::Mat(),
00550           CvSVMParams _params=CvSVMParams() );
00551     
00552     virtual bool train( const cv::Mat& _train_data, const cv::Mat& _responses,
00553                        const cv::Mat& _var_idx=cv::Mat(), const cv::Mat& _sample_idx=cv::Mat(),
00554                        CvSVMParams _params=CvSVMParams() );
00555     
00556     virtual bool train_auto( const cv::Mat& _train_data, const cv::Mat& _responses,
00557                             const cv::Mat& _var_idx, const cv::Mat& _sample_idx, CvSVMParams _params,
00558                             int k_fold = 10,
00559                             CvParamGrid C_grid      = get_default_grid(CvSVM::C),
00560                             CvParamGrid gamma_grid  = get_default_grid(CvSVM::GAMMA),
00561                             CvParamGrid p_grid      = get_default_grid(CvSVM::P),
00562                             CvParamGrid nu_grid     = get_default_grid(CvSVM::NU),
00563                             CvParamGrid coef_grid   = get_default_grid(CvSVM::COEF),
00564                             CvParamGrid degree_grid = get_default_grid(CvSVM::DEGREE) );
00565     virtual float predict( const cv::Mat& _sample, bool returnDFVal=false ) const;    
00566 #endif
00567     
00568     virtual int get_support_vector_count() const;
00569     virtual const float* get_support_vector(int i) const;
00570     virtual CvSVMParams get_params() const { return params; };
00571     virtual void clear();
00572 
00573     static CvParamGrid get_default_grid( int param_id );
00574 
00575     virtual void write( CvFileStorage* storage, const char* name ) const;
00576     virtual void read( CvFileStorage* storage, CvFileNode* node );
00577     int get_var_count() const { return var_idx ? var_idx->cols : var_all; }
00578 
00579 protected:
00580 
00581     virtual bool set_params( const CvSVMParams& _params );
00582     virtual bool train1( int sample_count, int var_count, const float** samples,
00583                     const void* _responses, double Cp, double Cn,
00584                     CvMemStorage* _storage, double* alpha, double& rho );
00585     virtual bool do_train( int svm_type, int sample_count, int var_count, const float** samples,
00586                     const CvMat* _responses, CvMemStorage* _storage, double* alpha );
00587     virtual void create_kernel();
00588     virtual void create_solver();
00589 
00590     virtual float predict( const float* row_sample, int row_len, bool returnDFVal=false ) const;
00591 
00592     virtual void write_params( CvFileStorage* fs ) const;
00593     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
00594 
00595     CvSVMParams params;
00596     CvMat* class_labels;
00597     int var_all;
00598     float** sv;
00599     int sv_total;
00600     CvMat* var_idx;
00601     CvMat* class_weights;
00602     CvSVMDecisionFunc* decision_func;
00603     CvMemStorage* storage;
00604 
00605     CvSVMSolver* solver;
00606     CvSVMKernel* kernel;
00607 };
00608 
00609 /****************************************************************************************\
00610 *                              Expectation - Maximization                                *
00611 \****************************************************************************************/
00612 
00613 struct CV_EXPORTS CvEMParams
00614 {
00615     CvEMParams() : nclusters(10), cov_mat_type(1/*CvEM::COV_MAT_DIAGONAL*/),
00616         start_step(0/*CvEM::START_AUTO_STEP*/), probs(0), weights(0), means(0), covs(0)
00617     {
00618         term_crit=cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 100, FLT_EPSILON );
00619     }
00620 
00621     CvEMParams( int _nclusters, int _cov_mat_type=1/*CvEM::COV_MAT_DIAGONAL*/,
00622                 int _start_step=0/*CvEM::START_AUTO_STEP*/,
00623                 CvTermCriteria _term_crit=cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 100, FLT_EPSILON),
00624                 const CvMat* _probs=0, const CvMat* _weights=0, const CvMat* _means=0, const CvMat** _covs=0 ) :
00625                 nclusters(_nclusters), cov_mat_type(_cov_mat_type), start_step(_start_step),
00626                 probs(_probs), weights(_weights), means(_means), covs(_covs), term_crit(_term_crit)
00627     {}
00628 
00629     int nclusters;
00630     int cov_mat_type;
00631     int start_step;
00632     const CvMat* probs;
00633     const CvMat* weights;
00634     const CvMat* means;
00635     const CvMat** covs;
00636     CvTermCriteria term_crit;
00637 };
00638 
00639 
00640 class CV_EXPORTS CvEM : public CvStatModel
00641 {
00642 public:
00643     // Type of covariation matrices
00644     enum { COV_MAT_SPHERICAL=0, COV_MAT_DIAGONAL=1, COV_MAT_GENERIC=2 };
00645 
00646     // The initial step
00647     enum { START_E_STEP=1, START_M_STEP=2, START_AUTO_STEP=0 };
00648 
00649     CvEM();
00650     CvEM( const CvMat* samples, const CvMat* sample_idx=0,
00651           CvEMParams params=CvEMParams(), CvMat* labels=0 );
00652     //CvEM (CvEMParams params, CvMat * means, CvMat ** covs, CvMat * weights, CvMat * probs, CvMat * log_weight_div_det, CvMat * inv_eigen_values, CvMat** cov_rotate_mats);
00653 
00654     virtual ~CvEM();
00655 
00656     virtual bool train( const CvMat* samples, const CvMat* sample_idx=0,
00657                         CvEMParams params=CvEMParams(), CvMat* labels=0 );
00658 
00659     virtual float predict( const CvMat* sample, CvMat* probs ) const;
00660 
00661 #ifndef SWIG
00662     CvEM( const cv::Mat& samples, const cv::Mat& sample_idx=cv::Mat(),
00663          CvEMParams params=CvEMParams(), cv::Mat* labels=0 );
00664     
00665     virtual bool train( const cv::Mat& samples, const cv::Mat& sample_idx=cv::Mat(),
00666                        CvEMParams params=CvEMParams(), cv::Mat* labels=0 );
00667     
00668     virtual float predict( const cv::Mat& sample, cv::Mat* probs ) const;
00669 #endif
00670     
00671     virtual void clear();
00672 
00673     int           get_nclusters() const;
00674     const CvMat*  get_means()     const;
00675     const CvMat** get_covs()      const;
00676     const CvMat*  get_weights()   const;
00677     const CvMat*  get_probs()     const;
00678 
00679     inline double         get_log_likelihood     () const { return log_likelihood;     };
00680     
00681 //    inline const CvMat *  get_log_weight_div_det () const { return log_weight_div_det; };
00682 //    inline const CvMat *  get_inv_eigen_values   () const { return inv_eigen_values;   };
00683 //    inline const CvMat ** get_cov_rotate_mats    () const { return cov_rotate_mats;    };
00684 
00685 protected:
00686 
00687     virtual void set_params( const CvEMParams& params,
00688                              const CvVectors& train_data );
00689     virtual void init_em( const CvVectors& train_data );
00690     virtual double run_em( const CvVectors& train_data );
00691     virtual void init_auto( const CvVectors& samples );
00692     virtual void kmeans( const CvVectors& train_data, int nclusters,
00693                          CvMat* labels, CvTermCriteria criteria,
00694                          const CvMat* means );
00695     CvEMParams params;
00696     double log_likelihood;
00697 
00698     CvMat* means;
00699     CvMat** covs;
00700     CvMat* weights;
00701     CvMat* probs;
00702 
00703     CvMat* log_weight_div_det;
00704     CvMat* inv_eigen_values;
00705     CvMat** cov_rotate_mats;
00706 };
00707 
00708 /****************************************************************************************\
00709 *                                      Decision Tree                                     *
00710 \****************************************************************************************/\
00711 struct CvPair16u32s
00712 {
00713     unsigned short* u;
00714     int* i;
00715 };
00716 
00717 
00718 #define CV_DTREE_CAT_DIR(idx,subset) \
00719     (2*((subset[(idx)>>5]&(1 << ((idx) & 31)))==0)-1)
00720 
00721 struct CvDTreeSplit
00722 {
00723     int var_idx;
00724     int condensed_idx;
00725     int inversed;
00726     float quality;
00727     CvDTreeSplit* next;
00728     union
00729     {
00730         int subset[2];
00731         struct
00732         {
00733             float c;
00734             int split_point;
00735         }
00736         ord;
00737     };
00738 };
00739 
00740 
00741 struct CvDTreeNode
00742 {
00743     int class_idx;
00744     int Tn;
00745     double value;
00746 
00747     CvDTreeNode* parent;
00748     CvDTreeNode* left;
00749     CvDTreeNode* right;
00750 
00751     CvDTreeSplit* split;
00752 
00753     int sample_count;
00754     int depth;
00755     int* num_valid;
00756     int offset;
00757     int buf_idx;
00758     double maxlr;
00759 
00760     // global pruning data
00761     int complexity;
00762     double alpha;
00763     double node_risk, tree_risk, tree_error;
00764 
00765     // cross-validation pruning data
00766     int* cv_Tn;
00767     double* cv_node_risk;
00768     double* cv_node_error;
00769 
00770     int get_num_valid(int vi) { return num_valid ? num_valid[vi] : sample_count; }
00771     void set_num_valid(int vi, int n) { if( num_valid ) num_valid[vi] = n; }
00772 };
00773 
00774 
00775 struct CV_EXPORTS CvDTreeParams
00776 {
00777     int   max_categories;
00778     int   max_depth;
00779     int   min_sample_count;
00780     int   cv_folds;
00781     bool  use_surrogates;
00782     bool  use_1se_rule;
00783     bool  truncate_pruned_tree;
00784     float regression_accuracy;
00785     const float* priors;
00786 
00787     CvDTreeParams() : max_categories(10), max_depth(INT_MAX), min_sample_count(10),
00788         cv_folds(10), use_surrogates(true), use_1se_rule(true),
00789         truncate_pruned_tree(true), regression_accuracy(0.01f), priors(0)
00790     {}
00791 
00792     CvDTreeParams( int _max_depth, int _min_sample_count,
00793                    float _regression_accuracy, bool _use_surrogates,
00794                    int _max_categories, int _cv_folds,
00795                    bool _use_1se_rule, bool _truncate_pruned_tree,
00796                    const float* _priors ) :
00797         max_categories(_max_categories), max_depth(_max_depth),
00798         min_sample_count(_min_sample_count), cv_folds (_cv_folds),
00799         use_surrogates(_use_surrogates), use_1se_rule(_use_1se_rule),
00800         truncate_pruned_tree(_truncate_pruned_tree),
00801         regression_accuracy(_regression_accuracy),
00802         priors(_priors)
00803     {}
00804 };
00805 
00806 
00807 struct CV_EXPORTS CvDTreeTrainData
00808 {
00809     CvDTreeTrainData();
00810     CvDTreeTrainData( const CvMat* _train_data, int _tflag,
00811                       const CvMat* _responses, const CvMat* _var_idx=0,
00812                       const CvMat* _sample_idx=0, const CvMat* _var_type=0,
00813                       const CvMat* _missing_mask=0,
00814                       const CvDTreeParams& _params=CvDTreeParams(),
00815                       bool _shared=false, bool _add_labels=false );
00816     virtual ~CvDTreeTrainData();
00817 
00818     virtual void set_data( const CvMat* _train_data, int _tflag,
00819                           const CvMat* _responses, const CvMat* _var_idx=0,
00820                           const CvMat* _sample_idx=0, const CvMat* _var_type=0,
00821                           const CvMat* _missing_mask=0,
00822                           const CvDTreeParams& _params=CvDTreeParams(),
00823                           bool _shared=false, bool _add_labels=false,
00824                           bool _update_data=false );
00825     virtual void do_responses_copy();
00826 
00827     virtual void get_vectors( const CvMat* _subsample_idx,
00828          float* values, uchar* missing, float* responses, bool get_class_idx=false );
00829 
00830     virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
00831 
00832     virtual void write_params( CvFileStorage* fs ) const;
00833     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
00834 
00835     // release all the data
00836     virtual void clear();
00837 
00838     int get_num_classes() const;
00839     int get_var_type(int vi) const;
00840     int get_work_var_count() const {return work_var_count;}
00841 
00842     virtual const float* get_ord_responses( CvDTreeNode* n, float* values_buf, int* sample_indices_buf );
00843     virtual const int* get_class_labels( CvDTreeNode* n, int* labels_buf );
00844     virtual const int* get_cv_labels( CvDTreeNode* n, int* labels_buf );
00845     virtual const int* get_sample_indices( CvDTreeNode* n, int* indices_buf );
00846     virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf );
00847     virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* sorted_indices_buf,
00848                                    const float** ord_values, const int** sorted_indices, int* sample_indices_buf );
00849     virtual int get_child_buf_idx( CvDTreeNode* n );
00850 
00852 
00853     virtual bool set_params( const CvDTreeParams& params );
00854     virtual CvDTreeNode* new_node( CvDTreeNode* parent, int count,
00855                                    int storage_idx, int offset );
00856 
00857     virtual CvDTreeSplit* new_split_ord( int vi, float cmp_val,
00858                 int split_point, int inversed, float quality );
00859     virtual CvDTreeSplit* new_split_cat( int vi, float quality );
00860     virtual void free_node_data( CvDTreeNode* node );
00861     virtual void free_train_data();
00862     virtual void free_node( CvDTreeNode* node );
00863 
00864     int sample_count, var_all, var_count, max_c_count;
00865     int ord_var_count, cat_var_count, work_var_count;
00866     bool have_labels, have_priors;
00867     bool is_classifier;
00868     int tflag;
00869 
00870     const CvMat* train_data;
00871     const CvMat* responses;
00872     CvMat* responses_copy; // used in Boosting
00873 
00874     int buf_count, buf_size;
00875     bool shared;
00876     int is_buf_16u;
00877     
00878     CvMat* cat_count;
00879     CvMat* cat_ofs;
00880     CvMat* cat_map;
00881 
00882     CvMat* counts;
00883     CvMat* buf;
00884     CvMat* direction;
00885     CvMat* split_buf;
00886 
00887     CvMat* var_idx;
00888     CvMat* var_type; // i-th element =
00889                      //   k<0  - ordered
00890                      //   k>=0 - categorical, see k-th element of cat_* arrays
00891     CvMat* priors;
00892     CvMat* priors_mult;
00893 
00894     CvDTreeParams params;
00895 
00896     CvMemStorage* tree_storage;
00897     CvMemStorage* temp_storage;
00898 
00899     CvDTreeNode* data_root;
00900 
00901     CvSet* node_heap;
00902     CvSet* split_heap;
00903     CvSet* cv_heap;
00904     CvSet* nv_heap;
00905 
00906     CvRNG rng;
00907 };
00908 
00909 class CvDTree;
00910 class CvForestTree;
00911 
00912 namespace cv
00913 {
00914     struct DTreeBestSplitFinder;
00915     struct ForestTreeBestSplitFinder;
00916 }
00917 
00918 class CV_EXPORTS CvDTree : public CvStatModel
00919 {
00920 public:
00921     CvDTree();
00922     virtual ~CvDTree();
00923 
00924     virtual bool train( const CvMat* _train_data, int _tflag,
00925                         const CvMat* _responses, const CvMat* _var_idx=0,
00926                         const CvMat* _sample_idx=0, const CvMat* _var_type=0,
00927                         const CvMat* _missing_mask=0,
00928                         CvDTreeParams params=CvDTreeParams() );
00929 
00930     virtual bool train( CvMLData* _data, CvDTreeParams _params=CvDTreeParams() );
00931 
00932     virtual float calc_error( CvMLData* _data, int type , std::vector<float> *resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
00933 
00934     virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx );
00935 
00936     virtual CvDTreeNode* predict( const CvMat* _sample, const CvMat* _missing_data_mask=0,
00937                                   bool preprocessed_input=false ) const;
00938 
00939 #ifndef SWIG
00940     virtual bool train( const cv::Mat& _train_data, int _tflag,
00941                        const cv::Mat& _responses, const cv::Mat& _var_idx=cv::Mat(),
00942                        const cv::Mat& _sample_idx=cv::Mat(), const cv::Mat& _var_type=cv::Mat(),
00943                        const cv::Mat& _missing_mask=cv::Mat(),
00944                        CvDTreeParams params=CvDTreeParams() );
00945     
00946     virtual CvDTreeNode* predict( const cv::Mat& _sample, const cv::Mat& _missing_data_mask=cv::Mat(),
00947                                   bool preprocessed_input=false ) const;
00948 #endif
00949     
00950     virtual const CvMat* get_var_importance();
00951     virtual void clear();
00952 
00953     virtual void read( CvFileStorage* fs, CvFileNode* node );
00954     virtual void write( CvFileStorage* fs, const char* name ) const;
00955 
00956     // special read & write methods for trees in the tree ensembles
00957     virtual void read( CvFileStorage* fs, CvFileNode* node,
00958                        CvDTreeTrainData* data );
00959     virtual void write( CvFileStorage* fs ) const;
00960 
00961     const CvDTreeNode* get_root() const;
00962     int get_pruned_tree_idx() const;
00963     CvDTreeTrainData* get_data();
00964 
00965 protected:
00966     friend struct cv::DTreeBestSplitFinder;
00967 
00968     virtual bool do_train( const CvMat* _subsample_idx );
00969 
00970     virtual void try_split_node( CvDTreeNode* n );
00971     virtual void split_node_data( CvDTreeNode* n );
00972     virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
00973     virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi, 
00974                             float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
00975     virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
00976                             float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
00977     virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi, 
00978                             float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
00979     virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi, 
00980                             float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
00981     virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
00982     virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
00983     virtual double calc_node_dir( CvDTreeNode* node );
00984     virtual void complete_node_dir( CvDTreeNode* node );
00985     virtual void cluster_categories( const int* vectors, int vector_count,
00986         int var_count, int* sums, int k, int* cluster_labels );
00987 
00988     virtual void calc_node_value( CvDTreeNode* node );
00989 
00990     virtual void prune_cv();
00991     virtual double update_tree_rnc( int T, int fold );
00992     virtual int cut_tree( int T, int fold, double min_alpha );
00993     virtual void free_prune_data(bool cut_tree);
00994     virtual void free_tree();
00995 
00996     virtual void write_node( CvFileStorage* fs, CvDTreeNode* node ) const;
00997     virtual void write_split( CvFileStorage* fs, CvDTreeSplit* split ) const;
00998     virtual CvDTreeNode* read_node( CvFileStorage* fs, CvFileNode* node, CvDTreeNode* parent );
00999     virtual CvDTreeSplit* read_split( CvFileStorage* fs, CvFileNode* node );
01000     virtual void write_tree_nodes( CvFileStorage* fs ) const;
01001     virtual void read_tree_nodes( CvFileStorage* fs, CvFileNode* node );
01002 
01003     CvDTreeNode* root;
01004     CvMat* var_importance;
01005     CvDTreeTrainData* data;
01006 
01007 public:
01008     int pruned_tree_idx;
01009 };
01010 
01011 
01012 /****************************************************************************************\
01013 *                                   Random Trees Classifier                              *
01014 \****************************************************************************************/
01015 
01016 class CvRTrees;
01017 
01018 class CV_EXPORTS CvForestTree: public CvDTree
01019 {
01020 public:
01021     CvForestTree();
01022     virtual ~CvForestTree();
01023 
01024     virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx, CvRTrees* forest );
01025 
01026     virtual int get_var_count() const {return data ? data->var_count : 0;}
01027     virtual void read( CvFileStorage* fs, CvFileNode* node, CvRTrees* forest, CvDTreeTrainData* _data );
01028 
01029     /* dummy methods to avoid warnings: BEGIN */
01030     virtual bool train( const CvMat* _train_data, int _tflag,
01031                         const CvMat* _responses, const CvMat* _var_idx=0,
01032                         const CvMat* _sample_idx=0, const CvMat* _var_type=0,
01033                         const CvMat* _missing_mask=0,
01034                         CvDTreeParams params=CvDTreeParams() );
01035 
01036     virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx );
01037     virtual void read( CvFileStorage* fs, CvFileNode* node );
01038     virtual void read( CvFileStorage* fs, CvFileNode* node,
01039                        CvDTreeTrainData* data );
01040     /* dummy methods to avoid warnings: END */
01041 
01042 protected:
01043     friend struct cv::ForestTreeBestSplitFinder;
01044 
01045     virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
01046     CvRTrees* forest;
01047 };
01048 
01049 
01050 struct CV_EXPORTS CvRTParams : public CvDTreeParams
01051 {
01052     //Parameters for the forest
01053     bool calc_var_importance; // true <=> RF processes variable importance
01054     int nactive_vars;
01055     CvTermCriteria term_crit;
01056 
01057     CvRTParams() : CvDTreeParams( 5, 10, 0, false, 10, 0, false, false, 0 ),
01058         calc_var_importance(false), nactive_vars(0)
01059     {
01060         term_crit = cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 50, 0.1 );
01061     }
01062 
01063     CvRTParams( int _max_depth, int _min_sample_count,
01064                 float _regression_accuracy, bool _use_surrogates,
01065                 int _max_categories, const float* _priors, bool _calc_var_importance,
01066                 int _nactive_vars, int max_num_of_trees_in_the_forest,
01067                 float forest_accuracy, int termcrit_type ) :
01068         CvDTreeParams( _max_depth, _min_sample_count, _regression_accuracy,
01069                        _use_surrogates, _max_categories, 0,
01070                        false, false, _priors ),
01071         calc_var_importance(_calc_var_importance),
01072         nactive_vars(_nactive_vars)
01073     {
01074         term_crit = cvTermCriteria(termcrit_type,
01075             max_num_of_trees_in_the_forest, forest_accuracy);
01076     }
01077 };
01078 
01079 
01080 class CV_EXPORTS CvRTrees : public CvStatModel
01081 {
01082 public:
01083     CvRTrees();
01084     virtual ~CvRTrees();
01085     virtual bool train( const CvMat* _train_data, int _tflag,
01086                         const CvMat* _responses, const CvMat* _var_idx=0,
01087                         const CvMat* _sample_idx=0, const CvMat* _var_type=0,
01088                         const CvMat* _missing_mask=0,
01089                         CvRTParams params=CvRTParams() );
01090     
01091     virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );
01092     virtual float predict( const CvMat* sample, const CvMat* missing = 0 ) const;
01093     virtual float predict_prob( const CvMat* sample, const CvMat* missing = 0 ) const;
01094 
01095 #ifndef SWIG
01096     virtual bool train( const cv::Mat& _train_data, int _tflag,
01097                        const cv::Mat& _responses, const cv::Mat& _var_idx=cv::Mat(),
01098                        const cv::Mat& _sample_idx=cv::Mat(), const cv::Mat& _var_type=cv::Mat(),
01099                        const cv::Mat& _missing_mask=cv::Mat(),
01100                        CvRTParams params=CvRTParams() );
01101     virtual float predict( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;
01102     virtual float predict_prob( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;
01103 #endif
01104     
01105     virtual void clear();
01106 
01107     virtual const CvMat* get_var_importance();
01108     virtual float get_proximity( const CvMat* sample1, const CvMat* sample2,
01109         const CvMat* missing1 = 0, const CvMat* missing2 = 0 ) const;
01110     
01111     virtual float calc_error( CvMLData* _data, int type , std::vector<float> *resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
01112 
01113     virtual float get_train_error();    
01114 
01115     virtual void read( CvFileStorage* fs, CvFileNode* node );
01116     virtual void write( CvFileStorage* fs, const char* name ) const;
01117 
01118     CvMat* get_active_var_mask();
01119     CvRNG* get_rng();
01120 
01121     int get_tree_count() const;
01122     CvForestTree* get_tree(int i) const;
01123 
01124 protected:
01125 
01126     virtual bool grow_forest( const CvTermCriteria term_crit );
01127 
01128     // array of the trees of the forest
01129     CvForestTree** trees;
01130     CvDTreeTrainData* data;
01131     int ntrees;
01132     int nclasses;
01133     double oob_error;
01134     CvMat* var_importance;
01135     int nsamples;
01136 
01137     CvRNG rng;
01138     CvMat* active_var_mask;
01139 };
01140 
01141 /****************************************************************************************\
01142 *                           Extremely randomized trees Classifier                        *
01143 \****************************************************************************************/
01144 struct CV_EXPORTS CvERTreeTrainData : public CvDTreeTrainData
01145 {
01146     virtual void set_data( const CvMat* _train_data, int _tflag,
01147                           const CvMat* _responses, const CvMat* _var_idx=0,
01148                           const CvMat* _sample_idx=0, const CvMat* _var_type=0,
01149                           const CvMat* _missing_mask=0,
01150                           const CvDTreeParams& _params=CvDTreeParams(),
01151                           bool _shared=false, bool _add_labels=false,
01152                           bool _update_data=false );
01153     virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* missing_buf,
01154                                    const float** ord_values, const int** missing, int* sample_buf = 0 );
01155     virtual const int* get_sample_indices( CvDTreeNode* n, int* indices_buf );
01156     virtual const int* get_cv_labels( CvDTreeNode* n, int* labels_buf );
01157     virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf );
01158     virtual void get_vectors( const CvMat* _subsample_idx, float* values, uchar* missing,
01159                               float* responses, bool get_class_idx=false );
01160     virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
01161     const CvMat* missing_mask;
01162 };
01163 
01164 class CV_EXPORTS CvForestERTree : public CvForestTree
01165 {
01166 protected:
01167     virtual double calc_node_dir( CvDTreeNode* node );
01168     virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi, 
01169         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01170     virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
01171         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01172     virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi, 
01173         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01174     virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi, 
01175         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01176     virtual void split_node_data( CvDTreeNode* n );
01177 };
01178 
01179 class CV_EXPORTS CvERTrees : public CvRTrees
01180 {
01181 public:
01182     CvERTrees();
01183     virtual ~CvERTrees();
01184     virtual bool train( const CvMat* _train_data, int _tflag,
01185                         const CvMat* _responses, const CvMat* _var_idx=0,
01186                         const CvMat* _sample_idx=0, const CvMat* _var_type=0,
01187                         const CvMat* _missing_mask=0,
01188                         CvRTParams params=CvRTParams());
01189 #ifndef SWIG
01190     virtual bool train( const cv::Mat& _train_data, int _tflag,
01191                        const cv::Mat& _responses, const cv::Mat& _var_idx=cv::Mat(),
01192                        const cv::Mat& _sample_idx=cv::Mat(), const cv::Mat& _var_type=cv::Mat(),
01193                        const cv::Mat& _missing_mask=cv::Mat(),
01194                        CvRTParams params=CvRTParams());
01195 #endif
01196     virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );
01197 protected:
01198     virtual bool grow_forest( const CvTermCriteria term_crit );
01199 };
01200 
01201 
01202 /****************************************************************************************\
01203 *                                   Boosted tree classifier                              *
01204 \****************************************************************************************/
01205 
01206 struct CV_EXPORTS CvBoostParams : public CvDTreeParams
01207 {
01208     int boost_type;
01209     int weak_count;
01210     int split_criteria;
01211     double weight_trim_rate;
01212 
01213     CvBoostParams();
01214     CvBoostParams( int boost_type, int weak_count, double weight_trim_rate,
01215                    int max_depth, bool use_surrogates, const float* priors );
01216 };
01217 
01218 
01219 class CvBoost;
01220 
01221 class CV_EXPORTS CvBoostTree: public CvDTree
01222 {
01223 public:
01224     CvBoostTree();
01225     virtual ~CvBoostTree();
01226 
01227     virtual bool train( CvDTreeTrainData* _train_data,
01228                         const CvMat* subsample_idx, CvBoost* ensemble );
01229 
01230     virtual void scale( double s );
01231     virtual void read( CvFileStorage* fs, CvFileNode* node,
01232                        CvBoost* ensemble, CvDTreeTrainData* _data );
01233     virtual void clear();
01234 
01235     /* dummy methods to avoid warnings: BEGIN */
01236     virtual bool train( const CvMat* _train_data, int _tflag,
01237                         const CvMat* _responses, const CvMat* _var_idx=0,
01238                         const CvMat* _sample_idx=0, const CvMat* _var_type=0,
01239                         const CvMat* _missing_mask=0,
01240                         CvDTreeParams params=CvDTreeParams() );
01241     virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx );
01242 
01243     virtual void read( CvFileStorage* fs, CvFileNode* node );
01244     virtual void read( CvFileStorage* fs, CvFileNode* node,
01245                        CvDTreeTrainData* data );
01246     /* dummy methods to avoid warnings: END */
01247 
01248 protected:
01249 
01250     virtual void try_split_node( CvDTreeNode* n );
01251     virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
01252     virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
01253     virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi, 
01254         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01255     virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
01256         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01257     virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi, 
01258         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01259     virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi, 
01260         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01261     virtual void calc_node_value( CvDTreeNode* n );
01262     virtual double calc_node_dir( CvDTreeNode* n );
01263 
01264     CvBoost* ensemble;
01265 };
01266 
01267 
01268 class CV_EXPORTS CvBoost : public CvStatModel
01269 {
01270 public:
01271     // Boosting type
01272     enum { DISCRETE=0, REAL=1, LOGIT=2, GENTLE=3 };
01273 
01274     // Splitting criteria
01275     enum { DEFAULT=0, GINI=1, MISCLASS=3, SQERR=4 };
01276 
01277     CvBoost();
01278     virtual ~CvBoost();
01279 
01280     CvBoost( const CvMat* _train_data, int _tflag,
01281              const CvMat* _responses, const CvMat* _var_idx=0,
01282              const CvMat* _sample_idx=0, const CvMat* _var_type=0,
01283              const CvMat* _missing_mask=0,
01284              CvBoostParams params=CvBoostParams() );
01285     
01286     virtual bool train( const CvMat* _train_data, int _tflag,
01287              const CvMat* _responses, const CvMat* _var_idx=0,
01288              const CvMat* _sample_idx=0, const CvMat* _var_type=0,
01289              const CvMat* _missing_mask=0,
01290              CvBoostParams params=CvBoostParams(),
01291              bool update=false );
01292     
01293     virtual bool train( CvMLData* data,
01294              CvBoostParams params=CvBoostParams(),
01295              bool update=false );
01296 
01297     virtual float predict( const CvMat* _sample, const CvMat* _missing=0,
01298                            CvMat* weak_responses=0, CvSlice slice=CV_WHOLE_SEQ,
01299                            bool raw_mode=false, bool return_sum=false ) const;
01300 
01301 #ifndef SWIG
01302     CvBoost( const cv::Mat& _train_data, int _tflag,
01303             const cv::Mat& _responses, const cv::Mat& _var_idx=cv::Mat(),
01304             const cv::Mat& _sample_idx=cv::Mat(), const cv::Mat& _var_type=cv::Mat(),
01305             const cv::Mat& _missing_mask=cv::Mat(),
01306             CvBoostParams params=CvBoostParams() );
01307     
01308     virtual bool train( const cv::Mat& _train_data, int _tflag,
01309                        const cv::Mat& _responses, const cv::Mat& _var_idx=cv::Mat(),
01310                        const cv::Mat& _sample_idx=cv::Mat(), const cv::Mat& _var_type=cv::Mat(),
01311                        const cv::Mat& _missing_mask=cv::Mat(),
01312                        CvBoostParams params=CvBoostParams(),
01313                        bool update=false );
01314     
01315     virtual float predict( const cv::Mat& _sample, const cv::Mat& _missing=cv::Mat(),
01316                           cv::Mat* weak_responses=0, CvSlice slice=CV_WHOLE_SEQ,
01317                           bool raw_mode=false, bool return_sum=false ) const;
01318 #endif
01319     
01320     virtual float calc_error( CvMLData* _data, int type , std::vector<float> *resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
01321 
01322     virtual void prune( CvSlice slice );
01323 
01324     virtual void clear();
01325 
01326     virtual void write( CvFileStorage* storage, const char* name ) const;
01327     virtual void read( CvFileStorage* storage, CvFileNode* node );
01328     virtual const CvMat* get_active_vars(bool absolute_idx=true);
01329 
01330     CvSeq* get_weak_predictors();
01331 
01332     CvMat* get_weights();
01333     CvMat* get_subtree_weights();
01334     CvMat* get_weak_response();
01335     const CvBoostParams& get_params() const;
01336     const CvDTreeTrainData* get_data() const;
01337 
01338 protected:
01339 
01340     virtual bool set_params( const CvBoostParams& _params );
01341     virtual void update_weights( CvBoostTree* tree );
01342     virtual void trim_weights();
01343     virtual void write_params( CvFileStorage* fs ) const;
01344     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
01345 
01346     CvDTreeTrainData* data;
01347     CvBoostParams params;
01348     CvSeq* weak;
01349 
01350     CvMat* active_vars;
01351     CvMat* active_vars_abs;
01352     bool have_active_cat_vars;
01353 
01354     CvMat* orig_response;
01355     CvMat* sum_response;
01356     CvMat* weak_eval;
01357     CvMat* subsample_mask;
01358     CvMat* weights;
01359     CvMat* subtree_weights;
01360     bool have_subsample;
01361 };
01362 
01363 
01364 /****************************************************************************************\
01365 *                              Artificial Neural Networks (ANN)                          *
01366 \****************************************************************************************/
01367 
01369 
01370 struct CV_EXPORTS CvANN_MLP_TrainParams
01371 {
01372     CvANN_MLP_TrainParams();
01373     CvANN_MLP_TrainParams( CvTermCriteria term_crit, int train_method,
01374                            double param1, double param2=0 );
01375     ~CvANN_MLP_TrainParams();
01376 
01377     enum { BACKPROP=0, RPROP=1 };
01378 
01379     CvTermCriteria term_crit;
01380     int train_method;
01381 
01382     // backpropagation parameters
01383     double bp_dw_scale, bp_moment_scale;
01384 
01385     // rprop parameters
01386     double rp_dw0, rp_dw_plus, rp_dw_minus, rp_dw_min, rp_dw_max;
01387 };
01388 
01389 
01390 class CV_EXPORTS CvANN_MLP : public CvStatModel
01391 {
01392 public:
01393     CvANN_MLP();
01394     CvANN_MLP( const CvMat* _layer_sizes,
01395                int _activ_func=SIGMOID_SYM,
01396                double _f_param1=0, double _f_param2=0 );
01397 
01398     virtual ~CvANN_MLP();
01399 
01400     virtual void create( const CvMat* _layer_sizes,
01401                          int _activ_func=SIGMOID_SYM,
01402                          double _f_param1=0, double _f_param2=0 );
01403     
01404     virtual int train( const CvMat* _inputs, const CvMat* _outputs,
01405                        const CvMat* _sample_weights, const CvMat* _sample_idx=0,
01406                        CvANN_MLP_TrainParams _params = CvANN_MLP_TrainParams(),
01407                        int flags=0 );
01408     virtual float predict( const CvMat* _inputs, CvMat* _outputs ) const;
01409     
01410 #ifndef SWIG
01411     CvANN_MLP( const cv::Mat& _layer_sizes,
01412               int _activ_func=SIGMOID_SYM,
01413               double _f_param1=0, double _f_param2=0 );
01414     
01415     virtual void create( const cv::Mat& _layer_sizes,
01416                         int _activ_func=SIGMOID_SYM,
01417                         double _f_param1=0, double _f_param2=0 );    
01418     
01419     virtual int train( const cv::Mat& _inputs, const cv::Mat& _outputs,
01420                       const cv::Mat& _sample_weights, const cv::Mat& _sample_idx=cv::Mat(),
01421                       CvANN_MLP_TrainParams _params = CvANN_MLP_TrainParams(),
01422                       int flags=0 );    
01423     
01424     virtual float predict( const cv::Mat& _inputs, cv::Mat& _outputs ) const;
01425 #endif
01426     
01427     virtual void clear();
01428 
01429     // possible activation functions
01430     enum { IDENTITY = 0, SIGMOID_SYM = 1, GAUSSIAN = 2 };
01431 
01432     // available training flags
01433     enum { UPDATE_WEIGHTS = 1, NO_INPUT_SCALE = 2, NO_OUTPUT_SCALE = 4 };
01434 
01435     virtual void read( CvFileStorage* fs, CvFileNode* node );
01436     virtual void write( CvFileStorage* storage, const char* name ) const;
01437 
01438     int get_layer_count() { return layer_sizes ? layer_sizes->cols : 0; }
01439     const CvMat* get_layer_sizes() { return layer_sizes; }
01440     double* get_weights(int layer)
01441     {
01442         return layer_sizes && weights &&
01443             (unsigned)layer <= (unsigned)layer_sizes->cols ? weights[layer] : 0;
01444     }
01445 
01446 protected:
01447 
01448     virtual bool prepare_to_train( const CvMat* _inputs, const CvMat* _outputs,
01449             const CvMat* _sample_weights, const CvMat* _sample_idx,
01450             CvVectors* _ivecs, CvVectors* _ovecs, double** _sw, int _flags );
01451 
01452     // sequential random backpropagation
01453     virtual int train_backprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
01454 
01455     // RPROP algorithm
01456     virtual int train_rprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
01457 
01458     virtual void calc_activ_func( CvMat* xf, const double* bias ) const;
01459     virtual void calc_activ_func_deriv( CvMat* xf, CvMat* deriv, const double* bias ) const;
01460     virtual void set_activ_func( int _activ_func=SIGMOID_SYM,
01461                                  double _f_param1=0, double _f_param2=0 );
01462     virtual void init_weights();
01463     virtual void scale_input( const CvMat* _src, CvMat* _dst ) const;
01464     virtual void scale_output( const CvMat* _src, CvMat* _dst ) const;
01465     virtual void calc_input_scale( const CvVectors* vecs, int flags );
01466     virtual void calc_output_scale( const CvVectors* vecs, int flags );
01467 
01468     virtual void write_params( CvFileStorage* fs ) const;
01469     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
01470 
01471     CvMat* layer_sizes;
01472     CvMat* wbuf;
01473     CvMat* sample_weights;
01474     double** weights;
01475     double f_param1, f_param2;
01476     double min_val, max_val, min_val1, max_val1;
01477     int activ_func;
01478     int max_count, max_buf_sz;
01479     CvANN_MLP_TrainParams params;
01480     CvRNG rng;
01481 };
01482 
01483 #if 0
01484 /****************************************************************************************\
01485 *                            Convolutional Neural Network                                *
01486 \****************************************************************************************/
01487 typedef struct CvCNNLayer CvCNNLayer;
01488 typedef struct CvCNNetwork CvCNNetwork;
01489 
01490 #define CV_CNN_LEARN_RATE_DECREASE_HYPERBOLICALLY  1
01491 #define CV_CNN_LEARN_RATE_DECREASE_SQRT_INV        2
01492 #define CV_CNN_LEARN_RATE_DECREASE_LOG_INV         3
01493 
01494 #define CV_CNN_GRAD_ESTIM_RANDOM        0
01495 #define CV_CNN_GRAD_ESTIM_BY_WORST_IMG  1
01496 
01497 #define ICV_CNN_LAYER                0x55550000
01498 #define ICV_CNN_CONVOLUTION_LAYER    0x00001111
01499 #define ICV_CNN_SUBSAMPLING_LAYER    0x00002222
01500 #define ICV_CNN_FULLCONNECT_LAYER    0x00003333
01501 
01502 #define ICV_IS_CNN_LAYER( layer )                                          \
01503     ( ((layer) != NULL) && ((((CvCNNLayer*)(layer))->flags & CV_MAGIC_MASK)\
01504         == ICV_CNN_LAYER ))
01505 
01506 #define ICV_IS_CNN_CONVOLUTION_LAYER( layer )                              \
01507     ( (ICV_IS_CNN_LAYER( layer )) && (((CvCNNLayer*) (layer))->flags       \
01508         & ~CV_MAGIC_MASK) == ICV_CNN_CONVOLUTION_LAYER )
01509 
01510 #define ICV_IS_CNN_SUBSAMPLING_LAYER( layer )                              \
01511     ( (ICV_IS_CNN_LAYER( layer )) && (((CvCNNLayer*) (layer))->flags       \
01512         & ~CV_MAGIC_MASK) == ICV_CNN_SUBSAMPLING_LAYER )
01513 
01514 #define ICV_IS_CNN_FULLCONNECT_LAYER( layer )                              \
01515     ( (ICV_IS_CNN_LAYER( layer )) && (((CvCNNLayer*) (layer))->flags       \
01516         & ~CV_MAGIC_MASK) == ICV_CNN_FULLCONNECT_LAYER )
01517 
01518 typedef void (CV_CDECL *CvCNNLayerForward)
01519     ( CvCNNLayer* layer, const CvMat* input, CvMat* output );
01520 
01521 typedef void (CV_CDECL *CvCNNLayerBackward)
01522     ( CvCNNLayer* layer, int t, const CvMat* X, const CvMat* dE_dY, CvMat* dE_dX );
01523 
01524 typedef void (CV_CDECL *CvCNNLayerRelease)
01525     (CvCNNLayer** layer);
01526 
01527 typedef void (CV_CDECL *CvCNNetworkAddLayer)
01528     (CvCNNetwork* network, CvCNNLayer* layer);
01529 
01530 typedef void (CV_CDECL *CvCNNetworkRelease)
01531     (CvCNNetwork** network);
01532 
01533 #define CV_CNN_LAYER_FIELDS()           \
01534     /* Indicator of the layer's type */ \
01535     int flags;                          \
01536                                         \
01537     /* Number of input images */        \
01538     int n_input_planes;                 \
01539     /* Height of each input image */    \
01540     int input_height;                   \
01541     /* Width of each input image */     \
01542     int input_width;                    \
01543                                         \
01544     /* Number of output images */       \
01545     int n_output_planes;                \
01546     /* Height of each output image */   \
01547     int output_height;                  \
01548     /* Width of each output image */    \
01549     int output_width;                   \
01550                                         \
01551     /* Learning rate at the first iteration */                      \
01552     float init_learn_rate;                                          \
01553     /* Dynamics of learning rate decreasing */                      \
01554     int learn_rate_decrease_type;                                   \
01555     /* Trainable weights of the layer (including bias) */           \
01556     /* i-th row is a set of weights of the i-th output plane */     \
01557     CvMat* weights;                                                 \
01558                                                                     \
01559     CvCNNLayerForward  forward;                                     \
01560     CvCNNLayerBackward backward;                                    \
01561     CvCNNLayerRelease  release;                                     \
01562     /* Pointers to the previous and next layers in the network */   \
01563     CvCNNLayer* prev_layer;                                         \
01564     CvCNNLayer* next_layer
01565 
01566 typedef struct CvCNNLayer
01567 {
01568     CV_CNN_LAYER_FIELDS();
01569 }CvCNNLayer;
01570 
01571 typedef struct CvCNNConvolutionLayer
01572 {
01573     CV_CNN_LAYER_FIELDS();
01574     // Kernel size (height and width) for convolution.
01575     int K;
01576     // connections matrix, (i,j)-th element is 1 iff there is a connection between
01577     // i-th plane of the current layer and j-th plane of the previous layer;
01578     // (i,j)-th element is equal to 0 otherwise
01579     CvMat *connect_mask;
01580     // value of the learning rate for updating weights at the first iteration
01581 }CvCNNConvolutionLayer;
01582 
01583 typedef struct CvCNNSubSamplingLayer
01584 {
01585     CV_CNN_LAYER_FIELDS();
01586     // ratio between the heights (or widths - ratios are supposed to be equal)
01587     // of the input and output planes
01588     int sub_samp_scale;
01589     // amplitude of sigmoid activation function
01590     float a;
01591     // scale parameter of sigmoid activation function
01592     float s;
01593     // exp2ssumWX = exp(2<s>*(bias+w*(x1+...+x4))), where x1,...x4 are some elements of X
01594     // - is the vector used in computing of the activation function in backward
01595     CvMat* exp2ssumWX;
01596     // (x1+x2+x3+x4), where x1,...x4 are some elements of X
01597     // - is the vector used in computing of the activation function in backward
01598     CvMat* sumX;
01599 }CvCNNSubSamplingLayer;
01600 
01601 // Structure of the last layer.
01602 typedef struct CvCNNFullConnectLayer
01603 {
01604     CV_CNN_LAYER_FIELDS();
01605     // amplitude of sigmoid activation function
01606     float a;
01607     // scale parameter of sigmoid activation function
01608     float s;
01609     // exp2ssumWX = exp(2*<s>*(W*X)) - is the vector used in computing of the
01610     // activation function and it's derivative by the formulae
01611     // activ.func. = <a>(exp(2<s>WX)-1)/(exp(2<s>WX)+1) == <a> - 2<a>/(<exp2ssumWX> + 1)
01612     // (activ.func.)' = 4<a><s>exp(2<s>WX)/(exp(2<s>WX)+1)^2
01613     CvMat* exp2ssumWX;
01614 }CvCNNFullConnectLayer;
01615 
01616 typedef struct CvCNNetwork
01617 {
01618     int n_layers;
01619     CvCNNLayer* layers;
01620     CvCNNetworkAddLayer add_layer;
01621     CvCNNetworkRelease release;
01622 }CvCNNetwork;
01623 
01624 typedef struct CvCNNStatModel
01625 {
01626     CV_STAT_MODEL_FIELDS();
01627     CvCNNetwork* network;
01628     // etalons are allocated as rows, the i-th etalon has label cls_labeles[i]
01629     CvMat* etalons;
01630     // classes labels
01631     CvMat* cls_labels;
01632 }CvCNNStatModel;
01633 
01634 typedef struct CvCNNStatModelParams
01635 {
01636     CV_STAT_MODEL_PARAM_FIELDS();
01637     // network must be created by the functions cvCreateCNNetwork and <add_layer>
01638     CvCNNetwork* network;
01639     CvMat* etalons;
01640     // termination criteria
01641     int max_iter;
01642     int start_iter;
01643     int grad_estim_type;
01644 }CvCNNStatModelParams;
01645 
01646 CVAPI(CvCNNLayer*) cvCreateCNNConvolutionLayer(
01647     int n_input_planes, int input_height, int input_width,
01648     int n_output_planes, int K,
01649     float init_learn_rate, int learn_rate_decrease_type,
01650     CvMat* connect_mask CV_DEFAULT(0), CvMat* weights CV_DEFAULT(0) );
01651 
01652 CVAPI(CvCNNLayer*) cvCreateCNNSubSamplingLayer(
01653     int n_input_planes, int input_height, int input_width,
01654     int sub_samp_scale, float a, float s,
01655     float init_learn_rate, int learn_rate_decrease_type, CvMat* weights CV_DEFAULT(0) );
01656 
01657 CVAPI(CvCNNLayer*) cvCreateCNNFullConnectLayer(
01658     int n_inputs, int n_outputs, float a, float s,
01659     float init_learn_rate, int learning_type, CvMat* weights CV_DEFAULT(0) );
01660 
01661 CVAPI(CvCNNetwork*) cvCreateCNNetwork( CvCNNLayer* first_layer );
01662 
01663 CVAPI(CvStatModel*) cvTrainCNNClassifier(
01664             const CvMat* train_data, int tflag,
01665             const CvMat* responses,
01666             const CvStatModelParams* params,
01667             const CvMat* CV_DEFAULT(0),
01668             const CvMat* sample_idx CV_DEFAULT(0),
01669             const CvMat* CV_DEFAULT(0), const CvMat* CV_DEFAULT(0) );
01670 
01671 /****************************************************************************************\
01672 *                               Estimate classifiers algorithms                          *
01673 \****************************************************************************************/
01674 typedef const CvMat* (CV_CDECL *CvStatModelEstimateGetMat)
01675                     ( const CvStatModel* estimateModel );
01676 
01677 typedef int (CV_CDECL *CvStatModelEstimateNextStep)
01678                     ( CvStatModel* estimateModel );
01679 
01680 typedef void (CV_CDECL *CvStatModelEstimateCheckClassifier)
01681                     ( CvStatModel* estimateModel,
01682                 const CvStatModel* model,
01683                 const CvMat*       features,
01684                       int          sample_t_flag,
01685                 const CvMat*       responses );
01686 
01687 typedef void (CV_CDECL *CvStatModelEstimateCheckClassifierEasy)
01688                     ( CvStatModel* estimateModel,
01689                 const CvStatModel* model );
01690 
01691 typedef float (CV_CDECL *CvStatModelEstimateGetCurrentResult)
01692                     ( const CvStatModel* estimateModel,
01693                             float*       correlation );
01694 
01695 typedef void (CV_CDECL *CvStatModelEstimateReset)
01696                     ( CvStatModel* estimateModel );
01697 
01698 //-------------------------------- Cross-validation --------------------------------------
01699 #define CV_CROSS_VALIDATION_ESTIMATE_CLASSIFIER_PARAM_FIELDS()    \
01700     CV_STAT_MODEL_PARAM_FIELDS();                                 \
01701     int     k_fold;                                               \
01702     int     is_regression;                                        \
01703     CvRNG*  rng
01704 
01705 typedef struct CvCrossValidationParams
01706 {
01707     CV_CROSS_VALIDATION_ESTIMATE_CLASSIFIER_PARAM_FIELDS();
01708 } CvCrossValidationParams;
01709 
01710 #define CV_CROSS_VALIDATION_ESTIMATE_CLASSIFIER_FIELDS()    \
01711     CvStatModelEstimateGetMat               getTrainIdxMat; \
01712     CvStatModelEstimateGetMat               getCheckIdxMat; \
01713     CvStatModelEstimateNextStep             nextStep;       \
01714     CvStatModelEstimateCheckClassifier      check;          \
01715     CvStatModelEstimateGetCurrentResult     getResult;      \
01716     CvStatModelEstimateReset                reset;          \
01717     int     is_regression;                                  \
01718     int     folds_all;                                      \
01719     int     samples_all;                                    \
01720     int*    sampleIdxAll;                                   \
01721     int*    folds;                                          \
01722     int     max_fold_size;                                  \
01723     int         current_fold;                               \
01724     int         is_checked;                                 \
01725     CvMat*      sampleIdxTrain;                             \
01726     CvMat*      sampleIdxEval;                              \
01727     CvMat*      predict_results;                            \
01728     int     correct_results;                                \
01729     int     all_results;                                    \
01730     double  sq_error;                                       \
01731     double  sum_correct;                                    \
01732     double  sum_predict;                                    \
01733     double  sum_cc;                                         \
01734     double  sum_pp;                                         \
01735     double  sum_cp
01736 
01737 typedef struct CvCrossValidationModel
01738 {
01739     CV_STAT_MODEL_FIELDS();
01740     CV_CROSS_VALIDATION_ESTIMATE_CLASSIFIER_FIELDS();
01741 } CvCrossValidationModel;
01742 
01743 CVAPI(CvStatModel*)
01744 cvCreateCrossValidationEstimateModel
01745            ( int                samples_all,
01746        const CvStatModelParams* estimateParams CV_DEFAULT(0),
01747        const CvMat*             sampleIdx CV_DEFAULT(0) );
01748 
01749 CVAPI(float)
01750 cvCrossValidation( const CvMat*             trueData,
01751                          int                tflag,
01752                    const CvMat*             trueClasses,
01753                          CvStatModel*     (*createClassifier)( const CvMat*,
01754                                                                      int,
01755                                                                const CvMat*,
01756                                                                const CvStatModelParams*,
01757                                                                const CvMat*,
01758                                                                const CvMat*,
01759                                                                const CvMat*,
01760                                                                const CvMat* ),
01761                    const CvStatModelParams* estimateParams CV_DEFAULT(0),
01762                    const CvStatModelParams* trainParams CV_DEFAULT(0),
01763                    const CvMat*             compIdx CV_DEFAULT(0),
01764                    const CvMat*             sampleIdx CV_DEFAULT(0),
01765                          CvStatModel**      pCrValModel CV_DEFAULT(0),
01766                    const CvMat*             typeMask CV_DEFAULT(0),
01767                    const CvMat*             missedMeasurementMask CV_DEFAULT(0) );
01768 #endif
01769 
01770 /****************************************************************************************\
01771 *                           Auxilary functions declarations                              *
01772 \****************************************************************************************/
01773 
01774 /* Generates <sample> from multivariate normal distribution, where <mean> - is an
01775    average row vector, <cov> - symmetric covariation matrix */
01776 CVAPI(void) cvRandMVNormal( CvMat* mean, CvMat* cov, CvMat* sample,
01777                            CvRNG* rng CV_DEFAULT(0) );
01778 
01779 /* Generates sample from gaussian mixture distribution */
01780 CVAPI(void) cvRandGaussMixture( CvMat* means[],
01781                                CvMat* covs[],
01782                                float weights[],
01783                                int clsnum,
01784                                CvMat* sample,
01785                                CvMat* sampClasses CV_DEFAULT(0) );
01786 
01787 #define CV_TS_CONCENTRIC_SPHERES 0
01788 
01789 /* creates test set */
01790 CVAPI(void) cvCreateTestSet( int type, CvMat** samples,
01791                  int num_samples,
01792                  int num_features,
01793                  CvMat** responses,
01794                  int num_classes, ... );
01795 
01796 
01797 #endif
01798 
01799 /****************************************************************************************\
01800 *                                      Data                                             *
01801 \****************************************************************************************/
01802 
01803 #include <map>
01804 #include <string>
01805 #include <iostream>
01806 
01807 #define CV_COUNT     0
01808 #define CV_PORTION   1
01809 
01810 struct CV_EXPORTS CvTrainTestSplit
01811 {
01812 public:
01813     CvTrainTestSplit();
01814     CvTrainTestSplit( int _train_sample_count, bool _mix = true);
01815     CvTrainTestSplit( float _train_sample_portion, bool _mix = true);
01816 
01817     union
01818     {
01819         int count;
01820         float portion;
01821     } train_sample_part;
01822     int train_sample_part_mode;
01823 
01824     union
01825     {
01826         int *count;
01827         float *portion;
01828     } *class_part;
01829     int class_part_mode;
01830 
01831     bool mix;    
01832 };
01833 
01834 class CV_EXPORTS CvMLData
01835 {
01836 public:
01837     CvMLData();
01838     virtual ~CvMLData();
01839 
01840     // returns:
01841     // 0 - OK  
01842     // 1 - file can not be opened or is not correct
01843     int read_csv(const char* filename);
01844 
01845     const CvMat* get_values(){ return values; };
01846 
01847     const CvMat* get_responses();
01848 
01849     const CvMat* get_missing(){ return missing; };
01850 
01851     void set_response_idx( int idx ); // old response become predictors, new response_idx = idx
01852                                       // if idx < 0 there will be no response
01853     int get_response_idx() { return response_idx; }
01854 
01855     const CvMat* get_train_sample_idx() { return train_sample_idx; };
01856     const CvMat* get_test_sample_idx() { return test_sample_idx; };
01857     void mix_train_and_test_idx();
01858     void set_train_test_split( const CvTrainTestSplit * spl);
01859     
01860     const CvMat* get_var_idx();
01861     void chahge_var_idx( int vi, bool state ); // state == true to set vi-variable as predictor
01862 
01863     const CvMat* get_var_types();
01864     int get_var_type( int var_idx ) { return var_types->data.ptr[var_idx]; };
01865     // following 2 methods enable to change vars type
01866     // use these methods to assign CV_VAR_CATEGORICAL type for categorical variable
01867     // with numerical labels; in the other cases var types are correctly determined automatically
01868     void set_var_types( const char* str );  // str examples:
01869                                             // "ord[0-17],cat[18]", "ord[0,2,4,10-12], cat[1,3,5-9,13,14]",
01870                                             // "cat", "ord" (all vars are categorical/ordered)
01871     void change_var_type( int var_idx, int type); // type in { CV_VAR_ORDERED, CV_VAR_CATEGORICAL }    
01872  
01873     void set_delimiter( char ch );
01874     char get_delimiter() { return delimiter; };
01875 
01876     void set_miss_ch( char ch );
01877     char get_miss_ch() { return miss_ch; };
01878     
01879 protected:
01880     virtual void clear();
01881 
01882     void str_to_flt_elem( const char* token, float& flt_elem, int& type);
01883     void free_train_test_idx();
01884     
01885     char delimiter;
01886     char miss_ch;
01887     //char flt_separator;
01888 
01889     CvMat* values;
01890     CvMat* missing;
01891     CvMat* var_types;
01892     CvMat* var_idx_mask;
01893 
01894     CvMat* response_out; // header
01895     CvMat* var_idx_out; // mat
01896     CvMat* var_types_out; // mat
01897 
01898     int response_idx;
01899 
01900     int train_sample_count;
01901     bool mix;
01902    
01903     int total_class_count;
01904     std::map<std::string, int> *class_map;
01905 
01906     CvMat* train_sample_idx;
01907     CvMat* test_sample_idx;
01908     int* sample_idx; // data of train_sample_idx and test_sample_idx
01909 
01910     CvRNG rng;
01911 };
01912 
01913 
01914 namespace cv
01915 {
01916     
01917 typedef CvStatModel StatModel;
01918 typedef CvParamGrid ParamGrid;
01919 typedef CvNormalBayesClassifier NormalBayesClassifier;
01920 typedef CvKNearest KNearest;
01921 typedef CvSVMParams SVMParams;
01922 typedef CvSVMKernel SVMKernel;
01923 typedef CvSVMSolver SVMSolver;
01924 typedef CvSVM SVM;
01925 typedef CvEMParams EMParams;
01926 typedef CvEM ExpectationMaximization;
01927 typedef CvDTreeParams DTreeParams;
01928 typedef CvMLData TrainData;
01929 typedef CvDTree DecisionTree;
01930 typedef CvForestTree ForestTree;
01931 typedef CvRTParams RandomTreeParams;
01932 typedef CvRTrees RandomTrees;
01933 typedef CvERTreeTrainData ERTreeTRainData;
01934 typedef CvForestERTree ERTree;
01935 typedef CvERTrees ERTrees;
01936 typedef CvBoostParams BoostParams;
01937 typedef CvBoostTree BoostTree;
01938 typedef CvBoost Boost;
01939 typedef CvANN_MLP_TrainParams ANN_MLP_TrainParams;
01940 typedef CvANN_MLP NeuralNet_MLP;
01941     
01942 }
01943 
01944 #endif
01945 /* End of file. */