提交 8f9c55c7 authored 作者: James Bergstra's avatar James Bergstra

Fixed important bug in TensorType C extract code that could cause segfault.

上级 ceb89e2d
......@@ -570,7 +570,6 @@ class TensorType(Type):
# input received.
return """
%(name)s = NULL;
type_num_%(name)s = ((PyArrayObject*)py_%(name)s)->descr->type_num; //we expect %(type_num)s
if (py_%(name)s == Py_None) {
// We can either fail here or set %(name)s to NULL and rely on Ops using
// tensors to handle the NULL case, but if they fail to do so they'll end up
......@@ -578,18 +577,17 @@ class TensorType(Type):
PyErr_SetString(PyExc_ValueError, "expected an ndarray, not None");
%(fail)s
}
else if (!PyArray_Check(py_%(name)s)) {
if (!PyArray_Check(py_%(name)s)) {
PyErr_SetString(PyExc_ValueError, "expected an ndarray");
%(fail)s
}
else if (type_num_%(name)s != %(type_num)s) {
type_num_%(name)s = ((PyArrayObject*)py_%(name)s)->descr->type_num; //we expect %(type_num)s
if (type_num_%(name)s != %(type_num)s) {
PyErr_SetString(PyExc_ValueError, "expected %(type_num)s");
%(fail)s
}
else {
%(name)s = (PyArrayObject*)(py_%(name)s);
Py_XINCREF(%(name)s);
}
%(name)s = (PyArrayObject*)(py_%(name)s);
Py_XINCREF(%(name)s);
""" % dict(sub, name = name, type_num = self.dtype_specs()[2])
def c_cleanup(self, name, sub):
......@@ -631,7 +629,7 @@ class TensorType(Type):
def c_code_cache_version(self):
scalar_version = scal.Scalar(self.dtype).c_code_cache_version()
if scalar_version:
return (1,) + scalar_version
return (2,) + scalar_version
else:
return ()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论