提交 911d7160 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

added C code for the increment case

上级 51d2e95d
...@@ -2511,6 +2511,23 @@ class GpuIncSubtensor(tensor.IncSubtensor, GpuOp): ...@@ -2511,6 +2511,23 @@ class GpuIncSubtensor(tensor.IncSubtensor, GpuOp):
%(fail)s; %(fail)s;
}""" % locals() }""" % locals()
def add_to_xview(self, x, fail):
return """
PyObject * add_result = CudaNdarray_inplace_add(xview, py_%(x)s);
if (! add_result )
{
Py_DECREF(xview);
%(fail)s;
}
else
{
Py_DECREF(add_result);
}
""" % locals()
def c_code_cache_version(self): def c_code_cache_version(self):
# TODO: cooperate with parent class' C code # TODO: cooperate with parent class' C code
return () return ()
......
...@@ -1740,7 +1740,7 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t ...@@ -1740,7 +1740,7 @@ CudaNdarray_inplace_elemwise(PyObject* py_self, PyObject * py_other, operator_t
* It returns py_self on success with an additional reference. Else NULL. * It returns py_self on success with an additional reference. Else NULL.
*/ */
// Will be called by __iadd__ in Python // Will be called by __iadd__ in Python
static PyObject * PyObject *
CudaNdarray_inplace_add(PyObject* py_self, PyObject * py_other) CudaNdarray_inplace_add(PyObject* py_self, PyObject * py_other)
{ {
if (CudaNdarray_inplace_elemwise(py_self, py_other, IADD)) if (CudaNdarray_inplace_elemwise(py_self, py_other, IADD))
......
...@@ -144,7 +144,7 @@ CudaNdarray_Equal(CudaNdarray *cnda1, CudaNdarray *cnda2); ...@@ -144,7 +144,7 @@ CudaNdarray_Equal(CudaNdarray *cnda1, CudaNdarray *cnda2);
/**** /****
* Set the idx'th dimension to value d. * Set the idx'th dimension to value d.
* *
* Updates the log2dim shaddow array. * Updates the log2dim shadow array.
* *
* Does not sync structure to host. * Does not sync structure to host.
*/ */
...@@ -441,6 +441,7 @@ int fprint_CudaNdarray(FILE * fd, const CudaNdarray *self); ...@@ -441,6 +441,7 @@ int fprint_CudaNdarray(FILE * fd, const CudaNdarray *self);
PyObject * CudaNdarray_View(const CudaNdarray * self); PyObject * CudaNdarray_View(const CudaNdarray * self);
CudaNdarray_inplace_add(PyObject* py_self, PyObject * py_other);
#endif #endif
......
...@@ -4648,6 +4648,8 @@ class IncSubtensor(Op): ...@@ -4648,6 +4648,8 @@ class IncSubtensor(Op):
copy_into = self.copy_into("xview", y) copy_into = self.copy_into("xview", y)
add_to_xview = self.add_to_xview(y, fail)
make_modification = """ make_modification = """
if (%(op_is_set)s) if (%(op_is_set)s)
{ {
...@@ -4659,19 +4661,7 @@ class IncSubtensor(Op): ...@@ -4659,19 +4661,7 @@ class IncSubtensor(Op):
} }
else else
{ {
PyArrayObject * add_rval = (PyArrayObject*)PyNumber_InPlaceAdd( %(add_to_xview)s
(PyObject*)xview, py_%(y)s);
if (add_rval)
{
assert (PyArray_Check((PyObject*)add_rval));
assert (PyArray_DATA(add_rval) == PyArray_DATA(xview));
Py_DECREF(add_rval);
}
else
{
Py_DECREF(xview);
%(fail)s;
}
} }
""" % locals() """ % locals()
...@@ -4773,6 +4763,24 @@ class IncSubtensor(Op): ...@@ -4773,6 +4763,24 @@ class IncSubtensor(Op):
# On CPU there is nothing to do # On CPU there is nothing to do
return "" return ""
def add_to_xview(self, x, fail):
""" Return C code to add x to xview. Should DECREF xview if the
add fails."""
return """
PyArrayObject * add_rval = (PyArrayObject*)PyNumber_InPlaceAdd(
(PyObject*)xview, py_%(x)s);
if (add_rval)
{
assert (PyArray_Check((PyObject*)add_rval));
assert (PyArray_DATA(add_rval) == PyArray_DATA(xview));
Py_DECREF(add_rval);
}
else
{
Py_DECREF(xview);
%(fail)s;
}""" % locals()
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
return [shapes[0]] return [shapes[0]]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论