提交 6a9aa55d authored 作者: Frederic Bastien's avatar Frederic Bastien

Speed up opt. Now constant folding is done in the eq, no need to traverse the…

Speed up opt. Now constant folding is done in the eq, no need to traverse the graph to get constant at each node.
上级 d6b3dff4
...@@ -3344,7 +3344,7 @@ pprint.assign(pow, printing.OperatorPrinter('**', 1, 'right')) ...@@ -3344,7 +3344,7 @@ pprint.assign(pow, printing.OperatorPrinter('**', 1, 'right'))
########################## ##########################
def extract_constant(x, elemwise=True): def extract_constant(x, elemwise=True, only_process_constants=False):
""" """
This function is basically a call to tensor.get_scalar_constant_value. This function is basically a call to tensor.get_scalar_constant_value.
...@@ -3356,7 +3356,9 @@ def extract_constant(x, elemwise=True): ...@@ -3356,7 +3356,9 @@ def extract_constant(x, elemwise=True):
""" """
try: try:
x = get_scalar_constant_value(x, elemwise=elemwise) x = get_scalar_constant_value(x,
elemwise,
only_process_constants)
except NotScalarConstantError: except NotScalarConstantError:
pass pass
if ((isinstance(x, scal.ScalarVariable) or if ((isinstance(x, scal.ScalarVariable) or
......
...@@ -4665,23 +4665,23 @@ def local_useless_elemwise_comparison(node): ...@@ -4665,23 +4665,23 @@ def local_useless_elemwise_comparison(node):
if isinstance(node.op.scalar_op, scalar.LT) and \ if isinstance(node.op.scalar_op, scalar.LT) and \
node.inputs[0].owner and \ node.inputs[0].owner and \
isinstance(node.inputs[0].owner.op, Shape_i) and \ isinstance(node.inputs[0].owner.op, Shape_i) and \
T.extract_constant(node.inputs[1]) == 0: T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [T.zeros_like(node.inputs[0], dtype=node.outputs[0].dtype)] return [T.zeros_like(node.inputs[0], dtype=node.outputs[0].dtype)]
# Elemwise[GE](X.shape[i], 0) -> Elemwise[ones](X) # Elemwise[GE](X.shape[i], 0) -> Elemwise[ones](X)
if isinstance(node.op.scalar_op, scalar.GE) and \ if isinstance(node.op.scalar_op, scalar.GE) and \
node.inputs[0].owner and \ node.inputs[0].owner and \
isinstance(node.inputs[0].owner.op, Shape_i) and \ isinstance(node.inputs[0].owner.op, Shape_i) and \
T.extract_constant(node.inputs[1]) == 0: T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [T.ones_like(node.inputs[0], dtype=node.outputs[0].dtype)] return [T.ones_like(node.inputs[0], dtype=node.outputs[0].dtype)]
# Elemwise[maximum](X.shape[i], 0) -> X.shape[i] # Elemwise[maximum](X.shape[i], 0) -> X.shape[i]
if isinstance(node.op.scalar_op, scalar.Maximum) and \ if isinstance(node.op.scalar_op, scalar.Maximum) and \
node.inputs[0].owner and \ node.inputs[0].owner and \
isinstance(node.inputs[0].owner.op, Shape_i) and \ isinstance(node.inputs[0].owner.op, Shape_i) and \
T.extract_constant(node.inputs[1]) == 0: T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [node.inputs[0]] return [node.inputs[0]]
# Elemwise[maximum](0, X.shape[i]) -> X.shape[i] # Elemwise[maximum](0, X.shape[i]) -> X.shape[i]
if isinstance(node.op.scalar_op, scalar.Maximum) and \ if isinstance(node.op.scalar_op, scalar.Maximum) and \
T.extract_constant(node.inputs[0]) == 0 and \ T.extract_constant(node.inputs[0], only_process_constants=True) == 0 and \
node.inputs[1].owner and \ node.inputs[1].owner and \
isinstance(node.inputs[1].owner.op, Shape_i): isinstance(node.inputs[1].owner.op, Shape_i):
return [node.inputs[1]] return [node.inputs[1]]
...@@ -4689,11 +4689,11 @@ def local_useless_elemwise_comparison(node): ...@@ -4689,11 +4689,11 @@ def local_useless_elemwise_comparison(node):
if isinstance(node.op.scalar_op, scalar.Minimum) and \ if isinstance(node.op.scalar_op, scalar.Minimum) and \
node.inputs[0].owner and \ node.inputs[0].owner and \
isinstance(node.inputs[0].owner.op, Shape_i) and \ isinstance(node.inputs[0].owner.op, Shape_i) and \
T.extract_constant(node.inputs[1]) == 0: T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [T.zeros_like(node.inputs[0], dtype=node.outputs[0].dtype)] return [T.zeros_like(node.inputs[0], dtype=node.outputs[0].dtype)]
# Elemwise[minimum](0, X.shape[i]) -> 0 # Elemwise[minimum](0, X.shape[i]) -> 0
if isinstance(node.op.scalar_op, scalar.Minimum) and \ if isinstance(node.op.scalar_op, scalar.Minimum) and \
T.extract_constant(node.inputs[0]) == 0 and \ T.extract_constant(node.inputs[0], only_process_constants=True) == 0 and \
node.inputs[1].owner and \ node.inputs[1].owner and \
isinstance(node.inputs[1].owner.op, Shape_i): isinstance(node.inputs[1].owner.op, Shape_i):
return [T.zeros_like(node.inputs[1], dtype=node.outputs[0].dtype)] return [T.zeros_like(node.inputs[1], dtype=node.outputs[0].dtype)]
...@@ -4705,7 +4705,7 @@ def local_useless_elemwise_comparison(node): ...@@ -4705,7 +4705,7 @@ def local_useless_elemwise_comparison(node):
isinstance(node.inputs[0].owner.op.scalar_op, scalar.Add) and \ isinstance(node.inputs[0].owner.op.scalar_op, scalar.Add) and \
all([isinstance(var.owner and var.owner.op, Shape_i) all([isinstance(var.owner and var.owner.op, Shape_i)
for var in node.inputs[0].owner.inputs]) and \ for var in node.inputs[0].owner.inputs]) and \
T.extract_constant(node.inputs[1]) == 0: T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [T.zeros_like(node.inputs[0], dtype=node.outputs[0].dtype)] return [T.zeros_like(node.inputs[0], dtype=node.outputs[0].dtype)]
# Elemwise[GE](add([anything that is shapes]), 0) -> Elemwise[ones](X) # Elemwise[GE](add([anything that is shapes]), 0) -> Elemwise[ones](X)
...@@ -4715,7 +4715,7 @@ def local_useless_elemwise_comparison(node): ...@@ -4715,7 +4715,7 @@ def local_useless_elemwise_comparison(node):
isinstance(node.inputs[0].owner.op.scalar_op, scalar.Add) and \ isinstance(node.inputs[0].owner.op.scalar_op, scalar.Add) and \
all([isinstance(var.owner and var.owner.op, Shape_i) all([isinstance(var.owner and var.owner.op, Shape_i)
for var in node.inputs[0].owner.inputs]) and \ for var in node.inputs[0].owner.inputs]) and \
T.extract_constant(node.inputs[1]) == 0: T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [T.ones_like(node.inputs[0], dtype=node.outputs[0].dtype)] return [T.ones_like(node.inputs[0], dtype=node.outputs[0].dtype)]
return return
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论