提交 7c52a043 authored 作者: Iulian Vlad Serban's avatar Iulian Vlad Serban

Implemented extensive tests for opt local_elemwise_sub_zero, including stack trace testss.

上级 e009e8f9
...@@ -3482,6 +3482,61 @@ def test_local_fill_useless(): ...@@ -3482,6 +3482,61 @@ def test_local_fill_useless():
f(m_, x_) f(m_, x_)
def test_local_elemwise_sub_zeros():
# Test opt local_elemwise_sub_zeros
# We test separetly for scalars, vectors and matrices
scalar = T.scalar()
vect = T.vector()
mat = T.matrix()
rng = numpy.random.RandomState(seed=utt.fetch_seed())
scalar_val = rng.rand(1)[0]
vect_val = rng.rand(5)
mat_val = rng.rand(3, 2)
mode = theano.compile.get_default_mode()\
.excluding('canonicalize', 'uncanonicalize',\
'ShapeOpt', 'local_fill_to_alloc',\
'local_elemwise_alloc')\
.including('local_elemwise_sub_zeros')
# Test scalar minus scalar
f = function([scalar], scalar-scalar, mode=mode)
# Check optimized graph is correct
assert isinstance(f.maker.fgraph.toposort()[0].op, T.Elemwise)
assert isinstance(f.maker.fgraph.toposort()[0].inputs[1],\
T.TensorConstant) or\
isinstance(f.maker.fgraph.toposort()[0].inputs[1],\
T.TensorConstant)
utt.assert_allclose(f(scalar_val), 0.0)
# Check stack trace is copied over
assert check_stack_trace(f, ops_to_check='all')
# Test vector minus vector
f = function([vect], vect-vect, mode=mode)
# Check optimized graph is correct
assert isinstance(f.maker.fgraph.toposort()[0].op, T.Elemwise)
assert isinstance(f.maker.fgraph.toposort()[0].inputs[1],\
T.TensorConstant) or\
isinstance(f.maker.fgraph.toposort()[0].inputs[1],\
T.TensorConstant)
utt.assert_allclose(f(vect_val), numpy.zeros(vect_val.shape))
# Check stack trace is copied over
assert check_stack_trace(f, ops_to_check='all')
# Test vector minus vector
f = function([mat], mat-mat, mode=mode)
# Check optimized graph is correct
assert isinstance(f.maker.fgraph.toposort()[0].op, T.Elemwise)
assert isinstance(f.maker.fgraph.toposort()[0].inputs[1],\
T.TensorConstant) or\
isinstance(f.maker.fgraph.toposort()[0].inputs[1],\
T.TensorConstant)
utt.assert_allclose(f(mat_val), numpy.zeros(mat_val.shape))
# Check stack trace is copied over
assert check_stack_trace(f, ops_to_check='all')
class Test_local_useless_elemwise_comparison(unittest.TestCase): class Test_local_useless_elemwise_comparison(unittest.TestCase):
def setUp(self): def setUp(self):
self.rng = numpy.random.RandomState(utt.fetch_seed()) self.rng = numpy.random.RandomState(utt.fetch_seed())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论