提交 0959c943 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add optimization to lift Argmax.

上级 2f3d63cf
...@@ -3147,7 +3147,8 @@ def local_cudnn_maxandargmax(node): ...@@ -3147,7 +3147,8 @@ def local_cudnn_maxandargmax(node):
return return
# order of the axes influences the output indices # order of the axes influences the output indices
if tuple(sorted(node.op.axis)) != node.op.axis: if (node.op.axis is not None and
tuple(sorted(node.op.axis)) != node.op.axis):
return return
max, arg = GpuDnnReduction('maximum', node.op.axis, node.outputs[0].dtype, max, arg = GpuDnnReduction('maximum', node.op.axis, node.outputs[0].dtype,
...@@ -3158,6 +3159,32 @@ def local_cudnn_maxandargmax(node): ...@@ -3158,6 +3159,32 @@ def local_cudnn_maxandargmax(node):
node.outputs[1].type.context_name)) node.outputs[1].type.context_name))
@register_opt('cudnn', 'fast_compile')
@op_lifter([Argmax])
@register_opt2([Argmax], 'fast_compile', 'cudnn')
def local_dnn_argmax(op, ctx_name, inputs, outputs):
if not dnn_available(ctx_name):
return
if version(raises=False) < 6000:
return
if inputs[0].ndim > 8:
return
if inputs[0].dtype not in ['float16', 'float32', 'float64']:
return
# order of the axes influences the output indices
if op.axis is not None and tuple(sorted(op.axis)) != op.axis:
return
max, arg = GpuDnnReduction('maximum', op.axis, inputs[0].dtype,
inputs[0].dtype, True)
return [as_gpuarray_variable(arg.astype('int64'), ctx_name)]
class NoCuDNNRaise(Optimizer): class NoCuDNNRaise(Optimizer):
def apply(self, fgraph): def apply(self, fgraph):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论