提交 bd56a8f8 authored 作者: Frederic's avatar Frederic

Allow local optimizer to return a Dict that give variable to replaces.

Make PatternSub use that feature to lower optimizer time by tracking less frequent node. This allow to save 85s on 317s in optimizer time.
上级 12ec7339
......@@ -775,6 +775,8 @@ class LocalOptimizer(object):
or
- <list of variables> to use in place of `node`'s outputs in the
greater graph.
- dict(old variables -> new variables). A dictionary that map
from old variables to new variables to replace.
:type node: an Apply instance
......@@ -1015,8 +1017,10 @@ class PatternSub(LocalOptimizer):
(scrabble, 'x'))
"""
def __init__(self, in_pattern, out_pattern, allow_multiple_clients=False,
skip_identities_fn=None, name=None, pdb=False):
def __init__(self, in_pattern, out_pattern,
allow_multiple_clients=False,
skip_identities_fn=None, name=None, pdb=False,
tracks=(), get_nodes=None):
"""
Creates a PatternSub that replaces occurrences of
in_pattern by occurrences of out_pattern.
......@@ -1026,8 +1030,21 @@ class PatternSub(LocalOptimizer):
:param allow_multiple_clients: if False, the pattern matching will fail
if one of the subpatterns has more than
one client.
:param skip_identities_fn: TODO
:param name: Allow to override this optimizer name
:param pdb: if True, we invoke pdb when the first node in the
pattern match.
:param tracks: Optional. The values that self.tracks() will
return. Useful to speed up optimization some times.
:param get_nodes: Optional. If you provide `tracks`, you must
provide this parameter. It must be a function that take the
tracked node and return a list of node on which we will try
this optimizer.
`tracks` and `get_nodes` can be used to make this optimizer
track a less frequent Op, so this will make this optimizer
tried less frequently,
"""
self.in_pattern = in_pattern
self.out_pattern = out_pattern
......@@ -1046,18 +1063,33 @@ class PatternSub(LocalOptimizer):
if name:
self.__name__ = name
self.pdb = pdb
self._tracks = tracks
self.get_nodes = get_nodes
if tracks != ():
assert get_nodes
def op_key(self):
return self.op
def tracks(self):
if self._tracks != ():
return self._tracks
return [self.op]
def transform(self, node):
def transform(self, node, get_nodes=True):
"""
Checks if the graph from node corresponds to in_pattern. If it does,
constructs out_pattern and performs the replacement.
"""
if get_nodes and self.get_nodes is not None:
for real_node in self.get_nodes(node):
if real_node == "output":
continue
ret = self.transform(real_node, get_nodes=False)
if ret is not False and ret is not None:
assert len(real_node.outputs) == len(ret)
return dict(zip(real_node.outputs, ret))
if node.op != self.op:
return False
#TODO: if we remove pdb, do this speed things up?
......@@ -1330,20 +1362,24 @@ class NavigatorOptimizer(Optimizer):
raise
if replacements is False or replacements is None:
return False
if not isinstance(replacements, (tuple, list)):
old_vars = node.outputs
if isinstance(replacements, dict):
old_vars = replacements.keys()
replacements = replacements.values()
elif not isinstance(replacements, (tuple, list)):
raise TypeError('Optimizer %s gave wrong type of replacement. '
'Expected list or tuple.' % lopt)
if len(node.outputs) != len(replacements):
if len(old_vars) != len(replacements):
raise ValueError('Optimizer %s gave wrong number of replacements'
% lopt)
# None in the replacement mean that this variable isn't used
# and we want to remove it
for r, rnew in zip(node.outputs, replacements):
for r, rnew in zip(old_vars, replacements):
if rnew is None and len(r.clients) > 0:
raise ValueError("A local optimizer tried to remove a Variable that is used")
# If an output would be replaced by itself, no need to perform
# the replacement
repl_pairs = [(r, rnew) for r, rnew in zip(node.outputs, replacements)
repl_pairs = [(r, rnew) for r, rnew in zip(old_vars, replacements)
if rnew is not r and rnew is not None]
if len(repl_pairs) == 0:
......
......@@ -4072,13 +4072,29 @@ def _is_minus1(expr):
except NotScalarConstantError:
return False
def get_clients(node):
"Used by erf/erfc opt to track less frequent op"
return [c for c, i in node.outputs[0].clients
if c != "output"]
def get_clients2(node):
"Used by erf/erfc opt to track less frequent op"
l = []
for c, i in node.outputs[0].clients:
if c != "output":
for var in c.outputs:
l.extend([cc for cc, ii in var.clients if cc != "output"])
return l
#1+erf(x)=>erfc(-x)
local_one_plus_erf = gof.PatternSub((T.add,
dict(pattern='y', constraint=_is_1),
(T.erf, 'x')),
(T.erfc, (T.neg, 'x')),
allow_multiple_clients=True,
name='local_one_plus_erf')
name='local_one_plus_erf',
tracks=[T.erf],
get_nodes=get_clients)
register_canonicalize(local_one_plus_erf)
register_stabilize(local_one_plus_erf)
register_specialize(local_one_plus_erf)
......@@ -4111,7 +4127,9 @@ local_one_plus_neg_erf = gof.PatternSub((T.add,
(T.neg, (T.erf, 'x'))),
(T.erfc, 'x'),
allow_multiple_clients=True,
name='local_one_plus_neg_erf')
name='local_one_plus_neg_erf',
tracks=[T.erf],
get_nodes=get_clients2)
register_canonicalize(local_one_plus_neg_erf)
register_stabilize(local_one_plus_neg_erf)
register_specialize(local_one_plus_neg_erf)
......@@ -4123,7 +4141,9 @@ local_erf_minus_one = gof.PatternSub((T.add,
(T.erf, 'x')),
(T.neg, (T.erfc, 'x')),
allow_multiple_clients=True,
name='local_erf_minus_one')
name='local_erf_minus_one',
tracks=[T.erf],
get_nodes=get_clients)
register_canonicalize(local_erf_minus_one)
register_stabilize(local_erf_minus_one)
register_specialize(local_erf_minus_one)
......@@ -4134,7 +4154,9 @@ local_one_minus_erfc = gof.PatternSub((T.sub,
(T.erfc, 'x')),
(T.erf, 'x'),
allow_multiple_clients=True,
name='local_one_minus_erfc')
name='local_one_minus_erfc',
tracks=[T.erfc],
get_nodes=get_clients)
register_canonicalize(local_one_minus_erfc)
register_stabilize(local_one_minus_erfc)
register_specialize(local_one_minus_erfc)
......@@ -4144,7 +4166,9 @@ local_one_minus_erfc2 = gof.PatternSub((T.add,
(T.neg, (T.erfc, 'x'))),
(T.erf, 'x'),
allow_multiple_clients=True,
name='local_one_minus_erfc2')
name='local_one_minus_erfc2',
tracks=[T.erfc],
get_nodes=get_clients2)
register_canonicalize(local_one_minus_erfc2)
register_stabilize(local_one_minus_erfc2)
register_specialize(local_one_minus_erfc2)
......@@ -4154,7 +4178,9 @@ local_one_minus_erfc3 = gof.PatternSub((T.add,
(T.mul, -1, (T.erfc, 'x'))),
(T.erf, 'x'),
allow_multiple_clients=True,
name='local_one_minus_erfc3')
name='local_one_minus_erfc3',
tracks=[T.erfc],
get_nodes=get_clients2)
register_canonicalize(local_one_minus_erfc3)
register_stabilize(local_one_minus_erfc3)
register_specialize(local_one_minus_erfc3)
......@@ -4166,7 +4192,10 @@ local_one_add_neg_erfc = gof.PatternSub((T.add,
(T.neg, (T.erfc, 'x'))),
(T.erf, 'x'),
allow_multiple_clients=True,
name='local_one_add_neg_erfc')
name='local_one_add_neg_erfc',
tracks=[T.erfc],
get_nodes=get_clients2)
register_canonicalize(local_one_add_neg_erfc)
register_stabilize(local_one_add_neg_erfc)
register_specialize(local_one_add_neg_erfc)
......@@ -4177,7 +4206,9 @@ local_erf_neg_minus_one = gof.PatternSub((T.add,
(T.erfc, (T.neg, 'x'))),
(T.erf, 'x'),
allow_multiple_clients=True,
name='local_erf_neg_minus_one')
name='local_erf_neg_minus_one',
tracks=[T.erfc],
get_nodes=get_clients)
register_canonicalize(local_erf_neg_minus_one)
register_stabilize(local_erf_neg_minus_one)
register_specialize(local_erf_neg_minus_one)
......@@ -4188,7 +4219,9 @@ local_erf_neg_minus_one2 = gof.PatternSub((T.add,
(T.erfc, (T.mul, -1, 'x'))),
(T.erf, 'x'),
allow_multiple_clients=True,
name='local_erf_neg_minus_one2')
name='local_erf_neg_minus_one2',
tracks=[T.erfc],
get_nodes=get_clients)
register_canonicalize(local_erf_neg_minus_one2)
register_stabilize(local_erf_neg_minus_one2)
register_specialize(local_erf_neg_minus_one2)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论