当前位置:网站首页>Pytorch手动实现多层感知机
Pytorch手动实现多层感知机
2022-07-17 13:08:00 【phac123】
简述
- 获取和读取数据;这里还是采用Fashion-MNIST数据集
- 定义模型参数,这里的输入层是28*28,隐藏层设为256,输出层为10
- 定义模型,激活函数使用ReLU,然后再经过一层线性层
- 定义损失函数,采用交叉熵损失函数
- 最后采用小批量随机梯度下降进行训练优化
- 注: 由于SoftmaxCrossEntropyLoss在反向传播的时候相对于沿batch维求和了,而PyTorch默认的是求平均,所以用PyTorch计算得到的loss比mxnet小很多(大概是maxnet计算得到的1/batch_size这个量级),所以反向传播得到的梯度也小很多,所以为了得到差不多的学习效果,我们把学习率设置成100.0。(之所以这么大,应该是因为d2lzh_pytorch里面的sgd函数在更新的时候除以了batch_size,其实PyTorch在计算loss的时候已经除过一次了,sgd这里应该不用除了)
完整代码
d2lzh_pytorch.py
import random
from IPython import display
import matplotlib.pyplot as plt
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import sys
import torch.nn as nn
def use_svg_display():
# 用矢量图显示
display.set_matplotlib_formats('svg')
def set_figsize(figsize=(3.5, 2.5)):
use_svg_display()
# 设置图的尺寸
plt.rcParams['figure.figsize'] = figsize
'''给定batch_size, feature, labels,做数据的打乱并生成指定大小的数据集'''
def data_iter(batch_size, features, labels):
num_examples = len(features)
indices = list(range(num_examples))
random.shuffle(indices)
for i in range(0, num_examples, batch_size): #(start, staop, step)
j = torch.LongTensor(indices[i: min(i + batch_size, num_examples)]) #最后一次可能没有一个batch
yield features.index_select(0, j), labels.index_select(0, j)
'''定义线性回归的模型'''
def linreg(X, w, b):
return torch.mm(X, w) + b
'''定义线性回归的损失函数'''
def squared_loss(y_hat, y):
return (y_hat - y.view(y_hat.size())) ** 2 / 2
'''线性回归的优化算法 —— 小批量随机梯度下降法'''
def sgd(params, lr, batch_size):
for param in params:
param.data -= lr * param.grad / batch_size #这里使用的是param.data
'''MINIST,可以将数值标签转成相应的文本标签'''
def get_fashion_mnist_labels(labels):
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
'''定义一个可以在一行里画出多张图像和对应标签的函数'''
def show_fashion_mnist(images, labels):
use_svg_display()
# 这里的_表示我们忽略(不使用)的变量
_, figs = plt.subplots(1, len(images), figsize=(12, 12))
for f, img, lbl in zip(figs, images, labels):
f.imshow(img.view((28, 28)).numpy())
f.set_title(lbl)
f.axes.get_xaxis().set_visible(False)
f.axes.get_yaxis().set_visible(False)
plt.show()
'''获取并读取Fashion-MNIST数据集;该函数将返回train_iter和test_iter两个变量'''
def load_data_fashion_mnist(batch_size):
mnist_train = torchvision.datasets.FashionMNIST(root='Datasets/FashionMNIST', train=True, download=True,
transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='Datasets/FashionMNIST', train=False, download=True,
transform=transforms.ToTensor())
if sys.platform.startswith('win'):
num_workers = 0 # 0表示不用额外的进程来加速读取数据
else:
num_workers = 4
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)
return train_iter, test_iter
'''评估模型net在数据集data_iter的准确率'''
def evaluate_accuracy(data_iter, net):
acc_sum, n = 0.0, 0
for X, y in data_iter:
acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()
n += y.shape[0]
return acc_sum / n
'''训练模型,softmax'''
def train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, params=None, lr=None, optimizer=None):
for epoch in range(num_epochs):
train_l_sum, train_acc_sum, n = 0.0, 0.0, 0
for X, y in train_iter:
y_hat = net(X)
l = loss(y_hat, y).sum()
# 梯度清零
if optimizer is not None:
optimizer.zero_grad()
elif params is not None and params[0].grad is not None:
for param in params:
param.grad.data.zero_()
l.backward()
if optimizer is None:
sgd(params, lr, batch_size)
else:
optimizer.step()
train_l_sum += l.item()
train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()
n += y.shape[0]
test_acc = evaluate_accuracy(test_iter, net)
print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'
% (epoch + 1, train_l_sum / n, train_acc_sum / n, test_acc))
'''对x的形状转换'''
class FlattenLayer(nn.Module):
def __init__(self):
super(FlattenLayer, self).__init__()
def forward(self, x):
return x.view(x.shape[0], -1)
main.py
import torch
import numpy as np
import sys
sys.path.append("..")
import d2lzh_pytorch as d2l
# 获取和读取数据
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
# 定义模型参数
num_inputs, num_hiddens, num_outputs = 784, 256, 10
w1 = torch.tensor(np.random.normal(0, 0.01, (num_inputs, num_hiddens)), dtype=torch.float)
b1 = torch.zeros(num_hiddens, dtype=torch.float)
w2 = torch.tensor(np.random.normal(0, 0.01, (num_hiddens, num_outputs)), dtype=torch.float)
b2 = torch.zeros(num_outputs, dtype=torch.float)
params = [w1, b1, w2, b2]
for param in params:
param.requires_grad_(requires_grad=True)
# 定义ReLU函数
def relu(X):
return torch.max(input = X, other = torch.tensor(0.0))
# 定义模型
def net(X):
X = X.view(-1, num_inputs)
H = relu(torch.mm(X, w1) + b1)
return torch.mm(H, w2) + b2
# 定义损失函数
loss = torch.nn.CrossEntropyLoss()
# 训练模型
num_epochs, lr = 5, 100.0
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, params, lr)
边栏推荐
- 使用tesseract.js-offline识别图片文字记录
- Data Lake solutions of various manufacturers
- LeetCode 2325. 解密消息(map)
- 开发第一个Flink应用
- The R language uses the plot function in the native package (basic import package, graphics) to visualize the scatter plot
- 通过中序遍历和前序遍历,后续遍历来构建二叉树
- Domestic flagship mobile phones have a serious price foam, and second-hand phones are more cost-effective than new ones, or buy iPhones
- unity3d中的旋转
- JSP based novel writing and creation website
- 分类任务中的类别不平衡问题
猜你喜欢
![Effectively understand FreeSQL wheredynamicfilter and deeply understand the original design intention [.net orm]](/img/cb/76200539c59bb865e60e5ea1121feb.png)
Effectively understand FreeSQL wheredynamicfilter and deeply understand the original design intention [.net orm]

Pytorch框架 学习记录1 CIFAR-10分类

NJCTF 2017messager

Data Lake solutions of various manufacturers

LeetCode 2315. 统计星号(字符串)

STL中stack和queue的使用以及模拟实现

win10开始键点击无响应

因果学习将开启下一代AI浪潮?九章云极DataCanvas正式发布YLearn因果学习开源项目

从预测到决策,九章云极DataCanvas推出YLearn因果学习开源项目

Zhongke Panyun - Cyberspace Security packet capture topic b.pcap
随机推荐
[makefile] some notes on the use of makefile
LeetCode 2319. 判断矩阵是否是一个 X 矩阵
yarn(cdh)中的虚拟cpu和内存
[Niuke swipe questions] / *c language realizes left-hand rotation of strings*/
Find balanced binary tree
LeetCode 2331. 计算布尔二叉树的值(树的遍历)
Take a look at this ugly face | MathWorks account unavailable - technical issue
bazel使用教程 转
使用tesseract.js-offline识别图片文字记录
NJCTF 2017messager
The select function of dplyr package in R language deletes the data columns in dataframe containing the specified string content (drop columns in dataframe)
2022 Shaanxi secondary vocational group "Cyberspace Security" - packet analysis
架构实战营|模块7
Idea display service port --service
二分类学习推广到多分类学习
Domestic flagship mobile phones have a serious price foam, and second-hand phones are more cost-effective than new ones, or buy iPhones
unity3d如何利用asset store下载一些有用的资源包
十分钟从 PyTorch 转 MXNet(转)
读已提交级别下 注解事务+分布式锁结合引起的事故--活动购买机会的错乱
c# treeView 树形结构递归处理(企业集团型层次树形展示)