提交 c398b58d authored 作者: Harm de Vries's avatar Harm de Vries

test softmax_op and softmax_ghraph in test_argmax_pushdown

上级 24f6a4e5
...@@ -1127,10 +1127,10 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -1127,10 +1127,10 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
def test_argmax_pushdown(): def test_argmax_pushdown():
x = tensor.matrix() x = tensor.matrix()
for softmax in [softmax_graph, softmax_op]:
# test that the max_and_argmax is pushed down if the max is not used # test that the max_and_argmax is pushed down if the max is not used
out = tensor.max_and_argmax( out = tensor.max_and_argmax(
softmax_graph(tensor.exp(tensor.tanh(sigmoid(x)))), softmax(tensor.exp(tensor.tanh(sigmoid(x)))),
axis=-1)[1] axis=-1)[1]
fgraph = gof.FunctionGraph( fgraph = gof.FunctionGraph(
[x], [x],
...@@ -1147,7 +1147,7 @@ def test_argmax_pushdown(): ...@@ -1147,7 +1147,7 @@ def test_argmax_pushdown():
x = tensor.matrix() x = tensor.matrix()
# test that the max_and_argmax is not pushed down if the max is used # test that the max_and_argmax is not pushed down if the max is used
out = tensor.max_and_argmax( out = tensor.max_and_argmax(
softmax_op(tensor.exp(tensor.tanh(sigmoid(x)))), softmax(tensor.exp(tensor.tanh(sigmoid(x)))),
axis=-1)[0] axis=-1)[0]
fgraph = gof.FunctionGraph( fgraph = gof.FunctionGraph(
[x], [x],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论