提交 ff40b63e authored 作者: notoraptor's avatar notoraptor

Wrap op params for theano.gpuarray.pool.GpuAveragePoolGrad.

上级 b12925dd
...@@ -213,6 +213,7 @@ class GpuAveragePoolGrad(CGpuKernelBase): ...@@ -213,6 +213,7 @@ class GpuAveragePoolGrad(CGpuKernelBase):
""" """
__props__ = ('ignore_border', 'mode', 'ndim') __props__ = ('ignore_border', 'mode', 'ndim')
params_type = ParamsType(mode=PoolingMode_t, context=gpu_context_type)
def __init__(self, ignore_border, mode='max', ndim=2): def __init__(self, ignore_border, mode='max', ndim=2):
self.ndim = ndim self.ndim = ndim
...@@ -225,6 +226,9 @@ class GpuAveragePoolGrad(CGpuKernelBase): ...@@ -225,6 +226,9 @@ class GpuAveragePoolGrad(CGpuKernelBase):
assert mode in ('sum', 'average_inc_pad', 'average_exc_pad') assert mode in ('sum', 'average_inc_pad', 'average_exc_pad')
assert ndim in [2, 3] assert ndim in [2, 3]
def get_params(self, node):
return self.params_type.get_params(self, context=node.inputs[0].type.context)
def c_headers(self): def c_headers(self):
return ['gpuarray_api.h', 'gpuarray_helper.h', 'numpy_compat.h'] return ['gpuarray_api.h', 'gpuarray_helper.h', 'numpy_compat.h']
...@@ -266,12 +270,6 @@ class GpuAveragePoolGrad(CGpuKernelBase): ...@@ -266,12 +270,6 @@ class GpuAveragePoolGrad(CGpuKernelBase):
return Apply(self, [inp, out_grad, ws, stride, pad], [inp.type()]) return Apply(self, [inp, out_grad, ws, stride, pad], [inp.type()])
def get_op_params(self):
inc_pad = int(self.mode == 'average_inc_pad')
sum_mode = int(self.mode == 'sum')
return [('INC_PAD', inc_pad),
('SUM_MODE', sum_mode)]
def infer_shape(self, node, in_shapes): def infer_shape(self, node, in_shapes):
return [in_shapes[0]] return [in_shapes[0]]
......
...@@ -115,7 +115,9 @@ int APPLY_SPECIFIC(ave_pool_grad)(PyGpuArrayObject *x, ...@@ -115,7 +115,9 @@ int APPLY_SPECIFIC(ave_pool_grad)(PyGpuArrayObject *x,
PyArrayObject *stride, PyArrayObject *stride,
PyArrayObject *pad, PyArrayObject *pad,
PyGpuArrayObject **gx, PyGpuArrayObject **gx,
PyGpuContextObject *ctx) { PARAMS_TYPE* params) {
bool inc_pad = (params->mode == POOLING_AVERAGE_COUNT_INCLUDE_PADDING);
bool sum_mode = (params->mode == POOLING_SUM);
if (!GpuArray_IS_C_CONTIGUOUS(&x->ga) if (!GpuArray_IS_C_CONTIGUOUS(&x->ga)
|| !GpuArray_IS_C_CONTIGUOUS(&gz->ga)) || !GpuArray_IS_C_CONTIGUOUS(&gz->ga))
{ {
...@@ -131,7 +133,7 @@ int APPLY_SPECIFIC(ave_pool_grad)(PyGpuArrayObject *x, ...@@ -131,7 +133,7 @@ int APPLY_SPECIFIC(ave_pool_grad)(PyGpuArrayObject *x,
return 1; return 1;
} }
if (theano_prep_output(gx, PyGpuArray_NDIM(x), PyGpuArray_DIMS(x), if (theano_prep_output(gx, PyGpuArray_NDIM(x), PyGpuArray_DIMS(x),
x->ga.typecode, GA_C_ORDER, ctx) != 0) x->ga.typecode, GA_C_ORDER, params->context) != 0)
{ {
PyErr_SetString(PyExc_RuntimeError, PyErr_SetString(PyExc_RuntimeError,
"GpuMaxPoolGrad: failed to allocate memory"); "GpuMaxPoolGrad: failed to allocate memory");
...@@ -161,7 +163,7 @@ int APPLY_SPECIFIC(ave_pool_grad)(PyGpuArrayObject *x, ...@@ -161,7 +163,7 @@ int APPLY_SPECIFIC(ave_pool_grad)(PyGpuArrayObject *x,
x->ga.data, x->ga.offset, x->ga.data, x->ga.offset,
gz->ga.data, gz->ga.offset, gz->ga.data, gz->ga.offset,
w[0], w[1], s[0], s[1], p[0], p[1], w[0], w[1], s[0], s[1], p[0], p[1],
INC_PAD, SUM_MODE, inc_pad, sum_mode,
(*gx)->ga.data, (*gx)->ga.offset); (*gx)->ga.data, (*gx)->ga.offset);
if (err != GA_NO_ERROR) { if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError, PyErr_Format(PyExc_RuntimeError,
...@@ -177,7 +179,7 @@ int APPLY_SPECIFIC(ave_pool_grad)(PyGpuArrayObject *x, ...@@ -177,7 +179,7 @@ int APPLY_SPECIFIC(ave_pool_grad)(PyGpuArrayObject *x,
x->ga.data, x->ga.offset, x->ga.data, x->ga.offset,
gz->ga.data, gz->ga.offset, gz->ga.data, gz->ga.offset,
w[0], w[1], w[2], s[0], s[1], s[2], w[0], w[1], w[2], s[0], s[1], s[2],
p[0], p[1], p[2], INC_PAD, SUM_MODE, p[0], p[1], p[2], inc_pad, sum_mode,
(*gx)->ga.data, (*gx)->ga.offset); (*gx)->ga.data, (*gx)->ga.offset);
if (err != GA_NO_ERROR) { if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError, PyErr_Format(PyExc_RuntimeError,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论