提交 bf0df19f authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Rename missed instances.

上级 2f32d9ed
...@@ -1592,9 +1592,9 @@ class GpuDnnReduction(DnnBase): ...@@ -1592,9 +1592,9 @@ class GpuDnnReduction(DnnBase):
self.c_axis = self._convert_axis(axis) self.c_axis = self._convert_axis(axis)
# axis is a list of axes to reduce on # axis is a list of axes to reduce on
self.axis = axis self.axis = axis
if arg and (red_op != 'max' and red_op != 'min'): if return_indices and (red_op != 'max' and red_op != 'min'):
raise ValueError("Can't request indices for something other than min or max") raise ValueError("Can't request indices for something other than min or max")
self.arg = arg self.return_indices = return_indices
def _convert_axis(self, axis): def _convert_axis(self, axis):
if axis is None: if axis is None:
...@@ -1623,7 +1623,7 @@ class GpuDnnReduction(DnnBase): ...@@ -1623,7 +1623,7 @@ class GpuDnnReduction(DnnBase):
if not (self.c_axis & (1 << i)): if not (self.c_axis & (1 << i)):
bcast.append(inp.broadcastable[i]) bcast.append(inp.broadcastable[i])
outs = [inp.type.clone(dtype=self.dtype, broadcastable=bcast)()] outs = [inp.type.clone(dtype=self.dtype, broadcastable=bcast)()]
if self.arg: if self.return_indices:
outs.append(GpuArrayType(dtype='uint32', broadcastable=bcast, outs.append(GpuArrayType(dtype='uint32', broadcastable=bcast,
context_name=ctx_name)()) context_name=ctx_name)())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论