提交 911d7160 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

added C code for the increment case

上级 51d2e95d
......@@ -2511,6 +2511,23 @@ class GpuIncSubtensor(tensor.IncSubtensor, GpuOp):
%(fail)s;
}""" % locals()
def add_to_xview(self, x, fail):
return """
PyObject * add_result = CudaNdarray_inplace_add(xview, py_%(x)s);
if (! add_result )
{
Py_DECREF(xview);
%(fail)s;
}
else
{
Py_DECREF(add_result);
}
""" % locals()
def c_code_cache_version(self):
# TODO: cooperate with parent class' C code
return ()
......
......@@ -1740,7 +1740,7 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t
* It returns py_self on success with an additional reference. Else NULL.
*/
// Will be called by __iadd__ in Python
static PyObject *
PyObject *
CudaNdarray_inplace_add(PyObject* py_self, PyObject * py_other)
{
if (CudaNdarray_inplace_elemwise(py_self, py_other, IADD))
......
......@@ -144,7 +144,7 @@ CudaNdarray_Equal(CudaNdarray *cnda1, CudaNdarray *cnda2);
/****
* Set the idx'th dimension to value d.
*
* Updates the log2dim shaddow array.
* Updates the log2dim shadow array.
*
* Does not sync structure to host.
*/
......@@ -441,6 +441,7 @@ int fprint_CudaNdarray(FILE * fd, const CudaNdarray *self);
PyObject * CudaNdarray_View(const CudaNdarray * self);
CudaNdarray_inplace_add(PyObject* py_self, PyObject * py_other);
#endif
......
......@@ -4648,6 +4648,8 @@ class IncSubtensor(Op):
copy_into = self.copy_into("xview", y)
add_to_xview = self.add_to_xview(y, fail)
make_modification = """
if (%(op_is_set)s)
{
......@@ -4659,19 +4661,7 @@ class IncSubtensor(Op):
}
else
{
PyArrayObject * add_rval = (PyArrayObject*)PyNumber_InPlaceAdd(
(PyObject*)xview, py_%(y)s);
if (add_rval)
{
assert (PyArray_Check((PyObject*)add_rval));
assert (PyArray_DATA(add_rval) == PyArray_DATA(xview));
Py_DECREF(add_rval);
}
else
{
Py_DECREF(xview);
%(fail)s;
}
%(add_to_xview)s
}
""" % locals()
......@@ -4773,6 +4763,24 @@ class IncSubtensor(Op):
# On CPU there is nothing to do
return ""
def add_to_xview(self, x, fail):
""" Return C code to add x to xview. Should DECREF xview if the
add fails."""
return """
PyArrayObject * add_rval = (PyArrayObject*)PyNumber_InPlaceAdd(
(PyObject*)xview, py_%(x)s);
if (add_rval)
{
assert (PyArray_Check((PyObject*)add_rval));
assert (PyArray_DATA(add_rval) == PyArray_DATA(xview));
Py_DECREF(add_rval);
}
else
{
Py_DECREF(xview);
%(fail)s;
}""" % locals()
def infer_shape(self, node, shapes):
return [shapes[0]]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论