提交 6a9aa55d authored 作者: Frederic Bastien's avatar Frederic Bastien

Speed up opt. Now constant folding is done in the eq, no need to traverse the…

Speed up opt. Now constant folding is done in the eq, no need to traverse the graph to get constant at each node.
上级 d6b3dff4
......@@ -3344,7 +3344,7 @@ pprint.assign(pow, printing.OperatorPrinter('**', 1, 'right'))
##########################
def extract_constant(x, elemwise=True):
def extract_constant(x, elemwise=True, only_process_constants=False):
"""
This function is basically a call to tensor.get_scalar_constant_value.
......@@ -3356,7 +3356,9 @@ def extract_constant(x, elemwise=True):
"""
try:
x = get_scalar_constant_value(x, elemwise=elemwise)
x = get_scalar_constant_value(x,
elemwise,
only_process_constants)
except NotScalarConstantError:
pass
if ((isinstance(x, scal.ScalarVariable) or
......
......@@ -4665,23 +4665,23 @@ def local_useless_elemwise_comparison(node):
if isinstance(node.op.scalar_op, scalar.LT) and \
node.inputs[0].owner and \
isinstance(node.inputs[0].owner.op, Shape_i) and \
T.extract_constant(node.inputs[1]) == 0:
T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [T.zeros_like(node.inputs[0], dtype=node.outputs[0].dtype)]
# Elemwise[GE](X.shape[i], 0) -> Elemwise[ones](X)
if isinstance(node.op.scalar_op, scalar.GE) and \
node.inputs[0].owner and \
isinstance(node.inputs[0].owner.op, Shape_i) and \
T.extract_constant(node.inputs[1]) == 0:
T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [T.ones_like(node.inputs[0], dtype=node.outputs[0].dtype)]
# Elemwise[maximum](X.shape[i], 0) -> X.shape[i]
if isinstance(node.op.scalar_op, scalar.Maximum) and \
node.inputs[0].owner and \
isinstance(node.inputs[0].owner.op, Shape_i) and \
T.extract_constant(node.inputs[1]) == 0:
T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [node.inputs[0]]
# Elemwise[maximum](0, X.shape[i]) -> X.shape[i]
if isinstance(node.op.scalar_op, scalar.Maximum) and \
T.extract_constant(node.inputs[0]) == 0 and \
T.extract_constant(node.inputs[0], only_process_constants=True) == 0 and \
node.inputs[1].owner and \
isinstance(node.inputs[1].owner.op, Shape_i):
return [node.inputs[1]]
......@@ -4689,11 +4689,11 @@ def local_useless_elemwise_comparison(node):
if isinstance(node.op.scalar_op, scalar.Minimum) and \
node.inputs[0].owner and \
isinstance(node.inputs[0].owner.op, Shape_i) and \
T.extract_constant(node.inputs[1]) == 0:
T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [T.zeros_like(node.inputs[0], dtype=node.outputs[0].dtype)]
# Elemwise[minimum](0, X.shape[i]) -> 0
if isinstance(node.op.scalar_op, scalar.Minimum) and \
T.extract_constant(node.inputs[0]) == 0 and \
T.extract_constant(node.inputs[0], only_process_constants=True) == 0 and \
node.inputs[1].owner and \
isinstance(node.inputs[1].owner.op, Shape_i):
return [T.zeros_like(node.inputs[1], dtype=node.outputs[0].dtype)]
......@@ -4705,7 +4705,7 @@ def local_useless_elemwise_comparison(node):
isinstance(node.inputs[0].owner.op.scalar_op, scalar.Add) and \
all([isinstance(var.owner and var.owner.op, Shape_i)
for var in node.inputs[0].owner.inputs]) and \
T.extract_constant(node.inputs[1]) == 0:
T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [T.zeros_like(node.inputs[0], dtype=node.outputs[0].dtype)]
# Elemwise[GE](add([anything that is shapes]), 0) -> Elemwise[ones](X)
......@@ -4715,7 +4715,7 @@ def local_useless_elemwise_comparison(node):
isinstance(node.inputs[0].owner.op.scalar_op, scalar.Add) and \
all([isinstance(var.owner and var.owner.op, Shape_i)
for var in node.inputs[0].owner.inputs]) and \
T.extract_constant(node.inputs[1]) == 0:
T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [T.ones_like(node.inputs[0], dtype=node.outputs[0].dtype)]
return
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论