提交 bd56a8f8 authored 作者: Frederic's avatar Frederic

Allow local optimizer to return a Dict that give variable to replaces.

Make PatternSub use that feature to lower optimizer time by tracking less frequent node. This allow to save 85s on 317s in optimizer time.
上级 12ec7339
...@@ -775,6 +775,8 @@ class LocalOptimizer(object): ...@@ -775,6 +775,8 @@ class LocalOptimizer(object):
or or
- <list of variables> to use in place of `node`'s outputs in the - <list of variables> to use in place of `node`'s outputs in the
greater graph. greater graph.
- dict(old variables -> new variables). A dictionary that map
from old variables to new variables to replace.
:type node: an Apply instance :type node: an Apply instance
...@@ -1015,8 +1017,10 @@ class PatternSub(LocalOptimizer): ...@@ -1015,8 +1017,10 @@ class PatternSub(LocalOptimizer):
(scrabble, 'x')) (scrabble, 'x'))
""" """
def __init__(self, in_pattern, out_pattern, allow_multiple_clients=False, def __init__(self, in_pattern, out_pattern,
skip_identities_fn=None, name=None, pdb=False): allow_multiple_clients=False,
skip_identities_fn=None, name=None, pdb=False,
tracks=(), get_nodes=None):
""" """
Creates a PatternSub that replaces occurrences of Creates a PatternSub that replaces occurrences of
in_pattern by occurrences of out_pattern. in_pattern by occurrences of out_pattern.
...@@ -1026,8 +1030,21 @@ class PatternSub(LocalOptimizer): ...@@ -1026,8 +1030,21 @@ class PatternSub(LocalOptimizer):
:param allow_multiple_clients: if False, the pattern matching will fail :param allow_multiple_clients: if False, the pattern matching will fail
if one of the subpatterns has more than if one of the subpatterns has more than
one client. one client.
:param skip_identities_fn: TODO
:param name: Allow to override this optimizer name
:param pdb: if True, we invoke pdb when the first node in the :param pdb: if True, we invoke pdb when the first node in the
pattern match. pattern match.
:param tracks: Optional. The values that self.tracks() will
return. Useful to speed up optimization some times.
:param get_nodes: Optional. If you provide `tracks`, you must
provide this parameter. It must be a function that take the
tracked node and return a list of node on which we will try
this optimizer.
`tracks` and `get_nodes` can be used to make this optimizer
track a less frequent Op, so this will make this optimizer
tried less frequently,
""" """
self.in_pattern = in_pattern self.in_pattern = in_pattern
self.out_pattern = out_pattern self.out_pattern = out_pattern
...@@ -1046,18 +1063,33 @@ class PatternSub(LocalOptimizer): ...@@ -1046,18 +1063,33 @@ class PatternSub(LocalOptimizer):
if name: if name:
self.__name__ = name self.__name__ = name
self.pdb = pdb self.pdb = pdb
self._tracks = tracks
self.get_nodes = get_nodes
if tracks != ():
assert get_nodes
def op_key(self): def op_key(self):
return self.op return self.op
def tracks(self): def tracks(self):
if self._tracks != ():
return self._tracks
return [self.op] return [self.op]
def transform(self, node): def transform(self, node, get_nodes=True):
""" """
Checks if the graph from node corresponds to in_pattern. If it does, Checks if the graph from node corresponds to in_pattern. If it does,
constructs out_pattern and performs the replacement. constructs out_pattern and performs the replacement.
""" """
if get_nodes and self.get_nodes is not None:
for real_node in self.get_nodes(node):
if real_node == "output":
continue
ret = self.transform(real_node, get_nodes=False)
if ret is not False and ret is not None:
assert len(real_node.outputs) == len(ret)
return dict(zip(real_node.outputs, ret))
if node.op != self.op: if node.op != self.op:
return False return False
#TODO: if we remove pdb, do this speed things up? #TODO: if we remove pdb, do this speed things up?
...@@ -1330,20 +1362,24 @@ class NavigatorOptimizer(Optimizer): ...@@ -1330,20 +1362,24 @@ class NavigatorOptimizer(Optimizer):
raise raise
if replacements is False or replacements is None: if replacements is False or replacements is None:
return False return False
if not isinstance(replacements, (tuple, list)): old_vars = node.outputs
if isinstance(replacements, dict):
old_vars = replacements.keys()
replacements = replacements.values()
elif not isinstance(replacements, (tuple, list)):
raise TypeError('Optimizer %s gave wrong type of replacement. ' raise TypeError('Optimizer %s gave wrong type of replacement. '
'Expected list or tuple.' % lopt) 'Expected list or tuple.' % lopt)
if len(node.outputs) != len(replacements): if len(old_vars) != len(replacements):
raise ValueError('Optimizer %s gave wrong number of replacements' raise ValueError('Optimizer %s gave wrong number of replacements'
% lopt) % lopt)
# None in the replacement mean that this variable isn't used # None in the replacement mean that this variable isn't used
# and we want to remove it # and we want to remove it
for r, rnew in zip(node.outputs, replacements): for r, rnew in zip(old_vars, replacements):
if rnew is None and len(r.clients) > 0: if rnew is None and len(r.clients) > 0:
raise ValueError("A local optimizer tried to remove a Variable that is used") raise ValueError("A local optimizer tried to remove a Variable that is used")
# If an output would be replaced by itself, no need to perform # If an output would be replaced by itself, no need to perform
# the replacement # the replacement
repl_pairs = [(r, rnew) for r, rnew in zip(node.outputs, replacements) repl_pairs = [(r, rnew) for r, rnew in zip(old_vars, replacements)
if rnew is not r and rnew is not None] if rnew is not r and rnew is not None]
if len(repl_pairs) == 0: if len(repl_pairs) == 0:
......
...@@ -4072,13 +4072,29 @@ def _is_minus1(expr): ...@@ -4072,13 +4072,29 @@ def _is_minus1(expr):
except NotScalarConstantError: except NotScalarConstantError:
return False return False
def get_clients(node):
"Used by erf/erfc opt to track less frequent op"
return [c for c, i in node.outputs[0].clients
if c != "output"]
def get_clients2(node):
"Used by erf/erfc opt to track less frequent op"
l = []
for c, i in node.outputs[0].clients:
if c != "output":
for var in c.outputs:
l.extend([cc for cc, ii in var.clients if cc != "output"])
return l
#1+erf(x)=>erfc(-x) #1+erf(x)=>erfc(-x)
local_one_plus_erf = gof.PatternSub((T.add, local_one_plus_erf = gof.PatternSub((T.add,
dict(pattern='y', constraint=_is_1), dict(pattern='y', constraint=_is_1),
(T.erf, 'x')), (T.erf, 'x')),
(T.erfc, (T.neg, 'x')), (T.erfc, (T.neg, 'x')),
allow_multiple_clients=True, allow_multiple_clients=True,
name='local_one_plus_erf') name='local_one_plus_erf',
tracks=[T.erf],
get_nodes=get_clients)
register_canonicalize(local_one_plus_erf) register_canonicalize(local_one_plus_erf)
register_stabilize(local_one_plus_erf) register_stabilize(local_one_plus_erf)
register_specialize(local_one_plus_erf) register_specialize(local_one_plus_erf)
...@@ -4111,7 +4127,9 @@ local_one_plus_neg_erf = gof.PatternSub((T.add, ...@@ -4111,7 +4127,9 @@ local_one_plus_neg_erf = gof.PatternSub((T.add,
(T.neg, (T.erf, 'x'))), (T.neg, (T.erf, 'x'))),
(T.erfc, 'x'), (T.erfc, 'x'),
allow_multiple_clients=True, allow_multiple_clients=True,
name='local_one_plus_neg_erf') name='local_one_plus_neg_erf',
tracks=[T.erf],
get_nodes=get_clients2)
register_canonicalize(local_one_plus_neg_erf) register_canonicalize(local_one_plus_neg_erf)
register_stabilize(local_one_plus_neg_erf) register_stabilize(local_one_plus_neg_erf)
register_specialize(local_one_plus_neg_erf) register_specialize(local_one_plus_neg_erf)
...@@ -4123,7 +4141,9 @@ local_erf_minus_one = gof.PatternSub((T.add, ...@@ -4123,7 +4141,9 @@ local_erf_minus_one = gof.PatternSub((T.add,
(T.erf, 'x')), (T.erf, 'x')),
(T.neg, (T.erfc, 'x')), (T.neg, (T.erfc, 'x')),
allow_multiple_clients=True, allow_multiple_clients=True,
name='local_erf_minus_one') name='local_erf_minus_one',
tracks=[T.erf],
get_nodes=get_clients)
register_canonicalize(local_erf_minus_one) register_canonicalize(local_erf_minus_one)
register_stabilize(local_erf_minus_one) register_stabilize(local_erf_minus_one)
register_specialize(local_erf_minus_one) register_specialize(local_erf_minus_one)
...@@ -4134,7 +4154,9 @@ local_one_minus_erfc = gof.PatternSub((T.sub, ...@@ -4134,7 +4154,9 @@ local_one_minus_erfc = gof.PatternSub((T.sub,
(T.erfc, 'x')), (T.erfc, 'x')),
(T.erf, 'x'), (T.erf, 'x'),
allow_multiple_clients=True, allow_multiple_clients=True,
name='local_one_minus_erfc') name='local_one_minus_erfc',
tracks=[T.erfc],
get_nodes=get_clients)
register_canonicalize(local_one_minus_erfc) register_canonicalize(local_one_minus_erfc)
register_stabilize(local_one_minus_erfc) register_stabilize(local_one_minus_erfc)
register_specialize(local_one_minus_erfc) register_specialize(local_one_minus_erfc)
...@@ -4144,7 +4166,9 @@ local_one_minus_erfc2 = gof.PatternSub((T.add, ...@@ -4144,7 +4166,9 @@ local_one_minus_erfc2 = gof.PatternSub((T.add,
(T.neg, (T.erfc, 'x'))), (T.neg, (T.erfc, 'x'))),
(T.erf, 'x'), (T.erf, 'x'),
allow_multiple_clients=True, allow_multiple_clients=True,
name='local_one_minus_erfc2') name='local_one_minus_erfc2',
tracks=[T.erfc],
get_nodes=get_clients2)
register_canonicalize(local_one_minus_erfc2) register_canonicalize(local_one_minus_erfc2)
register_stabilize(local_one_minus_erfc2) register_stabilize(local_one_minus_erfc2)
register_specialize(local_one_minus_erfc2) register_specialize(local_one_minus_erfc2)
...@@ -4154,7 +4178,9 @@ local_one_minus_erfc3 = gof.PatternSub((T.add, ...@@ -4154,7 +4178,9 @@ local_one_minus_erfc3 = gof.PatternSub((T.add,
(T.mul, -1, (T.erfc, 'x'))), (T.mul, -1, (T.erfc, 'x'))),
(T.erf, 'x'), (T.erf, 'x'),
allow_multiple_clients=True, allow_multiple_clients=True,
name='local_one_minus_erfc3') name='local_one_minus_erfc3',
tracks=[T.erfc],
get_nodes=get_clients2)
register_canonicalize(local_one_minus_erfc3) register_canonicalize(local_one_minus_erfc3)
register_stabilize(local_one_minus_erfc3) register_stabilize(local_one_minus_erfc3)
register_specialize(local_one_minus_erfc3) register_specialize(local_one_minus_erfc3)
...@@ -4166,7 +4192,10 @@ local_one_add_neg_erfc = gof.PatternSub((T.add, ...@@ -4166,7 +4192,10 @@ local_one_add_neg_erfc = gof.PatternSub((T.add,
(T.neg, (T.erfc, 'x'))), (T.neg, (T.erfc, 'x'))),
(T.erf, 'x'), (T.erf, 'x'),
allow_multiple_clients=True, allow_multiple_clients=True,
name='local_one_add_neg_erfc') name='local_one_add_neg_erfc',
tracks=[T.erfc],
get_nodes=get_clients2)
register_canonicalize(local_one_add_neg_erfc) register_canonicalize(local_one_add_neg_erfc)
register_stabilize(local_one_add_neg_erfc) register_stabilize(local_one_add_neg_erfc)
register_specialize(local_one_add_neg_erfc) register_specialize(local_one_add_neg_erfc)
...@@ -4177,7 +4206,9 @@ local_erf_neg_minus_one = gof.PatternSub((T.add, ...@@ -4177,7 +4206,9 @@ local_erf_neg_minus_one = gof.PatternSub((T.add,
(T.erfc, (T.neg, 'x'))), (T.erfc, (T.neg, 'x'))),
(T.erf, 'x'), (T.erf, 'x'),
allow_multiple_clients=True, allow_multiple_clients=True,
name='local_erf_neg_minus_one') name='local_erf_neg_minus_one',
tracks=[T.erfc],
get_nodes=get_clients)
register_canonicalize(local_erf_neg_minus_one) register_canonicalize(local_erf_neg_minus_one)
register_stabilize(local_erf_neg_minus_one) register_stabilize(local_erf_neg_minus_one)
register_specialize(local_erf_neg_minus_one) register_specialize(local_erf_neg_minus_one)
...@@ -4188,7 +4219,9 @@ local_erf_neg_minus_one2 = gof.PatternSub((T.add, ...@@ -4188,7 +4219,9 @@ local_erf_neg_minus_one2 = gof.PatternSub((T.add,
(T.erfc, (T.mul, -1, 'x'))), (T.erfc, (T.mul, -1, 'x'))),
(T.erf, 'x'), (T.erf, 'x'),
allow_multiple_clients=True, allow_multiple_clients=True,
name='local_erf_neg_minus_one2') name='local_erf_neg_minus_one2',
tracks=[T.erfc],
get_nodes=get_clients)
register_canonicalize(local_erf_neg_minus_one2) register_canonicalize(local_erf_neg_minus_one2)
register_stabilize(local_erf_neg_minus_one2) register_stabilize(local_erf_neg_minus_one2)
register_specialize(local_erf_neg_minus_one2) register_specialize(local_erf_neg_minus_one2)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论