博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
K-D树 C++实现
阅读量:2441 次
发布时间:2019-05-10

本文共 7872 字,大约阅读时间需要 26 分钟。

       K-D树主要是为了实现机器学习算法中的K近邻算法,单纯的K-D树只能实现最近邻,但是结合优先队列就可以实现K近邻了,这里只是把K-D树简单的实现了一下,经过简单测试,暂时没有发现重大bug。

#ifndef KDTREE_H#define KDTREE_H#include 
#include
#include
#include
#include
#include
using ::std::vector;using ::std::cout;using ::std::endl;namespace sx { typedef float DataType; typedef unsigned int UInt; struct Feature { vector
data; int id; Feature() {} Feature(const vector
& d, int i) : data(d), id(i) {} } /* optional variable list */; template
class KDTree { public: KDTree(); virtual ~KDTree(); KDTree(const KDTree & rhs); const KDTree & operator = (const KDTree & rhs); void Clean(); void Build(const vector
& matrix_feature); int FindNearestFeature(const Feature & target) const; int FindNearestFeature(const Feature & target, DataType & min_difference) const; void Show() const; private: struct KDNode { KDNode * left; KDNode * right; Feature feature; int depth; KDNode(const Feature & f, KDNode * lt, KDNode * rt, int d) : feature(f), left(lt), right(rt), depth(d) {} } /* optional variable list */; KDNode * root_; struct Comparator { int index_comparator; Comparator(int ix) : index_comparator(ix) {} bool operator () (const Feature & lhs, const Feature & rhs) { return lhs.data[index_comparator] < rhs.data[index_comparator]; } } /* optional variable list */; KDNode * Clone(KDNode * t) const; void Clean(KDNode * & t); void SortFeature(vector
& features, int index); void Build(const vector
& matrix_feature, KDNode * & t, int depth); DataType Feature2FeatureDifference(const Feature & f1, const Feature & f2) const; int FindNearestFeature(const Feature & target, DataType & min_difference, KDNode * t) const; void Show(KDNode * t) const; }; template
KDTree
:: KDTree() : root_(NULL) {} template
KDTree
:: ~KDTree() { Clean(); } template
KDTree
:: KDTree(const KDTree & rhs) { *this = rhs; } template
const KDTree
& KDTree
:: operator = (const KDTree & rhs) { if (this != &rhs) { Clean(); root_ = Clone(rhs.root_); } return *this; } template
void KDTree
:: Clean() { Clean(root_); } template
void KDTree
:: Build(const vector
& matrix_feature) { if (matrix_feature.size() != 0) { assert(matrix_feature[0].data.size() == K); } Build(matrix_feature, root_, 0); } template
int KDTree
:: FindNearestFeature(const Feature & target) const { DataType min_difference; return FindNearestFeature(target, min_difference); } template
int KDTree
:: FindNearestFeature(const Feature & target, DataType & min_difference) const { min_difference = 10e8; return FindNearestFeature(target, min_difference, root_); } template
void KDTree
:: Show() const { Show(root_); return ; } template
typename KDTree
::KDNode * KDTree
:: Clone(KDNode * t) const { if (NULL == t) { return NULL; } return new KDNode(t->feature, t->left, t->right, t->depth); } template
void KDTree
:: Clean(KDNode * & t) { if (t != NULL) { Clean(t->left); Clean(t->right); delete t; } t = NULL; } template
void KDTree
:: SortFeature(vector
& features, int index) { sort(features.begin(), features.end(), Comparator(index)); } template
void KDTree
:: Build(const vector
& matrix_feature, KDNode * & t, int depth) { if (matrix_feature.size() == 0) { t = NULL; return ; } vector
temp_feature = matrix_feature; vector
left_feature; vector
right_feature; SortFeature(temp_feature, depth % K); int length = (int)temp_feature.size(); int middle_position = length / 2; t = new KDNode(temp_feature[middle_position], NULL, NULL, depth); for (int i = 0; i < middle_position; ++i) { left_feature.push_back(temp_feature[i]); } for (int i = middle_position + 1; i < length; ++i) { right_feature.push_back(temp_feature[i]); } Build(left_feature, t->left, depth + 1); Build(right_feature, t->right, depth + 1); return ; } template
DataType KDTree
:: Feature2FeatureDifference(const Feature & f1, const Feature & f2) const { DataType diff = 0.0; assert(f1.data.size() == f2.data.size()); for (int i = 0; i < (int)f1.data.size(); ++i) { diff += (f1.data[i] - f2.data[i]) * (f1.data[i] - f2.data[i]); } return sqrt(diff); } template
int KDTree
:: FindNearestFeature(const Feature & target, DataType & min_difference, KDNode * t) const { if (NULL == t) { return -1; } DataType diff_parent = Feature2FeatureDifference(target, t->feature); DataType diff_left = 10e8; DataType diff_right = 10e8; int result_parent = -1; int result_left = -1; int result_right = -1; if (diff_parent < min_difference) { min_difference = diff_parent; result_parent = t->feature.id; } if (NULL == t->left && NULL == t->right) { return result_parent; } if (NULL == t->left /* && t->right != NULL */) { result_right = FindNearestFeature(target, diff_right, t->right); if (diff_right < min_difference) { min_difference = diff_right; result_parent = result_right; } return result_parent; } if (NULL == t->right /* && t->left != NULL */) { result_left = FindNearestFeature(target, diff_left, t->left); if (diff_left < min_difference) { min_difference = diff_left; result_parent = result_left; } return result_parent; } int index_feature = t->depth % K; DataType diff_boundary = fabs(target.data[index_feature] - t->feature.data[index_feature]); if (target.data[index_feature] < t->feature.data[index_feature]) { result_left = FindNearestFeature(target, diff_left, t->left); if (diff_left < min_difference) { min_difference = diff_left; result_parent = result_left; } if (diff_boundary < Feature2FeatureDifference(target, t->left->feature)) { result_right = FindNearestFeature(target, diff_right, t->right); if (diff_right < min_difference) { min_difference = diff_right; result_parent = result_right; } } } else { result_right = FindNearestFeature(target, diff_right, t->right); if (diff_right < min_difference) { min_difference = diff_right; result_parent = result_right; } if (diff_boundary < Feature2FeatureDifference(target, t->right->feature)) { result_left = FindNearestFeature(target, diff_left, t->left); if (diff_left < min_difference) { min_difference = diff_left; result_parent = result_left; } } } return result_parent; } template
void KDTree
::Show(KDNode * t) const { cout << "ID: " << t->feature.id << endl; cout << "Data: "; for (int i = 0; i < (int)t->feature.data.size(); ++i) { cout << t->feature.data[i] << " "; } cout << endl; if (t->left != NULL) { cout << "Left: " << t->feature.id << " -> " << t->left->feature.id << endl; Show(t->left); } if (t->right != NULL) { cout << "Right: " << t->feature.id << " -> " << t->right->feature.id << endl; Show(t->right); } return ; }} /* sx */#endif /* end of include guard: KDTREE_H */

      下面是测试的main.cc文件

// =============================================================================////       Filename:  main.cc////    Description:  K-D Tree////        Version:  1.0//        Created:  04/11/2013 04:28:02 PM//       Revision:  none//       Compiler:  g++////         Author:  Geek SphinX(Perstare et Praestare), geek.sphinx@zoho.com//   Organization:  Hefei University of Technology//// =============================================================================#include "KDTree.h"#include 
using namespace sx;using namespace std;int main(int argc, char *argv[]) { KDTree<2> kdtree; vector
vf; int idx = 0; for (int i = 0; i < 5; ++i) { for (int j = 0; j < 5; ++j) { vector
vd; vd.push_back(i); vd.push_back(j); vf.push_back(Feature(vd, idx++)); } } kdtree.Build(vf); kdtree.Show(); Feature target; int n; DataType x, y; cin >> n; while (n--) { cin >> x >> y; vector
vd; vd.push_back(x); vd.push_back(y); target = Feature(vd, 0); DataType md; int t = kdtree.FindNearestFeature(target, md); cout << "Result is " << t << endl; for (int i = 0; i < (int)vf[t].data.size(); ++i) { cout << vf[t].data[i] << " "; } cout << endl; cout << "Minimum Difference is " << md << endl; } return 0;}

转载地址:http://dsbqb.baihongyu.com/

你可能感兴趣的文章
如何在CentOS 8上安装Node.js
查看>>
如何在Ubuntu 20.04上安装Git
查看>>
javascript深度图_在JavaScript中深度克隆对象(及其工作方式)
查看>>
centos ssh密钥_如何在CentOS 8上设置SSH密钥
查看>>
JavaScript中的初学者排序算法:冒泡,选择和插入排序
查看>>
debian 10 安装_如何在Debian 10上安装Webmin
查看>>
使用CentOS 8进行初始服务器设置
查看>>
ecmascript v3_节点v12中的新ECMAScript模块简介
查看>>
vue jest 测试_Vue.js中的Jest快照测试简介
查看>>
Electron.js简介-第2部分:Todo应用
查看>>
gatsby_Gatsby CLI快速参考
查看>>
vue中的突变方法在哪_了解GraphQL中的突变
查看>>
redis修改配置重启命令_如何从命令行更改Redis的配置
查看>>
盖茨比乔布斯_使用wrapRootElement挂钩在盖茨比进行状态管理
查看>>
盖茨比乔布斯_通过盖茨比使用Airtable
查看>>
mern技术栈好处?_如何开始使用MERN堆栈
查看>>
路由器接路由器_路由器之战:到达路由器vsReact路由器
查看>>
rxjs 搜索_如何使用RxJS构建搜索栏
查看>>
如何在Debian 10上安装MariaDB
查看>>
go函数的可变长参数_如何在Go中使用可变参数函数
查看>>