当前位置:网站首页>《PyTorch深度学习实践》-B站 刘二大人-day7
《PyTorch深度学习实践》-B站 刘二大人-day7
2022-07-17 05:20:00 【爱编程的西瓜】
多分类问题
B站 刘二大人 的PyTorch深度学习实践——多分类问题
在多分类的视频中,我们了解到,我们要对图中的数字进行输出分类,判断他们是什么数字,这里就会出现两个问题,一个是让输出相互抑制,并且让概率之和相加正好为1,二是概率都大于0.
视频中截图
说明:
1、softmax的输入不需要再做非线性变换,也就是说softmax之前不再需要激活函数(relu)。softmax两个作用,如果在进行softmax前的input有负数,通过指数变换,得到正数。所有类的概率求和为1。
2、y的标签编码方式是one-hot。我对one-hot的理解是只有一位是1,其他位为0。(但是标签的one-hot编码是算法完成的,算法的输入仍为原始标签)
3、多分类问题,标签y的类型是LongTensor。比如说0-9分类问题,如果y = torch.LongTensor([3]),对应的one-hot是[0,0,0,1,0,0,0,0,0,0].(这里要注意,如果使用了one-hot,标签y的类型是LongTensor,糖尿病数据集中的target的类型是FloatTensor)
4、CrossEntropyLoss <==> LogSoftmax + NLLLoss。也就是说使用CrossEntropyLoss最后一层(线性层)是不需要做其他变化的;使用NLLLoss之前,需要对最后一层(线性层)先进行SoftMax处理,再进行log操作。



代码说明:
1、第8讲 from torch.utils.data import Dataset,第9讲 from torchvision import datasets。该datasets里面init,getitem,len魔法函数已实现。
2、torch.max的返回值有两个,第一个是每一行的最大值是多少,第二个是每一行最大值的下标(索引)是多少。
3、全连接神经网络
4、torch.no_grad() Python中with的用法
5、代码中"_"的说明 Python中各种下划线的操作
6、torch.max( )的用法 torch.max( )使用讲解
import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
# prepare dataset
batch_size = 64
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) # 归一化,均值和方差
train_dataset = datasets.MNIST(root='../dataset/mnist/', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
test_dataset = datasets.MNIST(root='../dataset/mnist/', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)
# design model using class
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.l1 = torch.nn.Linear(784, 512) #数据为28X28的矩阵
self.l2 = torch.nn.Linear(512, 256)
self.l3 = torch.nn.Linear(256, 128)
self.l4 = torch.nn.Linear(128, 64)
self.l5 = torch.nn.Linear(64, 10)
def forward(self, x):
x = x.view(-1, 784) # -1其实就是自动获取mini_batch
x = F.relu(self.l1(x))
x = F.relu(self.l2(x))
x = F.relu(self.l3(x))
x = F.relu(self.l4(x))
return self.l5(x) # 最后一层不做激活,不进行非线性变换
model = Net()
# construct loss and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
# training cycle forward, backward, update
def train(epoch):
running_loss = 0.0
for batch_idx, data in enumerate(train_loader, 0):
# 获得一个批次的数据和标签
inputs, target = data
optimizer.zero_grad()
# 获得模型预测结果(64, 10)
outputs = model(inputs)
# 交叉熵代价函数outputs(64,10),target(64)
loss = criterion(outputs, target)
loss.backward()
optimizer.step()
running_loss += loss.item()
if batch_idx % 300 == 299:
print('[%d, %5d] loss: %.3f' % (epoch+1, batch_idx+1, running_loss/300))
running_loss = 0.0
def test():
correct = 0
total = 0
with torch.no_grad():
for data in test_loader:
images, labels = data
outputs = model(images)
_, predicted = torch.max(outputs.data, dim=1) # dim = 1 列是第0个维度,行是第1个维度
total += labels.size(0)
correct += (predicted == labels).sum().item() # 张量之间的比较运算
print('accuracy on test set: %d %% ' % (100*correct/total))
if __name__ == '__main__':
for epoch in range(10):
train(epoch)
test()
边栏推荐
- Learning non posture gaze deviation with head movement
- Addition and subtraction of busybox date time
- 2022/07/14 学习笔记 (day07)数组
- 山西省第二届网络安全技能大赛(企业组)部分赛题WP(一)
- Knapsack with dependency, narrow sense (binary enumeration), broad sense (tree DP)
- 基于视觉显著性的外观注视估计
- 山西省第二届网络安全技能大赛(企业组)部分赛题WP(四)
- 量子三体问题: 数值计算概述
- [Niuke] traversal of binary tree
- Busybox 1.21.1 has udpsvd function, which can be compiled successfully without interfering with the local busybox method
猜你喜欢

2022/07/09 第五小组 丁帅 学习笔记 day02
![[force buckle] flip binary tree](/img/9f/d0c5f624ae455c5c010f9a6df5df64.png)
[force buckle] flip binary tree
![[force buckle] bracket matching](/img/0d/8290cee0601c106e0ebbffb77d83ab.png)
[force buckle] bracket matching

Cours de mathématiques de base 2 Fonction Euler, écran linéaire, élargissement de l'Europe

Read pictures and convert them to show different color spaces
![[Li Kou] a subtree of another tree](/img/96/1aaf8b8ff310677aeaabe4859eec46.png)
[Li Kou] a subtree of another tree

Attention prediction in self video based on motion and visual prominence

【力扣】另一棵树的子树

QT creator flashback solution

基于视觉显著性的外观注视估计
随机推荐
Learning non posture gaze deviation with head movement
2022/07/09 第五小组 丁帅 学习笔记 day02
颜色直方图 灰度图&彩色图
[force buckle] realize queue with stack
JS variable promotion
【力扣】单值二叉树
Knapsack with dependency, narrow sense (binary enumeration), broad sense (tree DP)
你的企业最适合哪种深度学习?
Computational geometry (4.17)
一种基于凝视的平板电脑手势控制系统的设计与实现
Using VOR depth estimation to solve the problem of target ambiguity in three-dimensional gaze interaction
Interview review nth time
Decorate Apple Tree
[force buckle] single valued binary tree
Design and implementation of a gesture control system for tablet computer based on gaze
ACWing每日一题.3511
Volatile function of embedded C language
Basic mathematics course 2_ Euler function, linear sieve, extended Euler
浅谈跨域的几种解决方案
C language calls the file browser to realize the effect of selecting files