提交 088eaedd authored 作者: AdeB's avatar AdeB 提交者: Pascal Lamblin

finish updating tests/nnet.py with chec_trace helper

上级 2acc0720
...@@ -1342,8 +1342,6 @@ def test_argmax_pushdown_bias(): ...@@ -1342,8 +1342,6 @@ def test_argmax_pushdown_bias():
fgraph = gof.FunctionGraph( fgraph = gof.FunctionGraph(
[x, b], [x, b],
[out]) [out])
f = theano.function([x, b], out)
assert hasattr(f.maker.fgraph.outputs[0].tag, 'trace')
theano.compile.mode.optdb.query( theano.compile.mode.optdb.query(
theano.compile.mode.OPT_FAST_RUN).optimize(fgraph) theano.compile.mode.OPT_FAST_RUN).optimize(fgraph)
...@@ -1351,11 +1349,13 @@ def test_argmax_pushdown_bias(): ...@@ -1351,11 +1349,13 @@ def test_argmax_pushdown_bias():
# print 'AFTER' # print 'AFTER'
# for node in fgraph.toposort(): # for node in fgraph.toposort():
# print node.op # print node.op
types_to_check = (tensor.DimShuffle, tensor.Elemwise, tensor.MaxAndArgmax)
assert len(fgraph.toposort()) == 4 assert len(fgraph.toposort()) == 4
assert isinstance(fgraph.toposort()[0].op, tensor.DimShuffle) for i, type in enumerate(types_to_check):
assert isinstance(fgraph.toposort()[1].op, tensor.Elemwise) assert isinstance(fgraph.toposort()[i].op, type)
assert isinstance(fgraph.toposort()[2].op, tensor.MaxAndArgmax)
assert str(fgraph.toposort()[3].op) == 'OutputGuard' assert str(fgraph.toposort()[3].op) == 'OutputGuard'
assert check_stack_trace(
fgraph, ops_to_check=lambda node: isinstance(node.op, types_to_check))
x = tensor.matrix() x = tensor.matrix()
b = tensor.vector() b = tensor.vector()
...@@ -1363,8 +1363,6 @@ def test_argmax_pushdown_bias(): ...@@ -1363,8 +1363,6 @@ def test_argmax_pushdown_bias():
fgraph = gof.FunctionGraph( fgraph = gof.FunctionGraph(
[x, b], [x, b],
[out]) [out])
f = theano.function([x, b], out)
assert hasattr(f.maker.fgraph.outputs[0].tag, 'trace')
backup = config.warn.argmax_pushdown_bug backup = config.warn.argmax_pushdown_bug
config.warn.argmax_pushdown_bug = False config.warn.argmax_pushdown_bug = False
...@@ -1382,7 +1380,10 @@ def test_argmax_pushdown_bias(): ...@@ -1382,7 +1380,10 @@ def test_argmax_pushdown_bias():
assert isinstance(fgraph.toposort()[1].op, tensor.CAReduce) assert isinstance(fgraph.toposort()[1].op, tensor.CAReduce)
assert isinstance(fgraph.toposort()[1].op.scalar_op, theano.scalar.Maximum) assert isinstance(fgraph.toposort()[1].op.scalar_op, theano.scalar.Maximum)
assert str(fgraph.toposort()[2].op) == 'OutputGuard' assert str(fgraph.toposort()[2].op) == 'OutputGuard'
assert check_stack_trace(
fgraph, ops_to_check=lambda node: (
isinstance(node, (SoftmaxWithBias, tensor.CAReduce)) or
isinstance(node.op.scalar_op, theano.scalar.Maximum)))
def test_asymptotic_32(): def test_asymptotic_32():
""" """
...@@ -1455,7 +1456,7 @@ class Test_softmax_opt: ...@@ -1455,7 +1456,7 @@ class Test_softmax_opt:
# test that function contains softmax and no div. # test that function contains softmax and no div.
f = theano.function([c], p_y, mode=self.mode) f = theano.function([c], p_y, mode=self.mode)
assert hasattr(f.maker.fgraph.outputs[0].tag, 'trace') assert check_stack_trace(f, ops_to_check=softmax_op)
f_ops = [n.op for n in f.maker.fgraph.toposort()] f_ops = [n.op for n in f.maker.fgraph.toposort()]
# print '--- f =' # print '--- f ='
...@@ -1472,7 +1473,7 @@ class Test_softmax_opt: ...@@ -1472,7 +1473,7 @@ class Test_softmax_opt:
# test that function contains softmax and no div. # test that function contains softmax and no div.
f = theano.function([c], p_y, mode=self.mode) f = theano.function([c], p_y, mode=self.mode)
assert hasattr(f.maker.fgraph.outputs[0].tag, 'trace') assert check_stack_trace(f, ops_to_check=softmax_op)
f_ops = [n.op for n in f.maker.fgraph.toposort()] f_ops = [n.op for n in f.maker.fgraph.toposort()]
# print '--- f =' # print '--- f ='
...@@ -1492,7 +1493,6 @@ class Test_softmax_opt: ...@@ -1492,7 +1493,6 @@ class Test_softmax_opt:
config.warn.sum_div_dimshuffle_bug = False config.warn.sum_div_dimshuffle_bug = False
try: try:
g = theano.function([c, w], T.grad((p_y * w).sum(), c)) g = theano.function([c, w], T.grad((p_y * w).sum(), c))
hasattr(g.maker.fgraph.outputs[0].tag, 'trace')
finally: finally:
config.warn.sum_div_dimshuffle_bug = backup config.warn.sum_div_dimshuffle_bug = backup
g_ops = [n.op for n in g.maker.fgraph.toposort()] g_ops = [n.op for n in g.maker.fgraph.toposort()]
...@@ -1520,7 +1520,6 @@ class Test_softmax_opt: ...@@ -1520,7 +1520,6 @@ class Test_softmax_opt:
config.warn.sum_div_dimshuffle_bug = False config.warn.sum_div_dimshuffle_bug = False
try: try:
g = theano.function([c], T.grad(p_y.sum(), c)) g = theano.function([c], T.grad(p_y.sum(), c))
hasattr(g.maker.fgraph.outputs[0].tag, 'trace')
finally: finally:
config.warn.sum_div_dimshuffle_bug = backup config.warn.sum_div_dimshuffle_bug = backup
# printing.debugprint(g) # printing.debugprint(g)
...@@ -1533,7 +1532,6 @@ class Test_softmax_opt: ...@@ -1533,7 +1532,6 @@ class Test_softmax_opt:
# test that function contains softmax and no div. # test that function contains softmax and no div.
f = theano.function([c], p_y) f = theano.function([c], p_y)
hasattr(f.maker.fgraph.outputs[0].tag, 'trace')
# printing.debugprint(f) # printing.debugprint(f)
# test that function contains softmax and no div. # test that function contains softmax and no div.
...@@ -1541,7 +1539,6 @@ class Test_softmax_opt: ...@@ -1541,7 +1539,6 @@ class Test_softmax_opt:
config.warn.sum_div_dimshuffle_bug = False config.warn.sum_div_dimshuffle_bug = False
try: try:
g = theano.function([c], T.grad(p_y.sum(), c)) g = theano.function([c], T.grad(p_y.sum(), c))
hasattr(g.maker.fgraph.outputs[0].tag, 'trace')
finally: finally:
config.warn.sum_div_dimshuffle_bug = backup config.warn.sum_div_dimshuffle_bug = backup
# printing.debugprint(g) # printing.debugprint(g)
...@@ -1581,7 +1578,7 @@ def test_stabilize_log_softmax(): ...@@ -1581,7 +1578,7 @@ def test_stabilize_log_softmax():
z = theano.tensor.log(y) z = theano.tensor.log(y)
f = theano.function([x], z, mode=mode) f = theano.function([x], z, mode=mode)
assert hasattr(f.maker.fgraph.outputs[0].tag, 'trace') assert check_stack_trace(f, ops_to_check='all')
# check that the softmax has been optimized out # check that the softmax has been optimized out
for node in f.maker.fgraph.toposort(): for node in f.maker.fgraph.toposort():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论