提交 c398b58d authored 作者: Harm de Vries's avatar Harm de Vries

test softmax_op and softmax_ghraph in test_argmax_pushdown

上级 24f6a4e5
......@@ -1127,49 +1127,49 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
def test_argmax_pushdown():
x = tensor.matrix()
for softmax in [softmax_graph, softmax_op]:
# test that the max_and_argmax is pushed down if the max is not used
out = tensor.max_and_argmax(
softmax(tensor.exp(tensor.tanh(sigmoid(x)))),
axis=-1)[1]
fgraph = gof.FunctionGraph(
[x],
[out])
theano.compile.mode.optdb.query(
theano.compile.mode.OPT_FAST_RUN).optimize(fgraph)
# test that the max_and_argmax is pushed down if the max is not used
out = tensor.max_and_argmax(
softmax_graph(tensor.exp(tensor.tanh(sigmoid(x)))),
axis=-1)[1]
fgraph = gof.FunctionGraph(
[x],
[out])
theano.compile.mode.optdb.query(
theano.compile.mode.OPT_FAST_RUN).optimize(fgraph)
# print 'AFTER'
# for node in fgraph.toposort():
# print node.op
assert len(fgraph.toposort()) == 2 # an output_guard is second
assert fgraph.toposort()[0].op == tensor.basic._max_and_argmax
assert str(fgraph.toposort()[1].op) == 'OutputGuard'
x = tensor.matrix()
# test that the max_and_argmax is not pushed down if the max is used
out = tensor.max_and_argmax(
softmax_op(tensor.exp(tensor.tanh(sigmoid(x)))),
axis=-1)[0]
fgraph = gof.FunctionGraph(
[x],
[out])
# print 'AFTER'
# for node in fgraph.toposort():
# print node.op
assert len(fgraph.toposort()) == 2 # an output_guard is second
assert fgraph.toposort()[0].op == tensor.basic._max_and_argmax
assert str(fgraph.toposort()[1].op) == 'OutputGuard'
x = tensor.matrix()
# test that the max_and_argmax is not pushed down if the max is used
out = tensor.max_and_argmax(
softmax(tensor.exp(tensor.tanh(sigmoid(x)))),
axis=-1)[0]
fgraph = gof.FunctionGraph(
[x],
[out])
backup = config.warn.argmax_pushdown_bug
config.warn.argmax_pushdown_bug = False
try:
theano.compile.mode.optdb.query(
theano.compile.mode.OPT_FAST_RUN).optimize(fgraph)
finally:
config.warn.argmax_pushdown_bug = backup
backup = config.warn.argmax_pushdown_bug
config.warn.argmax_pushdown_bug = False
try:
theano.compile.mode.optdb.query(
theano.compile.mode.OPT_FAST_RUN).optimize(fgraph)
finally:
config.warn.argmax_pushdown_bug = backup
# print 'AFTER'
# for node in fgraph.toposort():
# print node.op
assert len(fgraph.toposort()) == 4 # an output_guard is second
assert isinstance(fgraph.toposort()[0].op, tensor.Elemwise)
assert isinstance(fgraph.toposort()[1].op, Softmax)
assert isinstance(fgraph.toposort()[2].op, tensor.CAReduce)
assert isinstance(fgraph.toposort()[2].op.scalar_op, theano.scalar.Maximum)
assert str(fgraph.toposort()[3].op) == 'OutputGuard'
# print 'AFTER'
# for node in fgraph.toposort():
# print node.op
assert len(fgraph.toposort()) == 4 # an output_guard is second
assert isinstance(fgraph.toposort()[0].op, tensor.Elemwise)
assert isinstance(fgraph.toposort()[1].op, Softmax)
assert isinstance(fgraph.toposort()[2].op, tensor.CAReduce)
assert isinstance(fgraph.toposort()[2].op.scalar_op, theano.scalar.Maximum)
assert str(fgraph.toposort()[3].op) == 'OutputGuard'
def test_argmax_pushdown_bias():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论