提交 88eac16c authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3092 from t13m/opt_local_fill_sink

Optimize local_fill_sink
...@@ -579,6 +579,7 @@ class MergeFeature(object): ...@@ -579,6 +579,7 @@ class MergeFeature(object):
if node.inputs: if node.inputs:
assert len(node.inputs[0].clients) > 0 assert len(node.inputs[0].clients) > 0
assert (node, 0) in node.inputs[0].clients assert (node, 0) in node.inputs[0].clients
merge_candidates = [c for (c, i) in node.inputs[0].clients merge_candidates = [c for (c, i) in node.inputs[0].clients
if c in self.nodes_seen] if c in self.nodes_seen]
......
...@@ -1349,6 +1349,60 @@ theano.compile.mode.optdb.register('ShapeOpt', ShapeOptimizer(), ...@@ -1349,6 +1349,60 @@ theano.compile.mode.optdb.register('ShapeOpt', ShapeOptimizer(),
0.1, 'fast_run', 'fast_compile') 0.1, 'fast_run', 'fast_compile')
@gof.local_optimizer([T.Elemwise])
def local_fill_sink(node):
"""
f(fill(a, b), fill(c, d), e) -> fill(a, fill(c, f(b, d, e)))
f need to be an elemwise
"""
if (not hasattr(node, 'op') or
not isinstance(node.op, T.Elemwise) or
node.op == T.fill):
return False
models = []
inputs = []
for input in node.inputs:
if input.owner and input.owner.op == T.fill:
models.append(input.owner.inputs[0])
inputs.append(input.owner.inputs[1])
else:
inputs.append(input)
if not models:
return False
c = node.op(*inputs)
for model in models:
if model.type != c.type:
c = T.fill(model, c)
# The newly created node c doesn't has 'clients',
# so this iteration is took place with node.outputs[0]
replacements = {node.outputs[0]: c}
all_clients_replaced = True
for client, cl_idx in node.outputs[0].clients:
if (hasattr(client, 'op') and
isinstance(client.op, T.Elemwise) and
not client.op == T.fill):
client_inputs = client.inputs[:]
client_inputs[cl_idx] = c
new_client = client.op(*client_inputs)
# Add clients to new_client
new_client.owner.outputs[0].clients = client.outputs[0].clients
r = local_fill_sink.transform(new_client.owner)
if not r:
all_clients_replaced = False
continue
replacements.update(r)
else:
all_clients_replaced = False
if all_clients_replaced:
replacements.pop(node.outputs[0], None)
return replacements
register_canonicalize(local_fill_sink)
@register_specialize @register_specialize
@register_stabilize @register_stabilize
@register_canonicalize @register_canonicalize
...@@ -3320,6 +3374,7 @@ def local_useless_switch(node): ...@@ -3320,6 +3374,7 @@ def local_useless_switch(node):
return False return False
@register_specialize
@register_canonicalize @register_canonicalize
@gof.local_optimizer([T.mul]) @gof.local_optimizer([T.mul])
def local_mul_switch_sink(node): def local_mul_switch_sink(node):
...@@ -3349,6 +3404,7 @@ def local_mul_switch_sink(node): ...@@ -3349,6 +3404,7 @@ def local_mul_switch_sink(node):
return False return False
for idx, i in enumerate(node.inputs): for idx, i in enumerate(node.inputs):
if i.owner and i.owner.op == T.switch: if i.owner and i.owner.op == T.switch:
# import ipdb;ipdb.set_trace()
switch = i.owner switch = i.owner
try: try:
if (get_scalar_constant_value( if (get_scalar_constant_value(
...@@ -3648,38 +3704,11 @@ register_canonicalize(local_fill_cut) ...@@ -3648,38 +3704,11 @@ register_canonicalize(local_fill_cut)
register_canonicalize(gof.OpRemove(T.tensor_copy), name='remove_tensor_copy') register_canonicalize(gof.OpRemove(T.tensor_copy), name='remove_tensor_copy')
@gof.local_optimizer([T.Elemwise])
def local_fill_sink(node):
"""
f(fill(a, b), fill(c, d), e) -> fill(a, fill(c, f(b, d, e)))
f need to be an elemwise
"""
if not isinstance(node.op, T.Elemwise) or node.op == T.fill:
return False
models = []
inputs = []
for input in node.inputs:
if input.owner and input.owner.op == T.fill:
models.append(input.owner.inputs[0])
inputs.append(input.owner.inputs[1])
else:
inputs.append(input)
if not models:
return False
c = node.op(*inputs)
for model in models:
c = T.fill(model, c)
return [c]
register_canonicalize(local_fill_sink)
################ ################
# Canonization # # Canonization #
################ ################
class Canonizer(gof.LocalOptimizer): class Canonizer(gof.LocalOptimizer):
""" """
Simplification tool. Simplification tool.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论