提交 acb2a9d0 authored 作者: Hengjean's avatar Hengjean

Multiple fixes.

上级 63e45998
......@@ -293,17 +293,19 @@ class CudaNdarrayType(Type):
//fprintf(stderr, "c_extract CNDA object w refcnt %%p %%i\\n", py_%(name)s, (py_%(name)s->ob_refcnt));
%(name)s = (CudaNdarray*)py_%(name)s;
//std::cerr << "c_extract " << %(name)s << '\\n';
if (%(name)s->nd != %(nd)s)
{
PyErr_Format(PyExc_RuntimeError,
"c_extract: Some CudaNdarray has rank %%i, it was supposed to have rank %(nd)s",
%(name)s->nd);
%(name)s = NULL;
%(fail)s;
}
//std::cerr << "c_extract " << %(name)s << " nd check passed\\n";
""" % locals()
if(check_input):
print >> sio, """
if (%(name)s->nd != %(nd)s)
{
PyErr_Format(PyExc_RuntimeError,
"c_extract: Some CudaNdarray has rank %%i, it was supposed to have rank %(nd)s",
%(name)s->nd);
%(name)s = NULL;
%(fail)s;
}
//std::cerr << "c_extract " << %(name)s << " nd check passed\\n";
""" % locals()
for i, b in enumerate(self.broadcastable):
if b:
print >> sio, """
......
......@@ -416,14 +416,21 @@ class TensorType(Type):
return str(self)
#"TensorType{%s, %s}" % (str(self.dtype), str(self.broadcastable))
def c_declare(self, name, sub):
def c_declare(self, name, sub, check_input=True):
"""Override `CLinkerType.c_declare` """
return """
if(check_input):
check = """
typedef %(dtype)s dtype_%(name)s;
""" % dict(sub, name=name, dtype=self.dtype_specs()[1])
else:
check = ""
declaration = """
PyArrayObject* %(name)s;
int type_num_%(name)s;
typedef %(dtype)s dtype_%(name)s;
""" % dict(sub, name=name, dtype=self.dtype_specs()[1])
return declaration + check
def c_init(self, name, sub):
"""Override `CLinkerType.c_init` """
return """
......@@ -485,7 +492,9 @@ class TensorType(Type):
}
""" % dict(sub, name=name, type_num=self.dtype_specs()[2])
else:
check = ""
check = """
type_num_%(name)s = PyArray_TYPE((PyArrayObject*) py_%(name)s);
"""
return check + """
%(name)s = (PyArrayObject*)(py_%(name)s);
Py_XINCREF(%(name)s);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论