include/opencv2/flann/kdtree_index.h
Go to the documentation of this file.
00001 /***********************************************************************
00002  * Software License Agreement (BSD License)
00003  *
00004  * Copyright 2008-2009  Marius Muja (mariusm@cs.ubc.ca). All rights reserved.
00005  * Copyright 2008-2009  David G. Lowe (lowe@cs.ubc.ca). All rights reserved.
00006  *
00007  * THE BSD LICENSE
00008  *
00009  * Redistribution and use in source and binary forms, with or without
00010  * modification, are permitted provided that the following conditions
00011  * are met:
00012  *
00013  * 1. Redistributions of source code must retain the above copyright
00014  *    notice, this list of conditions and the following disclaimer.
00015  * 2. Redistributions in binary form must reproduce the above copyright
00016  *    notice, this list of conditions and the following disclaimer in the
00017  *    documentation and/or other materials provided with the distribution.
00018  *
00019  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
00020  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
00021  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
00022  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
00023  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
00024  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
00025  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
00026  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00027  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
00028  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00029  *************************************************************************/
00030 
00031 #ifndef _OPENCV_KDTREE_H_
00032 #define _OPENCV_KDTREE_H_
00033 
00034 #include <algorithm>
00035 #include <map>
00036 #include <cassert>
00037 #include <cstring>
00038 
00039 #include "opencv2/flann/general.h"
00040 #include "opencv2/flann/nn_index.h"
00041 #include "opencv2/flann/matrix.h"
00042 #include "opencv2/flann/result_set.h"
00043 #include "opencv2/flann/heap.h"
00044 #include "opencv2/flann/allocator.h"
00045 #include "opencv2/flann/random.h"
00046 #include "opencv2/flann/saving.h"
00047 
00048 
00049 namespace cvflann
00050 {
00051 
00052 struct CV_EXPORTS KDTreeIndexParams : public IndexParams {
00053     KDTreeIndexParams(int trees_ = 4) : IndexParams(FLANN_INDEX_KDTREE), trees(trees_) {};
00054 
00055     int trees;                 // number of randomized trees to use (for kdtree)
00056 
00057     void print() const
00058     {
00059         logger().info("Index type: %d\n",(int)algorithm);
00060         logger().info("Trees: %d\n", trees);
00061     }
00062 
00063 };
00064 
00065 
00072 template <typename ELEM_TYPE, typename DIST_TYPE = typename DistType<ELEM_TYPE>::type >
00073 class KDTreeIndex : public NNIndex<ELEM_TYPE>
00074 {
00075 
00076     enum {
00082         SAMPLE_MEAN = 100,
00090         RAND_DIM=5
00091     };
00092 
00093 
00097     int numTrees;
00098 
00102     int* vind;
00103 
00104 
00108     const Matrix<ELEM_TYPE> dataset;
00109 
00110     const IndexParams& index_params;
00111 
00112     size_t size_;
00113     size_t veclen_;
00114 
00115 
00116     DIST_TYPE* mean;
00117     DIST_TYPE* var;
00118 
00119 
00120     /*--------------------- Internal Data Structures --------------------------*/
00121 
00128     struct TreeSt {
00134         int divfeat;
00138         DIST_TYPE divval;
00142         TreeSt *child1, *child2;
00143     };
00144     typedef TreeSt* Tree;
00145 
00149     Tree* trees;
00150     typedef BranchStruct<Tree> BranchSt;
00151     typedef BranchSt* Branch;
00152 
00160     PooledAllocator pool;
00161 
00162 
00163 
00164 public:
00165 
00166     flann_algorithm_t getType() const
00167     {
00168         return FLANN_INDEX_KDTREE;
00169     }
00170 
00178     KDTreeIndex(const Matrix<ELEM_TYPE>& inputData, const KDTreeIndexParams& params = KDTreeIndexParams() ) :
00179         dataset(inputData), index_params(params)
00180     {
00181         size_ = dataset.rows;
00182         veclen_ = dataset.cols;
00183 
00184         numTrees = params.trees;
00185         trees = new Tree[numTrees];
00186 
00187         // get the parameters
00188 //        if (params.find("trees") != params.end()) {
00189 //          numTrees = (int)params["trees"];
00190 //          trees = new Tree[numTrees];
00191 //        }
00192 //        else {
00193 //          numTrees = -1;
00194 //          trees = NULL;
00195 //        }
00196 
00197         // Create a permutable array of indices to the input vectors.
00198         vind = new int[size_];
00199         for (size_t i = 0; i < size_; i++) {
00200             vind[i] = (int)i;
00201         }
00202 
00203         mean = new DIST_TYPE[veclen_];
00204         var = new DIST_TYPE[veclen_];
00205     }
00206 
00210     ~KDTreeIndex()
00211     {
00212         delete[] vind;
00213         if (trees!=NULL) {
00214             delete[] trees;
00215         }
00216         delete[] mean;
00217         delete[] var;
00218     }
00219 
00220 
00224     void buildIndex()
00225     {
00226         /* Construct the randomized trees. */
00227         for (int i = 0; i < numTrees; i++) {
00228             /* Randomize the order of vectors to allow for unbiased sampling. */
00229             for (int j = (int)size_; j > 0; --j) {
00230                 int rnd = rand_int(j);
00231                 std::swap(vind[j-1], vind[rnd]);
00232             }
00233             trees[i] = divideTree(0, (int)size_ - 1);
00234         }
00235     }
00236 
00237 
00238 
00239     void saveIndex(FILE* stream)
00240     {
00241         save_value(stream, numTrees);
00242         for (int i=0;i<numTrees;++i) {
00243             save_tree(stream, trees[i]);
00244         }
00245     }
00246 
00247 
00248 
00249     void loadIndex(FILE* stream)
00250     {
00251         load_value(stream, numTrees);
00252 
00253         if (trees!=NULL) {
00254             delete[] trees;
00255         }
00256         trees = new Tree[numTrees];
00257         for (int i=0;i<numTrees;++i) {
00258             load_tree(stream,trees[i]);
00259         }
00260     }
00261 
00262 
00266     size_t size() const
00267     {
00268         return size_;
00269     }
00270 
00274     size_t veclen() const
00275     {
00276         return veclen_;
00277     }
00278 
00279 
00284     int usedMemory() const
00285     {
00286         return  (int)(pool.usedMemory+pool.wastedMemory+dataset.rows*sizeof(int));   // pool memory and vind array memory
00287     }
00288 
00289 
00299     void findNeighbors(ResultSet<ELEM_TYPE>& result, const ELEM_TYPE* vec, const SearchParams& searchParams)
00300     {
00301         int maxChecks = searchParams.checks;
00302 
00303         if (maxChecks<0) {
00304             getExactNeighbors(result, vec);
00305         } else {
00306             getNeighbors(result, vec, maxChecks);
00307         }
00308     }
00309 
00310     const IndexParams* getParameters() const
00311     {
00312         return &index_params;
00313     }
00314 
00315 private:
00316 
00317     KDTreeIndex& operator=(const KDTreeIndex&);
00318     KDTreeIndex(const KDTreeIndex&);
00319 
00320 
00321     void save_tree(FILE* stream, Tree tree)
00322     {
00323         save_value(stream, *tree);
00324         if (tree->child1!=NULL) {
00325             save_tree(stream, tree->child1);
00326         }
00327         if (tree->child2!=NULL) {
00328             save_tree(stream, tree->child2);
00329         }
00330     }
00331 
00332 
00333     void load_tree(FILE* stream, Tree& tree)
00334     {
00335         tree = pool.allocate<TreeSt>();
00336         load_value(stream, *tree);
00337         if (tree->child1!=NULL) {
00338             load_tree(stream, tree->child1);
00339         }
00340         if (tree->child2!=NULL) {
00341             load_tree(stream, tree->child2);
00342         }
00343     }
00344 
00345 
00355     Tree divideTree(int first, int last)
00356     {
00357         Tree node = pool.allocate<TreeSt>(); // allocate memory
00358 
00359         /* If only one exemplar remains, then make this a leaf node. */
00360         if (first == last) {
00361             node->child1 = node->child2 = NULL;    /* Mark as leaf node. */
00362             node->divfeat = vind[first];    /* Store index of this vec. */
00363         }
00364         else {
00365             chooseDivision(node, first, last);
00366             subdivide(node, first, last);
00367         }
00368 
00369         return node;
00370     }
00371 
00372 
00378     void chooseDivision(Tree node, int first, int last)
00379     {
00380         memset(mean,0,veclen_*sizeof(DIST_TYPE));
00381         memset(var,0,veclen_*sizeof(DIST_TYPE));
00382 
00383         /* Compute mean values.  Only the first SAMPLE_MEAN values need to be
00384             sampled to get a good estimate.
00385         */
00386         int end = std::min(first + SAMPLE_MEAN, last);
00387         for (int j = first; j <= end; ++j) {
00388             ELEM_TYPE* v = dataset[vind[j]];
00389             for (size_t k=0; k<veclen_; ++k) {
00390                 mean[k] += v[k];
00391             }
00392         }
00393         for (size_t k=0; k<veclen_; ++k) {
00394             mean[k] /= (end - first + 1);
00395         }
00396 
00397         /* Compute variances (no need to divide by count). */
00398         for (int j = first; j <= end; ++j) {
00399             ELEM_TYPE* v = dataset[vind[j]];
00400             for (size_t k=0; k<veclen_; ++k) {
00401                 DIST_TYPE dist = v[k] - mean[k];
00402                 var[k] += dist * dist;
00403             }
00404         }
00405         /* Select one of the highest variance indices at random. */
00406         node->divfeat = selectDivision(var);
00407         node->divval = mean[node->divfeat];
00408 
00409     }
00410 
00411 
00416     int selectDivision(DIST_TYPE* v)
00417     {
00418         int num = 0;
00419         int topind[RAND_DIM];
00420 
00421         /* Create a list of the indices of the top RAND_DIM values. */
00422         for (size_t i = 0; i < veclen_; ++i) {
00423             if (num < RAND_DIM  ||  v[i] > v[topind[num-1]]) {
00424                 /* Put this element at end of topind. */
00425                 if (num < RAND_DIM) {
00426                     topind[num++] = (int)i;            /* Add to list. */
00427                 }
00428                 else {
00429                     topind[num-1] = (int)i;         /* Replace last element. */
00430                 }
00431                 /* Bubble end value down to right location by repeated swapping. */
00432                 int j = num - 1;
00433                 while (j > 0  &&  v[topind[j]] > v[topind[j-1]]) {
00434                     std::swap(topind[j], topind[j-1]);
00435                     --j;
00436                 }
00437             }
00438         }
00439         /* Select a random integer in range [0,num-1], and return that index. */
00440         int rnd = rand_int(num);
00441         return topind[rnd];
00442     }
00443 
00444 
00449     void subdivide(Tree node, int first, int last)
00450     {
00451         /* Move vector indices for left subtree to front of list. */
00452         int i = first;
00453         int j = last;
00454         while (i <= j) {
00455             int ind = vind[i];
00456             ELEM_TYPE val = dataset[ind][node->divfeat];
00457             if (val < node->divval) {
00458                 ++i;
00459             } else {
00460                 /* Move to end of list by swapping vind i and j. */
00461                 std::swap(vind[i], vind[j]);
00462                 --j;
00463             }
00464         }
00465         /* If either list is empty, it means we have hit the unlikely case
00466             in which all remaining features are identical. Split in the middle
00467             to maintain a balanced tree.
00468         */
00469         if ( (i == first) || (i == last+1)) {
00470             i = (first+last+1)/2;
00471         }
00472 
00473         node->child1 = divideTree(first, i - 1);
00474         node->child2 = divideTree(i, last);
00475     }
00476 
00477 
00478 
00483     void getExactNeighbors(ResultSet<ELEM_TYPE>& result, const ELEM_TYPE* vec)
00484     {
00485 //      checkID -= 1;  /* Set a different unique ID for each search. */
00486 
00487         if (numTrees > 1) {
00488             fprintf(stderr,"It doesn't make any sense to use more than one tree for exact search");
00489         }
00490         if (numTrees>0) {
00491             searchLevelExact(result, vec, trees[0], 0.0);
00492         }
00493         assert(result.full());
00494     }
00495 
00501     void getNeighbors(ResultSet<ELEM_TYPE>& result, const ELEM_TYPE* vec, int maxCheck)
00502     {
00503         int i;
00504         BranchSt branch;
00505 
00506         int checkCount = 0;
00507         Heap<BranchSt>* heap = new Heap<BranchSt>((int)size_);
00508         std::vector<bool> checked(size_,false);
00509 
00510         /* Search once through each tree down to root. */
00511         for (i = 0; i < numTrees; ++i) {
00512             searchLevel(result, vec, trees[i], 0.0, checkCount, maxCheck, heap, checked);
00513         }
00514 
00515         /* Keep searching other branches from heap until finished. */
00516         while ( heap->popMin(branch) && (checkCount < maxCheck || !result.full() )) {
00517             searchLevel(result, vec, branch.node, branch.mindistsq, checkCount, maxCheck, heap, checked);
00518         }
00519 
00520         delete heap;
00521 
00522         assert(result.full());
00523     }
00524 
00525 
00531     void searchLevel(ResultSet<ELEM_TYPE>& result, const ELEM_TYPE* vec, Tree node, float mindistsq, int& checkCount, int maxCheck,
00532             Heap<BranchSt>* heap, std::vector<bool>& checked)
00533     {
00534         if (result.worstDist()<mindistsq) {
00535 //          printf("Ignoring branch, too far\n");
00536             return;
00537         }
00538 
00539         /* If this is a leaf node, then do check and return. */
00540         if (node->child1 == NULL  &&  node->child2 == NULL) {
00541 
00542             /* Do not check same node more than once when searching multiple trees.
00543                 Once a vector is checked, we set its location in vind to the
00544                 current checkID.
00545             */
00546             if (checked[node->divfeat] == true || checkCount>=maxCheck) {
00547                 if (result.full()) return;
00548             }
00549             checkCount++;
00550             checked[node->divfeat] = true;
00551 
00552             result.addPoint(dataset[node->divfeat],node->divfeat);
00553             return;
00554         }
00555 
00556         /* Which child branch should be taken first? */
00557         ELEM_TYPE val = vec[node->divfeat];
00558         DIST_TYPE diff = val - node->divval;
00559         Tree bestChild = (diff < 0) ? node->child1 : node->child2;
00560         Tree otherChild = (diff < 0) ? node->child2 : node->child1;
00561 
00562         /* Create a branch record for the branch not taken.  Add distance
00563             of this feature boundary (we don't attempt to correct for any
00564             use of this feature in a parent node, which is unlikely to
00565             happen and would have only a small effect).  Don't bother
00566             adding more branches to heap after halfway point, as cost of
00567             adding exceeds their value.
00568         */
00569 
00570         DIST_TYPE new_distsq = (DIST_TYPE)flann_dist(&val, &val+1, &node->divval, mindistsq);
00571 //      if (2 * checkCount < maxCheck  ||  !result.full()) {
00572         if (new_distsq < result.worstDist() ||  !result.full()) {
00573             heap->insert( BranchSt::make_branch(otherChild, new_distsq) );
00574         }
00575 
00576         /* Call recursively to search next level down. */
00577         searchLevel(result, vec, bestChild, mindistsq, checkCount, maxCheck, heap, checked);
00578     }
00579 
00583     void searchLevelExact(ResultSet<ELEM_TYPE>& result, const ELEM_TYPE* vec, Tree node, float mindistsq)
00584     {
00585         if (mindistsq>result.worstDist()) {
00586             return;
00587         }
00588 
00589         /* If this is a leaf node, then do check and return. */
00590         if (node->child1 == NULL  &&  node->child2 == NULL) {
00591 
00592             /* Do not check same node more than once when searching multiple trees.
00593                 Once a vector is checked, we set its location in vind to the
00594                 current checkID.
00595             */
00596 //          if (vind[node->divfeat] == checkID)
00597 //              return;
00598 //          vind[node->divfeat] = checkID;
00599 
00600             result.addPoint(dataset[node->divfeat],node->divfeat);
00601             return;
00602         }
00603 
00604         /* Which child branch should be taken first? */
00605         ELEM_TYPE val = vec[node->divfeat];
00606         DIST_TYPE diff = val - node->divval;
00607         Tree bestChild = (diff < 0) ? node->child1 : node->child2;
00608         Tree otherChild = (diff < 0) ? node->child2 : node->child1;
00609 
00610 
00611         /* Call recursively to search next level down. */
00612         searchLevelExact(result, vec, bestChild, mindistsq);
00613         DIST_TYPE new_distsq = (DIST_TYPE)flann_dist(&val, &val+1, &node->divval, mindistsq);
00614         searchLevelExact(result, vec, otherChild, new_distsq);
00615     }
00616 
00617 };   // class KDTree
00618 
00619 } // namespace cvflann
00620 
00621 #endif //_OPENCV_KDTREE_H_