提交 5ad8bbc2 authored 作者: notoraptor's avatar notoraptor

Wrap op params for theano.tensor.elemwise.DimShuffle.

上级 1fe27f0f
......@@ -9,7 +9,7 @@ import theano
from theano import gof
from theano.compat import izip
from theano.configparser import change_flags
from theano.gof import Apply, Op, OpenMPOp
from theano.gof import Apply, Op, OpenMPOp, ParamsType
from theano import scalar
from theano.scalar import get_scalar_type
from theano.printing import pprint
......@@ -131,11 +131,24 @@ class DimShuffle(Op):
check_input = False
__props__ = ("input_broadcastable", "new_order", "inplace")
@property
def params_type(self):
# We can't directly create `params_type` as class attribute
# because of importation issues related to TensorType.
return ParamsType(input_broadcastable=TensorType(dtype='bool', broadcastable=(False,)),
_new_order=theano.tensor.lvector,
inplace=theano.scalar.bool)
@property
def _new_order(self):
# Param for C code.
# self.new_order may contain 'x', which is not a valid integer value.
# We replace it with -1.
return [(-1 if x == 'x' else x) for x in self.new_order]
def __init__(self, input_broadcastable, new_order, inplace=True):
input_broadcastable = tuple(input_broadcastable)
self.input_broadcastable = input_broadcastable
new_order = tuple(new_order)
self.new_order = new_order
self.input_broadcastable = tuple(input_broadcastable)
self.new_order = tuple(new_order)
if inplace is True:
self.inplace = inplace
else:
......@@ -222,7 +235,7 @@ class DimShuffle(Op):
else:
return "DimShuffle{%s}" % ",".join(str(x) for x in self.new_order)
def perform(self, node, inp, out):
def perform(self, node, inp, out, params):
input, = inp
storage, = out
# drop
......@@ -265,98 +278,113 @@ class DimShuffle(Op):
res, = out
basename = input + '__view_or_copy'
def statements(lst):
return ';\n'.join(lst) + ';'
return """{
npy_bool* input_broadcastable;
npy_int64* new_order;
npy_intp nd_in;
npy_intp nd_out;
PyArrayObject* %(basename)s;
npy_intp* dimensions;
npy_intp* strides;
if (!PyArray_IS_C_CONTIGUOUS(%(params)s->input_broadcastable)) {
PyErr_SetString(PyExc_RuntimeError, "DimShuffle: param input_broadcastable must be C-contiguous.");
%(just_fail)s
}
if (!PyArray_IS_F_CONTIGUOUS(%(params)s->_new_order)) {
PyErr_SetString(PyExc_RuntimeError, "DimShuffle: param _new_order must be C-contiguous.");
%(just_fail)s
}
input_broadcastable = (npy_bool*) PyArray_DATA(%(params)s->input_broadcastable);
new_order = (npy_int64*) PyArray_DATA(%(params)s->_new_order);
nd_in = PyArray_SIZE(%(params)s->input_broadcastable);
nd_out = PyArray_SIZE(%(params)s->_new_order);
/* check_input_nd */
if (PyArray_NDIM(%(input)s) != nd_in) {
PyErr_SetString(PyExc_NotImplementedError, "input nd");
%(just_fail)s
}
/* clear_output */
if (%(res)s)
Py_XDECREF(%(res)s);
/* get_base */
if (%(params)s->inplace) {
%(basename)s = %(input)s;
Py_INCREF((PyObject*)%(basename)s);
} else {
%(basename)s =
(PyArrayObject*)PyArray_FromAny((PyObject*)%(input)s,
NULL, 0, 0, NPY_ARRAY_ALIGNED|NPY_ARRAY_ENSURECOPY, NULL);
}
/* shape_statements and strides_statements */
dimensions = (npy_intp*) malloc(nd_out * sizeof(npy_intp));
strides = (npy_intp*) malloc(nd_out * sizeof(npy_intp));
if (dimensions == NULL || strides == NULL) {
PyErr_NoMemory();
%(fail)s
};
for (npy_intp i = 0; i < nd_out; ++i) {
if (new_order[i] != -1) {
dimensions[i] = PyArray_DIMS(%(basename)s)[new_order[i]];
strides[i] = PyArray_DIMS(%(basename)s)[new_order[i]] == 1 ?
0 : PyArray_STRIDES(%(basename)s)[new_order[i]];
} else {
dimensions[i] = 1;
strides[i] = 0;
}
}
/* set the strides of the broadcasted dimensions.
* This algorithm is from numpy: PyArray_Newshape() in
* cvs/numpy/numpy/core/src/multiarraymodule.c */
if (nd_out > 0) {
if (strides[nd_out - 1] == 0)
strides[nd_out - 1] = PyArray_DESCR(%(basename)s)->elsize;
for (npy_intp i = nd_out - 2; i > -1; --i) {
if (strides[i] == 0)
strides[i] = strides[i + 1] * dimensions[i + 1];
}
}
nd_in = len(self.input_broadcastable)
nd_out = len(self.new_order)
/* close_bracket */
// create a new array.
%(res)s = (PyArrayObject*)PyArray_New(&PyArray_Type, nd_out, dimensions,
PyArray_TYPE(%(basename)s), strides,
PyArray_DATA(%(basename)s), PyArray_ITEMSIZE(%(basename)s),
// borrow only the writable flag from the base
// the NPY_OWNDATA flag will default to 0.
(NPY_ARRAY_WRITEABLE * PyArray_ISWRITEABLE(%(basename)s)),
NULL);
if (%(res)s == NULL) {
%(fail)s
}
check_input_nd = [('if (PyArray_NDIM(%(input)s) != ' + str(nd_in) + ')'
'{PyErr_SetString(PyExc_NotImplementedError, '
'"input nd"); %(fail)s;}')]
// recalculate flags: CONTIGUOUS, FORTRAN, ALIGNED
PyArray_UpdateFlags(%(res)s, NPY_ARRAY_UPDATE_ALL);
clear_output = ['if (%(res)s) {Py_XDECREF(%(res)s);}']
// we are making a view in both inplace and non-inplace cases
PyArray_SetBaseObject(%(res)s, (PyObject*)%(basename)s);
# get the copy / view of the input depending on whether we're doingi
# things inplace or not.
if self.inplace:
get_base = ['{ PyArrayObject * %(basename)s = %(input)s',
'Py_INCREF((PyObject*)%(basename)s)']
else:
get_base = [
('{ PyArrayObject * %(basename)s = '
'(PyArrayObject*)PyArray_FromAny((PyObject*)%(input)s,'
' NULL, 0, 0, NPY_ARRAY_ALIGNED|NPY_ARRAY_ENSURECOPY,'
' NULL)')]
shape_statements = ['npy_intp dimensions[%i]' % nd_out]
for i, o in enumerate(self.new_order):
if o != 'x':
shape_statements += [('dimensions[' + str(
i) + '] = PyArray_DIMS(%(basename)s)[' + str(o) + ']')]
else:
shape_statements += [('dimensions[' + str(i) + '] = 1')]
strides_statements = ['npy_intp strides[%i]' % nd_out]
# set the strides of the non-broadcasted dimensions
for i, o in enumerate(self.new_order):
if o != 'x':
strides_statements += [('strides[' + str(i) +
'] = PyArray_DIMS(%(basename)s)[' +
str(o) +
'] == 1? 0 : '
'PyArray_STRIDES(%(basename)s)[' +
str(o) + ']')]
else:
strides_statements += [('strides[' + str(i) + '] = 0')]
# set the strides of the broadcasted dimensions
# this algorithm is from numpy: PyArray_Newshape() in
# cvs/numpy/numpy/core/src/multiarraymodule.c
if nd_out > 0:
strides_statements.append(
'if (strides[' +
str(nd_out) +
'-1] == 0) strides[' +
str(nd_out) +
'-1] = PyArray_DESCR(%(basename)s)->elsize'
)
for i in xrange(nd_out - 2, -1, -1):
strides_statements.append(
"if (strides[%(i)s] == 0) strides[%(i)s] = strides[%(i)s+1] * "
"dimensions[%(i)s+1]" % dict(i=str(i)))
close_bracket = [
# create a new array,
('%(res)s = (PyArrayObject*)PyArray_New(&PyArray_Type, '
'' + str(nd_out) + ', dimensions, '
'PyArray_TYPE(%(basename)s), strides, '
'PyArray_DATA(%(basename)s), PyArray_ITEMSIZE(%(basename)s), '
# borrow only the writable flag from the base
# the NPY_OWNDATA flag will default to 0.
'(NPY_ARRAY_WRITEABLE*PyArray_ISWRITEABLE(%(basename)s)), '
'NULL)'),
'if (%(res)s == NULL) %(fail)s;',
# recalculate flags: CONTIGUOUS, FORTRAN, ALIGNED
'PyArray_UpdateFlags(%(res)s, NPY_ARRAY_UPDATE_ALL)',
# we are making a view in both inplace and non-inplace cases
"""
PyArray_SetBaseObject(%(res)s, (PyObject*)%(basename)s);
"""
'}']
full_code = statements(check_input_nd +
clear_output +
get_base +
shape_statements +
strides_statements +
close_bracket)
return full_code % dict(locals(), **sub)
free(strides);
free(dimensions);
}""" % dict(input=input, res=res,
basename=basename,
params=sub['params'],
just_fail=sub['fail'],
fail="""
free(strides);
free(dimensions);
%(fail)s
""" % dict(fail=sub['fail']))
def c_code_cache_version(self):
return (3,)
return (4,)
def grad(self, inp, grads):
x, = inp
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论