提交 f8d06511 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use inheritance in local optimizer Op tracking

This commit introduces a `LocalOptTracker` object that performs an MRO-based lookup of `LocalOptimizer`s that track `Op` types.
上级 d475421b
...@@ -29,8 +29,8 @@ class GraphToGPULocalOptGroup(LocalOptGroup): ...@@ -29,8 +29,8 @@ class GraphToGPULocalOptGroup(LocalOptGroup):
def transform(self, fgraph, op, context_name, inputs, outputs): def transform(self, fgraph, op, context_name, inputs, outputs):
if len(self.opts) == 0: if len(self.opts) == 0:
return return
opts = self.track_map[type(op)] + self.track_map[op] + self.track_map[None]
for opt in opts: for opt in self.tracker.get_trackers(op):
opt_start = time.time() opt_start = time.time()
new_repl = opt.transform(fgraph, op, context_name, inputs, outputs) new_repl = opt.transform(fgraph, op, context_name, inputs, outputs)
opt_finish = time.time() opt_finish = time.time()
......
...@@ -6,6 +6,7 @@ amount of useful generic optimization tools. ...@@ -6,6 +6,7 @@ amount of useful generic optimization tools.
import abc import abc
import contextlib import contextlib
import copy import copy
import functools
import inspect import inspect
import logging import logging
import pdb import pdb
...@@ -13,9 +14,10 @@ import sys ...@@ -13,9 +14,10 @@ import sys
import time import time
import traceback import traceback
import warnings import warnings
from collections import OrderedDict, UserList, defaultdict, deque from collections import UserList, defaultdict, deque
from collections.abc import Iterable from collections.abc import Iterable
from functools import partial, reduce from functools import partial, reduce
from itertools import chain
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -1269,32 +1271,90 @@ def local_optimizer( ...@@ -1269,32 +1271,90 @@ def local_optimizer(
return decorator return decorator
class LocalOptTracker:
r"""A container that maps rewrites to `Op` instances and `Op`-type inheritance."""
def __init__(self):
self.tracked_instances = {}
self.tracked_types = {}
self.untracked_opts = []
def add_tracker(self, rw: LocalOptimizer):
"""Add a `LocalOptimizer` to be keyed by its `LocalOptimizer.tracks` or applied generally."""
tracks = rw.tracks()
if tracks is None:
self.untracked_opts.append(rw)
else:
for c in tracks:
if isinstance(c, type):
self.tracked_types.setdefault(c, []).append(rw)
else:
self.tracked_instances.setdefault(c, []).append(rw)
def _find_impl(self, cls):
r"""Returns the `LocalOptimizer`\s that apply to `cls` based on inheritance.
This based on `functools._find_impl`.
"""
mro = functools._compose_mro(cls, self.tracked_types.keys())
matches = []
for t in mro:
match = self.tracked_types.get(t, None)
if match:
matches.extend(match)
return matches
@functools.lru_cache()
def get_trackers(self, op: Op) -> List[LocalOptimizer]:
"""Get all the rewrites applicable to `op`."""
return (
self._find_impl(type(op))
+ self.tracked_instances.get(op, [])
+ self.untracked_opts
)
def get_rewriters(self):
return chain(
chain.from_iterable(
chain(self.tracked_types.values(), self.tracked_instances.values())
),
self.untracked_opts,
)
class LocalOptGroup(LocalOptimizer): class LocalOptGroup(LocalOptimizer):
r"""An optimizer that applies a list of `LocalOptimizer`\s to a node. r"""An optimizer that applies a list of `LocalOptimizer`\s to a node.
Parameters
----------
optimizers :
A list of optimizers to be applied to nodes.
apply_all_opts : bool (Default False)
If ``False``, it will return after the new node after the first optimizer
applied. Otherwise, it will start again with the new node until no new
optimization apply.
profile :
Whether or not to profile the optimizations.
Attributes Attributes
---------- ----------
reentrant : bool reentrant : bool
Some global optimizer like `NavigatorOptimizer` can use this value to Some global optimizers, like `NavigatorOptimizer`, use this value to
determine if it ignore new nodes during a pass on the nodes. Sometimes, determine if they should ignore new nodes.
``ignore_newtrees`` is not reentrant.
retains_inputs : bool retains_inputs : bool
States whether or not the inputs of a transformed node are transferred States whether or not the inputs of a transformed node are transferred
to the outputs. to the outputs.
""" """
def __init__(self, *optimizers, apply_all_opts=False, profile=False): def __init__(
self, *optimizers, apply_all_opts: bool = False, profile: bool = False
):
"""
Parameters
----------
optimizers
A list of optimizers to be applied to nodes.
apply_all_opts
If ``False``, it will return after the first successfully applied
rewrite; otherwise, it will apply every applicable rewrite
incrementally.
profile
Whether or not to profile the optimizations.
"""
super().__init__()
if len(optimizers) == 1 and isinstance(optimizers[0], list): if len(optimizers) == 1 and isinstance(optimizers[0], list):
# This happen when created by LocalGroupDB. # This happen when created by LocalGroupDB.
optimizers = tuple(optimizers[0]) optimizers = tuple(optimizers[0])
...@@ -1307,26 +1367,25 @@ class LocalOptGroup(LocalOptimizer): ...@@ -1307,26 +1367,25 @@ class LocalOptGroup(LocalOptimizer):
) )
self.apply_all_opts = apply_all_opts self.apply_all_opts = apply_all_opts
self.profile = profile self.profile = profile
self.track_map = defaultdict(lambda: [])
if self.profile: if self.profile:
self.time_opts = {} self.time_opts = {}
self.process_count = {} self.process_count = {}
self.applied_true = {} self.applied_true = {}
self.node_created = {} self.node_created = {}
self.tracker = LocalOptTracker()
for o in self.opts: for o in self.opts:
self.tracker.add_tracker(o)
if self.profile: if self.profile:
self.time_opts.setdefault(o, 0) self.time_opts.setdefault(o, 0)
self.process_count.setdefault(o, 0) self.process_count.setdefault(o, 0)
self.applied_true.setdefault(o, 0) self.applied_true.setdefault(o, 0)
self.node_created.setdefault(o, 0) self.node_created.setdefault(o, 0)
tracks = o.tracks()
if tracks is None:
self.track_map[None].append(o)
else:
for c in tracks:
self.track_map[c].append(o)
def __str__(self): def __str__(self):
return getattr( return getattr(
...@@ -1346,13 +1405,12 @@ class LocalOptGroup(LocalOptimizer): ...@@ -1346,13 +1405,12 @@ class LocalOptGroup(LocalOptimizer):
def transform(self, fgraph, node): def transform(self, fgraph, node):
if len(self.opts) == 0: if len(self.opts) == 0:
return return
repl = None repl = None
while True: while True:
opts = ( opts = self.tracker.get_trackers(node.op)
self.track_map[type(node.op)]
+ self.track_map[node.op]
+ self.track_map[None]
)
new_repl = None new_repl = None
for opt in opts: for opt in opts:
opt_start = time.time() opt_start = time.time()
...@@ -2333,38 +2391,27 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2333,38 +2391,27 @@ class EquilibriumOptimizer(NavigatorOptimizer):
super().__init__( super().__init__(
None, ignore_newtrees=ignore_newtrees, failure_callback=failure_callback None, ignore_newtrees=ignore_newtrees, failure_callback=failure_callback
) )
self.local_optimizers_map = OrderedDict()
self.local_optimizers_all = []
self.global_optimizers = [] self.global_optimizers = []
self.final_optimizers = [] self.final_optimizers = []
self.cleanup_optimizers = [] self.cleanup_optimizers = []
self.tracks_on_change_inputs = tracks_on_change_inputs self.tracks_on_change_inputs = tracks_on_change_inputs
self.local_tracker = LocalOptTracker()
for opt in optimizers: for opt in optimizers:
if isinstance(opt, LocalOptimizer): if isinstance(opt, LocalOptimizer):
if opt.tracks() is None: self.local_tracker.add_tracker(opt)
self.local_optimizers_all.append(opt)
else:
for c in opt.tracks():
self.local_optimizers_map.setdefault(c, []).append(opt)
else: else:
self.global_optimizers.append(opt) self.global_optimizers.append(opt)
if final_optimizers: if final_optimizers:
self.final_optimizers = final_optimizers self.final_optimizers = final_optimizers
if cleanup_optimizers: if cleanup_optimizers:
self.cleanup_optimizers = cleanup_optimizers self.cleanup_optimizers = cleanup_optimizers
self.max_use_ratio = max_use_ratio self.max_use_ratio = max_use_ratio
assert self.max_use_ratio is not None, "max_use_ratio has to be a number"
def get_local_optimizers(self): def get_local_optimizers(self):
for opt in self.local_optimizers_all: yield from self.local_tracker.get_rewriters()
yield opt
# if repeat is not a problem we can drop the set
s = set()
for lopt in self.local_optimizers_map.values():
for opt in lopt:
if opt not in s:
yield opt
s.add(opt)
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
super().add_requirements(fgraph) super().add_requirements(fgraph)
...@@ -2496,11 +2543,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2496,11 +2543,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
if node not in fgraph.apply_nodes: if node not in fgraph.apply_nodes:
continue continue
current_node = node current_node = node
for lopt in ( for lopt in self.local_tracker.get_trackers(node.op):
self.local_optimizers_all
+ self.local_optimizers_map.get(type(node.op), [])
+ self.local_optimizers_map.get(node.op, [])
):
nb = change_tracker.nb_imported nb = change_tracker.nb_imported
t_opt = time.time() t_opt = time.time()
lopt_change = self.process_node(fgraph, node, lopt) lopt_change = self.process_node(fgraph, node, lopt)
......
...@@ -8,6 +8,7 @@ from aesara.graph.op import Op ...@@ -8,6 +8,7 @@ from aesara.graph.op import Op
from aesara.graph.opt import ( from aesara.graph.opt import (
EquilibriumOptimizer, EquilibriumOptimizer,
LocalOptGroup, LocalOptGroup,
LocalOptTracker,
MergeOptimizer, MergeOptimizer,
OpKeyOptimizer, OpKeyOptimizer,
OpSub, OpSub,
...@@ -755,3 +756,59 @@ def test_local_optimizer(): ...@@ -755,3 +756,59 @@ def test_local_optimizer():
# This is not allowed by `tracks` # This is not allowed by `tracks`
local_opt_1.transform(fgraph, fgraph.outputs[2].owner) local_opt_1.transform(fgraph, fgraph.outputs[2].owner)
assert hits[0] == 2 assert hits[0] == 2
def test_TrackingLocalOptimizer():
@local_optimizer(None)
def local_opt_1(fgraph, node):
pass
@local_optimizer([op1])
def local_opt_2(fgraph, node):
pass
@local_optimizer([Op])
def local_opt_3(fgraph, node):
pass
@local_optimizer([MyOp])
def local_opt_4(fgraph, node):
pass
@local_optimizer([MyOp])
def local_opt_5(fgraph, node):
pass
tracker = LocalOptTracker()
tracker.add_tracker(local_opt_1)
tracker.add_tracker(local_opt_2)
tracker.add_tracker(local_opt_3)
tracker.add_tracker(local_opt_4)
tracker.add_tracker(local_opt_5)
assert tracker.tracked_instances == {op1: [local_opt_2]}
assert tracker.tracked_types == {
Op: [local_opt_3],
MyOp: [local_opt_4, local_opt_5],
}
assert tracker.untracked_opts == [local_opt_1]
res = tracker.get_trackers(op1)
assert res == [local_opt_4, local_opt_5, local_opt_3, local_opt_2, local_opt_1]
class MyNewOp(Op):
def perform(self, *args):
pass
new_op = MyNewOp()
res = tracker.get_trackers(new_op)
assert res == [local_opt_3, local_opt_1]
assert list(tracker.get_rewriters()) == [
local_opt_3,
local_opt_4,
local_opt_5,
local_opt_2,
local_opt_1,
]
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论