提交 629fa735 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix optimization that is incorrect for non-bool

Since tensor.or_ and tensor.and_ are bitwise operations (for non-booleans at least), some "logical" simplifications are wrong.
上级 579707bb
...@@ -2087,16 +2087,16 @@ def local_subtensor_make_vector(node): ...@@ -2087,16 +2087,16 @@ def local_subtensor_make_vector(node):
@gof.local_optimizer([T.Elemwise]) @gof.local_optimizer([T.Elemwise])
def local_useless_elemwise(node): def local_useless_elemwise(node):
""" """
eq(x,x) -> 1 eq(x, x) -> 1
neq(x,x) -> 0 neq(x, x) -> 0
mul(x) -> x mul(x) -> x
add(x) -> x add(x) -> x
identity(x) -> x identity(x) -> x
and(x,1) -> x and(x, 1) -> x (if x.dtype == 'bool')
and(x,0) -> zeros_like(x) and(x, 0) -> zeros_like(x)
or(x,0) -> x or(x, 0) -> x
or(x,1) -> ones_like(x) or(x, 1) -> ones_like(x) (if x.dtype == 'bool')
xor(x,x) -> zeros_like(x) xor(x, x) -> zeros_like(x)
""" """
if isinstance(node.op, T.Elemwise): if isinstance(node.op, T.Elemwise):
...@@ -2141,7 +2141,9 @@ def local_useless_elemwise(node): ...@@ -2141,7 +2141,9 @@ def local_useless_elemwise(node):
if const_val == 0: if const_val == 0:
return [T.zeros_like(node.inputs[1], dtype=dtype, return [T.zeros_like(node.inputs[1], dtype=dtype,
opt=True)] opt=True)]
else: elif node.outputs[0].dtype == 'bool':
# If the output is not Boolean, it is the bitwise AND,
# and this optimization would be wrong
return [node.inputs[1].astype(node.outputs[0].dtype)] return [node.inputs[1].astype(node.outputs[0].dtype)]
if isinstance(node.inputs[1], T.TensorConstant): if isinstance(node.inputs[1], T.TensorConstant):
...@@ -2150,7 +2152,9 @@ def local_useless_elemwise(node): ...@@ -2150,7 +2152,9 @@ def local_useless_elemwise(node):
if const_val == 0: if const_val == 0:
return [T.zeros_like(node.inputs[0], dtype=dtype, return [T.zeros_like(node.inputs[0], dtype=dtype,
opt=True)] opt=True)]
else: elif node.outputs[0].dtype == 'bool':
# If the output is not Boolean, it is the bitwise AND,
# and this optimization would be wrong
return [node.inputs[0].astype(node.outputs[0].dtype)] return [node.inputs[0].astype(node.outputs[0].dtype)]
elif (isinstance(node.op.scalar_op, scalar.OR) and elif (isinstance(node.op.scalar_op, scalar.OR) and
...@@ -2161,7 +2165,9 @@ def local_useless_elemwise(node): ...@@ -2161,7 +2165,9 @@ def local_useless_elemwise(node):
if not isinstance(const_val, Variable): if not isinstance(const_val, Variable):
if const_val == 0: if const_val == 0:
return [node.inputs[1].astype(node.outputs[0].dtype)] return [node.inputs[1].astype(node.outputs[0].dtype)]
else: elif node.outputs[0].dtype == 'bool':
# If the output is not Boolean, it is the bitwise OR,
# and this optimization would be wrong
return [T.ones_like(node.inputs[1], dtype=dtype, return [T.ones_like(node.inputs[1], dtype=dtype,
opt=True)] opt=True)]
...@@ -2170,7 +2176,9 @@ def local_useless_elemwise(node): ...@@ -2170,7 +2176,9 @@ def local_useless_elemwise(node):
if not isinstance(const_val, Variable): if not isinstance(const_val, Variable):
if const_val == 0: if const_val == 0:
return [node.inputs[0].astype(node.outputs[0].dtype)] return [node.inputs[0].astype(node.outputs[0].dtype)]
else: elif node.outputs[0].dtype == 'bool':
# If the output is not Boolean, it is the bitwise OR,
# and this optimization would be wrong
return [T.ones_like(node.inputs[0], dtype=dtype, return [T.ones_like(node.inputs[0], dtype=dtype,
opt=True)] opt=True)]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论