提交 acb2a9d0 authored 作者: Hengjean's avatar Hengjean

Multiple fixes.

上级 63e45998
...@@ -293,6 +293,9 @@ class CudaNdarrayType(Type): ...@@ -293,6 +293,9 @@ class CudaNdarrayType(Type):
//fprintf(stderr, "c_extract CNDA object w refcnt %%p %%i\\n", py_%(name)s, (py_%(name)s->ob_refcnt)); //fprintf(stderr, "c_extract CNDA object w refcnt %%p %%i\\n", py_%(name)s, (py_%(name)s->ob_refcnt));
%(name)s = (CudaNdarray*)py_%(name)s; %(name)s = (CudaNdarray*)py_%(name)s;
//std::cerr << "c_extract " << %(name)s << '\\n'; //std::cerr << "c_extract " << %(name)s << '\\n';
""" % locals()
if(check_input):
print >> sio, """
if (%(name)s->nd != %(nd)s) if (%(name)s->nd != %(nd)s)
{ {
PyErr_Format(PyExc_RuntimeError, PyErr_Format(PyExc_RuntimeError,
...@@ -303,7 +306,6 @@ class CudaNdarrayType(Type): ...@@ -303,7 +306,6 @@ class CudaNdarrayType(Type):
} }
//std::cerr << "c_extract " << %(name)s << " nd check passed\\n"; //std::cerr << "c_extract " << %(name)s << " nd check passed\\n";
""" % locals() """ % locals()
if(check_input):
for i, b in enumerate(self.broadcastable): for i, b in enumerate(self.broadcastable):
if b: if b:
print >> sio, """ print >> sio, """
......
...@@ -416,14 +416,21 @@ class TensorType(Type): ...@@ -416,14 +416,21 @@ class TensorType(Type):
return str(self) return str(self)
#"TensorType{%s, %s}" % (str(self.dtype), str(self.broadcastable)) #"TensorType{%s, %s}" % (str(self.dtype), str(self.broadcastable))
def c_declare(self, name, sub): def c_declare(self, name, sub, check_input=True):
"""Override `CLinkerType.c_declare` """ """Override `CLinkerType.c_declare` """
return """ if(check_input):
check = """
typedef %(dtype)s dtype_%(name)s;
""" % dict(sub, name=name, dtype=self.dtype_specs()[1])
else:
check = ""
declaration = """
PyArrayObject* %(name)s; PyArrayObject* %(name)s;
int type_num_%(name)s; int type_num_%(name)s;
typedef %(dtype)s dtype_%(name)s;
""" % dict(sub, name=name, dtype=self.dtype_specs()[1]) """ % dict(sub, name=name, dtype=self.dtype_specs()[1])
return declaration + check
def c_init(self, name, sub): def c_init(self, name, sub):
"""Override `CLinkerType.c_init` """ """Override `CLinkerType.c_init` """
return """ return """
...@@ -485,7 +492,9 @@ class TensorType(Type): ...@@ -485,7 +492,9 @@ class TensorType(Type):
} }
""" % dict(sub, name=name, type_num=self.dtype_specs()[2]) """ % dict(sub, name=name, type_num=self.dtype_specs()[2])
else: else:
check = "" check = """
type_num_%(name)s = PyArray_TYPE((PyArrayObject*) py_%(name)s);
"""
return check + """ return check + """
%(name)s = (PyArrayObject*)(py_%(name)s); %(name)s = (PyArrayObject*)(py_%(name)s);
Py_XINCREF(%(name)s); Py_XINCREF(%(name)s);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论