当前位置:网站首页>tensorflow clip对NaN、inf的效果
tensorflow clip对NaN、inf的效果
2022-07-17 23:11:00 【HackerTom】
一个训练不稳定的模型如 [1],训练时梯度可能会出现 NaN。之前为了 debug 加了 check_numerics[2],但这会直接报错退出。用 clipping 稳定一下训练,但要将 NaN 去掉先。
这里记录下 tensorflow 中两种 clipping[3,4] 在有 NaN 和 inf 时的效果。结论:
clip_by_value可以将 inf 截断成正常值,但对 NaN 无效;- 有 NaN 时,
clip_by_norm会将整个向量变成 NaN(因为算 norm 有 NaN,除 NaN 得 NaN); - 有 inf 无 NaN 时,
clip_by_norm好像没有效果?
也许可以将 inf、NaN 置零先,然而这样可能会影响优化方向,就再 clip 一下,避免一次大步过头越练越差?
Code
import math
import tensorflow as tf
def zero_inf_nan(grad):
"""将 NaN、inf 置零"""
if grad is None:
return grad
_cond = tf.is_nan(grad) | tf.is_inf(grad)
return tf.where(_cond, tf.zeros_like(grad), grad)
with tf.Session() as sess:
# 有 NaN 有 inf
a = tf.constant([10, math.nan, math.inf, - math.inf], tf.float32)
b = tf.clip_by_value(a, -1, 1)
c = tf.clip_by_norm(a, 5)
print(sess.run([a, b, c]))
# 去掉之后
d = zero_inf_nan(a)
e = tf.clip_by_value(d, -1, 1)
f = tf.clip_by_norm(d, 5)
print(sess.run([d, e, f]))
# 有 inf 无 NaN 用 clip_by_norm
g = tf.constant([20, math.inf, - math.inf], tf.float32)
h = tf.clip_by_norm(g, 5)
print(sess.run(g))
- 输出
[array([ 10., nan, inf, -inf], dtype=float32), array([ 1., nan, 1., -1.], dtype=float32), array([nan, nan, nan, nan], dtype=float32)]
[array([10., 0., 0., 0.], dtype=float32), array([1., 0., 0., 0.], dtype=float32), array([5., 0., 0., 0.], dtype=float32)]
[ 20. inf -inf]
References
边栏推荐
猜你喜欢

【花雕动手做】有趣好玩的音乐可视化项目(11)---WS2812幻彩灯带

Leetcode 1296. Divide the array into a set of consecutive numbers (solved)

06_ Service call feign

ZABBIX realizes the monitoring of redis

文档型全文检索知识库管理系统源码

Leetcode 1275. 找出井字棋的獲勝者

Google Earth engine - Classification and processing of UAV images

GYM103660H.Distance

Li Hongyi machine learning 2022.7.15 -- gradient descent

解决jupyter控制台出现中文乱码的问题
随机推荐
原始套接字
Wechat applet 7 cloud storage
马走斜日(回溯法)
天勤第九章课后习题代码
[microservice] microservice learning note 3: use feign to replace resttemplate to complete remote call
Codeforces round 807 (Div. 2) e. mark and Professor Koro binary / segment tree
UVA340 Master-Mind Hints
ARM系统调用异常 汇编
ICML2022 | 几何多模态对比表示学习
MySQL 安装
Field programmable logic gate array FPGA
全排列(深度优先,排列树)
Li Hongyi machine learning 2022.7.15 -- gradient descent
Leetcode 1275. Trouver le vainqueur de "Jingzi"
Leetcode 1275. 找出井字棋的获胜者
初试Dart,笔记
买股票开户应该选哪个证券公司?什么证券公司是更安全的
天天基金上买基金是安全的吗?在线等答案
[code attached] how to realize handwritten digit recognition with hog+svm
GYM103660E. Disjoint path on tree count