提交 afca8297 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Fixed some tests when config.cast_policy == numpy+floatX and config.floatX == float32

上级 00ecd70d
...@@ -201,7 +201,7 @@ def makeTester(name, op, expected, checks = {}, good = {}, bad_build = {}, ...@@ -201,7 +201,7 @@ def makeTester(name, op, expected, checks = {}, good = {}, bad_build = {},
if not isinstance(expecteds, (list, tuple)): if not isinstance(expecteds, (list, tuple)):
expecteds = (expecteds, ) expecteds = (expecteds, )
for i, (variable, expected) in enumerate(zip(variables, expecteds)): for i, (variable, expected) in enumerate(izip(variables, expecteds)):
if variable.dtype != expected.dtype or variable.shape != expected.shape or \ if variable.dtype != expected.dtype or variable.shape != expected.shape or \
numpy.any(numpy.abs(variable - expected) > eps): numpy.any(numpy.abs(variable - expected) > eps):
self.fail("Test %s::%s: Output %s gave the wrong value. With inputs %s, expected %s, got %s. numpy.allclose return %s %s" self.fail("Test %s::%s: Output %s gave the wrong value. With inputs %s, expected %s, got %s. numpy.allclose return %s %s"
...@@ -347,8 +347,31 @@ _grad_broadcast_binary_normal = dict(same_shapes = (rand(2, 3), rand(2, 3)), ...@@ -347,8 +347,31 @@ _grad_broadcast_binary_normal = dict(same_shapes = (rand(2, 3), rand(2, 3)),
) )
def check_floatX(inputs, rval):
"""
:param inputs: Inputs to a function that returned `rval` with these inputs.
:param rval: Value returned by a function with inputs set to `inputs`.
:returns: Either `rval` unchanged, or `rval` cast in float32. The idea is
that when a numpy function would have returned a float64, Theano may prefer
to return a float32 instead when `config.cast_policy` is set to
'numpy+floatX' and config.floatX to 'float32', and there was no float64
input.
"""
if (isinstance(rval, numpy.ndarray) and
rval.dtype == 'float64' and
config.cast_policy == 'numpy+floatX'
and config.floatX == 'float32' and
all(x.dtype != 'float64' for x in inputs)):
# Then we expect float32 instead of float64.
return rval.astype('float32')
else:
return rval
AddTester = makeBroadcastTester(op = add, AddTester = makeBroadcastTester(op = add,
expected = lambda *inputs: reduce(lambda x, y: x + y, inputs), expected = lambda *inputs: check_floatX(inputs, reduce(lambda x, y: x + y, inputs)),
good = dict(three_inputs_same_shapes = (rand(2, 3), rand(2, 3), rand(2, 3)), good = dict(three_inputs_same_shapes = (rand(2, 3), rand(2, 3), rand(2, 3)),
four_inputs_broadcast = (rand(2, 3), rand(1, 3), rand(2, 1), rand(1, 1)), four_inputs_broadcast = (rand(2, 3), rand(1, 3), rand(2, 1), rand(1, 1)),
**_good_broadcast_binary_normal), **_good_broadcast_binary_normal),
...@@ -364,7 +387,7 @@ AddInplaceTester = makeBroadcastTester(op = inplace.add_inplace, ...@@ -364,7 +387,7 @@ AddInplaceTester = makeBroadcastTester(op = inplace.add_inplace,
inplace = True) inplace = True)
SubTester = makeBroadcastTester(op = sub, SubTester = makeBroadcastTester(op = sub,
expected = lambda x, y: x - y, expected = lambda x, y: check_floatX((x, y), x - y),
good = _good_broadcast_binary_normal, good = _good_broadcast_binary_normal,
bad_build = _bad_build_broadcast_binary_normal, bad_build = _bad_build_broadcast_binary_normal,
bad_runtime = _bad_runtime_broadcast_binary_normal, bad_runtime = _bad_runtime_broadcast_binary_normal,
...@@ -379,7 +402,7 @@ SubInplaceTester = makeBroadcastTester(op = inplace.sub_inplace, ...@@ -379,7 +402,7 @@ SubInplaceTester = makeBroadcastTester(op = inplace.sub_inplace,
inplace = True) inplace = True)
MaximumTester = makeBroadcastTester(op = maximum, MaximumTester = makeBroadcastTester(op = maximum,
expected = numpy.maximum, expected = lambda *inputs: check_floatX(inputs, numpy.maximum(*inputs)),
good = _good_broadcast_binary_normal, good = _good_broadcast_binary_normal,
bad_build = _bad_build_broadcast_binary_normal, bad_build = _bad_build_broadcast_binary_normal,
bad_runtime = _bad_runtime_broadcast_binary_normal, bad_runtime = _bad_runtime_broadcast_binary_normal,
...@@ -394,7 +417,7 @@ MaximumInplaceTester = makeBroadcastTester(op = inplace.maximum_inplace, ...@@ -394,7 +417,7 @@ MaximumInplaceTester = makeBroadcastTester(op = inplace.maximum_inplace,
inplace = True) inplace = True)
MinimumTester = makeBroadcastTester(op = minimum, MinimumTester = makeBroadcastTester(op = minimum,
expected = numpy.minimum, expected = lambda *inputs: check_floatX(inputs, numpy.minimum(*inputs)),
good = _good_broadcast_binary_normal, good = _good_broadcast_binary_normal,
bad_build = _bad_build_broadcast_binary_normal, bad_build = _bad_build_broadcast_binary_normal,
bad_runtime = _bad_runtime_broadcast_binary_normal, bad_runtime = _bad_runtime_broadcast_binary_normal,
...@@ -409,7 +432,7 @@ MinimumInplaceTester = makeBroadcastTester(op = inplace.minimum_inplace, ...@@ -409,7 +432,7 @@ MinimumInplaceTester = makeBroadcastTester(op = inplace.minimum_inplace,
inplace = True) inplace = True)
MulTester = makeBroadcastTester(op = mul, MulTester = makeBroadcastTester(op = mul,
expected = lambda *inputs: reduce(lambda x, y: x * y, inputs), expected = lambda *inputs: check_floatX(inputs, reduce(lambda x, y: x * y, inputs)),
good = dict(three_inputs_same_shapes = (rand(2, 3), rand(2, 3), rand(2, 3)), good = dict(three_inputs_same_shapes = (rand(2, 3), rand(2, 3), rand(2, 3)),
four_inputs_broadcast = (rand(2, 3), rand(1, 3), rand(2, 1), rand(1, 1)), four_inputs_broadcast = (rand(2, 3), rand(1, 3), rand(2, 1), rand(1, 1)),
**_good_broadcast_binary_normal), **_good_broadcast_binary_normal),
...@@ -475,7 +498,7 @@ if config.floatX=='float32': ...@@ -475,7 +498,7 @@ if config.floatX=='float32':
# This is probably caused by our way of computing the gradient error. # This is probably caused by our way of computing the gradient error.
div_grad_rtol=0.025 div_grad_rtol=0.025
DivTester = makeBroadcastTester(op = true_div, DivTester = makeBroadcastTester(op = true_div,
expected = lambda x, y: x / y, expected = lambda x, y: check_floatX((x, y), x / y),
good = _good_broadcast_div_mod_normal_float, good = _good_broadcast_div_mod_normal_float,
# integers = (randint(2, 3), randint_nonzero(2, 3)), # integers = (randint(2, 3), randint_nonzero(2, 3)),
# dtype_mixup_1 = (rand(2, 3), randint_nonzero(2, 3)), # dtype_mixup_1 = (rand(2, 3), randint_nonzero(2, 3)),
...@@ -527,7 +550,7 @@ _good_broadcast_pow_normal_float_pow = copy(_good_broadcast_pow_normal_float) ...@@ -527,7 +550,7 @@ _good_broadcast_pow_normal_float_pow = copy(_good_broadcast_pow_normal_float)
del _good_broadcast_pow_normal_float_pow["empty2"] del _good_broadcast_pow_normal_float_pow["empty2"]
PowTester = makeBroadcastTester(op = pow, PowTester = makeBroadcastTester(op = pow,
expected = lambda x, y: x ** y, expected = lambda x, y: check_floatX((x, y), x ** y),
good = _good_broadcast_pow_normal_float, good = _good_broadcast_pow_normal_float,
grad = _grad_broadcast_pow_normal) grad = _grad_broadcast_pow_normal)
PowInplaceTester = makeBroadcastTester(op = inplace.pow_inplace, PowInplaceTester = makeBroadcastTester(op = inplace.pow_inplace,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论