提交 ba9de0e0 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add more informative message in verify_grad when input is not float.

上级 185d62c2
......@@ -5149,6 +5149,11 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None, abs_tol=None, rel_tol=No
assert isinstance(pt, (list,tuple))
pt = [numpy.array(p) for p in pt]
for i, p in enumerate(pt):
if p.dtype not in ('float32', 'float64'):
raise TypeError(('verify_grad can work only with floating point '
'inputs, but input %i has dtype "%s".') % (i, p.dtype))
_type_tol = dict( # relativ error tolerances for different types
float32=1e-2,
float64=1e-4)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论