当前位置:网站首页>RuntimeError_ Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor)
RuntimeError_ Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor)
2022-07-19 03:20:00 【lucky-wz】
Two solutions
Scheme 1 : Check whether the network and data are GPU On
RuntimeError: Input type (torch.FloatTensor) And weight type (torch.cuda.FloatTensor) It should be the same , Or the input should be a MKLDNN tensor , And weight is a dense tensor .
Our weight is cuda type (GPU Train to get ), And input ( Data to be tested ) No cuda type . use GPU Training model , Not directly in CPU Upper use , Put it in GPU Medium forecast .
If the model parameters used are cuda On (gpu) Trained , When using it for testing , You need to put the data to be tested into GPU On , namely :data.cuda().
Carefully check whether the neural network and related data are placed GPU On !
Option two : Model definitions are all placed in class initialization
Besides , According to the analysis of the boss , May also be , Because your model is defined , Not well defined , As a result, part of the model cannot be transferred to cuda On . The specific test is skipped .
in general , If the above method doesn't work , Then try putting all the network layers ( As long as there are parameters to be trained in the network layer ) in __init__() Function to define , Only in forward() Write the logic of the runtime , namely :
class A(nn.Module):
def __init__(self):
super(A,self).__init__()
self.conv = nn.Conv2d(in_channels=3, out_channels=8, kernel_size=3)
self.relu = nn.ReLU(inplace=True)
self.b_module = B()
def forward(self,x):
out = self.conv(x)
out = self.relu(out)
out = self.b_module(out)
return out
class B(nn.Module):
def __init__(self):
super(B,self).__init__()
self.conv = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
out = self.conv(x)
out = self.relu(out)
return out
Reference resources :
边栏推荐
- 【单片机仿真】(八)指令系统 — 数据传送指令
- Yolov6 learning first chapter
- JPA初识(ORM思想、JPA的基本操作)
- 自动装配 & 集合注入
- SysTick定时器的基础学习以及手撕代码
- Pure virtual function
- [single chip microcomputer simulation] (XI) instruction system logic operation instruction - logic and instruction anl, logic or instruction ORL
- 2002 - Can‘t connect to server on ‘127.0.0.1‘ (36)
- [MCU simulation] (XVI) control transfer instructions - unconditional transfer instructions, conditional transfer instructions
- Multi table query - case exercise
猜你喜欢

MySQL optimized index

The place where the dream begins ---- first knowing C language

About XML file (VI) - the difference between JSON and XML file

乐视还有400多位员工?过着没有老板的神仙日子 官方出来回应了...

Yolov5 opencv DNN reasoning

MySQL interview questions (2022)

工具及方法 - Excel插件XLTools

Bisenetv1 face segmentation

A Youku VIP member account can be used by several people to log in at the same time. How to share multiple people using Youku member accounts?

JDBC连接Mysql数据库
随机推荐
Wechat applet -- Summary of problems in the actual development of taro framework
We should increase revenue and reduce expenditure
深入理解机器学习——类别不平衡学习(Imbalanced Learning):样本采样技术-[人工采样技术之ADASYN采样法]
Ubuntu clear CUDA cache
Affine transformation implementation
LETV has more than 400 employees? Living a fairy life without a boss, the official responded
【单片机仿真】(七)寻址方式 — 位寻址
Can't access this website can't find DNS address DNS_ PROBE_ What about started?
The WinRAR command copies the specified folder as a compressed file, and calls the scheduled task for backup.
JDBC connection to MySQL database
Wechat applet
Polynomial interpolation fitting (III)
[template record] string hash to judge palindrome string
04_服务注册Eureka
After 4 years of developing two-sided meituan, we finally lost: the interview question of volatile keyword function and principle
RTX3090安装pytorch3D
ES6 learning notes - brother Ma at station B
mysql复制表
MySQL multi table query
Go语言 实现发送短信验证码 并登录