提交 0f1edb05 authored 作者: Amjad Almahairi's avatar Amjad Almahairi

test corrections + new tests

上级 8b91e067
......@@ -3135,15 +3135,6 @@ def test_local_fill_useless():
assert T.Alloc in ops
f(m_, x_)
def assert_eqs_const(topo, val):
elem = topo[0]
assert len(topo) == 1, topo
assert elem.op == deep_copy_op, elem.op
assert len(elem.inputs) == 1, elem.inputs
assert isinstance(elem.inputs[0], T.TensorConstant), elem
assert T.extract_constant(elem.inputs[0]) == val, val
class Test_local_useless_elemwise_comparison(unittest.TestCase):
def test_local_useless_elemwise_comparison(self):
......@@ -3212,48 +3203,98 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
> |X[t] [@O] <TensorType(float64, vector)> -> [@E]
"""
def assert_eqs_const(self, f, val):
topo = f.maker.fgraph.toposort()
elem = topo[0]
assert len(topo) == 1, topo
assert elem.op == deep_copy_op, elem.op
assert len(elem.inputs) == 1, elem.inputs
assert isinstance(elem.inputs[0], T.TensorConstant), elem
assert T.extract_constant(elem.inputs[0]) == val, val
def assert_identity(self, f):
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert topo[0].op == deep_copy_op
x_val = 10
assert f(x_val) == x_val
#def assert_returns
def test_inequality_with_self(self):
x = T.scalar('x', dtype=config.floatX)
mode = theano.compile.get_default_mode().including('local_useless_elemwise_comparison')
f = theano.function([x], T.lt(x, x), mode=mode)
assert_eqs_const(f.maker.fgraph.toposort(), 0)
self.assert_eqs_const(f, 0)
f = theano.function([x], T.le(x, x), mode=mode)
assert_eqs_const(f.maker.fgraph.toposort(), 1)
self.assert_eqs_const(f, 1)
f = theano.function([x], T.gt(x, x), mode=mode)
assert_eqs_const(f.maker.fgraph.toposort(), 0)
self.assert_eqs_const(f, 0)
f = theano.function([x], T.ge(x, x), mode=mode)
assert_eqs_const(f.maker.fgraph.toposort(), 1)
self.assert_eqs_const(f, 1)
def assert_identity(self, f):
f = theano.function([x], T.minimum(x, x), mode=mode)
self.assert_identity(f)
f = theano.function([x], T.maximum(x, x), mode=mode)
self.assert_identity(f)
def test_shape_inequality_with_self(self):
x = T.vector('x', dtype=config.floatX)
mode = theano.compile.get_default_mode().including('local_useless_elemwise_comparison')
f = theano.function([x], T.lt(x.shape[0], 0), mode=mode)
self.assert_eqs_const(f, 0)
f = theano.function([x], T.ge(x.shape[0], 0), mode=mode)
self.assert_eqs_const(f, 1)
f = theano.function([x], T.maximum(x.shape[0], 0), mode=mode)
topo = f.maker.fgraph.toposort()
assert topo[0].op == deep_copy_op
x_val = -128
assert f(x_val) == x_val
x_val = -1
assert f(x_val) == x_val
x_val = 0
assert f(x_val) == x_val
x_val = 1
assert f(x_val) == x_val
x_val = 127
assert f(x_val) == x_val
x_val = numpy.random.randint(255)-128
assert f(x_val) == x_val
assert len(topo) == 1
assert isinstance(topo[0].op, Shape_i), topo[0].op
x_val = numpy.ones(100)
assert f(x_val) == x_val.shape[0]
f = theano.function([x], T.maximum(0, x.shape[0]), mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op, Shape_i), topo[0].op
x_val = numpy.ones(100)
assert f(x_val) == x_val.shape[0]
f = theano.function([x], T.minimum(x.shape[0], 0), mode=mode)
self.assert_eqs_const(f, 0)
f = theano.function([x], T.minimum(0, x.shape[0]), mode=mode)
self.assert_eqs_const(f, 0)
def test_shape_add_inequality(self):
x = T.vector('x', dtype=config.floatX)
mode = theano.compile.get_default_mode().including('local_useless_elemwise_comparison')
y = T.vector('y', dtype=config.floatX)
f = theano.function([x, y], T.lt(x.shape[0]+y.shape[0], 0), mode=mode)
self.assert_eqs_const(f, 0)
f = theano.function([x, y], T.ge(x.shape[0]+y.shape[0], 0), mode=mode)
self.assert_eqs_const(f, 1)
def test_and(self):
mode = theano.compile.get_default_mode().including('canonicalize')
x = T.scalar('x', dtype='int8')
f = theano.function([x], T.and_(x, 0), mode=mode)
assert_eqs_const(f.maker.fgraph.toposort(), 0)
self.assert_eqs_const(f, 0)
f = theano.function([x], T.and_(0, x), mode=mode)
assert_eqs_const(f.maker.fgraph.toposort(), 0)
self.assert_eqs_const(f, 0)
f = theano.function([x], T.and_(x, 1), mode=mode)
self.assert_identity(f)
......@@ -3266,10 +3307,10 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
x = T.scalar('x', dtype='int8')
f = theano.function([x], T.or_(x, 1), mode=mode)
assert_eqs_const(f.maker.fgraph.toposort(), 1)
self.assert_eqs_const(f, 1)
f = theano.function([x], T.or_(1, x), mode=mode)
assert_eqs_const(f.maker.fgraph.toposort(), 1)
self.assert_eqs_const(f, 1)
f = theano.function([x], T.or_(x, 0), mode=mode)
self.assert_identity(f)
......@@ -3282,8 +3323,7 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
x = T.scalar('x', dtype='int8')
f = theano.function([x], T.xor(x, x), mode=mode)
assert_eqs_const(f.maker.fgraph.toposort(), 0)
self.assert_eqs_const(f, 0)
class Test_local_useless_alloc(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论