Skip to main content
 首页 » 人工智能

R 实现线性判别分析教程

2022年07月19日179qlqwjy

本文介绍线性判别分析概念,并通过示例介绍R的实现过程。

介绍线性判别分析模型

线性判别分析用于基于一组变量把响应变量分为俩类或更多的算法。但线性判别算法对数据有一些要求:

  • 响应变量必须是类别变量。线性判别是分类算法,因此响应变量应该是类别变量。

  • 预测变量应遵循正太分布。首先检查每个预测变量是否大致符合正太分布,如果不满足,需要选择转换算法使其近似满足。

  • 每个预测变量有相同的标准差。现实中很难能够满足该条件,但我们可以对数据进行标准化,让变量统一为标准差为1,均值为0.

  • 检查异常值。在用于LDA之前要检查异常值。可以简单通过箱线图或散点图查进行检测。

LDA模型在现实中应用广泛,下面简单举例:

  • 市场营销

零售公司经常使用LDA将购物者分为几类。然后利用建立LDA模型来预测特定购物者是低消费者、中等消费者还是高消费者,使用预测变量如收入、年度总消费额和家庭人数等变量。

  • 医学领域

医院或医疗机构的研究人员通常利用LDA预测给定一组异常细胞是否会导致轻微、中度或严重疾病。

  • 产品研发

一些公司会利用LDA模型预测消费者属于每天、每周、每月或年使用他们的产品,基于预测变量有性别、年度收入、使用类似产品的频率。

  • 生态领域

研究者利用LDA模型预测是否给定珊瑚礁的健康状况:好、中等、坏、严重。预测变量包括大小、年度污染情况、年份。

加载实现库

library(MASS) 
library(ggplot2) 

载入数据

我们打算使用内置的iris数据,下面代码展示如何载入查看数据。

str(iris) 
 
# 'data.frame':	150 obs. of  5 variables: 
#  $ Sepal.Length: num  5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ... 
#  $ Sepal.Width : num  3.5 3 3.2 3.1 3.6 3.9 3.4 3.4 2.9 3.1 ... 
#  $ Petal.Length: num  1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ... 
#  $ Petal.Width : num  0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ... 
#  $ Species     : Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 1 1 1 1 1 1 ... 

我们看到共包括5个变量150个观察记录。下面通过线性判别分析模型对给定鸢尾花进行分类。

我们使用下面四个预测变量:

  • Sepal.length
  • Sepal.Width
  • Petal.Length
  • Petal.Width

预测响应变量为Species,分别包括三类:

  • setosa
  • versicolor
  • virginica

数据标准化

线性判别算法其中一个关键假设为每个预测变量具有相同的标准差。一种简单办法可以对预测变量进行标准化,这样预测变量统一为均值为0、方差为1。

我们使用内置的scale函数,并利用apply函数进行验证:

iris[1:4] <- scale(iris[1:4]) 
 
apply(iris[1:4], 2, mean) 
#  Sepal.Length   Sepal.Width  Petal.Length   Petal.Width  
# -3.219358e-18 -4.916405e-18 -1.440616e-17 -1.822508e-17  
 
apply(iris[1:4], 2, sd) 
# Sepal.Length  Sepal.Width Petal.Length  Petal.Width  
#            1            1            1            1  

创建训练和测试数据集

接下来我们把数据分为训练集和测试集:

set.seed(1) 
 
sample <- sample(c(TRUE, FALSE), nrow(iris), replace = TRUE, prob = c(.7, .3)) 
 
train <- iris[sample, ] 
test  <- iris[!sample,] 

拟合LDA模型

下面我们利用MASS包中的lda函数实现LDA模型:

library(MASS) 
model <- lda(Species~., data=train) 
model 
 
# Call: 
# lda(Species ~ ., data = train) 
#  
# Prior probabilities of groups: 
#     setosa versicolor  virginica  
#  0.3207547  0.3207547  0.3584906  
#  
# Group means: 
#            Sepal.Length Sepal.Width Petal.Length Petal.Width 
# setosa       -1.0397484   0.8131654   -1.2891006  -1.2570316 
# versicolor    0.1820921  -0.6038909    0.3403524   0.2208153 
# virginica     0.9582674  -0.1919146    1.0389776   1.1229172 
#  
# Coefficients of linear discriminants: 
#                     LD1        LD2 
# Sepal.Length  0.7922820  0.5294210 
# Sepal.Width   0.5710586  0.7130743 
# Petal.Length -4.0762061 -2.7305131 
# Petal.Width  -2.0602181  2.6326229 
#  
# Proportion of trace: 
#    LD1    LD2  
# 0.9921 0.0079  

下面我们解释上面的输出:

  • 每组的先验概率

这些表示训练集数据中每组的概率。如:所有训练集中35.8%的观测值属于virginica类别。

  • 组均值

这些数据显示每类每个预测变量的均值。

  • 线性判别系数

这里展示了LDA模型的判别规则,每个预测变量的线性组合情况:

  • LD1: .792Sepal.Length + .571Sepal.Width – 4.076Petal.Length – 2.06Petal.Width
  • LD2: .529Sepal.Length + .713Sepal.Width – 2.731Petal.Length + 2.63Petal.Width
  • 分离百分比

这些展示了每个线性判别函数实现的分离百分比。

使用模型进行预测

我们已经使用训练数据拟合了模型,下面使用模型对测试数据进行预测:

predicted <- predict(model, test) 
 
names(predicted) 
 
head(predicted$class) 
# [1] setosa setosa setosa setosa setosa setosa 
# Levels: setosa versicolor virginica 
 
head(predicted$posterior) 
 
#    setosa   versicolor    virginica 
# 4       1 2.425563e-17 1.341984e-35 
# 6       1 1.400976e-21 4.482684e-40 
# 7       1 3.345770e-19 1.511748e-37 
# 15      1 6.389105e-31 7.361660e-53 
# 17      1 1.193282e-25 2.238696e-45 
# 18      1 6.445594e-22 4.894053e-41 
 
head(predicted$x) 
 
#          LD1        LD2 
# 4   7.150360 -0.7177382 
# 6   7.961538  1.4839408 
# 7   7.504033  0.2731178 
# 15 10.170378  1.9859027 
# 17  8.885168  2.1026494 
# 18  8.113443  0.7563902 

我们看到输出列表中包括三个变量:

  • class 预测类型
  • posterior 每个类别对应的后验概率
  • x 线性判别

下面我们来看LDA模型正确预测类型的百分比:

mean(predicted$class == test$Species) 
# [1] 1 

输出显示模型预测正确率100%。在现实世界中模型很少能够预测每个类别都完全正确,因内置iris数据集比较简单,预测结果比较好。

可视化结果

最后,我们创建LDA图形观察线性判别模型,通过图示方式展示三种类型区分情况:

library(ggplot2) 
lda_plot <- cbind(train, predict(model)$x) 
 
ggplot(lda_plot, aes(LD1, LD2)) + geom_point(aes(color=Species)) 

在这里插入图片描述


本文参考链接:https://blog.csdn.net/neweastsun/article/details/122501074