当前位置:网站首页>Pytorch框架 学习记录1 CIFAR-10分类
Pytorch框架 学习记录1 CIFAR-10分类
2022-07-17 12:48:00 【好像几块钱】
- 使用torchvision加载并预处理CIFAR-10数据集
- 定义网络
- 定义损失函数和优化器
- 训练网络并更新网络参数
- 测试网络
import torchvision as tv
import torch as t
import torchvision.transforms as transforms #transforms模块专门用来进行图像预处理
from torchvision.transforms import ToPILImage
show=ToPILImage()
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))#标准化
])
trainset=tv.datasets.CIFAR10(
root='./cifar10',
train=True,
download=True,
transform=transform
)
trainloader=t.utils.data.DataLoader( #此处该接口将Pytroch已有数据接口的输入按照
trainset, #batch_size封装成Tensor,后续只需要再包装
batch_size=4, #成Variadle即可做为模型的输入
shuffle=True,
num_workers=2
)
testset=tv.datasets.CIFAR10(
root='./cifar10',
train=False,
download=True,
transform=transform
)
testloader=t.utils.data.DataLoader(
testset,
batch_size=4,
shuffle=False,
num_workers=2
)
classes=('plane','car','bird','cat',
'deer','dog','frog','horse','ship','truck') #cifar10数据集中图像分为这10个类别
(data,label)=trainset[100]
print(classes[label])
show((data+1)/2).resize((100,100)) #显示索引为100的一张图像
Files already downloaded and verified Files already downloaded and verified shipOut[1]:
dataiter=iter(trainloader) #封装成迭代器
images,labels=dataiter.next() #迭代器
print(''.join('%11s'%classes[labels[j]] for j in range(4))) #按一定格式输出标签
show(tv.utils.make_grid((images+1)/2)).resize((400,100))
#按一定格式显示图片
#torchvision.utils.make_grid()的作用是将若干幅图像拼成一副图像horse ship car shipOut[21]:
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.conv1=nn.Conv2d(3, 6, 5)
self.conv2=nn.Conv2d(6, 16, 5)
self.fc1 =nn.Linear(16*5*5, 120)
self.fc2 =nn.Linear(120, 84)
self.fc3 =nn.Linear(84,10)
def forward(self, x):
x=F.max_pool2d(F.relu(self.conv1(x)),(2,2))
x=F.max_pool2d(F.relu(self.conv2(x)),2)
x=x.view(x.size()[0],-1) #将多维度的Tensor展平成一维,这个-1指的是列数不定的
x=F.relu(self.fc1(x)) #情况下,根据原来Tensor内容和Tensor的大小自动分配列数
x=F.relu(self.fc2(x)) #relu函数激活
x=self.fc3(x)
return x
net=Net()
print(net)
#定义并输出网络Net( (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1)) (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1)) (fc1): Linear(in_features=400, out_features=120, bias=True) (fc2): Linear(in_features=120, out_features=84, bias=True) (fc3): Linear(in_features=84, out_features=10, bias=True) )
#定义损失函数和优化器
from torch.autograd import Variable
from torch import optim
criterion =nn.CrossEntropyLoss()
optimizer=optim.SGD(net.parameters(),lr=0.001,momentum=0.9)
#训练2轮
for epoch in range(2):
running_loss=0.0
for i,data in enumerate(trainloader,0): #注释见下,trainloader可迭代
inputs,labels=data
inputs,labels=Variable(inputs),Variable(labels) #输入都需转为Variable
optimizer.zero_grad() #梯度清零
outputs=net(inputs)
loss=criterion(outputs,labels) #求损失
loss.backward() #反向传播
optimizer.step() #参数更新
#打印log信息
running_loss+=loss.item()
if i%2000 == 1999:
print('[%d, %5d] loss:%.3f'%
(epoch+1, i+1, running_loss/2000))
running_loss=0.0
print('Finished Training')
#enumerate(sequence,[strat=0])函数将一个可遍历的数据对象组合为一个索引序列,同时列出
#数据和数据下标,一般用在for循环中。sequence表示一个序列或其他支持迭代对象,start表示下
#标起始位置。
[1, 2000] loss:0.911 [1, 4000] loss:0.931 [1, 6000] loss:0.926 [1, 8000] loss:0.940 [1, 10000] loss:0.952 [1, 12000] loss:0.975 [2, 2000] loss:0.873 [2, 4000] loss:0.886 [2, 6000] loss:0.883 [2, 8000] loss:0.893 [2, 10000] loss:0.899 [2, 12000] loss:0.922 Finished Training
#测试图片实际label
dataiter =iter(testloader)
images,labels=dataiter.next()
print('实际的label: ',''.join('%08s'%classes[labels[j]] for j in range(4)))
show(tv.utils.make_grid(images/2-0.5)).resize((400,100))实际的label: frog truck truck deerOut[47]:
#放入网络预测图片对应标签
outputs=net(Variable(images))
_,predicted=t.max(outputs.data, 1)
print('预测结果:',' '.join('%5s'%classes[j] for j in range(4)))预测结果: plane car bird cat由于在实验时忘记冻结训练块(新取块开始测试),模型训练不至2轮,可能导致过拟合,所以4张图片测试结果都不对。
#统计1000张测试集的准确率
correct =0
total=0
for data in testloader:
images,labels=data
outputs=net(Variable(images))
_,predicted=t.max(outputs.data, 1)
total+=labels.size(0)
correct+=(predicted==labels).sum()
print('1000张测试集中的准确率为:%d %%'%(100*correct/total))1000张测试集中的准确率为:68 %
2022/7/13
边栏推荐
- LVI-SAM:激光-IMU-相机紧耦合建图
- 选择比努力更重要
- MFC | self drawn CEdit control under the framework
- BEV空间内的特征级融合
- C # treeview tree structure recursive processing (enterprise group type hierarchical tree display)
- 顺序表的基本建立,以及增删改查的相关操作(c语言描述之顺序表)
- R language uses the ordinal of epidisplay package or. The display function obtains the summary statistical information of the ordered logistic regression model (the odds ratio and its confidence inter
- Complete knapsack problem code template
- VScode+Unity3D的配置
- Studio 3T unlimited trial
猜你喜欢

Bidirectional NAT Technology

C serialport configuration and attribute understanding

HCIA OSPF

VScode+Unity3D的配置

YARN环境中应用程序JAR包冲突问题的分析及解决
![Effectively understand FreeSQL wheredynamicfilter and deeply understand the original design intention [.net orm]](/img/cb/76200539c59bb865e60e5ea1121feb.png)
Effectively understand FreeSQL wheredynamicfilter and deeply understand the original design intention [.net orm]

【牛客刷题】/*C语言实现字符串左旋*/

基于微信小程序的外卖点餐系统

2022 windows penetration test of "Cyberspace Security" of Hunan secondary vocational group (ultra detailed)

Analysis and solution of application jar package conflict in yarn environment
随机推荐
追根问底:Objective-C关联属性原理分析
How to solve the problem of cross domain access by Google browser
Distinction between private key and public key -- Explanation of private key and public key
LVI-SAM:激光-IMU-相机紧耦合建图
查找——平衡二叉树
Effectively understand FreeSQL wheredynamicfilter and deeply understand the original design intention [.net orm]
圆桌实录:炉边对话——如何在 Web3 实现创新
分类任务中的类别不平衡问题
各厂商的数据湖解决方案
创建虚拟机第一章(vmvare虚拟机)
YARN环境中应用程序JAR包冲突问题的分析及解决
R language uses the aggregate function of epidisplay package to divide numerical variables into different subsets based on factor variables, calculate the summary statistics of each subset, and set na
Takeout ordering system based on wechat applet
Analysis of Web Remote Code Execution Vulnerability of Zhongke panyun-d module
旋转矩阵(Rotate Matrix)的性质分析(转发)
B. AccurateLee【双指针】【substr函数】
Lvi-sam: laser IMU camera tight coupling mapping
R language uses the ordinal of epidisplay package or. The display function obtains the summary statistical information of the ordered logistic regression model (the odds ratio and its confidence inter
bazel使用教程 转
Feature level fusion in Bev space


