提交 06be9e5b authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Rename the bitfield to c_axis and change axis to the list of axes from

the arguments.
上级 cceeafad
...@@ -1570,7 +1570,7 @@ class GpuDnnReduction(DnnBase): ...@@ -1570,7 +1570,7 @@ class GpuDnnReduction(DnnBase):
params_type = ParamsType(red_op=cudnn.cudnnReduceTensorOp_t, params_type = ParamsType(red_op=cudnn.cudnnReduceTensorOp_t,
acc_dtype=cudnn.cudnnDataType_t, acc_dtype=cudnn.cudnnDataType_t,
axis=uint32_t, c_axis=uint32_t,
handle=handle_type) handle=handle_type)
def __init__(self, red_op, axis, acc_dtype, dtype, arg): def __init__(self, red_op, axis, acc_dtype, dtype, arg):
...@@ -1587,7 +1587,11 @@ class GpuDnnReduction(DnnBase): ...@@ -1587,7 +1587,11 @@ class GpuDnnReduction(DnnBase):
raise ValueError('Too many axes to reduce on') raise ValueError('Too many axes to reduce on')
if any(a >= 8 for a in axis): if any(a >= 8 for a in axis):
raise ValueError('Axes larger than 8 not supported') raise ValueError('Axes larger than 8 not supported')
self.axis = self._convert_axis(axis) axis = tuple(axis)
# c_axis is a bitfield (1 to reduce)
self.c_axis = self._convert_axis(axis)
# axis is a list of axes to reduce on
self.axis = axis
if arg and (red_op != 'max' and red_op != 'min'): if arg and (red_op != 'max' and red_op != 'min'):
raise ValueError("Can't request indices for something other than min or max") raise ValueError("Can't request indices for something other than min or max")
self.arg = arg self.arg = arg
......
...@@ -72,7 +72,7 @@ int APPLY_SPECIFIC(dnn_redux)(PyGpuArrayObject *input, ...@@ -72,7 +72,7 @@ int APPLY_SPECIFIC(dnn_redux)(PyGpuArrayObject *input,
p = 0; p = 0;
rsz = 1; rsz = 1;
for (unsigned int i = 0; i < PyGpuArray_NDIM(input); i++) { for (unsigned int i = 0; i < PyGpuArray_NDIM(input); i++) {
if (!(params->axis & (1U << i))) { if (!(params->c_axis & (1U << i))) {
dims[p] = PyGpuArray_DIM(input, i); dims[p] = PyGpuArray_DIM(input, i);
p++; p++;
} else { } else {
...@@ -111,7 +111,7 @@ int APPLY_SPECIFIC(dnn_redux)(PyGpuArrayObject *input, ...@@ -111,7 +111,7 @@ int APPLY_SPECIFIC(dnn_redux)(PyGpuArrayObject *input,
// We have to do some trickery to be able to pass it what it need. // We have to do some trickery to be able to pass it what it need.
p = 0; p = 0;
for (unsigned int i = 0; i < PyGpuArray_NDIM(input); i++) { for (unsigned int i = 0; i < PyGpuArray_NDIM(input); i++) {
if (params->axis & (1U << i)) { if (params->c_axis & (1U << i)) {
dims[i] = 1; dims[i] = 1;
strs[i] = 0; strs[i] = 0;
} else { } else {
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论