提交 aa636171 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add optimizer for GpuMaxAndArgmax.

上级 d5ba6134
...@@ -37,6 +37,7 @@ from .basic_ops import (as_gpuarray_variable, infer_context_name, ...@@ -37,6 +37,7 @@ from .basic_ops import (as_gpuarray_variable, infer_context_name,
gpu_contiguous, GpuAllocEmpty, gpu_contiguous, GpuAllocEmpty,
empty_like, GpuArrayType, HostFromGpu) empty_like, GpuArrayType, HostFromGpu)
from .elemwise import GpuElemwise, GpuCAReduceCuda from .elemwise import GpuElemwise, GpuCAReduceCuda
from .reduction import GpuMaxAndArgmax
# These don't exist in gpuarray # These don't exist in gpuarray
# GpuDownsampleFactorMax, GpuDownsampleFactorMaxGrad # GpuDownsampleFactorMax, GpuDownsampleFactorMaxGrad
...@@ -1592,8 +1593,9 @@ class GpuDnnReduction(DnnBase): ...@@ -1592,8 +1593,9 @@ class GpuDnnReduction(DnnBase):
self.c_axis = self._convert_axis(axis) self.c_axis = self._convert_axis(axis)
# axis is a list of axes to reduce on # axis is a list of axes to reduce on
self.axis = axis self.axis = axis
if return_indices and (red_op != 'max' and red_op != 'min'): if return_indices and (red_op != 'maximum' and red_op != 'minimum'):
raise ValueError("Can't request indices for something other than min or max") raise ValueError("Can't request indices for something other than"
" minimum or maximum")
self.return_indices = return_indices self.return_indices = return_indices
def _convert_axis(self, axis): def _convert_axis(self, axis):
...@@ -3122,6 +3124,38 @@ def local_dnn_reduction(node): ...@@ -3122,6 +3124,38 @@ def local_dnn_reduction(node):
node.op.dtype, node.op.dtype,
False)(node.inputs[0]),) False)(node.inputs[0]),)
@register_opt('cudnn')
@local_optimizer([GpuMaxAndArgmax])
def local_cudnn_maxandargmax(node):
if not isinstance(node.op, GpuMaxAndArgmax):
return
if not dnn_available(node.inputs[0].type.context_name):
return
if version(raises=False) < 6000:
return
if node.inputs[0].ndim > 8:
return
if node.inputs[0].dtype != node.outputs[0].dtype:
return
if node.inputs[0].dtype not in ['float16', 'float32', 'float64']:
return
# order of the axes influences the output indices
if tuple(sorted(node.op.axis)) != node.op.axis:
return
max, arg = GpuDnnReduction('maximum', node.op.axis, node.outputs[0].dtype,
node.outputs[0].dtype, True)(node.inputs[0])
# cudnn can only return int32 indices
return (max, as_gpuarray_variable(arg.astype('int64'),
node.outputs[1].type.context_name))
class NoCuDNNRaise(Optimizer): class NoCuDNNRaise(Optimizer):
......
...@@ -37,8 +37,8 @@ class GpuMaxAndArgmax(Op): ...@@ -37,8 +37,8 @@ class GpuMaxAndArgmax(Op):
broadcastable = [b for i, b in enumerate(X.type.broadcastable) broadcastable = [b for i, b in enumerate(X.type.broadcastable)
if i not in all_axes] if i not in all_axes]
inputs = [as_gpuarray_variable(X, context_name)] inputs = [as_gpuarray_variable(X, context_name)]
outputs = [GpuArrayType(X.type.dtype, broadcastable, context_name=context_name, name='max')(), outputs = [GpuArrayType(X.type.dtype, broadcastable, context_name=context_name)(),
GpuArrayType(self.argmax_dtype, broadcastable, context_name=context_name, name='argmax')()] GpuArrayType(self.argmax_dtype, broadcastable, context_name=context_name)()]
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
def c_headers(self): def c_headers(self):
......
...@@ -1506,6 +1506,34 @@ def test_dnn_reduction_opt(): ...@@ -1506,6 +1506,34 @@ def test_dnn_reduction_opt():
yield dnn_reduction, 2, idtype, adtype, odtype yield dnn_reduction, 2, idtype, adtype, odtype
def dnn_maxargmax(nd, idtype, axis):
inp = T.TensorType(idtype, (False,) * nd)()
res = T.max_and_argmax(inp, axis=axis)
f = theano.function([inp], res, mode=mode_with_gpu)
assert any(isinstance(n.op, dnn.GpuDnnReduction)
for n in f.maker.fgraph.apply_nodes)
def test_dnn_maxandargmax_opt():
if not dnn.dnn_available(test_ctx_name) or dnn.version(raises=False) < 6000:
raise SkipTest(dnn.dnn_available.msg)
for nd in range(1, 9):
yield dnn_maxargmax, nd, 'float32', None
for idtype in ('float64', 'float16'):
yield dnn_maxargmax, 2, idtype, None
yield dnn_maxargmax, 3, 'float32', (0, 1)
yield dnn_maxargmax, 3, 'float32', (0, 2)
yield dnn_maxargmax, 3, 'float32', (1, 2)
yield dnn_maxargmax, 3, 'float32', (0, 1, 2)
yield dnn_maxargmax, 3, 'float32', (0,)
yield dnn_maxargmax, 3, 'float32', (1,)
yield dnn_maxargmax, 3, 'float32', (2,)
yield dnn_maxargmax, 3, 'float32', ()
def test_dnn_batchnorm_train(): def test_dnn_batchnorm_train():
if not dnn.dnn_available(test_ctx_name): if not dnn.dnn_available(test_ctx_name):
raise SkipTest(dnn.dnn_available.msg) raise SkipTest(dnn.dnn_available.msg)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论