当前位置:网站首页>Pytorch框架 学习记录1 CIFAR-10分类
Pytorch框架 学习记录1 CIFAR-10分类
2022-07-17 12:48:00 【好像几块钱】
- 使用torchvision加载并预处理CIFAR-10数据集
- 定义网络
- 定义损失函数和优化器
- 训练网络并更新网络参数
- 测试网络
import torchvision as tv
import torch as t
import torchvision.transforms as transforms #transforms模块专门用来进行图像预处理
from torchvision.transforms import ToPILImage
show=ToPILImage()
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))#标准化
])
trainset=tv.datasets.CIFAR10(
root='./cifar10',
train=True,
download=True,
transform=transform
)
trainloader=t.utils.data.DataLoader( #此处该接口将Pytroch已有数据接口的输入按照
trainset, #batch_size封装成Tensor,后续只需要再包装
batch_size=4, #成Variadle即可做为模型的输入
shuffle=True,
num_workers=2
)
testset=tv.datasets.CIFAR10(
root='./cifar10',
train=False,
download=True,
transform=transform
)
testloader=t.utils.data.DataLoader(
testset,
batch_size=4,
shuffle=False,
num_workers=2
)
classes=('plane','car','bird','cat',
'deer','dog','frog','horse','ship','truck') #cifar10数据集中图像分为这10个类别
(data,label)=trainset[100]
print(classes[label])
show((data+1)/2).resize((100,100)) #显示索引为100的一张图像
Files already downloaded and verified Files already downloaded and verified shipOut[1]:
dataiter=iter(trainloader) #封装成迭代器
images,labels=dataiter.next() #迭代器
print(''.join('%11s'%classes[labels[j]] for j in range(4))) #按一定格式输出标签
show(tv.utils.make_grid((images+1)/2)).resize((400,100))
#按一定格式显示图片
#torchvision.utils.make_grid()的作用是将若干幅图像拼成一副图像horse ship car shipOut[21]:
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.conv1=nn.Conv2d(3, 6, 5)
self.conv2=nn.Conv2d(6, 16, 5)
self.fc1 =nn.Linear(16*5*5, 120)
self.fc2 =nn.Linear(120, 84)
self.fc3 =nn.Linear(84,10)
def forward(self, x):
x=F.max_pool2d(F.relu(self.conv1(x)),(2,2))
x=F.max_pool2d(F.relu(self.conv2(x)),2)
x=x.view(x.size()[0],-1) #将多维度的Tensor展平成一维,这个-1指的是列数不定的
x=F.relu(self.fc1(x)) #情况下,根据原来Tensor内容和Tensor的大小自动分配列数
x=F.relu(self.fc2(x)) #relu函数激活
x=self.fc3(x)
return x
net=Net()
print(net)
#定义并输出网络Net( (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1)) (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1)) (fc1): Linear(in_features=400, out_features=120, bias=True) (fc2): Linear(in_features=120, out_features=84, bias=True) (fc3): Linear(in_features=84, out_features=10, bias=True) )
#定义损失函数和优化器
from torch.autograd import Variable
from torch import optim
criterion =nn.CrossEntropyLoss()
optimizer=optim.SGD(net.parameters(),lr=0.001,momentum=0.9)
#训练2轮
for epoch in range(2):
running_loss=0.0
for i,data in enumerate(trainloader,0): #注释见下,trainloader可迭代
inputs,labels=data
inputs,labels=Variable(inputs),Variable(labels) #输入都需转为Variable
optimizer.zero_grad() #梯度清零
outputs=net(inputs)
loss=criterion(outputs,labels) #求损失
loss.backward() #反向传播
optimizer.step() #参数更新
#打印log信息
running_loss+=loss.item()
if i%2000 == 1999:
print('[%d, %5d] loss:%.3f'%
(epoch+1, i+1, running_loss/2000))
running_loss=0.0
print('Finished Training')
#enumerate(sequence,[strat=0])函数将一个可遍历的数据对象组合为一个索引序列,同时列出
#数据和数据下标,一般用在for循环中。sequence表示一个序列或其他支持迭代对象,start表示下
#标起始位置。
[1, 2000] loss:0.911 [1, 4000] loss:0.931 [1, 6000] loss:0.926 [1, 8000] loss:0.940 [1, 10000] loss:0.952 [1, 12000] loss:0.975 [2, 2000] loss:0.873 [2, 4000] loss:0.886 [2, 6000] loss:0.883 [2, 8000] loss:0.893 [2, 10000] loss:0.899 [2, 12000] loss:0.922 Finished Training
#测试图片实际label
dataiter =iter(testloader)
images,labels=dataiter.next()
print('实际的label: ',''.join('%08s'%classes[labels[j]] for j in range(4)))
show(tv.utils.make_grid(images/2-0.5)).resize((400,100))实际的label: frog truck truck deerOut[47]:
#放入网络预测图片对应标签
outputs=net(Variable(images))
_,predicted=t.max(outputs.data, 1)
print('预测结果:',' '.join('%5s'%classes[j] for j in range(4)))预测结果: plane car bird cat由于在实验时忘记冻结训练块(新取块开始测试),模型训练不至2轮,可能导致过拟合,所以4张图片测试结果都不对。
#统计1000张测试集的准确率
correct =0
total=0
for data in testloader:
images,labels=data
outputs=net(Variable(images))
_,predicted=t.max(outputs.data, 1)
total+=labels.size(0)
correct+=(predicted==labels).sum()
print('1000张测试集中的准确率为:%d %%'%(100*correct/total))1000张测试集中的准确率为:68 %
2022/7/13
边栏推荐
- Autojs learning - Dynamic decryption
- unity3d中的旋转
- unity3d如何利用asset store下载一些有用的资源包
- Detailed explanation of C language custom types
- Huawei wireless device configuration intelligent roaming
- 【CSP-J 2021】总结
- 高效理解 FreeSql WhereDynamicFilter,深入了解设计初衷[.NET ORM]
- 电商销售数据分析与预测(日期数据统计、按天统计、按月统计)
- 华为防火墙认证技术
- B. Accuratelee [double pointer] [substr function]
猜你喜欢

多元线性回归详解

王者荣耀商城异地多活架构设计

string类的介绍及模拟实现

2022 windows penetration test of "Cyberspace Security" of Hunan secondary vocational group (ultra detailed)

国产旗舰手机价格泡沫严重,二手比新机更划算,要不然就买iPhone

Aike AI frontier promotion (7.17)

STL中stack和queue的使用以及模拟实现

Bidirectional NAT Technology

Know what it is, and know why, JS object creation and inheritance

Hcip day 1 7.15
随机推荐
C# SerialPort配置和属性了解
MFC|框架下自绘CEdit控件
unity3d中的旋转
Autojs learning - Dynamic decryption
SAP Fiori Launchpad 上看不到任何 tile 应该怎么办?
The R language uses the plot function in the native package (basic import package, graphics) to visualize the scatter plot
bazel使用教程 转
mysql不能启动了?相关组件缺失?系统升级?组件不匹配?开始重装mysql
Date -- machine test topic for postgraduate entrance examination of Guizhou University
Autojs learning - multi function treasure chest - medium
The use and Simulation of stack and queue in STL
[makefile] some notes on the use of makefile
String类型函数传递问题
MFC | self drawn CEdit control under the framework
Detailed explanation of C language custom types
圆桌实录:炉边对话——如何在 Web3 实现创新
Through middle order traversal and pre order traversal, the subsequent traversal will always build a binary tree
yarn(cdh)中的虚拟cpu和内存
HCIA review and answer 2022.7.6
Data Lake solutions of various manufacturers


