当前位置:网站首页>【pytorch】简单的线性回归模型
【pytorch】简单的线性回归模型
2022-07-15 23:59:00 【李峻枫】
前言
回归是一种能更加多个变量之间的关系进行建模的一种方法,其在机器学习中有着官方运用。线性回归是其中最最最最简单的一种,其假设自变量与因变量之间是线性关系。利用pytorch就可以简单地写出线性回归的代码。
线性回归
首先需要知道线性回归的基本假设:
- 自变量和因变量之间是线性关系,并且允许存在些噪声。
- 存在的噪声都是比较“正常”的,符号正态分布。
因此可以用一个简单是式子来表示这个模型。
y ^ = w T x + b \hat{y} = w^Tx+b y^=wTx+b
其中w,x均是列向量。
对于这个简单的问题,可以用数学的方法,直接求出解析解,最常见的就说最小二乘法。
但是,因为线性回归模型过于简单,才可以这样求出答案,对于其他的模型,是不可能的,因此在这里同样是使用随机梯度下降法。
数据集
这种简单的数据集,并不需要去哪里下载,直接自行生成即可。
def data_maker(w, b, n_size): # y=w*x+b,n个数据
X = torch.normal(0 , 1 , (n_size , len(w))) # n*len(w)的参数
y = torch.matmul(X , w) + b
y = y + torch.normal(0, 0.01 , y.shape)
return X , y.reshape((-1 , 1))
读取
不失一般性,一般都是读取一个batch的,为了方便,可以利用yield将其写成一个迭代器。
def data_iter(batch_size , x , y):
n = len(x)
index = list(range(n))
random.shuffle(index)
for i in range(0 , n , batch_size):
batch_index = torch.tensor(index[i:min(i + batch_size,n)])
yield x[batch_index], y[batch_index]
模型
为了更好的理解,此处使用自己定义的函数。
但是需要用到pytorch的自动求梯度。
def linreg(X,w,b):
return torch.matmul(X , w) + b
def loss_function(Y , y):
return (Y - y.reshape(Y.shape))**2/2
def SGD(params , learning_rate , batch_size):
with torch.no_grad():
for param in params:
param -= learning_rate * param.grad / batch_size
param.grad.zero_()
训练
这里和普通的模型差不多
#training
for epoch in range(num_epochs):
for X ,Y in data_iter(batch_size , x , y):
l = loss(net(X , w,b), Y)
l.sum().backward()
SGD([w,b] , 0.001 , batch_size)
with torch.no_grad():
train_loss = loss(net(x , w , b), y)
print(f'epoch{
epoch+1} , loos {
float(train_loss.mean())}')
完整代码
#!/usr/bin/env python
# coding: utf-8
# In[1]:
import torch
import random
# In[2]:
def data_maker(w, b, n_size): # y=w*x+b,n个数据
X = torch.normal(0 , 1 , (n_size , len(w))) # n*len(w)的参数
y = torch.matmul(X , w) + b
y = y + torch.normal(0, 0.01 , y.shape)
return X , y.reshape((-1 , 1))
# In[3]:
W = torch.tensor([4.0])
B = 1
x , y = data_maker(W, B, 1000)
# In[4]:
def data_iter(batch_size , x , y):
n = len(x)
index = list(range(n))
random.shuffle(index)
for i in range(0 , n , batch_size):
batch_index = torch.tensor(index[i:min(i + batch_size,n)])
yield x[batch_index], y[batch_index]
# In[5]:
def linreg(X,w,b):
return torch.matmul(X , w) + b
def loss_function(Y , y):
return (Y - y.reshape(Y.shape))**2/2
def SGD(params , learning_rate , batch_size):
with torch.no_grad():
for param in params:
param -= learning_rate * param.grad / batch_size
param.grad.zero_()
# In[6]:
batch_size = 5
num_epochs = 50
net = linreg
loss = loss_function
w = torch.normal(0,1 , size=(1,1) , requires_grad= True)
b = torch.normal(0,1 , size=(1,1) , requires_grad= True)
# In[7]:
#training
for epoch in range(num_epochs):
for X ,Y in data_iter(batch_size , x , y):
l = loss(net(X , w,b), Y)
l.sum().backward()
SGD([w,b] , 0.001 , batch_size)
with torch.no_grad():
train_loss = loss(net(x , w , b), y)
print(f'epoch{
epoch+1} , loos {
float(train_loss.mean())}')
# In[8]:
print(w , b)
边栏推荐
- autojs脚本备忘
- Reading a data driven graph generic model for temporary interaction networks
- 如何基于知识图谱技术构建现代搜索引擎系统、智能问答系统、智能推荐系统?
- Flink(二)时间和窗口
- Flutter lifecycle
- 曲伟海:坚持选择不放弃 是实现初心的法宝
- Reading a data driven graph generic model for temporary interaction networks
- 低代码渲染那些事
- ThinkPHP 3 adds word segmentation weight search function and phpanalysis plug-in
- PBFT简单介绍
猜你喜欢

英特尔助力开立医疗推动超声产检智能化

Acceptance test experiment based on fitness

The secret of the three moving averages in the spot gold trend chart

JVM briefly introduces GC garbage collection mechanism

深圳开展建设工程合同网签试点,法大大助力建筑数字化

Flink (III) processing function

02. Resttemplate learning notes

不同的评估方法适合哪种机器学习模型?

现货黄金走势图中三条均线的秘密

02、RestTemplate学习笔记
随机推荐
Reading a data driven graph generic model for temporary interaction networks
How does wechat applet realize pull-down refresh?
Envoy monitoring management
计算除去部门最高工资,和最低工资的平均工资(字节跳动面试)
删除.idea目录后,svn菜单恢复操作
flutter provide
Installation method of memory module on dual channel (Dual CPU) server motherboard
转本结束暑假2022.6.29-7.13我的深圳之行(体验)
喜事乐,一场蓄谋已久的购物新时代
Financial industry open platform
Double thread guessing numbers
免费SSL证书申请及部署实践
[development tutorial 6] crazy shell · open source Bluetooth smart health watch - touch
Interface test - process test supports batch parameter import, and the test efficiency is directly full
Memory management: memory allocation and recycling
Torch 常用 Tricks 总结
PHP conversion hours ago
Flink(七)Flink SQL
02. Resttemplate learning notes
双通道(双CPU)服务器主板上内存条的安装方式