提交 8dcc5fc6 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #6029 from lamblin/fix_useless_bitwise

Fix issue in optimizations with bitwise operations
...@@ -181,8 +181,9 @@ def as_tensor_variable(x, name=None, ndim=None): ...@@ -181,8 +181,9 @@ def as_tensor_variable(x, name=None, ndim=None):
if isinstance(x, bool): if isinstance(x, bool):
raise AsTensorError( raise AsTensorError(
"Cannot cast True or False as a tensor variable. Please use 1 or " "Cannot cast True or False as a tensor variable. Please use "
"0. This error might be caused by using the == operator on " "np.array(True) or np.array(False) if you need these constants. "
"This error might be caused by using the == operator on "
"Variables. v == w does not do what you think it does, " "Variables. v == w does not do what you think it does, "
"use theano.tensor.eq(v, w) instead.") "use theano.tensor.eq(v, w) instead.")
......
...@@ -2087,16 +2087,16 @@ def local_subtensor_make_vector(node): ...@@ -2087,16 +2087,16 @@ def local_subtensor_make_vector(node):
@gof.local_optimizer([T.Elemwise]) @gof.local_optimizer([T.Elemwise])
def local_useless_elemwise(node): def local_useless_elemwise(node):
""" """
eq(x,x) -> 1 eq(x, x) -> 1
neq(x,x) -> 0 neq(x, x) -> 0
mul(x) -> x mul(x) -> x
add(x) -> x add(x) -> x
identity(x) -> x identity(x) -> x
and(x,1) -> x and(x, 1) -> x (if x.dtype == 'bool')
and(x,0) -> zeros_like(x) and(x, 0) -> zeros_like(x)
or(x,0) -> x or(x, 0) -> x
or(x,1) -> ones_like(x) or(x, 1) -> ones_like(x) (if x.dtype == 'bool')
xor(x,x) -> zeros_like(x) xor(x, x) -> zeros_like(x)
""" """
if isinstance(node.op, T.Elemwise): if isinstance(node.op, T.Elemwise):
...@@ -2141,7 +2141,9 @@ def local_useless_elemwise(node): ...@@ -2141,7 +2141,9 @@ def local_useless_elemwise(node):
if const_val == 0: if const_val == 0:
return [T.zeros_like(node.inputs[1], dtype=dtype, return [T.zeros_like(node.inputs[1], dtype=dtype,
opt=True)] opt=True)]
else: elif node.outputs[0].dtype == 'bool':
# If the output is not Boolean, it is the bitwise AND,
# and this optimization would be wrong
return [node.inputs[1].astype(node.outputs[0].dtype)] return [node.inputs[1].astype(node.outputs[0].dtype)]
if isinstance(node.inputs[1], T.TensorConstant): if isinstance(node.inputs[1], T.TensorConstant):
...@@ -2150,7 +2152,9 @@ def local_useless_elemwise(node): ...@@ -2150,7 +2152,9 @@ def local_useless_elemwise(node):
if const_val == 0: if const_val == 0:
return [T.zeros_like(node.inputs[0], dtype=dtype, return [T.zeros_like(node.inputs[0], dtype=dtype,
opt=True)] opt=True)]
else: elif node.outputs[0].dtype == 'bool':
# If the output is not Boolean, it is the bitwise AND,
# and this optimization would be wrong
return [node.inputs[0].astype(node.outputs[0].dtype)] return [node.inputs[0].astype(node.outputs[0].dtype)]
elif (isinstance(node.op.scalar_op, scalar.OR) and elif (isinstance(node.op.scalar_op, scalar.OR) and
...@@ -2161,7 +2165,9 @@ def local_useless_elemwise(node): ...@@ -2161,7 +2165,9 @@ def local_useless_elemwise(node):
if not isinstance(const_val, Variable): if not isinstance(const_val, Variable):
if const_val == 0: if const_val == 0:
return [node.inputs[1].astype(node.outputs[0].dtype)] return [node.inputs[1].astype(node.outputs[0].dtype)]
else: elif node.outputs[0].dtype == 'bool':
# If the output is not Boolean, it is the bitwise OR,
# and this optimization would be wrong
return [T.ones_like(node.inputs[1], dtype=dtype, return [T.ones_like(node.inputs[1], dtype=dtype,
opt=True)] opt=True)]
...@@ -2170,7 +2176,9 @@ def local_useless_elemwise(node): ...@@ -2170,7 +2176,9 @@ def local_useless_elemwise(node):
if not isinstance(const_val, Variable): if not isinstance(const_val, Variable):
if const_val == 0: if const_val == 0:
return [node.inputs[0].astype(node.outputs[0].dtype)] return [node.inputs[0].astype(node.outputs[0].dtype)]
else: elif node.outputs[0].dtype == 'bool':
# If the output is not Boolean, it is the bitwise OR,
# and this optimization would be wrong
return [T.ones_like(node.inputs[0], dtype=dtype, return [T.ones_like(node.inputs[0], dtype=dtype,
opt=True)] opt=True)]
......
...@@ -3493,8 +3493,12 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase): ...@@ -3493,8 +3493,12 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert len(topo) == 1 assert len(topo) == 1
assert topo[0].op == deep_copy_op assert topo[0].op == deep_copy_op
x_val = 10 if f.outputs[0].variable.dtype == 'bool':
assert f(x_val) == x_val x_vals = [0, 1]
else:
x_vals = [0, 1, 10]
for x_val in x_vals:
assert f(x_val) == x_val
def test_inequality_with_self(self): def test_inequality_with_self(self):
x = T.scalar('x', dtype=config.floatX) x = T.scalar('x', dtype=config.floatX)
...@@ -3601,11 +3605,14 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase): ...@@ -3601,11 +3605,14 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
assert (f([3, 3]) == 0).all() assert (f([3, 3]) == 0).all()
def test_and(self): def test_and(self):
# bitwise "and" with 0 should give 0 for both bool and int
# bitwise "and" with 1 should only simplify for bool
mode = theano.compile.get_default_mode().including('canonicalize') mode = theano.compile.get_default_mode().including('canonicalize')
for dtype, zero, one in [('bool', np.array(False), np.array(True)),
('int8', np.int8(0), np.int8(1)),
('int8', 0, 1)]:
x = T.scalar('x', dtype=dtype)
x = T.scalar('x', dtype='int8')
for zero, one in [(np.int8(0), np.int8(1)), (0, 1)]:
f = theano.function([x], T.and_(x, zero), mode=mode) f = theano.function([x], T.and_(x, zero), mode=mode)
self.assert_eqs_const(f, 0) self.assert_eqs_const(f, 0)
...@@ -3613,38 +3620,54 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase): ...@@ -3613,38 +3620,54 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
self.assert_eqs_const(f, 0) self.assert_eqs_const(f, 0)
f = theano.function([x], T.and_(x, one), mode=mode) f = theano.function([x], T.and_(x, one), mode=mode)
if f.outputs[0].variable.dtype == x.dtype: if dtype == 'bool':
self.assert_identity(f) self.assert_identity(f)
f = theano.function([x], T.and_(one, x), mode=mode) f = theano.function([x], T.and_(one, x), mode=mode)
if f.outputs[0].variable.dtype == x.dtype: if dtype == 'bool':
self.assert_identity(f) self.assert_identity(f)
def test_and_int(self):
# Test that bitwise "and" is correctly computed on int constants.
f = theano.function([], T.and_(5, 6))
assert f() == 4
def test_or(self): def test_or(self):
# bitwise "or" with 0 should simplify for both bool and int
# bitwise "or" with 1 should only give 1 for bool
mode = theano.compile.get_default_mode().including('canonicalize') mode = theano.compile.get_default_mode().including('canonicalize')
x = T.scalar('x', dtype='int8') for dtype, zero, one in [('bool', np.array(False), np.array(True)),
('int8', np.int8(0), np.int8(1)),
('int8', 0, 1)]:
x = T.scalar('x', dtype=dtype)
for zero, one in [(np.int8(0), np.int8(1)), (0, 1)]:
f = theano.function([x], T.or_(x, one), mode=mode) f = theano.function([x], T.or_(x, one), mode=mode)
self.assert_eqs_const(f, 1) if dtype == 'bool':
self.assert_eqs_const(f, 1)
f = theano.function([x], T.or_(one, x), mode=mode) f = theano.function([x], T.or_(one, x), mode=mode)
self.assert_eqs_const(f, 1) if dtype == 'bool':
self.assert_eqs_const(f, 1)
f = theano.function([x], T.or_(x, zero), mode=mode) f = theano.function([x], T.or_(x, zero), mode=mode)
if f.outputs[0].variable.dtype == x.dtype: self.assert_identity(f)
self.assert_identity(f)
f = theano.function([x], T.or_(zero, x), mode=mode) f = theano.function([x], T.or_(zero, x), mode=mode)
if f.outputs[0].variable.dtype == x.dtype: self.assert_identity(f)
self.assert_identity(f)
def test_or_int(self):
# Test that bitwise "or" is correctly computed on int constants.
f = theano.function([], T.or_(5, 6))
assert f() == 7
def test_xor(self): def test_xor(self):
# bitwise "xor" with itself should always give 0 for both bool and int.
mode = theano.compile.get_default_mode().including('canonicalize') mode = theano.compile.get_default_mode().including('canonicalize')
x = T.scalar('x', dtype='int8') for dtype in ('bool', 'int8'):
x = T.scalar('x', dtype=dtype)
f = theano.function([x], T.xor(x, x), mode=mode) f = theano.function([x], T.xor(x, x), mode=mode)
self.assert_eqs_const(f, 0) self.assert_eqs_const(f, 0)
def test_stacktrace(self): def test_stacktrace(self):
mode = theano.compile.get_default_mode().including( mode = theano.compile.get_default_mode().including(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论