提交 f8d06511 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use inheritance in local optimizer Op tracking

This commit introduces a `LocalOptTracker` object that performs an MRO-based lookup of `LocalOptimizer`s that track `Op` types.
上级 d475421b
......@@ -29,8 +29,8 @@ class GraphToGPULocalOptGroup(LocalOptGroup):
def transform(self, fgraph, op, context_name, inputs, outputs):
if len(self.opts) == 0:
return
opts = self.track_map[type(op)] + self.track_map[op] + self.track_map[None]
for opt in opts:
for opt in self.tracker.get_trackers(op):
opt_start = time.time()
new_repl = opt.transform(fgraph, op, context_name, inputs, outputs)
opt_finish = time.time()
......
......@@ -6,6 +6,7 @@ amount of useful generic optimization tools.
import abc
import contextlib
import copy
import functools
import inspect
import logging
import pdb
......@@ -13,9 +14,10 @@ import sys
import time
import traceback
import warnings
from collections import OrderedDict, UserList, defaultdict, deque
from collections import UserList, defaultdict, deque
from collections.abc import Iterable
from functools import partial, reduce
from itertools import chain
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
......@@ -1269,32 +1271,90 @@ def local_optimizer(
return decorator
class LocalOptTracker:
r"""A container that maps rewrites to `Op` instances and `Op`-type inheritance."""
def __init__(self):
self.tracked_instances = {}
self.tracked_types = {}
self.untracked_opts = []
def add_tracker(self, rw: LocalOptimizer):
"""Add a `LocalOptimizer` to be keyed by its `LocalOptimizer.tracks` or applied generally."""
tracks = rw.tracks()
if tracks is None:
self.untracked_opts.append(rw)
else:
for c in tracks:
if isinstance(c, type):
self.tracked_types.setdefault(c, []).append(rw)
else:
self.tracked_instances.setdefault(c, []).append(rw)
def _find_impl(self, cls):
r"""Returns the `LocalOptimizer`\s that apply to `cls` based on inheritance.
This based on `functools._find_impl`.
"""
mro = functools._compose_mro(cls, self.tracked_types.keys())
matches = []
for t in mro:
match = self.tracked_types.get(t, None)
if match:
matches.extend(match)
return matches
@functools.lru_cache()
def get_trackers(self, op: Op) -> List[LocalOptimizer]:
"""Get all the rewrites applicable to `op`."""
return (
self._find_impl(type(op))
+ self.tracked_instances.get(op, [])
+ self.untracked_opts
)
def get_rewriters(self):
return chain(
chain.from_iterable(
chain(self.tracked_types.values(), self.tracked_instances.values())
),
self.untracked_opts,
)
class LocalOptGroup(LocalOptimizer):
r"""An optimizer that applies a list of `LocalOptimizer`\s to a node.
Parameters
----------
optimizers :
A list of optimizers to be applied to nodes.
apply_all_opts : bool (Default False)
If ``False``, it will return after the new node after the first optimizer
applied. Otherwise, it will start again with the new node until no new
optimization apply.
profile :
Whether or not to profile the optimizations.
Attributes
----------
reentrant : bool
Some global optimizer like `NavigatorOptimizer` can use this value to
determine if it ignore new nodes during a pass on the nodes. Sometimes,
``ignore_newtrees`` is not reentrant.
Some global optimizers, like `NavigatorOptimizer`, use this value to
determine if they should ignore new nodes.
retains_inputs : bool
States whether or not the inputs of a transformed node are transferred
to the outputs.
"""
def __init__(self, *optimizers, apply_all_opts=False, profile=False):
def __init__(
self, *optimizers, apply_all_opts: bool = False, profile: bool = False
):
"""
Parameters
----------
optimizers
A list of optimizers to be applied to nodes.
apply_all_opts
If ``False``, it will return after the first successfully applied
rewrite; otherwise, it will apply every applicable rewrite
incrementally.
profile
Whether or not to profile the optimizations.
"""
super().__init__()
if len(optimizers) == 1 and isinstance(optimizers[0], list):
# This happen when created by LocalGroupDB.
optimizers = tuple(optimizers[0])
......@@ -1307,26 +1367,25 @@ class LocalOptGroup(LocalOptimizer):
)
self.apply_all_opts = apply_all_opts
self.profile = profile
self.track_map = defaultdict(lambda: [])
if self.profile:
self.time_opts = {}
self.process_count = {}
self.applied_true = {}
self.node_created = {}
self.tracker = LocalOptTracker()
for o in self.opts:
self.tracker.add_tracker(o)
if self.profile:
self.time_opts.setdefault(o, 0)
self.process_count.setdefault(o, 0)
self.applied_true.setdefault(o, 0)
self.node_created.setdefault(o, 0)
tracks = o.tracks()
if tracks is None:
self.track_map[None].append(o)
else:
for c in tracks:
self.track_map[c].append(o)
def __str__(self):
return getattr(
......@@ -1346,13 +1405,12 @@ class LocalOptGroup(LocalOptimizer):
def transform(self, fgraph, node):
if len(self.opts) == 0:
return
repl = None
while True:
opts = (
self.track_map[type(node.op)]
+ self.track_map[node.op]
+ self.track_map[None]
)
opts = self.tracker.get_trackers(node.op)
new_repl = None
for opt in opts:
opt_start = time.time()
......@@ -2333,38 +2391,27 @@ class EquilibriumOptimizer(NavigatorOptimizer):
super().__init__(
None, ignore_newtrees=ignore_newtrees, failure_callback=failure_callback
)
self.local_optimizers_map = OrderedDict()
self.local_optimizers_all = []
self.global_optimizers = []
self.final_optimizers = []
self.cleanup_optimizers = []
self.tracks_on_change_inputs = tracks_on_change_inputs
self.local_tracker = LocalOptTracker()
for opt in optimizers:
if isinstance(opt, LocalOptimizer):
if opt.tracks() is None:
self.local_optimizers_all.append(opt)
else:
for c in opt.tracks():
self.local_optimizers_map.setdefault(c, []).append(opt)
self.local_tracker.add_tracker(opt)
else:
self.global_optimizers.append(opt)
if final_optimizers:
self.final_optimizers = final_optimizers
if cleanup_optimizers:
self.cleanup_optimizers = cleanup_optimizers
self.max_use_ratio = max_use_ratio
assert self.max_use_ratio is not None, "max_use_ratio has to be a number"
def get_local_optimizers(self):
for opt in self.local_optimizers_all:
yield opt
# if repeat is not a problem we can drop the set
s = set()
for lopt in self.local_optimizers_map.values():
for opt in lopt:
if opt not in s:
yield opt
s.add(opt)
yield from self.local_tracker.get_rewriters()
def add_requirements(self, fgraph):
super().add_requirements(fgraph)
......@@ -2496,11 +2543,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
if node not in fgraph.apply_nodes:
continue
current_node = node
for lopt in (
self.local_optimizers_all
+ self.local_optimizers_map.get(type(node.op), [])
+ self.local_optimizers_map.get(node.op, [])
):
for lopt in self.local_tracker.get_trackers(node.op):
nb = change_tracker.nb_imported
t_opt = time.time()
lopt_change = self.process_node(fgraph, node, lopt)
......
......@@ -8,6 +8,7 @@ from aesara.graph.op import Op
from aesara.graph.opt import (
EquilibriumOptimizer,
LocalOptGroup,
LocalOptTracker,
MergeOptimizer,
OpKeyOptimizer,
OpSub,
......@@ -755,3 +756,59 @@ def test_local_optimizer():
# This is not allowed by `tracks`
local_opt_1.transform(fgraph, fgraph.outputs[2].owner)
assert hits[0] == 2
def test_TrackingLocalOptimizer():
@local_optimizer(None)
def local_opt_1(fgraph, node):
pass
@local_optimizer([op1])
def local_opt_2(fgraph, node):
pass
@local_optimizer([Op])
def local_opt_3(fgraph, node):
pass
@local_optimizer([MyOp])
def local_opt_4(fgraph, node):
pass
@local_optimizer([MyOp])
def local_opt_5(fgraph, node):
pass
tracker = LocalOptTracker()
tracker.add_tracker(local_opt_1)
tracker.add_tracker(local_opt_2)
tracker.add_tracker(local_opt_3)
tracker.add_tracker(local_opt_4)
tracker.add_tracker(local_opt_5)
assert tracker.tracked_instances == {op1: [local_opt_2]}
assert tracker.tracked_types == {
Op: [local_opt_3],
MyOp: [local_opt_4, local_opt_5],
}
assert tracker.untracked_opts == [local_opt_1]
res = tracker.get_trackers(op1)
assert res == [local_opt_4, local_opt_5, local_opt_3, local_opt_2, local_opt_1]
class MyNewOp(Op):
def perform(self, *args):
pass
new_op = MyNewOp()
res = tracker.get_trackers(new_op)
assert res == [local_opt_3, local_opt_1]
assert list(tracker.get_rewriters()) == [
local_opt_3,
local_opt_4,
local_opt_5,
local_opt_2,
local_opt_1,
]
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论