提交 0be9c072 authored 作者: Frederic's avatar Frederic

Make Subtensor use the new NumPy interface

IncSugtensor and the GPU code still need update.
上级 b7880c5f
......@@ -532,46 +532,25 @@ class Subtensor(Op):
"""
return {
"c_prefix": "PyArray",
"update_flags": ("PyArray_UpdateFlags(%(view_name)s,"
" NPY_ARRAY_C_CONTIGUOUS|"
"NPY_ARRAY_F_CONTIGUOUS);"),
"set_data": "PyArray_set_data",
"set_dim": "PyArray_set_dim",
"set_stride": "PyArray_set_stride",
"strides_mul": 1,
"view_name": "xview"}
"c_prefix": "PyArray",
"strides_mul": 1,
"view_name": "xview"}
@staticmethod
def helper_c_code(node, name, inputs, outputs, sub, idx_list,
def helper_c_code(node, name, inputs, outputs, sub, idx_list, view_ndim,
c_prefix=None,
update_flags=None,
set_data=None,
set_dim=None,
set_stride=None,
strides_mul=None,
view_name=None
):
"""
The parameters c_prefix, update_flags, set_data, set_dim,
set_stride and strides_mul are there to allow reusing this
The parameters c_prefix are there to allow reusing this
function on PyArray and CudaNdarray object.
This fct take as input the x,
"""
default_args = Subtensor.default_helper_c_code_args()
if update_flags is None:
update_flags = default_args['update_flags']
if set_data is None:
set_data = default_args['set_data']
if set_dim is None:
set_dim = default_args['set_dim']
if set_stride is None:
set_stride = default_args['set_stride']
if strides_mul is None:
strides_mul = default_args['strides_mul']
......@@ -581,9 +560,6 @@ class Subtensor(Op):
if view_name is None:
view_name = default_args['view_name']
#update_flags may depend on view_name
update_flags = update_flags % locals()
#
# two arrays are created in C code:
# is_slice: len == ndim, 0 means int, 1 means slice
......@@ -658,11 +634,8 @@ class Subtensor(Op):
z, = outputs
xview = view_name
rval = """
#define PyArray_set_dim(obj, idx, d) PyArray_DIMS(obj)[idx]=d
#define PyArray_set_stride(obj, idx, d) PyArray_STRIDES(obj)[idx]=d
#define PyArray_set_data(obj, ptr, base) PyArray_BYTES(obj)=ptr
// The subtensor is created by iterating over the dimensions
// and updating stride, shape, and data pointers
......@@ -674,32 +647,10 @@ class Subtensor(Op):
int inner_ii = 0; // the current dimension of zview
int outer_ii = 0; // current dimension of z
char* ptr = (char*) %(c_prefix)s_BYTES(%(xview)s);
if ((%(c_prefix)s_DIMS(%(xview)s) == %(c_prefix)s_DIMS(%(x)s))
&& (%(c_prefix)s_DIMS(%(x)s) != NULL))
{
PyErr_Format(PyExc_ValueError, "x and %(xview)s"
"(with %%d dims) have the same dimensions"
" pointers: %%p and %%p",
%(c_prefix)s_NDIM(%(x)s),
%(c_prefix)s_DIMS(%(xview)s),
%(c_prefix)s_DIMS(%(x)s));
Py_XDECREF(%(xview)s);
%(fail)s;
}
if (%(c_prefix)s_STRIDES(%(xview)s) == %(c_prefix)s_STRIDES(%(x)s)
&& (%(c_prefix)s_DIMS(%(x)s) != NULL))
{
PyErr_Format(PyExc_ValueError, "x and %(xview)s"
"(with %%d dims) have the same strides"
" pointers: %%p and %%p",
%(c_prefix)s_NDIM(%(x)s),
%(c_prefix)s_STRIDES(%(xview)s),
%(c_prefix)s_STRIDES(%(x)s));
Py_XDECREF(%(xview)s);
%(fail)s;
}
// Argument of the view
char* xview_ptr = (char*) %(c_prefix)s_BYTES(%(x)s);
ssize_t xview_dims[%(view_ndim)s];
ssize_t xview_strides[%(view_ndim)s];
for (; outer_ii < %(len_is_slice)s; ++outer_ii)
{
......@@ -719,10 +670,8 @@ class Subtensor(Op):
// PySlice_GetIndicesEx in python source
if (!step)
{
Py_DECREF(%(xview)s);
PyErr_Format(PyExc_ValueError,
"slice step cannot be zero");
Py_XDECREF(%(xview)s);
%(fail)s;
}
......@@ -771,11 +720,10 @@ class Subtensor(Op):
assert (slicelength <= length);
ptr += %(c_prefix)s_STRIDES(%(x)s)[outer_ii] * start *
xview_ptr += %(c_prefix)s_STRIDES(%(x)s)[outer_ii] * start *
%(strides_mul)s;
%(set_dim)s(%(xview)s, inner_ii, slicelength);
%(set_stride)s(%(xview)s, inner_ii,
%(c_prefix)s_STRIDES(%(x)s)[outer_ii] * step);
xview_dims[inner_ii] = slicelength;
xview_strides[inner_ii] = %(c_prefix)s_STRIDES(%(x)s)[outer_ii] * step;
inner_ii += 1;
spec_pos += 3;
......@@ -788,45 +736,42 @@ class Subtensor(Op):
{
if (idx < %(c_prefix)s_DIMS(%(x)s)[outer_ii])
{
ptr += %(c_prefix)s_STRIDES(%(x)s)[outer_ii] * idx *
xview_ptr += %(c_prefix)s_STRIDES(%(x)s)[outer_ii] * idx *
%(strides_mul)s;
}
else
{
PyErr_Format(PyExc_IndexError,"index out of bounds");
Py_XDECREF(%(xview)s);
%(fail)s;
}
}
else
{
PyErr_Format(PyExc_IndexError,"index out of bounds");
Py_XDECREF(%(xview)s);
%(fail)s;
}
spec_pos += 1;
}
}
%(set_data)s(%(xview)s, ptr, (PyObject*)NULL);
assert (inner_ii <= %(c_prefix)s_NDIM(%(xview)s));
while (inner_ii < %(c_prefix)s_NDIM(%(xview)s))
assert (inner_ii <= %(view_ndim)s);
while (inner_ii < %(view_ndim)s)
{
assert (outer_ii < %(c_prefix)s_NDIM(%(x)s));
%(set_dim)s(%(xview)s, inner_ii,
%(c_prefix)s_DIMS(%(x)s)[outer_ii]);
%(set_stride)s(%(xview)s, inner_ii,
%(c_prefix)s_STRIDES(%(x)s)[outer_ii]);
xview_dims[inner_ii] = %(c_prefix)s_DIMS(%(x)s)[outer_ii];
xview_strides[inner_ii] = %(c_prefix)s_STRIDES(%(x)s)[outer_ii];
inner_ii += 1;
outer_ii += 1;
}
%(update_flags)s
//(update_flags)s
""" % locals()
# print rval
return rval
@staticmethod
def helper_c_code_cache_version():
return ()
return (5,)
def c_code(self, node, name, inputs, outputs, sub): # DEBUG
......@@ -838,29 +783,35 @@ class Subtensor(Op):
view_ndim = node.outputs[0].ndim
fail = sub['fail']
decl = "PyArrayObject * xview = NULL;"
get_xview = self.helper_c_code(node, name, inputs, outputs, sub,
self.idx_list, view_ndim)
build_view = """
//TODO: give this Op a second output so that this view can be cached
//TODO: alternatively, fix the memory leak on failure
Py_INCREF(PyArray_DESCR(%(x)s));
PyArrayObject * xview = (PyArrayObject*)PyArray_NewFromDescr(
xview = (PyArrayObject*)PyArray_NewFromDescr(
&PyArray_Type,
PyArray_DESCR(%(x)s),
%(view_ndim)s,
PyArray_DIMS(%(x)s),
PyArray_STRIDES(%(x)s),
PyArray_DATA(%(x)s),
xview_dims, //PyArray_DIMS(%(x)s),
xview_strides, //PyArray_STRIDES(%(x)s),
xview_ptr, //PyArray_DATA(%(x)s),
PyArray_FLAGS(%(x)s),
NULL);
assert (PyArray_NDIM(xview) == %(view_ndim)s);
if (!xview)
{
%(fail)s;
}
""" % locals()
get_xview = self.helper_c_code(node, name, inputs, outputs, sub,
self.idx_list)
finish_view = """
if (%(z)s) Py_DECREF(%(z)s);
//Is this update needed? The doc of PyArray_NewFromDescr do not
//tell us the information needed
PyArray_UpdateFlags(xview, NPY_ARRAY_C_CONTIGUOUS| NPY_ARRAY_F_CONTIGUOUS);
Py_XDECREF(%(z)s);
Py_INCREF(py_%(x)s);
#if NPY_API_VERSION < 0x00000007
PyArray_BASE(xview) = py_%(x)s;
......@@ -871,7 +822,7 @@ class Subtensor(Op):
%(z)s = xview;
""" % locals()
return build_view + "{" + get_xview + "}" + finish_view
return decl + get_xview + build_view + finish_view
def c_code_cache_version(self):
hv = self.helper_c_code_cache_version()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论