提交 b537db77 authored 作者: abergeron's avatar abergeron 提交者: GitHub

Merge pull request #4865 from nouiz/pre_cleanup

Pre cleanup
...@@ -2022,20 +2022,22 @@ def local_useless_elemwise(node): ...@@ -2022,20 +2022,22 @@ def local_useless_elemwise(node):
if isinstance(node.op, T.Elemwise): if isinstance(node.op, T.Elemwise):
def zeros_like(node, in_idx): def zeros_like(node, in_idx):
# it is the same var in the graph. That will always be true # it is the same var in the graph. That will always be true
return [T.fill(node.inputs[in_idx], ret = T.fill(node.inputs[in_idx],
T.constant(0.0, dtype=node.outputs[0].type.dtype))] T.constant(0.0, dtype=node.outputs[0].type.dtype))
ret = pre_greedy_local_optimizer([local_useless_fill], ret)
return [ret]
def ones_like(node, in_idx): def ones_like(node, in_idx):
# it is the same var in the graph. That will always be true # it is the same var in the graph. That will always be true
return [T.fill(node.inputs[in_idx], ret = T.fill(node.inputs[in_idx],
T.constant(1.0, dtype=node.outputs[0].type.dtype))] T.constant(1.0, dtype=node.outputs[0].type.dtype))
ret = pre_greedy_local_optimizer([local_useless_fill], ret)
return [ret]
if node.op.scalar_op == theano.scalar.eq and len(node.inputs) == 2: if node.op.scalar_op == theano.scalar.eq and len(node.inputs) == 2:
if node.inputs[0] == node.inputs[1]: if node.inputs[0] == node.inputs[1]:
# it is the same var in the graph. That will always be true # it is the same var in the graph. That will always be true
ret = [T.fill(node.inputs[0], ret = ones_like(node, 0)
T.constant(1.0,
dtype=node.outputs[0].type.dtype))]
# Copy stack trace from input to constant output # Copy stack trace from input to constant output
copy_stack_trace(node.outputs[0], ret) copy_stack_trace(node.outputs[0], ret)
...@@ -2043,9 +2045,7 @@ def local_useless_elemwise(node): ...@@ -2043,9 +2045,7 @@ def local_useless_elemwise(node):
elif node.op.scalar_op == theano.scalar.neq and len(node.inputs) == 2: elif node.op.scalar_op == theano.scalar.neq and len(node.inputs) == 2:
if node.inputs[0] == node.inputs[1]: if node.inputs[0] == node.inputs[1]:
# it is the same var in the graph. That will always be false # it is the same var in the graph. That will always be false
ret = [T.fill(node.inputs[0], ret = zeros_like(node, 0)
T.constant(0.0,
dtype=node.outputs[0].type.dtype))]
# Copy stack trace from input to constant output # Copy stack trace from input to constant output
copy_stack_trace(node.outputs[0], ret) copy_stack_trace(node.outputs[0], ret)
...@@ -5021,14 +5021,25 @@ def local_useless_elemwise_comparison(node): ...@@ -5021,14 +5021,25 @@ def local_useless_elemwise_comparison(node):
return return
if node.op.scalar_op.nin != 2: if node.op.scalar_op.nin != 2:
return return
def zeros_like(model, dtype):
ret = T.zeros_like(node.inputs[0], dtype=node.outputs[0].dtype)
ret = pre_greedy_local_optimizer([local_useless_fill], ret)
return ret
def ones_like(model, dtype):
ret = T.ones_like(node.inputs[0], dtype=node.outputs[0].dtype)
ret = pre_greedy_local_optimizer([local_useless_fill], ret)
return ret
# Elemwise[{LT,GT}](X, X) -> Elemwise[zeros](X) # Elemwise[{LT,GT}](X, X) -> Elemwise[zeros](X)
if isinstance(node.op.scalar_op, (scalar.LT, scalar.GT)) and \ if isinstance(node.op.scalar_op, (scalar.LT, scalar.GT)) and \
node.inputs[0] is node.inputs[1]: node.inputs[0] is node.inputs[1]:
return [T.zeros_like(node.inputs[0], dtype=node.outputs[0].dtype)] return [zeros_like(node.inputs[0], dtype=node.outputs[0].dtype)]
# Elemwise[{LE,GE}](X, X) -> Elemwise[ones](X) # Elemwise[{LE,GE}](X, X) -> Elemwise[ones](X)
if isinstance(node.op.scalar_op, (scalar.LE, scalar.GE)) and \ if isinstance(node.op.scalar_op, (scalar.LE, scalar.GE)) and \
node.inputs[0] is node.inputs[1]: node.inputs[0] is node.inputs[1]:
return [T.ones_like(node.inputs[0], dtype=node.outputs[0].dtype)] return [ones_like(node.inputs[0], dtype=node.outputs[0].dtype)]
# Elemwise[{minimum,maximum}](X, X) -> X # Elemwise[{minimum,maximum}](X, X) -> X
if isinstance(node.op.scalar_op, (scalar.Minimum, scalar.Maximum)) and \ if isinstance(node.op.scalar_op, (scalar.Minimum, scalar.Maximum)) and \
node.inputs[0] is node.inputs[1]: node.inputs[0] is node.inputs[1]:
...@@ -5039,13 +5050,13 @@ def local_useless_elemwise_comparison(node): ...@@ -5039,13 +5050,13 @@ def local_useless_elemwise_comparison(node):
node.inputs[0].owner and \ node.inputs[0].owner and \
isinstance(node.inputs[0].owner.op, Shape_i) and \ isinstance(node.inputs[0].owner.op, Shape_i) and \
T.extract_constant(node.inputs[1], only_process_constants=True) == 0: T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [T.zeros_like(node.inputs[0], dtype=node.outputs[0].dtype)] return [zeros_like(node.inputs[0], dtype=node.outputs[0].dtype)]
# Elemwise[GE](X.shape[i], 0) -> Elemwise[ones](X) # Elemwise[GE](X.shape[i], 0) -> Elemwise[ones](X)
if isinstance(node.op.scalar_op, scalar.GE) and \ if isinstance(node.op.scalar_op, scalar.GE) and \
node.inputs[0].owner and \ node.inputs[0].owner and \
isinstance(node.inputs[0].owner.op, Shape_i) and \ isinstance(node.inputs[0].owner.op, Shape_i) and \
T.extract_constant(node.inputs[1], only_process_constants=True) == 0: T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [T.ones_like(node.inputs[0], dtype=node.outputs[0].dtype)] return [ones_like(node.inputs[0], dtype=node.outputs[0].dtype)]
# Elemwise[maximum](X.shape[i], 0) -> X.shape[i] # Elemwise[maximum](X.shape[i], 0) -> X.shape[i]
if isinstance(node.op.scalar_op, scalar.Maximum) and \ if isinstance(node.op.scalar_op, scalar.Maximum) and \
node.inputs[0].owner and \ node.inputs[0].owner and \
...@@ -5063,13 +5074,13 @@ def local_useless_elemwise_comparison(node): ...@@ -5063,13 +5074,13 @@ def local_useless_elemwise_comparison(node):
node.inputs[0].owner and \ node.inputs[0].owner and \
isinstance(node.inputs[0].owner.op, Shape_i) and \ isinstance(node.inputs[0].owner.op, Shape_i) and \
T.extract_constant(node.inputs[1], only_process_constants=True) == 0: T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [T.zeros_like(node.inputs[0], dtype=node.outputs[0].dtype)] return [zeros_like(node.inputs[0], dtype=node.outputs[0].dtype)]
# Elemwise[minimum](0, X.shape[i]) -> 0 # Elemwise[minimum](0, X.shape[i]) -> 0
if isinstance(node.op.scalar_op, scalar.Minimum) and \ if isinstance(node.op.scalar_op, scalar.Minimum) and \
T.extract_constant(node.inputs[0], only_process_constants=True) == 0 and \ T.extract_constant(node.inputs[0], only_process_constants=True) == 0 and \
node.inputs[1].owner and \ node.inputs[1].owner and \
isinstance(node.inputs[1].owner.op, Shape_i): isinstance(node.inputs[1].owner.op, Shape_i):
return [T.zeros_like(node.inputs[1], dtype=node.outputs[0].dtype)] return [zeros_like(node.inputs[1], dtype=node.outputs[0].dtype)]
# Elemwise[LT](add([anything that is shapes]), 0) -> Elemwise[zeros](X) # Elemwise[LT](add([anything that is shapes]), 0) -> Elemwise[zeros](X)
if isinstance(node.op.scalar_op, scalar.LT) and \ if isinstance(node.op.scalar_op, scalar.LT) and \
...@@ -5080,7 +5091,7 @@ def local_useless_elemwise_comparison(node): ...@@ -5080,7 +5091,7 @@ def local_useless_elemwise_comparison(node):
for var in node.inputs[0].owner.inputs]) and \ for var in node.inputs[0].owner.inputs]) and \
T.extract_constant(node.inputs[1], only_process_constants=True) == 0: T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [T.zeros_like(node.inputs[0], dtype=node.outputs[0].dtype)] return [zeros_like(node.inputs[0], dtype=node.outputs[0].dtype)]
# Elemwise[GE](add([anything that is shapes]), 0) -> Elemwise[ones](X) # Elemwise[GE](add([anything that is shapes]), 0) -> Elemwise[ones](X)
if isinstance(node.op.scalar_op, scalar.GE) and \ if isinstance(node.op.scalar_op, scalar.GE) and \
node.inputs[0].owner and \ node.inputs[0].owner and \
...@@ -5089,7 +5100,7 @@ def local_useless_elemwise_comparison(node): ...@@ -5089,7 +5100,7 @@ def local_useless_elemwise_comparison(node):
all([isinstance(var.owner and var.owner.op, Shape_i) all([isinstance(var.owner and var.owner.op, Shape_i)
for var in node.inputs[0].owner.inputs]) and \ for var in node.inputs[0].owner.inputs]) and \
T.extract_constant(node.inputs[1], only_process_constants=True) == 0: T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [T.ones_like(node.inputs[0], dtype=node.outputs[0].dtype)] return [ones_like(node.inputs[0], dtype=node.outputs[0].dtype)]
# Elemwise[EQ](Subtensor(Shape(x)), -N) # Elemwise[EQ](Subtensor(Shape(x)), -N)
# Elemwise[EQ](somegraph that only depend of shape, -N) # Elemwise[EQ](somegraph that only depend of shape, -N)
...@@ -5122,7 +5133,7 @@ def local_useless_elemwise_comparison(node): ...@@ -5122,7 +5133,7 @@ def local_useless_elemwise_comparison(node):
cst = get_scalar_constant_value(node.inputs[1], cst = get_scalar_constant_value(node.inputs[1],
only_process_constants=True) only_process_constants=True)
if cst < 0: if cst < 0:
return [T.zeros_like(node.inputs[0], return [zeros_like(node.inputs[0],
dtype=node.outputs[0].dtype)] dtype=node.outputs[0].dtype)]
except NotScalarConstantError: except NotScalarConstantError:
pass pass
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论