00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031 #ifndef OPENCV_FLANN_KDTREE_SINGLE_INDEX_H_
00032 #define OPENCV_FLANN_KDTREE_SINGLE_INDEX_H_
00033
00034 #include <algorithm>
00035 #include <map>
00036 #include <cassert>
00037 #include <cstring>
00038
00039 #include "general.h"
00040 #include "nn_index.h"
00041 #include "matrix.h"
00042 #include "result_set.h"
00043 #include "heap.h"
00044 #include "allocator.h"
00045 #include "random.h"
00046 #include "saving.h"
00047
00048 namespace cvflann
00049 {
00050
00051 struct KDTreeSingleIndexParams : public IndexParams
00052 {
00053 KDTreeSingleIndexParams(int leaf_max_size = 10, bool reorder = true, int dim = -1)
00054 {
00055 (*this)["algorithm"] = FLANN_INDEX_KDTREE_SINGLE;
00056 (*this)["leaf_max_size"] = leaf_max_size;
00057 (*this)["reorder"] = reorder;
00058 (*this)["dim"] = dim;
00059 }
00060 };
00061
00062
00069 template <typename Distance>
00070 class KDTreeSingleIndex : public NNIndex<Distance>
00071 {
00072 public:
00073 typedef typename Distance::ElementType ElementType;
00074 typedef typename Distance::ResultType DistanceType;
00075
00076
00084 KDTreeSingleIndex(const Matrix<ElementType>& inputData, const IndexParams& params = KDTreeSingleIndexParams(),
00085 Distance d = Distance() ) :
00086 dataset_(inputData), index_params_(params), distance_(d)
00087 {
00088 size_ = dataset_.rows;
00089 dim_ = dataset_.cols;
00090 int dim_param = get_param(params,"dim",-1);
00091 if (dim_param>0) dim_ = dim_param;
00092 leaf_max_size_ = get_param(params,"leaf_max_size",10);
00093 reorder_ = get_param(params,"reorder",true);
00094
00095
00096 vind_.resize(size_);
00097 for (size_t i = 0; i < size_; i++) {
00098 vind_[i] = (int)i;
00099 }
00100 }
00101
00102 KDTreeSingleIndex(const KDTreeSingleIndex&);
00103 KDTreeSingleIndex& operator=(const KDTreeSingleIndex&);
00104
00108 ~KDTreeSingleIndex()
00109 {
00110 if (reorder_) delete[] data_.data;
00111 }
00112
00116 void buildIndex()
00117 {
00118 computeBoundingBox(root_bbox_);
00119 root_node_ = divideTree(0, (int)size_, root_bbox_ );
00120
00121 if (reorder_) {
00122 delete[] data_.data;
00123 data_ = cvflann::Matrix<ElementType>(new ElementType[size_*dim_], size_, dim_);
00124 for (size_t i=0; i<size_; ++i) {
00125 for (size_t j=0; j<dim_; ++j) {
00126 data_[i][j] = dataset_[vind_[i]][j];
00127 }
00128 }
00129 }
00130 else {
00131 data_ = dataset_;
00132 }
00133 }
00134
00135 flann_algorithm_t getType() const
00136 {
00137 return FLANN_INDEX_KDTREE_SINGLE;
00138 }
00139
00140
00141 void saveIndex(FILE* stream)
00142 {
00143 save_value(stream, size_);
00144 save_value(stream, dim_);
00145 save_value(stream, root_bbox_);
00146 save_value(stream, reorder_);
00147 save_value(stream, leaf_max_size_);
00148 save_value(stream, vind_);
00149 if (reorder_) {
00150 save_value(stream, data_);
00151 }
00152 save_tree(stream, root_node_);
00153 }
00154
00155
00156 void loadIndex(FILE* stream)
00157 {
00158 load_value(stream, size_);
00159 load_value(stream, dim_);
00160 load_value(stream, root_bbox_);
00161 load_value(stream, reorder_);
00162 load_value(stream, leaf_max_size_);
00163 load_value(stream, vind_);
00164 if (reorder_) {
00165 load_value(stream, data_);
00166 }
00167 else {
00168 data_ = dataset_;
00169 }
00170 load_tree(stream, root_node_);
00171
00172
00173 index_params_["algorithm"] = getType();
00174 index_params_["leaf_max_size"] = leaf_max_size_;
00175 index_params_["reorder"] = reorder_;
00176 }
00177
00181 size_t size() const
00182 {
00183 return size_;
00184 }
00185
00189 size_t veclen() const
00190 {
00191 return dim_;
00192 }
00193
00198 int usedMemory() const
00199 {
00200 return (int)(pool_.usedMemory+pool_.wastedMemory+dataset_.rows*sizeof(int));
00201 }
00202
00203
00212 void knnSearch(const Matrix<ElementType>& queries, Matrix<int>& indices, Matrix<DistanceType>& dists, int knn, const SearchParams& params)
00213 {
00214 assert(queries.cols == veclen());
00215 assert(indices.rows >= queries.rows);
00216 assert(dists.rows >= queries.rows);
00217 assert(int(indices.cols) >= knn);
00218 assert(int(dists.cols) >= knn);
00219
00220 KNNSimpleResultSet<DistanceType> resultSet(knn);
00221 for (size_t i = 0; i < queries.rows; i++) {
00222 resultSet.init(indices[i], dists[i]);
00223 findNeighbors(resultSet, queries[i], params);
00224 }
00225 }
00226
00227 IndexParams getParameters() const
00228 {
00229 return index_params_;
00230 }
00231
00241 void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams)
00242 {
00243 float epsError = 1+get_param(searchParams,"eps",0.0f);
00244
00245 std::vector<DistanceType> dists(dim_,0);
00246 DistanceType distsq = computeInitialDistances(vec, dists);
00247 searchLevel(result, vec, root_node_, distsq, dists, epsError);
00248 }
00249
00250 private:
00251
00252
00253
00254 struct Node
00255 {
00259 int left, right;
00263 int divfeat;
00267 DistanceType divlow, divhigh;
00271 Node* child1, * child2;
00272 };
00273 typedef Node* NodePtr;
00274
00275
00276 struct Interval
00277 {
00278 DistanceType low, high;
00279 };
00280
00281 typedef std::vector<Interval> BoundingBox;
00282
00283 typedef BranchStruct<NodePtr, DistanceType> BranchSt;
00284 typedef BranchSt* Branch;
00285
00286
00287
00288
00289 void save_tree(FILE* stream, NodePtr tree)
00290 {
00291 save_value(stream, *tree);
00292 if (tree->child1!=NULL) {
00293 save_tree(stream, tree->child1);
00294 }
00295 if (tree->child2!=NULL) {
00296 save_tree(stream, tree->child2);
00297 }
00298 }
00299
00300
00301 void load_tree(FILE* stream, NodePtr& tree)
00302 {
00303 tree = pool_.allocate<Node>();
00304 load_value(stream, *tree);
00305 if (tree->child1!=NULL) {
00306 load_tree(stream, tree->child1);
00307 }
00308 if (tree->child2!=NULL) {
00309 load_tree(stream, tree->child2);
00310 }
00311 }
00312
00313
00314 void computeBoundingBox(BoundingBox& bbox)
00315 {
00316 bbox.resize(dim_);
00317 for (size_t i=0; i<dim_; ++i) {
00318 bbox[i].low = (DistanceType)dataset_[0][i];
00319 bbox[i].high = (DistanceType)dataset_[0][i];
00320 }
00321 for (size_t k=1; k<dataset_.rows; ++k) {
00322 for (size_t i=0; i<dim_; ++i) {
00323 if (dataset_[k][i]<bbox[i].low) bbox[i].low = (DistanceType)dataset_[k][i];
00324 if (dataset_[k][i]>bbox[i].high) bbox[i].high = (DistanceType)dataset_[k][i];
00325 }
00326 }
00327 }
00328
00329
00339 NodePtr divideTree(int left, int right, BoundingBox& bbox)
00340 {
00341 NodePtr node = pool_.allocate<Node>();
00342
00343
00344 if ( (right-left) <= leaf_max_size_) {
00345 node->child1 = node->child2 = NULL;
00346 node->left = left;
00347 node->right = right;
00348
00349
00350 for (size_t i=0; i<dim_; ++i) {
00351 bbox[i].low = (DistanceType)dataset_[vind_[left]][i];
00352 bbox[i].high = (DistanceType)dataset_[vind_[left]][i];
00353 }
00354 for (int k=left+1; k<right; ++k) {
00355 for (size_t i=0; i<dim_; ++i) {
00356 if (bbox[i].low>dataset_[vind_[k]][i]) bbox[i].low=(DistanceType)dataset_[vind_[k]][i];
00357 if (bbox[i].high<dataset_[vind_[k]][i]) bbox[i].high=(DistanceType)dataset_[vind_[k]][i];
00358 }
00359 }
00360 }
00361 else {
00362 int idx;
00363 int cutfeat;
00364 DistanceType cutval;
00365 middleSplit_(&vind_[0]+left, right-left, idx, cutfeat, cutval, bbox);
00366
00367 node->divfeat = cutfeat;
00368
00369 BoundingBox left_bbox(bbox);
00370 left_bbox[cutfeat].high = cutval;
00371 node->child1 = divideTree(left, left+idx, left_bbox);
00372
00373 BoundingBox right_bbox(bbox);
00374 right_bbox[cutfeat].low = cutval;
00375 node->child2 = divideTree(left+idx, right, right_bbox);
00376
00377 node->divlow = left_bbox[cutfeat].high;
00378 node->divhigh = right_bbox[cutfeat].low;
00379
00380 for (size_t i=0; i<dim_; ++i) {
00381 bbox[i].low = std::min(left_bbox[i].low, right_bbox[i].low);
00382 bbox[i].high = std::max(left_bbox[i].high, right_bbox[i].high);
00383 }
00384 }
00385
00386 return node;
00387 }
00388
00389 void computeMinMax(int* ind, int count, int dim, ElementType& min_elem, ElementType& max_elem)
00390 {
00391 min_elem = dataset_[ind[0]][dim];
00392 max_elem = dataset_[ind[0]][dim];
00393 for (int i=1; i<count; ++i) {
00394 ElementType val = dataset_[ind[i]][dim];
00395 if (val<min_elem) min_elem = val;
00396 if (val>max_elem) max_elem = val;
00397 }
00398 }
00399
00400 void middleSplit(int* ind, int count, int& index, int& cutfeat, DistanceType& cutval, const BoundingBox& bbox)
00401 {
00402
00403 ElementType max_span = bbox[0].high-bbox[0].low;
00404 cutfeat = 0;
00405 cutval = (bbox[0].high+bbox[0].low)/2;
00406 for (size_t i=1; i<dim_; ++i) {
00407 ElementType span = bbox[i].high-bbox[i].low;
00408 if (span>max_span) {
00409 max_span = span;
00410 cutfeat = i;
00411 cutval = (bbox[i].high+bbox[i].low)/2;
00412 }
00413 }
00414
00415
00416 ElementType min_elem, max_elem;
00417 computeMinMax(ind, count, cutfeat, min_elem, max_elem);
00418 cutval = (min_elem+max_elem)/2;
00419 max_span = max_elem - min_elem;
00420
00421
00422 size_t k = cutfeat;
00423 for (size_t i=0; i<dim_; ++i) {
00424 if (i==k) continue;
00425 ElementType span = bbox[i].high-bbox[i].low;
00426 if (span>max_span) {
00427 computeMinMax(ind, count, i, min_elem, max_elem);
00428 span = max_elem - min_elem;
00429 if (span>max_span) {
00430 max_span = span;
00431 cutfeat = i;
00432 cutval = (min_elem+max_elem)/2;
00433 }
00434 }
00435 }
00436 int lim1, lim2;
00437 planeSplit(ind, count, cutfeat, cutval, lim1, lim2);
00438
00439 if (lim1>count/2) index = lim1;
00440 else if (lim2<count/2) index = lim2;
00441 else index = count/2;
00442 }
00443
00444
00445 void middleSplit_(int* ind, int count, int& index, int& cutfeat, DistanceType& cutval, const BoundingBox& bbox)
00446 {
00447 const float EPS=0.00001f;
00448 DistanceType max_span = bbox[0].high-bbox[0].low;
00449 for (size_t i=1; i<dim_; ++i) {
00450 DistanceType span = bbox[i].high-bbox[i].low;
00451 if (span>max_span) {
00452 max_span = span;
00453 }
00454 }
00455 DistanceType max_spread = -1;
00456 cutfeat = 0;
00457 for (size_t i=0; i<dim_; ++i) {
00458 DistanceType span = bbox[i].high-bbox[i].low;
00459 if (span>(DistanceType)((1-EPS)*max_span)) {
00460 ElementType min_elem, max_elem;
00461 computeMinMax(ind, count, cutfeat, min_elem, max_elem);
00462 DistanceType spread = (DistanceType)(max_elem-min_elem);
00463 if (spread>max_spread) {
00464 cutfeat = (int)i;
00465 max_spread = spread;
00466 }
00467 }
00468 }
00469
00470 DistanceType split_val = (bbox[cutfeat].low+bbox[cutfeat].high)/2;
00471 ElementType min_elem, max_elem;
00472 computeMinMax(ind, count, cutfeat, min_elem, max_elem);
00473
00474 if (split_val<min_elem) cutval = (DistanceType)min_elem;
00475 else if (split_val>max_elem) cutval = (DistanceType)max_elem;
00476 else cutval = split_val;
00477
00478 int lim1, lim2;
00479 planeSplit(ind, count, cutfeat, cutval, lim1, lim2);
00480
00481 if (lim1>count/2) index = lim1;
00482 else if (lim2<count/2) index = lim2;
00483 else index = count/2;
00484 }
00485
00486
00496 void planeSplit(int* ind, int count, int cutfeat, DistanceType cutval, int& lim1, int& lim2)
00497 {
00498
00499 int left = 0;
00500 int right = count-1;
00501 for (;; ) {
00502 while (left<=right && dataset_[ind[left]][cutfeat]<cutval) ++left;
00503 while (left<=right && dataset_[ind[right]][cutfeat]>=cutval) --right;
00504 if (left>right) break;
00505 std::swap(ind[left], ind[right]); ++left; --right;
00506 }
00507
00508
00509
00510 lim1 = left;
00511 right = count-1;
00512 for (;; ) {
00513 while (left<=right && dataset_[ind[left]][cutfeat]<=cutval) ++left;
00514 while (left<=right && dataset_[ind[right]][cutfeat]>cutval) --right;
00515 if (left>right) break;
00516 std::swap(ind[left], ind[right]); ++left; --right;
00517 }
00518 lim2 = left;
00519 }
00520
00521 DistanceType computeInitialDistances(const ElementType* vec, std::vector<DistanceType>& dists)
00522 {
00523 DistanceType distsq = 0.0;
00524
00525 for (size_t i = 0; i < dim_; ++i) {
00526 if (vec[i] < root_bbox_[i].low) {
00527 dists[i] = distance_.accum_dist(vec[i], root_bbox_[i].low, (int)i);
00528 distsq += dists[i];
00529 }
00530 if (vec[i] > root_bbox_[i].high) {
00531 dists[i] = distance_.accum_dist(vec[i], root_bbox_[i].high, (int)i);
00532 distsq += dists[i];
00533 }
00534 }
00535
00536 return distsq;
00537 }
00538
00542 void searchLevel(ResultSet<DistanceType>& result_set, const ElementType* vec, const NodePtr node, DistanceType mindistsq,
00543 std::vector<DistanceType>& dists, const float epsError)
00544 {
00545
00546 if ((node->child1 == NULL)&&(node->child2 == NULL)) {
00547 DistanceType worst_dist = result_set.worstDist();
00548 for (int i=node->left; i<node->right; ++i) {
00549 int index = reorder_ ? i : vind_[i];
00550 DistanceType dist = distance_(vec, data_[index], dim_, worst_dist);
00551 if (dist<worst_dist) {
00552 result_set.addPoint(dist,vind_[i]);
00553 }
00554 }
00555 return;
00556 }
00557
00558
00559 int idx = node->divfeat;
00560 ElementType val = vec[idx];
00561 DistanceType diff1 = val - node->divlow;
00562 DistanceType diff2 = val - node->divhigh;
00563
00564 NodePtr bestChild;
00565 NodePtr otherChild;
00566 DistanceType cut_dist;
00567 if ((diff1+diff2)<0) {
00568 bestChild = node->child1;
00569 otherChild = node->child2;
00570 cut_dist = distance_.accum_dist(val, node->divhigh, idx);
00571 }
00572 else {
00573 bestChild = node->child2;
00574 otherChild = node->child1;
00575 cut_dist = distance_.accum_dist( val, node->divlow, idx);
00576 }
00577
00578
00579 searchLevel(result_set, vec, bestChild, mindistsq, dists, epsError);
00580
00581 DistanceType dst = dists[idx];
00582 mindistsq = mindistsq + cut_dist - dst;
00583 dists[idx] = cut_dist;
00584 if (mindistsq*epsError<=result_set.worstDist()) {
00585 searchLevel(result_set, vec, otherChild, mindistsq, dists, epsError);
00586 }
00587 dists[idx] = dst;
00588 }
00589
00590 private:
00591
00595 const Matrix<ElementType> dataset_;
00596
00597 IndexParams index_params_;
00598
00599 int leaf_max_size_;
00600 bool reorder_;
00601
00602
00606 std::vector<int> vind_;
00607
00608 Matrix<ElementType> data_;
00609
00610 size_t size_;
00611 size_t dim_;
00612
00616 NodePtr root_node_;
00617
00618 BoundingBox root_bbox_;
00619
00627 PooledAllocator pool_;
00628
00629 Distance distance_;
00630 };
00631
00632 }
00633
00634 #endif //OPENCV_FLANN_KDTREE_SINGLE_INDEX_H_