提交 59196995 authored 作者: Shawn Tan's avatar Shawn Tan

Modify tests for nlinalg ExtractDiag

- Added ValueError exception - Remove failure test for tensor3
上级 92460ef4
...@@ -349,15 +349,6 @@ class test_diag(unittest.TestCase): ...@@ -349,15 +349,6 @@ class test_diag(unittest.TestCase):
y = extract_diag(x) y = extract_diag(x)
assert y.owner.op.__class__ == ExtractDiag assert y.owner.op.__class__ == ExtractDiag
# other types should raise error
x = theano.tensor.tensor3()
ok = False
try:
y = extract_diag(x)
except TypeError:
ok = True
assert ok
# not testing the view=True case since it is not used anywhere. # not testing the view=True case since it is not used anywhere.
def test_extract_diag(self): def test_extract_diag(self):
rng = np.random.RandomState(utt.fetch_seed()) rng = np.random.RandomState(utt.fetch_seed())
...@@ -384,6 +375,8 @@ class test_diag(unittest.TestCase): ...@@ -384,6 +375,8 @@ class test_diag(unittest.TestCase):
extract_diag(xx) extract_diag(xx)
except TypeError: except TypeError:
ok = True ok = True
except ValueError:
ok = True
assert ok assert ok
# Test infer_shape # Test infer_shape
...@@ -429,6 +422,9 @@ def test_trace(): ...@@ -429,6 +422,9 @@ def test_trace():
trace(xx) trace(xx)
except TypeError: except TypeError:
ok = True ok = True
except ValueError:
ok = True
assert ok assert ok
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论