提交 60a906e7 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #6081 from nouiz/params_gemm

Params for Gemm, CGer, CGemv
......@@ -421,7 +421,7 @@ def get_c_extract_out(r, name, sub):
# `c_extract_out` is used to extract an output variable from
# the compute map, to be used as pre-allocated memory for `r`
# before its value gets computed.
# If the node producing `r` has `check_inputs=True`, it may
# If the node producing `r` has `check_input=True`, it may
# also perform type checks on the initial value of the output,
# so we need to pass `check_input=True` to `c_extract_out`.
# However, that code is not used by potential clients of `r`,
......
......@@ -145,9 +145,11 @@ from theano.gof import (utils, Op, view_roots,
InconsistencyError, toolbox, SequenceDB,
EquilibriumOptimizer, Apply,
ReplacementDidntRemovedError)
from theano.gof.params_type import ParamsType
from theano.printing import pprint, FunctionPrinter, debugprint
from theano.compile.mode import optdb
import theano.scalar
from theano.scalar import bool as bool_t
from theano.tensor import basic as T
from theano.tensor.blas_headers import blas_header_text
from theano.tensor.blas_headers import blas_header_version
......@@ -243,7 +245,7 @@ class Gemv(Op):
raise TypeError('gemv requires vector for y', y.type)
return Apply(self, [y, alpha, A, x, beta], [y.type()])
def perform(self, node, inputs, out_storage):
def perform(self, node, inputs, out_storage, params=None):
y, alpha, A, x, beta = inputs
if (have_fblas and y.shape[0] != 0 and x.shape[0] != 0 and
y.dtype in _blas_gemv_fns):
......@@ -333,7 +335,7 @@ class Ger(Op):
raise TypeError('only float and complex types supported', x.dtype)
return Apply(self, [A, alpha, x, y], [A.type()])
def perform(self, node, inp, out):
def perform(self, node, inp, out, params=None):
cA, calpha, cx, cy = inp
cZ, = out
if self.destructive:
......@@ -522,7 +524,11 @@ class GemmRelated(Op):
int sx_0, sx_1, sy_0, sy_1, sz_0, sz_1;
"""
# implement if you don't have an inplace props
# setup_z_Nz_Sz = None
# otherwise implement
# setup_z_Nz_Sz_inplace = None
# setup_z_Nz_Sz_outplace = None
check_xyz_rank2 = """
if (PyArray_NDIM(%(_x)s) != 2) {
......@@ -755,11 +761,16 @@ class GemmRelated(Op):
"""
def build_gemm_call(self):
if hasattr(self, 'inplace'):
setup_z_Nz_Sz = "if(%%(params)s->inplace){%s}else{%s}" % (
self.setup_z_Nz_Sz_inplace, self.setup_z_Nz_Sz_outplace)
else:
setup_z_Nz_Sz = self.setup_z_Nz_Sz
return reduce(str.__add__, (
self.declare_NS,
self.check_xyz_rank2,
self.setup_z_Nz_Sz,
setup_z_Nz_Sz,
self.check_xyz_double_or_float,
self.check_ab_double_or_float,
self.check_dims,
......@@ -809,14 +820,13 @@ class Gemm(GemmRelated):
E_float = 'gemm requires floating-point dtypes'
__props__ = ('inplace',)
params_type = ParamsType(inplace=bool_t,)
check_input = False
def __init__(self, inplace):
self.inplace = inplace
if self.inplace:
self.destroy_map = {0: [0]}
self.setup_z_Nz_Sz = self.setup_z_Nz_Sz_inplace
else:
self.setup_z_Nz_Sz = self.setup_z_Nz_Sz_outplace
def __str__(self):
if self.inplace:
......@@ -827,10 +837,6 @@ class Gemm(GemmRelated):
def __setstate__(self, dct):
self.__dict__.update(dct)
if self.inplace:
self.setup_z_Nz_Sz = self.setup_z_Nz_Sz_inplace
else:
self.setup_z_Nz_Sz = self.setup_z_Nz_Sz_outplace
# Correctly reload older pickles where destroy_map were not
# saved
......@@ -841,7 +847,7 @@ class Gemm(GemmRelated):
rval = self.__dict__.copy()
# Do not serialize the setup code, it will be restored in __setstate__
# depending on the value of 'inplace'
rval.pop('setup_z_Nz_Sz')
rval.pop('setup_z_Nz_Sz', None)
return rval
def make_node(self, *inputs):
......@@ -892,12 +898,12 @@ class Gemm(GemmRelated):
output = z.type()
return Apply(self, inputs, [output])
def perform(self, node, inp, out):
def perform(self, node, inp, out, params):
z, a, x, y, b = inp
zout, = out
assert a.shape == ()
assert b.shape == ()
if not self.inplace:
if not params.inplace:
z = z.copy() # the original z will not be changed
if z.shape == ():
z.itemset(z * a + b * np.dot(x, y))
......@@ -1039,7 +1045,7 @@ class Gemm(GemmRelated):
def c_code_cache_version(self):
gv = self.build_gemm_version()
if gv:
return (5,) + gv
return (6,) + gv
else:
return gv
......@@ -1522,6 +1528,7 @@ class Dot22(GemmRelated):
This is a specialization of the more general Dot().
"""
check_input = False
def make_node(self, x, y):
dtypes = ('float16', 'float32', 'float64', 'complex64', 'complex128')
......@@ -1784,6 +1791,7 @@ class Dot22Scalar(GemmRelated):
compute scalar*dot(x,y).
"""
check_input = False
def make_node(self, x, y, a):
if a.ndim != 0:
......
from __future__ import absolute_import, print_function, division
from theano import config
from theano.gof.params_type import ParamsType
from theano.scalar import bool as bool_t
from theano.tensor.opt import in2out
from theano.tensor.blas import ldflags, blas_header_text, blas_header_version
from theano.tensor.blas import blas_optdb, optdb, local_optimizer
......@@ -30,7 +32,7 @@ class BaseBLAS(object):
# GER
# ##### ####### #######
def ger_c_code(A, a, x, y, Z, destructive, fail):
def ger_c_code(A, a, x, y, Z, fail, params):
return """
int elemsize ;
......@@ -71,7 +73,7 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
}
// copy A if !self.destructive or A is fully strided
if (!%(destructive)s
if (!%(params)s->destructive
|| (PyArray_STRIDES(%(A)s)[0] < 0)
|| (PyArray_STRIDES(%(A)s)[1] < 0)
|| ((PyArray_STRIDES(%(A)s)[0] != elemsize)
......@@ -311,16 +313,18 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
class CGer(BaseBLAS, Ger):
params_type = ParamsType(destructive=bool_t,)
def c_code(self, node, name, inp, out, sub):
A, a, x, y = inp
Z, = out
code = ger_c_code(A, a, x, y, Z,
destructive=int(self.destructive),
fail=sub['fail'])
fail=sub['fail'],
params=sub['params'])
return code
def c_code_cache_version(self):
return (10, blas_header_version())
return (11, blas_header_version())
cger_inplace = CGer(True)
cger_no_inplace = CGer(False)
......@@ -349,8 +353,8 @@ def make_c_ger_destructive(node):
# ##### ####### #######
def gemv_c_code(y, A, x, z, alpha, beta, destructive, fail,
force_init_beta=False):
def gemv_c_code(y, A, x, z, alpha, beta, fail,
force_init_beta=False, params=None):
"""
z <- beta * y + alpha * dot(A, x)
......@@ -385,7 +389,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, destructive, fail,
fbeta = dbeta = ((dtype_%(beta)s*)PyArray_DATA(%(beta)s))[0];
// copy y if not destructive
if (!%(destructive)s)
if (!%(params)s->inplace)
{
if ((NULL == %(z)s)
|| (PyArray_DIMS(%(z)s)[0] != PyArray_DIMS(%(y)s)[0]))
......@@ -593,6 +597,8 @@ def gemv_c_code(y, A, x, z, alpha, beta, destructive, fail,
class CGemv(BaseBLAS, Gemv):
params_type = ParamsType(inplace=bool_t,)
def __init__(self, inplace):
super(CGemv, self).__init__(inplace)
......@@ -601,14 +607,14 @@ class CGemv(BaseBLAS, Gemv):
z, = out
code = gemv_c_code(
y, A, x, z, alpha, beta,
destructive=int(self.inplace),
fail=sub['fail'],
force_init_beta=check_force_gemv_init()
force_init_beta=check_force_gemv_init(),
params=sub['params'],
)
return code
def c_code_cache_version(self):
return (13, blas_header_version(), check_force_gemv_init())
return (14, blas_header_version(), check_force_gemv_init())
cgemv_inplace = CGemv(inplace=True)
cgemv_no_inplace = CGemv(inplace=False)
......
......@@ -14,7 +14,9 @@ from six.moves import xrange
import six.moves.builtins as builtins
import theano
from theano import gof, OpenMPOp, tensor, Variable, Apply
from theano.gof.params_type import ParamsType
from theano.gradient import DisconnectedType
from theano.scalar import bool as bool_t
def max_pool_2d_same_size(input, patch_size):
......@@ -294,6 +296,7 @@ class Pool(OpenMPOp):
"""
__props__ = ('ignore_border', 'mode', 'ndim')
params_type = ParamsType(ignore_border=bool_t,)
@staticmethod
def out_shape(imgshape, ws=None, ignore_border=False, stride=None, pad=None,
......@@ -508,7 +511,7 @@ class Pool(OpenMPOp):
out = tensor.TensorType(x.dtype, broad)
return gof.Apply(self, [x, ws, stride, pad], [out()])
def perform(self, node, inp, out):
def perform(self, node, inp, out, params):
x, ws, stride, pad = inp
z, = out
nd = self.ndim
......@@ -516,8 +519,8 @@ class Pool(OpenMPOp):
if len(x.shape) < nd:
raise NotImplementedError(
'Pool requires input with {} or more dimensions'.format(nd))
z_shape = self.out_shape(x.shape, ws, self.ignore_border, stride, pad, nd)
if not self.ignore_border:
z_shape = self.out_shape(x.shape, ws, params.ignore_border, stride, pad, nd)
if not params.ignore_border:
assert all(z > 0 for z in z_shape[-nd:])
if (z[0] is None) or (z[0].shape != z_shape):
z[0] = np.empty(z_shape, dtype=x.dtype)
......@@ -617,7 +620,7 @@ class Pool(OpenMPOp):
total_ndim = node.inputs[0].ndim
non_pool_ndim = total_ndim - nd
fail = sub['fail']
ignore_border = int(self.ignore_border)
params = sub['params']
if self.openmp:
# run in parallel over each pooling block
omp_parallel = '#pragma omp parallel for private(r_st, r_end, r_idx, i_idx, o_idx, collector) schedule(static)'
......@@ -661,13 +664,13 @@ class Pool(OpenMPOp):
if (pd[i]>0)
nonzero_padding = 1;
}
if (!%(ignore_border)s && nonzero_padding)
if (!%(params)s->ignore_border && nonzero_padding)
{
PyErr_SetString(PyExc_ValueError,
"padding must be zero when ignore border is False");
%(fail)s;
}
if (%(ignore_border)s)
if (%(params)s->ignore_border)
{
for (int i=0; i<%(nd)s; i++)
{
......@@ -801,13 +804,13 @@ class Pool(OpenMPOp):
r_st[%(i)s] -= pd[%(i)s];
r_end[%(i)s] -= pd[%(i)s];
// handle the case where no padding, ignore border is True
if (%(ignore_border)s)
if (%(params)s->ignore_border)
{
r_end[%(i)s] = r_end[%(i)s] > r[%(i)s] ? r[%(i)s] : r_end[%(i)s];
}
// use the index to find the correct position in the output
o_idx[%(non_pool_ndim)s + %(i)s] = r_idx[%(i)s];
""" % dict(i=i, ignore_border=ignore_border, non_pool_ndim=non_pool_ndim)
""" % dict(i=i, non_pool_ndim=non_pool_ndim, params=sub['params'])
ccode += """
// get a pointer to the correct position in the output
......@@ -907,7 +910,7 @@ class Pool(OpenMPOp):
return ccode % locals()
def c_code_cache_version(self):
return (0, 6, 8, 7, self.openmp)
return (9, self.openmp)
class PoolGrad(OpenMPOp):
......@@ -1089,6 +1092,8 @@ class PoolGrad(OpenMPOp):
class MaxPoolGrad(PoolGrad):
# params_type ignore_border don't change c code
def __init__(self, ignore_border, ndim=2, openmp=None):
PoolGrad.__init__(self, ignore_border, mode='max', ndim=ndim, openmp=openmp)
......@@ -1191,7 +1196,7 @@ class MaxPoolGrad(PoolGrad):
total_ndim = node.inputs[0].ndim
non_pool_ndim = total_ndim - nd
fail = sub['fail']
ignore_border = int(self.ignore_border)
if self.openmp:
# run in parallel over each pooling block
omp_parallel = '#pragma omp parallel for private(r_st, r_end, r_idx, i_idx, o_idx, maximum) schedule(static)'
......@@ -1404,6 +1409,8 @@ class MaxPoolGrad(PoolGrad):
class AveragePoolGrad(PoolGrad):
# ignore_border is used for perform, but not c code. No need in params_type
def __init__(self, ignore_border, mode='average_inc_pad', ndim=2):
assert mode in ['sum', 'average_inc_pad', 'average_exc_pad']
PoolGrad.__init__(self, ignore_border, mode, ndim)
......@@ -1859,7 +1866,7 @@ class DownsampleFactorMaxGradGrad(OpenMPOp):
total_ndim = node.inputs[0].ndim
non_pool_ndim = total_ndim - nd
fail = sub['fail']
ignore_border = int(self.ignore_border)
if self.openmp:
# run in parallel over each pooling block
omp_parallel = '#pragma omp parallel for private(r_st, r_end, r_idx, i_idx, o_idx, maximum) schedule(static)'
......@@ -2064,6 +2071,7 @@ class MaxPoolRop(OpenMPOp):
"""
__props__ = ('ignore_border', 'mode', 'ndim')
params_type = ParamsType(ignore_border=bool_t,)
def __init__(self, ignore_border=False, mode='max', ndim=2, openmp=None):
super(MaxPoolRop, self).__init__(openmp=openmp)
......@@ -2108,7 +2116,7 @@ class MaxPoolRop(OpenMPOp):
out = tensor.TensorType(eval_point.dtype, broad)
return gof.Apply(self, [x, eval_point, ws, stride, pad], [out()])
def perform(self, node, inp, out):
def perform(self, node, inp, out, params):
x, ex, ws, stride, pad = inp
z, = out
nd = self.ndim
......@@ -2116,7 +2124,7 @@ class MaxPoolRop(OpenMPOp):
if len(x.shape) < nd:
raise NotImplementedError(
'Pool requires input with {} or more dimensions'.format(nd))
z_shape = Pool.out_shape(x.shape, ws, self.ignore_border, stride, pad, nd)
z_shape = Pool.out_shape(x.shape, ws, params.ignore_border, stride, pad, nd)
if not self.ignore_border:
assert all(z > 0 for z in z_shape[-nd:])
if (z[0] is None) or (z[0].shape != z_shape):
......@@ -2179,7 +2187,8 @@ class MaxPoolRop(OpenMPOp):
total_ndim = node.inputs[0].ndim
non_pool_ndim = total_ndim - nd
fail = sub['fail']
ignore_border = int(self.ignore_border)
params = sub['params']
if self.openmp:
# run in parallel over each pooling block
omp_parallel = '#pragma omp parallel for private(r_st, r_end, r_idx, i_idx, o_idx, collector, eval_collector) schedule(static)'
......@@ -2228,13 +2237,13 @@ class MaxPoolRop(OpenMPOp):
if (pd[i]>0)
nonzero_padding = 1;
}
if (!%(ignore_border)s && nonzero_padding)
if (!%(params)s->ignore_border && nonzero_padding)
{
PyErr_SetString(PyExc_ValueError,
"padding must be zero when ignore border is False");
%(fail)s;
}
if (%(ignore_border)s)
if (%(params)s->ignore_border)
{
for (int i=0; i<%(nd)s; i++)
{
......@@ -2369,13 +2378,13 @@ class MaxPoolRop(OpenMPOp):
r_st[%(i)s] -= pd[%(i)s];
r_end[%(i)s] -= pd[%(i)s];
// handle the case where no padding, ignore border is True
if (%(ignore_border)s)
if (%(params)s->ignore_border)
{
r_end[%(i)s] = r_end[%(i)s] > r[%(i)s] ? r[%(i)s] : r_end[%(i)s];
}
// use the index to find the correct position in the output
o_idx[%(non_pool_ndim)s + %(i)s] = r_idx[%(i)s];
""" % dict(i=i, ignore_border=ignore_border, non_pool_ndim=non_pool_ndim)
""" % dict(i=i, params=sub['params'], non_pool_ndim=non_pool_ndim)
ccode += """
// get a pointer to the correct position in the output
......@@ -2444,4 +2453,4 @@ class MaxPoolRop(OpenMPOp):
return ccode % locals()
def c_code_cache_version(self):
return (0, self.openmp)
return (1, self.openmp)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论