当前位置:网站首页>PyTorch学习笔记【4】:从图像学习
PyTorch学习笔记【4】:从图像学习
2022-07-17 05:10:00 【zzzyzh】
文章目录
前言
本文是基于《Pytorch深度学习实战》一书第七章的内容所整理的学习笔记
相关代码的解释以及对应的拓展。
本文使用的代码均基于jupyter
1. 微小图像数据集合
%matplotlib inline
from matplotlib import pyplot as plt
import numpy as np
import torch
torch.set_printoptions(edgeitems=2, linewidth=75)
torch.manual_seed(123)
1.1. 下载CIFAR-10
from torchvision import datasets
data_path = 'data/p1ch7/'
cifar10 = datasets.CIFAR10(data_path, train=True, download=True) # 实例画一个数据集用于训练数据,如果数据集不存在,则TorchVision将下载该数据集
cifar10_val = datasets.CIFAR10(data_path, train=False, download=True) # 使用train=False,获取一个数据集用于验证数据,并在需要时再次下载该数据集
class_names = ['airplane','automobile','bird','cat','deer',
'dog','frog','horse','ship','truck']
fig = plt.figure(figsize=(8,3))
num_classes = 10
for i in range(num_classes):
ax = fig.add_subplot(2, 5, 1 + i, xticks=[], yticks=[])
ax.set_title(class_names[i])
img = next(img for img, label in cifar10 if label == i)
plt.imshow(img)
plt.show()
数据集子模块为我们提供了对最流行的计算机视觉数据集的预存储访问,在每种情况下,数据集都作为torch.utils.data.Dataset的子类返回。
type(cifar10).__mro__
- mro
为了方便且快速地看清继承关系和顺序,可以用__mro__方法来获取这个类的调用顺序。
class X(object):pass
class Y(object):pass
class A(X, Y):pass
class B(Y):pass
class C(A, B):pass
C.__mro__w
1.2. Dataset类
表示数据集的抽象类,它不一定持有数据,但是它提供了对七进行统一访问的函数__len__()和__getitem__(),且子类必须继承上述两个函数
- len():获取数据集长度
len(cifar10) - getitem():获取样本对,模型直接通过这一函数获得一对样本对{x:y}
img, label = cifar10[99] img, label, class_names[label]
且这个对象是RGB PIL(Python Imaging Library)图像的一个实例,可以被打印出来
plt.imshow(img)
plt.show()

1.3. Dataset变换
torchvision.transforms
这个模块定义了一组可组合的、类似函数的对象,它可以作为参数传递到TorchVision模块的数据集
1.3.1. 可用对象的列表
from torchvision import transforms
dir(transforms)
- dir()
内置的函数 dir() 可以找到模块内定义的所有名称。以一个字符串列表的形式返回
1.3.2. ToTensor()
一旦ToTensor被实例化,就可以向调用函数一样调用它,以PIL图像作为参数,返回一个张量作为输出
from torchvision import transforms
to_tensor = transforms.ToTensor()
img_t = to_tensor(img)
img_t.shape
我们也可以将transform直接作为参数传递给dataset.CIFAR10
tensor_cifar10 = datasets.CIFAR10(data_path, train=True, download=False, transform=transforms.ToTensor())
与之前相比,访问数据集的元素将返回一个张量,而不是PIL图像
img_t, _ = tensor_cifar10[99]
type(img_t), img_t.shape, img_t.dtype
img_t.min(), img_t.max()
打印此时的图片结果
plt.imshow(img_t.permute(1, 2, 0)) # 将轴的顺序由CxHxW改为HxWxC以匹配matplotlib
plt.show()

1.4. 数据归一化
变换非常方便,因为我们可以使用 transforms.Compose()将它们连接起来,以实现一系列transform操作,然后在数据加载器中直接透明地进行数据归一化和数据增强操作。
对每个通道进行归一化使其具有相同的分布,可以保证在相同的学习率下,通过梯度下降实现通道信息的混合和更新。
为了使每个通道的均值为0、标准差为 1,我们可以应用以下转换来计算数据集中每个通道的平均值和标淮差:v_n[c]=(v[c]-mean[c]) / stdev[c]。这正是transforms.Normalize0/所做的。
我们将数据集返回的素有张量沿着一个额外的纬度进行堆叠
imgs = torch.stack([img_t for img_t, _ in tensor_cifar10], dim=3)
imgs.shape
- stack()
沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状。
可以理解为把多个2维的张量凑成一个3维的张量;多个3维的凑成一个4维的张量……以此类推,也就是在增加新的维度进行堆叠。# 假设是时间步T1的输出 T1 = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) # 假设是时间步T2的输出 T2 = torch.tensor([[10, 20, 30], [40, 50, 60], [70, 80, 90]]) print(torch.stack((T1,T2),dim=0).shape) print(torch.stack((T1,T2),dim=1).shape) print(torch.stack((T1,T2),dim=1)) print(torch.stack((T1,T2),dim=2).shape)
计算每个通道的平均值
imgs.view(3, -1) # view(3, -1)保留了3个通道,并将剩余的所有纬度合并为一个纬度,从而计算出适当的尺寸大小。这里我们的3x32x32的图像被转换成一个3x1024的向量,然后对每个通道的1024个元素取平均值
- torch.view(-1) & torch.view(a, -1)
在参数a已知的情况下自动补齐列向量长度
imgs.view(3, -1).mean(dim=1) # 计算均值
imgs.view(3, -1).std(dim=1) # 计算标准差
transforms.Normalize((0.4915, 0.4823, 0.4468), (0.2470, 0.2435, 0.2616)) # 初始化Normalize变换
使用Compose连接多个变换
transformed_cifar10 = datasets.CIFAR10(
data_path, train=True, download=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4915, 0.4823, 0.4468),
(0.2470, 0.2435, 0.2616))
]))
transformed_cifar10_val = datasets.CIFAR10(
data_path, train=False, download=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4915, 0.4823, 0.4468),
(0.2470, 0.2435, 0.2616))
]))
注意,此时,从数据集绘制的图像不能为我们提供实际图像的真实表示
img_t, _ = transformed_cifar10[99]
plt.imshow(img_t.permute(1, 2, 0))
plt.show()

2. 区分鸟和飞机
%matplotlib inline
from matplotlib import pyplot as plt
import numpy as np
import torch
torch.set_printoptions(edgeitems=2)
torch.manual_seed(123)
class_names = ['airplane','automobile','bird','cat','deer',
'dog','frog','horse','ship','truck']
from torchvision import datasets, transforms
data_path = 'data/p1ch7/'
cifar10 = datasets.CIFAR10(
data_path, train=True, download=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4915, 0.4823, 0.4468),
(0.2470, 0.2435, 0.2616))
]))
cifar10_val = datasets.CIFAR10(
data_path, train=False, download=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4915, 0.4823, 0.4468),
(0.2470, 0.2435, 0.2616))
]))
2.1. 构建数据集
从CIFAR10中创建一个只包含鸟和飞机的数据集子集
label_map = {
0: 0, 2: 1}
class_names = ['airplane', 'bird']
cifar2 = [(img, label_map[label])
for img, label in cifar10
if label in [0, 2]]
cifar2_val = [(img, label_map[label])
for img, label in cifar10_val
if label in [0, 2]]
2.2. 一个全连接模型
import torch.nn as nn
n_out = 2
model = nn.Sequential(
nn.Linear(
3072, # 输入特征
512, # 隐藏层的大小
),
nn.Tanh(),
nn.Linear(
512, # 隐藏层的大小
n_out, # 输出类
)
)
2.3. 分类器的输出
我们需要认识到输出是分类的:它要么是一只鸟,要么是一架飞机。当我们必须表示一个分类变量时,我们应该用该变量的独热编码表示,如对于飞机使用[1,0],对于鸟使用[0,1],顺序任意。
理想情况下,网络将为飞机输出 torch.tensor([1.0,0.0]),为鸟输出; torch.tensor([0.0,1.0])。实际上,由于我们的分类器并不是很完美的,我们可以期望网络输出介于二者之间的结果。关键的实现是我们可以将输出解释为概率:第1项是“飞机”的概率,第2项是“鸟”的概率。
一些额外的约束。
- 输出的每个元素必须在[0.0,1.0]的范围内(结果的概率不能小于0或大于 1)。
- 输出元素的总和必须为1.0(我们确信这2种结果中的一种将发生)
2.4. 用概率表示输出
Softmax,它获取一个值向量并生成另一个相同纬度的向量,其中的值满足我们刚刚列出的表示概率的约束条件。
def softmax(x):
return torch.exp(x) / torch.exp(x).sum()
x = torch.tensor([1.0, 2.0, 3.0])
softmax(x), softmax(x).sum()
Softmax是一个单调函数,因为输入中的较小值对应输出中的较小值。但是,它并不是比率不变的,因为值之间的比率没有被保留。
softmax = nn.Softmax(dim=1)
x = torch.tensor([[1.0, 2.0, 3.0],
[1.0, 2.0, 3.0]])
softmax(x)
在搭建的神经网络模型的末尾添加一个nn.Softmax()
model = nn.Sequential(
nn.Linear(3072, 512),
nn.Tanh(),
nn.Linear(512, 2),
nn.Softmax(dim=1))
img, _ = cifar2[0]
img_batch = img.view(-1).unsqueeze(0)
out = model(img_batch)
out
训练之后,通过输出概率的argmax来获得作为索引的标签,即获得最大概率的索引
_, index = torch.max(out, dim=1)
index
2.5. 分类的损失
我们希望惩罚错误分类,所以我们需要最大化的是与正确的类out[class_index]相关的概率。其中out是softmax的输出,class_index
是一个向量。
与正确类相关的概率,被称为我们的模型给定参数的似然,即我们想要一个损失函数,当概率很低的时候,损失非常高——低到其他选择都有比它更高的概率。相反,当概率高于其他选择时,损失应该很低,而且我们并不是真的专注于将概率提高到1。
负对数似然(NLL),表达式为 NLL =-sum(log(out_i[c_i])),其中sum()用于对N个样本求和,而c_i是样本i的目标类别。
计算分类损失的步骤:
- 运行正向传播,并从最后的线性层获得输出值
- 计算它们的Softmax,并获得概率
- 取于目标类别对应的预测概率(参数的可能性)
- 计算它的对数,在它前面加上一个符号,再添加到损失中
修改模型为:
model = nn.Sequential(
nn.Linear(3072, 512),
nn.Tanh(),
nn.Linear(512, 2),
nn.LogSoftmax(dim=1))
实例化NLL损失
loss = nn.NLLLoss()
img, label = cifar2[0]
out = model(img.view(-1).unsqueeze(0))
loss(out, torch.tensor([label]))
2.6. 训练分类器
2.6.1. 训练循环:
- 对整个数据集进行平均更新
- 更新每个样本的模型
- 在小批量上平均更新
import torch
import torch.nn as nn
import torch.optim as optim
model = nn.Sequential(
nn.Linear(3072, 512),
nn.Tanh(),
nn.Linear(512, 2),
nn.LogSoftmax(dim=1))
learning_rate = 1e-2
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
loss_fn = nn.NLLLoss()
n_epochs = 100
for epoch in range(n_epochs):
for img, label in cifar2:
out = model(img.view(-1).unsqueeze(0))
loss = loss_fn(out, torch.tensor([label]))
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("Epoch: %d, Loss: %f" % (epoch, float(loss)))
通过在每个迭代周期上变换样本并一次估计一个或几个样本的梯度(提高稳定性),我们在梯度下降中有效地引入了随机性。
在小批量上估计的梯度是在整个数据集上估计的梯度的较差近似值,有助于收敛并防止优化在过程中陷入局部极小。
通常,小批量是一个固定的大小,需要我们在训练之前设置,就像学习率一样。这些被称为超参数,以区别于模型的参数。
2.6.3. DataLoader
有助于打乱数据和组织数据
数据加载器的工作是从数据集中采样小批量,这是我们能够灵活地选择不同的采样策略。一种非常常见的策略是在每个迭代周期洗牌后进行均匀采样。
DataLoader()构造函数至少接收一个数据集对象作为输入,以及batch_size和一个shuffle布尔值,该布尔值指示数据是否需要在每个迭代周期开始时被重新打乱:
import torch
import torch.nn as nn
import torch.optim as optim
train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64, shuffle=True)
model = nn.Sequential(
nn.Linear(3072, 128),
nn.Tanh(),
nn.Linear(128, 2),
nn.LogSoftmax(dim=1))
learning_rate = 1e-2
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
loss_fn = nn.NLLLoss()
n_epochs = 100
for epoch in range(n_epochs):
for imgs, labels in train_loader:
outputs = model(imgs.view(imgs.shape[0], -1))
loss = loss_fn(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("Epoch: %d, Loss: %f" % (epoch, float(loss)))
train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64,shuffle=False)
correct = 0
total = 0
with torch.no_grad():
for imgs, labels in train_loader:
outputs = model(imgs.view(imgs.shape[0], -1))
_, predicted = torch.max(outputs, dim=1)
total += labels.shape[0]
correct += int((predicted == labels).sum())
print("Accuracy: %f" % (correct / total))
val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64,shuffle=False)
correct = 0
total = 0
with torch.no_grad():
for imgs, labels in val_loader:
outputs = model(imgs.view(imgs.shape[0], -1))
_, predicted = torch.max(outputs, dim=1)
total += labels.shape[0]
correct += int((predicted == labels).sum())
print("Accuracy: %f" % (correct / total))
更新网络结构,以获取更高的性能
使用交叉熵损失函数替换均方差损失函数
model = nn.Sequential(
nn.Linear(3072, 1024),
nn.Tanh(),
nn.Linear(1024, 512),
nn.Tanh(),
nn.Linear(512, 128),
nn.Tanh(),
nn.Linear(128, 2))
loss_fn = nn.CrossEntropyLoss()
import torch
import torch.nn as nn
import torch.optim as optim
train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64,
shuffle=True)
model = nn.Sequential(
nn.Linear(3072, 1024),
nn.Tanh(),
nn.Linear(1024, 512),
nn.Tanh(),
nn.Linear(512, 128),
nn.Tanh(),
nn.Linear(128, 2))
learning_rate = 1e-2
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()
n_epochs = 100
for epoch in range(n_epochs):
for imgs, labels in train_loader:
outputs = model(imgs.view(imgs.shape[0], -1))
loss = loss_fn(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("Epoch: %d, Loss: %f" % (epoch, float(loss)))
train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64,
shuffle=False)
correct = 0
total = 0
with torch.no_grad():
for imgs, labels in train_loader:
outputs = model(imgs.view(imgs.shape[0], -1))
_, predicted = torch.max(outputs, dim=1)
total += labels.shape[0]
correct += int((predicted == labels).sum())
print("Accuracy: %f" % (correct / total))
val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64,
shuffle=False)
correct = 0
total = 0
with torch.no_grad():
for imgs, labels in val_loader:
outputs = model(imgs.view(imgs.shape[0], -1))
_, predicted = torch.max(outputs, dim=1)
total += labels.shape[0]
correct += int((predicted == labels).sum())
print("Accuracy: %f" % (correct / total))
观察模型中可训练的参数的数量
sum([p.numel() for p in model.parameters()])
sum([p.numel() for p in model.parameters() if p.requires_grad == True])
linear = nn.Linear(3072, 1024)
linear.weight.shape, linear.bias.shape
总结
本文主要讲解了:
- 如何构建一个前馈神经网络
- 使用Dataset和DataLoader加载数据
- 了解分类损失
边栏推荐
- Bottomsheetdialogfragment imitation Tiktok comment box
- 5. Spark核心编程(1)
- 跨域和处理跨域
- 电商用户行为实时分析系统(Flink1.10.1)
- Object to map
- Pointnet++代码详解(五):sample_and_group函数和samle_and_group_all函数
- 使用OpenCV、ONNXRuntime部署YOLOV7目标检测——记录贴
- Wxml template syntax in wechat applet
- gradle
- MySQL learning notes (4) - (basic crud) operate the data of tables in the database
猜你喜欢

Use of MySQL

10 question 10 answer: do you really know thread pools?

2021-05-21

MySQL learning notes (5) -- join join table query, self join query, paging and sorting, sub query and nested query

C语言实现迭代实现二分查找

1. Neusoft cross border e-commerce warehouse demand specification document

微信小程序的自定义组件

INRIAPerson数据集转化为yolo训练格式并可视化
![[first launch in the whole network] will an abnormal main thread cause the JVM to exit?](/img/ae/5df25d64c2f29292bbfb21f696bbb0.png)
[first launch in the whole network] will an abnormal main thread cause the JVM to exit?

MySQL事务
随机推荐
软件过程与管理总复习
1.東軟跨境電商數倉需求規格說明文檔
Problems encountered by kotlin generics
JNI实用笔记
Use of MySQL
C语言——冒泡排序
6. Data warehouse design for data warehouse construction
Write a timed self-test
Common components of wechat applet
SQL练习题集合
3.东软跨境电商数仓项目架构设计
软件过程与管理复习(七)
PCM静默检测
The future of data Lakehouse - Open
1. Dongsoft Cross - Border E - commerce Data Warehouse Requirement specification document
Page navigation of wechat applet
Composants communs des applets Wechat
Unable to determine Electron version. Please specify an Electron version
Gradle custom plug-in
微信小程序代码的构成