提交 60a906e7 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #6081 from nouiz/params_gemm

Params for Gemm, CGer, CGemv
...@@ -421,7 +421,7 @@ def get_c_extract_out(r, name, sub): ...@@ -421,7 +421,7 @@ def get_c_extract_out(r, name, sub):
# `c_extract_out` is used to extract an output variable from # `c_extract_out` is used to extract an output variable from
# the compute map, to be used as pre-allocated memory for `r` # the compute map, to be used as pre-allocated memory for `r`
# before its value gets computed. # before its value gets computed.
# If the node producing `r` has `check_inputs=True`, it may # If the node producing `r` has `check_input=True`, it may
# also perform type checks on the initial value of the output, # also perform type checks on the initial value of the output,
# so we need to pass `check_input=True` to `c_extract_out`. # so we need to pass `check_input=True` to `c_extract_out`.
# However, that code is not used by potential clients of `r`, # However, that code is not used by potential clients of `r`,
......
...@@ -145,9 +145,11 @@ from theano.gof import (utils, Op, view_roots, ...@@ -145,9 +145,11 @@ from theano.gof import (utils, Op, view_roots,
InconsistencyError, toolbox, SequenceDB, InconsistencyError, toolbox, SequenceDB,
EquilibriumOptimizer, Apply, EquilibriumOptimizer, Apply,
ReplacementDidntRemovedError) ReplacementDidntRemovedError)
from theano.gof.params_type import ParamsType
from theano.printing import pprint, FunctionPrinter, debugprint from theano.printing import pprint, FunctionPrinter, debugprint
from theano.compile.mode import optdb from theano.compile.mode import optdb
import theano.scalar import theano.scalar
from theano.scalar import bool as bool_t
from theano.tensor import basic as T from theano.tensor import basic as T
from theano.tensor.blas_headers import blas_header_text from theano.tensor.blas_headers import blas_header_text
from theano.tensor.blas_headers import blas_header_version from theano.tensor.blas_headers import blas_header_version
...@@ -243,7 +245,7 @@ class Gemv(Op): ...@@ -243,7 +245,7 @@ class Gemv(Op):
raise TypeError('gemv requires vector for y', y.type) raise TypeError('gemv requires vector for y', y.type)
return Apply(self, [y, alpha, A, x, beta], [y.type()]) return Apply(self, [y, alpha, A, x, beta], [y.type()])
def perform(self, node, inputs, out_storage): def perform(self, node, inputs, out_storage, params=None):
y, alpha, A, x, beta = inputs y, alpha, A, x, beta = inputs
if (have_fblas and y.shape[0] != 0 and x.shape[0] != 0 and if (have_fblas and y.shape[0] != 0 and x.shape[0] != 0 and
y.dtype in _blas_gemv_fns): y.dtype in _blas_gemv_fns):
...@@ -333,7 +335,7 @@ class Ger(Op): ...@@ -333,7 +335,7 @@ class Ger(Op):
raise TypeError('only float and complex types supported', x.dtype) raise TypeError('only float and complex types supported', x.dtype)
return Apply(self, [A, alpha, x, y], [A.type()]) return Apply(self, [A, alpha, x, y], [A.type()])
def perform(self, node, inp, out): def perform(self, node, inp, out, params=None):
cA, calpha, cx, cy = inp cA, calpha, cx, cy = inp
cZ, = out cZ, = out
if self.destructive: if self.destructive:
...@@ -522,7 +524,11 @@ class GemmRelated(Op): ...@@ -522,7 +524,11 @@ class GemmRelated(Op):
int sx_0, sx_1, sy_0, sy_1, sz_0, sz_1; int sx_0, sx_1, sy_0, sy_1, sz_0, sz_1;
""" """
# implement if you don't have an inplace props
# setup_z_Nz_Sz = None # setup_z_Nz_Sz = None
# otherwise implement
# setup_z_Nz_Sz_inplace = None
# setup_z_Nz_Sz_outplace = None
check_xyz_rank2 = """ check_xyz_rank2 = """
if (PyArray_NDIM(%(_x)s) != 2) { if (PyArray_NDIM(%(_x)s) != 2) {
...@@ -755,11 +761,16 @@ class GemmRelated(Op): ...@@ -755,11 +761,16 @@ class GemmRelated(Op):
""" """
def build_gemm_call(self): def build_gemm_call(self):
if hasattr(self, 'inplace'):
setup_z_Nz_Sz = "if(%%(params)s->inplace){%s}else{%s}" % (
self.setup_z_Nz_Sz_inplace, self.setup_z_Nz_Sz_outplace)
else:
setup_z_Nz_Sz = self.setup_z_Nz_Sz
return reduce(str.__add__, ( return reduce(str.__add__, (
self.declare_NS, self.declare_NS,
self.check_xyz_rank2, self.check_xyz_rank2,
self.setup_z_Nz_Sz, setup_z_Nz_Sz,
self.check_xyz_double_or_float, self.check_xyz_double_or_float,
self.check_ab_double_or_float, self.check_ab_double_or_float,
self.check_dims, self.check_dims,
...@@ -809,14 +820,13 @@ class Gemm(GemmRelated): ...@@ -809,14 +820,13 @@ class Gemm(GemmRelated):
E_float = 'gemm requires floating-point dtypes' E_float = 'gemm requires floating-point dtypes'
__props__ = ('inplace',) __props__ = ('inplace',)
params_type = ParamsType(inplace=bool_t,)
check_input = False
def __init__(self, inplace): def __init__(self, inplace):
self.inplace = inplace self.inplace = inplace
if self.inplace: if self.inplace:
self.destroy_map = {0: [0]} self.destroy_map = {0: [0]}
self.setup_z_Nz_Sz = self.setup_z_Nz_Sz_inplace
else:
self.setup_z_Nz_Sz = self.setup_z_Nz_Sz_outplace
def __str__(self): def __str__(self):
if self.inplace: if self.inplace:
...@@ -827,10 +837,6 @@ class Gemm(GemmRelated): ...@@ -827,10 +837,6 @@ class Gemm(GemmRelated):
def __setstate__(self, dct): def __setstate__(self, dct):
self.__dict__.update(dct) self.__dict__.update(dct)
if self.inplace:
self.setup_z_Nz_Sz = self.setup_z_Nz_Sz_inplace
else:
self.setup_z_Nz_Sz = self.setup_z_Nz_Sz_outplace
# Correctly reload older pickles where destroy_map were not # Correctly reload older pickles where destroy_map were not
# saved # saved
...@@ -841,7 +847,7 @@ class Gemm(GemmRelated): ...@@ -841,7 +847,7 @@ class Gemm(GemmRelated):
rval = self.__dict__.copy() rval = self.__dict__.copy()
# Do not serialize the setup code, it will be restored in __setstate__ # Do not serialize the setup code, it will be restored in __setstate__
# depending on the value of 'inplace' # depending on the value of 'inplace'
rval.pop('setup_z_Nz_Sz') rval.pop('setup_z_Nz_Sz', None)
return rval return rval
def make_node(self, *inputs): def make_node(self, *inputs):
...@@ -892,12 +898,12 @@ class Gemm(GemmRelated): ...@@ -892,12 +898,12 @@ class Gemm(GemmRelated):
output = z.type() output = z.type()
return Apply(self, inputs, [output]) return Apply(self, inputs, [output])
def perform(self, node, inp, out): def perform(self, node, inp, out, params):
z, a, x, y, b = inp z, a, x, y, b = inp
zout, = out zout, = out
assert a.shape == () assert a.shape == ()
assert b.shape == () assert b.shape == ()
if not self.inplace: if not params.inplace:
z = z.copy() # the original z will not be changed z = z.copy() # the original z will not be changed
if z.shape == (): if z.shape == ():
z.itemset(z * a + b * np.dot(x, y)) z.itemset(z * a + b * np.dot(x, y))
...@@ -1039,7 +1045,7 @@ class Gemm(GemmRelated): ...@@ -1039,7 +1045,7 @@ class Gemm(GemmRelated):
def c_code_cache_version(self): def c_code_cache_version(self):
gv = self.build_gemm_version() gv = self.build_gemm_version()
if gv: if gv:
return (5,) + gv return (6,) + gv
else: else:
return gv return gv
...@@ -1522,6 +1528,7 @@ class Dot22(GemmRelated): ...@@ -1522,6 +1528,7 @@ class Dot22(GemmRelated):
This is a specialization of the more general Dot(). This is a specialization of the more general Dot().
""" """
check_input = False
def make_node(self, x, y): def make_node(self, x, y):
dtypes = ('float16', 'float32', 'float64', 'complex64', 'complex128') dtypes = ('float16', 'float32', 'float64', 'complex64', 'complex128')
...@@ -1784,6 +1791,7 @@ class Dot22Scalar(GemmRelated): ...@@ -1784,6 +1791,7 @@ class Dot22Scalar(GemmRelated):
compute scalar*dot(x,y). compute scalar*dot(x,y).
""" """
check_input = False
def make_node(self, x, y, a): def make_node(self, x, y, a):
if a.ndim != 0: if a.ndim != 0:
......
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
from theano import config from theano import config
from theano.gof.params_type import ParamsType
from theano.scalar import bool as bool_t
from theano.tensor.opt import in2out from theano.tensor.opt import in2out
from theano.tensor.blas import ldflags, blas_header_text, blas_header_version from theano.tensor.blas import ldflags, blas_header_text, blas_header_version
from theano.tensor.blas import blas_optdb, optdb, local_optimizer from theano.tensor.blas import blas_optdb, optdb, local_optimizer
...@@ -30,7 +32,7 @@ class BaseBLAS(object): ...@@ -30,7 +32,7 @@ class BaseBLAS(object):
# GER # GER
# ##### ####### ####### # ##### ####### #######
def ger_c_code(A, a, x, y, Z, destructive, fail): def ger_c_code(A, a, x, y, Z, fail, params):
return """ return """
int elemsize ; int elemsize ;
...@@ -71,7 +73,7 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -71,7 +73,7 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
} }
// copy A if !self.destructive or A is fully strided // copy A if !self.destructive or A is fully strided
if (!%(destructive)s if (!%(params)s->destructive
|| (PyArray_STRIDES(%(A)s)[0] < 0) || (PyArray_STRIDES(%(A)s)[0] < 0)
|| (PyArray_STRIDES(%(A)s)[1] < 0) || (PyArray_STRIDES(%(A)s)[1] < 0)
|| ((PyArray_STRIDES(%(A)s)[0] != elemsize) || ((PyArray_STRIDES(%(A)s)[0] != elemsize)
...@@ -311,16 +313,18 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -311,16 +313,18 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
class CGer(BaseBLAS, Ger): class CGer(BaseBLAS, Ger):
params_type = ParamsType(destructive=bool_t,)
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
A, a, x, y = inp A, a, x, y = inp
Z, = out Z, = out
code = ger_c_code(A, a, x, y, Z, code = ger_c_code(A, a, x, y, Z,
destructive=int(self.destructive), fail=sub['fail'],
fail=sub['fail']) params=sub['params'])
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (10, blas_header_version()) return (11, blas_header_version())
cger_inplace = CGer(True) cger_inplace = CGer(True)
cger_no_inplace = CGer(False) cger_no_inplace = CGer(False)
...@@ -349,8 +353,8 @@ def make_c_ger_destructive(node): ...@@ -349,8 +353,8 @@ def make_c_ger_destructive(node):
# ##### ####### ####### # ##### ####### #######
def gemv_c_code(y, A, x, z, alpha, beta, destructive, fail, def gemv_c_code(y, A, x, z, alpha, beta, fail,
force_init_beta=False): force_init_beta=False, params=None):
""" """
z <- beta * y + alpha * dot(A, x) z <- beta * y + alpha * dot(A, x)
...@@ -385,7 +389,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, destructive, fail, ...@@ -385,7 +389,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, destructive, fail,
fbeta = dbeta = ((dtype_%(beta)s*)PyArray_DATA(%(beta)s))[0]; fbeta = dbeta = ((dtype_%(beta)s*)PyArray_DATA(%(beta)s))[0];
// copy y if not destructive // copy y if not destructive
if (!%(destructive)s) if (!%(params)s->inplace)
{ {
if ((NULL == %(z)s) if ((NULL == %(z)s)
|| (PyArray_DIMS(%(z)s)[0] != PyArray_DIMS(%(y)s)[0])) || (PyArray_DIMS(%(z)s)[0] != PyArray_DIMS(%(y)s)[0]))
...@@ -593,6 +597,8 @@ def gemv_c_code(y, A, x, z, alpha, beta, destructive, fail, ...@@ -593,6 +597,8 @@ def gemv_c_code(y, A, x, z, alpha, beta, destructive, fail,
class CGemv(BaseBLAS, Gemv): class CGemv(BaseBLAS, Gemv):
params_type = ParamsType(inplace=bool_t,)
def __init__(self, inplace): def __init__(self, inplace):
super(CGemv, self).__init__(inplace) super(CGemv, self).__init__(inplace)
...@@ -601,14 +607,14 @@ class CGemv(BaseBLAS, Gemv): ...@@ -601,14 +607,14 @@ class CGemv(BaseBLAS, Gemv):
z, = out z, = out
code = gemv_c_code( code = gemv_c_code(
y, A, x, z, alpha, beta, y, A, x, z, alpha, beta,
destructive=int(self.inplace),
fail=sub['fail'], fail=sub['fail'],
force_init_beta=check_force_gemv_init() force_init_beta=check_force_gemv_init(),
params=sub['params'],
) )
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (13, blas_header_version(), check_force_gemv_init()) return (14, blas_header_version(), check_force_gemv_init())
cgemv_inplace = CGemv(inplace=True) cgemv_inplace = CGemv(inplace=True)
cgemv_no_inplace = CGemv(inplace=False) cgemv_no_inplace = CGemv(inplace=False)
......
...@@ -14,7 +14,9 @@ from six.moves import xrange ...@@ -14,7 +14,9 @@ from six.moves import xrange
import six.moves.builtins as builtins import six.moves.builtins as builtins
import theano import theano
from theano import gof, OpenMPOp, tensor, Variable, Apply from theano import gof, OpenMPOp, tensor, Variable, Apply
from theano.gof.params_type import ParamsType
from theano.gradient import DisconnectedType from theano.gradient import DisconnectedType
from theano.scalar import bool as bool_t
def max_pool_2d_same_size(input, patch_size): def max_pool_2d_same_size(input, patch_size):
...@@ -294,6 +296,7 @@ class Pool(OpenMPOp): ...@@ -294,6 +296,7 @@ class Pool(OpenMPOp):
""" """
__props__ = ('ignore_border', 'mode', 'ndim') __props__ = ('ignore_border', 'mode', 'ndim')
params_type = ParamsType(ignore_border=bool_t,)
@staticmethod @staticmethod
def out_shape(imgshape, ws=None, ignore_border=False, stride=None, pad=None, def out_shape(imgshape, ws=None, ignore_border=False, stride=None, pad=None,
...@@ -508,7 +511,7 @@ class Pool(OpenMPOp): ...@@ -508,7 +511,7 @@ class Pool(OpenMPOp):
out = tensor.TensorType(x.dtype, broad) out = tensor.TensorType(x.dtype, broad)
return gof.Apply(self, [x, ws, stride, pad], [out()]) return gof.Apply(self, [x, ws, stride, pad], [out()])
def perform(self, node, inp, out): def perform(self, node, inp, out, params):
x, ws, stride, pad = inp x, ws, stride, pad = inp
z, = out z, = out
nd = self.ndim nd = self.ndim
...@@ -516,8 +519,8 @@ class Pool(OpenMPOp): ...@@ -516,8 +519,8 @@ class Pool(OpenMPOp):
if len(x.shape) < nd: if len(x.shape) < nd:
raise NotImplementedError( raise NotImplementedError(
'Pool requires input with {} or more dimensions'.format(nd)) 'Pool requires input with {} or more dimensions'.format(nd))
z_shape = self.out_shape(x.shape, ws, self.ignore_border, stride, pad, nd) z_shape = self.out_shape(x.shape, ws, params.ignore_border, stride, pad, nd)
if not self.ignore_border: if not params.ignore_border:
assert all(z > 0 for z in z_shape[-nd:]) assert all(z > 0 for z in z_shape[-nd:])
if (z[0] is None) or (z[0].shape != z_shape): if (z[0] is None) or (z[0].shape != z_shape):
z[0] = np.empty(z_shape, dtype=x.dtype) z[0] = np.empty(z_shape, dtype=x.dtype)
...@@ -617,7 +620,7 @@ class Pool(OpenMPOp): ...@@ -617,7 +620,7 @@ class Pool(OpenMPOp):
total_ndim = node.inputs[0].ndim total_ndim = node.inputs[0].ndim
non_pool_ndim = total_ndim - nd non_pool_ndim = total_ndim - nd
fail = sub['fail'] fail = sub['fail']
ignore_border = int(self.ignore_border) params = sub['params']
if self.openmp: if self.openmp:
# run in parallel over each pooling block # run in parallel over each pooling block
omp_parallel = '#pragma omp parallel for private(r_st, r_end, r_idx, i_idx, o_idx, collector) schedule(static)' omp_parallel = '#pragma omp parallel for private(r_st, r_end, r_idx, i_idx, o_idx, collector) schedule(static)'
...@@ -661,13 +664,13 @@ class Pool(OpenMPOp): ...@@ -661,13 +664,13 @@ class Pool(OpenMPOp):
if (pd[i]>0) if (pd[i]>0)
nonzero_padding = 1; nonzero_padding = 1;
} }
if (!%(ignore_border)s && nonzero_padding) if (!%(params)s->ignore_border && nonzero_padding)
{ {
PyErr_SetString(PyExc_ValueError, PyErr_SetString(PyExc_ValueError,
"padding must be zero when ignore border is False"); "padding must be zero when ignore border is False");
%(fail)s; %(fail)s;
} }
if (%(ignore_border)s) if (%(params)s->ignore_border)
{ {
for (int i=0; i<%(nd)s; i++) for (int i=0; i<%(nd)s; i++)
{ {
...@@ -801,13 +804,13 @@ class Pool(OpenMPOp): ...@@ -801,13 +804,13 @@ class Pool(OpenMPOp):
r_st[%(i)s] -= pd[%(i)s]; r_st[%(i)s] -= pd[%(i)s];
r_end[%(i)s] -= pd[%(i)s]; r_end[%(i)s] -= pd[%(i)s];
// handle the case where no padding, ignore border is True // handle the case where no padding, ignore border is True
if (%(ignore_border)s) if (%(params)s->ignore_border)
{ {
r_end[%(i)s] = r_end[%(i)s] > r[%(i)s] ? r[%(i)s] : r_end[%(i)s]; r_end[%(i)s] = r_end[%(i)s] > r[%(i)s] ? r[%(i)s] : r_end[%(i)s];
} }
// use the index to find the correct position in the output // use the index to find the correct position in the output
o_idx[%(non_pool_ndim)s + %(i)s] = r_idx[%(i)s]; o_idx[%(non_pool_ndim)s + %(i)s] = r_idx[%(i)s];
""" % dict(i=i, ignore_border=ignore_border, non_pool_ndim=non_pool_ndim) """ % dict(i=i, non_pool_ndim=non_pool_ndim, params=sub['params'])
ccode += """ ccode += """
// get a pointer to the correct position in the output // get a pointer to the correct position in the output
...@@ -907,7 +910,7 @@ class Pool(OpenMPOp): ...@@ -907,7 +910,7 @@ class Pool(OpenMPOp):
return ccode % locals() return ccode % locals()
def c_code_cache_version(self): def c_code_cache_version(self):
return (0, 6, 8, 7, self.openmp) return (9, self.openmp)
class PoolGrad(OpenMPOp): class PoolGrad(OpenMPOp):
...@@ -1089,6 +1092,8 @@ class PoolGrad(OpenMPOp): ...@@ -1089,6 +1092,8 @@ class PoolGrad(OpenMPOp):
class MaxPoolGrad(PoolGrad): class MaxPoolGrad(PoolGrad):
# params_type ignore_border don't change c code
def __init__(self, ignore_border, ndim=2, openmp=None): def __init__(self, ignore_border, ndim=2, openmp=None):
PoolGrad.__init__(self, ignore_border, mode='max', ndim=ndim, openmp=openmp) PoolGrad.__init__(self, ignore_border, mode='max', ndim=ndim, openmp=openmp)
...@@ -1191,7 +1196,7 @@ class MaxPoolGrad(PoolGrad): ...@@ -1191,7 +1196,7 @@ class MaxPoolGrad(PoolGrad):
total_ndim = node.inputs[0].ndim total_ndim = node.inputs[0].ndim
non_pool_ndim = total_ndim - nd non_pool_ndim = total_ndim - nd
fail = sub['fail'] fail = sub['fail']
ignore_border = int(self.ignore_border)
if self.openmp: if self.openmp:
# run in parallel over each pooling block # run in parallel over each pooling block
omp_parallel = '#pragma omp parallel for private(r_st, r_end, r_idx, i_idx, o_idx, maximum) schedule(static)' omp_parallel = '#pragma omp parallel for private(r_st, r_end, r_idx, i_idx, o_idx, maximum) schedule(static)'
...@@ -1404,6 +1409,8 @@ class MaxPoolGrad(PoolGrad): ...@@ -1404,6 +1409,8 @@ class MaxPoolGrad(PoolGrad):
class AveragePoolGrad(PoolGrad): class AveragePoolGrad(PoolGrad):
# ignore_border is used for perform, but not c code. No need in params_type
def __init__(self, ignore_border, mode='average_inc_pad', ndim=2): def __init__(self, ignore_border, mode='average_inc_pad', ndim=2):
assert mode in ['sum', 'average_inc_pad', 'average_exc_pad'] assert mode in ['sum', 'average_inc_pad', 'average_exc_pad']
PoolGrad.__init__(self, ignore_border, mode, ndim) PoolGrad.__init__(self, ignore_border, mode, ndim)
...@@ -1859,7 +1866,7 @@ class DownsampleFactorMaxGradGrad(OpenMPOp): ...@@ -1859,7 +1866,7 @@ class DownsampleFactorMaxGradGrad(OpenMPOp):
total_ndim = node.inputs[0].ndim total_ndim = node.inputs[0].ndim
non_pool_ndim = total_ndim - nd non_pool_ndim = total_ndim - nd
fail = sub['fail'] fail = sub['fail']
ignore_border = int(self.ignore_border)
if self.openmp: if self.openmp:
# run in parallel over each pooling block # run in parallel over each pooling block
omp_parallel = '#pragma omp parallel for private(r_st, r_end, r_idx, i_idx, o_idx, maximum) schedule(static)' omp_parallel = '#pragma omp parallel for private(r_st, r_end, r_idx, i_idx, o_idx, maximum) schedule(static)'
...@@ -2064,6 +2071,7 @@ class MaxPoolRop(OpenMPOp): ...@@ -2064,6 +2071,7 @@ class MaxPoolRop(OpenMPOp):
""" """
__props__ = ('ignore_border', 'mode', 'ndim') __props__ = ('ignore_border', 'mode', 'ndim')
params_type = ParamsType(ignore_border=bool_t,)
def __init__(self, ignore_border=False, mode='max', ndim=2, openmp=None): def __init__(self, ignore_border=False, mode='max', ndim=2, openmp=None):
super(MaxPoolRop, self).__init__(openmp=openmp) super(MaxPoolRop, self).__init__(openmp=openmp)
...@@ -2108,7 +2116,7 @@ class MaxPoolRop(OpenMPOp): ...@@ -2108,7 +2116,7 @@ class MaxPoolRop(OpenMPOp):
out = tensor.TensorType(eval_point.dtype, broad) out = tensor.TensorType(eval_point.dtype, broad)
return gof.Apply(self, [x, eval_point, ws, stride, pad], [out()]) return gof.Apply(self, [x, eval_point, ws, stride, pad], [out()])
def perform(self, node, inp, out): def perform(self, node, inp, out, params):
x, ex, ws, stride, pad = inp x, ex, ws, stride, pad = inp
z, = out z, = out
nd = self.ndim nd = self.ndim
...@@ -2116,7 +2124,7 @@ class MaxPoolRop(OpenMPOp): ...@@ -2116,7 +2124,7 @@ class MaxPoolRop(OpenMPOp):
if len(x.shape) < nd: if len(x.shape) < nd:
raise NotImplementedError( raise NotImplementedError(
'Pool requires input with {} or more dimensions'.format(nd)) 'Pool requires input with {} or more dimensions'.format(nd))
z_shape = Pool.out_shape(x.shape, ws, self.ignore_border, stride, pad, nd) z_shape = Pool.out_shape(x.shape, ws, params.ignore_border, stride, pad, nd)
if not self.ignore_border: if not self.ignore_border:
assert all(z > 0 for z in z_shape[-nd:]) assert all(z > 0 for z in z_shape[-nd:])
if (z[0] is None) or (z[0].shape != z_shape): if (z[0] is None) or (z[0].shape != z_shape):
...@@ -2179,7 +2187,8 @@ class MaxPoolRop(OpenMPOp): ...@@ -2179,7 +2187,8 @@ class MaxPoolRop(OpenMPOp):
total_ndim = node.inputs[0].ndim total_ndim = node.inputs[0].ndim
non_pool_ndim = total_ndim - nd non_pool_ndim = total_ndim - nd
fail = sub['fail'] fail = sub['fail']
ignore_border = int(self.ignore_border) params = sub['params']
if self.openmp: if self.openmp:
# run in parallel over each pooling block # run in parallel over each pooling block
omp_parallel = '#pragma omp parallel for private(r_st, r_end, r_idx, i_idx, o_idx, collector, eval_collector) schedule(static)' omp_parallel = '#pragma omp parallel for private(r_st, r_end, r_idx, i_idx, o_idx, collector, eval_collector) schedule(static)'
...@@ -2228,13 +2237,13 @@ class MaxPoolRop(OpenMPOp): ...@@ -2228,13 +2237,13 @@ class MaxPoolRop(OpenMPOp):
if (pd[i]>0) if (pd[i]>0)
nonzero_padding = 1; nonzero_padding = 1;
} }
if (!%(ignore_border)s && nonzero_padding) if (!%(params)s->ignore_border && nonzero_padding)
{ {
PyErr_SetString(PyExc_ValueError, PyErr_SetString(PyExc_ValueError,
"padding must be zero when ignore border is False"); "padding must be zero when ignore border is False");
%(fail)s; %(fail)s;
} }
if (%(ignore_border)s) if (%(params)s->ignore_border)
{ {
for (int i=0; i<%(nd)s; i++) for (int i=0; i<%(nd)s; i++)
{ {
...@@ -2369,13 +2378,13 @@ class MaxPoolRop(OpenMPOp): ...@@ -2369,13 +2378,13 @@ class MaxPoolRop(OpenMPOp):
r_st[%(i)s] -= pd[%(i)s]; r_st[%(i)s] -= pd[%(i)s];
r_end[%(i)s] -= pd[%(i)s]; r_end[%(i)s] -= pd[%(i)s];
// handle the case where no padding, ignore border is True // handle the case where no padding, ignore border is True
if (%(ignore_border)s) if (%(params)s->ignore_border)
{ {
r_end[%(i)s] = r_end[%(i)s] > r[%(i)s] ? r[%(i)s] : r_end[%(i)s]; r_end[%(i)s] = r_end[%(i)s] > r[%(i)s] ? r[%(i)s] : r_end[%(i)s];
} }
// use the index to find the correct position in the output // use the index to find the correct position in the output
o_idx[%(non_pool_ndim)s + %(i)s] = r_idx[%(i)s]; o_idx[%(non_pool_ndim)s + %(i)s] = r_idx[%(i)s];
""" % dict(i=i, ignore_border=ignore_border, non_pool_ndim=non_pool_ndim) """ % dict(i=i, params=sub['params'], non_pool_ndim=non_pool_ndim)
ccode += """ ccode += """
// get a pointer to the correct position in the output // get a pointer to the correct position in the output
...@@ -2444,4 +2453,4 @@ class MaxPoolRop(OpenMPOp): ...@@ -2444,4 +2453,4 @@ class MaxPoolRop(OpenMPOp):
return ccode % locals() return ccode % locals()
def c_code_cache_version(self): def c_code_cache_version(self):
return (0, self.openmp) return (1, self.openmp)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论