提交 d888cd81 authored 作者: Frederic Bastien's avatar Frederic Bastien

Add test on compute_test_value with __mul__. Fix test with this flag and make it…

Add test on compute_test_value with __mul__. Fix test with this flag and make it more resistent to change in code.
上级 866818d2
...@@ -173,6 +173,21 @@ class TestComputeTestValue(unittest.TestCase): ...@@ -173,6 +173,21 @@ class TestComputeTestValue(unittest.TestCase):
finally: finally:
theano.config.compute_test_value = orig_compute_test_value theano.config.compute_test_value = orig_compute_test_value
def test_overided_function(self):
# We need to test those as they mess with Exception
# And we don't want the exception to be changed.
orig_compute_test_value = theano.config.compute_test_value
try:
config.compute_test_value = "raise"
x = T.matrix()
x.tag.test_value = numpy.zeros((2,3))
y = T.matrix()
y.tag.test_value = numpy.zeros((2,2))
self.assertRaises(ValueError, x.__mul__, y)
finally:
theano.config.compute_test_value = orig_compute_test_value
def test_scan(self): def test_scan(self):
""" """
Test the compute_test_value mechanism Scan. Test the compute_test_value mechanism Scan.
...@@ -269,13 +284,7 @@ class TestComputeTestValue(unittest.TestCase): ...@@ -269,13 +284,7 @@ class TestComputeTestValue(unittest.TestCase):
n_steps=k) n_steps=k)
assert False assert False
except ValueError, e: except ValueError, e:
# Get traceback assert e.message.startswith("shape mismatch")
tb = sys.exc_info()[2]
# Get last frame info
frame_info = traceback.extract_tb(tb)[-1]
# We should be in scan_op.py, function 'perform'
assert os.path.split(frame_info[0])[1] == 'scan_op.py'
assert frame_info[2] == 'perform'
finally: finally:
theano.config.compute_test_value = orig_compute_test_value theano.config.compute_test_value = orig_compute_test_value
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论