提交 a2ab9bd9 authored 作者: Caglar's avatar Caglar

fixed the condition in the local optimization.

上级 9ec3b896
...@@ -446,12 +446,11 @@ def local_gpu_dot_to_dot22(node): ...@@ -446,12 +446,11 @@ def local_gpu_dot_to_dot22(node):
shape_out))] shape_out))]
return False return False
@local_optimizer(None) @local_optimizer(None)
def local_assert_no_cpu_op(node): def local_assert_no_cpu_op(node):
if not isinstance(node.op, GpuOp) and all([var.owner and isinstance(var.owner.op, if not isinstance(node.op, GpuOp) and all([var.owner and isinstance(var.owner.op,
HostFromGpu) for var in node.inputs]) and all([var.owner and HostFromGpu) for var in node.inputs]) and any([[c for c in var.clients
isinstance(var.owner.op, GpuFromHost) for var in node.outputs]): if isinstance(c[0].op, GpuFromHost)] for var in node.outputs]):
if config.assert_no_cpu_op == "warn": if config.assert_no_cpu_op == "warn":
_logger.warning(("CPU op %s is detected in the computational" _logger.warning(("CPU op %s is detected in the computational"
" graph") % node) " graph") % node)
...@@ -459,6 +458,7 @@ def local_assert_no_cpu_op(node): ...@@ -459,6 +458,7 @@ def local_assert_no_cpu_op(node):
raise RuntimeError("The op %s is on CPU." % node) raise RuntimeError("The op %s is on CPU." % node)
elif config.assert_no_cpu_op == "pdb": elif config.assert_no_cpu_op == "pdb":
pdb.set_trace() pdb.set_trace()
return None return None
#Register the local_assert_no_cpu_op: #Register the local_assert_no_cpu_op:
......
...@@ -107,6 +107,7 @@ def test_local_assert_no_cpu_op(): ...@@ -107,6 +107,7 @@ def test_local_assert_no_cpu_op():
# If the flag is raise # If the flag is raise
try: try:
config.assert_no_cpu_op = 'raise' config.assert_no_cpu_op = 'raise'
assert_raises(RuntimeError, theano.function, assert_raises(RuntimeError, theano.function,
[], out, mode=mode_local_assert) [], out, mode=mode_local_assert)
finally: finally:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论