提交 8dcc5fc6 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #6029 from lamblin/fix_useless_bitwise

Fix issue in optimizations with bitwise operations
......@@ -181,8 +181,9 @@ def as_tensor_variable(x, name=None, ndim=None):
if isinstance(x, bool):
raise AsTensorError(
"Cannot cast True or False as a tensor variable. Please use 1 or "
"0. This error might be caused by using the == operator on "
"Cannot cast True or False as a tensor variable. Please use "
"np.array(True) or np.array(False) if you need these constants. "
"This error might be caused by using the == operator on "
"Variables. v == w does not do what you think it does, "
"use theano.tensor.eq(v, w) instead.")
......
......@@ -2087,16 +2087,16 @@ def local_subtensor_make_vector(node):
@gof.local_optimizer([T.Elemwise])
def local_useless_elemwise(node):
"""
eq(x,x) -> 1
neq(x,x) -> 0
eq(x, x) -> 1
neq(x, x) -> 0
mul(x) -> x
add(x) -> x
identity(x) -> x
and(x,1) -> x
and(x,0) -> zeros_like(x)
or(x,0) -> x
or(x,1) -> ones_like(x)
xor(x,x) -> zeros_like(x)
and(x, 1) -> x (if x.dtype == 'bool')
and(x, 0) -> zeros_like(x)
or(x, 0) -> x
or(x, 1) -> ones_like(x) (if x.dtype == 'bool')
xor(x, x) -> zeros_like(x)
"""
if isinstance(node.op, T.Elemwise):
......@@ -2141,7 +2141,9 @@ def local_useless_elemwise(node):
if const_val == 0:
return [T.zeros_like(node.inputs[1], dtype=dtype,
opt=True)]
else:
elif node.outputs[0].dtype == 'bool':
# If the output is not Boolean, it is the bitwise AND,
# and this optimization would be wrong
return [node.inputs[1].astype(node.outputs[0].dtype)]
if isinstance(node.inputs[1], T.TensorConstant):
......@@ -2150,7 +2152,9 @@ def local_useless_elemwise(node):
if const_val == 0:
return [T.zeros_like(node.inputs[0], dtype=dtype,
opt=True)]
else:
elif node.outputs[0].dtype == 'bool':
# If the output is not Boolean, it is the bitwise AND,
# and this optimization would be wrong
return [node.inputs[0].astype(node.outputs[0].dtype)]
elif (isinstance(node.op.scalar_op, scalar.OR) and
......@@ -2161,7 +2165,9 @@ def local_useless_elemwise(node):
if not isinstance(const_val, Variable):
if const_val == 0:
return [node.inputs[1].astype(node.outputs[0].dtype)]
else:
elif node.outputs[0].dtype == 'bool':
# If the output is not Boolean, it is the bitwise OR,
# and this optimization would be wrong
return [T.ones_like(node.inputs[1], dtype=dtype,
opt=True)]
......@@ -2170,7 +2176,9 @@ def local_useless_elemwise(node):
if not isinstance(const_val, Variable):
if const_val == 0:
return [node.inputs[0].astype(node.outputs[0].dtype)]
else:
elif node.outputs[0].dtype == 'bool':
# If the output is not Boolean, it is the bitwise OR,
# and this optimization would be wrong
return [T.ones_like(node.inputs[0], dtype=dtype,
opt=True)]
......
......@@ -3493,8 +3493,12 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert topo[0].op == deep_copy_op
x_val = 10
assert f(x_val) == x_val
if f.outputs[0].variable.dtype == 'bool':
x_vals = [0, 1]
else:
x_vals = [0, 1, 10]
for x_val in x_vals:
assert f(x_val) == x_val
def test_inequality_with_self(self):
x = T.scalar('x', dtype=config.floatX)
......@@ -3601,11 +3605,14 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
assert (f([3, 3]) == 0).all()
def test_and(self):
# bitwise "and" with 0 should give 0 for both bool and int
# bitwise "and" with 1 should only simplify for bool
mode = theano.compile.get_default_mode().including('canonicalize')
for dtype, zero, one in [('bool', np.array(False), np.array(True)),
('int8', np.int8(0), np.int8(1)),
('int8', 0, 1)]:
x = T.scalar('x', dtype=dtype)
x = T.scalar('x', dtype='int8')
for zero, one in [(np.int8(0), np.int8(1)), (0, 1)]:
f = theano.function([x], T.and_(x, zero), mode=mode)
self.assert_eqs_const(f, 0)
......@@ -3613,38 +3620,54 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
self.assert_eqs_const(f, 0)
f = theano.function([x], T.and_(x, one), mode=mode)
if f.outputs[0].variable.dtype == x.dtype:
if dtype == 'bool':
self.assert_identity(f)
f = theano.function([x], T.and_(one, x), mode=mode)
if f.outputs[0].variable.dtype == x.dtype:
if dtype == 'bool':
self.assert_identity(f)
def test_and_int(self):
# Test that bitwise "and" is correctly computed on int constants.
f = theano.function([], T.and_(5, 6))
assert f() == 4
def test_or(self):
# bitwise "or" with 0 should simplify for both bool and int
# bitwise "or" with 1 should only give 1 for bool
mode = theano.compile.get_default_mode().including('canonicalize')
x = T.scalar('x', dtype='int8')
for dtype, zero, one in [('bool', np.array(False), np.array(True)),
('int8', np.int8(0), np.int8(1)),
('int8', 0, 1)]:
x = T.scalar('x', dtype=dtype)
for zero, one in [(np.int8(0), np.int8(1)), (0, 1)]:
f = theano.function([x], T.or_(x, one), mode=mode)
self.assert_eqs_const(f, 1)
if dtype == 'bool':
self.assert_eqs_const(f, 1)
f = theano.function([x], T.or_(one, x), mode=mode)
self.assert_eqs_const(f, 1)
if dtype == 'bool':
self.assert_eqs_const(f, 1)
f = theano.function([x], T.or_(x, zero), mode=mode)
if f.outputs[0].variable.dtype == x.dtype:
self.assert_identity(f)
self.assert_identity(f)
f = theano.function([x], T.or_(zero, x), mode=mode)
if f.outputs[0].variable.dtype == x.dtype:
self.assert_identity(f)
self.assert_identity(f)
def test_or_int(self):
# Test that bitwise "or" is correctly computed on int constants.
f = theano.function([], T.or_(5, 6))
assert f() == 7
def test_xor(self):
# bitwise "xor" with itself should always give 0 for both bool and int.
mode = theano.compile.get_default_mode().including('canonicalize')
x = T.scalar('x', dtype='int8')
for dtype in ('bool', 'int8'):
x = T.scalar('x', dtype=dtype)
f = theano.function([x], T.xor(x, x), mode=mode)
self.assert_eqs_const(f, 0)
f = theano.function([x], T.xor(x, x), mode=mode)
self.assert_eqs_const(f, 0)
def test_stacktrace(self):
mode = theano.compile.get_default_mode().including(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论