提交 2774de58 authored 作者: Hengjean's avatar Hengjean

Modified c_extract and c_declare in tutorial and scalar.

上级 960ccba1
......@@ -373,7 +373,7 @@ class Shape_i(gof.Op):
itype = node.inputs[0].type.__class__
if itype in self.c_code_and_version:
sc = """
if (%(i)s>PyArray_NDIM(%(iname)s){
if (%(i)s>=PyArray_NDIM(%(iname)s)){
PyErr_SetString(PyExc_TypeError, "Number of dimensions lower than expected");
%(fail)s
}
......
......@@ -306,11 +306,11 @@ def get_nothing(r, name, sub):
return ""
def get_c_declare(r, name, sub, check_input=True):
def get_c_declare(r, name, sub):
"""Wrapper around c_declare that declares py_name"""
if any([c == 'output' or getattr(c.op, 'check_input', True) for (c, _)
in r.clients]) or r.owner and getattr(r.owner.op,
'check_input', True):
in r.clients]) or (r.owner and getattr(r.owner.op,
'check_input', True)):
c_declare = r.type.c_declare(name, sub, True)
else:
......@@ -334,8 +334,7 @@ def get_c_extract(r, name, sub):
"""Wrapper around c_extract that initializes py_name from storage."""
if any([getattr(c.op, 'check_input', True) for (c, _) in r.clients]):
c_extract = r.type.c_extract(name, sub,
True)
c_extract = r.type.c_extract(name, sub, True)
else:
c_extract = r.type.c_extract(name, sub, False)
......
......@@ -255,8 +255,13 @@ class Scalar(Type):
return str(data)
def c_declare(self, name, sub, check_input=True):
return """
%(dtype)s %(name)s;
if(check_input):
pre = """
%(dtype)s %(name)s;
""" % dict(name=name, dtype=self.dtype_specs()[1])
else:
pre = ""
return pre + """
typedef %(dtype)s %(name)s_dtype; // Deprecated use dtype_%(name)s instead.
typedef %(dtype)s dtype_%(name)s;
""" % dict(name=name, dtype=self.dtype_specs()[1])
......@@ -268,18 +273,23 @@ class Scalar(Type):
def c_extract(self, name, sub, check_input=True):
specs = self.dtype_specs()
return """
if (!PyObject_TypeCheck(py_%(name)s, &%(pyarr_type)s))
{
PyErr_Format(PyExc_ValueError,
"Scalar check failed (%(dtype)s)");
%(fail)s
}
if(check_input):
pre = """
if (!PyObject_TypeCheck(py_%(name)s, &%(pyarr_type)s))
{
PyErr_Format(PyExc_ValueError,
"Scalar check failed (%(dtype)s)");
%(fail)s
}
""" % dict(sub,
name=name,
dtype=specs[1],
pyarr_type='Py%sArrType_Type' % specs[2])
else:
pre = ""
return pre + """
PyArray_ScalarAsCtype(py_%(name)s, &%(name)s);
""" % dict(sub,
name=name,
dtype=specs[1],
pyarr_type='Py%sArrType_Type' % specs[2])
""" % dict(sub, name=name)
def c_sync(self, name, sub):
specs = self.dtype_specs()
......
......@@ -245,19 +245,20 @@ class T_extending(unittest.TestCase):
""" % dict(name = name)
double.c_init = c_init
def c_extract(name, sub, check_input=True):
return """
if (!PyFloat_Check(py_%(name)s)) {
PyErr_SetString(PyExc_TypeError, "expected a float");
%(fail)s
}
if(check_input):
pre = """
if (!PyFloat_Check(py_%(name)s)) {
PyErr_SetString(PyExc_TypeError, "expected a float");
%(fail)s
}""" % dict(name = name, fail = sub['fail'])
else:
pre = ""
return pre + """
%(name)s = PyFloat_AsDouble(py_%(name)s);
""" % dict(name = name, fail = sub['fail'])
double.c_extract = c_extract
def c_sync( name, sub):
return """
Py_XDECREF(py_%(name)s);
......@@ -309,13 +310,18 @@ class T_extending(unittest.TestCase):
""" % dict(name = name)
def c_extract(self, name, sub, check_input=True):
return """
if (!PyFloat_Check(py_%(name)s)) {
PyErr_SetString(PyExc_TypeError, "expected a float");
%(fail)s
}
if(check_input):
pre = """
if (!PyFloat_Check(py_%(name)s)) {
PyErr_SetString(PyExc_TypeError, "expected a float");
%(fail)s
}
""" % dict(sub, name=name)
else:
pre = ""
return pre + """
%(name)s = PyFloat_AsDouble(py_%(name)s);
""" % dict(sub, name = name)
""" % dict(sub, name=name)
def c_sync(self, name, sub):
return """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论