提交 04eb6bf9 authored 作者: Firstname Lastname's avatar Firstname Lastname

branch merge

...@@ -173,6 +173,21 @@ class TestComputeTestValue(unittest.TestCase): ...@@ -173,6 +173,21 @@ class TestComputeTestValue(unittest.TestCase):
finally: finally:
theano.config.compute_test_value = orig_compute_test_value theano.config.compute_test_value = orig_compute_test_value
def test_overided_function(self):
# We need to test those as they mess with Exception
# And we don't want the exception to be changed.
orig_compute_test_value = theano.config.compute_test_value
try:
config.compute_test_value = "raise"
x = T.matrix()
x.tag.test_value = numpy.zeros((2,3))
y = T.matrix()
y.tag.test_value = numpy.zeros((2,2))
self.assertRaises(ValueError, x.__mul__, y)
finally:
theano.config.compute_test_value = orig_compute_test_value
def test_scan(self): def test_scan(self):
""" """
Test the compute_test_value mechanism Scan. Test the compute_test_value mechanism Scan.
...@@ -269,13 +284,7 @@ class TestComputeTestValue(unittest.TestCase): ...@@ -269,13 +284,7 @@ class TestComputeTestValue(unittest.TestCase):
n_steps=k) n_steps=k)
assert False assert False
except ValueError, e: except ValueError, e:
# Get traceback assert e.message.startswith("shape mismatch")
tb = sys.exc_info()[2]
# Get last frame info
frame_info = traceback.extract_tb(tb)[-1]
# We should be in scan_op.py, function 'perform'
assert os.path.split(frame_info[0])[1] == 'scan_op.py'
assert frame_info[2] == 'perform'
finally: finally:
theano.config.compute_test_value = orig_compute_test_value theano.config.compute_test_value = orig_compute_test_value
......
...@@ -1132,11 +1132,13 @@ class _tensor_py_operators: ...@@ -1132,11 +1132,13 @@ class _tensor_py_operators:
def __add__(self,other): def __add__(self,other):
try: try:
return add(self,other) return add(self,other)
# We should catch only NotImplementedError. # We should catch the minimum number of exception here.
# As otherwise this will transfer error to the wrong error # Otherwise this will convert error when Theano flags
# Theano flags compute_test_value need to be able to raise # compute_test_value is used
# other type of error. # Evidently, we need to catch NotImplementedError
except NotImplementedError, e: # But we also need to catch TypeError
# Oterwise TensorVariable * SparseVariable won't work!
except (NotImplementedError, TypeError), e:
# We must return NotImplemented and not an # We must return NotImplemented and not an
# NotImplementedError or raise an NotImplementedError. # NotImplementedError or raise an NotImplementedError.
# That way python will give a good error message like this # That way python will give a good error message like this
...@@ -1148,14 +1150,14 @@ class _tensor_py_operators: ...@@ -1148,14 +1150,14 @@ class _tensor_py_operators:
# adn the return value in that case # adn the return value in that case
try: try:
return sub(self,other) return sub(self,other)
except NotImplementedError, e: except (NotImplementedError, TypeError), e:
return NotImplemented return NotImplemented
def __mul__(self,other): def __mul__(self,other):
# See explanation in __add__ for the error catched # See explanation in __add__ for the error catched
# adn the return value in that case # adn the return value in that case
try: try:
return mul(self,other) return mul(self,other)
except NotImplementedError, e: except (NotImplementedError, TypeError), e:
return NotImplemented return NotImplemented
def __div__(self,other): def __div__(self,other):
# See explanation in __add__ for the error catched # See explanation in __add__ for the error catched
...@@ -1166,14 +1168,14 @@ class _tensor_py_operators: ...@@ -1166,14 +1168,14 @@ class _tensor_py_operators:
# This is to raise the exception that occurs when trying to divide # This is to raise the exception that occurs when trying to divide
# two integer arrays (currently forbidden). # two integer arrays (currently forbidden).
raise raise
except NotImplementedError, e: except (NotImplementedError, TypeError), e:
return NotImplemented return NotImplemented
def __pow__(self,other): def __pow__(self,other):
# See explanation in __add__ for the error catched # See explanation in __add__ for the error catched
# adn the return value in that case # adn the return value in that case
try: try:
return pow(self,other) return pow(self,other)
except NotImplementedError, e: except (NotImplementedError, TypeError), e:
return NotImplemented return NotImplemented
def __mod__(self,other): def __mod__(self,other):
# See explanation in __add__ for the error catched # See explanation in __add__ for the error catched
...@@ -1184,7 +1186,7 @@ class _tensor_py_operators: ...@@ -1184,7 +1186,7 @@ class _tensor_py_operators:
# This is to raise the exception that occurs when trying to compute # This is to raise the exception that occurs when trying to compute
# x % y with either x or y a complex number. # x % y with either x or y a complex number.
raise raise
except NotImplementedError, e: except (NotImplementedError, TypeError), e:
return NotImplemented return NotImplemented
def __truediv__(self,other): return true_div(self, other) def __truediv__(self,other): return true_div(self, other)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论