提交 33c4b2cc authored 作者: Frederic's avatar Frederic

parametrise the helper fct in Subtensor to allow reusing it in GpuSubtensor.

Also add missing decref in error handling code.
上级 ae13b006
...@@ -3952,9 +3952,21 @@ class Subtensor(Op): ...@@ -3952,9 +3952,21 @@ class Subtensor(Op):
return "%s{%s}" % (self.__class__.__name__, ", ".join(indices)) return "%s{%s}" % (self.__class__.__name__, ", ".join(indices))
@staticmethod @staticmethod
def helper_c_code(node, name, inputs, outputs, sub, idx_list): def helper_c_code(node, name, inputs, outputs, sub, idx_list,
if not isinstance(node.inputs[0].type, TensorType): c_prefix="PyArray",
raise NotImplementedError() update_flags=("PyArray_UpdateFlags(xview,"
" NPY_ARRAY_C_CONTIGUOUS|"
"NPY_ARRAY_F_CONTIGUOUS);"),
set_data='PyArray_set_data',
set_dim='PyArray_set_dim',
set_stride='PyArray_set_stride',
strides_mul=1,
):
"""The parameters c_prefix, update_flags, set_data, set_dim,
set_stride and strides_mul are there to allow reusing this
function on PyArray and CudaNdarray object.
"""
# #
# two arrays are created in C code: # two arrays are created in C code:
# is_slice: len == ndim, 0 means int, 1 means slice # is_slice: len == ndim, 0 means int, 1 means slice
...@@ -4029,6 +4041,10 @@ class Subtensor(Op): ...@@ -4029,6 +4041,10 @@ class Subtensor(Op):
z, = outputs z, = outputs
rval = """ rval = """
#define PyArray_set_dim(obj, idx, d) PyArray_DIMS(obj)[idx]=d
#define PyArray_set_stride(obj, idx, d) PyArray_STRIDES(obj)[idx]=d
#define PyArray_set_data(obj, ptr, base) PyArray_BYTES(obj)=ptr
// The subtensor is created by iterating over the dimensions // The subtensor is created by iterating over the dimensions
// and updating stride, shape, and data pointers // and updating stride, shape, and data pointers
...@@ -4039,22 +4055,30 @@ class Subtensor(Op): ...@@ -4039,22 +4055,30 @@ class Subtensor(Op):
int inner_ii = 0; // the current dimension of zview int inner_ii = 0; // the current dimension of zview
int outer_ii = 0; // current dimension of z int outer_ii = 0; // current dimension of z
if ((PyArray_DIMS(xview) == PyArray_DIMS(%(x)s)) char* ptr = (char*) %(c_prefix)s_BYTES(xview);
&& (PyArray_DIMS(%(x)s) != NULL))
if ((%(c_prefix)s_DIMS(xview) == %(c_prefix)s_DIMS(%(x)s))
&& (%(c_prefix)s_DIMS(%(x)s) != NULL))
{ {
PyErr_Format(PyExc_ValueError, "x and xview" PyErr_Format(PyExc_ValueError, "x and xview"
"(with %%d dims) have the same dimensions" "(with %%d dims) have the same dimensions"
" pointers: %%p and %%p", " pointers: %%p and %%p",
PyArray_NDIM(%(x)s), PyArray_DIMS(xview), PyArray_DIMS(%(x)s)); %(c_prefix)s_NDIM(%(x)s),
%(c_prefix)s_DIMS(xview),
%(c_prefix)s_DIMS(%(x)s));
Py_XDECREF(xview);
%(fail)s; %(fail)s;
} }
if (PyArray_STRIDES(xview) == PyArray_STRIDES(%(x)s) if (%(c_prefix)s_STRIDES(xview) == %(c_prefix)s_STRIDES(%(x)s)
&& (PyArray_DIMS(%(x)s) != NULL)) && (%(c_prefix)s_DIMS(%(x)s) != NULL))
{ {
PyErr_Format(PyExc_ValueError, "x and xview" PyErr_Format(PyExc_ValueError, "x and xview"
"(with %%d dims) have the same strides" "(with %%d dims) have the same strides"
" pointers: %%p and %%p", " pointers: %%p and %%p",
PyArray_NDIM(%(x)s), PyArray_STRIDES(xview), PyArray_STRIDES(%(x)s)); %(c_prefix)s_NDIM(%(x)s),
%(c_prefix)s_STRIDES(xview),
%(c_prefix)s_STRIDES(%(x)s));
Py_XDECREF(xview);
%(fail)s; %(fail)s;
} }
...@@ -4062,7 +4086,7 @@ class Subtensor(Op): ...@@ -4062,7 +4086,7 @@ class Subtensor(Op):
{ {
if (is_slice[outer_ii]) if (is_slice[outer_ii])
{ {
npy_intp length = PyArray_DIMS(%(x)s)[outer_ii]; npy_intp length = %(c_prefix)s_DIMS(%(x)s)[outer_ii];
npy_intp slicelength; npy_intp slicelength;
npy_intp start = subtensor_spec[spec_pos+0]; npy_intp start = subtensor_spec[spec_pos+0];
npy_intp stop = subtensor_spec[spec_pos+1]; npy_intp stop = subtensor_spec[spec_pos+1];
...@@ -4079,6 +4103,7 @@ class Subtensor(Op): ...@@ -4079,6 +4103,7 @@ class Subtensor(Op):
Py_DECREF(xview); Py_DECREF(xview);
PyErr_Format(PyExc_ValueError, PyErr_Format(PyExc_ValueError,
"slice step cannot be zero"); "slice step cannot be zero");
Py_XDECREF(xview);
%(fail)s; %(fail)s;
} }
...@@ -4126,9 +4151,12 @@ class Subtensor(Op): ...@@ -4126,9 +4151,12 @@ class Subtensor(Op):
} }
assert (slicelength <= length); assert (slicelength <= length);
xview->data += PyArray_STRIDES(%(x)s)[outer_ii] * start;
PyArray_DIMS(xview)[inner_ii] = slicelength; ptr += %(c_prefix)s_STRIDES(%(x)s)[outer_ii] * start *
PyArray_STRIDES(xview)[inner_ii] = PyArray_STRIDES(%(x)s)[outer_ii] * step; %(strides_mul)s;
%(set_dim)s(xview, inner_ii, slicelength);
%(set_stride)s(xview, inner_ii,
%(c_prefix)s_STRIDES(%(x)s)[outer_ii] * step);
inner_ii += 1; inner_ii += 1;
spec_pos += 3; spec_pos += 3;
...@@ -4136,47 +4164,54 @@ class Subtensor(Op): ...@@ -4136,47 +4164,54 @@ class Subtensor(Op):
else // tuple coord `outer_ii` is an int else // tuple coord `outer_ii` is an int
{ {
int idx = subtensor_spec[spec_pos]; int idx = subtensor_spec[spec_pos];
if (idx < 0) idx += PyArray_DIMS(%(x)s)[outer_ii]; if (idx < 0) idx += %(c_prefix)s_DIMS(%(x)s)[outer_ii];
if (idx >= 0) if (idx >= 0)
{ {
if (idx < PyArray_DIMS(%(x)s)[outer_ii]) if (idx < %(c_prefix)s_DIMS(%(x)s)[outer_ii])
{ {
xview->data += PyArray_STRIDES(%(x)s)[outer_ii] * idx; ptr += %(c_prefix)s_STRIDES(%(x)s)[outer_ii] * idx *
%(strides_mul)s;
} }
else else
{ {
PyErr_Format(PyExc_IndexError,"index out of bounds"); PyErr_Format(PyExc_IndexError,"index out of bounds");
Py_XDECREF(xview);
%(fail)s; %(fail)s;
} }
} }
else else
{ {
PyErr_Format(PyExc_IndexError,"index out of bounds"); PyErr_Format(PyExc_IndexError,"index out of bounds");
Py_XDECREF(xview);
%(fail)s; %(fail)s;
} }
spec_pos += 1; spec_pos += 1;
} }
} }
assert (inner_ii <= PyArray_NDIM(xview)); %(set_data)s(xview, ptr, (PyObject*)NULL);
while (inner_ii < PyArray_NDIM(xview)) assert (inner_ii <= %(c_prefix)s_NDIM(xview));
while (inner_ii < %(c_prefix)s_NDIM(xview))
{ {
assert (outer_ii < PyArray_NDIM(%(x)s)); assert (outer_ii < %(c_prefix)s_NDIM(%(x)s));
PyArray_DIMS(xview)[inner_ii] = PyArray_DIMS(%(x)s)[outer_ii]; %(set_dim)s(xview, inner_ii, %(c_prefix)s_DIMS(%(x)s)[outer_ii]);
PyArray_STRIDES(xview)[inner_ii] = PyArray_STRIDES(%(x)s)[outer_ii]; %(set_stride)s(xview, inner_ii, %(c_prefix)s_STRIDES(%(x)s)[outer_ii]);
inner_ii += 1; inner_ii += 1;
outer_ii += 1; outer_ii += 1;
} }
PyArray_UpdateFlags(xview, NPY_ARRAY_C_CONTIGUOUS|NPY_F_CONTIGUOUS); %(update_flags)s
""" % locals() """ % locals()
# print rval # print rval
return rval return rval
@staticmethod @staticmethod
def helper_c_code_cache_version(): def helper_c_code_cache_version():
return (4,) return (5,)
def c_code(self, node, name, inputs, outputs, sub): # DEBUG def c_code(self, node, name, inputs, outputs, sub): # DEBUG
if not isinstance(node.inputs[0].type, TensorType):
raise NotImplementedError()
x = inputs[0] x = inputs[0]
z, = outputs z, = outputs
view_ndim = node.outputs[0].ndim view_ndim = node.outputs[0].ndim
...@@ -4219,7 +4254,7 @@ class Subtensor(Op): ...@@ -4219,7 +4254,7 @@ class Subtensor(Op):
# have a versioned version of this op's C code. # have a versioned version of this op's C code.
if len(hv) == 0: if len(hv) == 0:
return () return ()
return (1, hv) return (2, hv)
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
# Subtensor is not differentiable wrt to its indices, therefore we # Subtensor is not differentiable wrt to its indices, therefore we
...@@ -4479,6 +4514,8 @@ class IncSubtensor(Op): ...@@ -4479,6 +4514,8 @@ class IncSubtensor(Op):
out[0] = x out[0] = x
def c_code(self, node, name, inputs, outputs, sub): # DEBUG def c_code(self, node, name, inputs, outputs, sub): # DEBUG
if not isinstance(node.inputs[0].type, TensorType):
raise NotImplementedError()
if self.inplace: # convert bool to int if self.inplace: # convert bool to int
inplace = 1 inplace = 1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论