提交 84935fcf authored 作者: Hengjean's avatar Hengjean

Added chck_input to c_declare

上级 991c794d
...@@ -308,10 +308,15 @@ def get_nothing(r, name, sub): ...@@ -308,10 +308,15 @@ def get_nothing(r, name, sub):
def get_c_declare(r, name, sub): def get_c_declare(r, name, sub):
"""Wrapper around c_declare that declares py_name""" """Wrapper around c_declare that declares py_name"""
if r.owner:
c_declare = r.type.c_declare(name, sub,
getattr(r.owner.op, 'check_input', True))
else:
c_declare = r.type.c_declare(name, sub, True)
pre = """ pre = """
PyObject* py_%(name)s; PyObject* py_%(name)s;
""" % locals() """ % locals()
return pre + r.type.c_declare(name, sub) return pre + c_declare
def get_c_init(r, name, sub): def get_c_init(r, name, sub):
......
...@@ -44,7 +44,7 @@ class CLinkerType(CLinkerObject): ...@@ -44,7 +44,7 @@ class CLinkerType(CLinkerObject):
""" """
raise MethodNotDefined("c_literal", type(self), self.__class__.__name__) raise MethodNotDefined("c_literal", type(self), self.__class__.__name__)
def c_declare(self, name, sub): def c_declare(self, name, sub, check_input=True):
"""Required: Return c code to declare variables that will be """Required: Return c code to declare variables that will be
instantiated by `c_extract`. instantiated by `c_extract`.
...@@ -434,7 +434,7 @@ class Generic(SingletonType): ...@@ -434,7 +434,7 @@ class Generic(SingletonType):
def is_valid_value(self, a): def is_valid_value(self, a):
return True return True
def c_declare(self, name, sub): def c_declare(self, name, sub, check_input=True):
return """ return """
PyObject* %(name)s; PyObject* %(name)s;
""" % locals() """ % locals()
......
...@@ -274,7 +274,7 @@ class CudaNdarrayType(Type): ...@@ -274,7 +274,7 @@ class CudaNdarrayType(Type):
return str(self) return str(self)
#"CudaNdarrayType{%s, %s}" % (str(self.dtype), str(self.broadcastable)) #"CudaNdarrayType{%s, %s}" % (str(self.dtype), str(self.broadcastable))
def c_declare(self, name, sub): def c_declare(self, name, sub, check_input=True):
return """ CudaNdarray * %(name)s;""" % locals() return """ CudaNdarray * %(name)s;""" % locals()
def c_init(self, name, sub): def c_init(self, name, sub):
......
...@@ -155,7 +155,7 @@ class GpuArrayType(Type): ...@@ -155,7 +155,7 @@ class GpuArrayType(Type):
else: else:
return numpy.dtype(self.dtype).itemsize return numpy.dtype(self.dtype).itemsize
def c_declare(self, name, sub): def c_declare(self, name, sub, check_input=True):
return """ return """
PyGpuArrayObject *%(name)s; PyGpuArrayObject *%(name)s;
""" % locals() """ % locals()
......
...@@ -254,7 +254,7 @@ class Scalar(Type): ...@@ -254,7 +254,7 @@ class Scalar(Type):
raise NotImplementedError("No literal for complex values.") raise NotImplementedError("No literal for complex values.")
return str(data) return str(data)
def c_declare(self, name, sub): def c_declare(self, name, sub, check_input=True):
return """ return """
%(dtype)s %(name)s; %(dtype)s %(name)s;
typedef %(dtype)s %(name)s_dtype; // Deprecated use dtype_%(name)s instead. typedef %(dtype)s %(name)s_dtype; // Deprecated use dtype_%(name)s instead.
......
...@@ -232,7 +232,7 @@ class T_extending(unittest.TestCase): ...@@ -232,7 +232,7 @@ class T_extending(unittest.TestCase):
div = BinaryDoubleOp(name = 'div', div = BinaryDoubleOp(name = 'div',
fn = lambda x, y: x / y) fn = lambda x, y: x / y)
def c_declare(name, sub): def c_declare(name, sub, check_input=True):
return """ return """
double %(name)s; double %(name)s;
""" % dict(name = name) """ % dict(name = name)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论