提交 83a8b1f4 authored 作者: Reyhane Askari's avatar Reyhane Askari

changed expectedFailure to assertFailure

上级 54713094
...@@ -11,7 +11,7 @@ from theano.gof.opt import (OpKeyOptimizer, PatternSub, NavigatorOptimizer, ...@@ -11,7 +11,7 @@ from theano.gof.opt import (OpKeyOptimizer, PatternSub, NavigatorOptimizer,
from theano.gof import destroyhandler from theano.gof import destroyhandler
from theano.gof.fg import FunctionGraph, InconsistencyError from theano.gof.fg import FunctionGraph, InconsistencyError
from theano.gof.toolbox import ReplaceValidate from theano.gof.toolbox import ReplaceValidate
from theano.tests.unittest_tools import expectedFailure_fast from theano.tests.unittest_tools import assertFailure_fast
from theano.configparser import change_flags from theano.configparser import change_flags
...@@ -170,7 +170,7 @@ def test_misc(): ...@@ -170,7 +170,7 @@ def test_misc():
###################### ######################
@expectedFailure_fast @assertFailure_fast
def test_aliased_inputs_replacement(): def test_aliased_inputs_replacement():
x, y, z = inputs() x, y, z = inputs()
tv = transpose_view(x) tv = transpose_view(x)
...@@ -202,7 +202,7 @@ def test_indestructible(): ...@@ -202,7 +202,7 @@ def test_indestructible():
consistent(g) consistent(g)
@expectedFailure_fast @assertFailure_fast
def test_usage_loop_through_views_2(): def test_usage_loop_through_views_2():
x, y, z = inputs() x, y, z = inputs()
e0 = transpose_view(transpose_view(sigmoid(x))) e0 = transpose_view(transpose_view(sigmoid(x)))
...@@ -213,7 +213,7 @@ def test_usage_loop_through_views_2(): ...@@ -213,7 +213,7 @@ def test_usage_loop_through_views_2():
inconsistent(g) # we cut off the path to the sigmoid inconsistent(g) # we cut off the path to the sigmoid
@expectedFailure_fast @assertFailure_fast
def test_destroyers_loop(): def test_destroyers_loop():
# AddInPlace(x, y) and AddInPlace(y, x) should not coexist # AddInPlace(x, y) and AddInPlace(y, x) should not coexist
x, y, z = inputs() x, y, z = inputs()
...@@ -263,7 +263,7 @@ def test_aliased_inputs2(): ...@@ -263,7 +263,7 @@ def test_aliased_inputs2():
inconsistent(g) inconsistent(g)
@expectedFailure_fast @assertFailure_fast
def test_aliased_inputs_tolerate(): def test_aliased_inputs_tolerate():
x, y, z = inputs() x, y, z = inputs()
e = add_in_place_2(x, x) e = add_in_place_2(x, x)
...@@ -278,7 +278,7 @@ def test_aliased_inputs_tolerate2(): ...@@ -278,7 +278,7 @@ def test_aliased_inputs_tolerate2():
inconsistent(g) inconsistent(g)
@expectedFailure_fast @assertFailure_fast
def test_same_aliased_inputs_ignored(): def test_same_aliased_inputs_ignored():
x, y, z = inputs() x, y, z = inputs()
e = add_in_place_3(x, x) e = add_in_place_3(x, x)
...@@ -286,7 +286,7 @@ def test_same_aliased_inputs_ignored(): ...@@ -286,7 +286,7 @@ def test_same_aliased_inputs_ignored():
consistent(g) consistent(g)
@expectedFailure_fast @assertFailure_fast
def test_different_aliased_inputs_ignored(): def test_different_aliased_inputs_ignored():
x, y, z = inputs() x, y, z = inputs()
e = add_in_place_3(x, transpose_view(x)) e = add_in_place_3(x, transpose_view(x))
...@@ -321,7 +321,7 @@ def test_indirect(): ...@@ -321,7 +321,7 @@ def test_indirect():
inconsistent(g) inconsistent(g)
@expectedFailure_fast @assertFailure_fast
def test_indirect_2(): def test_indirect_2():
x, y, z = inputs() x, y, z = inputs()
e0 = transpose_view(x) e0 = transpose_view(x)
...@@ -333,7 +333,7 @@ def test_indirect_2(): ...@@ -333,7 +333,7 @@ def test_indirect_2():
consistent(g) consistent(g)
@expectedFailure_fast @assertFailure_fast
def test_long_destroyers_loop(): def test_long_destroyers_loop():
x, y, z = inputs() x, y, z = inputs()
e = dot(dot(add_in_place(x, y), e = dot(dot(add_in_place(x, y),
...@@ -375,7 +375,7 @@ def test_multi_destroyers(): ...@@ -375,7 +375,7 @@ def test_multi_destroyers():
pass pass
@expectedFailure_fast @assertFailure_fast
def test_multi_destroyers_through_views(): def test_multi_destroyers_through_views():
x, y, z = inputs() x, y, z = inputs()
e = dot(add(transpose_view(z), y), add(z, x)) e = dot(add(transpose_view(z), y), add(z, x))
...@@ -418,7 +418,7 @@ def test_usage_loop_through_views(): ...@@ -418,7 +418,7 @@ def test_usage_loop_through_views():
consistent(g) consistent(g)
@expectedFailure_fast @assertFailure_fast
def test_usage_loop_insert_views(): def test_usage_loop_insert_views():
x, y, z = inputs() x, y, z = inputs()
e = dot(add_in_place(x, add(y, z)), e = dot(add_in_place(x, add(y, z)),
...@@ -453,7 +453,7 @@ def test_value_repl_2(): ...@@ -453,7 +453,7 @@ def test_value_repl_2():
consistent(g) consistent(g)
@expectedFailure_fast @assertFailure_fast
def test_multiple_inplace(): def test_multiple_inplace():
# this tests issue #5223 # this tests issue #5223
# there were some problems with Ops that have more than # there were some problems with Ops that have more than
......
...@@ -1754,7 +1754,7 @@ def test_without_dnn_batchnorm_train_without_running_averages(): ...@@ -1754,7 +1754,7 @@ def test_without_dnn_batchnorm_train_without_running_averages():
f_abstract(X, Scale, Bias, Dy) f_abstract(X, Scale, Bias, Dy)
@utt.expectedFailure_fast @utt.assertFailure_fast
def test_dnn_batchnorm_train_inplace(): def test_dnn_batchnorm_train_inplace():
# test inplace_running_mean and inplace_running_var # test inplace_running_mean and inplace_running_var
if not dnn.dnn_available(test_ctx_name): if not dnn.dnn_available(test_ctx_name):
...@@ -1877,7 +1877,7 @@ def test_batchnorm_inference(): ...@@ -1877,7 +1877,7 @@ def test_batchnorm_inference():
utt.assert_allclose(outputs_abstract[5], outputs_ref[5], rtol=2e-3, atol=4e-5) # dvar utt.assert_allclose(outputs_abstract[5], outputs_ref[5], rtol=2e-3, atol=4e-5) # dvar
@utt.expectedFailure_fast @utt.assertFailure_fast
def test_batchnorm_inference_inplace(): def test_batchnorm_inference_inplace():
# test inplace # test inplace
if not dnn.dnn_available(test_ctx_name): if not dnn.dnn_available(test_ctx_name):
......
...@@ -175,7 +175,7 @@ class TestGpuCholesky(unittest.TestCase): ...@@ -175,7 +175,7 @@ class TestGpuCholesky(unittest.TestCase):
GpuCholesky(lower=True, inplace=False)(A) GpuCholesky(lower=True, inplace=False)(A)
self.assertRaises(AssertionError, invalid_input_func) self.assertRaises(AssertionError, invalid_input_func)
@utt.expectedFailure_fast @utt.assertFailure_fast
def test_diag_chol(self): def test_diag_chol(self):
# Diagonal matrix input Cholesky test. # Diagonal matrix input Cholesky test.
for lower in [True, False]: for lower in [True, False]:
...@@ -184,7 +184,7 @@ class TestGpuCholesky(unittest.TestCase): ...@@ -184,7 +184,7 @@ class TestGpuCholesky(unittest.TestCase):
A_val = np.diag(np.random.uniform(size=5).astype("float32") + 1) A_val = np.diag(np.random.uniform(size=5).astype("float32") + 1)
self.compare_gpu_cholesky_to_np(A_val, lower=lower, inplace=inplace) self.compare_gpu_cholesky_to_np(A_val, lower=lower, inplace=inplace)
@utt.expectedFailure_fast @utt.assertFailure_fast
def test_dense_chol_lower(self): def test_dense_chol_lower(self):
# Dense matrix input lower-triangular Cholesky test. # Dense matrix input lower-triangular Cholesky test.
for lower in [True, False]: for lower in [True, False]:
......
...@@ -582,7 +582,7 @@ def test_no_complex(): ...@@ -582,7 +582,7 @@ def test_no_complex():
mode=mode_with_gpu) mode=mode_with_gpu)
@utt.expectedFailure_fast @utt.assertFailure_fast
def test_local_lift_solve(): def test_local_lift_solve():
if not cusolver_available: if not cusolver_available:
raise SkipTest('No cuSolver') raise SkipTest('No cuSolver')
...@@ -617,7 +617,7 @@ def test_gpu_solve_not_inplace(): ...@@ -617,7 +617,7 @@ def test_gpu_solve_not_inplace():
utt.assert_allclose(f_cpu(A_val, b_val), f_gpu(A_val, b_val)) utt.assert_allclose(f_cpu(A_val, b_val), f_gpu(A_val, b_val))
@utt.expectedFailure_fast @utt.assertFailure_fast
def test_local_lift_cholesky(): def test_local_lift_cholesky():
if not cusolver_available: if not cusolver_available:
raise SkipTest('No cuSolver') raise SkipTest('No cuSolver')
......
...@@ -886,7 +886,7 @@ class T_Scan(unittest.TestCase): ...@@ -886,7 +886,7 @@ class T_Scan(unittest.TestCase):
utt.assert_allclose(numpy_out, theano_out) utt.assert_allclose(numpy_out, theano_out)
# simple rnn ; compute inplace version 1 # simple rnn ; compute inplace version 1
@utt.expectedFailure_fast @utt.assertFailure_fast
def test_inplace1(self): def test_inplace1(self):
rng = np.random.RandomState(utt.fetch_seed()) rng = np.random.RandomState(utt.fetch_seed())
vW = asarrayX(np.random.uniform()) vW = asarrayX(np.random.uniform())
...@@ -951,7 +951,7 @@ class T_Scan(unittest.TestCase): ...@@ -951,7 +951,7 @@ class T_Scan(unittest.TestCase):
utt.assert_allclose(theano_x1, numpy_x1) utt.assert_allclose(theano_x1, numpy_x1)
# simple rnn ; compute inplace version 2 # simple rnn ; compute inplace version 2
@utt.expectedFailure_fast @utt.assertFailure_fast
def test_inplace2(self): def test_inplace2(self):
rng = np.random.RandomState(utt.fetch_seed()) rng = np.random.RandomState(utt.fetch_seed())
vW = asarrayX(np.random.uniform()) vW = asarrayX(np.random.uniform())
...@@ -1023,7 +1023,7 @@ class T_Scan(unittest.TestCase): ...@@ -1023,7 +1023,7 @@ class T_Scan(unittest.TestCase):
utt.assert_allclose(theano_x0, numpy_x0) utt.assert_allclose(theano_x0, numpy_x0)
utt.assert_allclose(theano_x1, numpy_x1) utt.assert_allclose(theano_x1, numpy_x1)
@utt.expectedFailure_fast @utt.assertFailure_fast
def test_inplace3(self): def test_inplace3(self):
rng = np.random.RandomState(utt.fetch_seed()) rng = np.random.RandomState(utt.fetch_seed())
......
...@@ -500,7 +500,7 @@ def just_gemm(i, o, ishapes=[(4, 3), (3, 5), (4, 5), (), ()], ...@@ -500,7 +500,7 @@ def just_gemm(i, o, ishapes=[(4, 3), (3, 5), (4, 5), (), ()],
raise raise
@unittest_tools.expectedFailure_fast @unittest_tools.assertFailure_fast
def test_gemm_opt0(): def test_gemm_opt0():
# Many subgraphs whose dots can be eliminated # Many subgraphs whose dots can be eliminated
X, Y, Z, a, b = XYZab() X, Y, Z, a, b = XYZab()
...@@ -529,7 +529,7 @@ def test_gemm_opt0(): ...@@ -529,7 +529,7 @@ def test_gemm_opt0():
just_gemm([X, Y, Z, a, b], [Z - a * b * a * T.dot(X, Y)]) just_gemm([X, Y, Z, a, b], [Z - a * b * a * T.dot(X, Y)])
@unittest_tools.expectedFailure_fast @unittest_tools.assertFailure_fast
def test_gemm_opt_double_gemm(): def test_gemm_opt_double_gemm():
# This is the pattern that shows up in the autoencoder # This is the pattern that shows up in the autoencoder
X, Y, Z, a, b = T.matrix(), T.matrix(), T.matrix(), T.scalar(), T.scalar() X, Y, Z, a, b = T.matrix(), T.matrix(), T.matrix(), T.scalar(), T.scalar()
......
...@@ -1367,7 +1367,7 @@ class TestCompositeCodegen(unittest.TestCase): ...@@ -1367,7 +1367,7 @@ class TestCompositeCodegen(unittest.TestCase):
utt.assert_allclose(f([[1.]]), [[0.]]) utt.assert_allclose(f([[1.]]), [[0.]])
@utt.expectedFailure_fast @utt.assertFailure_fast
def test_log1p(): def test_log1p():
m = theano.config.mode m = theano.config.mode
if m == 'FAST_COMPILE': if m == 'FAST_COMPILE':
...@@ -1990,7 +1990,7 @@ class test_local_subtensor_lift(unittest.TestCase): ...@@ -1990,7 +1990,7 @@ class test_local_subtensor_lift(unittest.TestCase):
assert len(prog) == 3 assert len(prog) == 3
f([4, 5]) # let debugmode test something f([4, 5]) # let debugmode test something
@utt.expectedFailure_fast @utt.assertFailure_fast
def test4(self): def test4(self):
# basic test that the optimization doesn't work with broadcasting # basic test that the optimization doesn't work with broadcasting
# ... It *could* be extended to, # ... It *could* be extended to,
......
...@@ -447,11 +447,16 @@ class AttemptManyTimes: ...@@ -447,11 +447,16 @@ class AttemptManyTimes:
return attempt_multiple_times return attempt_multiple_times
def expectedFailure_fast(f): def assertFailure_fast(f):
"""A Decorator to handle the test cases that are failing when """A Decorator to handle the test cases that are failing when
THEANO_FLAGS =cycle_detection='fast'. THEANO_FLAGS =cycle_detection='fast'.
""" """
if theano.config.cycle_detection == 'fast': if theano.config.cycle_detection == 'fast':
return unittest.expectedFailure(f) class TestAssertion(unittest.TestCase):
def runTest(self, *args, **kwargs):
with self.assertRaises(Exception):
f(*args, **kwargs)
test_assertion = TestAssertion()
return test_assertion
else: else:
return f return f
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论