提交 75f8fee4 authored 作者: Caglar's avatar Caglar

made the config var overridable.

上级 465b124c
...@@ -131,7 +131,7 @@ AddConfigVar( ...@@ -131,7 +131,7 @@ AddConfigVar(
AddConfigVar( AddConfigVar(
'assert_no_cpu_op', 'assert_no_cpu_op',
"Raise an error/warning if there is a CPU op in the computational graph.", "Raise an error/warning if there is a CPU op in the computational graph.",
EnumStr('ignore', 'warn', 'raise', 'pdb', allow_override=False), EnumStr('ignore', 'warn', 'raise', 'pdb', allow_override=True),
in_c_key=False) in_c_key=False)
......
...@@ -447,7 +447,7 @@ def local_gpu_dot_to_dot22(node): ...@@ -447,7 +447,7 @@ def local_gpu_dot_to_dot22(node):
return False return False
@local_optimizer([theano.gof.Op]) @local_optimizer(None)
def local_assert_no_cpu_op(node): def local_assert_no_cpu_op(node):
if not isinstance(node.op, GpuOp) and all([var.owner and isinstance(var.owner.op, if not isinstance(node.op, GpuOp) and all([var.owner and isinstance(var.owner.op,
HostFromGpu) for var in node.inputs]) and all([var.owner and HostFromGpu) for var in node.inputs]) and all([var.owner and
......
...@@ -102,13 +102,23 @@ def test_local_assert_no_cpu_op(): ...@@ -102,13 +102,23 @@ def test_local_assert_no_cpu_op():
mode_local_assert = mode_with_gpu.including("local_assert_no_cpu_op") mode_local_assert = mode_with_gpu.including("local_assert_no_cpu_op")
mode_local_assert = mode_local_assert.excluding("local_gpu_elemwise_1") mode_local_assert = mode_local_assert.excluding("local_gpu_elemwise_1")
old = config.assert_no_cpu_op
#If the flag is raise #If the flag is raise
try: try:
mode_local_assert = \ config.assert_no_cpu_op = 'raise'
mode_local_assert.including("assert_no_cpu_op='%s'" % flag) assert_raises(AssertionError,
f = theano.function([], out, mode=mode_local_assert) theano.function([], out,
except Exception as expt: mode=mode_local_assert))
finally:
#If the flag is ignore #If the flag is ignore
config.assert_no_cpu_op = old
try:
config.assert_no_cpu_op = 'ignore'
fn = theano.function([], out, mode=mode_local_assert))
finally:
#If the flag is ignore
config.assert_no_cpu_op = old
def test_int_pow(): def test_int_pow():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论