提交 25dec06f authored 作者: Hengjean's avatar Hengjean

Added check input parameters for all c_extract.

上级 2beab0e4
......@@ -329,7 +329,7 @@ def get_c_extract(r, name, sub):
c_extract = r.type.c_extract(name, sub,
getattr(r.owner.op, 'check_input', True))
else:
c_extract = r.type.c_extract(name, sub)
c_extract = r.type.c_extract(name, sub, True)
pre = """
py_%(name)s = PyList_GET_ITEM(storage_%(name)s, 0);
......@@ -344,7 +344,7 @@ def get_c_extract_out(r, name, sub):
c_extract = r.type.c_extract(name, sub,
getattr(r.owner.op, 'check_input', True))
else:
c_extract = r.type.c_extract(name, sub)
c_extract = r.type.c_extract(name, sub, True)
pre = """
py_%(name)s = PyList_GET_ITEM(storage_%(name)s, 0);
......
......@@ -96,7 +96,7 @@ class CLinkerType(CLinkerObject):
"""
raise MethodNotDefined("c_init", type(self), self.__class__.__name__)
def c_extract(self, name, sub):
def c_extract(self, name, sub, check_input=True):
"""Required: Return c code to extract a PyObject * instance.
The code returned from this function must be templated using
......
......@@ -163,7 +163,7 @@ class GpuArrayType(Type):
def c_init(self, name, sub):
return "%s = NULL;" % (name,)
def c_extract(self, name, sub):
def c_extract(self, name, sub, check_input=True):
# TODO I don't check broadcast stuff for now.
return """
%(name)s = NULL;
......
......@@ -266,7 +266,7 @@ class Scalar(Type):
%(name)s = 0;
""" % locals()
def c_extract(self, name, sub):
def c_extract(self, name, sub, check_input=True):
specs = self.dtype_specs()
return """
if (!PyObject_TypeCheck(py_%(name)s, &%(pyarr_type)s))
......
......@@ -247,7 +247,7 @@ class T_extending(unittest.TestCase):
def c_extract(name, sub):
def c_extract(name, sub, check_input=True):
return """
if (!PyFloat_Check(py_%(name)s)) {
PyErr_SetString(PyExc_TypeError, "expected a float");
......@@ -308,7 +308,7 @@ class T_extending(unittest.TestCase):
%(name)s = 0.0;
""" % dict(name = name)
def c_extract(self, name, sub):
def c_extract(self, name, sub, check_input=True):
return """
if (!PyFloat_Check(py_%(name)s)) {
PyErr_SetString(PyExc_TypeError, "expected a float");
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论