提交 bc355bfe authored 作者: Frederic Bastien's avatar Frederic Bastien

Make tests less verbose

上级 f636d83a
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
import sys
import numpy as np import numpy as np
from six.moves import StringIO
import theano import theano
...@@ -23,7 +26,12 @@ def test_detect_nan(): ...@@ -23,7 +26,12 @@ def test_detect_nan():
f = theano.function([x], [theano.tensor.log(x) * x], f = theano.function([x], [theano.tensor.log(x) * x],
mode=theano.compile.MonitorMode( mode=theano.compile.MonitorMode(
post_func=detect_nan)) post_func=detect_nan))
f(0) # log(0) * 0 = -inf * 0 = NaN try:
old_stdout = sys.stdout
sys.stdout = StringIO()
f(0) # log(0) * 0 = -inf * 0 = NaN
finally:
sys.stdout = old_stdout
assert nan_detected[0] assert nan_detected[0]
...@@ -49,7 +57,12 @@ def test_optimizer(): ...@@ -49,7 +57,12 @@ def test_optimizer():
mode=mode) mode=mode)
# Test that the fusion wasn't done # Test that the fusion wasn't done
assert len(f.maker.fgraph.apply_nodes) == 2 assert len(f.maker.fgraph.apply_nodes) == 2
f(0) # log(0) * 0 = -inf * 0 = NaN try:
old_stdout = sys.stdout
sys.stdout = StringIO()
f(0) # log(0) * 0 = -inf * 0 = NaN
finally:
sys.stdout = old_stdout
# Test that we still detect the nan # Test that we still detect the nan
assert nan_detected[0] assert nan_detected[0]
...@@ -83,7 +96,12 @@ def test_not_inplace(): ...@@ -83,7 +96,12 @@ def test_not_inplace():
# Test that the fusion wasn't done # Test that the fusion wasn't done
assert len(f.maker.fgraph.apply_nodes) == 5 assert len(f.maker.fgraph.apply_nodes) == 5
assert not f.maker.fgraph.toposort()[-1].op.destroy_map assert not f.maker.fgraph.toposort()[-1].op.destroy_map
f([0, 0]) # log(0) * 0 = -inf * 0 = NaN try:
old_stdout = sys.stdout
sys.stdout = StringIO()
f([0, 0]) # log(0) * 0 = -inf * 0 = NaN
finally:
sys.stdout = old_stdout
# Test that we still detect the nan # Test that we still detect the nan
assert nan_detected[0] assert nan_detected[0]
...@@ -59,7 +59,6 @@ def test_elemwise_pow(): ...@@ -59,7 +59,6 @@ def test_elemwise_pow():
assert exp.dtype == dtype_exp assert exp.dtype == dtype_exp
output = base ** exp output = base ** exp
f = theano.function([base], output, mode=mode_with_gpu) f = theano.function([base], output, mode=mode_with_gpu)
theano.printing.debugprint(f)
# We don't transfer to the GPU when the output dtype is int* # We don't transfer to the GPU when the output dtype is int*
n = len([n for n in f.maker.fgraph.apply_nodes n = len([n for n in f.maker.fgraph.apply_nodes
if isinstance(n.op, GpuElemwise)]) if isinstance(n.op, GpuElemwise)])
......
...@@ -820,7 +820,7 @@ def test_maximum_minimum_grad(): ...@@ -820,7 +820,7 @@ def test_maximum_minimum_grad():
for op in [tensor.maximum, tensor.minimum]: for op in [tensor.maximum, tensor.minimum]:
o = op(x, y) o = op(x, y)
g = theano.grad(o.sum(), [x, y]) g = theano.grad(o.sum(), [x, y])
theano.printing.debugprint(g)
f = theano.function([x, y], g) f = theano.function([x, y], g)
assert np.allclose(f([1], [1]), [[1], [0]]) assert np.allclose(f([1], [1]), [[1], [0]])
...@@ -7789,7 +7789,7 @@ class TestSpecifyShape(unittest.TestCase): ...@@ -7789,7 +7789,7 @@ class TestSpecifyShape(unittest.TestCase):
f(xval) f(xval)
xval = np.random.rand(3).astype(floatX) xval = np.random.rand(3).astype(floatX)
self.assertRaises(AssertionError, f, xval) self.assertRaises(AssertionError, f, xval)
theano.printing.debugprint(f)
assert isinstance([n for n in f.maker.fgraph.toposort() assert isinstance([n for n in f.maker.fgraph.toposort()
if isinstance(n.op, SpecifyShape)][0].inputs[0].type, if isinstance(n.op, SpecifyShape)][0].inputs[0].type,
self.input_type) self.input_type)
......
...@@ -1116,7 +1116,6 @@ class test_fusion(unittest.TestCase): ...@@ -1116,7 +1116,6 @@ class test_fusion(unittest.TestCase):
nb_elemwise, answer, out_dtype] in enumerate(cases): nb_elemwise, answer, out_dtype] in enumerate(cases):
if isinstance(out_dtype, dict): if isinstance(out_dtype, dict):
out_dtype = out_dtype[config.cast_policy] out_dtype = out_dtype[config.cast_policy]
print("new cases", id)
if shared_fn is None: if shared_fn is None:
f = compile.function(list(sym_inputs), g, mode=mode) f = compile.function(list(sym_inputs), g, mode=mode)
...@@ -1139,6 +1138,7 @@ class test_fusion(unittest.TestCase): ...@@ -1139,6 +1138,7 @@ class test_fusion(unittest.TestCase):
atol = 1e-6 atol = 1e-6
if not np.allclose(out, answer * nb_repeat, atol=atol): if not np.allclose(out, answer * nb_repeat, atol=atol):
fail1.append(id) fail1.append(id)
print("cases", id)
print(val_inputs) print(val_inputs)
print(out) print(out)
print(answer * nb_repeat) print(answer * nb_repeat)
...@@ -1163,7 +1163,8 @@ class test_fusion(unittest.TestCase): ...@@ -1163,7 +1163,8 @@ class test_fusion(unittest.TestCase):
fail4.append((id, out_dtype, out.dtype)) fail4.append((id, out_dtype, out.dtype))
failed = len(fail1 + fail2 + fail3 + fail4) failed = len(fail1 + fail2 + fail3 + fail4)
print("Executed", len(cases), "cases", "failed", failed) if failed > 0:
print("Executed", len(cases), "cases", "failed", failed)
if failed > 0: if failed > 0:
raise Exception("Failed %d cases" % failed, fail1, raise Exception("Failed %d cases" % failed, fail1,
fail2, fail3, fail4) fail2, fail3, fail4)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论