提交 d323c6cd authored 作者: Frederic Bastien's avatar Frederic Bastien

Add params to CumOp

上级 9520fdec
......@@ -6,9 +6,10 @@ import theano
from theano.tensor import basic
from theano.tensor import nlinalg # noqa
from theano import gof, scalar
from theano.gof import Generic
from theano.gof import Generic, ParamsType, EnumList
from theano import gradient
from theano.gradient import DisconnectedType, disconnected_type
from theano.scalar import int32 as int_t
tensor = basic
......@@ -245,6 +246,10 @@ class CumOp(theano.Op):
# See function cumsum/cumprod for docstring
__props__ = ("axis", "mode")
check_input = False
params_type = ParamsType(c_axis=int_t,
mode=EnumList(('MODE_ADD', 'add'),
('MODE_MUL', 'mul')))
def __init__(self, axis=None, mode='add'):
if mode not in ('add', 'mul'):
......@@ -252,6 +257,8 @@ class CumOp(theano.Op):
self.axis = axis
self.mode = mode
c_axis = property(lambda self: np.MAXDIMS if self.axis is None else self.axis)
def make_node(self, x):
x = basic.as_tensor_variable(x)
out_type = x.type()
......@@ -263,7 +270,7 @@ class CumOp(theano.Op):
return theano.Apply(self, [x], [out_type])
def perform(self, node, inputs, output_storage):
def perform(self, node, inputs, output_storage, params):
x = inputs[0]
z = output_storage[0]
z[0] = {'add': np.cumsum, 'mul': np.cumprod}[self.mode](x, axis=self.axis)
......@@ -311,49 +318,43 @@ class CumOp(theano.Op):
z, = onames
axis = self.axis
fail = sub['fail']
func = dict(mul='CumProd', add='CumSum')[self.mode]
params = sub['params']
if self.axis is None or (self.axis == 0 and node.inputs[0].ndim == 1):
code = """
code = """
int axis = %(params)s->c_axis;
if (axis == 0 && PyArray_NDIM(%(x)s) == 1)
axis = NPY_MAXDIMS;
npy_intp shape[1] = { PyArray_SIZE(%(x)s) };
if(!(%(z)s && PyArray_DIMS(%(z)s)[0] == shape[0]))
if(axis == NPY_MAXDIMS && !(%(z)s && PyArray_DIMS(%(z)s)[0] == shape[0]))
{
Py_XDECREF(%(z)s);
%(z)s = (PyArrayObject*) PyArray_SimpleNew(1, shape, PyArray_TYPE((PyArrayObject*) py_%(x)s));
}
if (!%(z)s)
%(fail)s;
{
PyObject * t = PyArray_%(func)s(
%(x)s, NPY_MAXDIMS,
PyArray_TYPE((PyArrayObject*) py_%(x)s), %(z)s);
if (!t){
%(fail)s;
}
// Because PyArray_%(func)s returns a newly created reference on t.
Py_XDECREF(t);
}
""" % locals()
else:
code = """
if(!(%(z)s && PyArray_CompareLists(PyArray_DIMS(%(z)s), PyArray_DIMS(%(x)s), PyArray_NDIM(%(x)s))))
else if(axis != NPY_MAXDIMS && !(%(z)s && PyArray_CompareLists(PyArray_DIMS(%(z)s), PyArray_DIMS(%(x)s), PyArray_NDIM(%(x)s))))
{
Py_XDECREF(%(z)s);
%(z)s = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM(%(x)s), PyArray_DIMS(%(x)s), PyArray_TYPE((PyArrayObject*) py_%(x)s));
%(z)s = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM(%(x)s), PyArray_DIMS(%(x)s), PyArray_TYPE(%(x)s));
}
if (!%(z)s)
%(fail)s;
{
PyObject * t = PyArray_%(func)s(
%(x)s, %(axis)s,
PyArray_TYPE((PyArrayObject*) py_%(x)s), %(z)s);
PyObject * t = NULL;
if(%(params)s->mode == MODE_ADD)
t = PyArray_CumSum(
%(x)s, axis,
PyArray_TYPE((PyArrayObject*) py_%(x)s), %(z)s);
else if(%(params)s->mode == MODE_MUL)
t = PyArray_CumProd(
%(x)s, axis,
PyArray_TYPE((PyArrayObject*) py_%(x)s), %(z)s);
if (!t){
%(fail)s;
}
// Because PyArray_%(func)s returns a newly created reference on t.
// Because PyArray_CumSum/CumProd returns a newly created reference on t.
Py_XDECREF(t);
}
""" % locals()
......@@ -361,7 +362,7 @@ class CumOp(theano.Op):
return code
def c_code_cache_version(self):
return (7,)
return (8,)
def __str__(self):
return "%s{%s, %s}" % (self.__class__.__name__, self.axis, self.mode)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论