提交 c398b58d authored 作者: Harm de Vries's avatar Harm de Vries

test softmax_op and softmax_ghraph in test_argmax_pushdown

上级 24f6a4e5
...@@ -1127,49 +1127,49 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -1127,49 +1127,49 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
def test_argmax_pushdown(): def test_argmax_pushdown():
x = tensor.matrix() x = tensor.matrix()
for softmax in [softmax_graph, softmax_op]:
# test that the max_and_argmax is pushed down if the max is not used
out = tensor.max_and_argmax(
softmax(tensor.exp(tensor.tanh(sigmoid(x)))),
axis=-1)[1]
fgraph = gof.FunctionGraph(
[x],
[out])
theano.compile.mode.optdb.query(
theano.compile.mode.OPT_FAST_RUN).optimize(fgraph)
# test that the max_and_argmax is pushed down if the max is not used # print 'AFTER'
out = tensor.max_and_argmax( # for node in fgraph.toposort():
softmax_graph(tensor.exp(tensor.tanh(sigmoid(x)))), # print node.op
axis=-1)[1] assert len(fgraph.toposort()) == 2 # an output_guard is second
fgraph = gof.FunctionGraph( assert fgraph.toposort()[0].op == tensor.basic._max_and_argmax
[x], assert str(fgraph.toposort()[1].op) == 'OutputGuard'
[out]) x = tensor.matrix()
theano.compile.mode.optdb.query( # test that the max_and_argmax is not pushed down if the max is used
theano.compile.mode.OPT_FAST_RUN).optimize(fgraph) out = tensor.max_and_argmax(
softmax(tensor.exp(tensor.tanh(sigmoid(x)))),
# print 'AFTER' axis=-1)[0]
# for node in fgraph.toposort(): fgraph = gof.FunctionGraph(
# print node.op [x],
assert len(fgraph.toposort()) == 2 # an output_guard is second [out])
assert fgraph.toposort()[0].op == tensor.basic._max_and_argmax
assert str(fgraph.toposort()[1].op) == 'OutputGuard'
x = tensor.matrix()
# test that the max_and_argmax is not pushed down if the max is used
out = tensor.max_and_argmax(
softmax_op(tensor.exp(tensor.tanh(sigmoid(x)))),
axis=-1)[0]
fgraph = gof.FunctionGraph(
[x],
[out])
backup = config.warn.argmax_pushdown_bug backup = config.warn.argmax_pushdown_bug
config.warn.argmax_pushdown_bug = False config.warn.argmax_pushdown_bug = False
try: try:
theano.compile.mode.optdb.query( theano.compile.mode.optdb.query(
theano.compile.mode.OPT_FAST_RUN).optimize(fgraph) theano.compile.mode.OPT_FAST_RUN).optimize(fgraph)
finally: finally:
config.warn.argmax_pushdown_bug = backup config.warn.argmax_pushdown_bug = backup
# print 'AFTER' # print 'AFTER'
# for node in fgraph.toposort(): # for node in fgraph.toposort():
# print node.op # print node.op
assert len(fgraph.toposort()) == 4 # an output_guard is second assert len(fgraph.toposort()) == 4 # an output_guard is second
assert isinstance(fgraph.toposort()[0].op, tensor.Elemwise) assert isinstance(fgraph.toposort()[0].op, tensor.Elemwise)
assert isinstance(fgraph.toposort()[1].op, Softmax) assert isinstance(fgraph.toposort()[1].op, Softmax)
assert isinstance(fgraph.toposort()[2].op, tensor.CAReduce) assert isinstance(fgraph.toposort()[2].op, tensor.CAReduce)
assert isinstance(fgraph.toposort()[2].op.scalar_op, theano.scalar.Maximum) assert isinstance(fgraph.toposort()[2].op.scalar_op, theano.scalar.Maximum)
assert str(fgraph.toposort()[3].op) == 'OutputGuard' assert str(fgraph.toposort()[3].op) == 'OutputGuard'
def test_argmax_pushdown_bias(): def test_argmax_pushdown_bias():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论