当前位置:网站首页>Fisher线性判别分析Fisher Linear Distrimination
Fisher线性判别分析Fisher Linear Distrimination
2022-07-17 00:10:00 【elkluh】
Fisher线性判别分析是一种线性分类方法,它的主要思想是:是类内的方差小,类均值之间相差比较大。(类间大,类内小)
以两个类的分类为例:
将两个类由在x1,x2上投影到向量u 上,这样由二维转到了一维,然后将两类从两团点的中间分开。

如果要使类间相差大的话,那么每个类的平均数之间也会相差大,设分别为加号点和减号点的平均值,那么投影后,他们距离的平方,也就是
尽可能大。
如果要使类内方差小的话,那么两个类投影到直线(向量)上后,他们的点分别为
,表示两个协方差的投影。


所以他们的和也要尽可能小
因此把大的作为分子,小的作为分母,他们相除的整体就是越大越好,

设

则 
对J(u)进行求导,则有
令导数等于0, 得到 
括号里的可以用一个缩放值来代替
则 
因为
和
在相乘之后变为常数
因此最终 
就是我们要投影的向量。
import matplotlib.pyplot as plt
import numpy as np
def gauss2D(x, m, C):
Ci = np.linalg.inv(C) #求矩阵的逆
dC = np.linalg.det(C) #求矩阵的行列式
num = np.exp(-0.5 * np.dot((x-m).T, np.dot(Ci,(x-m))))
den = 2 * np.pi * (dC**0.5) #计算矩阵的密度函数
return num/den
def twoDGaussianPlot(nx, ny, m, C):
x = np.linspace(-6, 6, nx)
y = np.linspace(-6, 6, ny)
X, Y = np.meshgrid(x, y, indexing='ij')
Z = np.zeros([nx,ny])
for i in range(nx):
for j in range(ny):
xvec = np.array([X[i,j], Y[i,j]])
Z[i,j] = gauss2D(xvec, m, C)
return X, Y, Z
X = np.random.randn(200, 2)
C1 = np.array([[2,1],[1,2]])
C2 = np.array([[2,1],[1,2]])
m1 = np.array([0, 3])
m2 = np.array([3,2.5])
A = np.linalg.cholesky(C1)
Y1 = X @ A.T + m1
Y2 = X @ A.T + m2
plt.figure(1)
plt.scatter(Y1[:,0], Y1[:,1], c='c', s=4)
plt.scatter(Y2[:,0], Y2[:,1], c='m', s=4)
Xp, Yp, Zp = twoDGaussianPlot(40,50,m1,C1)
plt.contour(Xp, Yp, Zp, 5)
Xp2, Yp2, Zp2 = twoDGaussianPlot(40,50,m2,C2)
plt.contour(Xp2, Yp2, Zp2, 5)
uF = np.linalg.inv(C1 + C2)@(m1-m2)
print(uF)
#ax.arrow(0, 0, *(uF*10), color='b', linewidth=2.0, head_width=0.20, head_length=0.25)
plt.arrow(0, 0, *(uF), color='b', linewidth=2.0, head_width=0.30, head_length=0.35)
plt.axis('equal')
plt.grid()
plt.xlim([-6,6])
plt.ylim([-5,8])
plt.savefig('density graph.png')
yp1 = Y1 @ uF
yp2 = Y2 @ uF
plt.figure(2)
plt.rcParams.update({'font.size':16})
plt.hist(yp1, bins=40)
plt.hist(yp2, bins=40)
plt.savefig('histogramprojections.png')


边栏推荐
- 【文献阅读】Small-Footprint Keyword Spotting with Multi-Scale Temporal Convolution
- [literature reading] counting integer points in parametric polymers using barvinok's rational functions
- ViLT Vision-and-Language Transformer Without Convolution or Region Supervision
- 数组定义格式
- 5章 性能平台GodEye源码分析-第三方模块
- 雾计算中的数据安全问题综述
- Today's code farmer girl learned about nodejs and repl interactive interpreter
- nodejs-uuid
- Valgrind detailed tutorial (1) MemCheck
- 基于CSI的通信感知一体化设计:问题、挑战和展望
猜你喜欢

二阶边缘检测 - Laplacian of Guassian 高斯拉普拉斯算子

Common asynchronous sending code writing

Leveraging Semi-Supervised Learning for Fairness using Neural Networks

Handling Conditional Discrimination(可解释歧视和确切的歧视)

Windbos download and install openssh

动手学深度学习--多层感知机篇(MLP)
![[MySQL] windows install MySQL 5.7](/img/71/be5b0cc3e130c2b9f3884d90b9cd39.jpg)
[MySQL] windows install MySQL 5.7

Fair Multiple Decision Making Through Soft Interventions

集成学习

动手学深度学习---从全连接层到卷积层篇
随机推荐
MapReduce
Mxnet network model of show me the code (III)
CheckPoint and DataNode
VSCode中安装Go:tools failed to install.
Dhfs read / write process
【文献阅读】TENET: A Framework for Modeling Tensor Dataflow Based on Relation-centric Notation
The following packages have unmet dependencies: deepin.com.wechat:i386 : Depends: deepin-wine:i386
What are the NFT digital collection platforms? Which platforms are worth collecting?
Rivaliser pour la guerre clé des utilisateurs de stock, aider les entreprises à construire un système d'étiquetage parfait 丨 01 examen en direct
记录一次海外图片加载不出来的排查
实时开发平台建设实践,深入释放实时数据价值丨04期直播回顾
袋鼠云数栈基于CBO在Spark SQL优化上的探索
数据资产为王,如何解析企业数字化转型与数据资产管理的关系?
MXNet网络模型(四)GAN神经网络
apt-get update报错:Hash 校验和不符
组合键截图分析
WKWebView 设置自定义UserAgent正确姿势
The platform of digital collection NFT is good
基于移动终端的MIMO阵列三维成像技术
ACE下载地址