提交 0be9c072 authored 作者: Frederic's avatar Frederic

Make Subtensor use the new NumPy interface

IncSugtensor and the GPU code still need update.
上级 b7880c5f
...@@ -533,45 +533,24 @@ class Subtensor(Op): ...@@ -533,45 +533,24 @@ class Subtensor(Op):
return { return {
"c_prefix": "PyArray", "c_prefix": "PyArray",
"update_flags": ("PyArray_UpdateFlags(%(view_name)s,"
" NPY_ARRAY_C_CONTIGUOUS|"
"NPY_ARRAY_F_CONTIGUOUS);"),
"set_data": "PyArray_set_data",
"set_dim": "PyArray_set_dim",
"set_stride": "PyArray_set_stride",
"strides_mul": 1, "strides_mul": 1,
"view_name": "xview"} "view_name": "xview"}
@staticmethod @staticmethod
def helper_c_code(node, name, inputs, outputs, sub, idx_list, def helper_c_code(node, name, inputs, outputs, sub, idx_list, view_ndim,
c_prefix=None, c_prefix=None,
update_flags=None,
set_data=None,
set_dim=None,
set_stride=None,
strides_mul=None, strides_mul=None,
view_name=None view_name=None
): ):
""" """
The parameters c_prefix, update_flags, set_data, set_dim, The parameters c_prefix are there to allow reusing this
set_stride and strides_mul are there to allow reusing this
function on PyArray and CudaNdarray object. function on PyArray and CudaNdarray object.
This fct take as input the x,
""" """
default_args = Subtensor.default_helper_c_code_args() default_args = Subtensor.default_helper_c_code_args()
if update_flags is None:
update_flags = default_args['update_flags']
if set_data is None:
set_data = default_args['set_data']
if set_dim is None:
set_dim = default_args['set_dim']
if set_stride is None:
set_stride = default_args['set_stride']
if strides_mul is None: if strides_mul is None:
strides_mul = default_args['strides_mul'] strides_mul = default_args['strides_mul']
...@@ -581,9 +560,6 @@ class Subtensor(Op): ...@@ -581,9 +560,6 @@ class Subtensor(Op):
if view_name is None: if view_name is None:
view_name = default_args['view_name'] view_name = default_args['view_name']
#update_flags may depend on view_name
update_flags = update_flags % locals()
# #
# two arrays are created in C code: # two arrays are created in C code:
# is_slice: len == ndim, 0 means int, 1 means slice # is_slice: len == ndim, 0 means int, 1 means slice
...@@ -660,9 +636,6 @@ class Subtensor(Op): ...@@ -660,9 +636,6 @@ class Subtensor(Op):
xview = view_name xview = view_name
rval = """ rval = """
#define PyArray_set_dim(obj, idx, d) PyArray_DIMS(obj)[idx]=d
#define PyArray_set_stride(obj, idx, d) PyArray_STRIDES(obj)[idx]=d
#define PyArray_set_data(obj, ptr, base) PyArray_BYTES(obj)=ptr
// The subtensor is created by iterating over the dimensions // The subtensor is created by iterating over the dimensions
// and updating stride, shape, and data pointers // and updating stride, shape, and data pointers
...@@ -674,32 +647,10 @@ class Subtensor(Op): ...@@ -674,32 +647,10 @@ class Subtensor(Op):
int inner_ii = 0; // the current dimension of zview int inner_ii = 0; // the current dimension of zview
int outer_ii = 0; // current dimension of z int outer_ii = 0; // current dimension of z
char* ptr = (char*) %(c_prefix)s_BYTES(%(xview)s); // Argument of the view
char* xview_ptr = (char*) %(c_prefix)s_BYTES(%(x)s);
if ((%(c_prefix)s_DIMS(%(xview)s) == %(c_prefix)s_DIMS(%(x)s)) ssize_t xview_dims[%(view_ndim)s];
&& (%(c_prefix)s_DIMS(%(x)s) != NULL)) ssize_t xview_strides[%(view_ndim)s];
{
PyErr_Format(PyExc_ValueError, "x and %(xview)s"
"(with %%d dims) have the same dimensions"
" pointers: %%p and %%p",
%(c_prefix)s_NDIM(%(x)s),
%(c_prefix)s_DIMS(%(xview)s),
%(c_prefix)s_DIMS(%(x)s));
Py_XDECREF(%(xview)s);
%(fail)s;
}
if (%(c_prefix)s_STRIDES(%(xview)s) == %(c_prefix)s_STRIDES(%(x)s)
&& (%(c_prefix)s_DIMS(%(x)s) != NULL))
{
PyErr_Format(PyExc_ValueError, "x and %(xview)s"
"(with %%d dims) have the same strides"
" pointers: %%p and %%p",
%(c_prefix)s_NDIM(%(x)s),
%(c_prefix)s_STRIDES(%(xview)s),
%(c_prefix)s_STRIDES(%(x)s));
Py_XDECREF(%(xview)s);
%(fail)s;
}
for (; outer_ii < %(len_is_slice)s; ++outer_ii) for (; outer_ii < %(len_is_slice)s; ++outer_ii)
{ {
...@@ -719,10 +670,8 @@ class Subtensor(Op): ...@@ -719,10 +670,8 @@ class Subtensor(Op):
// PySlice_GetIndicesEx in python source // PySlice_GetIndicesEx in python source
if (!step) if (!step)
{ {
Py_DECREF(%(xview)s);
PyErr_Format(PyExc_ValueError, PyErr_Format(PyExc_ValueError,
"slice step cannot be zero"); "slice step cannot be zero");
Py_XDECREF(%(xview)s);
%(fail)s; %(fail)s;
} }
...@@ -771,11 +720,10 @@ class Subtensor(Op): ...@@ -771,11 +720,10 @@ class Subtensor(Op):
assert (slicelength <= length); assert (slicelength <= length);
ptr += %(c_prefix)s_STRIDES(%(x)s)[outer_ii] * start * xview_ptr += %(c_prefix)s_STRIDES(%(x)s)[outer_ii] * start *
%(strides_mul)s; %(strides_mul)s;
%(set_dim)s(%(xview)s, inner_ii, slicelength); xview_dims[inner_ii] = slicelength;
%(set_stride)s(%(xview)s, inner_ii, xview_strides[inner_ii] = %(c_prefix)s_STRIDES(%(x)s)[outer_ii] * step;
%(c_prefix)s_STRIDES(%(x)s)[outer_ii] * step);
inner_ii += 1; inner_ii += 1;
spec_pos += 3; spec_pos += 3;
...@@ -788,45 +736,42 @@ class Subtensor(Op): ...@@ -788,45 +736,42 @@ class Subtensor(Op):
{ {
if (idx < %(c_prefix)s_DIMS(%(x)s)[outer_ii]) if (idx < %(c_prefix)s_DIMS(%(x)s)[outer_ii])
{ {
ptr += %(c_prefix)s_STRIDES(%(x)s)[outer_ii] * idx * xview_ptr += %(c_prefix)s_STRIDES(%(x)s)[outer_ii] * idx *
%(strides_mul)s; %(strides_mul)s;
} }
else else
{ {
PyErr_Format(PyExc_IndexError,"index out of bounds"); PyErr_Format(PyExc_IndexError,"index out of bounds");
Py_XDECREF(%(xview)s);
%(fail)s; %(fail)s;
} }
} }
else else
{ {
PyErr_Format(PyExc_IndexError,"index out of bounds"); PyErr_Format(PyExc_IndexError,"index out of bounds");
Py_XDECREF(%(xview)s);
%(fail)s; %(fail)s;
} }
spec_pos += 1; spec_pos += 1;
} }
} }
%(set_data)s(%(xview)s, ptr, (PyObject*)NULL); assert (inner_ii <= %(view_ndim)s);
assert (inner_ii <= %(c_prefix)s_NDIM(%(xview)s)); while (inner_ii < %(view_ndim)s)
while (inner_ii < %(c_prefix)s_NDIM(%(xview)s))
{ {
assert (outer_ii < %(c_prefix)s_NDIM(%(x)s)); assert (outer_ii < %(c_prefix)s_NDIM(%(x)s));
%(set_dim)s(%(xview)s, inner_ii, xview_dims[inner_ii] = %(c_prefix)s_DIMS(%(x)s)[outer_ii];
%(c_prefix)s_DIMS(%(x)s)[outer_ii]); xview_strides[inner_ii] = %(c_prefix)s_STRIDES(%(x)s)[outer_ii];
%(set_stride)s(%(xview)s, inner_ii,
%(c_prefix)s_STRIDES(%(x)s)[outer_ii]);
inner_ii += 1; inner_ii += 1;
outer_ii += 1; outer_ii += 1;
} }
%(update_flags)s //(update_flags)s
""" % locals() """ % locals()
# print rval # print rval
return rval return rval
@staticmethod @staticmethod
def helper_c_code_cache_version(): def helper_c_code_cache_version():
return ()
return (5,) return (5,)
def c_code(self, node, name, inputs, outputs, sub): # DEBUG def c_code(self, node, name, inputs, outputs, sub): # DEBUG
...@@ -838,29 +783,35 @@ class Subtensor(Op): ...@@ -838,29 +783,35 @@ class Subtensor(Op):
view_ndim = node.outputs[0].ndim view_ndim = node.outputs[0].ndim
fail = sub['fail'] fail = sub['fail']
decl = "PyArrayObject * xview = NULL;"
get_xview = self.helper_c_code(node, name, inputs, outputs, sub,
self.idx_list, view_ndim)
build_view = """ build_view = """
//TODO: give this Op a second output so that this view can be cached //TODO: give this Op a second output so that this view can be cached
//TODO: alternatively, fix the memory leak on failure //TODO: alternatively, fix the memory leak on failure
Py_INCREF(PyArray_DESCR(%(x)s)); Py_INCREF(PyArray_DESCR(%(x)s));
PyArrayObject * xview = (PyArrayObject*)PyArray_NewFromDescr( xview = (PyArrayObject*)PyArray_NewFromDescr(
&PyArray_Type, &PyArray_Type,
PyArray_DESCR(%(x)s), PyArray_DESCR(%(x)s),
%(view_ndim)s, %(view_ndim)s,
PyArray_DIMS(%(x)s), xview_dims, //PyArray_DIMS(%(x)s),
PyArray_STRIDES(%(x)s), xview_strides, //PyArray_STRIDES(%(x)s),
PyArray_DATA(%(x)s), xview_ptr, //PyArray_DATA(%(x)s),
PyArray_FLAGS(%(x)s), PyArray_FLAGS(%(x)s),
NULL); NULL);
assert (PyArray_NDIM(xview) == %(view_ndim)s);
if (!xview) if (!xview)
{ {
%(fail)s; %(fail)s;
} }
""" % locals() """ % locals()
get_xview = self.helper_c_code(node, name, inputs, outputs, sub,
self.idx_list)
finish_view = """ finish_view = """
if (%(z)s) Py_DECREF(%(z)s); //Is this update needed? The doc of PyArray_NewFromDescr do not
//tell us the information needed
PyArray_UpdateFlags(xview, NPY_ARRAY_C_CONTIGUOUS| NPY_ARRAY_F_CONTIGUOUS);
Py_XDECREF(%(z)s);
Py_INCREF(py_%(x)s); Py_INCREF(py_%(x)s);
#if NPY_API_VERSION < 0x00000007 #if NPY_API_VERSION < 0x00000007
PyArray_BASE(xview) = py_%(x)s; PyArray_BASE(xview) = py_%(x)s;
...@@ -871,7 +822,7 @@ class Subtensor(Op): ...@@ -871,7 +822,7 @@ class Subtensor(Op):
%(z)s = xview; %(z)s = xview;
""" % locals() """ % locals()
return build_view + "{" + get_xview + "}" + finish_view return decl + get_xview + build_view + finish_view
def c_code_cache_version(self): def c_code_cache_version(self):
hv = self.helper_c_code_cache_version() hv = self.helper_c_code_cache_version()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论