提交 4af1e536 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix fallout in tests.

上级 76a9e27f
......@@ -3180,7 +3180,7 @@ def local_dnn_argmax(op, ctx_name, inputs, outputs):
return
max, arg = GpuDnnReduction('maximum', op.axis, inputs[0].dtype,
inputs[0].dtype, True)
inputs[0].dtype, True)(*inputs)
return [as_gpuarray_variable(arg.astype('int64'), ctx_name)]
......
......@@ -1333,9 +1333,9 @@ def test_argmax_pushdown():
# for node in fgraph.toposort():
# print node.op
assert len(fgraph.toposort()) == 1
assert fgraph.toposort()[0].op == tensor.basic._argmax
assert isinstance(fgraph.toposort()[0].op, tensor.basic.Argmax)
assert check_stack_trace(
fgraph, ops_to_check=tensor.basic._argmax)
fgraph, ops_to_check=tensor.basic.Argmax)
x = tensor.matrix()
# test that the max_and_argmax is not pushed down if the max is used
out = tensor.max_and_argmax(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论