提交 3136ac1c authored 作者: orhanf's avatar orhanf

including/excluding 'unsafe' modes added, 'unsafe' registered in tensor/opt, tests updated

上级 b359a356
...@@ -59,10 +59,14 @@ def test_local_assert(): ...@@ -59,10 +59,14 @@ def test_local_assert():
def test_local_remove_all_assert(): def test_local_remove_all_assert():
x = theano.tensor.fmatrix() x = theano.tensor.fmatrix()
a = theano.tensor.opt.assert_op(x, theano.tensor.eq(x, 0).any()) a = theano.tensor.opt.assert_op(x, theano.tensor.eq(x, 0).any())
f = theano.function([x], a, mode=mode_with_gpu) f = theano.function([x], a, mode=mode_with_gpu.including('unsafe'))
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
a_op = [n for n in topo if isinstance(n.op, theano.tensor.opt.Assert)] a_op = [n for n in topo if isinstance(n.op, theano.tensor.opt.Assert)]
assert len(a_op) == 0 assert len(a_op) == 0
f = theano.function([x], a, mode=mode_with_gpu.excluding('unsafe'))
topo = f.maker.fgraph.toposort()
a_op = [n for n in topo if isinstance(n.op, theano.tensor.opt.Assert)]
assert len(a_op) == 1
def test_int_pow(): def test_int_pow():
......
...@@ -28,10 +28,14 @@ def test_local_assert(): ...@@ -28,10 +28,14 @@ def test_local_assert():
def test_local_remove_all_assert(): def test_local_remove_all_assert():
x = theano.tensor.fmatrix() x = theano.tensor.fmatrix()
a = theano.tensor.opt.assert_op(x, theano.tensor.eq(x, 0).any()) a = theano.tensor.opt.assert_op(x, theano.tensor.eq(x, 0).any())
f = theano.function([x], a, mode=mode_with_gpu) f = theano.function([x], a, mode=mode_with_gpu.including('unsafe'))
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
a_op = [n for n in topo if isinstance(n.op, theano.tensor.opt.Assert)] a_op = [n for n in topo if isinstance(n.op, theano.tensor.opt.Assert)]
assert len(a_op) == 0 assert len(a_op) == 0
f = theano.function([x], a, mode=mode_with_gpu.excluding('unsafe'))
topo = f.maker.fgraph.toposort()
a_op = [n for n in topo if isinstance(n.op, theano.tensor.opt.Assert)]
assert len(a_op) == 1
def test_flatten(): def test_flatten():
......
...@@ -1670,12 +1670,15 @@ def local_remove_all_assert(node): ...@@ -1670,12 +1670,15 @@ def local_remove_all_assert(node):
# Disabled by default # Disabled by default
compile.optdb['canonicalize'].register('local_remove_all_assert', compile.optdb['canonicalize'].register('local_remove_all_assert',
local_remove_all_assert, local_remove_all_assert,
'unsafe',
use_db_name_as_tag=False) use_db_name_as_tag=False)
compile.optdb['stabilize'].register('local_remove_all_assert', compile.optdb['stabilize'].register('local_remove_all_assert',
local_remove_all_assert, local_remove_all_assert,
'unsafe',
use_db_name_as_tag=False) use_db_name_as_tag=False)
compile.optdb['specialize'].register('local_remove_all_assert', compile.optdb['specialize'].register('local_remove_all_assert',
local_remove_all_assert, local_remove_all_assert,
'unsafe',
use_db_name_as_tag=False) use_db_name_as_tag=False)
......
...@@ -3553,6 +3553,17 @@ class test_assert(utt.InferShapeTester): ...@@ -3553,6 +3553,17 @@ class test_assert(utt.InferShapeTester):
assert len(topo) == 1, topo assert len(topo) == 1, topo
assert topo[0].op == deep_copy_op, topo assert topo[0].op == deep_copy_op, topo
mode = compile.mode.get_default_mode()
a = theano.tensor.opt.assert_op(x, T.eq(x, 0).any())
f = theano.function([x], a, mode=mode.including('unsafe'))
topo = f.maker.fgraph.toposort()
a_op = [n for n in topo if isinstance(n.op, T.opt.Assert)]
assert len(a_op) == 0
f = theano.function([x], a, mode=mode.excluding('unsafe'))
topo = f.maker.fgraph.toposort()
a_op = [n for n in topo if isinstance(n.op, T.opt.Assert)]
assert len(a_op) == 1
def test_infer_shape(self): def test_infer_shape(self):
adscal = dscalar() adscal = dscalar()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论