提交 1817ae13 authored 作者: Li's avatar Li 提交者: Frederic

check return value for failure

上级 5f2ed7aa
......@@ -80,7 +80,7 @@ class CumsumOp(theano.Op):
""" % locals()
else:
code = """
if(!(%(z)s && PyArray_CompareLists(PyArray_DIMS(%(z)s), PyArray_DIMS(%(x)s), PyArray_NDIM(%(x)s)) ))
if(!(%(z)s && PyArray_CompareLists(PyArray_DIMS(%(z)s), PyArray_DIMS(%(x)s), PyArray_NDIM(%(x)s))))
{
Py_XDECREF(%(z)s);
%(z)s = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM(%(x)s), PyArray_DIMS(%(x)s), PyArray_TYPE((PyArrayObject*) py_%(x)s));
......@@ -89,7 +89,13 @@ class CumsumOp(theano.Op):
if (!%(z)s)
%(fail)s;
{
PyArray_CumSum(%(x)s, %(axis)s, PyArray_TYPE((PyArrayObject*) py_%(x)s), %(z)s);
PyObject * t = PyArray_CumSum(
%(x)s, %(axis)s,
PyArray_TYPE((PyArrayObject*) py_%(x)s), %(z)s);
if (!t){
%(fail)s;
}
Py_XDECREF(%(z)s); // Because PyArray_CumSum returns a newly created reference on %(z)s.
}
""" % locals()
......@@ -97,7 +103,7 @@ class CumsumOp(theano.Op):
return code
def c_code_cache_version(self):
return (3,)
return (4,)
def __str__(self):
return "%s{%s}" % (self.__class__.__name__, self.axis)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论