提交 cd55efb3 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add helper functions for specifying the "expected" values

Otherwise, the expected values could be computed in float16, which would not be precise enough, and which would not be supported as a dtype by Theano.
上级 2aedf238
......@@ -189,6 +189,50 @@ def safe_make_node(op, *inputs):
return node.owner
def upcast_float16_ufunc(fn):
"""Decorator that enforces computation is not done in float16 by NumPy.
Some ufuncs in NumPy will compute float values on int8 and uint8
in half-precision (float16), which is not enough, and not compatible
with the C code.
:param fn: numpy ufunc
:returns: function similar to fn.__call__, computing the same
value with a minimum floating-point precision of float32
"""
def ret(*args, **kwargs):
out_dtype = numpy.find_common_type(
[a.dtype for a in args], [numpy.float16])
if out_dtype == 'float16':
# Force everything to float32
sig = 'f' * fn.nin + '->' + 'f' * fn.nout
kwargs.update(sig=sig)
return fn(*args, **kwargs)
return ret
def upcast_int8_nfunc(fn):
"""Decorator that upcasts input of dtype int8 to float32.
This is so that floating-point computation is not carried using
half-precision (float16), as some NumPy functions do.
:param fn: function computing a floating-point value from inputs
:returns: function similar to fn, but upcasting its uint8 and int8
inputs before carrying out the computation.
"""
def ret(*args, **kwargs):
args = list(args)
for i, a in enumerate(args):
if getattr(a, 'dtype', None) in ('int8', 'uint8'):
args[i] = a.astype('float32')
return fn(*args, **kwargs)
return ret
def makeTester(name, op, expected, checks=None, good=None, bad_build=None,
bad_runtime=None, grad=None, mode=None, grad_rtol=None,
eps=1e-10, skip=False, test_memmap=True, check_name=True):
......@@ -321,7 +365,8 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None,
expecteds = self.expected(*inputs)
eps = 1e-10
if any([i.dtype == 'float32' for i in inputs]):
if any([i.dtype in ('float32', 'int8', 'uint8')
for i in inputs]):
eps = 1e-6
eps = numpy.max([eps, _eps])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论