提交 63e45998 authored 作者: Hengjean's avatar Hengjean

Added check-input.

上级 64922a8b
......@@ -213,6 +213,8 @@ class Shape(gof.Op):
# the output variable is %(oname)s.
c_code_and_version = {}
check_input = True
def __hash__(self):
return hash(type(self))
......
......@@ -280,7 +280,7 @@ class CudaNdarrayType(Type):
def c_init(self, name, sub):
return "%(name)s = NULL;" % locals()
def c_extract(self, name, sub):
def c_extract(self, name, sub, check_input=True):
sio = StringIO()
fail = sub['fail']
nd = self.ndim
......@@ -303,7 +303,7 @@ class CudaNdarrayType(Type):
}
//std::cerr << "c_extract " << %(name)s << " nd check passed\\n";
""" % locals()
if(theano.compile.ops.Shape.check_input):
if(check_input):
for i, b in enumerate(self.broadcastable):
if b:
print >> sio, """
......
......@@ -431,9 +431,9 @@ class TensorType(Type):
type_num_%(name)s = %(type_num)s;
""" % dict(sub, name=name, type_num=self.dtype_specs()[2])
def c_extract(self, name, sub):
def c_extract(self, name, sub, check_input=True):
"""Override `CLinkerType.c_extract` """
if(theano.compile.ops.Shape.check_input):
if(check_input):
check = """
%(name)s = NULL;
if (py_%(name)s == Py_None) {
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论