00001 /*********************************************************************** 00002 * Software License Agreement (BSD License) 00003 * 00004 * Copyright 2008-2009 Marius Muja (mariusm@cs.ubc.ca). All rights reserved. 00005 * Copyright 2008-2009 David G. Lowe (lowe@cs.ubc.ca). All rights reserved. 00006 * 00007 * THE BSD LICENSE 00008 * 00009 * Redistribution and use in source and binary forms, with or without 00010 * modification, are permitted provided that the following conditions 00011 * are met: 00012 * 00013 * 1. Redistributions of source code must retain the above copyright 00014 * notice, this list of conditions and the following disclaimer. 00015 * 2. Redistributions in binary form must reproduce the above copyright 00016 * notice, this list of conditions and the following disclaimer in the 00017 * documentation and/or other materials provided with the distribution. 00018 * 00019 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR 00020 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES 00021 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. 00022 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, 00023 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT 00024 * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 00025 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 00026 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 00027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF 00028 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00029 *************************************************************************/ 00030 00031 #ifndef _OPENCV_KDTREE_H_ 00032 #define _OPENCV_KDTREE_H_ 00033 00034 #include <algorithm> 00035 #include <map> 00036 #include <cassert> 00037 #include <cstring> 00038 00039 #include "opencv2/flann/general.h" 00040 #include "opencv2/flann/nn_index.h" 00041 #include "opencv2/flann/matrix.h" 00042 #include "opencv2/flann/result_set.h" 00043 #include "opencv2/flann/heap.h" 00044 #include "opencv2/flann/allocator.h" 00045 #include "opencv2/flann/random.h" 00046 #include "opencv2/flann/saving.h" 00047 00048 00049 namespace cvflann 00050 { 00051 00052 struct CV_EXPORTS KDTreeIndexParams : public IndexParams { 00053 KDTreeIndexParams(int trees_ = 4) : IndexParams(FLANN_INDEX_KDTREE), trees(trees_) {}; 00054 00055 int trees; // number of randomized trees to use (for kdtree) 00056 00057 void print() const 00058 { 00059 logger().info("Index type: %d\n",(int)algorithm); 00060 logger().info("Trees: %d\n", trees); 00061 } 00062 00063 }; 00064 00065 00072 template <typename ELEM_TYPE, typename DIST_TYPE = typename DistType<ELEM_TYPE>::type > 00073 class KDTreeIndex : public NNIndex<ELEM_TYPE> 00074 { 00075 00076 enum { 00082 SAMPLE_MEAN = 100, 00090 RAND_DIM=5 00091 }; 00092 00093 00097 int numTrees; 00098 00102 int* vind; 00103 00104 00108 const Matrix<ELEM_TYPE> dataset; 00109 00110 const IndexParams& index_params; 00111 00112 size_t size_; 00113 size_t veclen_; 00114 00115 00116 DIST_TYPE* mean; 00117 DIST_TYPE* var; 00118 00119 00120 /*--------------------- Internal Data Structures --------------------------*/ 00121 00128 struct TreeSt { 00134 int divfeat; 00138 DIST_TYPE divval; 00142 TreeSt *child1, *child2; 00143 }; 00144 typedef TreeSt* Tree; 00145 00149 Tree* trees; 00150 typedef BranchStruct<Tree> BranchSt; 00151 typedef BranchSt* Branch; 00152 00160 PooledAllocator pool; 00161 00162 00163 00164 public: 00165 00166 flann_algorithm_t getType() const 00167 { 00168 return FLANN_INDEX_KDTREE; 00169 } 00170 00178 KDTreeIndex(const Matrix<ELEM_TYPE>& inputData, const KDTreeIndexParams& params = KDTreeIndexParams() ) : 00179 dataset(inputData), index_params(params) 00180 { 00181 size_ = dataset.rows; 00182 veclen_ = dataset.cols; 00183 00184 numTrees = params.trees; 00185 trees = new Tree[numTrees]; 00186 00187 // get the parameters 00188 // if (params.find("trees") != params.end()) { 00189 // numTrees = (int)params["trees"]; 00190 // trees = new Tree[numTrees]; 00191 // } 00192 // else { 00193 // numTrees = -1; 00194 // trees = NULL; 00195 // } 00196 00197 // Create a permutable array of indices to the input vectors. 00198 vind = new int[size_]; 00199 for (size_t i = 0; i < size_; i++) { 00200 vind[i] = (int)i; 00201 } 00202 00203 mean = new DIST_TYPE[veclen_]; 00204 var = new DIST_TYPE[veclen_]; 00205 } 00206 00210 ~KDTreeIndex() 00211 { 00212 delete[] vind; 00213 if (trees!=NULL) { 00214 delete[] trees; 00215 } 00216 delete[] mean; 00217 delete[] var; 00218 } 00219 00220 00224 void buildIndex() 00225 { 00226 /* Construct the randomized trees. */ 00227 for (int i = 0; i < numTrees; i++) { 00228 /* Randomize the order of vectors to allow for unbiased sampling. */ 00229 for (int j = (int)size_; j > 0; --j) { 00230 int rnd = rand_int(j); 00231 std::swap(vind[j-1], vind[rnd]); 00232 } 00233 trees[i] = divideTree(0, (int)size_ - 1); 00234 } 00235 } 00236 00237 00238 00239 void saveIndex(FILE* stream) 00240 { 00241 save_value(stream, numTrees); 00242 for (int i=0;i<numTrees;++i) { 00243 save_tree(stream, trees[i]); 00244 } 00245 } 00246 00247 00248 00249 void loadIndex(FILE* stream) 00250 { 00251 load_value(stream, numTrees); 00252 00253 if (trees!=NULL) { 00254 delete[] trees; 00255 } 00256 trees = new Tree[numTrees]; 00257 for (int i=0;i<numTrees;++i) { 00258 load_tree(stream,trees[i]); 00259 } 00260 } 00261 00262 00266 size_t size() const 00267 { 00268 return size_; 00269 } 00270 00274 size_t veclen() const 00275 { 00276 return veclen_; 00277 } 00278 00279 00284 int usedMemory() const 00285 { 00286 return (int)(pool.usedMemory+pool.wastedMemory+dataset.rows*sizeof(int)); // pool memory and vind array memory 00287 } 00288 00289 00299 void findNeighbors(ResultSet<ELEM_TYPE>& result, const ELEM_TYPE* vec, const SearchParams& searchParams) 00300 { 00301 int maxChecks = searchParams.checks; 00302 00303 if (maxChecks<0) { 00304 getExactNeighbors(result, vec); 00305 } else { 00306 getNeighbors(result, vec, maxChecks); 00307 } 00308 } 00309 00310 const IndexParams* getParameters() const 00311 { 00312 return &index_params; 00313 } 00314 00315 private: 00316 00317 KDTreeIndex& operator=(const KDTreeIndex&); 00318 KDTreeIndex(const KDTreeIndex&); 00319 00320 00321 void save_tree(FILE* stream, Tree tree) 00322 { 00323 save_value(stream, *tree); 00324 if (tree->child1!=NULL) { 00325 save_tree(stream, tree->child1); 00326 } 00327 if (tree->child2!=NULL) { 00328 save_tree(stream, tree->child2); 00329 } 00330 } 00331 00332 00333 void load_tree(FILE* stream, Tree& tree) 00334 { 00335 tree = pool.allocate<TreeSt>(); 00336 load_value(stream, *tree); 00337 if (tree->child1!=NULL) { 00338 load_tree(stream, tree->child1); 00339 } 00340 if (tree->child2!=NULL) { 00341 load_tree(stream, tree->child2); 00342 } 00343 } 00344 00345 00355 Tree divideTree(int first, int last) 00356 { 00357 Tree node = pool.allocate<TreeSt>(); // allocate memory 00358 00359 /* If only one exemplar remains, then make this a leaf node. */ 00360 if (first == last) { 00361 node->child1 = node->child2 = NULL; /* Mark as leaf node. */ 00362 node->divfeat = vind[first]; /* Store index of this vec. */ 00363 } 00364 else { 00365 chooseDivision(node, first, last); 00366 subdivide(node, first, last); 00367 } 00368 00369 return node; 00370 } 00371 00372 00378 void chooseDivision(Tree node, int first, int last) 00379 { 00380 memset(mean,0,veclen_*sizeof(DIST_TYPE)); 00381 memset(var,0,veclen_*sizeof(DIST_TYPE)); 00382 00383 /* Compute mean values. Only the first SAMPLE_MEAN values need to be 00384 sampled to get a good estimate. 00385 */ 00386 int end = std::min(first + SAMPLE_MEAN, last); 00387 for (int j = first; j <= end; ++j) { 00388 ELEM_TYPE* v = dataset[vind[j]]; 00389 for (size_t k=0; k<veclen_; ++k) { 00390 mean[k] += v[k]; 00391 } 00392 } 00393 for (size_t k=0; k<veclen_; ++k) { 00394 mean[k] /= (end - first + 1); 00395 } 00396 00397 /* Compute variances (no need to divide by count). */ 00398 for (int j = first; j <= end; ++j) { 00399 ELEM_TYPE* v = dataset[vind[j]]; 00400 for (size_t k=0; k<veclen_; ++k) { 00401 DIST_TYPE dist = v[k] - mean[k]; 00402 var[k] += dist * dist; 00403 } 00404 } 00405 /* Select one of the highest variance indices at random. */ 00406 node->divfeat = selectDivision(var); 00407 node->divval = mean[node->divfeat]; 00408 00409 } 00410 00411 00416 int selectDivision(DIST_TYPE* v) 00417 { 00418 int num = 0; 00419 int topind[RAND_DIM]; 00420 00421 /* Create a list of the indices of the top RAND_DIM values. */ 00422 for (size_t i = 0; i < veclen_; ++i) { 00423 if (num < RAND_DIM || v[i] > v[topind[num-1]]) { 00424 /* Put this element at end of topind. */ 00425 if (num < RAND_DIM) { 00426 topind[num++] = (int)i; /* Add to list. */ 00427 } 00428 else { 00429 topind[num-1] = (int)i; /* Replace last element. */ 00430 } 00431 /* Bubble end value down to right location by repeated swapping. */ 00432 int j = num - 1; 00433 while (j > 0 && v[topind[j]] > v[topind[j-1]]) { 00434 std::swap(topind[j], topind[j-1]); 00435 --j; 00436 } 00437 } 00438 } 00439 /* Select a random integer in range [0,num-1], and return that index. */ 00440 int rnd = rand_int(num); 00441 return topind[rnd]; 00442 } 00443 00444 00449 void subdivide(Tree node, int first, int last) 00450 { 00451 /* Move vector indices for left subtree to front of list. */ 00452 int i = first; 00453 int j = last; 00454 while (i <= j) { 00455 int ind = vind[i]; 00456 ELEM_TYPE val = dataset[ind][node->divfeat]; 00457 if (val < node->divval) { 00458 ++i; 00459 } else { 00460 /* Move to end of list by swapping vind i and j. */ 00461 std::swap(vind[i], vind[j]); 00462 --j; 00463 } 00464 } 00465 /* If either list is empty, it means we have hit the unlikely case 00466 in which all remaining features are identical. Split in the middle 00467 to maintain a balanced tree. 00468 */ 00469 if ( (i == first) || (i == last+1)) { 00470 i = (first+last+1)/2; 00471 } 00472 00473 node->child1 = divideTree(first, i - 1); 00474 node->child2 = divideTree(i, last); 00475 } 00476 00477 00478 00483 void getExactNeighbors(ResultSet<ELEM_TYPE>& result, const ELEM_TYPE* vec) 00484 { 00485 // checkID -= 1; /* Set a different unique ID for each search. */ 00486 00487 if (numTrees > 1) { 00488 fprintf(stderr,"It doesn't make any sense to use more than one tree for exact search"); 00489 } 00490 if (numTrees>0) { 00491 searchLevelExact(result, vec, trees[0], 0.0); 00492 } 00493 assert(result.full()); 00494 } 00495 00501 void getNeighbors(ResultSet<ELEM_TYPE>& result, const ELEM_TYPE* vec, int maxCheck) 00502 { 00503 int i; 00504 BranchSt branch; 00505 00506 int checkCount = 0; 00507 Heap<BranchSt>* heap = new Heap<BranchSt>((int)size_); 00508 std::vector<bool> checked(size_,false); 00509 00510 /* Search once through each tree down to root. */ 00511 for (i = 0; i < numTrees; ++i) { 00512 searchLevel(result, vec, trees[i], 0.0, checkCount, maxCheck, heap, checked); 00513 } 00514 00515 /* Keep searching other branches from heap until finished. */ 00516 while ( heap->popMin(branch) && (checkCount < maxCheck || !result.full() )) { 00517 searchLevel(result, vec, branch.node, branch.mindistsq, checkCount, maxCheck, heap, checked); 00518 } 00519 00520 delete heap; 00521 00522 assert(result.full()); 00523 } 00524 00525 00531 void searchLevel(ResultSet<ELEM_TYPE>& result, const ELEM_TYPE* vec, Tree node, float mindistsq, int& checkCount, int maxCheck, 00532 Heap<BranchSt>* heap, std::vector<bool>& checked) 00533 { 00534 if (result.worstDist()<mindistsq) { 00535 // printf("Ignoring branch, too far\n"); 00536 return; 00537 } 00538 00539 /* If this is a leaf node, then do check and return. */ 00540 if (node->child1 == NULL && node->child2 == NULL) { 00541 00542 /* Do not check same node more than once when searching multiple trees. 00543 Once a vector is checked, we set its location in vind to the 00544 current checkID. 00545 */ 00546 if (checked[node->divfeat] == true || checkCount>=maxCheck) { 00547 if (result.full()) return; 00548 } 00549 checkCount++; 00550 checked[node->divfeat] = true; 00551 00552 result.addPoint(dataset[node->divfeat],node->divfeat); 00553 return; 00554 } 00555 00556 /* Which child branch should be taken first? */ 00557 ELEM_TYPE val = vec[node->divfeat]; 00558 DIST_TYPE diff = val - node->divval; 00559 Tree bestChild = (diff < 0) ? node->child1 : node->child2; 00560 Tree otherChild = (diff < 0) ? node->child2 : node->child1; 00561 00562 /* Create a branch record for the branch not taken. Add distance 00563 of this feature boundary (we don't attempt to correct for any 00564 use of this feature in a parent node, which is unlikely to 00565 happen and would have only a small effect). Don't bother 00566 adding more branches to heap after halfway point, as cost of 00567 adding exceeds their value. 00568 */ 00569 00570 DIST_TYPE new_distsq = (DIST_TYPE)flann_dist(&val, &val+1, &node->divval, mindistsq); 00571 // if (2 * checkCount < maxCheck || !result.full()) { 00572 if (new_distsq < result.worstDist() || !result.full()) { 00573 heap->insert( BranchSt::make_branch(otherChild, new_distsq) ); 00574 } 00575 00576 /* Call recursively to search next level down. */ 00577 searchLevel(result, vec, bestChild, mindistsq, checkCount, maxCheck, heap, checked); 00578 } 00579 00583 void searchLevelExact(ResultSet<ELEM_TYPE>& result, const ELEM_TYPE* vec, Tree node, float mindistsq) 00584 { 00585 if (mindistsq>result.worstDist()) { 00586 return; 00587 } 00588 00589 /* If this is a leaf node, then do check and return. */ 00590 if (node->child1 == NULL && node->child2 == NULL) { 00591 00592 /* Do not check same node more than once when searching multiple trees. 00593 Once a vector is checked, we set its location in vind to the 00594 current checkID. 00595 */ 00596 // if (vind[node->divfeat] == checkID) 00597 // return; 00598 // vind[node->divfeat] = checkID; 00599 00600 result.addPoint(dataset[node->divfeat],node->divfeat); 00601 return; 00602 } 00603 00604 /* Which child branch should be taken first? */ 00605 ELEM_TYPE val = vec[node->divfeat]; 00606 DIST_TYPE diff = val - node->divval; 00607 Tree bestChild = (diff < 0) ? node->child1 : node->child2; 00608 Tree otherChild = (diff < 0) ? node->child2 : node->child1; 00609 00610 00611 /* Call recursively to search next level down. */ 00612 searchLevelExact(result, vec, bestChild, mindistsq); 00613 DIST_TYPE new_distsq = (DIST_TYPE)flann_dist(&val, &val+1, &node->divval, mindistsq); 00614 searchLevelExact(result, vec, otherChild, new_distsq); 00615 } 00616 00617 }; // class KDTree 00618 00619 } // namespace cvflann 00620 00621 #endif //_OPENCV_KDTREE_H_