提交 2774de58 authored 作者: Hengjean's avatar Hengjean

Modified c_extract and c_declare in tutorial and scalar.

上级 960ccba1
...@@ -373,7 +373,7 @@ class Shape_i(gof.Op): ...@@ -373,7 +373,7 @@ class Shape_i(gof.Op):
itype = node.inputs[0].type.__class__ itype = node.inputs[0].type.__class__
if itype in self.c_code_and_version: if itype in self.c_code_and_version:
sc = """ sc = """
if (%(i)s>PyArray_NDIM(%(iname)s){ if (%(i)s>=PyArray_NDIM(%(iname)s)){
PyErr_SetString(PyExc_TypeError, "Number of dimensions lower than expected"); PyErr_SetString(PyExc_TypeError, "Number of dimensions lower than expected");
%(fail)s %(fail)s
} }
......
...@@ -306,11 +306,11 @@ def get_nothing(r, name, sub): ...@@ -306,11 +306,11 @@ def get_nothing(r, name, sub):
return "" return ""
def get_c_declare(r, name, sub, check_input=True): def get_c_declare(r, name, sub):
"""Wrapper around c_declare that declares py_name""" """Wrapper around c_declare that declares py_name"""
if any([c == 'output' or getattr(c.op, 'check_input', True) for (c, _) if any([c == 'output' or getattr(c.op, 'check_input', True) for (c, _)
in r.clients]) or r.owner and getattr(r.owner.op, in r.clients]) or (r.owner and getattr(r.owner.op,
'check_input', True): 'check_input', True)):
c_declare = r.type.c_declare(name, sub, True) c_declare = r.type.c_declare(name, sub, True)
else: else:
...@@ -334,8 +334,7 @@ def get_c_extract(r, name, sub): ...@@ -334,8 +334,7 @@ def get_c_extract(r, name, sub):
"""Wrapper around c_extract that initializes py_name from storage.""" """Wrapper around c_extract that initializes py_name from storage."""
if any([getattr(c.op, 'check_input', True) for (c, _) in r.clients]): if any([getattr(c.op, 'check_input', True) for (c, _) in r.clients]):
c_extract = r.type.c_extract(name, sub, c_extract = r.type.c_extract(name, sub, True)
True)
else: else:
c_extract = r.type.c_extract(name, sub, False) c_extract = r.type.c_extract(name, sub, False)
......
...@@ -255,8 +255,13 @@ class Scalar(Type): ...@@ -255,8 +255,13 @@ class Scalar(Type):
return str(data) return str(data)
def c_declare(self, name, sub, check_input=True): def c_declare(self, name, sub, check_input=True):
return """ if(check_input):
%(dtype)s %(name)s; pre = """
%(dtype)s %(name)s;
""" % dict(name=name, dtype=self.dtype_specs()[1])
else:
pre = ""
return pre + """
typedef %(dtype)s %(name)s_dtype; // Deprecated use dtype_%(name)s instead. typedef %(dtype)s %(name)s_dtype; // Deprecated use dtype_%(name)s instead.
typedef %(dtype)s dtype_%(name)s; typedef %(dtype)s dtype_%(name)s;
""" % dict(name=name, dtype=self.dtype_specs()[1]) """ % dict(name=name, dtype=self.dtype_specs()[1])
...@@ -268,18 +273,23 @@ class Scalar(Type): ...@@ -268,18 +273,23 @@ class Scalar(Type):
def c_extract(self, name, sub, check_input=True): def c_extract(self, name, sub, check_input=True):
specs = self.dtype_specs() specs = self.dtype_specs()
return """ if(check_input):
if (!PyObject_TypeCheck(py_%(name)s, &%(pyarr_type)s)) pre = """
{ if (!PyObject_TypeCheck(py_%(name)s, &%(pyarr_type)s))
PyErr_Format(PyExc_ValueError, {
"Scalar check failed (%(dtype)s)"); PyErr_Format(PyExc_ValueError,
%(fail)s "Scalar check failed (%(dtype)s)");
} %(fail)s
}
""" % dict(sub,
name=name,
dtype=specs[1],
pyarr_type='Py%sArrType_Type' % specs[2])
else:
pre = ""
return pre + """
PyArray_ScalarAsCtype(py_%(name)s, &%(name)s); PyArray_ScalarAsCtype(py_%(name)s, &%(name)s);
""" % dict(sub, """ % dict(sub, name=name)
name=name,
dtype=specs[1],
pyarr_type='Py%sArrType_Type' % specs[2])
def c_sync(self, name, sub): def c_sync(self, name, sub):
specs = self.dtype_specs() specs = self.dtype_specs()
......
...@@ -245,19 +245,20 @@ class T_extending(unittest.TestCase): ...@@ -245,19 +245,20 @@ class T_extending(unittest.TestCase):
""" % dict(name = name) """ % dict(name = name)
double.c_init = c_init double.c_init = c_init
def c_extract(name, sub, check_input=True): def c_extract(name, sub, check_input=True):
return """ if(check_input):
if (!PyFloat_Check(py_%(name)s)) { pre = """
PyErr_SetString(PyExc_TypeError, "expected a float"); if (!PyFloat_Check(py_%(name)s)) {
%(fail)s PyErr_SetString(PyExc_TypeError, "expected a float");
} %(fail)s
}""" % dict(name = name, fail = sub['fail'])
else:
pre = ""
return pre + """
%(name)s = PyFloat_AsDouble(py_%(name)s); %(name)s = PyFloat_AsDouble(py_%(name)s);
""" % dict(name = name, fail = sub['fail']) """ % dict(name = name, fail = sub['fail'])
double.c_extract = c_extract double.c_extract = c_extract
def c_sync( name, sub): def c_sync( name, sub):
return """ return """
Py_XDECREF(py_%(name)s); Py_XDECREF(py_%(name)s);
...@@ -309,13 +310,18 @@ class T_extending(unittest.TestCase): ...@@ -309,13 +310,18 @@ class T_extending(unittest.TestCase):
""" % dict(name = name) """ % dict(name = name)
def c_extract(self, name, sub, check_input=True): def c_extract(self, name, sub, check_input=True):
return """ if(check_input):
if (!PyFloat_Check(py_%(name)s)) { pre = """
PyErr_SetString(PyExc_TypeError, "expected a float"); if (!PyFloat_Check(py_%(name)s)) {
%(fail)s PyErr_SetString(PyExc_TypeError, "expected a float");
} %(fail)s
}
""" % dict(sub, name=name)
else:
pre = ""
return pre + """
%(name)s = PyFloat_AsDouble(py_%(name)s); %(name)s = PyFloat_AsDouble(py_%(name)s);
""" % dict(sub, name = name) """ % dict(sub, name=name)
def c_sync(self, name, sub): def c_sync(self, name, sub):
return """ return """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论