提交 07fa43dc authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Update tests.

This also reflects the difference between bool and int8
上级 629fa735
...@@ -3493,7 +3493,11 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase): ...@@ -3493,7 +3493,11 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert len(topo) == 1 assert len(topo) == 1
assert topo[0].op == deep_copy_op assert topo[0].op == deep_copy_op
x_val = 10 if f.outputs[0].variable.dtype == 'bool':
x_vals = [0, 1]
else:
x_vals = [0, 1, 10]
for x_val in x_vals:
assert f(x_val) == x_val assert f(x_val) == x_val
def test_inequality_with_self(self): def test_inequality_with_self(self):
...@@ -3601,11 +3605,14 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase): ...@@ -3601,11 +3605,14 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
assert (f([3, 3]) == 0).all() assert (f([3, 3]) == 0).all()
def test_and(self): def test_and(self):
# bitwise "and" with 0 should give 0 for both bool and int
# bitwise "and" with 1 should only simplify for bool
mode = theano.compile.get_default_mode().including('canonicalize') mode = theano.compile.get_default_mode().including('canonicalize')
for dtype, zero, one in [('bool', np.array(False), np.array(True)),
('int8', np.int8(0), np.int8(1)),
('int8', 0, 1)]:
x = T.scalar('x', dtype=dtype)
x = T.scalar('x', dtype='int8')
for zero, one in [(np.int8(0), np.int8(1)), (0, 1)]:
f = theano.function([x], T.and_(x, zero), mode=mode) f = theano.function([x], T.and_(x, zero), mode=mode)
self.assert_eqs_const(f, 0) self.assert_eqs_const(f, 0)
...@@ -3613,35 +3620,51 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase): ...@@ -3613,35 +3620,51 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
self.assert_eqs_const(f, 0) self.assert_eqs_const(f, 0)
f = theano.function([x], T.and_(x, one), mode=mode) f = theano.function([x], T.and_(x, one), mode=mode)
if f.outputs[0].variable.dtype == x.dtype: if dtype == 'bool':
self.assert_identity(f) self.assert_identity(f)
f = theano.function([x], T.and_(one, x), mode=mode) f = theano.function([x], T.and_(one, x), mode=mode)
if f.outputs[0].variable.dtype == x.dtype: if dtype == 'bool':
self.assert_identity(f) self.assert_identity(f)
def test_and_int(self):
# Test that bitwise "and" is correctly computed on int constants.
f = theano.function([], T.and_(5, 6))
assert f() == 4
def test_or(self): def test_or(self):
# bitwise "or" with 0 should simplify for both bool and int
# bitwise "or" with 1 should only give 1 for bool
mode = theano.compile.get_default_mode().including('canonicalize') mode = theano.compile.get_default_mode().including('canonicalize')
x = T.scalar('x', dtype='int8') for dtype, zero, one in [('bool', np.array(False), np.array(True)),
('int8', np.int8(0), np.int8(1)),
('int8', 0, 1)]:
x = T.scalar('x', dtype=dtype)
for zero, one in [(np.int8(0), np.int8(1)), (0, 1)]:
f = theano.function([x], T.or_(x, one), mode=mode) f = theano.function([x], T.or_(x, one), mode=mode)
if dtype == 'bool':
self.assert_eqs_const(f, 1) self.assert_eqs_const(f, 1)
f = theano.function([x], T.or_(one, x), mode=mode) f = theano.function([x], T.or_(one, x), mode=mode)
if dtype == 'bool':
self.assert_eqs_const(f, 1) self.assert_eqs_const(f, 1)
f = theano.function([x], T.or_(x, zero), mode=mode) f = theano.function([x], T.or_(x, zero), mode=mode)
if f.outputs[0].variable.dtype == x.dtype:
self.assert_identity(f) self.assert_identity(f)
f = theano.function([x], T.or_(zero, x), mode=mode) f = theano.function([x], T.or_(zero, x), mode=mode)
if f.outputs[0].variable.dtype == x.dtype:
self.assert_identity(f) self.assert_identity(f)
def test_or_int(self):
# Test that bitwise "or" is correctly computed on int constants.
f = theano.function([], T.or_(5, 6))
assert f() == 7
def test_xor(self): def test_xor(self):
# bitwise "xor" with itself should always give 0 for both bool and int.
mode = theano.compile.get_default_mode().including('canonicalize') mode = theano.compile.get_default_mode().including('canonicalize')
x = T.scalar('x', dtype='int8') for dtype in ('bool', 'int8'):
x = T.scalar('x', dtype=dtype)
f = theano.function([x], T.xor(x, x), mode=mode) f = theano.function([x], T.xor(x, x), mode=mode)
self.assert_eqs_const(f, 0) self.assert_eqs_const(f, 0)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论