提交 8ce2395b authored 作者: orhanf's avatar orhanf

Registered on sandbox/gpuArray/opt.py and sandbox/cuda/opt.py, tests added

上级 b24d1a8d
......@@ -96,6 +96,7 @@ register_opt()(theano.tensor.opt.local_track_shape_i)
register_opt(name='gpu_constant_folding')(
tensor.opt.constant_folding)
register_opt()(theano.tensor.opt.local_subtensor_make_vector)
register_opt('unsafe')(theano.tensor.opt.local_remove_all_assert)
# This is a partial list of CPU ops that can be in some circonstance
......
......@@ -56,6 +56,15 @@ def test_local_assert():
assert isinstance(a_op[0].inputs[0].type, CudaNdarrayType)
def test_local_remove_all_assert():
x = theano.tensor.fmatrix()
a = theano.tensor.opt.assert_op(x, theano.tensor.eq(x, 0).any())
f = theano.function([x], a, mode=mode_with_gpu)
topo = f.maker.fgraph.toposort()
a_op = [n for n in topo if isinstance(n.op, theano.tensor.opt.Assert)]
assert len(a_op) == 0
def test_int_pow():
a = CudaNdarrayType([False])()
......
......@@ -59,7 +59,7 @@ def register_opt(*tags, **kwargs):
return f
register_opt('fast_compile')(theano.tensor.opt.local_track_shape_i)
register_opt('unsafe')(theano.tensor.opt.local_remove_all_assert)
def safe_to_gpu(x):
if isinstance(x.type, tensor.TensorType):
......
......@@ -25,6 +25,15 @@ def test_local_assert():
assert isinstance(a_op[0].inputs[0].type, GpuArrayType)
def test_local_remove_all_assert():
x = theano.tensor.fmatrix()
a = theano.tensor.opt.assert_op(x, theano.tensor.eq(x, 0).any())
f = theano.function([x], a, mode=mode_with_gpu)
topo = f.maker.fgraph.toposort()
a_op = [n for n in topo if isinstance(n.op, theano.tensor.opt.Assert)]
assert len(a_op) == 0
def test_flatten():
m = theano.tensor.fmatrix()
f = theano.function([m], m.flatten(), mode=mode_with_gpu)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论