提交 25dec06f authored 作者: Hengjean's avatar Hengjean

Added check input parameters for all c_extract.

上级 2beab0e4
...@@ -329,7 +329,7 @@ def get_c_extract(r, name, sub): ...@@ -329,7 +329,7 @@ def get_c_extract(r, name, sub):
c_extract = r.type.c_extract(name, sub, c_extract = r.type.c_extract(name, sub,
getattr(r.owner.op, 'check_input', True)) getattr(r.owner.op, 'check_input', True))
else: else:
c_extract = r.type.c_extract(name, sub) c_extract = r.type.c_extract(name, sub, True)
pre = """ pre = """
py_%(name)s = PyList_GET_ITEM(storage_%(name)s, 0); py_%(name)s = PyList_GET_ITEM(storage_%(name)s, 0);
...@@ -344,7 +344,7 @@ def get_c_extract_out(r, name, sub): ...@@ -344,7 +344,7 @@ def get_c_extract_out(r, name, sub):
c_extract = r.type.c_extract(name, sub, c_extract = r.type.c_extract(name, sub,
getattr(r.owner.op, 'check_input', True)) getattr(r.owner.op, 'check_input', True))
else: else:
c_extract = r.type.c_extract(name, sub) c_extract = r.type.c_extract(name, sub, True)
pre = """ pre = """
py_%(name)s = PyList_GET_ITEM(storage_%(name)s, 0); py_%(name)s = PyList_GET_ITEM(storage_%(name)s, 0);
......
...@@ -96,7 +96,7 @@ class CLinkerType(CLinkerObject): ...@@ -96,7 +96,7 @@ class CLinkerType(CLinkerObject):
""" """
raise MethodNotDefined("c_init", type(self), self.__class__.__name__) raise MethodNotDefined("c_init", type(self), self.__class__.__name__)
def c_extract(self, name, sub): def c_extract(self, name, sub, check_input=True):
"""Required: Return c code to extract a PyObject * instance. """Required: Return c code to extract a PyObject * instance.
The code returned from this function must be templated using The code returned from this function must be templated using
......
...@@ -163,7 +163,7 @@ class GpuArrayType(Type): ...@@ -163,7 +163,7 @@ class GpuArrayType(Type):
def c_init(self, name, sub): def c_init(self, name, sub):
return "%s = NULL;" % (name,) return "%s = NULL;" % (name,)
def c_extract(self, name, sub): def c_extract(self, name, sub, check_input=True):
# TODO I don't check broadcast stuff for now. # TODO I don't check broadcast stuff for now.
return """ return """
%(name)s = NULL; %(name)s = NULL;
......
...@@ -266,7 +266,7 @@ class Scalar(Type): ...@@ -266,7 +266,7 @@ class Scalar(Type):
%(name)s = 0; %(name)s = 0;
""" % locals() """ % locals()
def c_extract(self, name, sub): def c_extract(self, name, sub, check_input=True):
specs = self.dtype_specs() specs = self.dtype_specs()
return """ return """
if (!PyObject_TypeCheck(py_%(name)s, &%(pyarr_type)s)) if (!PyObject_TypeCheck(py_%(name)s, &%(pyarr_type)s))
......
...@@ -247,7 +247,7 @@ class T_extending(unittest.TestCase): ...@@ -247,7 +247,7 @@ class T_extending(unittest.TestCase):
def c_extract(name, sub): def c_extract(name, sub, check_input=True):
return """ return """
if (!PyFloat_Check(py_%(name)s)) { if (!PyFloat_Check(py_%(name)s)) {
PyErr_SetString(PyExc_TypeError, "expected a float"); PyErr_SetString(PyExc_TypeError, "expected a float");
...@@ -308,7 +308,7 @@ class T_extending(unittest.TestCase): ...@@ -308,7 +308,7 @@ class T_extending(unittest.TestCase):
%(name)s = 0.0; %(name)s = 0.0;
""" % dict(name = name) """ % dict(name = name)
def c_extract(self, name, sub): def c_extract(self, name, sub, check_input=True):
return """ return """
if (!PyFloat_Check(py_%(name)s)) { if (!PyFloat_Check(py_%(name)s)) {
PyErr_SetString(PyExc_TypeError, "expected a float"); PyErr_SetString(PyExc_TypeError, "expected a float");
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论