提交 069b58ac authored 作者: Michael Osthege's avatar Michael Osthege 提交者: Brandon T. Willard

Pass kwargs instead of list of tuples to change_flags

上级 f7898c43
...@@ -295,7 +295,7 @@ class TestAssertConvShape: ...@@ -295,7 +295,7 @@ class TestAssertConvShape:
class TestAssertShape: class TestAssertShape:
@change_flags([("conv__assert_shape", True)]) @change_flags(conv__assert_shape=True)
def test_basic(self): def test_basic(self):
x = tensor.tensor4() x = tensor.tensor4()
s1 = tensor.iscalar() s1 = tensor.iscalar()
...@@ -318,7 +318,7 @@ class TestAssertShape: ...@@ -318,7 +318,7 @@ class TestAssertShape:
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
f(v, 7, 7) f(v, 7, 7)
@change_flags([("conv__assert_shape", True)]) @change_flags(conv__assert_shape=True)
def test_shape_check_conv2d(self): def test_shape_check_conv2d(self):
input = tensor.tensor4() input = tensor.tensor4()
filters = tensor.tensor4() filters = tensor.tensor4()
...@@ -340,7 +340,7 @@ class TestAssertShape: ...@@ -340,7 +340,7 @@ class TestAssertShape:
np.zeros((7, 5, 2, 2), dtype="float32"), np.zeros((7, 5, 2, 2), dtype="float32"),
) )
@change_flags([("conv__assert_shape", True)]) @change_flags(conv__assert_shape=True)
@pytest.mark.skipif(theano.config.cxx == "", reason="test needs cxx") @pytest.mark.skipif(theano.config.cxx == "", reason="test needs cxx")
def test_shape_check_conv3d(self): def test_shape_check_conv3d(self):
input = tensor.tensor5() input = tensor.tensor5()
...@@ -363,7 +363,7 @@ class TestAssertShape: ...@@ -363,7 +363,7 @@ class TestAssertShape:
np.zeros((7, 5, 2, 2, 2), dtype="float32"), np.zeros((7, 5, 2, 2, 2), dtype="float32"),
) )
@change_flags([("conv__assert_shape", True)]) @change_flags(conv__assert_shape=True)
def test_shape_check_conv2d_grad_wrt_inputs(self): def test_shape_check_conv2d_grad_wrt_inputs(self):
output_grad = tensor.tensor4() output_grad = tensor.tensor4()
filters = tensor.tensor4() filters = tensor.tensor4()
...@@ -382,7 +382,7 @@ class TestAssertShape: ...@@ -382,7 +382,7 @@ class TestAssertShape:
np.zeros((7, 6, 3, 3), dtype="float32"), np.zeros((7, 6, 3, 3), dtype="float32"),
) )
@change_flags([("conv__assert_shape", True)]) @change_flags(conv__assert_shape=True)
@pytest.mark.skipif(theano.config.cxx == "", reason="test needs cxx") @pytest.mark.skipif(theano.config.cxx == "", reason="test needs cxx")
def test_shape_check_conv3d_grad_wrt_inputs(self): def test_shape_check_conv3d_grad_wrt_inputs(self):
output_grad = tensor.tensor5() output_grad = tensor.tensor5()
...@@ -402,7 +402,7 @@ class TestAssertShape: ...@@ -402,7 +402,7 @@ class TestAssertShape:
np.zeros((7, 6, 3, 3, 3), dtype="float32"), np.zeros((7, 6, 3, 3, 3), dtype="float32"),
) )
@change_flags([("conv__assert_shape", True)]) @change_flags(conv__assert_shape=True)
def test_shape_check_conv2d_grad_wrt_weights(self): def test_shape_check_conv2d_grad_wrt_weights(self):
input = tensor.tensor4() input = tensor.tensor4()
output_grad = tensor.tensor4() output_grad = tensor.tensor4()
...@@ -421,7 +421,7 @@ class TestAssertShape: ...@@ -421,7 +421,7 @@ class TestAssertShape:
np.zeros((3, 7, 5, 9), dtype="float32"), np.zeros((3, 7, 5, 9), dtype="float32"),
) )
@change_flags([("conv__assert_shape", True)]) @change_flags(conv__assert_shape=True)
@pytest.mark.skipif(theano.config.cxx == "", reason="test needs cxx") @pytest.mark.skipif(theano.config.cxx == "", reason="test needs cxx")
def test_shape_check_conv3d_grad_wrt_weights(self): def test_shape_check_conv3d_grad_wrt_weights(self):
input = tensor.tensor5() input = tensor.tensor5()
......
...@@ -874,7 +874,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -874,7 +874,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
assert crossentropy_softmax_argmax_1hot_with_bias in ops assert crossentropy_softmax_argmax_1hot_with_bias in ops
assert not [1 for o in ops if isinstance(o, tt.AdvancedSubtensor)] assert not [1 for o in ops if isinstance(o, tt.AdvancedSubtensor)]
with theano.change_flags([("warn__sum_div_dimshuffle_bug", False)]): with theano.change_flags(warn__sum_div_dimshuffle_bug=False):
fgraph = gof.FunctionGraph([x, b, y], [tt.grad(expr, x)]) fgraph = gof.FunctionGraph([x, b, y], [tt.grad(expr, x)])
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).optimize(fgraph)
...@@ -911,7 +911,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -911,7 +911,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
assert crossentropy_softmax_argmax_1hot_with_bias in ops assert crossentropy_softmax_argmax_1hot_with_bias in ops
assert not [1 for o in ops if isinstance(o, tt.AdvancedSubtensor)] assert not [1 for o in ops if isinstance(o, tt.AdvancedSubtensor)]
with theano.change_flags([("warn__sum_div_dimshuffle_bug", False)]): with theano.change_flags(warn__sum_div_dimshuffle_bug=False):
fgraph = gof.FunctionGraph([x, b, y], [tt.grad(expr, x)]) fgraph = gof.FunctionGraph([x, b, y], [tt.grad(expr, x)])
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).optimize(fgraph)
...@@ -949,7 +949,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -949,7 +949,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
assert crossentropy_softmax_argmax_1hot_with_bias in ops assert crossentropy_softmax_argmax_1hot_with_bias in ops
assert not [1 for o in ops if isinstance(o, tt.AdvancedSubtensor)] assert not [1 for o in ops if isinstance(o, tt.AdvancedSubtensor)]
with theano.change_flags([("warn__sum_div_dimshuffle_bug", False)]): with theano.change_flags(warn__sum_div_dimshuffle_bug=False):
fgraph = gof.FunctionGraph([x, b, y], [tt.grad(expr, x)]) fgraph = gof.FunctionGraph([x, b, y], [tt.grad(expr, x)])
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).optimize(fgraph)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论