31
#ifndef OPENCV_FLANN_KDTREE_INDEX_H_
32
#define OPENCV_FLANN_KDTREE_INDEX_H_
41
#include "dynamic_bitset.h"
43
#include "result_set.h"
45
#include "allocator.h"
53
struct
KDTreeIndexParams :
public
IndexParams
55
KDTreeIndexParams(
int
trees = 4)
57
(*this)[
"algorithm"] = FLANN_INDEX_KDTREE;
58
(*this)[
"trees"] = trees;
69
template
<
typename
Distance>
70
class
KDTreeIndex :
public
NNIndex<Distance>
73
typedef
typename
Distance::ElementType ElementType;
74
typedef
typename
Distance::ResultType DistanceType;
84
KDTreeIndex(
const
Matrix<ElementType>& inputData,
const
IndexParams& params = KDTreeIndexParams(),
85
Distance d = Distance() ) :
86
dataset_(inputData), index_params_(params), distance_(d)
88
size_ = dataset_.rows;
89
veclen_ = dataset_.cols;
91
trees_ = get_param(index_params_,
"trees",4);
92
tree_roots_ =
new
NodePtr[trees_];
96
for
(
size_t
i = 0; i < size_; ++i) {
100
mean_ =
new
DistanceType[veclen_];
101
var_ =
new
DistanceType[veclen_];
105
KDTreeIndex(
const
KDTreeIndex&);
106
KDTreeIndex& operator=(
const
KDTreeIndex&);
113
if
(tree_roots_!=NULL) {
114
delete[] tree_roots_;
123
void
buildIndex() CV_OVERRIDE
126
for
(
int
i = 0; i < trees_; i++) {
128
#ifndef OPENCV_FLANN_USE_STD_RAND
131
std::random_shuffle(vind_.begin(), vind_.end());
134
tree_roots_[i] = divideTree(&vind_[0],
int(size_) );
139
flann_algorithm_t getType() const CV_OVERRIDE
141
return
FLANN_INDEX_KDTREE;
145
void
saveIndex(FILE* stream) CV_OVERRIDE
147
save_value(stream, trees_);
148
for
(
int
i=0; i<trees_; ++i) {
149
save_tree(stream, tree_roots_[i]);
155
void
loadIndex(FILE* stream) CV_OVERRIDE
157
load_value(stream, trees_);
158
if
(tree_roots_!=NULL) {
159
delete[] tree_roots_;
161
tree_roots_ =
new
NodePtr[trees_];
162
for
(
int
i=0; i<trees_; ++i) {
163
load_tree(stream,tree_roots_[i]);
166
index_params_[
"algorithm"] = getType();
167
index_params_[
"trees"] = tree_roots_;
173
size_t
size() const CV_OVERRIDE
181
size_t
veclen() const CV_OVERRIDE
190
int
usedMemory() const CV_OVERRIDE
192
return
int(pool_.usedMemory+pool_.wastedMemory+dataset_.rows*
sizeof(
int));
204
void
findNeighbors(ResultSet<DistanceType>& result,
const
ElementType* vec,
const
SearchParams& searchParams) CV_OVERRIDE
206
const
int
maxChecks = get_param(searchParams,
"checks", 32);
207
const
float
epsError = 1+get_param(searchParams,
"eps",0.0f);
208
const
bool
explore_all_trees = get_param(searchParams,
"explore_all_trees",
false);
210
if
(maxChecks==FLANN_CHECKS_UNLIMITED) {
211
getExactNeighbors(result, vec, epsError);
214
getNeighbors(result, vec, maxChecks, epsError, explore_all_trees);
218
IndexParams getParameters() const CV_OVERRIDE
220
return
index_params_;
240
Node* child1, * child2;
242
typedef
Node* NodePtr;
243
typedef
BranchStruct<NodePtr, DistanceType> BranchSt;
244
typedef
BranchSt* Branch;
248
void
save_tree(FILE* stream, NodePtr tree)
250
save_value(stream, *tree);
251
if
(tree->child1!=NULL) {
252
save_tree(stream, tree->child1);
254
if
(tree->child2!=NULL) {
255
save_tree(stream, tree->child2);
260
void
load_tree(FILE* stream, NodePtr& tree)
262
tree = pool_.allocate<Node>();
263
load_value(stream, *tree);
264
if
(tree->child1!=NULL) {
265
load_tree(stream, tree->child1);
267
if
(tree->child2!=NULL) {
268
load_tree(stream, tree->child2);
282
NodePtr divideTree(
int* ind,
int
count)
284
NodePtr node = pool_.allocate<Node>();
288
node->child1 = node->child2 = NULL;
289
node->divfeat = *ind;
295
meanSplit(ind, count, idx, cutfeat, cutval);
297
node->divfeat = cutfeat;
298
node->divval = cutval;
299
node->child1 = divideTree(ind, idx);
300
node->child2 = divideTree(ind+idx, count-idx);
312
void
meanSplit(
int* ind,
int
count,
int& index,
int& cutfeat, DistanceType& cutval)
314
memset(mean_,0,veclen_*
sizeof(DistanceType));
315
memset(var_,0,veclen_*
sizeof(DistanceType));
320
int
cnt =
std::min((
int)SAMPLE_MEAN+1, count);
321
for
(
int
j = 0; j < cnt; ++j) {
322
ElementType* v = dataset_[ind[j]];
323
for
(
size_t
k=0; k<veclen_; ++k) {
327
for
(
size_t
k=0; k<veclen_; ++k) {
332
for
(
int
j = 0; j < cnt; ++j) {
333
ElementType* v = dataset_[ind[j]];
334
for
(
size_t
k=0; k<veclen_; ++k) {
335
DistanceType dist = v[k] - mean_[k];
336
var_[k] += dist * dist;
340
cutfeat = selectDivision(var_);
341
cutval = mean_[cutfeat];
344
planeSplit(ind, count, cutfeat, cutval, lim1, lim2);
346
if
(lim1>count/2) index = lim1;
347
else
if
(lim2<count/2) index = lim2;
348
else
index = count/2;
353
if
((lim1==count)||(lim2==0)) index = count/2;
361
int
selectDivision(DistanceType* v)
364
size_t
topind[RAND_DIM];
367
for
(
size_t
i = 0; i < veclen_; ++i) {
368
if
((num < RAND_DIM)||(v[i] > v[topind[num-1]])) {
370
if
(num < RAND_DIM) {
378
while
(j > 0 && v[topind[j]] > v[topind[j-1]]) {
385
int
rnd = rand_int(num);
386
return
(
int)topind[rnd];
399
void
planeSplit(
int* ind,
int
count,
int
cutfeat, DistanceType cutval,
int& lim1,
int& lim2)
405
while
(left<=right && dataset_[ind[left]][cutfeat]<cutval) ++left;
406
while
(left<=right && dataset_[ind[right]][cutfeat]>=cutval) --right;
407
if
(left>right)
break;
408
std::swap(ind[left], ind[right]); ++left; --right;
413
while
(left<=right && dataset_[ind[left]][cutfeat]<=cutval) ++left;
414
while
(left<=right && dataset_[ind[right]][cutfeat]>cutval) --right;
415
if
(left>right)
break;
416
std::swap(ind[left], ind[right]); ++left; --right;
425
void
getExactNeighbors(ResultSet<DistanceType>& result,
const
ElementType* vec,
float
epsError)
430
fprintf(stderr,
"It doesn't make any sense to use more than one tree for exact search");
433
searchLevelExact(result, vec, tree_roots_[0], 0.0, epsError);
443
void
getNeighbors(ResultSet<DistanceType>& result,
const
ElementType* vec,
444
int
maxCheck,
float
epsError,
bool
explore_all_trees =
false)
450
Heap<BranchSt>* heap =
new
Heap<BranchSt>((
int)size_);
451
DynamicBitset checked(size_);
454
for
(i = 0; i < trees_; ++i) {
455
searchLevel(result, vec, tree_roots_[i], 0, checkCount, maxCheck,
456
epsError, heap, checked, explore_all_trees);
457
if
(!explore_all_trees && (checkCount >= maxCheck) && result.full())
462
while
( heap->popMin(branch) && (checkCount < maxCheck || !result.full() )) {
463
searchLevel(result, vec, branch.node, branch.mindist, checkCount, maxCheck,
464
epsError, heap, checked,
false);
478
void
searchLevel(ResultSet<DistanceType>& result_set,
const
ElementType* vec, NodePtr node, DistanceType mindist,
int& checkCount,
int
maxCheck,
479
float
epsError, Heap<BranchSt>* heap, DynamicBitset& checked,
bool
explore_all_trees =
false)
481
if
(result_set.worstDist()<mindist) {
487
if
((node->child1 == NULL)&&(node->child2 == NULL)) {
492
int
index = node->divfeat;
493
if
( checked.test(index) ||
494
(!explore_all_trees && (checkCount>=maxCheck) && result_set.full()) ) {
500
DistanceType dist = distance_(dataset_[index], vec, veclen_);
501
result_set.addPoint(dist,index);
507
ElementType val = vec[node->divfeat];
508
DistanceType diff = val - node->divval;
509
NodePtr bestChild = (diff < 0) ? node->child1 : node->child2;
510
NodePtr otherChild = (diff < 0) ? node->child2 : node->child1;
520
DistanceType new_distsq = mindist + distance_.accum_dist(val, node->divval, node->divfeat);
522
if
((new_distsq*epsError < result_set.worstDist())|| !result_set.full()) {
523
heap->insert( BranchSt(otherChild, new_distsq) );
527
searchLevel(result_set, vec, bestChild, mindist, checkCount, maxCheck, epsError, heap, checked);
533
void
searchLevelExact(ResultSet<DistanceType>& result_set,
const
ElementType* vec,
const
NodePtr node, DistanceType mindist,
const
float
epsError)
536
if
((node->child1 == NULL)&&(node->child2 == NULL)) {
537
int
index = node->divfeat;
538
DistanceType dist = distance_(dataset_[index], vec, veclen_);
539
result_set.addPoint(dist,index);
544
ElementType val = vec[node->divfeat];
545
DistanceType diff = val - node->divval;
546
NodePtr bestChild = (diff < 0) ? node->child1 : node->child2;
547
NodePtr otherChild = (diff < 0) ? node->child2 : node->child1;
557
DistanceType new_distsq = mindist + distance_.accum_dist(val, node->divval, node->divfeat);
560
searchLevelExact(result_set, vec, bestChild, mindist, epsError);
562
if
(new_distsq*epsError<=result_set.worstDist()) {
563
searchLevelExact(result_set, vec, otherChild, new_distsq, epsError);
597
std::vector<int> vind_;
602
const
Matrix<ElementType> dataset_;
604
IndexParams index_params_;
617
NodePtr* tree_roots_;
626
PooledAllocator pool_;
CV_EXPORTS_W void min(InputArray src1, InputArray src2, OutputArray dst)
Calculates per-element minimum of two arrays or an array and a scalar.
CV_EXPORTS_W void randShuffle(InputOutputArray dst, double iterFactor=1., RNG *rng=0)
Shuffles the array elements randomly.
#define CV_Assert(expr)
Checks a condition at runtime and throws exception if it fails
Definition:
base.hpp:342
CV_EXPORTS void swap(Mat &a, Mat &b)
Swaps two matrices