提交 63e45998 authored 作者: Hengjean's avatar Hengjean

Added check-input.

上级 64922a8b
...@@ -213,6 +213,8 @@ class Shape(gof.Op): ...@@ -213,6 +213,8 @@ class Shape(gof.Op):
# the output variable is %(oname)s. # the output variable is %(oname)s.
c_code_and_version = {} c_code_and_version = {}
check_input = True
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
......
...@@ -280,7 +280,7 @@ class CudaNdarrayType(Type): ...@@ -280,7 +280,7 @@ class CudaNdarrayType(Type):
def c_init(self, name, sub): def c_init(self, name, sub):
return "%(name)s = NULL;" % locals() return "%(name)s = NULL;" % locals()
def c_extract(self, name, sub): def c_extract(self, name, sub, check_input=True):
sio = StringIO() sio = StringIO()
fail = sub['fail'] fail = sub['fail']
nd = self.ndim nd = self.ndim
...@@ -303,7 +303,7 @@ class CudaNdarrayType(Type): ...@@ -303,7 +303,7 @@ class CudaNdarrayType(Type):
} }
//std::cerr << "c_extract " << %(name)s << " nd check passed\\n"; //std::cerr << "c_extract " << %(name)s << " nd check passed\\n";
""" % locals() """ % locals()
if(theano.compile.ops.Shape.check_input): if(check_input):
for i, b in enumerate(self.broadcastable): for i, b in enumerate(self.broadcastable):
if b: if b:
print >> sio, """ print >> sio, """
......
...@@ -431,9 +431,9 @@ class TensorType(Type): ...@@ -431,9 +431,9 @@ class TensorType(Type):
type_num_%(name)s = %(type_num)s; type_num_%(name)s = %(type_num)s;
""" % dict(sub, name=name, type_num=self.dtype_specs()[2]) """ % dict(sub, name=name, type_num=self.dtype_specs()[2])
def c_extract(self, name, sub): def c_extract(self, name, sub, check_input=True):
"""Override `CLinkerType.c_extract` """ """Override `CLinkerType.c_extract` """
if(theano.compile.ops.Shape.check_input): if(check_input):
check = """ check = """
%(name)s = NULL; %(name)s = NULL;
if (py_%(name)s == Py_None) { if (py_%(name)s == Py_None) {
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论