提交 a79c9a16 authored 作者: Frederic's avatar Frederic

Refactor the helper c code from subtensor and incsubtensor

I'll reuse it for gpusubtensor.
上级 85f9247a
......@@ -4019,7 +4019,6 @@ class Subtensor(Op):
assert len(is_slice) <= node.inputs[0].ndim, node.inputs[0].ndim
len_is_slice = len(is_slice)
view_ndim = node.inputs[0].ndim - (numpy.asarray(is_slice) == 0).sum()
len_subtensor_spec = spec_pos()
......@@ -4040,23 +4039,6 @@ class Subtensor(Op):
int inner_ii = 0; // the current dimension of zview
int outer_ii = 0; // current dimension of z
//TODO: give this Op a second output so that this view can be cached
//TODO: alternatively, fix the memory leak on failure
Py_INCREF(PyArray_DESCR(%(x)s));
PyArrayObject * xview = (PyArrayObject*)PyArray_NewFromDescr(
&PyArray_Type,
PyArray_DESCR(%(x)s),
%(view_ndim)s,
PyArray_DIMS(%(x)s),
PyArray_STRIDES(%(x)s),
PyArray_DATA(%(x)s),
%(x)s->flags,
NULL);
if (!xview)
{
%(fail)s;
}
if ((PyArray_DIMS(xview) == PyArray_DIMS(%(x)s))
&& (PyArray_DIMS(%(x)s) != NULL))
{
......@@ -4195,12 +4177,33 @@ class Subtensor(Op):
return (4,)
def c_code(self, node, name, inputs, outputs, sub): # DEBUG
part0 = self.helper_c_code(node, name, inputs, outputs, sub,
self.idx_list)
x = inputs[0]
z, = outputs
part1 = """
view_ndim = node.outputs[0].ndim
fail = sub['fail']
build_view = """
//TODO: give this Op a second output so that this view can be cached
//TODO: alternatively, fix the memory leak on failure
Py_INCREF(PyArray_DESCR(%(x)s));
PyArrayObject * xview = (PyArrayObject*)PyArray_NewFromDescr(
&PyArray_Type,
PyArray_DESCR(%(x)s),
%(view_ndim)s,
PyArray_DIMS(%(x)s),
PyArray_STRIDES(%(x)s),
PyArray_DATA(%(x)s),
%(x)s->flags,
NULL);
if (!xview)
{
%(fail)s;
}
""" % locals()
get_xview = self.helper_c_code(node, name, inputs, outputs, sub,
self.idx_list)
finish_view = """
if (%(z)s) Py_DECREF(%(z)s);
Py_INCREF(py_%(x)s);
PyArray_BASE(xview) = py_%(x)s;
......@@ -4208,7 +4211,7 @@ class Subtensor(Op):
%(z)s = xview;
""" % locals()
return part0 + part1
return build_view + "{" + get_xview + "}" + finish_view
def c_code_cache_version(self):
hv = self.helper_c_code_cache_version()
......@@ -4489,7 +4492,9 @@ class IncSubtensor(Op):
else:
op_is_set = 0
fail = sub['fail']
view_ndim = (node.inputs[0].ndim -
numpy.sum([not isinstance(idx, slice)
for idx in self.idx_list]))
copy_input_if_necessary = """
if (%(inplace)s)
{
......@@ -4508,6 +4513,25 @@ class IncSubtensor(Op):
}
""" % locals()
#Make a first view on the output, as we will write into it.
build_view = """
//TODO: give this Op a second output so that this view can be cached
//TODO: alternatively, fix the memory leak on failure
Py_INCREF(PyArray_DESCR(%(z)s));
PyArrayObject * xview = (PyArrayObject*)PyArray_NewFromDescr(
&PyArray_Type,
PyArray_DESCR(%(z)s),
%(view_ndim)s,
PyArray_DIMS(%(z)s),
PyArray_STRIDES(%(z)s),
PyArray_DATA(%(z)s),
%(z)s->flags,
NULL);
if (!xview)
{
%(fail)s;
}
""" % locals()
# make xview actually a view of %(z)s
get_xview = Subtensor.helper_c_code(node, name,
outputs[:1] + inputs[2:],
......@@ -4541,7 +4565,8 @@ class IncSubtensor(Op):
""" % locals()
return (copy_input_if_necessary
+ get_xview
+ build_view
+ "{" + get_xview + "}"
+ make_modification
+ "Py_DECREF(xview);"
)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论