提交 939bfa73 authored 作者: abergeron's avatar abergeron

Merge pull request #1792 from nouiz/tracks

Faster optimizer
......@@ -775,6 +775,8 @@ class LocalOptimizer(object):
or
- <list of variables> to use in place of `node`'s outputs in the
greater graph.
- dict(old variables -> new variables). A dictionary that map
from old variables to new variables to replace.
:type node: an Apply instance
......@@ -1015,8 +1017,10 @@ class PatternSub(LocalOptimizer):
(scrabble, 'x'))
"""
def __init__(self, in_pattern, out_pattern, allow_multiple_clients=False,
skip_identities_fn=None, name=None, pdb=False):
def __init__(self, in_pattern, out_pattern,
allow_multiple_clients=False,
skip_identities_fn=None, name=None, pdb=False,
tracks=(), get_nodes=None):
"""
Creates a PatternSub that replaces occurrences of
in_pattern by occurrences of out_pattern.
......@@ -1026,8 +1030,21 @@ class PatternSub(LocalOptimizer):
:param allow_multiple_clients: if False, the pattern matching will fail
if one of the subpatterns has more than
one client.
:param skip_identities_fn: TODO
:param name: Allow to override this optimizer name
:param pdb: if True, we invoke pdb when the first node in the
pattern match.
:param tracks: Optional. The values that self.tracks() will
return. Useful to speed up optimization some times.
:param get_nodes: Optional. If you provide `tracks`, you must
provide this parameter. It must be a function that take the
tracked node and return a list of node on which we will try
this optimizer.
`tracks` and `get_nodes` can be used to make this optimizer
track a less frequent Op, so this will make this optimizer
tried less frequently,
"""
self.in_pattern = in_pattern
self.out_pattern = out_pattern
......@@ -1046,18 +1063,33 @@ class PatternSub(LocalOptimizer):
if name:
self.__name__ = name
self.pdb = pdb
self._tracks = tracks
self.get_nodes = get_nodes
if tracks != ():
assert get_nodes
def op_key(self):
return self.op
def tracks(self):
if self._tracks != ():
return self._tracks
return [self.op]
def transform(self, node):
def transform(self, node, get_nodes=True):
"""
Checks if the graph from node corresponds to in_pattern. If it does,
constructs out_pattern and performs the replacement.
"""
if get_nodes and self.get_nodes is not None:
for real_node in self.get_nodes(node):
if real_node == "output":
continue
ret = self.transform(real_node, get_nodes=False)
if ret is not False and ret is not None:
assert len(real_node.outputs) == len(ret)
return dict(zip(real_node.outputs, ret))
if node.op != self.op:
return False
#TODO: if we remove pdb, do this speed things up?
......@@ -1330,20 +1362,24 @@ class NavigatorOptimizer(Optimizer):
raise
if replacements is False or replacements is None:
return False
if not isinstance(replacements, (tuple, list)):
old_vars = node.outputs
if isinstance(replacements, dict):
old_vars = replacements.keys()
replacements = replacements.values()
elif not isinstance(replacements, (tuple, list)):
raise TypeError('Optimizer %s gave wrong type of replacement. '
'Expected list or tuple.' % lopt)
if len(node.outputs) != len(replacements):
if len(old_vars) != len(replacements):
raise ValueError('Optimizer %s gave wrong number of replacements'
% lopt)
# None in the replacement mean that this variable isn't used
# and we want to remove it
for r, rnew in zip(node.outputs, replacements):
for r, rnew in zip(old_vars, replacements):
if rnew is None and len(r.clients) > 0:
raise ValueError("A local optimizer tried to remove a Variable that is used")
# If an output would be replaced by itself, no need to perform
# the replacement
repl_pairs = [(r, rnew) for r, rnew in zip(node.outputs, replacements)
repl_pairs = [(r, rnew) for r, rnew in zip(old_vars, replacements)
if rnew is not r and rnew is not None]
if len(repl_pairs) == 0:
......
......@@ -4072,13 +4072,29 @@ def _is_minus1(expr):
except NotScalarConstantError:
return False
def get_clients(node):
"Used by erf/erfc opt to track less frequent op"
return [c for c, i in node.outputs[0].clients
if c != "output"]
def get_clients2(node):
"Used by erf/erfc opt to track less frequent op"
l = []
for c, i in node.outputs[0].clients:
if c != "output":
for var in c.outputs:
l.extend([cc for cc, ii in var.clients if cc != "output"])
return l
#1+erf(x)=>erfc(-x)
local_one_plus_erf = gof.PatternSub((T.add,
dict(pattern='y', constraint=_is_1),
(T.erf, 'x')),
(T.erfc, (T.neg, 'x')),
allow_multiple_clients=True,
name='local_one_plus_erf')
name='local_one_plus_erf',
tracks=[T.erf],
get_nodes=get_clients)
register_canonicalize(local_one_plus_erf)
register_stabilize(local_one_plus_erf)
register_specialize(local_one_plus_erf)
......@@ -4111,7 +4127,9 @@ local_one_plus_neg_erf = gof.PatternSub((T.add,
(T.neg, (T.erf, 'x'))),
(T.erfc, 'x'),
allow_multiple_clients=True,
name='local_one_plus_neg_erf')
name='local_one_plus_neg_erf',
tracks=[T.erf],
get_nodes=get_clients2)
register_canonicalize(local_one_plus_neg_erf)
register_stabilize(local_one_plus_neg_erf)
register_specialize(local_one_plus_neg_erf)
......@@ -4123,7 +4141,9 @@ local_erf_minus_one = gof.PatternSub((T.add,
(T.erf, 'x')),
(T.neg, (T.erfc, 'x')),
allow_multiple_clients=True,
name='local_erf_minus_one')
name='local_erf_minus_one',
tracks=[T.erf],
get_nodes=get_clients)
register_canonicalize(local_erf_minus_one)
register_stabilize(local_erf_minus_one)
register_specialize(local_erf_minus_one)
......@@ -4134,7 +4154,9 @@ local_one_minus_erfc = gof.PatternSub((T.sub,
(T.erfc, 'x')),
(T.erf, 'x'),
allow_multiple_clients=True,
name='local_one_minus_erfc')
name='local_one_minus_erfc',
tracks=[T.erfc],
get_nodes=get_clients)
register_canonicalize(local_one_minus_erfc)
register_stabilize(local_one_minus_erfc)
register_specialize(local_one_minus_erfc)
......@@ -4144,7 +4166,9 @@ local_one_minus_erfc2 = gof.PatternSub((T.add,
(T.neg, (T.erfc, 'x'))),
(T.erf, 'x'),
allow_multiple_clients=True,
name='local_one_minus_erfc2')
name='local_one_minus_erfc2',
tracks=[T.erfc],
get_nodes=get_clients2)
register_canonicalize(local_one_minus_erfc2)
register_stabilize(local_one_minus_erfc2)
register_specialize(local_one_minus_erfc2)
......@@ -4154,7 +4178,9 @@ local_one_minus_erfc3 = gof.PatternSub((T.add,
(T.mul, -1, (T.erfc, 'x'))),
(T.erf, 'x'),
allow_multiple_clients=True,
name='local_one_minus_erfc3')
name='local_one_minus_erfc3',
tracks=[T.erfc],
get_nodes=get_clients2)
register_canonicalize(local_one_minus_erfc3)
register_stabilize(local_one_minus_erfc3)
register_specialize(local_one_minus_erfc3)
......@@ -4166,7 +4192,10 @@ local_one_add_neg_erfc = gof.PatternSub((T.add,
(T.neg, (T.erfc, 'x'))),
(T.erf, 'x'),
allow_multiple_clients=True,
name='local_one_add_neg_erfc')
name='local_one_add_neg_erfc',
tracks=[T.erfc],
get_nodes=get_clients2)
register_canonicalize(local_one_add_neg_erfc)
register_stabilize(local_one_add_neg_erfc)
register_specialize(local_one_add_neg_erfc)
......@@ -4177,7 +4206,9 @@ local_erf_neg_minus_one = gof.PatternSub((T.add,
(T.erfc, (T.neg, 'x'))),
(T.erf, 'x'),
allow_multiple_clients=True,
name='local_erf_neg_minus_one')
name='local_erf_neg_minus_one',
tracks=[T.erfc],
get_nodes=get_clients)
register_canonicalize(local_erf_neg_minus_one)
register_stabilize(local_erf_neg_minus_one)
register_specialize(local_erf_neg_minus_one)
......@@ -4188,7 +4219,9 @@ local_erf_neg_minus_one2 = gof.PatternSub((T.add,
(T.erfc, (T.mul, -1, 'x'))),
(T.erf, 'x'),
allow_multiple_clients=True,
name='local_erf_neg_minus_one2')
name='local_erf_neg_minus_one2',
tracks=[T.erfc],
get_nodes=get_clients)
register_canonicalize(local_erf_neg_minus_one2)
register_stabilize(local_erf_neg_minus_one2)
register_specialize(local_erf_neg_minus_one2)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论