提交 99f22fb4 authored 作者: Frederic Bastien's avatar Frederic Bastien

pre-clean up some fill inside opt.

上级 e4e08782
...@@ -2004,20 +2004,22 @@ def local_useless_elemwise(node): ...@@ -2004,20 +2004,22 @@ def local_useless_elemwise(node):
if isinstance(node.op, T.Elemwise): if isinstance(node.op, T.Elemwise):
def zeros_like(node, in_idx): def zeros_like(node, in_idx):
# it is the same var in the graph. That will always be true # it is the same var in the graph. That will always be true
return [T.fill(node.inputs[in_idx], ret = T.fill(node.inputs[in_idx],
T.constant(0.0, dtype=node.outputs[0].type.dtype))] T.constant(0.0, dtype=node.outputs[0].type.dtype))
ret = pre_greedy_local_optimizer(local_useless_fill, ret)
return [ret]
def ones_like(node, in_idx): def ones_like(node, in_idx):
# it is the same var in the graph. That will always be true # it is the same var in the graph. That will always be true
return [T.fill(node.inputs[in_idx], ret = T.fill(node.inputs[in_idx],
T.constant(1.0, dtype=node.outputs[0].type.dtype))] T.constant(1.0, dtype=node.outputs[0].type.dtype))
ret = pre_greedy_local_optimizer(local_useless_fill, ret)
return [ret]
if node.op.scalar_op == theano.scalar.eq and len(node.inputs) == 2: if node.op.scalar_op == theano.scalar.eq and len(node.inputs) == 2:
if node.inputs[0] == node.inputs[1]: if node.inputs[0] == node.inputs[1]:
# it is the same var in the graph. That will always be true # it is the same var in the graph. That will always be true
ret = [T.fill(node.inputs[0], ret = ones_like(node, 0)
T.constant(1.0,
dtype=node.outputs[0].type.dtype))]
# Copy stack trace from input to constant output # Copy stack trace from input to constant output
copy_stack_trace(node.outputs[0], ret) copy_stack_trace(node.outputs[0], ret)
...@@ -2025,9 +2027,7 @@ def local_useless_elemwise(node): ...@@ -2025,9 +2027,7 @@ def local_useless_elemwise(node):
elif node.op.scalar_op == theano.scalar.neq and len(node.inputs) == 2: elif node.op.scalar_op == theano.scalar.neq and len(node.inputs) == 2:
if node.inputs[0] == node.inputs[1]: if node.inputs[0] == node.inputs[1]:
# it is the same var in the graph. That will always be false # it is the same var in the graph. That will always be false
ret = [T.fill(node.inputs[0], ret = zeros_like(node, 0)
T.constant(0.0,
dtype=node.outputs[0].type.dtype))]
# Copy stack trace from input to constant output # Copy stack trace from input to constant output
copy_stack_trace(node.outputs[0], ret) copy_stack_trace(node.outputs[0], ret)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论