提交 64922a8b authored 作者: Hengjean's avatar Hengjean

Added check inut for cuda and tensor

上级 799e97dd
...@@ -303,6 +303,7 @@ class CudaNdarrayType(Type): ...@@ -303,6 +303,7 @@ class CudaNdarrayType(Type):
} }
//std::cerr << "c_extract " << %(name)s << " nd check passed\\n"; //std::cerr << "c_extract " << %(name)s << " nd check passed\\n";
""" % locals() """ % locals()
if(theano.compile.ops.Shape.check_input):
for i, b in enumerate(self.broadcastable): for i, b in enumerate(self.broadcastable):
if b: if b:
print >> sio, """ print >> sio, """
...@@ -348,6 +349,12 @@ class CudaNdarrayType(Type): ...@@ -348,6 +349,12 @@ class CudaNdarrayType(Type):
} }
//std::cerr << "c_extract done " << %(name)s << '\\n'; //std::cerr << "c_extract done " << %(name)s << '\\n';
""" % locals() """ % locals()
else:
print >> sio, """
assert(%(name)s);
Py_INCREF(py_%(name)s);
}
""" % locals()
#print sio.getvalue() #print sio.getvalue()
return sio.getvalue() return sio.getvalue()
......
...@@ -433,7 +433,8 @@ class TensorType(Type): ...@@ -433,7 +433,8 @@ class TensorType(Type):
def c_extract(self, name, sub): def c_extract(self, name, sub):
"""Override `CLinkerType.c_extract` """ """Override `CLinkerType.c_extract` """
return """ if(theano.compile.ops.Shape.check_input):
check = """
%(name)s = NULL; %(name)s = NULL;
if (py_%(name)s == Py_None) { if (py_%(name)s == Py_None) {
// We can either fail here or set %(name)s to NULL and rely on Ops // We can either fail here or set %(name)s to NULL and rely on Ops
...@@ -482,6 +483,10 @@ class TensorType(Type): ...@@ -482,6 +483,10 @@ class TensorType(Type):
%(type_num)s, type_num_%(name)s); %(type_num)s, type_num_%(name)s);
%(fail)s %(fail)s
} }
""" % dict(sub, name=name, type_num=self.dtype_specs()[2])
else:
check = ""
return check + """
%(name)s = (PyArrayObject*)(py_%(name)s); %(name)s = (PyArrayObject*)(py_%(name)s);
Py_XINCREF(%(name)s); Py_XINCREF(%(name)s);
""" % dict(sub, name=name, type_num=self.dtype_specs()[2]) """ % dict(sub, name=name, type_num=self.dtype_specs()[2])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论