提交 ddef757d authored 作者: Dustin Webb's avatar Dustin Webb 提交者: Amjad Almahairi

Added tests for optimization of inequalities comparing a variable with itself.

上级 10855815
...@@ -4295,7 +4295,7 @@ def local_useless_elemwise_comparison(node): ...@@ -4295,7 +4295,7 @@ def local_useless_elemwise_comparison(node):
# Elemwise[{LT,GT}](X, X) -> Elemwise[zeros](X) # Elemwise[{LT,GT}](X, X) -> Elemwise[zeros](X)
if (isinstance(node.op.scalar_op, (scalar.LT, scalar.GT)) and if (isinstance(node.op.scalar_op, (scalar.LT, scalar.GT)) and
node.inputs[0] is node.inputs[1]): node.inputs[0] is node.inputs[1]):
return [T.zeros_like(node.outputs[0])] return [T.zeros_like(node.outputs[0], dtype=node.outputs[0].type.dtype)]
# Elemwise[{LE,GE}](X, X) -> Elemwise[ones](X) # Elemwise[{LE,GE}](X, X) -> Elemwise[ones](X)
if (isinstance(node.op.scalar_op, (scalar.LE, scalar.GE)) and if (isinstance(node.op.scalar_op, (scalar.LE, scalar.GE)) and
node.inputs[0] is node.inputs[1]): node.inputs[0] is node.inputs[1]):
......
...@@ -3136,20 +3136,45 @@ def test_local_fill_useless(): ...@@ -3136,20 +3136,45 @@ def test_local_fill_useless():
f(m_, x_) f(m_, x_)
def test_local_useless_elemwise_comparison(): def assert_eqs_const(topo, val):
# TODO: test each case individually. elem = topo[0]
# The following case is what made me discover those cases. assert len(topo) == 1
X = T.matrix('X') assert elem.op == deep_copy_op
Y = T.vector('Y') assert len(elem.inputs) == 1
X_sum, updates = theano.scan(fn=lambda x: x.sum(), assert isinstance(elem.inputs[0], T.TensorConstant)
outputs_info=None, assert T.extract_constant(elem.inputs[0]) == val
sequences=[X],
non_sequences=None)
Z = X_sum + Y class Test_local_useless_elemwise_comparison(unittest.TestCase):
theano.printing.debugprint(Z) def test_local_useless_elemwise_comparison(self):
mode = theano.compile.get_default_mode().excluding('fusion') # TODO: test each case individually.
f = theano.function([X, Y], Z, mode=mode) # The following case is what made me discover those cases.
theano.printing.debugprint(f, print_type=True) X = T.matrix('X')
Y = T.vector('Y')
X_sum, updates = theano.scan(fn=lambda x: x.sum(),
outputs_info=None,
sequences=[X],
non_sequences=None)
Z = X_sum + Y
theano.printing.debugprint(Z)
mode = theano.compile.get_default_mode().excluding('fusion')
f = theano.function([X, Y], Z, mode=mode)
theano.printing.debugprint(f, print_type=True)
def test_inequality_with_self(self):
x = T.scalar('x', dtype=config.floatX)
f = theano.function([x], T.lt(x, x))
assert_eqs_const(f.maker.fgraph.toposort(), 0)
f = theano.function([x], T.le(x, x))
assert_eqs_const(f.maker.fgraph.toposort(), 1)
f = theano.function([x], T.gt(x, x))
assert_eqs_const(f.maker.fgraph.toposort(), 0)
f = theano.function([x], T.ge(x, x))
assert_eqs_const(f.maker.fgraph.toposort(), 1)
class Test_local_useless_alloc(unittest.TestCase): class Test_local_useless_alloc(unittest.TestCase):
...@@ -3924,28 +3949,30 @@ class T_useless_elemwise(unittest.TestCase): ...@@ -3924,28 +3949,30 @@ class T_useless_elemwise(unittest.TestCase):
assert len(topo) == 1 assert len(topo) == 1
assert topo[0].op == deep_copy_op assert topo[0].op == deep_copy_op
def assert_eqs_const(self, topo, val):
elem = topo[0]
assert len(topo) == 1
assert elem.op == deep_copy_op
assert len(elem.inputs) == 1
assert isinstance(elem.inputs[0], T.TensorConstant)
assert T.extract_constant(elem.inputs[0]) == val
def assert_identity(self, f): def assert_identity(self, f):
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert topo[0].op == deep_copy_op assert topo[0].op == deep_copy_op
x_val = numpy.random.randint(256) x_val = -128
assert f(x_val) == x_val
x_val = -1
assert f(x_val) == x_val
x_val = 0
assert f(x_val) == x_val
x_val = 1
assert f(x_val) == x_val
x_val = 127
assert f(x_val) == x_val
x_val = numpy.random.randint(255)-128
assert f(x_val) == x_val assert f(x_val) == x_val
def test_and(self): def test_and(self):
x = T.scalar('x', dtype='int64') x = T.scalar('x', dtype='int8')
f = theano.function([x], T.and_(x, 0), mode=self.mode) f = theano.function([x], T.and_(x, 0), mode=self.mode)
self.assert_eqs_const(f.maker.fgraph.toposort(), 0) assert_eqs_const(f.maker.fgraph.toposort(), 0)
f = theano.function([x], T.and_(0, x), mode=self.mode) f = theano.function([x], T.and_(0, x), mode=self.mode)
self.assert_eqs_const(f.maker.fgraph.toposort(), 0) assert_eqs_const(f.maker.fgraph.toposort(), 0)
f = theano.function([x], T.and_(x, 1), mode=self.mode) f = theano.function([x], T.and_(x, 1), mode=self.mode)
self.assert_identity(f) self.assert_identity(f)
...@@ -3954,13 +3981,13 @@ class T_useless_elemwise(unittest.TestCase): ...@@ -3954,13 +3981,13 @@ class T_useless_elemwise(unittest.TestCase):
self.assert_identity(f) self.assert_identity(f)
def test_or(self): def test_or(self):
x = T.scalar('x', dtype='int64') x = T.scalar('x', dtype='int8')
f = theano.function([x], T.or_(x, 1), mode=self.mode) f = theano.function([x], T.or_(x, 1), mode=self.mode)
self.assert_eqs_const(f.maker.fgraph.toposort(), 1) assert_eqs_const(f.maker.fgraph.toposort(), 1)
f = theano.function([x], T.or_(1, x), mode=self.mode) f = theano.function([x], T.or_(1, x), mode=self.mode)
self.assert_eqs_const(f.maker.fgraph.toposort(), 1) assert_eqs_const(f.maker.fgraph.toposort(), 1)
f = theano.function([x], T.or_(x, 0), mode=self.mode) f = theano.function([x], T.or_(x, 0), mode=self.mode)
self.assert_identity(f) self.assert_identity(f)
...@@ -3969,10 +3996,10 @@ class T_useless_elemwise(unittest.TestCase): ...@@ -3969,10 +3996,10 @@ class T_useless_elemwise(unittest.TestCase):
self.assert_identity(f) self.assert_identity(f)
def test_xor(self): def test_xor(self):
x = T.scalar('x', dtype='int64') x = T.scalar('x', dtype='int8')
f = theano.function([x], T.xor(x, x), mode=self.mode) f = theano.function([x], T.xor(x, x), mode=self.mode)
self.assert_eqs_const(f.maker.fgraph.toposort(), 0) assert_eqs_const(f.maker.fgraph.toposort(), 0)
class T_cast_cast(unittest.TestCase): class T_cast_cast(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论