提交 7a63d84d authored 作者: --global's avatar --global

Add additional checks to the validation function

上级 34617034
......@@ -209,6 +209,48 @@ class Scan(PureOp):
"type '%s' and '%s' respectively." %
(self.name, type_input, type_output))
# If scan has the flag 'gpu' set to false (meaning that is shouldn't
# use the CUDA gpu backend ), ensure that is has no input and no
# output with type CudaNdarrayType
from theano.sandbox.cuda import CudaNdarrayType
if not self.info.get("gpu", False):
for inp in self.inputs:
if isinstance(inp.type, CudaNdarrayType):
raise TypeError("Inconsistency in the inner graph of "
"scan '%s' : one of the inputs to the "
"inner graph is of type CudaNdarray but "
"the attributes of the scan op indicate "
"that it shouldn't be the case")
for out in self.outputs:
if isinstance(out.type, CudaNdarrayType):
raise TypeError("Inconsistency in the inner graph of "
"scan '%s' : one of the outputs to the "
"inner graph is of type CudaNdarray but "
"the attributes of the scan op indicate "
"that it shouldn't be the case")
# If scan has the flag 'gpua' set to false (meaning that is shouldn't
# use the gpuarray gpu backend ), ensure that is has no input and no
# output with type GpuArrayType
from theano.sandbox.gpuarray import GpuArrayType
if not self.info.get("gpua", False):
for inp in self.inputs:
if isinstance(inp.type, GpuArrayType):
raise TypeError("Inconsistency in the inner graph of "
"scan '%s' : one of the inputs to the "
"inner graph is of type GpuArrayType but "
"the attributes of the scan op indicate "
"that it shouldn't be the case")
for out in self.outputs:
if isinstance(inp.type, GpuArrayType):
raise TypeError("Inconsistency in the inner graph of "
"scan '%s' : one of the outputs to the "
"inner graph is of type GpuArrayType but "
"the attributes of the scan op indicate "
"that it shouldn't be the case")
def __setstate__(self, d):
self.__dict__.update(d)
self.validate_inner_graph()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论