当前位置:网站首页>[pytorch] simple linear regression model
[pytorch] simple linear regression model
2022-07-18 11:41:00 【Li Junfeng】
Preface
Regression is a method that can model the relationship between more variables , It is officially used in machine learning . Linear regression is one of the simplest , It assumes that there is a linear relationship between independent variables and dependent variables . utilize pytorch You can simply write the code of linear regression .
Linear regression
First of all, we need to know the basic assumptions of linear regression :
- There is a linear relationship between independent variables and dependent variables , And allow some noise .
- The existing noise is relatively “ normal ” Of , Sign normal distribution .
Therefore, we can use a simple formula to express this model .
y ^ = w T x + b \hat{y} = w^Tx+b y^=wTx+b
among w,x Are column vectors .
For this simple question , You can use mathematical methods , Get the analytical solution directly , The most common is the least square method .
however , Because the linear regression model is too simple , Only in this way can we find the answer , For other models , It's impossible , Therefore, the random gradient descent method is also used here .
Data sets
This simple data set , You don't need to go anywhere to download , It can be directly generated by itself .
def data_maker(w, b, n_size): # y=w*x+b,n Data
X = torch.normal(0 , 1 , (n_size , len(w))) # n*len(w) Parameters of
y = torch.matmul(X , w) + b
y = y + torch.normal(0, 0.01 , y.shape)
return X , y.reshape((-1 , 1))
Read
No loss of generality , Generally, one is read batch Of , For convenience , You can use yield Write it as an iterator .
def data_iter(batch_size , x , y):
n = len(x)
index = list(range(n))
random.shuffle(index)
for i in range(0 , n , batch_size):
batch_index = torch.tensor(index[i:min(i + batch_size,n)])
yield x[batch_index], y[batch_index]
Model
For better understanding , Here, use the function defined by yourself .
But it needs to use pytorch The automatic gradient of .
def linreg(X,w,b):
return torch.matmul(X , w) + b
def loss_function(Y , y):
return (Y - y.reshape(Y.shape))**2/2
def SGD(params , learning_rate , batch_size):
with torch.no_grad():
for param in params:
param -= learning_rate * param.grad / batch_size
param.grad.zero_()
Training
This is similar to the ordinary model
#training
for epoch in range(num_epochs):
for X ,Y in data_iter(batch_size , x , y):
l = loss(net(X , w,b), Y)
l.sum().backward()
SGD([w,b] , 0.001 , batch_size)
with torch.no_grad():
train_loss = loss(net(x , w , b), y)
print(f'epoch{
epoch+1} , loos {
float(train_loss.mean())}')
Complete code
#!/usr/bin/env python
# coding: utf-8
# In[1]:
import torch
import random
# In[2]:
def data_maker(w, b, n_size): # y=w*x+b,n Data
X = torch.normal(0 , 1 , (n_size , len(w))) # n*len(w) Parameters of
y = torch.matmul(X , w) + b
y = y + torch.normal(0, 0.01 , y.shape)
return X , y.reshape((-1 , 1))
# In[3]:
W = torch.tensor([4.0])
B = 1
x , y = data_maker(W, B, 1000)
# In[4]:
def data_iter(batch_size , x , y):
n = len(x)
index = list(range(n))
random.shuffle(index)
for i in range(0 , n , batch_size):
batch_index = torch.tensor(index[i:min(i + batch_size,n)])
yield x[batch_index], y[batch_index]
# In[5]:
def linreg(X,w,b):
return torch.matmul(X , w) + b
def loss_function(Y , y):
return (Y - y.reshape(Y.shape))**2/2
def SGD(params , learning_rate , batch_size):
with torch.no_grad():
for param in params:
param -= learning_rate * param.grad / batch_size
param.grad.zero_()
# In[6]:
batch_size = 5
num_epochs = 50
net = linreg
loss = loss_function
w = torch.normal(0,1 , size=(1,1) , requires_grad= True)
b = torch.normal(0,1 , size=(1,1) , requires_grad= True)
# In[7]:
#training
for epoch in range(num_epochs):
for X ,Y in data_iter(batch_size , x , y):
l = loss(net(X , w,b), Y)
l.sum().backward()
SGD([w,b] , 0.001 , batch_size)
with torch.no_grad():
train_loss = loss(net(x , w , b), y)
print(f'epoch{
epoch+1} , loos {
float(train_loss.mean())}')
# In[8]:
print(w , b)
边栏推荐
猜你喜欢
随机推荐
Talking about software defect management
C语言预处理命令【一】
What is the once brilliant streaming media protocol RTMP? Can it be completely replaced? Yiwen takes you into the world of RTMP
Flink(四)分流合流
采取哪些行之有效的措施可以提高讨论研究问题的效率?
flutter EventBus
如何写好论文的每个部分?
Chapter 17 oauth2loginauthenticationwebfilter source code analysis
2个群晖使用同一域名不能同时登录,WHY?
Reading the pointpillar code of openpcdet -- Part 2: network structure
Parameters of deep learning model and flops calculation tool
转本结束暑假2022.6.29-7.13我的深圳之行(体验)
mysql 查询时过滤 html
The secret of the three moving averages in the spot gold trend chart
深圳开展建设工程合同网签试点,法大大助力建筑数字化
flutter showDialog弹窗
无线通信安全作业4
Emqx server establishes ssl/tls secure connection, one-way and two-way
jol-core
Flink (V) status programming









