提交 4b63cf4a authored 作者: Hengjean's avatar Hengjean

Added check_input where missing and changed get_c_declare and get_c_extract.

上级 84935fcf
...@@ -294,6 +294,7 @@ shape = Shape() ...@@ -294,6 +294,7 @@ shape = Shape()
_shape = shape # was used in the past, now use shape directly. _shape = shape # was used in the past, now use shape directly.
#pprint.assign(_shape, printing.MemberPrinter('shape')) #pprint.assign(_shape, printing.MemberPrinter('shape'))
class Shape_i(gof.Op): class Shape_i(gof.Op):
""" """
L{Op} to return the shape of a matrix. L{Op} to return the shape of a matrix.
...@@ -305,6 +306,8 @@ class Shape_i(gof.Op): ...@@ -305,6 +306,8 @@ class Shape_i(gof.Op):
# the output variable is %(oname)s. # the output variable is %(oname)s.
c_code_and_version = {} c_code_and_version = {}
check_input = False
def __init__(self, i): def __init__(self, i):
self.i = i self.i = i
...@@ -350,6 +353,9 @@ class Shape_i(gof.Op): ...@@ -350,6 +353,9 @@ class Shape_i(gof.Op):
return () return ()
version.append((str(t), v)) version.append((str(t), v))
if version:
version.append(1)
return tuple(version) return tuple(version)
def c_code(self, node, name, inames, onames, sub): def c_code(self, node, name, inames, onames, sub):
......
...@@ -306,13 +306,12 @@ def get_nothing(r, name, sub): ...@@ -306,13 +306,12 @@ def get_nothing(r, name, sub):
return "" return ""
def get_c_declare(r, name, sub): def get_c_declare(r, name, sub, check_input=True):
"""Wrapper around c_declare that declares py_name""" """Wrapper around c_declare that declares py_name"""
if r.owner: if any([c == 'output' or getattr(c.op, 'check_input', True) for (c, _) in r.clients]):
c_declare = r.type.c_declare(name, sub,
getattr(r.owner.op, 'check_input', True))
else:
c_declare = r.type.c_declare(name, sub, True) c_declare = r.type.c_declare(name, sub, True)
else:
c_declare = r.type.c_declare(name, sub, False)
pre = """ pre = """
PyObject* py_%(name)s; PyObject* py_%(name)s;
""" % locals() """ % locals()
...@@ -330,11 +329,12 @@ def get_c_init(r, name, sub): ...@@ -330,11 +329,12 @@ def get_c_init(r, name, sub):
def get_c_extract(r, name, sub): def get_c_extract(r, name, sub):
"""Wrapper around c_extract that initializes py_name from storage.""" """Wrapper around c_extract that initializes py_name from storage."""
if r.owner: if any([getattr(c.op, 'check_input', True) for (c, _) in r.clients]):
c_extract = r.type.c_extract(name, sub, c_extract = r.type.c_extract(name, sub,
getattr(r.owner.op, 'check_input', True)) True)
else: else:
c_extract = r.type.c_extract(name, sub, True) c_extract = r.type.c_extract(name, sub, False)
pre = """ pre = """
py_%(name)s = PyList_GET_ITEM(storage_%(name)s, 0); py_%(name)s = PyList_GET_ITEM(storage_%(name)s, 0);
...@@ -345,11 +345,8 @@ def get_c_extract(r, name, sub): ...@@ -345,11 +345,8 @@ def get_c_extract(r, name, sub):
def get_c_extract_out(r, name, sub): def get_c_extract_out(r, name, sub):
"""Wrapper around c_extract_out that initializes py_name from storage.""" """Wrapper around c_extract_out that initializes py_name from storage."""
if r.owner: c_extract = r.type.c_extract_out(name, sub,
c_extract = r.type.c_extract_out(name, sub,
getattr(r.owner.op, 'check_input', True)) getattr(r.owner.op, 'check_input', True))
else:
c_extract = r.type.c_extract_out(name, sub, True)
pre = """ pre = """
py_%(name)s = PyList_GET_ITEM(storage_%(name)s, 0); py_%(name)s = PyList_GET_ITEM(storage_%(name)s, 0);
......
...@@ -22,7 +22,7 @@ class TDouble(Type): ...@@ -22,7 +22,7 @@ class TDouble(Type):
def filter(self, data): def filter(self, data):
return float(data) return float(data)
def c_declare(self, name, sub): def c_declare(self, name, sub, check_input=True):
return "double %(name)s; void* %(name)s_bad_thing;" % locals() return "double %(name)s; void* %(name)s_bad_thing;" % locals()
def c_init(self, name, sub): def c_init(self, name, sub):
...@@ -35,7 +35,7 @@ class TDouble(Type): ...@@ -35,7 +35,7 @@ class TDouble(Type):
def c_literal(self, data): def c_literal(self, data):
return str(data) return str(data)
def c_extract(self, name, sub): def c_extract(self, name, sub, check_input=True):
return """ return """
if (!PyFloat_Check(py_%(name)s)) { if (!PyFloat_Check(py_%(name)s)) {
PyErr_SetString(PyExc_TypeError, "not a double!"); PyErr_SetString(PyExc_TypeError, "not a double!");
......
...@@ -444,7 +444,7 @@ class Generic(SingletonType): ...@@ -444,7 +444,7 @@ class Generic(SingletonType):
%(name)s = NULL; %(name)s = NULL;
""" % locals() """ % locals()
def c_extract(self, name, sub): def c_extract(self, name, sub, check_input=True):
return """ return """
Py_INCREF(py_%(name)s); Py_INCREF(py_%(name)s);
%(name)s = py_%(name)s; %(name)s = py_%(name)s;
......
...@@ -298,7 +298,7 @@ class T_extending(unittest.TestCase): ...@@ -298,7 +298,7 @@ class T_extending(unittest.TestCase):
def __str__(self): def __str__(self):
return "double" return "double"
def c_declare(self, name, sub): def c_declare(self, name, sub, check_input=True):
return """ return """
double %(name)s; double %(name)s;
""" % dict(name = name) """ % dict(name = name)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论