提交 c44b8b5c authored 作者: Amjad Almahairi's avatar Amjad Almahairi

fix conditions and comment

上级 285abe3b
...@@ -1605,17 +1605,17 @@ def local_useless_elemwise(node): ...@@ -1605,17 +1605,17 @@ def local_useless_elemwise(node):
if isinstance(node.inputs[0], T.TensorConstant): if isinstance(node.inputs[0], T.TensorConstant):
const_val = T.extract_constant(node.inputs[0]) const_val = T.extract_constant(node.inputs[0])
if const_val == 1: if const_val == 0:
return [node.inputs[1]]
elif const_val == 0:
return zeros_like(node, 1) return zeros_like(node, 1)
else:
return [node.inputs[1]]
if isinstance(node.inputs[1], T.TensorConstant): if isinstance(node.inputs[1], T.TensorConstant):
const_val = T.extract_constant(node.inputs[1]) const_val = T.extract_constant(node.inputs[1])
if const_val == 1: if const_val == 0:
return [node.inputs[0]]
elif const_val == 0:
return zeros_like(node, 0) return zeros_like(node, 0)
else:
return [node.inputs[0]]
elif (isinstance(node.op.scalar_op, scalar.OR) and elif (isinstance(node.op.scalar_op, scalar.OR) and
len(node.inputs) == 2): len(node.inputs) == 2):
...@@ -1624,14 +1624,14 @@ def local_useless_elemwise(node): ...@@ -1624,14 +1624,14 @@ def local_useless_elemwise(node):
const_val = T.extract_constant(node.inputs[0]) const_val = T.extract_constant(node.inputs[0])
if const_val == 0: if const_val == 0:
return [node.inputs[1]] return [node.inputs[1]]
elif const_val == 1: else:
return ones_like(node, 1) return ones_like(node, 1)
if isinstance(node.inputs[1], T.TensorConstant): if isinstance(node.inputs[1], T.TensorConstant):
const_val = T.extract_constant(node.inputs[1]) const_val = T.extract_constant(node.inputs[1])
if const_val == 0: if const_val == 0:
return [node.inputs[0]] return [node.inputs[0]]
if const_val == 1: else:
return ones_like(node, 0) return ones_like(node, 0)
elif (isinstance(node.op.scalar_op, scalar.XOR) and elif (isinstance(node.op.scalar_op, scalar.XOR) and
...@@ -4216,8 +4216,7 @@ def local_useless_elemwise_comparison(node): ...@@ -4216,8 +4216,7 @@ def local_useless_elemwise_comparison(node):
"""... """...
:note: These cases appear in the graph generated by scan. :note: These cases appear in the graph generated by scan.
These optimizations will not reduce computation, These optimizations will make the graph easier to read.
but will make the graph easier to read.
# Comparing to itself is constant # Comparing to itself is constant
Elemwise[{LT,GT}](X, X) -> Elemwise[zeros](X) Elemwise[{LT,GT}](X, X) -> Elemwise[zeros](X)
Elemwise[{LE,GE}](X, X) -> Elemwise[ones](X) Elemwise[{LE,GE}](X, X) -> Elemwise[ones](X)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论