本文共 7872 字,大约阅读时间需要 26 分钟。
K-D树主要是为了实现机器学习算法中的K近邻算法,单纯的K-D树只能实现最近邻,但是结合优先队列就可以实现K近邻了,这里只是把K-D树简单的实现了一下,经过简单测试,暂时没有发现重大bug。
#ifndef KDTREE_H#define KDTREE_H#include#include #include #include #include #include using ::std::vector;using ::std::cout;using ::std::endl;namespace sx { typedef float DataType; typedef unsigned int UInt; struct Feature { vector data; int id; Feature() {} Feature(const vector & d, int i) : data(d), id(i) {} } /* optional variable list */; template class KDTree { public: KDTree(); virtual ~KDTree(); KDTree(const KDTree & rhs); const KDTree & operator = (const KDTree & rhs); void Clean(); void Build(const vector & matrix_feature); int FindNearestFeature(const Feature & target) const; int FindNearestFeature(const Feature & target, DataType & min_difference) const; void Show() const; private: struct KDNode { KDNode * left; KDNode * right; Feature feature; int depth; KDNode(const Feature & f, KDNode * lt, KDNode * rt, int d) : feature(f), left(lt), right(rt), depth(d) {} } /* optional variable list */; KDNode * root_; struct Comparator { int index_comparator; Comparator(int ix) : index_comparator(ix) {} bool operator () (const Feature & lhs, const Feature & rhs) { return lhs.data[index_comparator] < rhs.data[index_comparator]; } } /* optional variable list */; KDNode * Clone(KDNode * t) const; void Clean(KDNode * & t); void SortFeature(vector & features, int index); void Build(const vector & matrix_feature, KDNode * & t, int depth); DataType Feature2FeatureDifference(const Feature & f1, const Feature & f2) const; int FindNearestFeature(const Feature & target, DataType & min_difference, KDNode * t) const; void Show(KDNode * t) const; }; template KDTree :: KDTree() : root_(NULL) {} template KDTree :: ~KDTree() { Clean(); } template KDTree :: KDTree(const KDTree & rhs) { *this = rhs; } template const KDTree & KDTree :: operator = (const KDTree & rhs) { if (this != &rhs) { Clean(); root_ = Clone(rhs.root_); } return *this; } template void KDTree :: Clean() { Clean(root_); } template void KDTree :: Build(const vector & matrix_feature) { if (matrix_feature.size() != 0) { assert(matrix_feature[0].data.size() == K); } Build(matrix_feature, root_, 0); } template int KDTree :: FindNearestFeature(const Feature & target) const { DataType min_difference; return FindNearestFeature(target, min_difference); } template int KDTree :: FindNearestFeature(const Feature & target, DataType & min_difference) const { min_difference = 10e8; return FindNearestFeature(target, min_difference, root_); } template void KDTree :: Show() const { Show(root_); return ; } template typename KDTree ::KDNode * KDTree :: Clone(KDNode * t) const { if (NULL == t) { return NULL; } return new KDNode(t->feature, t->left, t->right, t->depth); } template void KDTree :: Clean(KDNode * & t) { if (t != NULL) { Clean(t->left); Clean(t->right); delete t; } t = NULL; } template void KDTree :: SortFeature(vector & features, int index) { sort(features.begin(), features.end(), Comparator(index)); } template void KDTree :: Build(const vector & matrix_feature, KDNode * & t, int depth) { if (matrix_feature.size() == 0) { t = NULL; return ; } vector temp_feature = matrix_feature; vector left_feature; vector right_feature; SortFeature(temp_feature, depth % K); int length = (int)temp_feature.size(); int middle_position = length / 2; t = new KDNode(temp_feature[middle_position], NULL, NULL, depth); for (int i = 0; i < middle_position; ++i) { left_feature.push_back(temp_feature[i]); } for (int i = middle_position + 1; i < length; ++i) { right_feature.push_back(temp_feature[i]); } Build(left_feature, t->left, depth + 1); Build(right_feature, t->right, depth + 1); return ; } template DataType KDTree :: Feature2FeatureDifference(const Feature & f1, const Feature & f2) const { DataType diff = 0.0; assert(f1.data.size() == f2.data.size()); for (int i = 0; i < (int)f1.data.size(); ++i) { diff += (f1.data[i] - f2.data[i]) * (f1.data[i] - f2.data[i]); } return sqrt(diff); } template int KDTree :: FindNearestFeature(const Feature & target, DataType & min_difference, KDNode * t) const { if (NULL == t) { return -1; } DataType diff_parent = Feature2FeatureDifference(target, t->feature); DataType diff_left = 10e8; DataType diff_right = 10e8; int result_parent = -1; int result_left = -1; int result_right = -1; if (diff_parent < min_difference) { min_difference = diff_parent; result_parent = t->feature.id; } if (NULL == t->left && NULL == t->right) { return result_parent; } if (NULL == t->left /* && t->right != NULL */) { result_right = FindNearestFeature(target, diff_right, t->right); if (diff_right < min_difference) { min_difference = diff_right; result_parent = result_right; } return result_parent; } if (NULL == t->right /* && t->left != NULL */) { result_left = FindNearestFeature(target, diff_left, t->left); if (diff_left < min_difference) { min_difference = diff_left; result_parent = result_left; } return result_parent; } int index_feature = t->depth % K; DataType diff_boundary = fabs(target.data[index_feature] - t->feature.data[index_feature]); if (target.data[index_feature] < t->feature.data[index_feature]) { result_left = FindNearestFeature(target, diff_left, t->left); if (diff_left < min_difference) { min_difference = diff_left; result_parent = result_left; } if (diff_boundary < Feature2FeatureDifference(target, t->left->feature)) { result_right = FindNearestFeature(target, diff_right, t->right); if (diff_right < min_difference) { min_difference = diff_right; result_parent = result_right; } } } else { result_right = FindNearestFeature(target, diff_right, t->right); if (diff_right < min_difference) { min_difference = diff_right; result_parent = result_right; } if (diff_boundary < Feature2FeatureDifference(target, t->right->feature)) { result_left = FindNearestFeature(target, diff_left, t->left); if (diff_left < min_difference) { min_difference = diff_left; result_parent = result_left; } } } return result_parent; } template void KDTree ::Show(KDNode * t) const { cout << "ID: " << t->feature.id << endl; cout << "Data: "; for (int i = 0; i < (int)t->feature.data.size(); ++i) { cout << t->feature.data[i] << " "; } cout << endl; if (t->left != NULL) { cout << "Left: " << t->feature.id << " -> " << t->left->feature.id << endl; Show(t->left); } if (t->right != NULL) { cout << "Right: " << t->feature.id << " -> " << t->right->feature.id << endl; Show(t->right); } return ; }} /* sx */#endif /* end of include guard: KDTREE_H */
下面是测试的main.cc文件
// =============================================================================//// Filename: main.cc//// Description: K-D Tree//// Version: 1.0// Created: 04/11/2013 04:28:02 PM// Revision: none// Compiler: g++//// Author: Geek SphinX(Perstare et Praestare), geek.sphinx@zoho.com// Organization: Hefei University of Technology//// =============================================================================#include "KDTree.h"#includeusing namespace sx;using namespace std;int main(int argc, char *argv[]) { KDTree<2> kdtree; vector vf; int idx = 0; for (int i = 0; i < 5; ++i) { for (int j = 0; j < 5; ++j) { vector vd; vd.push_back(i); vd.push_back(j); vf.push_back(Feature(vd, idx++)); } } kdtree.Build(vf); kdtree.Show(); Feature target; int n; DataType x, y; cin >> n; while (n--) { cin >> x >> y; vector vd; vd.push_back(x); vd.push_back(y); target = Feature(vd, 0); DataType md; int t = kdtree.FindNearestFeature(target, md); cout << "Result is " << t << endl; for (int i = 0; i < (int)vf[t].data.size(); ++i) { cout << vf[t].data[i] << " "; } cout << endl; cout << "Minimum Difference is " << md << endl; } return 0;}
转载地址:http://dsbqb.baihongyu.com/