提交 af9d3a0f authored 作者: Caglar's avatar Caglar

Changed the RuntimeError to AssertionError and updated the test.

上级 a2ab9bd9
...@@ -455,7 +455,7 @@ def local_assert_no_cpu_op(node): ...@@ -455,7 +455,7 @@ def local_assert_no_cpu_op(node):
_logger.warning(("CPU op %s is detected in the computational" _logger.warning(("CPU op %s is detected in the computational"
" graph") % node) " graph") % node)
elif config.assert_no_cpu_op == "raise": elif config.assert_no_cpu_op == "raise":
raise RuntimeError("The op %s is on CPU." % node) raise AssertionError("The op %s is on CPU." % node)
elif config.assert_no_cpu_op == "pdb": elif config.assert_no_cpu_op == "pdb":
pdb.set_trace() pdb.set_trace()
...@@ -464,7 +464,7 @@ def local_assert_no_cpu_op(node): ...@@ -464,7 +464,7 @@ def local_assert_no_cpu_op(node):
#Register the local_assert_no_cpu_op: #Register the local_assert_no_cpu_op:
assert_no_cpu_op = theano.tensor.opt.in2out(local_assert_no_cpu_op, name='assert_no_cpu_op') assert_no_cpu_op = theano.tensor.opt.in2out(local_assert_no_cpu_op, name='assert_no_cpu_op')
# 48.7 is after specialize device # 48.7 is after specialize device
theano.compile.optdb.register('assert_no_cpu_op', assert_no_cpu_op, 48.7) theano.compile.optdb.register('assert_no_cpu_op', assert_no_cpu_op, 49.2)
@register_opt() @register_opt()
......
...@@ -96,7 +96,7 @@ def test_local_assert_no_cpu_op(): ...@@ -96,7 +96,7 @@ def test_local_assert_no_cpu_op():
numpy.random.seed(1) numpy.random.seed(1)
m = numpy.random.uniform(-1, 1, (10, 10)).astype("float32") m = numpy.random.uniform(-1, 1, (10, 10)).astype("float32")
ms = cuda.shared_constructor(m, name="m_shared") ms = cuda.shared_constructor(m, name="m_shared")
out = theano.tensor.tanh(ms**2 + ms).dot(ms.T) out = theano.tensor.tanh(ms).dot(ms.T)
mode_local_assert = mode_with_gpu.including("assert_no_cpu_op") mode_local_assert = mode_with_gpu.including("assert_no_cpu_op")
mode_local_assert = mode_local_assert.excluding("local_gpu_elemwise_0") mode_local_assert = mode_local_assert.excluding("local_gpu_elemwise_0")
...@@ -108,7 +108,7 @@ def test_local_assert_no_cpu_op(): ...@@ -108,7 +108,7 @@ def test_local_assert_no_cpu_op():
try: try:
config.assert_no_cpu_op = 'raise' config.assert_no_cpu_op = 'raise'
assert_raises(RuntimeError, theano.function, assert_raises(AssertionError, theano.function,
[], out, mode=mode_local_assert) [], out, mode=mode_local_assert)
finally: finally:
config.assert_no_cpu_op = old config.assert_no_cpu_op = old
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论