提交 cdfa5e85 authored 作者: orhanf's avatar orhanf

remove unsafe by default, add tests respectively

上级 5cfff95b
...@@ -96,7 +96,11 @@ register_opt()(theano.tensor.opt.local_track_shape_i) ...@@ -96,7 +96,11 @@ register_opt()(theano.tensor.opt.local_track_shape_i)
register_opt(name='gpu_constant_folding')( register_opt(name='gpu_constant_folding')(
tensor.opt.constant_folding) tensor.opt.constant_folding)
register_opt()(theano.tensor.opt.local_subtensor_make_vector) register_opt()(theano.tensor.opt.local_subtensor_make_vector)
register_opt('unsafe')(theano.tensor.opt.local_remove_all_assert)
# Register local_remove_all_assert as a global opt
gpu_optimizer.register('local_remove_all_assert',
theano.tensor.opt.local_remove_all_assert,
'unsafe')
# This is a partial list of CPU ops that can be in some circonstance # This is a partial list of CPU ops that can be in some circonstance
......
...@@ -59,10 +59,20 @@ def test_local_assert(): ...@@ -59,10 +59,20 @@ def test_local_assert():
def test_local_remove_all_assert(): def test_local_remove_all_assert():
x = theano.tensor.fmatrix() x = theano.tensor.fmatrix()
a = theano.tensor.opt.assert_op(x, theano.tensor.eq(x, 0).any()) a = theano.tensor.opt.assert_op(x, theano.tensor.eq(x, 0).any())
# By default `unsafe` should not be there
f = theano.function([x], a, mode=mode_with_gpu)
topo = f.maker.fgraph.toposort()
a_op = [n for n in topo if isinstance(n.op, theano.tensor.opt.Assert)]
assert len(a_op) == 1
# Put `unsafe`
f = theano.function([x], a, mode=mode_with_gpu.including('unsafe')) f = theano.function([x], a, mode=mode_with_gpu.including('unsafe'))
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
a_op = [n for n in topo if isinstance(n.op, theano.tensor.opt.Assert)] a_op = [n for n in topo if isinstance(n.op, theano.tensor.opt.Assert)]
assert len(a_op) == 0 assert len(a_op) == 0
# Remove `unsafe`
f = theano.function([x], a, mode=mode_with_gpu.excluding('unsafe')) f = theano.function([x], a, mode=mode_with_gpu.excluding('unsafe'))
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
a_op = [n for n in topo if isinstance(n.op, theano.tensor.opt.Assert)] a_op = [n for n in topo if isinstance(n.op, theano.tensor.opt.Assert)]
......
...@@ -59,7 +59,10 @@ def register_opt(*tags, **kwargs): ...@@ -59,7 +59,10 @@ def register_opt(*tags, **kwargs):
return f return f
register_opt('fast_compile')(theano.tensor.opt.local_track_shape_i) register_opt('fast_compile')(theano.tensor.opt.local_track_shape_i)
register_opt('unsafe')(theano.tensor.opt.local_remove_all_assert)
gpu_optimizer.register('local_remove_all_assert',
theano.tensor.opt.local_remove_all_assert,
'unsafe')
def safe_to_gpu(x): def safe_to_gpu(x):
......
...@@ -28,10 +28,20 @@ def test_local_assert(): ...@@ -28,10 +28,20 @@ def test_local_assert():
def test_local_remove_all_assert(): def test_local_remove_all_assert():
x = theano.tensor.fmatrix() x = theano.tensor.fmatrix()
a = theano.tensor.opt.assert_op(x, theano.tensor.eq(x, 0).any()) a = theano.tensor.opt.assert_op(x, theano.tensor.eq(x, 0).any())
# By default `unsafe` should not be there
f = theano.function([x], a, mode=mode_with_gpu)
topo = f.maker.fgraph.toposort()
a_op = [n for n in topo if isinstance(n.op, theano.tensor.opt.Assert)]
assert len(a_op) == 1
# Put `unsafe`
f = theano.function([x], a, mode=mode_with_gpu.including('unsafe')) f = theano.function([x], a, mode=mode_with_gpu.including('unsafe'))
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
a_op = [n for n in topo if isinstance(n.op, theano.tensor.opt.Assert)] a_op = [n for n in topo if isinstance(n.op, theano.tensor.opt.Assert)]
assert len(a_op) == 0 assert len(a_op) == 0
# Remove `unsafe`
f = theano.function([x], a, mode=mode_with_gpu.excluding('unsafe')) f = theano.function([x], a, mode=mode_with_gpu.excluding('unsafe'))
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
a_op = [n for n in topo if isinstance(n.op, theano.tensor.opt.Assert)] a_op = [n for n in topo if isinstance(n.op, theano.tensor.opt.Assert)]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论