提交 6dde460c authored 作者: nouiz's avatar nouiz

Merge pull request #396 from lamblin/fix_truediv_test

Fix test in TrueDivTester
......@@ -243,16 +243,19 @@ def makeTester(name, op, expected, checks={}, good={}, bad_build={},
or variable.shape != expected.shape
or numpy.any(abs(variable - expected) > eps)):
self.fail(("Test %s::%s: Output %s gave the wrong"
" value. With inputs %s, expected %s, got %s."
" numpy.allclose return %s %s") % (
" value. With inputs %s, expected %s (dtype %s),"
" got %s (dtype %s)."
" numpy.allclose returns %s %s") % (
self.op,
testname,
i,
inputs,
expected,
variable,
numpy.allclose(variable, expected, atol=eps),
numpy.allclose(variable, expected)))
expected.dtype,
variable,
variable.dtype,
numpy.allclose(variable, expected, atol=eps),
numpy.allclose(variable, expected)))
for description, check in self.checks.items():
if not check(inputs, variables):
......@@ -609,10 +612,22 @@ if config.floatX=='float32':
# This is probably caused by our way of computing the gradient error.
div_grad_rtol=0.025
def _numpy_true_div(x, y):
"""Performs true division, and cast the result in the type we expect.
We define that function so we can use it in TrueDivTester.expected,
because simply calling numpy.true_divide could cause a dtype mismatch.
"""
out = numpy.true_divide(x, y)
# Use floatX as the result of int / int
if x.dtype in tensor.discrete_dtypes and y.dtype in tensor.discrete_dtypes:
out = theano._asarray(out, dtype=config.floatX)
return out
TrueDivTester = makeBroadcastTester(
op=tensor.true_div,
expected = (lambda x, y:
check_floatX((x, y), numpy.true_divide(x, y))),
expected=_numpy_true_div,
good=_good_broadcast_div_mod_normal_float,
grad=_grad_broadcast_div_mod_normal,
grad_rtol=div_grad_rtol,
......@@ -620,7 +635,7 @@ TrueDivTester = makeBroadcastTester(
TrueDivInplaceTester = makeBroadcastTester(
op=inplace.true_div_inplace,
expected=(lambda x, y: numpy.true_divide(x, y)),
expected=_numpy_true_div,
good=copymod(
_good_broadcast_div_mod_normal_float_inplace,
# The output is now in float, we cannot work inplace on an int.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论