提交 d323c6cd authored 作者: Frederic Bastien's avatar Frederic Bastien

Add params to CumOp

上级 9520fdec
...@@ -6,9 +6,10 @@ import theano ...@@ -6,9 +6,10 @@ import theano
from theano.tensor import basic from theano.tensor import basic
from theano.tensor import nlinalg # noqa from theano.tensor import nlinalg # noqa
from theano import gof, scalar from theano import gof, scalar
from theano.gof import Generic from theano.gof import Generic, ParamsType, EnumList
from theano import gradient from theano import gradient
from theano.gradient import DisconnectedType, disconnected_type from theano.gradient import DisconnectedType, disconnected_type
from theano.scalar import int32 as int_t
tensor = basic tensor = basic
...@@ -245,6 +246,10 @@ class CumOp(theano.Op): ...@@ -245,6 +246,10 @@ class CumOp(theano.Op):
# See function cumsum/cumprod for docstring # See function cumsum/cumprod for docstring
__props__ = ("axis", "mode") __props__ = ("axis", "mode")
check_input = False
params_type = ParamsType(c_axis=int_t,
mode=EnumList(('MODE_ADD', 'add'),
('MODE_MUL', 'mul')))
def __init__(self, axis=None, mode='add'): def __init__(self, axis=None, mode='add'):
if mode not in ('add', 'mul'): if mode not in ('add', 'mul'):
...@@ -252,6 +257,8 @@ class CumOp(theano.Op): ...@@ -252,6 +257,8 @@ class CumOp(theano.Op):
self.axis = axis self.axis = axis
self.mode = mode self.mode = mode
c_axis = property(lambda self: np.MAXDIMS if self.axis is None else self.axis)
def make_node(self, x): def make_node(self, x):
x = basic.as_tensor_variable(x) x = basic.as_tensor_variable(x)
out_type = x.type() out_type = x.type()
...@@ -263,7 +270,7 @@ class CumOp(theano.Op): ...@@ -263,7 +270,7 @@ class CumOp(theano.Op):
return theano.Apply(self, [x], [out_type]) return theano.Apply(self, [x], [out_type])
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage, params):
x = inputs[0] x = inputs[0]
z = output_storage[0] z = output_storage[0]
z[0] = {'add': np.cumsum, 'mul': np.cumprod}[self.mode](x, axis=self.axis) z[0] = {'add': np.cumsum, 'mul': np.cumprod}[self.mode](x, axis=self.axis)
...@@ -311,49 +318,43 @@ class CumOp(theano.Op): ...@@ -311,49 +318,43 @@ class CumOp(theano.Op):
z, = onames z, = onames
axis = self.axis axis = self.axis
fail = sub['fail'] fail = sub['fail']
func = dict(mul='CumProd', add='CumSum')[self.mode] params = sub['params']
if self.axis is None or (self.axis == 0 and node.inputs[0].ndim == 1):
code = """ code = """
int axis = %(params)s->c_axis;
if (axis == 0 && PyArray_NDIM(%(x)s) == 1)
axis = NPY_MAXDIMS;
npy_intp shape[1] = { PyArray_SIZE(%(x)s) }; npy_intp shape[1] = { PyArray_SIZE(%(x)s) };
if(!(%(z)s && PyArray_DIMS(%(z)s)[0] == shape[0])) if(axis == NPY_MAXDIMS && !(%(z)s && PyArray_DIMS(%(z)s)[0] == shape[0]))
{ {
Py_XDECREF(%(z)s); Py_XDECREF(%(z)s);
%(z)s = (PyArrayObject*) PyArray_SimpleNew(1, shape, PyArray_TYPE((PyArrayObject*) py_%(x)s)); %(z)s = (PyArrayObject*) PyArray_SimpleNew(1, shape, PyArray_TYPE((PyArrayObject*) py_%(x)s));
} }
if (!%(z)s) else if(axis != NPY_MAXDIMS && !(%(z)s && PyArray_CompareLists(PyArray_DIMS(%(z)s), PyArray_DIMS(%(x)s), PyArray_NDIM(%(x)s))))
%(fail)s;
{
PyObject * t = PyArray_%(func)s(
%(x)s, NPY_MAXDIMS,
PyArray_TYPE((PyArrayObject*) py_%(x)s), %(z)s);
if (!t){
%(fail)s;
}
// Because PyArray_%(func)s returns a newly created reference on t.
Py_XDECREF(t);
}
""" % locals()
else:
code = """
if(!(%(z)s && PyArray_CompareLists(PyArray_DIMS(%(z)s), PyArray_DIMS(%(x)s), PyArray_NDIM(%(x)s))))
{ {
Py_XDECREF(%(z)s); Py_XDECREF(%(z)s);
%(z)s = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM(%(x)s), PyArray_DIMS(%(x)s), PyArray_TYPE((PyArrayObject*) py_%(x)s)); %(z)s = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM(%(x)s), PyArray_DIMS(%(x)s), PyArray_TYPE(%(x)s));
} }
if (!%(z)s) if (!%(z)s)
%(fail)s; %(fail)s;
{ {
PyObject * t = PyArray_%(func)s( PyObject * t = NULL;
%(x)s, %(axis)s, if(%(params)s->mode == MODE_ADD)
t = PyArray_CumSum(
%(x)s, axis,
PyArray_TYPE((PyArrayObject*) py_%(x)s), %(z)s); PyArray_TYPE((PyArrayObject*) py_%(x)s), %(z)s);
else if(%(params)s->mode == MODE_MUL)
t = PyArray_CumProd(
%(x)s, axis,
PyArray_TYPE((PyArrayObject*) py_%(x)s), %(z)s);
if (!t){ if (!t){
%(fail)s; %(fail)s;
} }
// Because PyArray_%(func)s returns a newly created reference on t. // Because PyArray_CumSum/CumProd returns a newly created reference on t.
Py_XDECREF(t); Py_XDECREF(t);
} }
""" % locals() """ % locals()
...@@ -361,7 +362,7 @@ class CumOp(theano.Op): ...@@ -361,7 +362,7 @@ class CumOp(theano.Op):
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (7,) return (8,)
def __str__(self): def __str__(self):
return "%s{%s, %s}" % (self.__class__.__name__, self.axis, self.mode) return "%s{%s, %s}" % (self.__class__.__name__, self.axis, self.mode)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论