当前位置:网站首页>Pytorch. NN implementation of multi-layer perceptron
Pytorch. NN implementation of multi-layer perceptron
2022-07-19 10:50:00 【phac123】
sketch
Use it directly nn Add model layer
Complete code
d2lzh_pytorch.py
import random
from IPython import display
import matplotlib.pyplot as plt
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import sys
import torch.nn as nn
def use_svg_display():
# It's shown in vector form
display.set_matplotlib_formats('svg')
def set_figsize(figsize=(3.5, 2.5)):
use_svg_display()
# Set the size of the drawing
plt.rcParams['figure.figsize'] = figsize
''' Given batch_size, feature, labels, Scramble the data and generate a data set of a specified size '''
def data_iter(batch_size, features, labels):
num_examples = len(features)
indices = list(range(num_examples))
random.shuffle(indices)
for i in range(0, num_examples, batch_size): #(start, staop, step)
j = torch.LongTensor(indices[i: min(i + batch_size, num_examples)]) # There may not be one last time batch
yield features.index_select(0, j), labels.index_select(0, j)
''' Define the model of linear regression '''
def linreg(X, w, b):
return torch.mm(X, w) + b
''' Define the loss function of linear regression '''
def squared_loss(y_hat, y):
return (y_hat - y.view(y_hat.size())) ** 2 / 2
''' Optimization algorithm of linear regression —— Small batch random gradient descent method '''
def sgd(params, lr, batch_size):
for param in params:
param.data -= lr * param.grad / batch_size # What we use here is param.data
'''MINIST, You can convert numeric labels into corresponding text labels '''
def get_fashion_mnist_labels(labels):
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
''' Define a function that can draw multiple images and corresponding labels in one line '''
def show_fashion_mnist(images, labels):
use_svg_display()
# there _ Means we ignore ( Don't use ) The variable of
_, figs = plt.subplots(1, len(images), figsize=(12, 12))
for f, img, lbl in zip(figs, images, labels):
f.imshow(img.view((28, 28)).numpy())
f.set_title(lbl)
f.axes.get_xaxis().set_visible(False)
f.axes.get_yaxis().set_visible(False)
plt.show()
''' Get and read Fashion-MNIST Data sets ; This function will return train_iter and test_iter Two variables '''
def load_data_fashion_mnist(batch_size):
mnist_train = torchvision.datasets.FashionMNIST(root='Datasets/FashionMNIST', train=True, download=True,
transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='Datasets/FashionMNIST', train=False, download=True,
transform=transforms.ToTensor())
if sys.platform.startswith('win'):
num_workers = 0 # 0 It means that there is no extra process to speed up reading data
else:
num_workers = 4
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)
return train_iter, test_iter
''' Evaluation model net In the data set data_iter The accuracy of '''
def evaluate_accuracy(data_iter, net):
acc_sum, n = 0.0, 0
for X, y in data_iter:
acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()
n += y.shape[0]
return acc_sum / n
''' Training models ,softmax'''
def train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, params=None, lr=None, optimizer=None):
for epoch in range(num_epochs):
train_l_sum, train_acc_sum, n = 0.0, 0.0, 0
for X, y in train_iter:
y_hat = net(X)
l = loss(y_hat, y).sum()
# Gradient clear
if optimizer is not None:
optimizer.zero_grad()
elif params is not None and params[0].grad is not None:
for param in params:
param.grad.data.zero_()
l.backward()
if optimizer is None:
sgd(params, lr, batch_size)
else:
optimizer.step()
train_l_sum += l.item()
train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()
n += y.shape[0]
test_acc = evaluate_accuracy(test_iter, net)
print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'
% (epoch + 1, train_l_sum / n, train_acc_sum / n, test_acc))
''' Yes x Shape conversion '''
class FlattenLayer(nn.Module):
def __init__(self):
super(FlattenLayer, self).__init__()
def forward(self, x):
return x.view(x.shape[0], -1)
main.py
import torch
from torch import nn
from torch.nn import init
import numpy as np
import sys
sys.path.append("..")
import d2lzh_pytorch as d2l
# Defining models
num_inputs, num_hiddens, num_outputs = 784, 256, 10
net = nn.Sequential(
d2l.FlattenLayer(),
nn.Linear(num_inputs, num_hiddens),
nn.ReLU(),
nn.Linear(num_hiddens, num_outputs)
)
for param in net.parameters():
init.normal_(param, mean = 0, std = 0.01)
# Read and train the model
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr = 0.5)
num_epochs = 5
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, None, None, optimizer)
边栏推荐
- 使用tesseract.js-offline识别图片文字记录
- 二分类学习推广到多分类学习
- antd表单设置数组字段
- 如何在双链笔记软件中建立仪表盘和知识库?以嵌入式小组件库 NotionPet 为例
- If you use mybatics to access Damon database, is it exactly the same? Because the SQL syntax has not changed. Right?
- [design process] Net ORM FreeSQL wheredynamicfilter dynamic table query function
- 基于网络编码的卫星网络容量提升方法
- 电商销售数据分析与预测(日期数据统计、按天统计、按月统计)
- Pytorch学习记录2 线性回归(Tensor,Variable)
- input number 纯数字输入 限制长度 限制 最大值
猜你喜欢

OpenCV编程:OpenCV3.X训练自己的分类器

军品研制过程所需文件-进阶版

The use and Simulation of stack and queue in STL

因果学习将开启下一代AI浪潮?九章云极DataCanvas正式发布YLearn因果学习开源项目

ue4对动画蓝图的理解

Opencv programming: opencv3 X trains its own classifier

Structure the combat battalion | module 7

过拟合与欠拟合

LeetCode 2315. 统计星号(字符串)

C serialport configuration and attribute understanding
随机推荐
[makefile] some notes on the use of makefile
Autojs learning - multi function treasure chest - bottom
Pytoch learning record 2 linear regression (tensor, variable)
Win10的环境变量配置
unity3d如何利用asset store下载一些有用的资源包
Brush questions with binary tree (2)
MySQL query error
基于网络编码的卫星网络容量提升方法
二分类学习推广到多分类学习
Data Lake solutions of various manufacturers
Through middle order traversal and pre order traversal, the subsequent traversal will always build a binary tree
电商销售数据分析与预测(日期数据统计、按天统计、按月统计)
[acwing] game 60 c-acwing 4496 eat fruit
基于网络编码的卫星网络容量提升方法
SVN学习
[LeetCode周赛复盘] 第 302 场周赛20220717
antd 下拉多选传值到后台做查询操作
Pytorch学习记录2 线性回归(Tensor,Variable)
OpenCV编程:OpenCV3.X训练自己的分类器
If you use mybatics to access Damon database, is it exactly the same? Because the SQL syntax has not changed. Right?