当前位置:网站首页>Pytorch learning diary (4)
Pytorch learning diary (4)
2022-07-19 07:08:00 【When to order】
Today, learn the construction of convolutional neural network
One 、 Construct convolution neural network to process mnist Data sets
1.1 get data
Build training set and test set respectively ( Verification set ); use DataLoader To get data iteratively :
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision import datasets,transforms
import numpy as np
# Define super parameters
input_size = 28 # The image size is 28*28
num_classes = 10 # Number of types of labels
num_epochs = 3 # The total cycle of training
batch_size = 64 # Number of batches
# Training set , Here is based on datasets Inside Mnist Module read
train_dataset = datasets.MNIST(root='./data',
train=True,
transform=transforms.ToTensor(),
download=True)
# Test set
test_dataset = datasets.MNIST(root='./data',
train=False,
transform=transforms.ToTensor())
# structure batch data
train_loader = torch.utils.data.DataLoader(dataset = train_dataset,
batch_size = batch_size,
shuffle = True)
test_loader = torch.utils.data.DataLoader(dataset = test_dataset,
batch_size = batch_size,
shuffle = True)1.2 Convolution network module construction
General convolution ,relu layer , The pooling layer can be written as a package ; Note that the final result of convolution is the characteristic graph , You need to convert the graph into a vector to do classification or regression tasks :
class CNN(nn.Module):
def __init__(self):
super(CNN,self).__init__()
self.conv1 = nn.Sequential( # Input is (1,28,28)
nn.Conv2d(
in_channels=1, # The channel number , The grayscale image here is 1
out_channels=16, # To get several characteristic graphs , That is, the number of convolution kernels
kernel_size=5, # Convolution kernel size
stride=1, # step
padding=2, # If you want the convolution size to be the same as the original , Need to set up padding=(kernel_size-1)/2 if stride=1
),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2), # The output is (16,14,14)
)
self.conv2 = nn.Sequential( # Input is (16,14,14)
nn.Conv2d(16,32,5,1,2),
nn.ReLU(),
nn.MaxPool2d(2), # Output is (32,7,7)
)
self.out = nn.Linear(32*7*7,10)
def forward(self,x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0),-1) #flatten operation , The result is :(batch_size,32*7*7)
print(x.size())
out = self.out(x)
return out 
1.3 Training network models
# Set the accuracy function as the evaluation standard
def accuracy(predictions,labels):
pred = torch.max(predictions.data,1)[1]
rights = pred.eq(labels.data.view_as(pred)).sum()
return rights, len(labels)
# Instantiation
net = CNN()
# Loss function
criterion = nn.CrossEntropyLoss()
# Optimizer
optimizer = optim.Adam(net.parameters(),lr=0.001)
# Start the training cycle
for epoch in range(num_epochs):
# At present epoch The results are preserved
train_rights = []
for batch_idx, (data,target) in enumerate(train_loader): # Cycle each batch in the container
net.train()
output = net(data)
loss = criterion(output,target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
right = accuracy(output,target)
train_rights.append(right)
if batch_idx % 100 == 0:
net.eval()
val_rights = []
for (data,target) in test_loader:
output = net(data)
right = accuracy(output,target)
val_rights.append(right)
# Accuracy calculation
train_r = (sum([tup[0] for tup in train_rights]),sum([tup[1] for tup in train_rights]))
val_r = (sum([tup[0] for tup in val_rights]),sum([tup[1] for tup in val_rights]))
print(' At present epoch:{} [{}/{} ({:.0f}%)]\t Loss :{:.6f}\t Training set accuracy :{:.2f}%\t Test set accuracy :{:.2f}%'.format(
epoch, batch_idx * batch_size, len(train_loader.dataset),
100. * batch_idx / len(train_loader),
loss.data,
100. * train_r[0].numpy() / train_r[1],
100. * val_r[0].numpy() / val_r[1],
))
边栏推荐
猜你喜欢

Xiaodi network security - Notes (3)

字典、元组和列表的使用及区别,

Matlab implementation code of image denoising method based on Hidden Markov tree model in wavelet domain

Minecraft Paper 1.18.1 版开服教程,我的世界开服教程,MCSManager9面板使用教程

M design of GPS data longitude and latitude height analysis and Kalman analysis software based on matlab-GUI

My world 1.12.2 Magic Baby (Fairy treasure dream) service opening tutorial

IP103.53.125. XXX IP address segment details

传奇游戏架设教程
![[ restartedMain] o.s.b.d.LoggingFailureAnalysisReporter :](/img/dd/054af819c8bdca31bd135495386fb4.png)
[ restartedMain] o.s.b.d.LoggingFailureAnalysisReporter :

m基于Lorenz混沌自同步的混沌数字保密通信系统的FPGA实现,verilog编程实现,带MATLAB混沌程序
随机推荐
剑指Offer刷题记录——Offer 05. 替换空格
字典,集合的使用,数据类型的转换
天翼云 杭州 云主机(VPS) 性能评测
1.服务器是什么?
剑指Offer刷题记录——Offer 06.从尾到头打印链表
PyTorch学习日记(二)
传奇手游怎么开服?需要投资多少?需要那些东西?
类与super、继承
Quickly master the sort command and tr command
m基于simulink的16QAM和2DPSK通信链路仿真,并通过matlab调用simulink模型得到误码率曲线
闭包与装饰器
爬虫基础—爬虫的基本原理
网站被劫持了怎么办?
What does ack attack mean? How to defend ack attack
Steam game high frequency i9-12900k build cs:go server
m基于matlab的超宽带MIMO雷达对目标的检测仿真,考虑时间反转
SSH remote login service
Escape from the center of the lake (math problem)
Matlab simulation of cognitive femtocell performance in m3gpp LTE communication network
1. What is a server?