提交 59fbcaee authored 作者: Hengjean's avatar Hengjean

Changed the use location of the global check_input flag. Removed no more relevant lines in doc.

上级 23104525
...@@ -141,8 +141,7 @@ The ``c_code`` method accepts variable names as arguments (``name``, ``inames``, ...@@ -141,8 +141,7 @@ The ``c_code`` method accepts variable names as arguments (``name``, ``inames``,
``onames``) and returns a C code fragment that computes the expression output. ``onames``) and returns a C code fragment that computes the expression output.
In case of error, the ``%(fail)s`` statement cleans up and returns properly. In case of error, the ``%(fail)s`` statement cleans up and returns properly.
The variables ``%(x)s`` and ``%(y)s`` are set up by the TensorType to be ``PyArrayObject`` pointers. The variables ``%(x)s`` and ``%(y)s`` are set up by the TensorType to be ``PyArrayObject`` pointers.
TensorType also set up ``dtype_%(x)s`` to be a typdef to the C type for ``x``, TensorType also set up ``dtype_%(x)s`` to be a typdef to the C type for ``x``.
``type_num_%(x)s`` is the corresponding NumPy type number.
In the first two lines of the C function, we make y point to a new array with In the first two lines of the C function, we make y point to a new array with
the correct size for the output. This is essentially simulating the line the correct size for the output. This is essentially simulating the line
......
...@@ -333,7 +333,8 @@ def get_c_init(r, name, sub): ...@@ -333,7 +333,8 @@ def get_c_init(r, name, sub):
def get_c_extract(r, name, sub): def get_c_extract(r, name, sub):
"""Wrapper around c_extract that initializes py_name from storage.""" """Wrapper around c_extract that initializes py_name from storage."""
if any([getattr(c.op, 'check_input', True) for (c, _) in r.clients]): if any([getattr(c.op, 'check_input', config.check_input) for (c, _) in
r.clients]):
c_extract = r.type.c_extract(name, sub, True) c_extract = r.type.c_extract(name, sub, True)
else: else:
...@@ -349,7 +350,7 @@ def get_c_extract(r, name, sub): ...@@ -349,7 +350,7 @@ def get_c_extract(r, name, sub):
def get_c_extract_out(r, name, sub): def get_c_extract_out(r, name, sub):
"""Wrapper around c_extract_out that initializes py_name from storage.""" """Wrapper around c_extract_out that initializes py_name from storage."""
c_extract = r.type.c_extract_out(name, sub, c_extract = r.type.c_extract_out(name, sub,
getattr(r.owner.op, 'check_input', True)) getattr(r.owner.op, 'check_input', config.check_input))
pre = """ pre = """
py_%(name)s = PyList_GET_ITEM(storage_%(name)s, 0); py_%(name)s = PyList_GET_ITEM(storage_%(name)s, 0);
......
...@@ -294,7 +294,7 @@ class CudaNdarrayType(Type): ...@@ -294,7 +294,7 @@ class CudaNdarrayType(Type):
%(name)s = (CudaNdarray*)py_%(name)s; %(name)s = (CudaNdarray*)py_%(name)s;
//std::cerr << "c_extract " << %(name)s << '\\n'; //std::cerr << "c_extract " << %(name)s << '\\n';
""" % locals() """ % locals()
if(check_input and theano.config.check_input): if(check_input):
print >> sio, """ print >> sio, """
if (%(name)s->nd != %(nd)s) if (%(name)s->nd != %(nd)s)
{ {
......
...@@ -273,7 +273,7 @@ class Scalar(Type): ...@@ -273,7 +273,7 @@ class Scalar(Type):
def c_extract(self, name, sub, check_input=True): def c_extract(self, name, sub, check_input=True):
specs = self.dtype_specs() specs = self.dtype_specs()
if(check_input and theano.config.check_input): if(check_input):
pre = """ pre = """
if (!PyObject_TypeCheck(py_%(name)s, &%(pyarr_type)s)) if (!PyObject_TypeCheck(py_%(name)s, &%(pyarr_type)s))
{ {
......
...@@ -438,7 +438,7 @@ class TensorType(Type): ...@@ -438,7 +438,7 @@ class TensorType(Type):
def c_extract(self, name, sub, check_input=True): def c_extract(self, name, sub, check_input=True):
"""Override `CLinkerType.c_extract` """ """Override `CLinkerType.c_extract` """
if(check_input and theano.config.check_input): if(check_input):
check = """ check = """
%(name)s = NULL; %(name)s = NULL;
if (py_%(name)s == Py_None) { if (py_%(name)s == Py_None) {
......
...@@ -246,7 +246,7 @@ class T_extending(unittest.TestCase): ...@@ -246,7 +246,7 @@ class T_extending(unittest.TestCase):
double.c_init = c_init double.c_init = c_init
def c_extract(name, sub, check_input=True): def c_extract(name, sub, check_input=True):
if(check_input and theano.config.check_input): if(check_input):
pre = """ pre = """
if (!PyFloat_Check(py_%(name)s)) { if (!PyFloat_Check(py_%(name)s)) {
PyErr_SetString(PyExc_TypeError, "expected a float"); PyErr_SetString(PyExc_TypeError, "expected a float");
...@@ -310,7 +310,7 @@ class T_extending(unittest.TestCase): ...@@ -310,7 +310,7 @@ class T_extending(unittest.TestCase):
""" % dict(name = name) """ % dict(name = name)
def c_extract(self, name, sub, check_input=True): def c_extract(self, name, sub, check_input=True):
if(check_input and theano.config.check_input): if(check_input):
pre = """ pre = """
if (!PyFloat_Check(py_%(name)s)) { if (!PyFloat_Check(py_%(name)s)) {
PyErr_SetString(PyExc_TypeError, "expected a float"); PyErr_SetString(PyExc_TypeError, "expected a float");
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论