提交 069b58ac authored 作者: Michael Osthege's avatar Michael Osthege 提交者: Brandon T. Willard

Pass kwargs instead of list of tuples to change_flags

上级 f7898c43
......@@ -295,7 +295,7 @@ class TestAssertConvShape:
class TestAssertShape:
@change_flags([("conv__assert_shape", True)])
@change_flags(conv__assert_shape=True)
def test_basic(self):
x = tensor.tensor4()
s1 = tensor.iscalar()
......@@ -318,7 +318,7 @@ class TestAssertShape:
with pytest.raises(AssertionError):
f(v, 7, 7)
@change_flags([("conv__assert_shape", True)])
@change_flags(conv__assert_shape=True)
def test_shape_check_conv2d(self):
input = tensor.tensor4()
filters = tensor.tensor4()
......@@ -340,7 +340,7 @@ class TestAssertShape:
np.zeros((7, 5, 2, 2), dtype="float32"),
)
@change_flags([("conv__assert_shape", True)])
@change_flags(conv__assert_shape=True)
@pytest.mark.skipif(theano.config.cxx == "", reason="test needs cxx")
def test_shape_check_conv3d(self):
input = tensor.tensor5()
......@@ -363,7 +363,7 @@ class TestAssertShape:
np.zeros((7, 5, 2, 2, 2), dtype="float32"),
)
@change_flags([("conv__assert_shape", True)])
@change_flags(conv__assert_shape=True)
def test_shape_check_conv2d_grad_wrt_inputs(self):
output_grad = tensor.tensor4()
filters = tensor.tensor4()
......@@ -382,7 +382,7 @@ class TestAssertShape:
np.zeros((7, 6, 3, 3), dtype="float32"),
)
@change_flags([("conv__assert_shape", True)])
@change_flags(conv__assert_shape=True)
@pytest.mark.skipif(theano.config.cxx == "", reason="test needs cxx")
def test_shape_check_conv3d_grad_wrt_inputs(self):
output_grad = tensor.tensor5()
......@@ -402,7 +402,7 @@ class TestAssertShape:
np.zeros((7, 6, 3, 3, 3), dtype="float32"),
)
@change_flags([("conv__assert_shape", True)])
@change_flags(conv__assert_shape=True)
def test_shape_check_conv2d_grad_wrt_weights(self):
input = tensor.tensor4()
output_grad = tensor.tensor4()
......@@ -421,7 +421,7 @@ class TestAssertShape:
np.zeros((3, 7, 5, 9), dtype="float32"),
)
@change_flags([("conv__assert_shape", True)])
@change_flags(conv__assert_shape=True)
@pytest.mark.skipif(theano.config.cxx == "", reason="test needs cxx")
def test_shape_check_conv3d_grad_wrt_weights(self):
input = tensor.tensor5()
......
......@@ -874,7 +874,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
assert crossentropy_softmax_argmax_1hot_with_bias in ops
assert not [1 for o in ops if isinstance(o, tt.AdvancedSubtensor)]
with theano.change_flags([("warn__sum_div_dimshuffle_bug", False)]):
with theano.change_flags(warn__sum_div_dimshuffle_bug=False):
fgraph = gof.FunctionGraph([x, b, y], [tt.grad(expr, x)])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
......@@ -911,7 +911,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
assert crossentropy_softmax_argmax_1hot_with_bias in ops
assert not [1 for o in ops if isinstance(o, tt.AdvancedSubtensor)]
with theano.change_flags([("warn__sum_div_dimshuffle_bug", False)]):
with theano.change_flags(warn__sum_div_dimshuffle_bug=False):
fgraph = gof.FunctionGraph([x, b, y], [tt.grad(expr, x)])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
......@@ -949,7 +949,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
assert crossentropy_softmax_argmax_1hot_with_bias in ops
assert not [1 for o in ops if isinstance(o, tt.AdvancedSubtensor)]
with theano.change_flags([("warn__sum_div_dimshuffle_bug", False)]):
with theano.change_flags(warn__sum_div_dimshuffle_bug=False):
fgraph = gof.FunctionGraph([x, b, y], [tt.grad(expr, x)])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论