当前位置:网站首页>《PyTorch深度学习实践》-B站 刘二大人-day3
《PyTorch深度学习实践》-B站 刘二大人-day3
2022-07-17 05:20:00 【爱编程的西瓜】
反向传播back propagation
B站 刘二大人 ,传送门PyTroch 深度学习实践——反向传播
代码说明:
1、w是Tensor(张量类型),Tensor中包含data和grad,data和grad也是Tensor。grad初始为None,调用l.backward()方法后w.grad为Tensor,故更新w.data时需使用w.grad.data。如果w需要计算梯度,那构建的计算图中,跟w相关的tensor都默认需要计算梯度。
刘老师视频中a = torch.Tensor([1.0]) 本文中更改为 a = torch.tensor([1.0])。两种方法都可以,个人习惯第二种。
import torch
a = torch.tensor([1.0])
a.requires_grad = True # 或者 a.requires_grad_()
print(a)
print(a.data)
print(a.type()) # a的类型是tensor
print(a.data.type()) # a.data的类型是tensor
print(a.grad)
print(type(a.grad))
结果为:
2、w是Tensor, forward函数的返回值也是Tensor,loss函数的返回值也是Tensor
3、本算法中反向传播主要体现在,l.backward()。调用该方法后w.grad由None更新为Tensor类型,且w.grad.data的值用于后续w.data的更新。
l.backward()会把计算图中所有需要梯度(grad)的地方都会求出来,然后把梯度都存在对应的待求的参数中,最终计算图被释放。
取tensor中的data是不会构建计算图的。
import torch
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
w = torch.tensor([1.0]) # w的初值为1.0
w.requires_grad = True # 需要计算梯度
def forward(x):
return x*w # w是一个Tensor
def loss(x, y):
y_pred = forward(x)
return (y_pred - y)**2
print("predict (before training)", 4, forward(4).item())
for epoch in range(100):
for x, y in zip(x_data, y_data):
l =loss(x,y) # l是一个张量,tensor主要是在建立计算图 forward, compute the loss
l.backward() # backward,compute grad for Tensor whose requires_grad set to True
print('\tgrad:', x, y, w.grad.item())
w.data = w.data - 0.01 * w.grad.data # 权重更新时,注意grad也是一个tensor
w.grad.data.zero_() # after update, remember set the grad to zero
print('progress:', epoch, l.item()) # 取出loss使用l.item,不要直接使用l(l是tensor会构建计算图)
print("predict (after training)", 4, forward(4).item())
课程中留的三个作业
1、手动推导线性模型y=w*x,损失函数loss=(ŷ-y)²下,当数据集x=2,y=4的时候,反向传播的过程。
答:
2、手动推导线性模型 y=w*x+b,损失函数loss=(ŷ-y)²下,当数据集x=1,y=2的时候,反向传播的过程。
答:
3、画出二次模型y=w1x²+w2x+b,损失函数loss=(ŷ-y)²的计算图,并且手动推导反向传播的过程,最后用pytorch的代码实现。
答:
构建和推导的过程
import numpy as np
import matplotlib.pyplot as plt
import torch
x_data = [1.0,2.0,3.0]
y_data = [2.0,4.0,6.0]
w1 = torch.Tensor([1.0])#初始权值
w1.requires_grad = True#计算梯度,默认是不计算的
w2 = torch.Tensor([1.0])
w2.requires_grad = True
b = torch.Tensor([1.0])
b.requires_grad = True
def forward(x):
return w1 * x**2 + w2 * x + b
def loss(x,y):#构建计算图
y_pred = forward(x)
return (y_pred-y) **2
print('Predict (befortraining)',4,forward(4))
for epoch in range(100):
l = loss(1, 2)#为了在for循环之前定义l,以便之后的输出,无实际意义
for x,y in zip(x_data,y_data):
l = loss(x, y)
l.backward()
print('\tgrad:',x,y,w1.grad.item(),w2.grad.item(),b.grad.item())
w1.data = w1.data - 0.01*w1.grad.data #注意这里的grad是一个tensor,所以要取他的data
w2.data = w2.data - 0.01 * w2.grad.data
b.data = b.data - 0.01 * b.grad.data
w1.grad.data.zero_() #释放之前计算的梯度
w2.grad.data.zero_()
b.grad.data.zero_()
print('Epoch:',epoch,l.item())
print('Predict(after training)',4,forward(4).item())
结束语
在用y=w1x²+w2x+b的模型训练100次后可以看到当x=4时,y=8.5,与正确值8相差比较大。原因可能是数据集本身是一次函数的数据,模型是二次函数。所以模型本身就不适合这个数据集,所以才导致预测结果和正确值相差比较大的情况。
边栏推荐
猜你喜欢

Unity2d learning Fox game production process 1: basic game character control, animation effects, lens control, item collection, bug optimization

2022/07/10 group 5 Ding Shuai's study notes day03

感知智能手機上用戶的關注狀態

Visual saliency based visual gaze estimation

虚拟现实中的眼睛跟踪

《PyTorch深度学习实践》-B站 刘二大人-day5

Robot stitching gesture recognition and classification
![Open source online markdown editor -- [editor.md]](/img/f3/b37acf934aa2526d99c8f585b6f229.png)
Open source online markdown editor -- [editor.md]

《PyTorch深度学习实践》-B站 刘二大人-day4

2022/07/11 group 5 Ding Shuai's study notes day04
随机推荐
山西省第二届网络安全技能大赛(企业组)部分赛题WP(四)
Set the index library structure, add suggestions that can be automatically completed to users, and turn some fields into collections and put them into suggestions
斑点检测 记录
Experiment 4 operator overloading and virtual functions
从零开始的 Rust 语言 blas 库之预备篇(2)—— blas 矩阵格式详解
Local makefile compile other folder files specify obj directory
Vscode Tips 1
虚拟现实中的眼睛跟踪
Internship written examination answers
Busybox specified date modification temporarily does not require clock -w to write to hardware
XOR gun (bit operation, thinking, interval violence)
Solutions to slow transmission speed of FileZilla virtual machine
Restclient query document
Positional change of the eyeball during eye movements: evidence of translational movement
Robot stitching gesture recognition and classification
三维凝视估计,没有明确的个人校准2018
#MySql MySql 计算今年有多少天周末(周六、日)
Talking about several solutions of cross domain
Key points of embedded C language (const, static, volatile, bit operation)
单表查询、添加、更新与删除数据