当前位置:网站首页>【pytorch】线性神经网络
【pytorch】线性神经网络
2022-07-15 23:59:00 【李峻枫】
除了使用自己定义模型的方法,也可以用torch提供的神经网络模型。
可以将其理解为只有输入层和输出层的全连接网络。
神经网络
from torch import nn
net = nn.Sequential(nn.Linear(4 , 1))
net[0].weight.data.normal_(0,0.01)
net[0].bias.data.fill_(0)
loss = nn.MSELoss()
trainer = torch.optim.SGD(net.parameters(), lr=0.001)
根据全连接网络的特性,不难发现,这就是一个线性回归模型。这也解释了,为什么现在的神经网络层数要如此只深。
因为多层的神经网络可以拟合任意函数,这一点也是被证明了的。
完整代码
#!/usr/bin/env python
# coding: utf-8
# In[1]:
import numpy as np
import torch
from torch.utils import data
# In[2]:
def data_maker(w, b, n_size): # y=w*x+b,n个数据
X = torch.normal(0 , 1 , (n_size , len(w))) # n*len(w)的参数
y = torch.matmul(X , w) + b
y = y + torch.normal(0, 0.01 , y.shape)
return X , y.reshape((-1 , 1))
W = torch.tensor([4.0 , 2.3 , 121313,312233])
B = 1
x , y = data_maker(W, B, 1000)
# In[3]:
def load_array(data_arrays , batch_size , is_train = True):
dataset = data.TensorDataset(*data_arrays)
return data.DataLoader(dataset , batch_size , shuffle = is_train)
# In[4]:
batch_size = 10
data_iter = load_array((x , y) , batch_size)
# In[5]:
from torch import nn
net = nn.Sequential(nn.Linear(4 , 1))
net[0].weight.data.normal_(0,0.01)
net[0].bias.data.fill_(0)
loss = nn.MSELoss()
trainer = torch.optim.SGD(net.parameters(), lr=0.001)
epochs = 600
# In[6]:
for epoch in range(epochs):
for X, Y in data_iter:
l = loss(net(X), Y)
trainer.zero_grad()
l.backward()
trainer.step()
pass
l = loss(net(x) , y)
print(f'epoch{
epoch + 1}, loss {
l}')
边栏推荐
- Reading a data driven graph generic model for temporary interaction networks
- Flink(一)概述
- PHP conversion hours ago
- Torch code template
- Thinkphp5 read multiline text, read files, and split multiline text
- jol-core
- Redis_Linux安装
- Complete process of invention patent application (from application to authorization)
- In class practice of software quality assurance and testing
- JDBC connection to mysql8.0 driver
猜你喜欢

Codeworks 5 questions per day (average 1500) - day 16

【Unity】Animator动画倒播,与StartRecording动画录制

Uploading and downloading of files

Xunwei domestic development board three development boards worth starting with

【vulnhub】DC9

Flink(二)时间和窗口

Emqx server establishes ssl/tls secure connection, one-way and two-way

Flink(四)分流合流

Reading the pointpillar code of openpcdet -- Part 2: network structure

转本结束暑假2022.6.29-7.13我的深圳之行(体验)
随机推荐
Experiment 4 shell programming
Flink (II) time and window
(manual) [sqli labs42, 43] post injection, stack injection, error echo, character injection
【sql面试题】求连续点击三次的用户数,而且中间不能有别人的点击
解密静态路由,一文分析静态路由优缺点!
Thinkphp5 read multiline text, read files, and split multiline text
Xunwei domestic development board three development boards worth starting with
【漫步刷题路】- 逆序字符串II
PHP conversion hours ago
Vertical/Column text select in PyCharm
网络安全从业人员应具备的职业素养
不同的评估方法适合哪种机器学习模型?
Flink(四)分流合流
ThinkPHP 3 adds word segmentation weight search function and phpanalysis plug-in
[development tutorial 2] crazy shell arm function mobile phone - Introduction to test program
Memory management: memory allocation and recycling
第1章:初识数据库与MySQL----MySQL安装
Record a segmented data multithread optimization process
In class practice of software quality assurance and testing
发明专利申请完整流程(从申请到授权)