00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031 #ifndef _OPENCV_KDTREE_H_
00032 #define _OPENCV_KDTREE_H_
00033
00034 #include <algorithm>
00035 #include <map>
00036 #include <cassert>
00037 #include <cstring>
00038
00039 #include "opencv2/flann/general.h"
00040 #include "opencv2/flann/nn_index.h"
00041 #include "opencv2/flann/matrix.h"
00042 #include "opencv2/flann/result_set.h"
00043 #include "opencv2/flann/heap.h"
00044 #include "opencv2/flann/allocator.h"
00045 #include "opencv2/flann/random.h"
00046 #include "opencv2/flann/saving.h"
00047
00048
00049 namespace cvflann
00050 {
00051
00052 struct CV_EXPORTS KDTreeIndexParams : public IndexParams {
00053 KDTreeIndexParams(int trees_ = 4) : IndexParams(FLANN_INDEX_KDTREE), trees(trees_) {};
00054
00055 int trees;
00056
00057 void print() const
00058 {
00059 logger().info("Index type: %d\n",(int)algorithm);
00060 logger().info("Trees: %d\n", trees);
00061 }
00062
00063 };
00064
00065
00072 template <typename ELEM_TYPE, typename DIST_TYPE = typename DistType<ELEM_TYPE>::type >
00073 class KDTreeIndex : public NNIndex<ELEM_TYPE>
00074 {
00075
00076 enum {
00082 SAMPLE_MEAN = 100,
00090 RAND_DIM=5
00091 };
00092
00093
00097 int numTrees;
00098
00102 int* vind;
00103
00104
00108 const Matrix<ELEM_TYPE> dataset;
00109
00110 const IndexParams& index_params;
00111
00112 size_t size_;
00113 size_t veclen_;
00114
00115
00116 DIST_TYPE* mean;
00117 DIST_TYPE* var;
00118
00119
00120
00121
00128 struct TreeSt {
00134 int divfeat;
00138 DIST_TYPE divval;
00142 TreeSt *child1, *child2;
00143 };
00144 typedef TreeSt* Tree;
00145
00149 Tree* trees;
00150 typedef BranchStruct<Tree> BranchSt;
00151 typedef BranchSt* Branch;
00152
00160 PooledAllocator pool;
00161
00162
00163
00164 public:
00165
00166 flann_algorithm_t getType() const
00167 {
00168 return FLANN_INDEX_KDTREE;
00169 }
00170
00178 KDTreeIndex(const Matrix<ELEM_TYPE>& inputData, const KDTreeIndexParams& params = KDTreeIndexParams() ) :
00179 dataset(inputData), index_params(params)
00180 {
00181 size_ = dataset.rows;
00182 veclen_ = dataset.cols;
00183
00184 numTrees = params.trees;
00185 trees = new Tree[numTrees];
00186
00187
00188
00189
00190
00191
00192
00193
00194
00195
00196
00197
00198 vind = new int[size_];
00199 for (size_t i = 0; i < size_; i++) {
00200 vind[i] = (int)i;
00201 }
00202
00203 mean = new DIST_TYPE[veclen_];
00204 var = new DIST_TYPE[veclen_];
00205 }
00206
00210 ~KDTreeIndex()
00211 {
00212 delete[] vind;
00213 if (trees!=NULL) {
00214 delete[] trees;
00215 }
00216 delete[] mean;
00217 delete[] var;
00218 }
00219
00220
00224 void buildIndex()
00225 {
00226
00227 for (int i = 0; i < numTrees; i++) {
00228
00229 for (int j = (int)size_; j > 0; --j) {
00230 int rnd = rand_int(j);
00231 std::swap(vind[j-1], vind[rnd]);
00232 }
00233 trees[i] = divideTree(0, (int)size_ - 1);
00234 }
00235 }
00236
00237
00238
00239 void saveIndex(FILE* stream)
00240 {
00241 save_value(stream, numTrees);
00242 for (int i=0;i<numTrees;++i) {
00243 save_tree(stream, trees[i]);
00244 }
00245 }
00246
00247
00248
00249 void loadIndex(FILE* stream)
00250 {
00251 load_value(stream, numTrees);
00252
00253 if (trees!=NULL) {
00254 delete[] trees;
00255 }
00256 trees = new Tree[numTrees];
00257 for (int i=0;i<numTrees;++i) {
00258 load_tree(stream,trees[i]);
00259 }
00260 }
00261
00262
00266 size_t size() const
00267 {
00268 return size_;
00269 }
00270
00274 size_t veclen() const
00275 {
00276 return veclen_;
00277 }
00278
00279
00284 int usedMemory() const
00285 {
00286 return (int)(pool.usedMemory+pool.wastedMemory+dataset.rows*sizeof(int));
00287 }
00288
00289
00299 void findNeighbors(ResultSet<ELEM_TYPE>& result, const ELEM_TYPE* vec, const SearchParams& searchParams)
00300 {
00301 int maxChecks = searchParams.checks;
00302
00303 if (maxChecks<0) {
00304 getExactNeighbors(result, vec);
00305 } else {
00306 getNeighbors(result, vec, maxChecks);
00307 }
00308 }
00309
00310 const IndexParams* getParameters() const
00311 {
00312 return &index_params;
00313 }
00314
00315 private:
00316
00317 KDTreeIndex& operator=(const KDTreeIndex&);
00318 KDTreeIndex(const KDTreeIndex&);
00319
00320
00321 void save_tree(FILE* stream, Tree tree)
00322 {
00323 save_value(stream, *tree);
00324 if (tree->child1!=NULL) {
00325 save_tree(stream, tree->child1);
00326 }
00327 if (tree->child2!=NULL) {
00328 save_tree(stream, tree->child2);
00329 }
00330 }
00331
00332
00333 void load_tree(FILE* stream, Tree& tree)
00334 {
00335 tree = pool.allocate<TreeSt>();
00336 load_value(stream, *tree);
00337 if (tree->child1!=NULL) {
00338 load_tree(stream, tree->child1);
00339 }
00340 if (tree->child2!=NULL) {
00341 load_tree(stream, tree->child2);
00342 }
00343 }
00344
00345
00355 Tree divideTree(int first, int last)
00356 {
00357 Tree node = pool.allocate<TreeSt>();
00358
00359
00360 if (first == last) {
00361 node->child1 = node->child2 = NULL;
00362 node->divfeat = vind[first];
00363 }
00364 else {
00365 chooseDivision(node, first, last);
00366 subdivide(node, first, last);
00367 }
00368
00369 return node;
00370 }
00371
00372
00378 void chooseDivision(Tree node, int first, int last)
00379 {
00380 memset(mean,0,veclen_*sizeof(DIST_TYPE));
00381 memset(var,0,veclen_*sizeof(DIST_TYPE));
00382
00383
00384
00385
00386 int end = std::min(first + SAMPLE_MEAN, last);
00387 for (int j = first; j <= end; ++j) {
00388 ELEM_TYPE* v = dataset[vind[j]];
00389 for (size_t k=0; k<veclen_; ++k) {
00390 mean[k] += v[k];
00391 }
00392 }
00393 for (size_t k=0; k<veclen_; ++k) {
00394 mean[k] /= (end - first + 1);
00395 }
00396
00397
00398 for (int j = first; j <= end; ++j) {
00399 ELEM_TYPE* v = dataset[vind[j]];
00400 for (size_t k=0; k<veclen_; ++k) {
00401 DIST_TYPE dist = v[k] - mean[k];
00402 var[k] += dist * dist;
00403 }
00404 }
00405
00406 node->divfeat = selectDivision(var);
00407 node->divval = mean[node->divfeat];
00408
00409 }
00410
00411
00416 int selectDivision(DIST_TYPE* v)
00417 {
00418 int num = 0;
00419 int topind[RAND_DIM];
00420
00421
00422 for (size_t i = 0; i < veclen_; ++i) {
00423 if (num < RAND_DIM || v[i] > v[topind[num-1]]) {
00424
00425 if (num < RAND_DIM) {
00426 topind[num++] = (int)i;
00427 }
00428 else {
00429 topind[num-1] = (int)i;
00430 }
00431
00432 int j = num - 1;
00433 while (j > 0 && v[topind[j]] > v[topind[j-1]]) {
00434 std::swap(topind[j], topind[j-1]);
00435 --j;
00436 }
00437 }
00438 }
00439
00440 int rnd = rand_int(num);
00441 return topind[rnd];
00442 }
00443
00444
00449 void subdivide(Tree node, int first, int last)
00450 {
00451
00452 int i = first;
00453 int j = last;
00454 while (i <= j) {
00455 int ind = vind[i];
00456 ELEM_TYPE val = dataset[ind][node->divfeat];
00457 if (val < node->divval) {
00458 ++i;
00459 } else {
00460
00461 std::swap(vind[i], vind[j]);
00462 --j;
00463 }
00464 }
00465
00466
00467
00468
00469 if ( (i == first) || (i == last+1)) {
00470 i = (first+last+1)/2;
00471 }
00472
00473 node->child1 = divideTree(first, i - 1);
00474 node->child2 = divideTree(i, last);
00475 }
00476
00477
00478
00483 void getExactNeighbors(ResultSet<ELEM_TYPE>& result, const ELEM_TYPE* vec)
00484 {
00485
00486
00487 if (numTrees > 1) {
00488 fprintf(stderr,"It doesn't make any sense to use more than one tree for exact search");
00489 }
00490 if (numTrees>0) {
00491 searchLevelExact(result, vec, trees[0], 0.0);
00492 }
00493 assert(result.full());
00494 }
00495
00501 void getNeighbors(ResultSet<ELEM_TYPE>& result, const ELEM_TYPE* vec, int maxCheck)
00502 {
00503 int i;
00504 BranchSt branch;
00505
00506 int checkCount = 0;
00507 Heap<BranchSt>* heap = new Heap<BranchSt>((int)size_);
00508 std::vector<bool> checked(size_,false);
00509
00510
00511 for (i = 0; i < numTrees; ++i) {
00512 searchLevel(result, vec, trees[i], 0.0, checkCount, maxCheck, heap, checked);
00513 }
00514
00515
00516 while ( heap->popMin(branch) && (checkCount < maxCheck || !result.full() )) {
00517 searchLevel(result, vec, branch.node, branch.mindistsq, checkCount, maxCheck, heap, checked);
00518 }
00519
00520 delete heap;
00521
00522 assert(result.full());
00523 }
00524
00525
00531 void searchLevel(ResultSet<ELEM_TYPE>& result, const ELEM_TYPE* vec, Tree node, float mindistsq, int& checkCount, int maxCheck,
00532 Heap<BranchSt>* heap, std::vector<bool>& checked)
00533 {
00534 if (result.worstDist()<mindistsq) {
00535
00536 return;
00537 }
00538
00539
00540 if (node->child1 == NULL && node->child2 == NULL) {
00541
00542
00543
00544
00545
00546 if (checked[node->divfeat] == true || checkCount>=maxCheck) {
00547 if (result.full()) return;
00548 }
00549 checkCount++;
00550 checked[node->divfeat] = true;
00551
00552 result.addPoint(dataset[node->divfeat],node->divfeat);
00553 return;
00554 }
00555
00556
00557 ELEM_TYPE val = vec[node->divfeat];
00558 DIST_TYPE diff = val - node->divval;
00559 Tree bestChild = (diff < 0) ? node->child1 : node->child2;
00560 Tree otherChild = (diff < 0) ? node->child2 : node->child1;
00561
00562
00563
00564
00565
00566
00567
00568
00569
00570 DIST_TYPE new_distsq = (DIST_TYPE)flann_dist(&val, &val+1, &node->divval, mindistsq);
00571
00572 if (new_distsq < result.worstDist() || !result.full()) {
00573 heap->insert( BranchSt::make_branch(otherChild, new_distsq) );
00574 }
00575
00576
00577 searchLevel(result, vec, bestChild, mindistsq, checkCount, maxCheck, heap, checked);
00578 }
00579
00583 void searchLevelExact(ResultSet<ELEM_TYPE>& result, const ELEM_TYPE* vec, Tree node, float mindistsq)
00584 {
00585 if (mindistsq>result.worstDist()) {
00586 return;
00587 }
00588
00589
00590 if (node->child1 == NULL && node->child2 == NULL) {
00591
00592
00593
00594
00595
00596
00597
00598
00599
00600 result.addPoint(dataset[node->divfeat],node->divfeat);
00601 return;
00602 }
00603
00604
00605 ELEM_TYPE val = vec[node->divfeat];
00606 DIST_TYPE diff = val - node->divval;
00607 Tree bestChild = (diff < 0) ? node->child1 : node->child2;
00608 Tree otherChild = (diff < 0) ? node->child2 : node->child1;
00609
00610
00611
00612 searchLevelExact(result, vec, bestChild, mindistsq);
00613 DIST_TYPE new_distsq = (DIST_TYPE)flann_dist(&val, &val+1, &node->divval, mindistsq);
00614 searchLevelExact(result, vec, otherChild, new_distsq);
00615 }
00616
00617 };
00618
00619 }
00620
00621 #endif //_OPENCV_KDTREE_H_