31
#ifndef OPENCV_FLANN_KDTREE_SINGLE_INDEX_H_
32
#define OPENCV_FLANN_KDTREE_SINGLE_INDEX_H_
42
#include "result_set.h"
44
#include "allocator.h"
51
struct
KDTreeSingleIndexParams :
public
IndexParams
53
KDTreeSingleIndexParams(
int
leaf_max_size = 10,
bool
reorder =
true,
int
dim = -1)
55
(*this)[
"algorithm"] = FLANN_INDEX_KDTREE_SINGLE;
56
(*this)[
"leaf_max_size"] = leaf_max_size;
57
(*this)[
"reorder"] = reorder;
69
template
<
typename
Distance>
70
class
KDTreeSingleIndex :
public
NNIndex<Distance>
73
typedef
typename
Distance::ElementType ElementType;
74
typedef
typename
Distance::ResultType DistanceType;
84
KDTreeSingleIndex(
const
Matrix<ElementType>& inputData,
const
IndexParams& params = KDTreeSingleIndexParams(),
85
Distance d = Distance() ) :
86
dataset_(inputData), index_params_(params), distance_(d)
88
size_ = dataset_.rows;
91
int
dim_param = get_param(params,
"dim",-1);
92
if
(dim_param>0) dim_ = dim_param;
93
leaf_max_size_ = get_param(params,
"leaf_max_size",10);
94
reorder_ = get_param(params,
"reorder",
true);
98
for
(
size_t
i = 0; i < size_; i++) {
103
KDTreeSingleIndex(
const
KDTreeSingleIndex&);
104
KDTreeSingleIndex& operator=(
const
KDTreeSingleIndex&);
111
if
(reorder_)
delete[] data_.data;
117
void
buildIndex() CV_OVERRIDE
119
computeBoundingBox(root_bbox_);
120
root_node_ = divideTree(0, (
int)size_, root_bbox_ );
124
data_ = cvflann::Matrix<ElementType>(
new
ElementType[size_*dim_], size_, dim_);
125
for
(
size_t
i=0; i<size_; ++i) {
126
for
(
size_t
j=0; j<dim_; ++j) {
127
data_[i][j] = dataset_[vind_[i]][j];
136
flann_algorithm_t getType() const CV_OVERRIDE
138
return
FLANN_INDEX_KDTREE_SINGLE;
142
void
saveIndex(FILE* stream) CV_OVERRIDE
144
save_value(stream, size_);
145
save_value(stream, dim_);
146
save_value(stream, root_bbox_);
147
save_value(stream, reorder_);
148
save_value(stream, leaf_max_size_);
149
save_value(stream, vind_);
151
save_value(stream, data_);
153
save_tree(stream, root_node_);
157
void
loadIndex(FILE* stream) CV_OVERRIDE
159
load_value(stream, size_);
160
load_value(stream, dim_);
161
load_value(stream, root_bbox_);
162
load_value(stream, reorder_);
163
load_value(stream, leaf_max_size_);
164
load_value(stream, vind_);
166
load_value(stream, data_);
171
load_tree(stream, root_node_);
174
index_params_[
"algorithm"] = getType();
175
index_params_[
"leaf_max_size"] = leaf_max_size_;
176
index_params_[
"reorder"] = reorder_;
182
size_t
size() const CV_OVERRIDE
190
size_t
veclen() const CV_OVERRIDE
199
int
usedMemory() const CV_OVERRIDE
201
return
(
int)(pool_.usedMemory+pool_.wastedMemory+dataset_.rows*
sizeof(int));
213
void
knnSearch(
const
Matrix<ElementType>& queries, Matrix<int>& indices, Matrix<DistanceType>& dists,
int
knn,
const
SearchParams& params) CV_OVERRIDE
221
KNNSimpleResultSet<DistanceType> resultSet(knn);
222
for
(
size_t
i = 0; i < queries.rows; i++) {
223
resultSet.init(indices[i], dists[i]);
224
findNeighbors(resultSet, queries[i], params);
228
IndexParams getParameters() const CV_OVERRIDE
230
return
index_params_;
242
void
findNeighbors(ResultSet<DistanceType>& result,
const
ElementType* vec,
const
SearchParams& searchParams) CV_OVERRIDE
244
float
epsError = 1+get_param(searchParams,
"eps",0.0f);
246
std::vector<DistanceType> dists(dim_,0);
247
DistanceType distsq = computeInitialDistances(vec, dists);
248
searchLevel(result, vec, root_node_, distsq, dists, epsError);
268
DistanceType divlow, divhigh;
272
Node* child1, * child2;
274
typedef
Node* NodePtr;
279
DistanceType low, high;
282
typedef
std::vector<Interval> BoundingBox;
284
typedef
BranchStruct<NodePtr, DistanceType> BranchSt;
285
typedef
BranchSt* Branch;
290
void
save_tree(FILE* stream, NodePtr tree)
292
save_value(stream, *tree);
293
if
(tree->child1!=NULL) {
294
save_tree(stream, tree->child1);
296
if
(tree->child2!=NULL) {
297
save_tree(stream, tree->child2);
302
void
load_tree(FILE* stream, NodePtr& tree)
304
tree = pool_.allocate<Node>();
305
load_value(stream, *tree);
306
if
(tree->child1!=NULL) {
307
load_tree(stream, tree->child1);
309
if
(tree->child2!=NULL) {
310
load_tree(stream, tree->child2);
315
void
computeBoundingBox(BoundingBox& bbox)
318
for
(
size_t
i=0; i<dim_; ++i) {
319
bbox[i].low = (DistanceType)dataset_[0][i];
320
bbox[i].high = (DistanceType)dataset_[0][i];
322
for
(
size_t
k=1; k<dataset_.rows; ++k) {
323
for
(
size_t
i=0; i<dim_; ++i) {
324
if
(dataset_[k][i]<bbox[i].low) bbox[i].low = (DistanceType)dataset_[k][i];
325
if
(dataset_[k][i]>bbox[i].high) bbox[i].high = (DistanceType)dataset_[k][i];
340
NodePtr divideTree(
int
left,
int
right, BoundingBox& bbox)
342
NodePtr node = pool_.allocate<Node>();
345
if
( (right-left) <= leaf_max_size_) {
346
node->child1 = node->child2 = NULL;
351
for
(
size_t
i=0; i<dim_; ++i) {
352
bbox[i].low = (DistanceType)dataset_[vind_[left]][i];
353
bbox[i].high = (DistanceType)dataset_[vind_[left]][i];
355
for
(
int
k=left+1; k<right; ++k) {
356
for
(
size_t
i=0; i<dim_; ++i) {
357
if
(bbox[i].low>dataset_[vind_[k]][i]) bbox[i].low=(DistanceType)dataset_[vind_[k]][i];
358
if
(bbox[i].high<dataset_[vind_[k]][i]) bbox[i].high=(DistanceType)dataset_[vind_[k]][i];
366
middleSplit_(&vind_[0]+left, right-left, idx, cutfeat, cutval, bbox);
368
node->divfeat = cutfeat;
370
BoundingBox left_bbox(bbox);
371
left_bbox[cutfeat].high = cutval;
372
node->child1 = divideTree(left, left+idx, left_bbox);
374
BoundingBox right_bbox(bbox);
375
right_bbox[cutfeat].low = cutval;
376
node->child2 = divideTree(left+idx, right, right_bbox);
378
node->divlow = left_bbox[cutfeat].high;
379
node->divhigh = right_bbox[cutfeat].low;
381
for
(
size_t
i=0; i<dim_; ++i) {
382
bbox[i].low =
std::min(left_bbox[i].low, right_bbox[i].low);
383
bbox[i].high =
std::max(left_bbox[i].high, right_bbox[i].high);
390
void
computeMinMax(
int* ind,
int
count,
int
dim, ElementType& min_elem, ElementType& max_elem)
392
min_elem = dataset_[ind[0]][dim];
393
max_elem = dataset_[ind[0]][dim];
394
for
(
int
i=1; i<count; ++i) {
395
ElementType val = dataset_[ind[i]][dim];
396
if
(val<min_elem) min_elem = val;
397
if
(val>max_elem) max_elem = val;
401
void
middleSplit(
int* ind,
int
count,
int& index,
int& cutfeat, DistanceType& cutval,
const
BoundingBox& bbox)
404
ElementType max_span = bbox[0].high-bbox[0].low;
406
cutval = (bbox[0].high+bbox[0].low)/2;
407
for
(
size_t
i=1; i<dim_; ++i) {
408
ElementType span = bbox[i].high-bbox[i].low;
412
cutval = (bbox[i].high+bbox[i].low)/2;
417
ElementType min_elem, max_elem;
418
computeMinMax(ind, count, cutfeat, min_elem, max_elem);
419
cutval = (min_elem+max_elem)/2;
420
max_span = max_elem - min_elem;
424
for
(
size_t
i=0; i<dim_; ++i) {
426
ElementType span = bbox[i].high-bbox[i].low;
428
computeMinMax(ind, count, i, min_elem, max_elem);
429
span = max_elem - min_elem;
433
cutval = (min_elem+max_elem)/2;
438
planeSplit(ind, count, cutfeat, cutval, lim1, lim2);
440
if
(lim1>count/2) index = lim1;
441
else
if
(lim2<count/2) index = lim2;
442
else
index = count/2;
446
void
middleSplit_(
int* ind,
int
count,
int& index,
int& cutfeat, DistanceType& cutval,
const
BoundingBox& bbox)
448
const
float
EPS=0.00001f;
449
DistanceType max_span = bbox[0].high-bbox[0].low;
450
for
(
size_t
i=1; i<dim_; ++i) {
451
DistanceType span = bbox[i].high-bbox[i].low;
456
DistanceType max_spread = -1;
458
for
(
size_t
i=0; i<dim_; ++i) {
459
DistanceType span = bbox[i].high-bbox[i].low;
460
if
(span>(DistanceType)((1-EPS)*max_span)) {
461
ElementType min_elem, max_elem;
462
computeMinMax(ind, count, (
int)i, min_elem, max_elem);
463
DistanceType spread = (DistanceType)(max_elem-min_elem);
464
if
(spread>max_spread) {
471
DistanceType split_val = (bbox[cutfeat].low+bbox[cutfeat].high)/2;
472
ElementType min_elem, max_elem;
473
computeMinMax(ind, count, cutfeat, min_elem, max_elem);
475
if
(split_val<min_elem) cutval = (DistanceType)min_elem;
476
else
if
(split_val>max_elem) cutval = (DistanceType)max_elem;
477
else
cutval = split_val;
480
planeSplit(ind, count, cutfeat, cutval, lim1, lim2);
482
if
(lim1>count/2) index = lim1;
483
else
if
(lim2<count/2) index = lim2;
484
else
index = count/2;
497
void
planeSplit(
int* ind,
int
count,
int
cutfeat, DistanceType cutval,
int& lim1,
int& lim2)
503
while
(left<=right && dataset_[ind[left]][cutfeat]<cutval) ++left;
504
while
(left<=right && dataset_[ind[right]][cutfeat]>=cutval) --right;
505
if
(left>right)
break;
506
std::swap(ind[left], ind[right]); ++left; --right;
514
while
(left<=right && dataset_[ind[left]][cutfeat]<=cutval) ++left;
515
while
(left<=right && dataset_[ind[right]][cutfeat]>cutval) --right;
516
if
(left>right)
break;
517
std::swap(ind[left], ind[right]); ++left; --right;
522
DistanceType computeInitialDistances(
const
ElementType* vec, std::vector<DistanceType>& dists)
524
DistanceType distsq = 0.0;
526
for
(
size_t
i = 0; i < dim_; ++i) {
527
if
(vec[i] < root_bbox_[i].low) {
528
dists[i] = distance_.accum_dist(vec[i], root_bbox_[i].low, (
int)i);
531
if
(vec[i] > root_bbox_[i].high) {
532
dists[i] = distance_.accum_dist(vec[i], root_bbox_[i].high, (
int)i);
543
void
searchLevel(ResultSet<DistanceType>& result_set,
const
ElementType* vec,
const
NodePtr node, DistanceType mindistsq,
544
std::vector<DistanceType>& dists,
const
float
epsError)
547
if
((node->child1 == NULL)&&(node->child2 == NULL)) {
548
DistanceType worst_dist = result_set.worstDist();
550
for
(
int
i=node->left; i<node->right; ++i) {
551
DistanceType dist = distance_(vec, data_[i], dim_, worst_dist);
552
if
(dist<worst_dist) {
553
result_set.addPoint(dist,vind_[i]);
557
for
(
int
i=node->left; i<node->right; ++i) {
558
DistanceType dist = distance_(vec, data_[vind_[i]], dim_, worst_dist);
559
if
(dist<worst_dist) {
560
result_set.addPoint(dist,vind_[i]);
568
int
idx = node->divfeat;
569
ElementType val = vec[idx];
570
DistanceType diff1 = val - node->divlow;
571
DistanceType diff2 = val - node->divhigh;
575
DistanceType cut_dist;
576
if
((diff1+diff2)<0) {
577
bestChild = node->child1;
578
otherChild = node->child2;
579
cut_dist = distance_.accum_dist(val, node->divhigh, idx);
582
bestChild = node->child2;
583
otherChild = node->child1;
584
cut_dist = distance_.accum_dist( val, node->divlow, idx);
588
searchLevel(result_set, vec, bestChild, mindistsq, dists, epsError);
590
DistanceType dst = dists[idx];
591
mindistsq = mindistsq + cut_dist - dst;
592
dists[idx] = cut_dist;
593
if
(mindistsq*epsError<=result_set.worstDist()) {
594
searchLevel(result_set, vec, otherChild, mindistsq, dists, epsError);
604
const
Matrix<ElementType> dataset_;
606
IndexParams index_params_;
615
std::vector<int> vind_;
617
Matrix<ElementType> data_;
627
BoundingBox root_bbox_;
636
PooledAllocator pool_;
CV_EXPORTS_W void max(InputArray src1, InputArray src2, OutputArray dst)
Calculates per-element maximum of two arrays or an array and a scalar.
CV_EXPORTS_W void min(InputArray src1, InputArray src2, OutputArray dst)
Calculates per-element minimum of two arrays or an array and a scalar.
#define CV_Assert(expr)
Checks a condition at runtime and throws exception if it fails
Definition:
base.hpp:342
CV_EXPORTS void swap(Mat &a, Mat &b)
Swaps two matrices