当前位置:网站首页>Pytorch学习记录2 线性回归(Tensor,Variable)
Pytorch学习记录2 线性回归(Tensor,Variable)
2022-07-17 12:48:00 【好像几块钱】
Tensor
import torch as t
%matplotlib inline #内嵌画图
from matplotlib import pyplot as plt
from IPython import display
t.manual_seed(1000)
def get_fake_data(batch_size=8):
'''产生随机数据,y=x*2+3,加上了一些噪音'''
x=t.rand(batch_size,1)*20
y=x*2+(1+t.randn(batch_size,1))*3
return x,y
x,y=get_fake_data()
plt.scatter(x.squeeze().numpy(),y.squeeze().numpy())
#scatter必须用numpy
<matplotlib.collections.PathCollection at 0x12d6aea47c0>
#随机初始化参数
w=t.rand(1,1)
b=t.zeros(1,1)
lr=0.001 #学习率
for ii in range(20000):
x,y=get_fake_data()
y_pred=x.mm(w)+b.expand_as(y)
loss=0.5*(y_pred-y)**2 #均方误差
loss=loss.sum()
#手动计算梯度
dloss=1
dy_pred=dloss*(y_pred-y)
dw=x.t().mm(dy_pred)
db=dy_pred.sum()
#更新参数
w.sub_(lr*dw)
b.sub_(lr*db)
if ii%1000==0:
#画图
display.clear_output(wait=True)
x=t.arange(0,20).view(-1,1) #arange返回整型numpy
y=x.float().mm(w)+b.expand_as(x)
plt.plot(x.numpy(),y.numpy()) #predicted
x2,y2=get_fake_data(batch_size=20)
plt.scatter(x2.numpy(),y2.numpy()) #true data
plt.xlim(0,20)
plt.ylim(0,41)
plt.show()
plt.pause(0.5)
print(w.squeeze(),b.squeeze()) #压缩
tensor(2.1227) tensor(2.9495)
Variable
import torch as t
from torch.autograd import Variable as V
%matplotlib inline
from matplotlib import pyplot as plt
from IPython import display
t.manual_seed(1000)
def get_fake_data(batch_size=8):
'''产生随机数据y=2*x+3,产生一些噪音'''
x=t.rand(batch_size,1)*20
y=x*2+(1+t.randn(batch_size,1))*3
return x,y
x,y=get_fake_data()
plt.scatter(x.squeeze().numpy(),y.squeeze().numpy())<matplotlib.collections.PathCollection at 0x12d6aea47c0>
w=V(t.rand(1,1),requires_grad=True)
b=V(t.zeros(1,1),requires_grad=True)
lr=0.001
for ii in range(8000):
x,y=get_fake_data()
x,y=V(x),V(y)
y_pred=x.mm(w)+b.expand_as(y)
loss=0.5*(y_pred-y)**2
loss=loss.sum()
loss.backward() #backward函数的输入值和返回值是variable
w.data.sub_(lr*w.grad.data)
b.data.sub_(lr*b.grad.data)
w.grad.data.zero_() #每次反向传播前要将梯度清0
b.grad.data.zero_()
if ii %1000==0:
display.clear_output(wait=True)
x=t.arange(0,20).view(-1,1) #t.arange产生的是整形numpy
y=x.float().mm(w.data)+b.data.expand_as(x)
plt.plot(x.numpy(),y.numpy()) #predicted
x2,y2=get_fake_data(batch_size=20)
plt.scatter(x2.numpy(),y2.numpy()) #true data
plt.xlim(0,20)
plt.ylim(0,41)
plt.show()
plt.pause(0.5)
print(w.data.squeeze(),b.squeeze())
tensor(1.9790) tensor(3.0168, grad_fn=<SqueezeBackward0>)
2022/7/15
边栏推荐
- c# treeView 树形结构递归处理(企业集团型层次树形展示)
- HCIA 静态综合实验报告 7.10
- 【CSP-J 2021】总结
- 分类任务中的类别不平衡问题
- Bidirectional NAT Technology
- String类型函数传递问题
- Zhongke Panyun - Cyberspace Security packet capture topic b.pcap
- Design of the multi live architecture in different places of the king glory mall
- YARN环境中应用程序JAR包冲突问题的分析及解决
- LVI-SAM:激光-IMU-相机紧耦合建图
猜你喜欢

C# SerialPort配置和属性了解
![高效理解 FreeSql WhereDynamicFilter,深入了解设计初衷[.NET ORM]](/img/cb/76200539c59bb865e60e5ea1121feb.png)
高效理解 FreeSql WhereDynamicFilter,深入了解设计初衷[.NET ORM]

Avi 部署使用指南(2):Avi 架构概述

Takeout ordering system based on wechat applet
![Effectively understand FreeSQL wheredynamicfilter and deeply understand the original design intention [.net orm]](/img/cb/76200539c59bb865e60e5ea1121feb.png)
Effectively understand FreeSQL wheredynamicfilter and deeply understand the original design intention [.net orm]

Convert excel table to word table, and keep the formula in Excel table unchanged

NJCTF 2017messager

因果学习将开启下一代AI浪潮?九章云极DataCanvas正式发布YLearn因果学习开源项目

Know what it is, and know why, JS object creation and inheritance

分类任务中的类别不平衡问题
随机推荐
如何使用SVG制作沿任意路径排布的文字效果
koa2 连接 mysql 数据库实现增删改查操作
Design of the multi live architecture in different places of the king glory mall
R language uses the KAP function of epidisplay package to calculate the proportion of calculation consistency of paired contingency tables and the value of kappa statistics, and uses xtabs function to
STL中stack和queue的使用以及模拟实现
String type function transfer problem
Figure an introduction to the interpretable method of neural network and a code example of gnnexplainer interpreting prediction
On the structural types of C language
NJCTF 2017messager
读已提交级别下 注解事务+分布式锁结合引起的事故--活动购买机会的错乱
Domestic flagship mobile phones have a serious price foam, and second-hand phones are more cost-effective than new ones, or buy iPhones
双向NAT技术
HCIA 复习作答 2022.7.6
顺序表的基本建立,以及增删改查的相关操作(c语言描述之顺序表)
opencv 画黑色矩形,并写上序号
微信小程序云开发 1 - 数据库
C # treeview tree structure recursive processing (enterprise group type hierarchical tree display)
新能源赛道已经高位风险,提醒大家注意安全
The select function of dplyr package in R language deletes the data columns in dataframe containing the specified string content (drop columns in dataframe)
HCIA static basic experiment 7.8


