提交 a2ab9bd9 authored 作者: Caglar's avatar Caglar

fixed the condition in the local optimization.

上级 9ec3b896
......@@ -446,12 +446,11 @@ def local_gpu_dot_to_dot22(node):
shape_out))]
return False
@local_optimizer(None)
def local_assert_no_cpu_op(node):
if not isinstance(node.op, GpuOp) and all([var.owner and isinstance(var.owner.op,
HostFromGpu) for var in node.inputs]) and all([var.owner and
isinstance(var.owner.op, GpuFromHost) for var in node.outputs]):
HostFromGpu) for var in node.inputs]) and any([[c for c in var.clients
if isinstance(c[0].op, GpuFromHost)] for var in node.outputs]):
if config.assert_no_cpu_op == "warn":
_logger.warning(("CPU op %s is detected in the computational"
" graph") % node)
......@@ -459,6 +458,7 @@ def local_assert_no_cpu_op(node):
raise RuntimeError("The op %s is on CPU." % node)
elif config.assert_no_cpu_op == "pdb":
pdb.set_trace()
return None
#Register the local_assert_no_cpu_op:
......
......@@ -107,6 +107,7 @@ def test_local_assert_no_cpu_op():
# If the flag is raise
try:
config.assert_no_cpu_op = 'raise'
assert_raises(RuntimeError, theano.function,
[], out, mode=mode_local_assert)
finally:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论