00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023 #pragma once
00024
00025 #include "cinder/Cinder.h"
00026 #include "cinder/Vector.h"
00027
00028 #include <vector>
00029 #include <float.h>
00030 #include <stdlib.h>
00031 #include <algorithm>
00032 #include <utility>
00033
00034 namespace cinder {
00035
00036
00037 template<unsigned char K>
00038 struct KdNode {
00039 void init( float p, uint32_t a) {
00040 splitPos = p;
00041 splitAxis = a;
00042 rightChild = ~0;
00043 hasLeftChild = 0;
00044 }
00045 void initLeaf() {
00046 splitAxis = K;
00047 rightChild = ~0;
00048 hasLeftChild = 0;
00049 }
00050
00051 float splitPos;
00052 uint32_t splitAxis:2;
00053 uint32_t hasLeftChild:1;
00054 uint32_t rightChild:29;
00055 };
00056
00057 struct NullLookupProc {
00058 public:
00059 void process( uint32_t id, float distSqrd, float &maxDistSqrd ) {}
00060 };
00061
00062 template <typename NodeData, unsigned char K=3, class LookupProc = NullLookupProc> class KdTree {
00063 public:
00064 typedef std::pair<const NodeData*, uint32_t> NodeDataIndex;
00065
00066
00067 template<typename NodeDataVector>
00068 KdTree( const NodeDataVector &data );
00069 KdTree() {}
00070 template<typename NodeDataVector>
00071 void initialize( const NodeDataVector &d );
00072 ~KdTree() {
00073 free( nodes );
00074 delete[] mNodeData;
00075 }
00076 void recursiveBuild( uint32_t nodeNum, uint32_t start, uint32_t end, std::vector<NodeDataIndex> &buildNodes );
00077 void lookup( const NodeData &p, const LookupProc &process, float maxDist ) const;
00078 void findNearest( float p[K], float result[K], uint32_t *resultIndex ) const;
00079
00080 private:
00081
00082 void privateLookup(uint32_t nodeNum, float p[K], const LookupProc &process, float &maxDistSquared) const;
00083 void privateFindNearest( uint32_t nodeNum, float p[K], float &maxDistSquared, float result[K], uint32_t *resultIndex ) const;
00084
00085 KdNode<K> *nodes;
00086 NodeDataIndex *mNodeData;
00087 uint32_t nNodes, nextFreeNode;
00088 };
00089
00090
00091
00092 template<typename NDV>
00093 struct NodeDataVectorTraits
00094 {
00095 static uint32_t getSize( const NDV &ndv ) {
00096 return static_cast<uint32_t>( ndv.size() );
00097 }
00098 };
00099
00100 template<typename NodeData>
00101 struct NodeDataTraits
00102 {
00103 static float getAxis( const NodeData &data, int axis ) {
00104 if( axis == 0 ) return data.x;
00105 else if( axis == 1 ) return data.y;
00106 else return (float)data.z;
00107 }
00108 static float getAxis0( const NodeData &data ) { return static_cast<float>( data.x ); }
00109 static float getAxis1( const NodeData &data ) { return static_cast<float>( data.y ); }
00110 static float getAxis2( const NodeData &data ) { return static_cast<float>( data.z ); }
00111 static float distanceSquared( const NodeData &data, float k[3] ) {
00112 float result = ( data.x - k[0] ) * ( data.x - k[0] );
00113 result += ( data.y - k[1] ) * ( data.y - k[1] );
00114 result += ( data.z - k[2] ) * ( data.z - k[2] );
00115 return result;
00116 }
00117 };
00118
00119 template<>
00120 struct NodeDataTraits<Vec2f>
00121 {
00122 static float getAxis( const Vec2f &data, int axis ) {
00123 if( axis == 0 ) return data.x;
00124 else return data.y;
00125 }
00126 static float getAxis0( const Vec2f &data ) { return static_cast<float>( data.x ); }
00127 static float getAxis1( const Vec2f &data ) { return static_cast<float>( data.y ); }
00128 static float distanceSquared( const Vec2f &data, float k[2] ) {
00129 float result = ( data.x - k[0] ) * ( data.x - k[0] );
00130 result += ( data.y - k[1] ) * ( data.y - k[1] );
00131 return result;
00132 }
00133 };
00134
00135 template<typename NodeData> struct CompareNode {
00136 CompareNode( int a ) { axis = a; }
00137 int axis;
00138 bool operator()(const std::pair<const NodeData*,uint32_t> &d1,
00139 const std::pair<const NodeData*,uint32_t> &d2) const {
00140 return NodeDataTraits<NodeData>::getAxis( *d1.first, axis ) == NodeDataTraits<NodeData>::getAxis( *d2.first, axis ) ? ( d1.first < d2.first ) :
00141 NodeDataTraits<NodeData>::getAxis( *d1.first, axis ) < NodeDataTraits<NodeData>::getAxis( *d2.first, axis );
00142 }
00143 };
00144
00145
00146 template<typename NodeData, unsigned char K, typename LookupProc>
00147 template<typename NodeDataVector>
00148 KdTree<NodeData, K, LookupProc>::KdTree(const NodeDataVector &d)
00149 {
00150 initialize( d );
00151 }
00152
00153 template<typename NodeData, unsigned char K, typename LookupProc>
00154 template<typename NodeDataVector>
00155 void KdTree<NodeData, K, LookupProc>::initialize( const NodeDataVector &d )
00156 {
00157 nNodes = NodeDataVectorTraits<NodeDataVector>::getSize( d );
00158 nextFreeNode = 1;
00159 nodes = (KdNode<K> *)malloc(nNodes * sizeof(KdNode<K>));
00160 mNodeData = new NodeDataIndex[nNodes];
00161 std::vector<NodeDataIndex> buildNodes;
00162 buildNodes.reserve( nNodes );
00163 for( uint32_t i = 0; i < nNodes; ++i )
00164 buildNodes.push_back( std::make_pair( &d[i], i ) );
00165
00166 recursiveBuild( 0, 0, nNodes, buildNodes );
00167 }
00168
00169 template<typename NodeData, unsigned char K, typename LookupProc>
00170 void KdTree<NodeData, K, LookupProc>::recursiveBuild( uint32_t nodeNum, uint32_t start, uint32_t end, std::vector<NodeDataIndex> &buildNodes )
00171 {
00172
00173 if( start + 1 == end) {
00174 nodes[nodeNum].initLeaf();
00175 mNodeData[nodeNum] = buildNodes[start];
00176 return;
00177 }
00178
00179
00180 float boundMin[K], boundMax[K];
00181 for( unsigned char k = 0; k < K; ++k ) {
00182 boundMin[k] = FLT_MAX;
00183 boundMax[k] = FLT_MIN;
00184 }
00185
00186 for( uint32_t i = start; i < end; ++i ) {
00187 for( uint8_t axis = 0; axis < K; axis++ ) {
00188
00189 boundMin[axis] = std::min( boundMin[axis], NodeDataTraits<NodeData>::getAxis( *buildNodes[i].first, axis ) );
00190 boundMax[axis] = std::max( boundMax[axis], NodeDataTraits<NodeData>::getAxis( *buildNodes[i].first, axis ) );
00191 }
00192 }
00193 int splitAxis = 0;
00194 float maxExtent = boundMax[0] - boundMin[0];
00195 for( unsigned char k = 1; k < K; ++k ) {
00196 if( boundMax[k] - boundMin[k] > maxExtent ) {
00197 splitAxis = k;
00198 maxExtent = boundMax[k] - boundMin[k];
00199 }
00200 }
00201 uint32_t splitPos = ( start + end ) / 2;
00202 std::nth_element( &buildNodes[start], &buildNodes[splitPos], &buildNodes[end-1] + 1, CompareNode<NodeData>(splitAxis) );
00203
00204 nodes[nodeNum].init( NodeDataTraits<NodeData>::getAxis( *buildNodes[splitPos].first, splitAxis ), splitAxis );
00205 mNodeData[nodeNum] = buildNodes[splitPos];
00206 if( start < splitPos ) {
00207 nodes[nodeNum].hasLeftChild = 1;
00208 uint32_t childNum = nextFreeNode++;
00209 recursiveBuild( childNum, start, splitPos, buildNodes );
00210 }
00211 if( splitPos + 1 < end ) {
00212 nodes[nodeNum].rightChild = nextFreeNode++;
00213 recursiveBuild( nodes[nodeNum].rightChild, splitPos + 1, end, buildNodes );
00214 }
00215 }
00216
00217 template<typename NodeData, unsigned char K, typename LookupProc>
00218 void KdTree<NodeData, K, LookupProc>::lookup( const NodeData &p, const LookupProc &proc, float maxDist ) const
00219 {
00220 float maxDistSqrd = maxDist * maxDist;
00221 float pt[K];
00222 for( unsigned char k = 0; k < K; ++k )
00223 pt[k] = NodeDataTraits<NodeData>::getAxis( p, k );
00224
00225 privateLookup( 0, pt, proc, maxDistSqrd );
00226 }
00227
00228 template<typename NodeData, unsigned char K, typename LookupProc>
00229 void KdTree<NodeData, K, LookupProc>::privateLookup( uint32_t nodeNum, float p[K], const LookupProc &process, float &maxDistSquared ) const
00230 {
00231 KdNode<K> *node = &nodes[nodeNum];
00232
00233 int axis = node->splitAxis;
00234 if( axis != K ) {
00235 float dist2 = ( p[axis] - node->splitPos ) * ( p[axis] - node->splitPos );
00236 if( p[axis] <= node->splitPos ) {
00237 if(node->hasLeftChild)
00238 privateLookup( nodeNum + 1, p, process, maxDistSquared );
00239 if( ( dist2 < maxDistSquared ) && ( node->rightChild < nNodes ) )
00240 privateLookup( node->rightChild, p, process, maxDistSquared );
00241 }
00242 else {
00243 if( node->rightChild < nNodes )
00244 privateLookup( node->rightChild, p, process, maxDistSquared );
00245 if( ( dist2 < maxDistSquared ) && node->hasLeftChild )
00246 privateLookup( nodeNum + 1, p, process, maxDistSquared );
00247 }
00248 }
00249
00250 float distSqr = 0.0f;
00251 for( unsigned char k = 0; k < K; ++k ) {
00252 float v = NodeDataTraits<NodeData>::getAxis( *mNodeData[nodeNum].first, k ) - p[k];
00253 distSqr += v * v;
00254 }
00255
00256 if( distSqr < maxDistSquared )
00257 process.process( mNodeData[nodeNum].second, distSqr, maxDistSquared );
00258 }
00259
00260
00261 template<typename NodeData, unsigned char K, typename LookupProc>
00262 void KdTree<NodeData, K, LookupProc>::findNearest( float p[K], float result[K], uint32_t *resultIndex ) const
00263 {
00264 float maxDist = FLT_MAX;
00265 *resultIndex = -1;
00266 privateFindNearest( 0, p, maxDist, result, resultIndex );
00267 }
00268
00269 template<typename NodeData, unsigned char K, typename LookupProc>
00270 void KdTree<NodeData, K, LookupProc>::privateFindNearest( uint32_t nodeNum, float p[K], float &maxDistSquared, float result[K], uint32_t *resultIndex ) const
00271 {
00272 KdNode<K> *node = &nodes[nodeNum];
00273
00274 int axis = node->splitAxis;
00275 if( axis != K ) {
00276 float dist2 = (p[axis] - node->splitPos) * (p[axis] - node->splitPos);
00277 if( p[axis] <= node->splitPos ) {
00278 if( node->hasLeftChild )
00279 privateFindNearest( nodeNum+1, p, maxDistSquared, result, resultIndex );
00280 if( ( dist2 < maxDistSquared ) && ( node->rightChild < nNodes) )
00281 privateFindNearest( node->rightChild, p, maxDistSquared, result, resultIndex );
00282 }
00283 else {
00284 if( node->rightChild < nNodes)
00285 privateFindNearest(node->rightChild,
00286 p,
00287 maxDistSquared, result, resultIndex );
00288 if( dist2 < maxDistSquared && node->hasLeftChild)
00289 privateFindNearest(nodeNum+1,
00290 p,
00291 maxDistSquared, result, resultIndex );
00292 }
00293 }
00294
00295 float distSqr = NodeDataTraits<NodeData>::distanceSquared( *mNodeData[nodeNum].first, p );
00296 if( distSqr < maxDistSquared ) {
00297 maxDistSquared = distSqr;
00298 for( unsigned char k = 0; k < K; ++k )
00299 result[k] = NodeDataTraits<NodeData>::getAxis( *mNodeData[nodeNum].first, k );
00300 *resultIndex = mNodeData[nodeNum].second;
00301 }
00302 }
00303
00304 }