提交 356ea69a authored 作者: Ziye Fan's avatar Ziye Fan

register local_fill_sink before local_fill_to_alloc

上级 6d7ece16
......@@ -1349,6 +1349,54 @@ theano.compile.mode.optdb.register('ShapeOpt', ShapeOptimizer(),
0.1, 'fast_run', 'fast_compile')
@gof.local_optimizer([T.Elemwise])
def local_fill_sink(node):
"""
f(fill(a, b), fill(c, d), e) -> fill(a, fill(c, f(b, d, e)))
f need to be an elemwise
"""
if (not hasattr(node, 'op') or
not isinstance(node.op, T.Elemwise) or
node.op == T.fill):
return False
models = []
inputs = []
for input in node.inputs:
if input.owner and input.owner.op == T.fill:
models.append(input.owner.inputs[0])
inputs.append(input.owner.inputs[1])
else:
inputs.append(input)
if not models:
return False
c = node.op(*inputs)
for model in models:
c = T.fill(model, c)
# The newly created node c doesn't has 'clients',
# so this iteration is took place with node.outputs[0]
replacements = {node.outputs[0]: c}
for client, cl_idx in node.outputs[0].clients:
if (hasattr(client, 'op') and
isinstance(client.op, T.Elemwise) and
not client.op == T.fill):
client_inputs = client.inputs[:]
client_inputs[cl_idx] = c
new_client = client.op(*client_inputs)
# Add clients to new_client
new_client.owner.outputs[0].clients = client.outputs[0].clients
r = local_fill_sink.transform(new_client.owner)
if r is False:
continue
replacements.pop(node.outputs[0], None)
replacements.update(r)
return replacements
register_canonicalize(local_fill_sink)
@register_specialize
@register_stabilize
@register_canonicalize
......@@ -3569,53 +3617,6 @@ register_canonicalize(local_fill_cut)
register_canonicalize(gof.OpRemove(T.tensor_copy), name='remove_tensor_copy')
@gof.local_optimizer([T.Elemwise])
def local_fill_sink(node):
"""
f(fill(a, b), fill(c, d), e) -> fill(a, fill(c, f(b, d, e)))
f need to be an elemwise
"""
if (not hasattr(node, 'op') or
not isinstance(node.op, T.Elemwise) or
node.op == T.fill):
return False
models = []
inputs = []
for input in node.inputs:
if input.owner and input.owner.op == T.fill:
models.append(input.owner.inputs[0])
inputs.append(input.owner.inputs[1])
else:
inputs.append(input)
if not models:
return False
c = node.op(*inputs)
for model in models:
c = T.fill(model, c)
# The newly created node c doesn't has 'clients',
# so this iteration is took place with node.outputs[0]
replacements = {node.outputs[0]: c}
for client, cl_idx in node.outputs[0].clients:
if (hasattr(client, 'op') and
isinstance(client.op, T.Elemwise) and
not client.op == T.fill):
client_inputs = client.inputs[:]
client_inputs[cl_idx] = c
new_client = client.op(*client_inputs)
# Add clients to new_client
new_client.owner.outputs[0].clients = client.outputs[0].clients
r = local_fill_sink.transform(new_client.owner)
if r is False:
continue
replacements.pop(node.outputs[0], None)
replacements.update(r)
return replacements
register_canonicalize(local_fill_sink)
################
# Canonization #
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论