提交 0cef4bd1 authored 作者: Frederic's avatar Frederic

Apply the previous fix elsewhere.

上级 1817ae13
...@@ -74,7 +74,12 @@ class CumsumOp(theano.Op): ...@@ -74,7 +74,12 @@ class CumsumOp(theano.Op):
if (!%(z)s) if (!%(z)s)
%(fail)s; %(fail)s;
{ {
PyArray_CumSum(%(x)s, NPY_MAXDIMS, PyArray_TYPE((PyArrayObject*) py_%(x)s), %(z)s); PyObject * t = PyArray_CumSum(
%(x)s, NPY_MAXDIMS,
PyArray_TYPE((PyArrayObject*) py_%(x)s), %(z)s);
if (!t){
%(fail)s;
}
Py_XDECREF(%(z)s); // Because PyArray_CumSum returns a newly created reference on %(z)s. Py_XDECREF(%(z)s); // Because PyArray_CumSum returns a newly created reference on %(z)s.
} }
""" % locals() """ % locals()
...@@ -103,7 +108,7 @@ class CumsumOp(theano.Op): ...@@ -103,7 +108,7 @@ class CumsumOp(theano.Op):
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (4,) return (5,)
def __str__(self): def __str__(self):
return "%s{%s}" % (self.__class__.__name__, self.axis) return "%s{%s}" % (self.__class__.__name__, self.axis)
...@@ -189,7 +194,12 @@ class CumprodOp(theano.Op): ...@@ -189,7 +194,12 @@ class CumprodOp(theano.Op):
if (!%(z)s) if (!%(z)s)
%(fail)s; %(fail)s;
{ {
PyArray_CumProd(%(x)s, NPY_MAXDIMS, PyArray_TYPE((PyArrayObject*) py_%(x)s), %(z)s); PyObject * t = PyArray_CumProd(
%(x)s, NPY_MAXDIMS,
PyArray_TYPE((PyArrayObject*) py_%(x)s), %(z)s);
if (!t){
%(fail)s;
}
Py_XDECREF(%(z)s); // Because PyArray_CumSum returns a newly created reference on %(z)s. Py_XDECREF(%(z)s); // Because PyArray_CumSum returns a newly created reference on %(z)s.
} }
""" % locals() """ % locals()
...@@ -204,7 +214,12 @@ class CumprodOp(theano.Op): ...@@ -204,7 +214,12 @@ class CumprodOp(theano.Op):
if (!%(z)s) if (!%(z)s)
%(fail)s; %(fail)s;
{ {
PyArray_CumProd(%(x)s, %(axis)s, PyArray_TYPE((PyArrayObject*) py_%(x)s), %(z)s); PyObject * t = PyArray_CumProd(
%(x)s, %(axis)s,
PyArray_TYPE((PyArrayObject*) py_%(x)s), %(z)s);
if (!t){
%(fail)s;
}
Py_XDECREF(%(z)s); // Because PyArray_CumSum returns a newly created reference on %(z)s. Py_XDECREF(%(z)s); // Because PyArray_CumSum returns a newly created reference on %(z)s.
} }
""" % locals() """ % locals()
...@@ -212,7 +227,7 @@ class CumprodOp(theano.Op): ...@@ -212,7 +227,7 @@ class CumprodOp(theano.Op):
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (2,) return (3,)
def __str__(self): def __str__(self):
return "%s{%s}" % (self.__class__.__name__, self.axis) return "%s{%s}" % (self.__class__.__name__, self.axis)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论