提交 99f22fb4 authored 作者: Frederic Bastien's avatar Frederic Bastien

pre-clean up some fill inside opt.

上级 e4e08782
......@@ -2004,20 +2004,22 @@ def local_useless_elemwise(node):
if isinstance(node.op, T.Elemwise):
def zeros_like(node, in_idx):
# it is the same var in the graph. That will always be true
return [T.fill(node.inputs[in_idx],
T.constant(0.0, dtype=node.outputs[0].type.dtype))]
ret = T.fill(node.inputs[in_idx],
T.constant(0.0, dtype=node.outputs[0].type.dtype))
ret = pre_greedy_local_optimizer(local_useless_fill, ret)
return [ret]
def ones_like(node, in_idx):
# it is the same var in the graph. That will always be true
return [T.fill(node.inputs[in_idx],
T.constant(1.0, dtype=node.outputs[0].type.dtype))]
ret = T.fill(node.inputs[in_idx],
T.constant(1.0, dtype=node.outputs[0].type.dtype))
ret = pre_greedy_local_optimizer(local_useless_fill, ret)
return [ret]
if node.op.scalar_op == theano.scalar.eq and len(node.inputs) == 2:
if node.inputs[0] == node.inputs[1]:
# it is the same var in the graph. That will always be true
ret = [T.fill(node.inputs[0],
T.constant(1.0,
dtype=node.outputs[0].type.dtype))]
ret = ones_like(node, 0)
# Copy stack trace from input to constant output
copy_stack_trace(node.outputs[0], ret)
......@@ -2025,9 +2027,7 @@ def local_useless_elemwise(node):
elif node.op.scalar_op == theano.scalar.neq and len(node.inputs) == 2:
if node.inputs[0] == node.inputs[1]:
# it is the same var in the graph. That will always be false
ret = [T.fill(node.inputs[0],
T.constant(0.0,
dtype=node.outputs[0].type.dtype))]
ret = zeros_like(node, 0)
# Copy stack trace from input to constant output
copy_stack_trace(node.outputs[0], ret)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论