提交 dfe32419 authored 作者: AdeB's avatar AdeB 提交者: Pascal Lamblin

Remove lambda functions that are not anymore needed

上级 088eaedd
...@@ -76,9 +76,8 @@ def pyconv3d(signals, filters): ...@@ -76,9 +76,8 @@ def pyconv3d(signals, filters):
def check_diagonal_subtensor_view_traces(fn): def check_diagonal_subtensor_view_traces(fn):
def check_node_type(node): assert check_stack_trace(
return isinstance(node.op, (DiagonalSubtensor, IncDiagonalSubtensor)) fn, ops_to_check=(DiagonalSubtensor, IncDiagonalSubtensor),
assert check_stack_trace(fn, ops_to_check=check_node_type,
bug_print='ignore') bug_print='ignore')
......
...@@ -1354,8 +1354,7 @@ def test_argmax_pushdown_bias(): ...@@ -1354,8 +1354,7 @@ def test_argmax_pushdown_bias():
for i, type in enumerate(types_to_check): for i, type in enumerate(types_to_check):
assert isinstance(fgraph.toposort()[i].op, type) assert isinstance(fgraph.toposort()[i].op, type)
assert str(fgraph.toposort()[3].op) == 'OutputGuard' assert str(fgraph.toposort()[3].op) == 'OutputGuard'
assert check_stack_trace( assert check_stack_trace(fgraph, ops_to_check=types_to_check)
fgraph, ops_to_check=lambda node: isinstance(node.op, types_to_check))
x = tensor.matrix() x = tensor.matrix()
b = tensor.vector() b = tensor.vector()
...@@ -1381,9 +1380,8 @@ def test_argmax_pushdown_bias(): ...@@ -1381,9 +1380,8 @@ def test_argmax_pushdown_bias():
assert isinstance(fgraph.toposort()[1].op.scalar_op, theano.scalar.Maximum) assert isinstance(fgraph.toposort()[1].op.scalar_op, theano.scalar.Maximum)
assert str(fgraph.toposort()[2].op) == 'OutputGuard' assert str(fgraph.toposort()[2].op) == 'OutputGuard'
assert check_stack_trace( assert check_stack_trace(
fgraph, ops_to_check=lambda node: ( fgraph, ops_to_check=(SoftmaxWithBias, tensor.CAReduce))
isinstance(node, (SoftmaxWithBias, tensor.CAReduce)) or
isinstance(node.op.scalar_op, theano.scalar.Maximum)))
def test_asymptotic_32(): def test_asymptotic_32():
""" """
......
...@@ -368,9 +368,7 @@ class T_softplus_opts(unittest.TestCase): ...@@ -368,9 +368,7 @@ class T_softplus_opts(unittest.TestCase):
theano.tensor.nnet.sigm.ScalarSoftplus, theano.tensor.nnet.sigm.ScalarSoftplus,
theano.scalar.Neg) theano.scalar.Neg)
assert check_stack_trace( assert check_stack_trace(f, ops_to_check=types_to_check)
f, ops_to_check=
lambda x: isinstance(x.op.scalar_op, types_to_check))
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert len(topo) == 3 assert len(topo) == 3
for i, op in enumerate(types_to_check): for i, op in enumerate(types_to_check):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论