K近邻(K-Nearest Neighbor ,简称KNN ) 是有监督非线性、非参数分类算法,非参数表示对数据集及其分布没有任何假设。它是最简单、最常用的分类算法之一,广泛应用于金融、医疗等领域。
K近邻算法
KNN算法中的k表示邻近数据结点的数量,其算法过程如下:
-
选择邻近结点数量K
-
计算出测试数据结点和K个最近结点的距离
-
在这个K个距离中,对每个分类进行计数
-
依据少数服从多数原则,将测试数据结点归入在K个点中占比最高的那一类
对于KNN分类算法,两点的距离计算采用欧式距离。请看下图:
假设数据集包括两类,分别为红色和蓝色表示。我们选择k为5,即基于欧式距离考虑5个最近结点,所以当测试新数据点时,5个结点,其中国三个蓝色、两个红色。则认为新数据点分类为蓝色。
实战示例
鸢尾花数据集(Iris)包括3种鸢尾(setosa, virginica, versicolor)各50个样本以及多个变量的数据集,是由英国统计学家和生物学家Ronald Fisher在其1936年的论文《The use of multiple measurements in taxonomic problems》中首次引用。Fisher从每个样本中测量了萼片和花瓣的长度和宽度等4个特征,并结合这4个特征建立了一个线性判别模型来区分不同的物种。
- 加载数据集,并查看概要信息
# Loading data
data(iris)
# Structure
str(iris)
- 执行KNN分类
# 加载依赖包
library(e1071)
library(caTools)
library(class)
# 加载数据
# data(iris)
# head(iris)
# 把数据集分为训练集和测试集
split <- sample.split(iris, SplitRatio = 0.7)
train_cl <- subset(iris, split == "TRUE")
test_cl <- subset(iris, split == "FALSE")
# 标准化特征变量
train_scale <- scale(train_cl[, 1:4])
test_scale <- scale(test_cl[, 1:4])
# 使用k=1 拟合 KNN 分类模型
classifier_knn <- knn(train = train_scale,
test = test_scale,
cl = train_cl$Species,
k = 1)
# classifier_knn
# 计算混淆矩阵
cm <- table(test_cl$Species, classifier_knn)
cm
# classifier_knn
# setosa versicolor virginica
# setosa 20 0 0
# versicolor 0 19 1
# virginica 0 0 20
# 模型评估 - 计算样本错误率
misClassError <- mean(classifier_knn != test_cl$Species)
# print(paste('Accuracy =', 1-misClassError))
# [1] "Accuracy = 0.933333333333333"
# K = 3
classifier_knn <- knn(train = train_scale,
test = test_scale,
cl = train_cl$Species,
k = 3)
misClassError <- mean(classifier_knn != test_cl$Species)
print(paste('Accuracy =', 1-misClassError))
# [1] "Accuracy = 0.933333333333333"
# K = 5
classifier_knn <- knn(train = train_scale,
test = test_scale,
cl = train_cl$Species,
k = 5)
misClassError <- mean(classifier_knn != test_cl$Species)
print(paste('Accuracy =', 1-misClassError))
# [1] "Accuracy = 0.95"
# K = 7
classifier_knn <- knn(train = train_scale,
test = test_scale,
cl = train_cl$Species,
k = 7)
misClassError <- mean(classifier_knn != test_cl$Species)
print(paste('Accuracy =', 1-misClassError))
# [1] "Accuracy = 0.966666666666667"
# K = 15
classifier_knn <- knn(train = train_scale,
test = test_scale,
cl = train_cl$Species,
k = 15)
misClassError <- mean(classifier_knn != test_cl$Species)
print(paste('Accuracy =', 1-misClassError))
# [1] "Accuracy = 0.983333333333333"
# K = 19
classifier_knn <- knn(train = train_scale,
test = test_scale,
cl = train_cl$Species,
k = 19)
misClassError <- mean(classifier_knn != test_cl$Species)
print(paste('Accuracy =', 1-misClassError))
# [1] "Accuracy = 0.966666666666667"
当k为15时,模型的准确率达到98.3%,比k为1、3、5、7时的准确率更高。k为19时的精度为96.7%,这意味着增加k值不会增加精度,因此K为15更为合适。
KNN优劣
KNN方法思路简单,易于理解,易于实现,无需估计参数。
该算法在分类时有两个主要不足。当样本不平衡时,如一个类的样本容量很大,而其他类样本容量很小时,有可能导致当输入一个新样本时,该样本的K个邻居中大容量类的样本占多数 。该方法的另一个不足之处是计算量较大,因为对每一个待分类的文本都要计算它到全体已知样本的距离,才能求得它的K个最近邻点 。
本文参考链接:https://blog.csdn.net/neweastsun/article/details/125474160