提交 1817ae13 authored 作者: Li's avatar Li 提交者: Frederic

check return value for failure

上级 5f2ed7aa
...@@ -80,7 +80,7 @@ class CumsumOp(theano.Op): ...@@ -80,7 +80,7 @@ class CumsumOp(theano.Op):
""" % locals() """ % locals()
else: else:
code = """ code = """
if(!(%(z)s && PyArray_CompareLists(PyArray_DIMS(%(z)s), PyArray_DIMS(%(x)s), PyArray_NDIM(%(x)s)) )) if(!(%(z)s && PyArray_CompareLists(PyArray_DIMS(%(z)s), PyArray_DIMS(%(x)s), PyArray_NDIM(%(x)s))))
{ {
Py_XDECREF(%(z)s); Py_XDECREF(%(z)s);
%(z)s = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM(%(x)s), PyArray_DIMS(%(x)s), PyArray_TYPE((PyArrayObject*) py_%(x)s)); %(z)s = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM(%(x)s), PyArray_DIMS(%(x)s), PyArray_TYPE((PyArrayObject*) py_%(x)s));
...@@ -89,7 +89,13 @@ class CumsumOp(theano.Op): ...@@ -89,7 +89,13 @@ class CumsumOp(theano.Op):
if (!%(z)s) if (!%(z)s)
%(fail)s; %(fail)s;
{ {
PyArray_CumSum(%(x)s, %(axis)s, PyArray_TYPE((PyArrayObject*) py_%(x)s), %(z)s);
PyObject * t = PyArray_CumSum(
%(x)s, %(axis)s,
PyArray_TYPE((PyArrayObject*) py_%(x)s), %(z)s);
if (!t){
%(fail)s;
}
Py_XDECREF(%(z)s); // Because PyArray_CumSum returns a newly created reference on %(z)s. Py_XDECREF(%(z)s); // Because PyArray_CumSum returns a newly created reference on %(z)s.
} }
""" % locals() """ % locals()
...@@ -97,7 +103,7 @@ class CumsumOp(theano.Op): ...@@ -97,7 +103,7 @@ class CumsumOp(theano.Op):
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (3,) return (4,)
def __str__(self): def __str__(self):
return "%s{%s}" % (self.__class__.__name__, self.axis) return "%s{%s}" % (self.__class__.__name__, self.axis)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论