提交 4af1e536 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix fallout in tests.

上级 76a9e27f
...@@ -3180,7 +3180,7 @@ def local_dnn_argmax(op, ctx_name, inputs, outputs): ...@@ -3180,7 +3180,7 @@ def local_dnn_argmax(op, ctx_name, inputs, outputs):
return return
max, arg = GpuDnnReduction('maximum', op.axis, inputs[0].dtype, max, arg = GpuDnnReduction('maximum', op.axis, inputs[0].dtype,
inputs[0].dtype, True) inputs[0].dtype, True)(*inputs)
return [as_gpuarray_variable(arg.astype('int64'), ctx_name)] return [as_gpuarray_variable(arg.astype('int64'), ctx_name)]
......
...@@ -1333,9 +1333,9 @@ def test_argmax_pushdown(): ...@@ -1333,9 +1333,9 @@ def test_argmax_pushdown():
# for node in fgraph.toposort(): # for node in fgraph.toposort():
# print node.op # print node.op
assert len(fgraph.toposort()) == 1 assert len(fgraph.toposort()) == 1
assert fgraph.toposort()[0].op == tensor.basic._argmax assert isinstance(fgraph.toposort()[0].op, tensor.basic.Argmax)
assert check_stack_trace( assert check_stack_trace(
fgraph, ops_to_check=tensor.basic._argmax) fgraph, ops_to_check=tensor.basic.Argmax)
x = tensor.matrix() x = tensor.matrix()
# test that the max_and_argmax is not pushed down if the max is used # test that the max_and_argmax is not pushed down if the max is used
out = tensor.max_and_argmax( out = tensor.max_and_argmax(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论