提交 8b91e067 authored 作者: Amjad Almahairi's avatar Amjad Almahairi

corrections

上级 7d1022f5
......@@ -4229,14 +4229,14 @@ def local_useless_elemwise_comparison(node):
# Comparing shape to 0 can be constant
Elemwise[LT](X.shape[i], 0) -> Elemwise[zeros](X)
Elemwise[GE](X.shape[i], 0) -> Elemwise[ones](X)
# Elemwise[maximum](X.shape[i], 0) -> X.shape[i]
# Elemwise[maximum](0, X.shape[i]) -> X.shape[i]
# Elemwise[minimum](X.shape[i], 0) -> 0
# Elemwise[minimum](0, X.shape[i]) -> 0
Elemwise[maximum](X.shape[i], 0) -> X.shape[i]
Elemwise[maximum](0, X.shape[i]) -> X.shape[i]
Elemwise[minimum](X.shape[i], 0) -> 0
Elemwise[minimum](0, X.shape[i]) -> 0
# The shape can be replaced with sum of shapes
Elemwise[LT](sum([anything that is shapes]), 0) -> Elemwise[zeros](X)
Elemwise[GE](sum([anything that is shapes]), 0) -> Elemwise[ones](X)
Elemwise[LT](add([anything that is shapes]), 0) -> Elemwise[zeros](X)
Elemwise[GE](add([anything that is shapes]), 0) -> Elemwise[ones](X)
"""
if not isinstance(node.op, T.Elemwise):
......@@ -4246,7 +4246,7 @@ def local_useless_elemwise_comparison(node):
# Elemwise[{LT,GT}](X, X) -> Elemwise[zeros](X)
if isinstance(node.op.scalar_op, (scalar.LT, scalar.GT)) and \
node.inputs[0] is node.inputs[1]:
return [T.zeros_like(node.outputs[0], dtype=node.outputs[0].type.dtype)]
return [T.zeros_like(node.inputs[0], dtype=node.outputs[0].type.dtype)]
# Elemwise[{LE,GE}](X, X) -> Elemwise[ones](X)
if isinstance(node.op.scalar_op, (scalar.LE, scalar.GE)) and \
node.inputs[0] is node.inputs[1]:
......@@ -4261,13 +4261,13 @@ def local_useless_elemwise_comparison(node):
node.inputs[0].owner and \
isinstance(node.inputs[0].owner.op, Shape_i) and \
T.extract_constant(node.inputs[1]) == 0:
return [T.zeros_like(node.outputs[0])]
return [T.zeros_like(node.inputs[0], dtype=node.inputs[1].type.dtype)]
# Elemwise[GE](X.shape[i], 0) -> Elemwise[ones](X)
if isinstance(node.op.scalar_op, scalar.GE) and \
node.inputs[0].owner and \
isinstance(node.inputs[0].owner.op, Shape_i) and \
T.extract_constant(node.inputs[1]) == 0:
return [T.ones_like(node.outputs[0])]
return [T.ones_like(node.inputs[0], dtype=node.inputs[1].type.dtype)]
# Elemwise[maximum](X.shape[i], 0) -> X.shape[i]
if isinstance(node.op.scalar_op, scalar.Maximum) and \
node.inputs[0].owner and \
......@@ -4285,32 +4285,33 @@ def local_useless_elemwise_comparison(node):
node.inputs[0].owner and \
isinstance(node.inputs[0].owner.op, Shape_i) and \
T.extract_constant(node.inputs[1]) == 0:
return [T.zeros_like(node.outputs[0])]
return [T.zeros_like(node.inputs[0], dtype=node.inputs[1].type.dtype)]
# Elemwise[minimum](0, X.shape[i]) -> 0
if isinstance(node.op.scalar_op, scalar.Minimum) and \
T.extract_constant(node.inputs[0]) == 0 and \
node.inputs[1].owner and \
isinstance(node.inputs[1].owner.op, Shape_i):
return [T.zeros_like(node.outputs[0])]
return [T.zeros_like(node.inputs[1], dtype=node.inputs[0].type.dtype)]
# Elemwise[LT](sum([anything that is shapes]), 0) -> Elemwise[zeros](X)
# Elemwise[LT](add([anything that is shapes]), 0) -> Elemwise[zeros](X)
if isinstance(node.op.scalar_op, scalar.LT) and \
node.inputs[0].owner and \
isinstance(node.inputs[0].owner.op, Elemwise) and \
isinstance(node.inputs[0].owner.op, scalar.Add) and \
isinstance(node.inputs[0].owner.op.scalar_op, scalar.Add) and \
all([isinstance(var.owner and var.owner.op, Shape_i)
for var in node.inputs[0].owner.inputs]) and \
T.extract_constant(node.inputs[1]) == 0:
return [T.zeros_like(node.inputs[0], dtype=node.outputs[0].dtype)]
# Elemwise[GE](sum([anything that is shapes]), 0) -> Elemwise[ones](X)
return [T.zeros_like(node.inputs[0], dtype=node.inputs[1].dtype)]
# Elemwise[GE](add([anything that is shapes]), 0) -> Elemwise[ones](X)
if isinstance(node.op.scalar_op, scalar.GE) and \
node.inputs[0].owner and \
isinstance(node.inputs[0].owner.op, Elemwise) and \
isinstance(node.inputs[0].owner.op, scalar.Add) and \
isinstance(node.inputs[0].owner.op.scalar_op, scalar.Add) and \
all([isinstance(var.owner and var.owner.op, Shape_i)
for var in node.inputs[0].owner.inputs]) and \
T.extract_constant(node.inputs[1]) == 0:
return [T.ones_like(node.inputs[0], dtype=node.outputs[0].dtype)]
return [T.ones_like(node.inputs[0], dtype=node.inputs[1].dtype)]
return
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论