ml.hpp
Go to the documentation of this file.
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 // By downloading, copying, installing or using the software you agree to this license.
6 // If you do not agree to this license, do not download, install,
7 // copy or use the software.
8 //
9 //
10 // Intel License Agreement
11 //
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
14 //
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
17 //
18 // * Redistribution's of source code must retain the above copyright notice,
19 // this list of conditions and the following disclaimer.
20 //
21 // * Redistribution's in binary form must reproduce the above copyright notice,
22 // this list of conditions and the following disclaimer in the documentation
23 // and/or other materials provided with the distribution.
24 //
25 // * The name of Intel Corporation may not be used to endorse or promote products
26 // derived from this software without specific prior written permission.
27 //
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
38 //
39 //M*/
40 
41 #ifndef __OPENCV_ML_HPP__
42 #define __OPENCV_ML_HPP__
43 
44 #include "opencv2/core/core.hpp"
45 #include <limits.h>
46 
47 #ifdef __cplusplus
48 
49 #include <map>
50 #include <string>
51 #include <iostream>
52 
53 // Apple defines a check() macro somewhere in the debug headers
54 // that interferes with a method definiton in this header
55 #undef check
56 
57 /****************************************************************************************\
58 * Main struct definitions *
59 \****************************************************************************************/
60 
61 /* log(2*PI) */
62 #define CV_LOG2PI (1.8378770664093454835606594728112)
63 
64 /* columns of <trainData> matrix are training samples */
65 #define CV_COL_SAMPLE 0
66 
67 /* rows of <trainData> matrix are training samples */
68 #define CV_ROW_SAMPLE 1
69 
70 #define CV_IS_ROW_SAMPLE(flags) ((flags) & CV_ROW_SAMPLE)
71 
72 struct CvVectors
73 {
74  int type;
75  int dims, count;
77  union
78  {
80  float** fl;
81  double** db;
82  } data;
83 };
84 
85 #if 0
86 /* A structure, representing the lattice range of statmodel parameters.
87  It is used for optimizing statmodel parameters by cross-validation method.
88  The lattice is logarithmic, so <step> must be greater then 1. */
89 typedef struct CvParamLattice
90 {
91  double min_val;
92  double max_val;
93  double step;
94 }
96 
97 CV_INLINE CvParamLattice cvParamLattice( double min_val, double max_val,
98  double log_step )
99 {
100  CvParamLattice pl;
101  pl.min_val = MIN( min_val, max_val );
102  pl.max_val = MAX( min_val, max_val );
103  pl.step = MAX( log_step, 1. );
104  return pl;
105 }
106 
108 {
109  CvParamLattice pl = {0,0,0};
110  return pl;
111 }
112 #endif
113 
114 /* Variable type */
115 #define CV_VAR_NUMERICAL 0
116 #define CV_VAR_ORDERED 0
117 #define CV_VAR_CATEGORICAL 1
118 
119 #define CV_TYPE_NAME_ML_SVM "opencv-ml-svm"
120 #define CV_TYPE_NAME_ML_KNN "opencv-ml-knn"
121 #define CV_TYPE_NAME_ML_NBAYES "opencv-ml-bayesian"
122 #define CV_TYPE_NAME_ML_EM "opencv-ml-em"
123 #define CV_TYPE_NAME_ML_BOOSTING "opencv-ml-boost-tree"
124 #define CV_TYPE_NAME_ML_TREE "opencv-ml-tree"
125 #define CV_TYPE_NAME_ML_ANN_MLP "opencv-ml-ann-mlp"
126 #define CV_TYPE_NAME_ML_CNN "opencv-ml-cnn"
127 #define CV_TYPE_NAME_ML_RTREES "opencv-ml-random-trees"
128 #define CV_TYPE_NAME_ML_ERTREES "opencv-ml-extremely-randomized-trees"
129 #define CV_TYPE_NAME_ML_GBT "opencv-ml-gradient-boosting-trees"
130 
131 #define CV_TRAIN_ERROR 0
132 #define CV_TEST_ERROR 1
133 
134 class CV_EXPORTS_W CvStatModel
135 {
136 public:
137  CvStatModel();
138  virtual ~CvStatModel();
139 
140  virtual void clear();
141 
142  CV_WRAP virtual void save( const char* filename, const char* name=0 ) const;
143  CV_WRAP virtual void load( const char* filename, const char* name=0 );
144 
145  virtual void write( CvFileStorage* storage, const char* name ) const;
146  virtual void read( CvFileStorage* storage, CvFileNode* node );
147 
148 protected:
149  const char* default_model_name;
150 };
151 
152 /****************************************************************************************\
153 * Normal Bayes Classifier *
154 \****************************************************************************************/
155 
156 /* The structure, representing the grid range of statmodel parameters.
157  It is used for optimizing statmodel accuracy by varying model parameters,
158  the accuracy estimate being computed by cross-validation.
159  The grid is logarithmic, so <step> must be greater then 1. */
160 
161 class CvMLData;
162 
163 struct CV_EXPORTS_W_MAP CvParamGrid
164 {
165  // SVM params type
166  enum { SVM_C=0, SVM_GAMMA=1, SVM_P=2, SVM_NU=3, SVM_COEF=4, SVM_DEGREE=5 };
167 
169  {
170  min_val = max_val = step = 0;
171  }
172 
173  CvParamGrid( double min_val, double max_val, double log_step );
174  //CvParamGrid( int param_id );
175  bool check() const;
176 
177  CV_PROP_RW double min_val;
178  CV_PROP_RW double max_val;
179  CV_PROP_RW double step;
180 };
181 
182 inline CvParamGrid::CvParamGrid( double _min_val, double _max_val, double _log_step )
183 {
184  min_val = _min_val;
185  max_val = _max_val;
186  step = _log_step;
187 }
188 
189 class CV_EXPORTS_W CvNormalBayesClassifier : public CvStatModel
190 {
191 public:
192  CV_WRAP CvNormalBayesClassifier();
193  virtual ~CvNormalBayesClassifier();
194 
195  CvNormalBayesClassifier( const CvMat* trainData, const CvMat* responses,
196  const CvMat* varIdx=0, const CvMat* sampleIdx=0 );
197 
198  virtual bool train( const CvMat* trainData, const CvMat* responses,
199  const CvMat* varIdx = 0, const CvMat* sampleIdx=0, bool update=false );
200 
201  virtual float predict( const CvMat* samples, CV_OUT CvMat* results=0 ) const;
202  CV_WRAP virtual void clear();
203 
204  CV_WRAP CvNormalBayesClassifier( const cv::Mat& trainData, const cv::Mat& responses,
205  const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat() );
206  CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
207  const cv::Mat& varIdx = cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
208  bool update=false );
209  CV_WRAP virtual float predict( const cv::Mat& samples, CV_OUT cv::Mat* results=0 ) const;
210 
211  virtual void write( CvFileStorage* storage, const char* name ) const;
212  virtual void read( CvFileStorage* storage, CvFileNode* node );
213 
214 protected:
215  int var_count, var_all;
225 };
226 
227 
228 /****************************************************************************************\
229 * K-Nearest Neighbour Classifier *
230 \****************************************************************************************/
231 
232 // k Nearest Neighbors
233 class CV_EXPORTS_W CvKNearest : public CvStatModel
234 {
235 public:
236 
237  CV_WRAP CvKNearest();
238  virtual ~CvKNearest();
239 
240  CvKNearest( const CvMat* trainData, const CvMat* responses,
241  const CvMat* sampleIdx=0, bool isRegression=false, int max_k=32 );
242 
243  virtual bool train( const CvMat* trainData, const CvMat* responses,
244  const CvMat* sampleIdx=0, bool is_regression=false,
245  int maxK=32, bool updateBase=false );
246 
247  virtual float find_nearest( const CvMat* samples, int k, CV_OUT CvMat* results=0,
248  const float** neighbors=0, CV_OUT CvMat* neighborResponses=0, CV_OUT CvMat* dist=0 ) const;
249 
250  CV_WRAP CvKNearest( const cv::Mat& trainData, const cv::Mat& responses,
251  const cv::Mat& sampleIdx=cv::Mat(), bool isRegression=false, int max_k=32 );
252 
253  CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
254  const cv::Mat& sampleIdx=cv::Mat(), bool isRegression=false,
255  int maxK=32, bool updateBase=false );
256 
257  virtual float find_nearest( const cv::Mat& samples, int k, cv::Mat* results=0,
258  const float** neighbors=0, cv::Mat* neighborResponses=0,
259  cv::Mat* dist=0 ) const;
260  CV_WRAP virtual float find_nearest( const cv::Mat& samples, int k, CV_OUT cv::Mat& results,
261  CV_OUT cv::Mat& neighborResponses, CV_OUT cv::Mat& dists) const;
262 
263  virtual void clear();
264  int get_max_k() const;
265  int get_var_count() const;
266  int get_sample_count() const;
267  bool is_regression() const;
268 
269  virtual float write_results( int k, int k1, int start, int end,
270  const float* neighbor_responses, const float* dist, CvMat* _results,
271  CvMat* _neighbor_responses, CvMat* _dist, Cv32suf* sort_buf ) const;
272 
273  virtual void find_neighbors_direct( const CvMat* _samples, int k, int start, int end,
274  float* neighbor_responses, const float** neighbors, float* dist ) const;
275 
276 protected:
277 
278  int max_k, var_count;
279  int total;
282 };
283 
284 /****************************************************************************************\
285 * Support Vector Machines *
286 \****************************************************************************************/
287 
288 // SVM training parameters
289 struct CV_EXPORTS_W_MAP CvSVMParams
290 {
291  CvSVMParams();
292  CvSVMParams( int svm_type, int kernel_type,
293  double degree, double gamma, double coef0,
294  double Cvalue, double nu, double p,
295  CvMat* class_weights, CvTermCriteria term_crit );
296 
297  CV_PROP_RW int svm_type;
298  CV_PROP_RW int kernel_type;
299  CV_PROP_RW double degree; // for poly
300  CV_PROP_RW double gamma; // for poly/rbf/sigmoid
301  CV_PROP_RW double coef0; // for poly/sigmoid
302 
303  CV_PROP_RW double C; // for CV_SVM_C_SVC, CV_SVM_EPS_SVR and CV_SVM_NU_SVR
304  CV_PROP_RW double nu; // for CV_SVM_NU_SVC, CV_SVM_ONE_CLASS, and CV_SVM_NU_SVR
305  CV_PROP_RW double p; // for CV_SVM_EPS_SVR
306  CvMat* class_weights; // for CV_SVM_C_SVC
307  CV_PROP_RW CvTermCriteria term_crit; // termination criteria
308 };
309 
310 
311 struct CV_EXPORTS CvSVMKernel
312 {
313  typedef void (CvSVMKernel::*Calc)( int vec_count, int vec_size, const float** vecs,
314  const float* another, float* results );
315  CvSVMKernel();
316  CvSVMKernel( const CvSVMParams* params, Calc _calc_func );
317  virtual bool create( const CvSVMParams* params, Calc _calc_func );
318  virtual ~CvSVMKernel();
319 
320  virtual void clear();
321  virtual void calc( int vcount, int n, const float** vecs, const float* another, float* results );
322 
324  Calc calc_func;
325 
326  virtual void calc_non_rbf_base( int vec_count, int vec_size, const float** vecs,
327  const float* another, float* results,
328  double alpha, double beta );
329 
330  virtual void calc_linear( int vec_count, int vec_size, const float** vecs,
331  const float* another, float* results );
332  virtual void calc_rbf( int vec_count, int vec_size, const float** vecs,
333  const float* another, float* results );
334  virtual void calc_poly( int vec_count, int vec_size, const float** vecs,
335  const float* another, float* results );
336  virtual void calc_sigmoid( int vec_count, int vec_size, const float** vecs,
337  const float* another, float* results );
338 };
339 
340 
342 {
345  float* data;
346 };
347 
348 
350 {
351  double obj;
352  double rho;
355  double r; // for Solver_NU
356 };
357 
358 class CV_EXPORTS CvSVMSolver
359 {
360 public:
361  typedef bool (CvSVMSolver::*SelectWorkingSet)( int& i, int& j );
362  typedef float* (CvSVMSolver::*GetRow)( int i, float* row, float* dst, bool existed );
363  typedef void (CvSVMSolver::*CalcRho)( double& rho, double& r );
364 
365  CvSVMSolver();
366 
367  CvSVMSolver( int count, int var_count, const float** samples, schar* y,
368  int alpha_count, double* alpha, double Cp, double Cn,
369  CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
370  SelectWorkingSet select_working_set, CalcRho calc_rho );
371  virtual bool create( int count, int var_count, const float** samples, schar* y,
372  int alpha_count, double* alpha, double Cp, double Cn,
373  CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
374  SelectWorkingSet select_working_set, CalcRho calc_rho );
375  virtual ~CvSVMSolver();
376 
377  virtual void clear();
378  virtual bool solve_generic( CvSVMSolutionInfo& si );
379 
380  virtual bool solve_c_svc( int count, int var_count, const float** samples, schar* y,
381  double Cp, double Cn, CvMemStorage* storage,
382  CvSVMKernel* kernel, double* alpha, CvSVMSolutionInfo& si );
383  virtual bool solve_nu_svc( int count, int var_count, const float** samples, schar* y,
385  double* alpha, CvSVMSolutionInfo& si );
386  virtual bool solve_one_class( int count, int var_count, const float** samples,
388  double* alpha, CvSVMSolutionInfo& si );
389 
390  virtual bool solve_eps_svr( int count, int var_count, const float** samples, const float* y,
392  double* alpha, CvSVMSolutionInfo& si );
393 
394  virtual bool solve_nu_svr( int count, int var_count, const float** samples, const float* y,
396  double* alpha, CvSVMSolutionInfo& si );
397 
398  virtual float* get_row_base( int i, bool* _existed );
399  virtual float* get_row( int i, float* dst );
400 
405  const float** samples;
410 
412 
413  double* G;
414  double* alpha;
415 
416  // -1 - lower bound, 0 - free, 1 - upper bound
418 
420  double* b;
421  float* buf[2];
422  double eps;
423  int max_iter;
424  double C[2]; // C[0] == Cn, C[1] == Cp
426 
427  SelectWorkingSet select_working_set_func;
428  CalcRho calc_rho_func;
429  GetRow get_row_func;
430 
431  virtual bool select_working_set( int& i, int& j );
432  virtual bool select_working_set_nu_svm( int& i, int& j );
433  virtual void calc_rho( double& rho, double& r );
434  virtual void calc_rho_nu_svm( double& rho, double& r );
435 
436  virtual float* get_row_svc( int i, float* row, float* dst, bool existed );
437  virtual float* get_row_one_class( int i, float* row, float* dst, bool existed );
438  virtual float* get_row_svr( int i, float* row, float* dst, bool existed );
439 };
440 
441 
443 {
444  double rho;
445  int sv_count;
446  double* alpha;
447  int* sv_index;
448 };
449 
450 
451 // SVM model
452 class CV_EXPORTS_W CvSVM : public CvStatModel
453 {
454 public:
455  // SVM type
456  enum { C_SVC=100, NU_SVC=101, ONE_CLASS=102, EPS_SVR=103, NU_SVR=104 };
457 
458  // SVM kernel type
459  enum { LINEAR=0, POLY=1, RBF=2, SIGMOID=3 };
460 
461  // SVM params type
462  enum { C=0, GAMMA=1, P=2, NU=3, COEF=4, DEGREE=5 };
463 
464  CV_WRAP CvSVM();
465  virtual ~CvSVM();
466 
467  CvSVM( const CvMat* trainData, const CvMat* responses,
468  const CvMat* varIdx=0, const CvMat* sampleIdx=0,
470 
471  virtual bool train( const CvMat* trainData, const CvMat* responses,
472  const CvMat* varIdx=0, const CvMat* sampleIdx=0,
474 
475  virtual bool train_auto( const CvMat* trainData, const CvMat* responses,
476  const CvMat* varIdx, const CvMat* sampleIdx, CvSVMParams params,
477  int kfold = 10,
478  CvParamGrid Cgrid = get_default_grid(CvSVM::C),
479  CvParamGrid gammaGrid = get_default_grid(CvSVM::GAMMA),
480  CvParamGrid pGrid = get_default_grid(CvSVM::P),
481  CvParamGrid nuGrid = get_default_grid(CvSVM::NU),
482  CvParamGrid coeffGrid = get_default_grid(CvSVM::COEF),
483  CvParamGrid degreeGrid = get_default_grid(CvSVM::DEGREE),
484  bool balanced=false );
485 
486  virtual float predict( const CvMat* sample, bool returnDFVal=false ) const;
487  virtual float predict( const CvMat* samples, CV_OUT CvMat* results ) const;
488 
489  CV_WRAP CvSVM( const cv::Mat& trainData, const cv::Mat& responses,
490  const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
491  CvSVMParams params=CvSVMParams() );
492 
493  CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
494  const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
495  CvSVMParams params=CvSVMParams() );
496 
497  CV_WRAP virtual bool train_auto( const cv::Mat& trainData, const cv::Mat& responses,
498  const cv::Mat& varIdx, const cv::Mat& sampleIdx, CvSVMParams params,
499  int k_fold = 10,
506  bool balanced=false);
507  CV_WRAP virtual float predict( const cv::Mat& sample, bool returnDFVal=false ) const;
508  CV_WRAP_AS(predict_all) void predict( cv::InputArray samples, cv::OutputArray results ) const;
509 
510  CV_WRAP virtual int get_support_vector_count() const;
511  virtual const float* get_support_vector(int i) const;
512  virtual CvSVMParams get_params() const { return params; };
513  CV_WRAP virtual void clear();
514 
515  static CvParamGrid get_default_grid( int param_id );
516 
517  virtual void write( CvFileStorage* storage, const char* name ) const;
518  virtual void read( CvFileStorage* storage, CvFileNode* node );
519  CV_WRAP int get_var_count() const { return var_idx ? var_idx->cols : var_all; }
520 
521 protected:
522 
523  virtual bool set_params( const CvSVMParams& params );
524  virtual bool train1( int sample_count, int var_count, const float** samples,
525  const void* responses, double Cp, double Cn,
526  CvMemStorage* _storage, double* alpha, double& rho );
527  virtual bool do_train( int svm_type, int sample_count, int var_count, const float** samples,
528  const CvMat* responses, CvMemStorage* _storage, double* alpha );
529  virtual void create_kernel();
530  virtual void create_solver();
531 
532  virtual float predict( const float* row_sample, int row_len, bool returnDFVal=false ) const;
533 
534  virtual void write_params( CvFileStorage* fs ) const;
535  virtual void read_params( CvFileStorage* fs, CvFileNode* node );
536 
537  void optimize_linear_svm();
538 
541  int var_all;
542  float** sv;
543  int sv_total;
548 
551 
552 private:
553  CvSVM(const CvSVM&);
554  CvSVM& operator = (const CvSVM&);
555 };
556 
557 /****************************************************************************************\
558 * Expectation - Maximization *
559 \****************************************************************************************/
560 namespace cv
561 {
562 class CV_EXPORTS_W EM : public Algorithm
563 {
564 public:
565  // Type of covariation matrices
566  enum {COV_MAT_SPHERICAL=0, COV_MAT_DIAGONAL=1, COV_MAT_GENERIC=2, COV_MAT_DEFAULT=COV_MAT_DIAGONAL};
567 
568  // Default parameters
569  enum {DEFAULT_NCLUSTERS=5, DEFAULT_MAX_ITERS=100};
570 
571  // The initial step
572  enum {START_E_STEP=1, START_M_STEP=2, START_AUTO_STEP=0};
573 
574  CV_WRAP EM(int nclusters=EM::DEFAULT_NCLUSTERS, int covMatType=EM::COV_MAT_DIAGONAL,
576  EM::DEFAULT_MAX_ITERS, FLT_EPSILON));
577 
578  virtual ~EM();
579  CV_WRAP virtual void clear();
580 
581  CV_WRAP virtual bool train(InputArray samples,
582  OutputArray logLikelihoods=noArray(),
584  OutputArray probs=noArray());
585 
586  CV_WRAP virtual bool trainE(InputArray samples,
587  InputArray means0,
588  InputArray covs0=noArray(),
589  InputArray weights0=noArray(),
590  OutputArray logLikelihoods=noArray(),
592  OutputArray probs=noArray());
593 
594  CV_WRAP virtual bool trainM(InputArray samples,
595  InputArray probs0,
596  OutputArray logLikelihoods=noArray(),
598  OutputArray probs=noArray());
599 
600  CV_WRAP Vec2d predict(InputArray sample,
601  OutputArray probs=noArray()) const;
602 
603  CV_WRAP bool isTrained() const;
604 
605  AlgorithmInfo* info() const;
606  virtual void read(const FileNode& fn);
607 
608 protected:
609 
610  virtual void setTrainData(int startStep, const Mat& samples,
611  const Mat* probs0,
612  const Mat* means0,
613  const vector<Mat>* covs0,
614  const Mat* weights0);
615 
616  bool doTrain(int startStep,
617  OutputArray logLikelihoods,
619  OutputArray probs);
620  virtual void eStep();
621  virtual void mStep();
622 
623  void clusterTrainSamples();
624  void decomposeCovs();
625  void computeLogWeightDivDet();
626 
627  Vec2d computeProbabilities(const Mat& sample, Mat* probs) const;
628 
629  // all inner matrices have type CV_64FC1
630  CV_PROP_RW int nclusters;
631  CV_PROP_RW int covMatType;
632  CV_PROP_RW int maxIters;
633  CV_PROP_RW double epsilon;
634 
639 
640  CV_PROP Mat weights;
641  CV_PROP Mat means;
642  CV_PROP vector<Mat> covs;
643 
644  vector<Mat> covsEigenValues;
645  vector<Mat> covsRotateMats;
646  vector<Mat> invCovsEigenValues;
648 };
649 } // namespace cv
650 
651 /****************************************************************************************\
652 * Decision Tree *
653 \****************************************************************************************/\
655 {
656  unsigned short* u;
657  int* i;
658 };
659 
660 
661 #define CV_DTREE_CAT_DIR(idx,subset) \
662  (2*((subset[(idx)>>5]&(1 << ((idx) & 31)))==0)-1)
663 
665 {
666  int var_idx;
668  int inversed;
669  float quality;
671  union
672  {
673  int subset[2];
674  struct
675  {
676  float c;
678  }
679  ord;
680  };
681 };
682 
684 {
686  int Tn;
687  double value;
688 
692 
694 
696  int depth;
697  int* num_valid;
698  int offset;
699  int buf_idx;
700  double maxlr;
701 
702  // global pruning data
704  double alpha;
706 
707  // cross-validation pruning data
708  int* cv_Tn;
709  double* cv_node_risk;
710  double* cv_node_error;
711 
712  int get_num_valid(int vi) { return num_valid ? num_valid[vi] : sample_count; }
713  void set_num_valid(int vi, int n) { if( num_valid ) num_valid[vi] = n; }
714 };
715 
716 
717 struct CV_EXPORTS_W_MAP CvDTreeParams
718 {
719  CV_PROP_RW int max_categories;
720  CV_PROP_RW int max_depth;
721  CV_PROP_RW int min_sample_count;
722  CV_PROP_RW int cv_folds;
723  CV_PROP_RW bool use_surrogates;
724  CV_PROP_RW bool use_1se_rule;
725  CV_PROP_RW bool truncate_pruned_tree;
726  CV_PROP_RW float regression_accuracy;
727  const float* priors;
728 
729  CvDTreeParams();
730  CvDTreeParams( int max_depth, int min_sample_count,
731  float regression_accuracy, bool use_surrogates,
732  int max_categories, int cv_folds,
733  bool use_1se_rule, bool truncate_pruned_tree,
734  const float* priors );
735 };
736 
737 
738 struct CV_EXPORTS CvDTreeTrainData
739 {
741  CvDTreeTrainData( const CvMat* trainData, int tflag,
742  const CvMat* responses, const CvMat* varIdx=0,
743  const CvMat* sampleIdx=0, const CvMat* varType=0,
744  const CvMat* missingDataMask=0,
745  const CvDTreeParams& params=CvDTreeParams(),
746  bool _shared=false, bool _add_labels=false );
747  virtual ~CvDTreeTrainData();
748 
749  virtual void set_data( const CvMat* trainData, int tflag,
750  const CvMat* responses, const CvMat* varIdx=0,
751  const CvMat* sampleIdx=0, const CvMat* varType=0,
752  const CvMat* missingDataMask=0,
753  const CvDTreeParams& params=CvDTreeParams(),
754  bool _shared=false, bool _add_labels=false,
755  bool _update_data=false );
756  virtual void do_responses_copy();
757 
758  virtual void get_vectors( const CvMat* _subsample_idx,
759  float* values, uchar* missing, float* responses, bool get_class_idx=false );
760 
761  virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
762 
763  virtual void write_params( CvFileStorage* fs ) const;
764  virtual void read_params( CvFileStorage* fs, CvFileNode* node );
765 
766  // release all the data
767  virtual void clear();
768 
769  int get_num_classes() const;
770  int get_var_type(int vi) const;
771  int get_work_var_count() const {return work_var_count;}
772 
773  virtual const float* get_ord_responses( CvDTreeNode* n, float* values_buf, int* sample_indices_buf );
774  virtual const int* get_class_labels( CvDTreeNode* n, int* labels_buf );
775  virtual const int* get_cv_labels( CvDTreeNode* n, int* labels_buf );
776  virtual const int* get_sample_indices( CvDTreeNode* n, int* indices_buf );
777  virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf );
778  virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* sorted_indices_buf,
779  const float** ord_values, const int** sorted_indices, int* sample_indices_buf );
780  virtual int get_child_buf_idx( CvDTreeNode* n );
781 
783 
784  virtual bool set_params( const CvDTreeParams& params );
785  virtual CvDTreeNode* new_node( CvDTreeNode* parent, int count,
786  int storage_idx, int offset );
787 
788  virtual CvDTreeSplit* new_split_ord( int vi, float cmp_val,
789  int split_point, int inversed, float quality );
790  virtual CvDTreeSplit* new_split_cat( int vi, float quality );
791  virtual void free_node_data( CvDTreeNode* node );
792  virtual void free_train_data();
793  virtual void free_node( CvDTreeNode* node );
794 
795  int sample_count, var_all, var_count, max_c_count;
796  int ord_var_count, cat_var_count, work_var_count;
797  bool have_labels, have_priors;
799  int tflag;
800 
802  const CvMat* responses;
803  CvMat* responses_copy; // used in Boosting
804 
805  int buf_count, buf_size; // buf_size is obsolete, please do not use it, use expression ((int64)buf->rows * (int64)buf->cols / buf_count) instead
806  bool shared;
808 
812 
815  inline size_t get_length_subbuf() const
816  {
817  size_t res = (size_t)(work_var_count + 1) * (size_t)sample_count;
818  return res;
819  }
820 
823 
825  CvMat* var_type; // i-th element =
826  // k<0 - ordered
827  // k>=0 - categorical, see k-th element of cat_* arrays
830 
832 
835 
837 
842 
844 };
845 
846 class CvDTree;
847 class CvForestTree;
848 
849 namespace cv
850 {
851  struct DTreeBestSplitFinder;
852  struct ForestTreeBestSplitFinder;
853 }
854 
855 class CV_EXPORTS_W CvDTree : public CvStatModel
856 {
857 public:
858  CV_WRAP CvDTree();
859  virtual ~CvDTree();
860 
861  virtual bool train( const CvMat* trainData, int tflag,
862  const CvMat* responses, const CvMat* varIdx=0,
863  const CvMat* sampleIdx=0, const CvMat* varType=0,
864  const CvMat* missingDataMask=0,
865  CvDTreeParams params=CvDTreeParams() );
866 
867  virtual bool train( CvMLData* trainData, CvDTreeParams params=CvDTreeParams() );
868 
869  // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
870  virtual float calc_error( CvMLData* trainData, int type, std::vector<float> *resp = 0 );
871 
872  virtual bool train( CvDTreeTrainData* trainData, const CvMat* subsampleIdx );
873 
874  virtual CvDTreeNode* predict( const CvMat* sample, const CvMat* missingDataMask=0,
875  bool preprocessedInput=false ) const;
876 
877  CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
878  const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
879  const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
880  const cv::Mat& missingDataMask=cv::Mat(),
881  CvDTreeParams params=CvDTreeParams() );
882 
883  CV_WRAP virtual CvDTreeNode* predict( const cv::Mat& sample, const cv::Mat& missingDataMask=cv::Mat(),
884  bool preprocessedInput=false ) const;
885  CV_WRAP virtual cv::Mat getVarImportance();
886 
887  virtual const CvMat* get_var_importance();
888  CV_WRAP virtual void clear();
889 
890  virtual void read( CvFileStorage* fs, CvFileNode* node );
891  virtual void write( CvFileStorage* fs, const char* name ) const;
892 
893  // special read & write methods for trees in the tree ensembles
894  virtual void read( CvFileStorage* fs, CvFileNode* node,
896  virtual void write( CvFileStorage* fs ) const;
897 
898  const CvDTreeNode* get_root() const;
899  int get_pruned_tree_idx() const;
900  CvDTreeTrainData* get_data();
901 
902 protected:
903  friend struct cv::DTreeBestSplitFinder;
904 
905  virtual bool do_train( const CvMat* _subsample_idx );
906 
907  virtual void try_split_node( CvDTreeNode* n );
908  virtual void split_node_data( CvDTreeNode* n );
909  virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
910  virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi,
911  float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
912  virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
913  float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
914  virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi,
915  float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
916  virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi,
917  float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
918  virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
919  virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
920  virtual double calc_node_dir( CvDTreeNode* node );
921  virtual void complete_node_dir( CvDTreeNode* node );
922  virtual void cluster_categories( const int* vectors, int vector_count,
923  int var_count, int* sums, int k, int* cluster_labels );
924 
925  virtual void calc_node_value( CvDTreeNode* node );
926 
927  virtual void prune_cv();
928  virtual double update_tree_rnc( int T, int fold );
929  virtual int cut_tree( int T, int fold, double min_alpha );
930  virtual void free_prune_data(bool cut_tree);
931  virtual void free_tree();
932 
933  virtual void write_node( CvFileStorage* fs, CvDTreeNode* node ) const;
934  virtual void write_split( CvFileStorage* fs, CvDTreeSplit* split ) const;
935  virtual CvDTreeNode* read_node( CvFileStorage* fs, CvFileNode* node, CvDTreeNode* parent );
936  virtual CvDTreeSplit* read_split( CvFileStorage* fs, CvFileNode* node );
937  virtual void write_tree_nodes( CvFileStorage* fs ) const;
938  virtual void read_tree_nodes( CvFileStorage* fs, CvFileNode* node );
939 
943 
944 public:
946 };
947 
948 
949 /****************************************************************************************\
950 * Random Trees Classifier *
951 \****************************************************************************************/
952 
953 class CvRTrees;
954 
955 class CV_EXPORTS CvForestTree: public CvDTree
956 {
957 public:
958  CvForestTree();
959  virtual ~CvForestTree();
960 
961  virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx, CvRTrees* forest );
962 
963  virtual int get_var_count() const {return data ? data->var_count : 0;}
964  virtual void read( CvFileStorage* fs, CvFileNode* node, CvRTrees* forest, CvDTreeTrainData* _data );
965 
966  /* dummy methods to avoid warnings: BEGIN */
967  virtual bool train( const CvMat* trainData, int tflag,
968  const CvMat* responses, const CvMat* varIdx=0,
969  const CvMat* sampleIdx=0, const CvMat* varType=0,
970  const CvMat* missingDataMask=0,
971  CvDTreeParams params=CvDTreeParams() );
972 
973  virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx );
974  virtual void read( CvFileStorage* fs, CvFileNode* node );
975  virtual void read( CvFileStorage* fs, CvFileNode* node,
977  /* dummy methods to avoid warnings: END */
978 
979 protected:
980  friend struct cv::ForestTreeBestSplitFinder;
981 
984 };
985 
986 
987 struct CV_EXPORTS_W_MAP CvRTParams : public CvDTreeParams
988 {
989  //Parameters for the forest
990  CV_PROP_RW bool calc_var_importance; // true <=> RF processes variable importance
991  CV_PROP_RW int nactive_vars;
993 
994  CvRTParams();
995  CvRTParams( int max_depth, int min_sample_count,
996  float regression_accuracy, bool use_surrogates,
997  int max_categories, const float* priors, bool calc_var_importance,
998  int nactive_vars, int max_num_of_trees_in_the_forest,
999  float forest_accuracy, int termcrit_type );
1000 };
1001 
1002 
1003 class CV_EXPORTS_W CvRTrees : public CvStatModel
1004 {
1005 public:
1006  CV_WRAP CvRTrees();
1007  virtual ~CvRTrees();
1008  virtual bool train( const CvMat* trainData, int tflag,
1009  const CvMat* responses, const CvMat* varIdx=0,
1010  const CvMat* sampleIdx=0, const CvMat* varType=0,
1011  const CvMat* missingDataMask=0,
1012  CvRTParams params=CvRTParams() );
1013 
1014  virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );
1015  virtual float predict( const CvMat* sample, const CvMat* missing = 0 ) const;
1016  virtual float predict_prob( const CvMat* sample, const CvMat* missing = 0 ) const;
1017 
1018  CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
1019  const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
1020  const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
1021  const cv::Mat& missingDataMask=cv::Mat(),
1022  CvRTParams params=CvRTParams() );
1023  CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;
1024  CV_WRAP virtual float predict_prob( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;
1025  CV_WRAP virtual cv::Mat getVarImportance();
1026 
1027  CV_WRAP virtual void clear();
1028 
1029  virtual const CvMat* get_var_importance();
1030  virtual float get_proximity( const CvMat* sample1, const CvMat* sample2,
1031  const CvMat* missing1 = 0, const CvMat* missing2 = 0 ) const;
1032 
1033  virtual float calc_error( CvMLData* data, int type , std::vector<float>* resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
1034 
1035  virtual float get_train_error();
1036 
1037  virtual void read( CvFileStorage* fs, CvFileNode* node );
1038  virtual void write( CvFileStorage* fs, const char* name ) const;
1039 
1040  CvMat* get_active_var_mask();
1041  CvRNG* get_rng();
1042 
1043  int get_tree_count() const;
1044  CvForestTree* get_tree(int i) const;
1045 
1046 protected:
1047  virtual std::string getName() const;
1048 
1049  virtual bool grow_forest( const CvTermCriteria term_crit );
1050 
1051  // array of the trees of the forest
1054  int ntrees;
1056  double oob_error;
1059 
1062 };
1063 
1064 /****************************************************************************************\
1065 * Extremely randomized trees Classifier *
1066 \****************************************************************************************/
1067 struct CV_EXPORTS CvERTreeTrainData : public CvDTreeTrainData
1068 {
1069  virtual void set_data( const CvMat* trainData, int tflag,
1070  const CvMat* responses, const CvMat* varIdx=0,
1071  const CvMat* sampleIdx=0, const CvMat* varType=0,
1072  const CvMat* missingDataMask=0,
1073  const CvDTreeParams& params=CvDTreeParams(),
1074  bool _shared=false, bool _add_labels=false,
1075  bool _update_data=false );
1076  virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* missing_buf,
1077  const float** ord_values, const int** missing, int* sample_buf = 0 );
1078  virtual const int* get_sample_indices( CvDTreeNode* n, int* indices_buf );
1079  virtual const int* get_cv_labels( CvDTreeNode* n, int* labels_buf );
1080  virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf );
1081  virtual void get_vectors( const CvMat* _subsample_idx, float* values, uchar* missing,
1082  float* responses, bool get_class_idx=false );
1083  virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
1085 };
1086 
1087 class CV_EXPORTS CvForestERTree : public CvForestTree
1088 {
1089 protected:
1090  virtual double calc_node_dir( CvDTreeNode* node );
1091  virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi,
1092  float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1093  virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
1094  float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1095  virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi,
1096  float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1097  virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi,
1098  float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1099  virtual void split_node_data( CvDTreeNode* n );
1100 };
1101 
1102 class CV_EXPORTS_W CvERTrees : public CvRTrees
1103 {
1104 public:
1105  CV_WRAP CvERTrees();
1106  virtual ~CvERTrees();
1107  virtual bool train( const CvMat* trainData, int tflag,
1108  const CvMat* responses, const CvMat* varIdx=0,
1109  const CvMat* sampleIdx=0, const CvMat* varType=0,
1110  const CvMat* missingDataMask=0,
1111  CvRTParams params=CvRTParams());
1112  CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
1113  const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
1114  const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
1115  const cv::Mat& missingDataMask=cv::Mat(),
1116  CvRTParams params=CvRTParams());
1117  virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );
1118 protected:
1119  virtual std::string getName() const;
1120  virtual bool grow_forest( const CvTermCriteria term_crit );
1121 };
1122 
1123 
1124 /****************************************************************************************\
1125 * Boosted tree classifier *
1126 \****************************************************************************************/
1127 
1128 struct CV_EXPORTS_W_MAP CvBoostParams : public CvDTreeParams
1129 {
1130  CV_PROP_RW int boost_type;
1131  CV_PROP_RW int weak_count;
1132  CV_PROP_RW int split_criteria;
1133  CV_PROP_RW double weight_trim_rate;
1134 
1135  CvBoostParams();
1136  CvBoostParams( int boost_type, int weak_count, double weight_trim_rate,
1137  int max_depth, bool use_surrogates, const float* priors );
1138 };
1139 
1140 
1141 class CvBoost;
1142 
1143 class CV_EXPORTS CvBoostTree: public CvDTree
1144 {
1145 public:
1146  CvBoostTree();
1147  virtual ~CvBoostTree();
1148 
1149  virtual bool train( CvDTreeTrainData* trainData,
1150  const CvMat* subsample_idx, CvBoost* ensemble );
1151 
1152  virtual void scale( double s );
1153  virtual void read( CvFileStorage* fs, CvFileNode* node,
1154  CvBoost* ensemble, CvDTreeTrainData* _data );
1155  virtual void clear();
1156 
1157  /* dummy methods to avoid warnings: BEGIN */
1158  virtual bool train( const CvMat* trainData, int tflag,
1159  const CvMat* responses, const CvMat* varIdx=0,
1160  const CvMat* sampleIdx=0, const CvMat* varType=0,
1161  const CvMat* missingDataMask=0,
1162  CvDTreeParams params=CvDTreeParams() );
1163  virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx );
1164 
1165  virtual void read( CvFileStorage* fs, CvFileNode* node );
1166  virtual void read( CvFileStorage* fs, CvFileNode* node,
1168  /* dummy methods to avoid warnings: END */
1169 
1170 protected:
1171 
1172  virtual void try_split_node( CvDTreeNode* n );
1173  virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
1174  virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
1175  virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi,
1176  float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1177  virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
1178  float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1179  virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi,
1180  float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1181  virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi,
1182  float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1183  virtual void calc_node_value( CvDTreeNode* n );
1184  virtual double calc_node_dir( CvDTreeNode* n );
1185 
1187 };
1188 
1189 
1190 class CV_EXPORTS_W CvBoost : public CvStatModel
1191 {
1192 public:
1193  // Boosting type
1194  enum { DISCRETE=0, REAL=1, LOGIT=2, GENTLE=3 };
1195 
1196  // Splitting criteria
1197  enum { DEFAULT=0, GINI=1, MISCLASS=3, SQERR=4 };
1198 
1199  CV_WRAP CvBoost();
1200  virtual ~CvBoost();
1201 
1202  CvBoost( const CvMat* trainData, int tflag,
1203  const CvMat* responses, const CvMat* varIdx=0,
1204  const CvMat* sampleIdx=0, const CvMat* varType=0,
1205  const CvMat* missingDataMask=0,
1206  CvBoostParams params=CvBoostParams() );
1207 
1208  virtual bool train( const CvMat* trainData, int tflag,
1209  const CvMat* responses, const CvMat* varIdx=0,
1210  const CvMat* sampleIdx=0, const CvMat* varType=0,
1211  const CvMat* missingDataMask=0,
1212  CvBoostParams params=CvBoostParams(),
1213  bool update=false );
1214 
1215  virtual bool train( CvMLData* data,
1216  CvBoostParams params=CvBoostParams(),
1217  bool update=false );
1218 
1219  virtual float predict( const CvMat* sample, const CvMat* missing=0,
1220  CvMat* weak_responses=0, CvSlice slice=CV_WHOLE_SEQ,
1221  bool raw_mode=false, bool return_sum=false ) const;
1222 
1223  CV_WRAP CvBoost( const cv::Mat& trainData, int tflag,
1224  const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
1225  const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
1226  const cv::Mat& missingDataMask=cv::Mat(),
1227  CvBoostParams params=CvBoostParams() );
1228 
1229  CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
1230  const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
1231  const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
1232  const cv::Mat& missingDataMask=cv::Mat(),
1233  CvBoostParams params=CvBoostParams(),
1234  bool update=false );
1235 
1236  CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing=cv::Mat(),
1237  const cv::Range& slice=cv::Range::all(), bool rawMode=false,
1238  bool returnSum=false ) const;
1239 
1240  virtual float calc_error( CvMLData* _data, int type , std::vector<float> *resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
1241 
1242  CV_WRAP virtual void prune( CvSlice slice );
1243 
1244  CV_WRAP virtual void clear();
1245 
1246  virtual void write( CvFileStorage* storage, const char* name ) const;
1247  virtual void read( CvFileStorage* storage, CvFileNode* node );
1248  virtual const CvMat* get_active_vars(bool absolute_idx=true);
1249 
1250  CvSeq* get_weak_predictors();
1251 
1252  CvMat* get_weights();
1253  CvMat* get_subtree_weights();
1254  CvMat* get_weak_response();
1255  const CvBoostParams& get_params() const;
1256  const CvDTreeTrainData* get_data() const;
1257 
1258 protected:
1259 
1260  void update_weights_impl( CvBoostTree* tree, double initial_weights[2] );
1261 
1262  virtual bool set_params( const CvBoostParams& params );
1263  virtual void update_weights( CvBoostTree* tree );
1264  virtual void trim_weights();
1265  virtual void write_params( CvFileStorage* fs ) const;
1266  virtual void read_params( CvFileStorage* fs, CvFileNode* node );
1267 
1271 
1275 
1283 };
1284 
1285 
1286 /****************************************************************************************\
1287 * Gradient Boosted Trees *
1288 \****************************************************************************************/
1289 
1290 // DataType: STRUCT CvGBTreesParams
1291 // Parameters of GBT (Gradient Boosted trees model), including single
1292 // tree settings and ensemble parameters.
1293 //
1294 // weak_count - count of trees in the ensemble
1295 // loss_function_type - loss function used for ensemble training
1296 // subsample_portion - portion of whole training set used for
1297 // every single tree training.
1298 // subsample_portion value is in (0.0, 1.0].
1299 // subsample_portion == 1.0 when whole dataset is
1300 // used on each step. Count of sample used on each
1301 // step is computed as
1302 // int(total_samples_count * subsample_portion).
1303 // shrinkage - regularization parameter.
1304 // Each tree prediction is multiplied on shrinkage value.
1305 
1306 
1307 struct CV_EXPORTS_W_MAP CvGBTreesParams : public CvDTreeParams
1308 {
1309  CV_PROP_RW int weak_count;
1310  CV_PROP_RW int loss_function_type;
1311  CV_PROP_RW float subsample_portion;
1312  CV_PROP_RW float shrinkage;
1313 
1314  CvGBTreesParams();
1315  CvGBTreesParams( int loss_function_type, int weak_count, float shrinkage,
1316  float subsample_portion, int max_depth, bool use_surrogates );
1317 };
1318 
1319 // DataType: CLASS CvGBTrees
1320 // Gradient Boosting Trees (GBT) algorithm implementation.
1321 //
1322 // data - training dataset
1323 // params - parameters of the CvGBTrees
1324 // weak - array[0..(class_count-1)] of CvSeq
1325 // for storing tree ensembles
1326 // orig_response - original responses of the training set samples
1327 // sum_response - predicitons of the current model on the training dataset.
1328 // this matrix is updated on every iteration.
1329 // sum_response_tmp - predicitons of the model on the training set on the next
1330 // step. On every iteration values of sum_responses_tmp are
1331 // computed via sum_responses values. When the current
1332 // step is complete sum_response values become equal to
1333 // sum_responses_tmp.
1334 // sampleIdx - indices of samples used for training the ensemble.
1335 // CvGBTrees training procedure takes a set of samples
1336 // (train_data) and a set of responses (responses).
1337 // Only pairs (train_data[i], responses[i]), where i is
1338 // in sample_idx are used for training the ensemble.
1339 // subsample_train - indices of samples used for training a single decision
1340 // tree on the current step. This indices are countered
1341 // relatively to the sample_idx, so that pairs
1342 // (train_data[sample_idx[i]], responses[sample_idx[i]])
1343 // are used for training a decision tree.
1344 // Training set is randomly splited
1345 // in two parts (subsample_train and subsample_test)
1346 // on every iteration accordingly to the portion parameter.
1347 // subsample_test - relative indices of samples from the training set,
1348 // which are not used for training a tree on the current
1349 // step.
1350 // missing - mask of the missing values in the training set. This
1351 // matrix has the same size as train_data. 1 - missing
1352 // value, 0 - not a missing value.
1353 // class_labels - output class labels map.
1354 // rng - random number generator. Used for spliting the
1355 // training set.
1356 // class_count - count of output classes.
1357 // class_count == 1 in the case of regression,
1358 // and > 1 in the case of classification.
1359 // delta - Huber loss function parameter.
1360 // base_value - start point of the gradient descent procedure.
1361 // model prediction is
1362 // f(x) = f_0 + sum_{i=1..weak_count-1}(f_i(x)), where
1363 // f_0 is the base value.
1364 
1365 
1366 
1367 class CV_EXPORTS_W CvGBTrees : public CvStatModel
1368 {
1369 public:
1370 
1371  /*
1372  // DataType: ENUM
1373  // Loss functions implemented in CvGBTrees.
1374  //
1375  // SQUARED_LOSS
1376  // problem: regression
1377  // loss = (x - x')^2
1378  //
1379  // ABSOLUTE_LOSS
1380  // problem: regression
1381  // loss = abs(x - x')
1382  //
1383  // HUBER_LOSS
1384  // problem: regression
1385  // loss = delta*( abs(x - x') - delta/2), if abs(x - x') > delta
1386  // 1/2*(x - x')^2, if abs(x - x') <= delta,
1387  // where delta is the alpha-quantile of pseudo responses from
1388  // the training set.
1389  //
1390  // DEVIANCE_LOSS
1391  // problem: classification
1392  //
1393  */
1394  enum {SQUARED_LOSS=0, ABSOLUTE_LOSS, HUBER_LOSS=3, DEVIANCE_LOSS};
1395 
1396 
1397  /*
1398  // Default constructor. Creates a model only (without training).
1399  // Should be followed by one form of the train(...) function.
1400  //
1401  // API
1402  // CvGBTrees();
1403 
1404  // INPUT
1405  // OUTPUT
1406  // RESULT
1407  */
1408  CV_WRAP CvGBTrees();
1409 
1410 
1411  /*
1412  // Full form constructor. Creates a gradient boosting model and does the
1413  // train.
1414  //
1415  // API
1416  // CvGBTrees( const CvMat* trainData, int tflag,
1417  const CvMat* responses, const CvMat* varIdx=0,
1418  const CvMat* sampleIdx=0, const CvMat* varType=0,
1419  const CvMat* missingDataMask=0,
1420  CvGBTreesParams params=CvGBTreesParams() );
1421 
1422  // INPUT
1423  // trainData - a set of input feature vectors.
1424  // size of matrix is
1425  // <count of samples> x <variables count>
1426  // or <variables count> x <count of samples>
1427  // depending on the tflag parameter.
1428  // matrix values are float.
1429  // tflag - a flag showing how do samples stored in the
1430  // trainData matrix row by row (tflag=CV_ROW_SAMPLE)
1431  // or column by column (tflag=CV_COL_SAMPLE).
1432  // responses - a vector of responses corresponding to the samples
1433  // in trainData.
1434  // varIdx - indices of used variables. zero value means that all
1435  // variables are active.
1436  // sampleIdx - indices of used samples. zero value means that all
1437  // samples from trainData are in the training set.
1438  // varType - vector of <variables count> length. gives every
1439  // variable type CV_VAR_CATEGORICAL or CV_VAR_ORDERED.
1440  // varType = 0 means all variables are numerical.
1441  // missingDataMask - a mask of misiing values in trainData.
1442  // missingDataMask = 0 means that there are no missing
1443  // values.
1444  // params - parameters of GTB algorithm.
1445  // OUTPUT
1446  // RESULT
1447  */
1448  CvGBTrees( const CvMat* trainData, int tflag,
1449  const CvMat* responses, const CvMat* varIdx=0,
1450  const CvMat* sampleIdx=0, const CvMat* varType=0,
1451  const CvMat* missingDataMask=0,
1452  CvGBTreesParams params=CvGBTreesParams() );
1453 
1454 
1455  /*
1456  // Destructor.
1457  */
1458  virtual ~CvGBTrees();
1459 
1460 
1461  /*
1462  // Gradient tree boosting model training
1463  //
1464  // API
1465  // virtual bool train( const CvMat* trainData, int tflag,
1466  const CvMat* responses, const CvMat* varIdx=0,
1467  const CvMat* sampleIdx=0, const CvMat* varType=0,
1468  const CvMat* missingDataMask=0,
1469  CvGBTreesParams params=CvGBTreesParams(),
1470  bool update=false );
1471 
1472  // INPUT
1473  // trainData - a set of input feature vectors.
1474  // size of matrix is
1475  // <count of samples> x <variables count>
1476  // or <variables count> x <count of samples>
1477  // depending on the tflag parameter.
1478  // matrix values are float.
1479  // tflag - a flag showing how do samples stored in the
1480  // trainData matrix row by row (tflag=CV_ROW_SAMPLE)
1481  // or column by column (tflag=CV_COL_SAMPLE).
1482  // responses - a vector of responses corresponding to the samples
1483  // in trainData.
1484  // varIdx - indices of used variables. zero value means that all
1485  // variables are active.
1486  // sampleIdx - indices of used samples. zero value means that all
1487  // samples from trainData are in the training set.
1488  // varType - vector of <variables count> length. gives every
1489  // variable type CV_VAR_CATEGORICAL or CV_VAR_ORDERED.
1490  // varType = 0 means all variables are numerical.
1491  // missingDataMask - a mask of misiing values in trainData.
1492  // missingDataMask = 0 means that there are no missing
1493  // values.
1494  // params - parameters of GTB algorithm.
1495  // update - is not supported now. (!)
1496  // OUTPUT
1497  // RESULT
1498  // Error state.
1499  */
1500  virtual bool train( const CvMat* trainData, int tflag,
1501  const CvMat* responses, const CvMat* varIdx=0,
1502  const CvMat* sampleIdx=0, const CvMat* varType=0,
1503  const CvMat* missingDataMask=0,
1505  bool update=false );
1506 
1507 
1508  /*
1509  // Gradient tree boosting model training
1510  //
1511  // API
1512  // virtual bool train( CvMLData* data,
1513  CvGBTreesParams params=CvGBTreesParams(),
1514  bool update=false ) {return false;};
1515 
1516  // INPUT
1517  // data - training set.
1518  // params - parameters of GTB algorithm.
1519  // update - is not supported now. (!)
1520  // OUTPUT
1521  // RESULT
1522  // Error state.
1523  */
1524  virtual bool train( CvMLData* data,
1526  bool update=false );
1527 
1528 
1529  /*
1530  // Response value prediction
1531  //
1532  // API
1533  // virtual float predict_serial( const CvMat* sample, const CvMat* missing=0,
1534  CvMat* weak_responses=0, CvSlice slice = CV_WHOLE_SEQ,
1535  int k=-1 ) const;
1536 
1537  // INPUT
1538  // sample - input sample of the same type as in the training set.
1539  // missing - missing values mask. missing=0 if there are no
1540  // missing values in sample vector.
1541  // weak_responses - predictions of all of the trees.
1542  // not implemented (!)
1543  // slice - part of the ensemble used for prediction.
1544  // slice = CV_WHOLE_SEQ when all trees are used.
1545  // k - number of ensemble used.
1546  // k is in {-1,0,1,..,<count of output classes-1>}.
1547  // in the case of classification problem
1548  // <count of output classes-1> ensembles are built.
1549  // If k = -1 ordinary prediction is the result,
1550  // otherwise function gives the prediction of the
1551  // k-th ensemble only.
1552  // OUTPUT
1553  // RESULT
1554  // Predicted value.
1555  */
1556  virtual float predict_serial( const CvMat* sample, const CvMat* missing=0,
1557  CvMat* weakResponses=0, CvSlice slice = CV_WHOLE_SEQ,
1558  int k=-1 ) const;
1559 
1560  /*
1561  // Response value prediction.
1562  // Parallel version (in the case of TBB existence)
1563  //
1564  // API
1565  // virtual float predict( const CvMat* sample, const CvMat* missing=0,
1566  CvMat* weak_responses=0, CvSlice slice = CV_WHOLE_SEQ,
1567  int k=-1 ) const;
1568 
1569  // INPUT
1570  // sample - input sample of the same type as in the training set.
1571  // missing - missing values mask. missing=0 if there are no
1572  // missing values in sample vector.
1573  // weak_responses - predictions of all of the trees.
1574  // not implemented (!)
1575  // slice - part of the ensemble used for prediction.
1576  // slice = CV_WHOLE_SEQ when all trees are used.
1577  // k - number of ensemble used.
1578  // k is in {-1,0,1,..,<count of output classes-1>}.
1579  // in the case of classification problem
1580  // <count of output classes-1> ensembles are built.
1581  // If k = -1 ordinary prediction is the result,
1582  // otherwise function gives the prediction of the
1583  // k-th ensemble only.
1584  // OUTPUT
1585  // RESULT
1586  // Predicted value.
1587  */
1588  virtual float predict( const CvMat* sample, const CvMat* missing=0,
1589  CvMat* weakResponses=0, CvSlice slice = CV_WHOLE_SEQ,
1590  int k=-1 ) const;
1591 
1592  /*
1593  // Deletes all the data.
1594  //
1595  // API
1596  // virtual void clear();
1597 
1598  // INPUT
1599  // OUTPUT
1600  // delete data, weak, orig_response, sum_response,
1601  // weak_eval, subsample_train, subsample_test,
1602  // sample_idx, missing, lass_labels
1603  // delta = 0.0
1604  // RESULT
1605  */
1606  CV_WRAP virtual void clear();
1607 
1608  /*
1609  // Compute error on the train/test set.
1610  //
1611  // API
1612  // virtual float calc_error( CvMLData* _data, int type,
1613  // std::vector<float> *resp = 0 );
1614  //
1615  // INPUT
1616  // data - dataset
1617  // type - defines which error is to compute: train (CV_TRAIN_ERROR) or
1618  // test (CV_TEST_ERROR).
1619  // OUTPUT
1620  // resp - vector of predicitons
1621  // RESULT
1622  // Error value.
1623  */
1624  virtual float calc_error( CvMLData* _data, int type,
1625  std::vector<float> *resp = 0 );
1626 
1627  /*
1628  //
1629  // Write parameters of the gtb model and data. Write learned model.
1630  //
1631  // API
1632  // virtual void write( CvFileStorage* fs, const char* name ) const;
1633  //
1634  // INPUT
1635  // fs - file storage to read parameters from.
1636  // name - model name.
1637  // OUTPUT
1638  // RESULT
1639  */
1640  virtual void write( CvFileStorage* fs, const char* name ) const;
1641 
1642 
1643  /*
1644  //
1645  // Read parameters of the gtb model and data. Read learned model.
1646  //
1647  // API
1648  // virtual void read( CvFileStorage* fs, CvFileNode* node );
1649  //
1650  // INPUT
1651  // fs - file storage to read parameters from.
1652  // node - file node.
1653  // OUTPUT
1654  // RESULT
1655  */
1656  virtual void read( CvFileStorage* fs, CvFileNode* node );
1657 
1658 
1659  // new-style C++ interface
1660  CV_WRAP CvGBTrees( const cv::Mat& trainData, int tflag,
1661  const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
1662  const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
1663  const cv::Mat& missingDataMask=cv::Mat(),
1664  CvGBTreesParams params=CvGBTreesParams() );
1665 
1666  CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
1667  const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
1668  const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
1669  const cv::Mat& missingDataMask=cv::Mat(),
1671  bool update=false );
1672 
1673  CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing=cv::Mat(),
1674  const cv::Range& slice = cv::Range::all(),
1675  int k=-1 ) const;
1676 
1677 protected:
1678 
1679  /*
1680  // Compute the gradient vector components.
1681  //
1682  // API
1683  // virtual void find_gradient( const int k = 0);
1684 
1685  // INPUT
1686  // k - used for classification problem, determining current
1687  // tree ensemble.
1688  // OUTPUT
1689  // changes components of data->responses
1690  // which correspond to samples used for training
1691  // on the current step.
1692  // RESULT
1693  */
1694  virtual void find_gradient( const int k = 0);
1695 
1696 
1697  /*
1698  //
1699  // Change values in tree leaves according to the used loss function.
1700  //
1701  // API
1702  // virtual void change_values(CvDTree* tree, const int k = 0);
1703  //
1704  // INPUT
1705  // tree - decision tree to change.
1706  // k - used for classification problem, determining current
1707  // tree ensemble.
1708  // OUTPUT
1709  // changes 'value' fields of the trees' leaves.
1710  // changes sum_response_tmp.
1711  // RESULT
1712  */
1713  virtual void change_values(CvDTree* tree, const int k = 0);
1714 
1715 
1716  /*
1717  //
1718  // Find optimal constant prediction value according to the used loss
1719  // function.
1720  // The goal is to find a constant which gives the minimal summary loss
1721  // on the _Idx samples.
1722  //
1723  // API
1724  // virtual float find_optimal_value( const CvMat* _Idx );
1725  //
1726  // INPUT
1727  // _Idx - indices of the samples from the training set.
1728  // OUTPUT
1729  // RESULT
1730  // optimal constant value.
1731  */
1732  virtual float find_optimal_value( const CvMat* _Idx );
1733 
1734 
1735  /*
1736  //
1737  // Randomly split the whole training set in two parts according
1738  // to params.portion.
1739  //
1740  // API
1741  // virtual void do_subsample();
1742  //
1743  // INPUT
1744  // OUTPUT
1745  // subsample_train - indices of samples used for training
1746  // subsample_test - indices of samples used for test
1747  // RESULT
1748  */
1749  virtual void do_subsample();
1750 
1751 
1752  /*
1753  //
1754  // Internal recursive function giving an array of subtree tree leaves.
1755  //
1756  // API
1757  // void leaves_get( CvDTreeNode** leaves, int& count, CvDTreeNode* node );
1758  //
1759  // INPUT
1760  // node - current leaf.
1761  // OUTPUT
1762  // count - count of leaves in the subtree.
1763  // leaves - array of pointers to leaves.
1764  // RESULT
1765  */
1766  void leaves_get( CvDTreeNode** leaves, int& count, CvDTreeNode* node );
1767 
1768 
1769  /*
1770  //
1771  // Get leaves of the tree.
1772  //
1773  // API
1774  // CvDTreeNode** GetLeaves( const CvDTree* dtree, int& len );
1775  //
1776  // INPUT
1777  // dtree - decision tree.
1778  // OUTPUT
1779  // len - count of the leaves.
1780  // RESULT
1781  // CvDTreeNode** - array of pointers to leaves.
1782  */
1783  CvDTreeNode** GetLeaves( const CvDTree* dtree, int& len );
1784 
1785 
1786  /*
1787  //
1788  // Is it a regression or a classification.
1789  //
1790  // API
1791  // bool problem_type();
1792  //
1793  // INPUT
1794  // OUTPUT
1795  // RESULT
1796  // false if it is a classification problem,
1797  // true - if regression.
1798  */
1799  virtual bool problem_type() const;
1800 
1801 
1802  /*
1803  //
1804  // Write parameters of the gtb model.
1805  //
1806  // API
1807  // virtual void write_params( CvFileStorage* fs ) const;
1808  //
1809  // INPUT
1810  // fs - file storage to write parameters to.
1811  // OUTPUT
1812  // RESULT
1813  */
1814  virtual void write_params( CvFileStorage* fs ) const;
1815 
1816 
1817  /*
1818  //
1819  // Read parameters of the gtb model and data.
1820  //
1821  // API
1822  // virtual void read_params( CvFileStorage* fs );
1823  //
1824  // INPUT
1825  // fs - file storage to read parameters from.
1826  // OUTPUT
1827  // params - parameters of the gtb model.
1828  // data - contains information about the structure
1829  // of the data set (count of variables,
1830  // their types, etc.).
1831  // class_labels - output class labels map.
1832  // RESULT
1833  */
1834  virtual void read_params( CvFileStorage* fs, CvFileNode* fnode );
1835  int get_len(const CvMat* mat) const;
1836 
1837 
1840 
1850 
1852 
1854  float delta;
1855  float base_value;
1856 
1857 };
1858 
1859 
1860 
1861 /****************************************************************************************\
1862 * Artificial Neural Networks (ANN) *
1863 \****************************************************************************************/
1864 
1866 
1867 struct CV_EXPORTS_W_MAP CvANN_MLP_TrainParams
1868 {
1870  CvANN_MLP_TrainParams( CvTermCriteria term_crit, int train_method,
1871  double param1, double param2=0 );
1872  ~CvANN_MLP_TrainParams();
1873 
1874  enum { BACKPROP=0, RPROP=1 };
1875 
1877  CV_PROP_RW int train_method;
1878 
1879  // backpropagation parameters
1880  CV_PROP_RW double bp_dw_scale, bp_moment_scale;
1881 
1882  // rprop parameters
1883  CV_PROP_RW double rp_dw0, rp_dw_plus, rp_dw_minus, rp_dw_min, rp_dw_max;
1884 };
1885 
1886 
1887 class CV_EXPORTS_W CvANN_MLP : public CvStatModel
1888 {
1889 public:
1890  CV_WRAP CvANN_MLP();
1891  CvANN_MLP( const CvMat* layerSizes,
1892  int activateFunc=CvANN_MLP::SIGMOID_SYM,
1893  double fparam1=0, double fparam2=0 );
1894 
1895  virtual ~CvANN_MLP();
1896 
1897  virtual void create( const CvMat* layerSizes,
1898  int activateFunc=CvANN_MLP::SIGMOID_SYM,
1899  double fparam1=0, double fparam2=0 );
1900 
1901  virtual int train( const CvMat* inputs, const CvMat* outputs,
1902  const CvMat* sampleWeights, const CvMat* sampleIdx=0,
1904  int flags=0 );
1905  virtual float predict( const CvMat* inputs, CV_OUT CvMat* outputs ) const;
1906 
1907  CV_WRAP CvANN_MLP( const cv::Mat& layerSizes,
1908  int activateFunc=CvANN_MLP::SIGMOID_SYM,
1909  double fparam1=0, double fparam2=0 );
1910 
1911  CV_WRAP virtual void create( const cv::Mat& layerSizes,
1912  int activateFunc=CvANN_MLP::SIGMOID_SYM,
1913  double fparam1=0, double fparam2=0 );
1914 
1915  CV_WRAP virtual int train( const cv::Mat& inputs, const cv::Mat& outputs,
1916  const cv::Mat& sampleWeights, const cv::Mat& sampleIdx=cv::Mat(),
1918  int flags=0 );
1919 
1920  CV_WRAP virtual float predict( const cv::Mat& inputs, CV_OUT cv::Mat& outputs ) const;
1921 
1922  CV_WRAP virtual void clear();
1923 
1924  // possible activation functions
1925  enum { IDENTITY = 0, SIGMOID_SYM = 1, GAUSSIAN = 2 };
1926 
1927  // available training flags
1928  enum { UPDATE_WEIGHTS = 1, NO_INPUT_SCALE = 2, NO_OUTPUT_SCALE = 4 };
1929 
1930  virtual void read( CvFileStorage* fs, CvFileNode* node );
1931  virtual void write( CvFileStorage* storage, const char* name ) const;
1932 
1933  int get_layer_count() { return layer_sizes ? layer_sizes->cols : 0; }
1934  const CvMat* get_layer_sizes() { return layer_sizes; }
1935  double* get_weights(int layer)
1936  {
1937  return layer_sizes && weights &&
1938  (unsigned)layer <= (unsigned)layer_sizes->cols ? weights[layer] : 0;
1939  }
1940 
1941  virtual void calc_activ_func_deriv( CvMat* xf, CvMat* deriv, const double* bias ) const;
1942 
1943 protected:
1944 
1945  virtual bool prepare_to_train( const CvMat* _inputs, const CvMat* _outputs,
1946  const CvMat* _sample_weights, const CvMat* sampleIdx,
1947  CvVectors* _ivecs, CvVectors* _ovecs, double** _sw, int _flags );
1948 
1949  // sequential random backpropagation
1950  virtual int train_backprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
1951 
1952  // RPROP algorithm
1953  virtual int train_rprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
1954 
1955  virtual void calc_activ_func( CvMat* xf, const double* bias ) const;
1956  virtual void set_activ_func( int _activ_func=SIGMOID_SYM,
1957  double _f_param1=0, double _f_param2=0 );
1958  virtual void init_weights();
1959  virtual void scale_input( const CvMat* _src, CvMat* _dst ) const;
1960  virtual void scale_output( const CvMat* _src, CvMat* _dst ) const;
1961  virtual void calc_input_scale( const CvVectors* vecs, int flags );
1962  virtual void calc_output_scale( const CvVectors* vecs, int flags );
1963 
1964  virtual void write_params( CvFileStorage* fs ) const;
1965  virtual void read_params( CvFileStorage* fs, CvFileNode* node );
1966 
1970  double** weights;
1971  double f_param1, f_param2;
1972  double min_val, max_val, min_val1, max_val1;
1974  int max_count, max_buf_sz;
1977 };
1978 
1979 /****************************************************************************************\
1980 * Auxilary functions declarations *
1981 \****************************************************************************************/
1982 
1983 /* Generates <sample> from multivariate normal distribution, where <mean> - is an
1984  average row vector, <cov> - symmetric covariation matrix */
1985 CVAPI(void) cvRandMVNormal( CvMat* mean, CvMat* cov, CvMat* sample,
1986  CvRNG* rng CV_DEFAULT(0) );
1987 
1988 /* Generates sample from gaussian mixture distribution */
1989 CVAPI(void) cvRandGaussMixture( CvMat* means[],
1991  float weights[],
1992  int clsnum,
1993  CvMat* sample,
1994  CvMat* sampClasses CV_DEFAULT(0) );
1995 
1996 #define CV_TS_CONCENTRIC_SPHERES 0
1997 
1998 /* creates test set */
1999 CVAPI(void) cvCreateTestSet( int type, CvMat** samples,
2000  int num_samples,
2001  int num_features,
2002  CvMat** responses,
2003  int num_classes, ... );
2004 
2005 /****************************************************************************************\
2006 * Data *
2007 \****************************************************************************************/
2008 
2009 #define CV_COUNT 0
2010 #define CV_PORTION 1
2011 
2012 struct CV_EXPORTS CvTrainTestSplit
2013 {
2014  CvTrainTestSplit();
2015  CvTrainTestSplit( int train_sample_count, bool mix = true);
2016  CvTrainTestSplit( float train_sample_portion, bool mix = true);
2017 
2018  union
2019  {
2020  int count;
2021  float portion;
2022  } train_sample_part;
2024 
2025  bool mix;
2026 };
2027 
2028 class CV_EXPORTS CvMLData
2029 {
2030 public:
2031  CvMLData();
2032  virtual ~CvMLData();
2033 
2034  // returns:
2035  // 0 - OK
2036  // -1 - file can not be opened or is not correct
2037  int read_csv( const char* filename );
2038 
2039  const CvMat* get_values() const;
2040  const CvMat* get_responses();
2041  const CvMat* get_missing() const;
2042 
2043  void set_response_idx( int idx ); // old response become predictors, new response_idx = idx
2044  // if idx < 0 there will be no response
2045  int get_response_idx() const;
2046 
2047  void set_train_test_split( const CvTrainTestSplit * spl );
2048  const CvMat* get_train_sample_idx() const;
2049  const CvMat* get_test_sample_idx() const;
2050  void mix_train_and_test_idx();
2051 
2052  const CvMat* get_var_idx();
2053  void chahge_var_idx( int vi, bool state ); // misspelled (saved for back compitability),
2054  // use change_var_idx
2055  void change_var_idx( int vi, bool state ); // state == true to set vi-variable as predictor
2056 
2057  const CvMat* get_var_types();
2058  int get_var_type( int var_idx ) const;
2059  // following 2 methods enable to change vars type
2060  // use these methods to assign CV_VAR_CATEGORICAL type for categorical variable
2061  // with numerical labels; in the other cases var types are correctly determined automatically
2062  void set_var_types( const char* str ); // str examples:
2063  // "ord[0-17],cat[18]", "ord[0,2,4,10-12], cat[1,3,5-9,13,14]",
2064  // "cat", "ord" (all vars are categorical/ordered)
2065  void change_var_type( int var_idx, int type); // type in { CV_VAR_ORDERED, CV_VAR_CATEGORICAL }
2066 
2067  void set_delimiter( char ch );
2068  char get_delimiter() const;
2069 
2070  void set_miss_ch( char ch );
2071  char get_miss_ch() const;
2072 
2073  const std::map<std::string, int>& get_class_labels_map() const;
2074 
2075 protected:
2076  virtual void clear();
2077 
2078  void str_to_flt_elem( const char* token, float& flt_elem, int& type);
2079  void free_train_test_idx();
2080 
2082  char miss_ch;
2083  //char flt_separator;
2084 
2089 
2090  CvMat* response_out; // header
2093 
2095 
2097  bool mix;
2098 
2100  std::map<std::string, int> class_map;
2101 
2104  int* sample_idx; // data of train_sample_idx and test_sample_idx
2105 
2107 };
2108 
2109 
2110 namespace cv
2111 {
2112 
2120 typedef CvSVM SVM;
2132 typedef CvBoost Boost;
2137 
2138 template<> CV_EXPORTS void Ptr<CvDTreeSplit>::delete_obj();
2139 
2140 CV_EXPORTS bool initModule_ml(void);
2141 
2142 }
2143 
2144 #endif // __cplusplus
2145 #endif // __OPENCV_ML_HPP__
2146 
2147 /* End of file. */
Mat trainSamples
Definition: ml.hpp:635
CV_PROP_RW CvTermCriteria term_crit
Definition: ml.hpp:1876
char miss_ch
Definition: ml.hpp:2082
GLdouble GLdouble GLdouble r
double * min_val
Definition: core_c.h:833
void find_nearest(const Matrix< typename Distance::ElementType > &dataset, typename Distance::ElementType *query, int *matches, int nn, int skip=0, Distance distance=Distance())
Definition: ground_truth.h:42
CvBoost * ensemble
Definition: ml.hpp:1186
const CvArr CvSeq CvSeq CvMemStorage CvSURFParams params
Definition: compat.hpp:647
Definition: ml.hpp:1003
void set_num_valid(int vi, int n)
Definition: ml.hpp:713
CV_PROP Mat means
Definition: ml.hpp:641
int train_sample_count
Definition: ml.hpp:2096
CV_PROP_RW int cv_folds
Definition: ml.hpp:722
CV_EXPORTS bool initModule_ml(void)
schar * alpha_status
Definition: ml.hpp:417
GLenum GLint GLint y
Definition: core_c.h:613
CvMat * missing
Definition: ml.hpp:2086
const _OutputArray & OutputArray
Definition: core.hpp:1449
int train_sample_part_mode
Definition: ml.hpp:2023
CvFileNode * node
Definition: core_c.h:1638
Random Number Generator.
Definition: core.hpp:2019
int * sample_idx
Definition: ml.hpp:2104
int total_class_count
Definition: ml.hpp:2099
CvMat * weak_eval
Definition: ml.hpp:1278
virtual bool grow_forest(const CvTermCriteria term_crit)
signed char schar
Definition: types_c.h:174
Termination criteria in iterative algorithms.
Definition: core.hpp:2091
double eps
Definition: ml.hpp:422
int get_work_var_count() const
Definition: ml.hpp:771
CvGBTreesParams params
Definition: ml.hpp:1839
CvMat * split_buf
Definition: ml.hpp:822
double tree_error
Definition: ml.hpp:705
cv::RNG * rng
Definition: ml.hpp:2106
Definition: ml.hpp:442
CV_PROP_RW int weak_count
Definition: ml.hpp:1131
uint64 CvRNG
Definition: types_c.h:399
CV_PROP_RW bool calc_var_importance
Definition: ml.hpp:990
A short numerical vector.
Definition: core.hpp:84
const char const char ** filename
Definition: core_c.h:1750
CvMat int int CvMat int num_classes
Definition: ml.hpp:1999
int split_point
Definition: ml.hpp:677
int * sv_index
Definition: ml.hpp:447
const CvMat * train_data
Definition: ml.hpp:801
CvMat * wbuf
Definition: ml.hpp:1968
CvSVMKernel * kernel
Definition: ml.hpp:425
CV_PROP_RW int train_method
Definition: ml.hpp:1877
CvMat * cat_ofs
Definition: ml.hpp:810
int tflag
Definition: ml.hpp:799
CV_PROP_RW double C
Definition: ml.hpp:303
GLuint start
cv::RNG * rng
Definition: ml.hpp:1060
CvDTreeNode * data_root
Definition: ml.hpp:836
CV_PROP_RW double coef0
Definition: ml.hpp:301
CvSVMParams SVMParams
Definition: ml.hpp:2117
Definition: ml.hpp:163
CvMemStorage * tree_storage
Definition: ml.hpp:833
CvMat * priors_mult
Definition: ml.hpp:829
Definition: ml.hpp:562
int subset[2]
Definition: ml.hpp:673
const CvMat * get_layer_sizes()
Definition: ml.hpp:1934
CvMat * layer_sizes
Definition: ml.hpp:1967
Definition: ml.hpp:654
CvMat * var_idx
Definition: ml.hpp:216
void delete_obj()
deletes the object. Override if needed
Definition: operations.hpp:2612
Definition: ml.hpp:717
File Storage Node class.
Definition: core.hpp:4119
CvMat * train_sample_idx
Definition: ml.hpp:2102
double node_risk
Definition: ml.hpp:705
int alpha_count
Definition: ml.hpp:411
CV_PROP_RW double min_val
Definition: ml.hpp:177
CV_PROP_RW int nclusters
Definition: ml.hpp:630
Definition: ml.hpp:1190
CvSVMKernelRow * rows
Definition: ml.hpp:409
CV_PROP_RW float subsample_portion
Definition: ml.hpp:1311
double ** db
Definition: ml.hpp:81
virtual CvDTreeSplit * find_split_cat_reg(CvDTreeNode *n, int vi, float init_quality=0, CvDTreeSplit *_split=0, uchar *ext_buf=0)
CvMat * class_labels
Definition: ml.hpp:540
float c
Definition: ml.hpp:676
CvMat * var_idx
Definition: ml.hpp:824
float ** sv
Definition: ml.hpp:542
Definition: ml.hpp:1087
double * G
Definition: ml.hpp:413
CvDTreeSplit * next
Definition: ml.hpp:670
bool shared
Definition: ml.hpp:806
CvMat * cls_labels
Definition: ml.hpp:217
CV_PROP_RW double epsilon
Definition: ml.hpp:633
const int * idx
Definition: core_c.h:323
uchar ** ptr
Definition: ml.hpp:79
static Range all()
Definition: operations.hpp:2219
int sample_count
Definition: ml.hpp:695
float ** fl
Definition: ml.hpp:80
The 2D range class.
Definition: core.hpp:979
virtual void split_node_data(CvDTreeNode *n)
int * num_valid
Definition: ml.hpp:697
int activ_func
Definition: ml.hpp:1973
CV_PROP_RW double step
Definition: ml.hpp:179
Mat trainProbs
Definition: ml.hpp:636
CvMat * var_importance
Definition: ml.hpp:1057
CvSVMKernelRow * prev
Definition: ml.hpp:343
CvGBTrees GradientBoostingTrees
Definition: ml.hpp:2136
union CvVectors::@334 data
CvSet * nv_heap
Definition: ml.hpp:841
bool have_priors
Definition: ml.hpp:797
CvANN_MLP NeuralNet_MLP
Definition: ml.hpp:2134
void int double rho
Definition: imgproc_c.h:603
CvMat * subsample_train
Definition: ml.hpp:1846
int int int flags
Definition: highgui_c.h:186
CvRTrees * forest
Definition: ml.hpp:983
CvMat ** sum
Definition: ml.hpp:219
CvMat int int num_features
Definition: ml.hpp:1999
virtual const int * get_cat_var_data(CvDTreeNode *n, int vi, int *cat_values_buf)
GLsizei GLsizei GLenum GLenum const GLvoid * data
Definition: core_c.h:403
int get_num_valid(int vi)
Definition: ml.hpp:712
Definition: ml.hpp:566
int count
Definition: ml.hpp:2020
virtual CvDTreeSplit * find_split_ord_class(CvDTreeNode *n, int vi, float init_quality=0, CvDTreeSplit *_split=0, uchar *ext_buf=0)
CvRect r
Definition: core_c.h:1282
Definition: ml.hpp:1128
const CvMat const CvMat const CvMat CvMat CvMat CvMat CvMat CvSize CvMat CvMat * T
Definition: calib3d.hpp:270
size_t get_length_subbuf() const
Definition: ml.hpp:815
struct CvDTreeSplit::@342::@344 ord
double * alpha
Definition: ml.hpp:414
virtual std::string getName() const
int total
Definition: ml.hpp:279
CvParamGrid()
Definition: ml.hpp:168
Definition: defines.h:94
CvArr const CvMat * kernel
Definition: imgproc_c.h:89
CvRNG * rng
Definition: core_c.h:652
virtual CvDTreeSplit * find_split_ord_reg(CvDTreeNode *n, int vi, float init_quality=0, CvDTreeSplit *_split=0, uchar *ext_buf=0)
CvSet * split_heap
Definition: ml.hpp:839
double rho
Definition: ml.hpp:352
int inversed
Definition: ml.hpp:668
CvRTParams RandomTreeParams
Definition: ml.hpp:2125
GLXDrawable GLXDrawable read
CvNormalBayesClassifier NormalBayesClassifier
Definition: ml.hpp:2115
struct CvParamLattice CvParamLattice
virtual CvDTreeNode * subsample_data(const CvMat *_subsample_idx)
CvDTreeParams DTreeParams
Definition: ml.hpp:2121
const char * default_model_name
Definition: ml.hpp:149
CvMat * weights
Definition: ml.hpp:1280
GLuint res
double min_val1
Definition: ml.hpp:1972
Definition: ml.hpp:189
CV_PROP_RW CvTermCriteria term_crit
Definition: ml.hpp:992
void int step
Definition: core_c.h:403
CV_EXPORTS_W void write(FileStorage &fs, const string &name, int value)
CVAPI(void) cvRandMVNormal(CvMat *mean
CV_PROP_RW bool use_surrogates
Definition: ml.hpp:723
CV_PROP_RW double nu
Definition: ml.hpp:304
CvMat * var_importance
Definition: ml.hpp:941
Proxy datatype for passing Mat's and vector<>'s as input parameters.
Definition: core.hpp:1312
double min_val
Definition: ml.hpp:91
void clear(const ColorA &color=ColorA::black(), bool clearDepthBuffer=true)
CvMat * var_types_out
Definition: ml.hpp:2092
CV_PROP_RW float shrinkage
Definition: ml.hpp:1312
Definition: types_c.h:1202
CvMat * missing
Definition: ml.hpp:1848
virtual void read(CvFileStorage *fs, CvFileNode *node)
CvSVMKernel SVMKernel
Definition: ml.hpp:2118
CvMat * cat_count
Definition: ml.hpp:809
CV_PROP Mat weights
Definition: ml.hpp:640
CvMat CvMat CvRNG *rng CV_DEFAULT(0))
std::map< std::string, int > class_map
Definition: ml.hpp:2100
const char const char * str
Definition: core_c.h:1552
CvMat * subsample_test
Definition: ml.hpp:1847
Definition: ml.hpp:1067
typedef void(CV_CDECL *CvMouseCallback)(int event
int depth
Definition: ml.hpp:696
CvMat * priors
Definition: ml.hpp:828
CV_PROP_RW double rp_dw_plus
Definition: ml.hpp:1883
virtual void get_vectors(const CvMat *_subsample_idx, float *values, uchar *missing, float *responses, bool get_class_idx=false)
int var_count
Definition: ml.hpp:402
CV_PROP_RW int max_categories
Definition: ml.hpp:719
Definition: ml.hpp:289
int var_idx
Definition: ml.hpp:666
CV_PROP_RW bool use_1se_rule
Definition: ml.hpp:724
int dims
Definition: ml.hpp:75
const GLbyte * weights
int class_count
Definition: ml.hpp:1853
CvSeq ** weak
Definition: ml.hpp:1841
CvMat * active_vars_abs
Definition: ml.hpp:1273
CvMat ** count
Definition: ml.hpp:218
CvMat * class_weights
Definition: ml.hpp:306
CvDTreeNode * left
Definition: ml.hpp:690
CV_PROP_RW int kernel_type
Definition: ml.hpp:298
GLenum GLenum GLvoid * row
CvMat * cov
Definition: ml.hpp:1985
SourceFileRef load(const DataSourceRef &dataSource, size_t sampleRate=0)
int Tn
Definition: ml.hpp:686
CvSet * node_heap
Definition: ml.hpp:838
CvBoostTree BoostTree
Definition: ml.hpp:2131
CV_PROP_RW double bp_moment_scale
Definition: ml.hpp:1880
CvMat * response_out
Definition: ml.hpp:2090
cv::RNG * rng
Definition: ml.hpp:1976
double value
Definition: ml.hpp:687
Definition: ml.hpp:738
CvDTreeTrainData * data
Definition: ml.hpp:1838
CvParamGrid ParamGrid
Definition: ml.hpp:2114
const _InputArray & InputArray
Definition: core.hpp:1447
double ** weights
Definition: ml.hpp:1970
Definition: ml.hpp:1102
CV_PROP_RW int nactive_vars
Definition: ml.hpp:991
bool have_subsample
Definition: ml.hpp:1282
virtual void read(CvFileStorage *storage, CvFileNode *node)
CvSVMSolver * solver
Definition: ml.hpp:549
Definition: ml.hpp:1925
CV_PROP_RW bool truncate_pruned_tree
Definition: ml.hpp:725
CV_PROP_RW double weight_trim_rate
Definition: ml.hpp:1133
Definition: types_c.h:1364
GetRow get_row_func
Definition: ml.hpp:429
int cache_line_size
Definition: ml.hpp:404
Definition: ml.hpp:683
the desired accuracy or change in parameters at which the iterative algorithm stops ...
Definition: core.hpp:2098
CvDTreeNode * parent
Definition: ml.hpp:689
CvMat * sample_weights
Definition: ml.hpp:1969
CvArr const CvMat * mat
Definition: core_c.h:700
int type
Definition: ml.hpp:74
CV_EXPORTS void split(const Mat &src, Mat *mvbegin)
copies each plane of a multi-channel array to a dedicated array
int get_layer_count()
Definition: ml.hpp:1933
Definition: ml.hpp:341
vector< Mat > covsRotateMats
Definition: ml.hpp:645
CalcRho calc_rho_func
Definition: ml.hpp:428
Definition: ml.hpp:855
Definition: ml.hpp:452
GLclampf GLclampf GLclampf alpha
Definition: core_c.h:687
CvSet * cv_heap
Definition: ml.hpp:840
int sv_count
Definition: ml.hpp:445
CvERTreeTrainData ERTreeTRainData
Definition: ml.hpp:2127
Definition: types_c.h:1272
virtual void write(CvFileStorage *storage, const char *name) const
CvDTree DecisionTree
Definition: ml.hpp:2123
vector< Mat > covsEigenValues
Definition: ml.hpp:644
double oob_error
Definition: ml.hpp:1056
double f_param2
Definition: ml.hpp:1971
const CvMat * responses
Definition: ml.hpp:802
float delta
Definition: ml.hpp:1854
CvDTreeTrainData * data
Definition: ml.hpp:942
size_t size_t CvMemStorage * storage
Definition: core_c.h:946
CvMLData TrainData
Definition: ml.hpp:2122
CV_INLINE CvParamLattice cvParamLattice(double min_val, double max_val, double log_step)
Definition: ml.hpp:97
int offset
Definition: ml.hpp:698
CvArr int CvScalar param1
Definition: core_c.h:649
CvMat * var_idx
Definition: ml.hpp:544
const CvMat CvMat CvMat int k
Definition: legacy.hpp:3052
CvBoost Boost
Definition: ml.hpp:2132
CvMat * orig_response
Definition: ml.hpp:1276
CvMat * sum_response
Definition: ml.hpp:1277
CvBoostParams BoostParams
Definition: ml.hpp:2130
GLintptr offset
int buf_size
Definition: ml.hpp:805
GLuint GLuint GLsizei count
Definition: core_c.h:973
Definition: ml.hpp:1867
GLenum GLsizei n
CvSVMKernelRow lru_list
Definition: ml.hpp:408
struct CvFileStorage CvFileStorage
Definition: types_c.h:1740
CvSlice slice
Definition: core_c.h:1053
CvRTrees RandomTrees
Definition: ml.hpp:2126
CvMat CvMat * sample
Definition: ml.hpp:1985
int sample_count
Definition: ml.hpp:401
CvMat * responses_copy
Definition: ml.hpp:803
Definition: types_c.h:1828
const CvSVMParams * params
Definition: ml.hpp:323
float *(CvSVMSolver::* GetRow)(int i, float *row, float *dst, bool existed)
Definition: ml.hpp:362
Definition: ml.hpp:955
int count
Definition: ml.hpp:75
Mat trainLabels
Definition: ml.hpp:638
bool mix
Definition: ml.hpp:2097
double alpha
Definition: ml.hpp:704
CvMat * counts
Definition: ml.hpp:813
CvDTreeNode * right
Definition: ml.hpp:691
CvANN_MLP_TrainParams ANN_MLP_TrainParams
Definition: ml.hpp:2133
const float * priors
Definition: ml.hpp:727
CvGBTreesParams GradientBoostingTreeParams
Definition: ml.hpp:2135
CvMat ** cov_rotate_mats
Definition: ml.hpp:223
Definition: types_c.h:645
Definition: ml.hpp:311
Mat trainLogLikelihoods
Definition: ml.hpp:637
CvMemStorage * temp_storage
Definition: ml.hpp:834
GLenum GLuint GLint GLint layer
double const CvArr double beta
Definition: core_c.h:523
Definition: ml.hpp:664
double upper_bound_p
Definition: ml.hpp:353
The Core Functionality.
Definition: ml.hpp:1887
the maximum number of iterations or elements to compute
Definition: core.hpp:2096
int pruned_tree_idx
Definition: ml.hpp:945
The n-dimensional matrix class.
Definition: core.hpp:1688
virtual CvDTreeSplit * find_surrogate_split_ord(CvDTreeNode *n, int vi, uchar *ext_buf=0)
Definition: ml.hpp:987
CvMat * subsample_mask
Definition: ml.hpp:1279
CV_PROP_RW int boost_type
Definition: ml.hpp:1130
int ntrees
Definition: ml.hpp:1054
char delimiter
Definition: ml.hpp:2081
Definition: types_c.h:997
Definition: ml.hpp:462
Definition: ml.hpp:358
CvSVMParams params
Definition: ml.hpp:539
CvMat * var_types
Definition: ml.hpp:2087
int max_iter
Definition: ml.hpp:423
CvArr * mean
Definition: core_c.h:802
Definition: core.hpp:4465
GLuint GLuint end
CV_WRAP int get_var_count() const
Definition: ml.hpp:519
OutputArray OutputArray labels
Definition: imgproc.hpp:823
Definition: ml.hpp:349
CvMat * active_vars
Definition: ml.hpp:1272
GLfloat GLfloat p
GLuint GLuint GLsizei GLenum type
Definition: core_c.h:114
Definition: ml.hpp:462
GLenum const GLfloat * params
Definition: compat.hpp:688
virtual void try_split_node(CvDTreeNode *n)
GLsizei samples
CvDTreeParams params
Definition: ml.hpp:831
CV_PROP_RW int weak_count
Definition: ml.hpp:1309
int var_count
Definition: ml.hpp:278
bool mix
Definition: ml.hpp:2025
CvMat * direction
Definition: ml.hpp:821
double upper_bound_n
Definition: ml.hpp:354
CvScalar scale
Definition: core_c.h:518
GLuint const GLchar * name
Definition: core_c.h:1546
bool have_active_cat_vars
Definition: ml.hpp:1274
SelectWorkingSet select_working_set_func
Definition: ml.hpp:427
CvMat ** inv_eigen_values
Definition: ml.hpp:222
CV_PROP vector< Mat > covs
Definition: ml.hpp:642
CvMat * sample_idx
Definition: ml.hpp:1845
CV_INLINE CvParamLattice cvDefaultParamLattice(void)
Definition: ml.hpp:107
float portion
Definition: ml.hpp:2021
CV_PROP_RW int maxIters
Definition: ml.hpp:632
CvMat int int CvMat ** responses
Definition: ml.hpp:1999
virtual double calc_node_dir(CvDTreeNode *node)
Definition: types_c.h:198
CV_PROP_RW int max_depth
Definition: ml.hpp:720
int * i
Definition: ml.hpp:657
CvMat * sum_response_tmp
Definition: ml.hpp:1844
int is_buf_16u
Definition: ml.hpp:807
CvMat * class_weights
Definition: ml.hpp:545
Definition: ml.hpp:1143
CvMat * test_sample_idx
Definition: ml.hpp:2103
Definition: ml.hpp:2028
CvMat * orig_response
Definition: ml.hpp:1842
Calc calc_func
Definition: ml.hpp:324
double obj
Definition: ml.hpp:351
CV_PROP_RW float regression_accuracy
Definition: ml.hpp:726
CvMemStorage * storage
Definition: ml.hpp:407
CvSVMDecisionFunc * decision_func
Definition: ml.hpp:546
int * cv_Tn
Definition: ml.hpp:708
CvMat * cat_map
Definition: ml.hpp:811
int cache_size
Definition: ml.hpp:403
GLfloat bias
virtual const int * get_sample_indices(CvDTreeNode *n, int *indices_buf)
Definition: ml.hpp:462
Definition: ml.hpp:134
int work_var_count
Definition: ml.hpp:796
virtual void clear()
CV_EXPORTS int check(const Mat &data, double min_val, double max_val, vector< int > *idx)
double * get_weights(int layer)
Definition: ml.hpp:1935
virtual void set_data(const CvMat *trainData, int tflag, const CvMat *responses, const CvMat *varIdx=0, const CvMat *sampleIdx=0, const CvMat *varType=0, const CvMat *missingDataMask=0, const CvDTreeParams &params=CvDTreeParams(), bool _shared=false, bool _add_labels=false, bool _update_data=false)
Base class for high-level OpenCV algorithms.
Definition: core.hpp:4390
CvMat * var_type
Definition: ml.hpp:825
CvSVMKernelRow * next
Definition: ml.hpp:344
Definition: ml.hpp:72
CvArr int CvScalar CvScalar param2
Definition: core_c.h:649
CvMat * sum_response
Definition: ml.hpp:1843
int n
Definition: legacy.hpp:3070
Definition: ml.hpp:1367
int var_count
Definition: ml.hpp:795
CvMat * active_var_mask
Definition: ml.hpp:1061
CV_PROP_RW double degree
Definition: ml.hpp:299
CvSeq * weak
Definition: ml.hpp:1270
unsigned char uchar
Definition: types_c.h:170
int buf_idx
Definition: ml.hpp:699
CvForestTree ForestTree
Definition: ml.hpp:2124
Definition: types_c.h:1333
double step
Definition: ml.hpp:93
CvDTreeTrainData * data
Definition: ml.hpp:1268
double const CvArr double double gamma
Definition: core_c.h:523
GLuint dst
Definition: calib3d.hpp:134
CvKNearest KNearest
Definition: ml.hpp:2116
CV_PROP_RW int split_criteria
Definition: ml.hpp:1132
float quality
Definition: ml.hpp:669
double rho
Definition: ml.hpp:444
CvMat * buf
Definition: ml.hpp:814
void * parent
Definition: core_c.h:1459
int var_count
Definition: ml.hpp:215
CvSVM SVM
Definition: ml.hpp:2120
CvMemStorage * storage
Definition: ml.hpp:547
CvDTreeTrainData * data
Definition: ml.hpp:1053
bool is_classifier
Definition: ml.hpp:798
GLenum GLsizei len
virtual CvDTreeSplit * find_surrogate_split_cat(CvDTreeNode *n, int vi, uchar *ext_buf=0)
CvSVMKernel * kernel
Definition: ml.hpp:550
CvERTrees ERTrees
Definition: ml.hpp:2129
virtual int get_var_count() const
Definition: ml.hpp:963
Definition: ml.hpp:2012
double r
Definition: ml.hpp:355
unsigned short * u
Definition: ml.hpp:656
CvMat ** avg
Definition: ml.hpp:221
CvBoostParams params
Definition: ml.hpp:1269
cv::RNG * rng
Definition: ml.hpp:843
CvMat * covs[]
Definition: ml.hpp:1990
double * cv_node_risk
Definition: ml.hpp:709
const float ** samples
Definition: ml.hpp:405
CV_PROP_RW int covMatType
Definition: ml.hpp:631
CvANN_MLP_TrainParams params
Definition: ml.hpp:1975
int sv_total
Definition: ml.hpp:543
const CvSVMParams * params
Definition: ml.hpp:406
GLboolean GLenum GLenum GLvoid * values
const CvArr CvArr CvStereoBMState * state
Definition: calib3d.hpp:353
CvForestTree ** trees
Definition: ml.hpp:1052
Mat logWeightDivDet
Definition: ml.hpp:647
CvMat * class_labels
Definition: ml.hpp:1849
CV_PROP_RW int loss_function_type
Definition: ml.hpp:1310
static CvParamGrid get_default_grid(int param_id)
CV_PROP_RW double gamma
Definition: ml.hpp:300
CV_PROP_RW double p
Definition: ml.hpp:305
int max_count
Definition: ml.hpp:1974
Definition: ml.hpp:89
Definition: ml.hpp:462
virtual const int * get_cv_labels(CvDTreeNode *n, int *labels_buf)
double * cv_node_error
Definition: ml.hpp:710
CvForestERTree ERTree
Definition: ml.hpp:2128
CvMat float int clsnum
Definition: ml.hpp:1990
int class_idx
Definition: ml.hpp:685
double * alpha
Definition: ml.hpp:446
CvMat * subtree_weights
Definition: ml.hpp:1281
Definition: ml.hpp:462
double max_val
Definition: ml.hpp:92
CvMat int num_samples
Definition: ml.hpp:1999
CV_EXPORTS OutputArray noArray()
CV_PROP_RW CvTermCriteria term_crit
Definition: ml.hpp:307
float * data
Definition: ml.hpp:345
virtual bool train(const CvMat *trainData, int tflag, const CvMat *responses, const CvMat *varIdx=0, const CvMat *sampleIdx=0, const CvMat *varType=0, const CvMat *missingDataMask=0, CvRTParams params=CvRTParams())
int condensed_idx
Definition: ml.hpp:667
GLdouble s
CV_PROP_RW int svm_type
Definition: ml.hpp:297
virtual bool train(const CvMat *trainData, int tflag, const CvMat *responses, const CvMat *varIdx=0, const CvMat *sampleIdx=0, const CvMat *varType=0, const CvMat *missingDataMask=0, CvDTreeParams params=CvDTreeParams())
virtual void get_ord_var_data(CvDTreeNode *n, int vi, float *ord_values_buf, int *sorted_indices_buf, const float **ord_values, const int **sorted_indices, int *sample_indices_buf)
double * b
Definition: ml.hpp:420
const CvMat * missing_mask
Definition: ml.hpp:1084
float base_value
Definition: ml.hpp:1855
CvVectors * samples
Definition: ml.hpp:281
int nsamples
Definition: ml.hpp:1058
CvDTreeSplit * split
Definition: ml.hpp:693
int response_idx
Definition: ml.hpp:2094
CvMat * values
Definition: ml.hpp:2085
CV_PROP_RW double max_val
Definition: ml.hpp:178
Definition: ml.hpp:462
virtual CV_WRAP void clear()
CvMat * c
Definition: ml.hpp:224
Definition: ml.hpp:233
CvSVMSolver SVMSolver
Definition: ml.hpp:2119
cv::RNG * rng
Definition: ml.hpp:1851
CvPoint3D64f double * dist
Definition: legacy.hpp:556
Definition: ml.hpp:569
Proxy datatype for passing Mat's and vector<>'s as input parameters.
Definition: core.hpp:1400
double tree_risk
Definition: ml.hpp:705
Definition: ml.hpp:569
double maxlr
Definition: ml.hpp:700
Definition: ml.hpp:1307
CvMat * var_idx_mask
Definition: ml.hpp:2088
CV_PROP_RW int min_sample_count
Definition: ml.hpp:721
int int sample_count
Definition: legacy.hpp:1177
virtual void calc_node_value(CvDTreeNode *node)
CvVectors * next
Definition: ml.hpp:76
double double * max_val
Definition: core_c.h:833
bool regression
Definition: ml.hpp:280
int complexity
Definition: ml.hpp:703
CvStatModel StatModel
Definition: ml.hpp:2113
virtual CvDTreeSplit * find_best_split(CvDTreeNode *n)
CvDTreeNode * root
Definition: ml.hpp:940
schar * y
Definition: ml.hpp:419
int nclasses
Definition: ml.hpp:1055
vector< Mat > invCovsEigenValues
Definition: ml.hpp:646
virtual CvDTreeSplit * find_split_cat_class(CvDTreeNode *n, int vi, float init_quality=0, CvDTreeSplit *_split=0, uchar *ext_buf=0)
int var_all
Definition: ml.hpp:541
CvMat * var_idx_out
Definition: ml.hpp:2091
CvMat ** productsum
Definition: ml.hpp:220