00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031 #ifndef OPENCV_FLANN_KDTREE_INDEX_H_
00032 #define OPENCV_FLANN_KDTREE_INDEX_H_
00033
00034 #include <algorithm>
00035 #include <map>
00036 #include <cassert>
00037 #include <cstring>
00038
00039 #include "general.h"
00040 #include "nn_index.h"
00041 #include "dynamic_bitset.h"
00042 #include "matrix.h"
00043 #include "result_set.h"
00044 #include "heap.h"
00045 #include "allocator.h"
00046 #include "random.h"
00047 #include "saving.h"
00048
00049
00050 namespace cvflann
00051 {
00052
00053 struct KDTreeIndexParams : public IndexParams
00054 {
00055 KDTreeIndexParams(int trees = 4)
00056 {
00057 (*this)["algorithm"] = FLANN_INDEX_KDTREE;
00058 (*this)["trees"] = trees;
00059 }
00060 };
00061
00062
00069 template <typename Distance>
00070 class KDTreeIndex : public NNIndex<Distance>
00071 {
00072 public:
00073 typedef typename Distance::ElementType ElementType;
00074 typedef typename Distance::ResultType DistanceType;
00075
00076
00084 KDTreeIndex(const Matrix<ElementType>& inputData, const IndexParams& params = KDTreeIndexParams(),
00085 Distance d = Distance() ) :
00086 dataset_(inputData), index_params_(params), distance_(d)
00087 {
00088 size_ = dataset_.rows;
00089 veclen_ = dataset_.cols;
00090
00091 trees_ = get_param(index_params_,"trees",4);
00092 tree_roots_ = new NodePtr[trees_];
00093
00094
00095 vind_.resize(size_);
00096 for (size_t i = 0; i < size_; ++i) {
00097 vind_[i] = int(i);
00098 }
00099
00100 mean_ = new DistanceType[veclen_];
00101 var_ = new DistanceType[veclen_];
00102 }
00103
00104
00105 KDTreeIndex(const KDTreeIndex&);
00106 KDTreeIndex& operator=(const KDTreeIndex&);
00107
00111 ~KDTreeIndex()
00112 {
00113 if (tree_roots_!=NULL) {
00114 delete[] tree_roots_;
00115 }
00116 delete[] mean_;
00117 delete[] var_;
00118 }
00119
00123 void buildIndex()
00124 {
00125
00126 for (int i = 0; i < trees_; i++) {
00127
00128 std::random_shuffle(vind_.begin(), vind_.end());
00129 tree_roots_[i] = divideTree(&vind_[0], int(size_) );
00130 }
00131 }
00132
00133
00134 flann_algorithm_t getType() const
00135 {
00136 return FLANN_INDEX_KDTREE;
00137 }
00138
00139
00140 void saveIndex(FILE* stream)
00141 {
00142 save_value(stream, trees_);
00143 for (int i=0; i<trees_; ++i) {
00144 save_tree(stream, tree_roots_[i]);
00145 }
00146 }
00147
00148
00149
00150 void loadIndex(FILE* stream)
00151 {
00152 load_value(stream, trees_);
00153 if (tree_roots_!=NULL) {
00154 delete[] tree_roots_;
00155 }
00156 tree_roots_ = new NodePtr[trees_];
00157 for (int i=0; i<trees_; ++i) {
00158 load_tree(stream,tree_roots_[i]);
00159 }
00160
00161 index_params_["algorithm"] = getType();
00162 index_params_["trees"] = tree_roots_;
00163 }
00164
00168 size_t size() const
00169 {
00170 return size_;
00171 }
00172
00176 size_t veclen() const
00177 {
00178 return veclen_;
00179 }
00180
00185 int usedMemory() const
00186 {
00187 return int(pool_.usedMemory+pool_.wastedMemory+dataset_.rows*sizeof(int));
00188 }
00189
00199 void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams)
00200 {
00201 int maxChecks = get_param(searchParams,"checks", 32);
00202 float epsError = 1+get_param(searchParams,"eps",0.0f);
00203
00204 if (maxChecks==FLANN_CHECKS_UNLIMITED) {
00205 getExactNeighbors(result, vec, epsError);
00206 }
00207 else {
00208 getNeighbors(result, vec, maxChecks, epsError);
00209 }
00210 }
00211
00212 IndexParams getParameters() const
00213 {
00214 return index_params_;
00215 }
00216
00217 private:
00218
00219
00220
00221 struct Node
00222 {
00226 int divfeat;
00230 DistanceType divval;
00234 Node* child1, * child2;
00235 };
00236 typedef Node* NodePtr;
00237 typedef BranchStruct<NodePtr, DistanceType> BranchSt;
00238 typedef BranchSt* Branch;
00239
00240
00241
00242 void save_tree(FILE* stream, NodePtr tree)
00243 {
00244 save_value(stream, *tree);
00245 if (tree->child1!=NULL) {
00246 save_tree(stream, tree->child1);
00247 }
00248 if (tree->child2!=NULL) {
00249 save_tree(stream, tree->child2);
00250 }
00251 }
00252
00253
00254 void load_tree(FILE* stream, NodePtr& tree)
00255 {
00256 tree = pool_.allocate<Node>();
00257 load_value(stream, *tree);
00258 if (tree->child1!=NULL) {
00259 load_tree(stream, tree->child1);
00260 }
00261 if (tree->child2!=NULL) {
00262 load_tree(stream, tree->child2);
00263 }
00264 }
00265
00266
00276 NodePtr divideTree(int* ind, int count)
00277 {
00278 NodePtr node = pool_.allocate<Node>();
00279
00280
00281 if ( count == 1) {
00282 node->child1 = node->child2 = NULL;
00283 node->divfeat = *ind;
00284 }
00285 else {
00286 int idx;
00287 int cutfeat;
00288 DistanceType cutval;
00289 meanSplit(ind, count, idx, cutfeat, cutval);
00290
00291 node->divfeat = cutfeat;
00292 node->divval = cutval;
00293 node->child1 = divideTree(ind, idx);
00294 node->child2 = divideTree(ind+idx, count-idx);
00295 }
00296
00297 return node;
00298 }
00299
00300
00306 void meanSplit(int* ind, int count, int& index, int& cutfeat, DistanceType& cutval)
00307 {
00308 memset(mean_,0,veclen_*sizeof(DistanceType));
00309 memset(var_,0,veclen_*sizeof(DistanceType));
00310
00311
00312
00313
00314 int cnt = std::min((int)SAMPLE_MEAN+1, count);
00315 for (int j = 0; j < cnt; ++j) {
00316 ElementType* v = dataset_[ind[j]];
00317 for (size_t k=0; k<veclen_; ++k) {
00318 mean_[k] += v[k];
00319 }
00320 }
00321 for (size_t k=0; k<veclen_; ++k) {
00322 mean_[k] /= cnt;
00323 }
00324
00325
00326 for (int j = 0; j < cnt; ++j) {
00327 ElementType* v = dataset_[ind[j]];
00328 for (size_t k=0; k<veclen_; ++k) {
00329 DistanceType dist = v[k] - mean_[k];
00330 var_[k] += dist * dist;
00331 }
00332 }
00333
00334 cutfeat = selectDivision(var_);
00335 cutval = mean_[cutfeat];
00336
00337 int lim1, lim2;
00338 planeSplit(ind, count, cutfeat, cutval, lim1, lim2);
00339
00340 if (lim1>count/2) index = lim1;
00341 else if (lim2<count/2) index = lim2;
00342 else index = count/2;
00343
00344
00345
00346
00347 if ((lim1==count)||(lim2==0)) index = count/2;
00348 }
00349
00350
00355 int selectDivision(DistanceType* v)
00356 {
00357 int num = 0;
00358 size_t topind[RAND_DIM];
00359
00360
00361 for (size_t i = 0; i < veclen_; ++i) {
00362 if ((num < RAND_DIM)||(v[i] > v[topind[num-1]])) {
00363
00364 if (num < RAND_DIM) {
00365 topind[num++] = i;
00366 }
00367 else {
00368 topind[num-1] = i;
00369 }
00370
00371 int j = num - 1;
00372 while (j > 0 && v[topind[j]] > v[topind[j-1]]) {
00373 std::swap(topind[j], topind[j-1]);
00374 --j;
00375 }
00376 }
00377 }
00378
00379 int rnd = rand_int(num);
00380 return (int)topind[rnd];
00381 }
00382
00383
00393 void planeSplit(int* ind, int count, int cutfeat, DistanceType cutval, int& lim1, int& lim2)
00394 {
00395
00396 int left = 0;
00397 int right = count-1;
00398 for (;; ) {
00399 while (left<=right && dataset_[ind[left]][cutfeat]<cutval) ++left;
00400 while (left<=right && dataset_[ind[right]][cutfeat]>=cutval) --right;
00401 if (left>right) break;
00402 std::swap(ind[left], ind[right]); ++left; --right;
00403 }
00404 lim1 = left;
00405 right = count-1;
00406 for (;; ) {
00407 while (left<=right && dataset_[ind[left]][cutfeat]<=cutval) ++left;
00408 while (left<=right && dataset_[ind[right]][cutfeat]>cutval) --right;
00409 if (left>right) break;
00410 std::swap(ind[left], ind[right]); ++left; --right;
00411 }
00412 lim2 = left;
00413 }
00414
00419 void getExactNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, float epsError)
00420 {
00421
00422
00423 if (trees_ > 1) {
00424 fprintf(stderr,"It doesn't make any sense to use more than one tree for exact search");
00425 }
00426 if (trees_>0) {
00427 searchLevelExact(result, vec, tree_roots_[0], 0.0, epsError);
00428 }
00429 assert(result.full());
00430 }
00431
00437 void getNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, int maxCheck, float epsError)
00438 {
00439 int i;
00440 BranchSt branch;
00441
00442 int checkCount = 0;
00443 Heap<BranchSt>* heap = new Heap<BranchSt>((int)size_);
00444 DynamicBitset checked(size_);
00445
00446
00447 for (i = 0; i < trees_; ++i) {
00448 searchLevel(result, vec, tree_roots_[i], 0, checkCount, maxCheck, epsError, heap, checked);
00449 }
00450
00451
00452 while ( heap->popMin(branch) && (checkCount < maxCheck || !result.full() )) {
00453 searchLevel(result, vec, branch.node, branch.mindist, checkCount, maxCheck, epsError, heap, checked);
00454 }
00455
00456 delete heap;
00457
00458 assert(result.full());
00459 }
00460
00461
00467 void searchLevel(ResultSet<DistanceType>& result_set, const ElementType* vec, NodePtr node, DistanceType mindist, int& checkCount, int maxCheck,
00468 float epsError, Heap<BranchSt>* heap, DynamicBitset& checked)
00469 {
00470 if (result_set.worstDist()<mindist) {
00471
00472 return;
00473 }
00474
00475
00476 if ((node->child1 == NULL)&&(node->child2 == NULL)) {
00477
00478
00479
00480
00481 int index = node->divfeat;
00482 if ( checked.test(index) || ((checkCount>=maxCheck)&& result_set.full()) ) return;
00483 checked.set(index);
00484 checkCount++;
00485
00486 DistanceType dist = distance_(dataset_[index], vec, veclen_);
00487 result_set.addPoint(dist,index);
00488
00489 return;
00490 }
00491
00492
00493 ElementType val = vec[node->divfeat];
00494 DistanceType diff = val - node->divval;
00495 NodePtr bestChild = (diff < 0) ? node->child1 : node->child2;
00496 NodePtr otherChild = (diff < 0) ? node->child2 : node->child1;
00497
00498
00499
00500
00501
00502
00503
00504
00505
00506 DistanceType new_distsq = mindist + distance_.accum_dist(val, node->divval, node->divfeat);
00507
00508 if ((new_distsq*epsError < result_set.worstDist())|| !result_set.full()) {
00509 heap->insert( BranchSt(otherChild, new_distsq) );
00510 }
00511
00512
00513 searchLevel(result_set, vec, bestChild, mindist, checkCount, maxCheck, epsError, heap, checked);
00514 }
00515
00519 void searchLevelExact(ResultSet<DistanceType>& result_set, const ElementType* vec, const NodePtr node, DistanceType mindist, const float epsError)
00520 {
00521
00522 if ((node->child1 == NULL)&&(node->child2 == NULL)) {
00523 int index = node->divfeat;
00524 DistanceType dist = distance_(dataset_[index], vec, veclen_);
00525 result_set.addPoint(dist,index);
00526 return;
00527 }
00528
00529
00530 ElementType val = vec[node->divfeat];
00531 DistanceType diff = val - node->divval;
00532 NodePtr bestChild = (diff < 0) ? node->child1 : node->child2;
00533 NodePtr otherChild = (diff < 0) ? node->child2 : node->child1;
00534
00535
00536
00537
00538
00539
00540
00541
00542
00543 DistanceType new_distsq = mindist + distance_.accum_dist(val, node->divval, node->divfeat);
00544
00545
00546 searchLevelExact(result_set, vec, bestChild, mindist, epsError);
00547
00548 if (new_distsq*epsError<=result_set.worstDist()) {
00549 searchLevelExact(result_set, vec, otherChild, new_distsq, epsError);
00550 }
00551 }
00552
00553
00554 private:
00555
00556 enum
00557 {
00563 SAMPLE_MEAN = 100,
00571 RAND_DIM=5
00572 };
00573
00574
00578 int trees_;
00579
00583 std::vector<int> vind_;
00584
00588 const Matrix<ElementType> dataset_;
00589
00590 IndexParams index_params_;
00591
00592 size_t size_;
00593 size_t veclen_;
00594
00595
00596 DistanceType* mean_;
00597 DistanceType* var_;
00598
00599
00603 NodePtr* tree_roots_;
00604
00612 PooledAllocator pool_;
00613
00614 Distance distance_;
00615
00616
00617 };
00618
00619 }
00620
00621 #endif //OPENCV_FLANN_KDTREE_INDEX_H_