提交 ac213377 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename EquilibriumOptimizer to EquilibriumGraphRewriter

上级 e6635af8
......@@ -212,10 +212,10 @@ optdb.register(
"canonicalize_db",
position=1,
)
# Register in the canonizer Equilibrium as a clean up opt the merge opt.
# Register in the canonizer Equilibrium as a clean-up rewrite the merge rewrite.
# Without this, as the equilibrium have ignore_newtrees=False, we
# won't merge all nodes if it is set as a global optimizer with
# final_opt=True.
# won't merge all nodes if it is set as a global rewriter with
# final_rewriter=True.
# We need a new instance of MergeOptimizer to don't have its name
# changed by other usage of it.
......
......@@ -1107,7 +1107,7 @@ def add_optimizer_configvars():
config.add(
"optdb__max_use_ratio",
"A ratio that prevent infinite loop in EquilibriumOptimizer.",
"A ratio that prevent infinite loop in EquilibriumGraphRewriter.",
FloatParam(8),
in_c_key=False,
)
......
......@@ -2227,26 +2227,26 @@ def merge_dict(d1, d2):
return d
class EquilibriumOptimizer(NodeProcessingGraphRewriter):
"""An `Rewriter` that applies an optimization until a fixed-point/equilibrium is reached."""
class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
"""A `Rewriter` that applies its rewrites until a fixed-point/equilibrium is reached."""
def __init__(
self,
optimizers: Sequence[Rewriter],
rewriters: Sequence[Rewriter],
failure_callback: Optional[FailureCallbackType] = None,
ignore_newtrees: bool = True,
tracks_on_change_inputs: bool = False,
max_use_ratio: Optional[float] = None,
final_optimizers: Optional[Sequence[GraphRewriter]] = None,
cleanup_optimizers: Optional[Sequence[GraphRewriter]] = None,
final_rewriters: Optional[Sequence[GraphRewriter]] = None,
cleanup_rewriters: Optional[Sequence[GraphRewriter]] = None,
):
"""
Parameters
----------
optimizers
rewriters
Node or graph rewriters to apply until equilibrium.
The global optimizer will be run at the start of each iteration before
The global rewriter will be run at the start of each iteration before
the node rewriter.
failure_callback
See :attr:`NodeProcessingGraphRewriter.failure_callback`.
......@@ -2257,9 +2257,9 @@ class EquilibriumOptimizer(NodeProcessingGraphRewriter):
max_use_ratio
Each rewriter can be applied at most ``(size_of_graph * max_use_ratio)``
times.
final_optimizers
final_rewriters
Rewriters that will be run after each iteration.
cleanup_optimizers
cleanup_rewriters
Rewriters applied after all graph rewriters, then when one
`NodeRewriter` is applied, then after all final rewriters.
They should not traverse the entire graph, since they are called
......@@ -2270,27 +2270,27 @@ class EquilibriumOptimizer(NodeProcessingGraphRewriter):
super().__init__(
None, ignore_newtrees=ignore_newtrees, failure_callback=failure_callback
)
self.global_optimizers: List[GraphRewriter] = []
self.global_rewriters: List[GraphRewriter] = []
self.tracks_on_change_inputs = tracks_on_change_inputs
self.node_tracker = OpToRewriterTracker()
for opt in optimizers:
if isinstance(opt, NodeRewriter):
self.node_tracker.add_tracker(opt)
for rewriter in rewriters:
if isinstance(rewriter, NodeRewriter):
self.node_tracker.add_tracker(rewriter)
else:
assert isinstance(opt, GraphRewriter)
self.global_optimizers.append(opt)
assert isinstance(rewriter, GraphRewriter)
self.global_rewriters.append(rewriter)
if final_optimizers:
self.final_optimizers = list(final_optimizers)
if final_rewriters:
self.final_rewriters = list(final_rewriters)
else:
self.final_optimizers = []
self.final_rewriters = []
if cleanup_optimizers:
self.cleanup_optimizers = list(cleanup_optimizers)
if cleanup_rewriters:
self.cleanup_rewriters = list(cleanup_rewriters)
else:
self.cleanup_optimizers = []
self.cleanup_rewriters = []
self.max_use_ratio = max_use_ratio
......@@ -2307,14 +2307,14 @@ class EquilibriumOptimizer(NodeProcessingGraphRewriter):
def add_requirements(self, fgraph):
super().add_requirements(fgraph)
for opt in self.get_node_rewriters():
opt.add_requirements(fgraph)
for opt in self.global_optimizers:
opt.add_requirements(fgraph)
for opt in self.final_optimizers:
opt.add_requirements(fgraph)
for opt in self.cleanup_optimizers:
opt.add_requirements(fgraph)
for rewriter in self.get_node_rewriters():
rewriter.add_requirements(fgraph)
for rewriter in self.global_rewriters:
rewriter.add_requirements(fgraph)
for rewriter in self.final_rewriters:
rewriter.add_requirements(fgraph)
for rewriter in self.cleanup_rewriters:
rewriter.add_requirements(fgraph)
def apply(self, fgraph, start_from=None):
change_tracker = ChangeTracker()
......@@ -2327,7 +2327,7 @@ class EquilibriumOptimizer(NodeProcessingGraphRewriter):
changed = True
max_use_abort = False
opt_name = None
rewriter_name = None
global_process_count = {}
start_nb_nodes = len(fgraph.apply_nodes)
max_nb_nodes = len(fgraph.apply_nodes)
......@@ -2335,39 +2335,39 @@ class EquilibriumOptimizer(NodeProcessingGraphRewriter):
loop_timing = []
loop_process_count = []
global_opt_timing = []
time_opts = {}
global_rewriter_timing = []
time_rewriters = {}
io_toposort_timing = []
nb_nodes = []
node_created = {}
global_sub_profs = []
final_sub_profs = []
cleanup_sub_profs = []
for opt in (
self.global_optimizers
for rewriter in (
self.global_rewriters
+ list(self.get_node_rewriters())
+ self.final_optimizers
+ self.cleanup_optimizers
+ self.final_rewriters
+ self.cleanup_rewriters
):
global_process_count.setdefault(opt, 0)
time_opts.setdefault(opt, 0)
node_created.setdefault(opt, 0)
global_process_count.setdefault(rewriter, 0)
time_rewriters.setdefault(rewriter, 0)
node_created.setdefault(rewriter, 0)
def apply_cleanup(profs_dict):
changed = False
for copt in self.cleanup_optimizers:
for crewriter in self.cleanup_rewriters:
change_tracker.reset()
nb = change_tracker.nb_imported
t_opt = time.time()
sub_prof = copt.apply(fgraph)
time_opts[copt] += time.time() - t_opt
profs_dict[copt].append(sub_prof)
t_rewrite = time.time()
sub_prof = crewriter.apply(fgraph)
time_rewriters[crewriter] += time.time() - t_rewrite
profs_dict[crewriter].append(sub_prof)
if change_tracker.changed:
process_count.setdefault(copt, 0)
process_count[copt] += 1
global_process_count[copt] += 1
process_count.setdefault(crewriter, 0)
process_count[crewriter] += 1
global_process_count[crewriter] += 1
changed = True
node_created[copt] += change_tracker.nb_imported - nb
node_created[crewriter] += change_tracker.nb_imported - nb
return changed
while changed and not max_use_abort:
......@@ -2375,32 +2375,32 @@ class EquilibriumOptimizer(NodeProcessingGraphRewriter):
t0 = time.time()
changed = False
iter_cleanup_sub_profs = {}
for copt in self.cleanup_optimizers:
iter_cleanup_sub_profs[copt] = []
for crewrite in self.cleanup_rewriters:
iter_cleanup_sub_profs[crewrite] = []
# apply global optimizers
# Apply global rewriters
sub_profs = []
for gopt in self.global_optimizers:
for grewrite in self.global_rewriters:
change_tracker.reset()
nb = change_tracker.nb_imported
t_opt = time.time()
sub_prof = gopt.apply(fgraph)
time_opts[gopt] += time.time() - t_opt
t_rewrite = time.time()
sub_prof = grewrite.apply(fgraph)
time_rewriters[grewrite] += time.time() - t_rewrite
sub_profs.append(sub_prof)
if change_tracker.changed:
process_count.setdefault(gopt, 0)
process_count[gopt] += 1
global_process_count[gopt] += 1
process_count.setdefault(grewrite, 0)
process_count[grewrite] += 1
global_process_count[grewrite] += 1
changed = True
node_created[gopt] += change_tracker.nb_imported - nb
if global_process_count[gopt] > max_use:
node_created[grewrite] += change_tracker.nb_imported - nb
if global_process_count[grewrite] > max_use:
max_use_abort = True
opt_name = getattr(gopt, "name", None) or getattr(
gopt, "__name__", ""
rewriter_name = getattr(grewrite, "name", None) or getattr(
grewrite, "__name__", ""
)
global_sub_profs.append(sub_profs)
global_opt_timing.append(float(time.time() - t0))
global_rewriter_timing.append(float(time.time() - t0))
changed |= apply_cleanup(iter_cleanup_sub_profs)
......@@ -2434,11 +2434,11 @@ class EquilibriumOptimizer(NodeProcessingGraphRewriter):
current_node = node
for node_rewriter in self.node_tracker.get_trackers(node.op):
nb = change_tracker.nb_imported
t_opt = time.time()
t_rewrite = time.time()
node_rewriter_change = self.process_node(
fgraph, node, node_rewriter
)
time_opts[node_rewriter] += time.time() - t_opt
time_rewriters[node_rewriter] += time.time() - t_rewrite
if not node_rewriter_change:
continue
process_count.setdefault(node_rewriter, 0)
......@@ -2449,48 +2449,48 @@ class EquilibriumOptimizer(NodeProcessingGraphRewriter):
changed |= apply_cleanup(iter_cleanup_sub_profs)
if global_process_count[node_rewriter] > max_use:
max_use_abort = True
opt_name = getattr(node_rewriter, "name", None) or getattr(
node_rewriter, "__name__", ""
)
rewriter_name = getattr(
node_rewriter, "name", None
) or getattr(node_rewriter, "__name__", "")
if node not in fgraph.apply_nodes:
# go to next node
break
finally:
self.detach_updater(fgraph, u)
# Apply final optimizers
# Apply final rewriters
sub_profs = []
t_before_final_opt = time.time()
for gopt in self.final_optimizers:
t_before_final_rewrites = time.time()
for grewrite in self.final_rewriters:
change_tracker.reset()
nb = change_tracker.nb_imported
t_opt = time.time()
sub_prof = gopt.apply(fgraph)
time_opts[gopt] += time.time() - t_opt
t_rewrite = time.time()
sub_prof = grewrite.apply(fgraph)
time_rewriters[grewrite] += time.time() - t_rewrite
sub_profs.append(sub_prof)
if change_tracker.changed:
process_count.setdefault(gopt, 0)
process_count[gopt] += 1
global_process_count[gopt] += 1
process_count.setdefault(grewrite, 0)
process_count[grewrite] += 1
global_process_count[grewrite] += 1
changed = True
node_created[gopt] += change_tracker.nb_imported - nb
if global_process_count[gopt] > max_use:
node_created[grewrite] += change_tracker.nb_imported - nb
if global_process_count[grewrite] > max_use:
max_use_abort = True
opt_name = getattr(gopt, "name", None) or getattr(
gopt, "__name__", ""
rewriter_name = getattr(grewrite, "name", None) or getattr(
grewrite, "__name__", ""
)
final_sub_profs.append(sub_profs)
global_opt_timing[-1] += time.time() - t_before_final_opt
# apply clean up as final opt can have done changes that
# request that
global_rewriter_timing[-1] += time.time() - t_before_final_rewrites
changed |= apply_cleanup(iter_cleanup_sub_profs)
# merge clean up profiles during that iteration.
# Merge clean up profiles during that iteration
c_sub_profs = []
for copt, sub_profs in iter_cleanup_sub_profs.items():
for crewrite, sub_profs in iter_cleanup_sub_profs.items():
sub_prof = sub_profs[0]
for s_p in sub_profs[1:]:
sub_prof = copt.merge_profile(sub_prof, s_p)
sub_prof = crewrite.merge_profile(sub_prof, s_p)
c_sub_profs.append(sub_prof)
cleanup_sub_profs.append(c_sub_profs)
......@@ -2501,9 +2501,9 @@ class EquilibriumOptimizer(NodeProcessingGraphRewriter):
if max_use_abort:
msg = (
f"EquilibriumOptimizer max'ed out by '{opt_name}'"
+ ". You can safely raise the current threshold of "
+ "{config.optdb__max_use_ratio:f} with the aesara flag 'optdb__max_use_ratio'."
f"{type(self).__name__} max'ed out by {rewriter_name}."
"You can safely raise the current threshold of "
f"{config.optdb__max_use_ratio} with the option `optdb__max_use_ratio`."
)
if config.on_opt_error == "raise":
raise AssertionError(msg)
......@@ -2511,7 +2511,7 @@ class EquilibriumOptimizer(NodeProcessingGraphRewriter):
_logger.error(msg)
fgraph.remove_feature(change_tracker)
assert len(loop_process_count) == len(loop_timing)
assert len(loop_process_count) == len(global_opt_timing)
assert len(loop_process_count) == len(global_rewriter_timing)
assert len(loop_process_count) == len(nb_nodes)
assert len(loop_process_count) == len(io_toposort_timing)
assert len(loop_process_count) == len(global_sub_profs)
......@@ -2522,9 +2522,9 @@ class EquilibriumOptimizer(NodeProcessingGraphRewriter):
loop_timing,
loop_process_count,
(start_nb_nodes, end_nb_nodes, max_nb_nodes),
global_opt_timing,
global_rewriter_timing,
nb_nodes,
time_opts,
time_rewriters,
io_toposort_timing,
node_created,
global_sub_profs,
......@@ -2543,16 +2543,16 @@ class EquilibriumOptimizer(NodeProcessingGraphRewriter):
stream, level=(level + 2), depth=(depth - 1)
)
@staticmethod
def print_profile(stream, prof, level=0):
@classmethod
def print_profile(cls, stream, prof, level=0):
(
opt,
rewrite,
loop_timing,
loop_process_count,
(start_nb_nodes, end_nb_nodes, max_nb_nodes),
global_opt_timing,
global_rewrite_timing,
nb_nodes,
time_opts,
time_rewrites,
io_toposort_timing,
node_created,
global_sub_profs,
......@@ -2561,8 +2561,12 @@ class EquilibriumOptimizer(NodeProcessingGraphRewriter):
) = prof
blanc = " " * level
print(blanc, "EquilibriumOptimizer", end=" ", file=stream)
print(blanc, getattr(opt, "name", getattr(opt, "__name__", "")), file=stream)
print(blanc, cls.__name__, end=" ", file=stream)
print(
blanc,
getattr(rewrite, "name", getattr(rewrite, "__name__", "")),
file=stream,
)
print(
blanc,
f" time {sum(loop_timing):.3f}s for {len(loop_timing)} passes",
......@@ -2574,13 +2578,13 @@ class EquilibriumOptimizer(NodeProcessingGraphRewriter):
file=stream,
)
print(blanc, f" time io_toposort {sum(io_toposort_timing):.3f}s", file=stream)
s = sum(time_opts[o] for o in opt.get_node_rewriters())
s = sum(time_rewrites[o] for o in rewrite.get_node_rewriters())
print(blanc, f" time in node rewriters {s:.3f}s", file=stream)
s = sum(time_opts[o] for o in opt.global_optimizers)
s = sum(time_rewrites[o] for o in rewrite.global_rewriters)
print(blanc, f" time in graph rewriters {s:.3f}s", file=stream)
s = sum(time_opts[o] for o in opt.final_optimizers)
s = sum(time_rewrites[o] for o in rewrite.final_rewriters)
print(blanc, f" time in final rewriters {s:.3f}s", file=stream)
s = sum(time_opts[o] for o in opt.cleanup_optimizers)
s = sum(time_rewrites[o] for o in rewrite.cleanup_rewriters)
print(blanc, f" time in cleanup rewriters {s:.3f}s", file=stream)
for i in range(len(loop_timing)):
loop_times = ""
......@@ -2594,21 +2598,21 @@ class EquilibriumOptimizer(NodeProcessingGraphRewriter):
print(
blanc,
(
f" {int(i):2d} - {loop_timing[i]:.3f}s {int(sum(loop_process_count[i].values()))} ({global_opt_timing[i]:.3f}s in graph rewriters, "
f" {int(i):2d} - {loop_timing[i]:.3f}s {int(sum(loop_process_count[i].values()))} ({global_rewrite_timing[i]:.3f}s in graph rewriters, "
f"{io_toposort_timing[i]:.3f}s io_toposort) - {int(nb_nodes[i])} nodes - {loop_times}"
),
file=stream,
)
count_opt = []
count_rewrite = []
not_used = []
not_used_time = 0
process_count = {}
for o in (
opt.global_optimizers
+ list(opt.get_node_rewriters())
+ list(opt.final_optimizers)
+ list(opt.cleanup_optimizers)
rewrite.global_rewriters
+ list(rewrite.get_node_rewriters())
+ list(rewrite.final_rewriters)
+ list(rewrite.cleanup_rewriters)
):
process_count.setdefault(o, 0)
for count in loop_process_count:
......@@ -2616,17 +2620,17 @@ class EquilibriumOptimizer(NodeProcessingGraphRewriter):
process_count[o] += v
for o, count in process_count.items():
if count > 0:
count_opt.append((time_opts[o], count, node_created[o], o))
count_rewrite.append((time_rewrites[o], count, node_created[o], o))
else:
not_used.append((time_opts[o], o))
not_used_time += time_opts[o]
not_used.append((time_rewrites[o], o))
not_used_time += time_rewrites[o]
if count_opt:
if count_rewrite:
print(
blanc, " times - times applied - nb node created - name:", file=stream
)
count_opt.sort()
for (t, count, n_created, o) in count_opt[::-1]:
count_rewrite.sort()
for (t, count, n_created, o) in count_rewrite[::-1]:
print(
blanc,
f" {t:.3f}s - {int(count)} - {int(n_created)} - {o}",
......@@ -2634,40 +2638,40 @@ class EquilibriumOptimizer(NodeProcessingGraphRewriter):
)
print(
blanc,
f" {not_used_time:.3f}s - in {len(not_used)} optimization that were not used (display only those with a runtime > 0)",
f" {not_used_time:.3f}s - in {len(not_used)} rewrites that were not used (i.e. those with a run-time of zero)",
file=stream,
)
not_used.sort(key=lambda nu: (nu[0], str(nu[1])))
for (t, o) in not_used[::-1]:
if t > 0:
# Skip opt that have 0 times, they probably wasn't even tried.
# Skip rewrites that have no run-times; they probably weren't even tried.
print(blanc + " ", f" {t:.3f}s - {o}", file=stream)
print(file=stream)
gf_opts = [
gf_rewrites = [
o
for o in (
opt.global_optimizers
+ list(opt.final_optimizers)
+ list(opt.cleanup_optimizers)
rewrite.global_rewrites
+ list(rewrite.final_rewriters)
+ list(rewrite.cleanup_rewriters)
)
if o.print_profile.__code__ is not GraphRewriter.print_profile.__code__
]
if not gf_opts:
if not gf_rewrites:
return
print(blanc, "Global, final and clean up optimizers", file=stream)
print(blanc, "Global, final, and clean up rewriters", file=stream)
for i in range(len(loop_timing)):
print(blanc, f"Iter {int(i)}", file=stream)
for o, prof in zip(opt.global_optimizers, global_sub_profs[i]):
for o, prof in zip(rewrite.global_rewriters, global_sub_profs[i]):
try:
o.print_profile(stream, prof, level + 2)
except NotImplementedError:
print(blanc, "merge not implemented for ", o)
for o, prof in zip(opt.final_optimizers, final_sub_profs[i]):
for o, prof in zip(rewrite.final_rewriters, final_sub_profs[i]):
try:
o.print_profile(stream, prof, level + 2)
except NotImplementedError:
print(blanc, "merge not implemented for ", o)
for o, prof in zip(opt.cleanup_optimizers, cleanup_sub_profs[i]):
for o, prof in zip(rewrite.cleanup_rewriters, cleanup_sub_profs[i]):
try:
o.print_profile(stream, prof, level + 2)
except NotImplementedError:
......@@ -2675,25 +2679,23 @@ class EquilibriumOptimizer(NodeProcessingGraphRewriter):
@staticmethod
def merge_profile(prof1, prof2):
# (opt, loop_timing, loop_process_count, max_nb_nodes,
# global_opt_timing, nb_nodes, time_opts, io_toposort_timing) = prof1
node_rewriters = OrderedSet(prof1[0].get_node_rewriters()).union(
prof2[0].get_node_rewriters()
)
global_optimizers = OrderedSet(prof1[0].global_optimizers).union(
prof2[0].global_optimizers
global_rewriters = OrderedSet(prof1[0].global_rewriters).union(
prof2[0].global_rewriters
)
final_optimizers = list(
OrderedSet(prof1[0].final_optimizers).union(prof2[0].final_optimizers)
final_rewriters = list(
OrderedSet(prof1[0].final_rewriters).union(prof2[0].final_rewriters)
)
cleanup_optimizers = list(
OrderedSet(prof1[0].cleanup_optimizers).union(prof2[0].cleanup_optimizers)
cleanup_rewriters = list(
OrderedSet(prof1[0].cleanup_rewriters).union(prof2[0].cleanup_rewriters)
)
new_opt = EquilibriumOptimizer(
node_rewriters.union(global_optimizers),
new_rewriter = EquilibriumGraphRewriter(
node_rewriters.union(global_rewriters),
max_use_ratio=1,
final_optimizers=final_optimizers,
cleanup_optimizers=cleanup_optimizers,
final_rewriters=final_rewriters,
cleanup_rewriters=cleanup_rewriters,
)
def add_append_list(l1, l2):
......@@ -2720,29 +2722,27 @@ class EquilibriumOptimizer(NodeProcessingGraphRewriter):
else:
process_count[process] = count
def merge(opts, attr, idx):
def merge(rewriters, attr, idx):
tmp = []
for opt in opts:
for rewriter in rewriters:
o1 = getattr(prof1[0], attr)
o2 = getattr(prof2[0], attr)
if opt in o1 and opt in o2:
p1 = prof1[idx][i][o1.index(opt)]
p2 = prof2[idx][i][o2.index(opt)]
if rewriter in o1 and rewriter in o2:
p1 = prof1[idx][i][o1.index(rewriter)]
p2 = prof2[idx][i][o2.index(rewriter)]
m = None
if hasattr(opt, "merge_profile"):
m = opt.merge_profile(p1, p2)
elif opt in o1:
m = prof1[idx][i][o1.index(opt)]
if hasattr(rewriter, "merge_profile"):
m = rewriter.merge_profile(p1, p2)
elif rewriter in o1:
m = prof1[idx][i][o1.index(rewriter)]
else:
m = prof2[idx][i][o2.index(opt)]
m = prof2[idx][i][o2.index(rewriter)]
tmp.append(m)
return tmp
global_sub_profs.append(merge(global_optimizers, "global_optimizers", 9))
final_sub_profs.append(merge(final_optimizers, "final_optimizers", 10))
cleanup_sub_profs.append(
merge(cleanup_optimizers, "cleanup_optimizers", 11)
)
global_sub_profs.append(merge(global_rewriters, "global_rewriters", 9))
final_sub_profs.append(merge(final_rewriters, "final_rewriters", 10))
cleanup_sub_profs.append(merge(cleanup_rewriters, "cleanup_rewriters", 11))
# Add the iteration done by only one of the profile.
loop_process_count.extend(prof1[2][len(loop_process_count) :])
......@@ -2756,15 +2756,15 @@ class EquilibriumOptimizer(NodeProcessingGraphRewriter):
max_nb_nodes = max(prof1[3], prof2[3])
global_opt_timing = add_append_list(prof1[4], prof2[4])
global_rewrite_timing = add_append_list(prof1[4], prof2[4])
nb_nodes = add_append_list(prof1[5], prof2[5])
time_opts = merge_dict(prof1[6], prof2[6])
time_rewrites = merge_dict(prof1[6], prof2[6])
io_toposort_timing = add_append_list(prof1[7], prof2[7])
assert (
len(loop_timing)
== len(global_opt_timing)
== len(global_rewrite_timing)
== len(global_sub_profs)
== len(io_toposort_timing)
== len(nb_nodes)
......@@ -2773,13 +2773,13 @@ class EquilibriumOptimizer(NodeProcessingGraphRewriter):
node_created = merge_dict(prof1[8], prof2[8])
return (
new_opt,
new_rewriter,
loop_timing,
loop_process_count,
max_nb_nodes,
global_opt_timing,
global_rewrite_timing,
nb_nodes,
time_opts,
time_rewrites,
io_toposort_timing,
node_created,
global_sub_profs,
......@@ -3235,6 +3235,11 @@ DEPRECATED_NAMES = [
"`OpKeyOptimizer` is deprecated: use `OpKeyGraphRewriter` instead.",
OpKeyGraphRewriter,
),
(
"EquilibriumOptimizer",
"`EquilibriumOptimizer` is deprecated: use `EquilibriumGraphRewriter` instead.",
EquilibriumGraphRewriter,
),
]
......
......@@ -31,19 +31,18 @@ class OptimizationDatabase:
def register(
self,
name: str,
optimizer: Union["OptimizationDatabase", OptimizersType],
rewriter: Union["OptimizationDatabase", OptimizersType],
*tags: str,
use_db_name_as_tag=True,
**kwargs,
):
"""Register a new optimizer to the database.
"""Register a new rewriter to the database.
Parameters
----------
name:
Name of the optimizer.
opt:
The optimizer to register.
Name of the rewriter.
rewriter:
The rewriter to register.
tags:
Tag name that allow to select the optimizer.
use_db_name_as_tag:
......@@ -58,14 +57,14 @@ class OptimizationDatabase:
"""
if not isinstance(
optimizer,
rewriter,
(
OptimizationDatabase,
aesara_opt.GraphRewriter,
aesara_opt.NodeRewriter,
),
):
raise TypeError(f"{optimizer} is not a valid optimizer type.")
raise TypeError(f"{rewriter} is not a valid optimizer type.")
if name in self.__db__:
raise ValueError(f"The tag '{name}' is already present in the database.")
......@@ -74,18 +73,18 @@ class OptimizationDatabase:
if self.name is not None:
tags = tags + (self.name,)
optimizer.name = name
rewriter.name = name
# This restriction is there because in many place we suppose that
# something in the OptimizationDatabase is there only once.
if optimizer.name in self.__db__:
if rewriter.name in self.__db__:
raise ValueError(
f"Tried to register {optimizer.name} again under the new name {name}. "
f"Tried to register {rewriter.name} again under the new name {name}. "
"The same optimization cannot be registered multiple times in"
" an ``OptimizationDatabase``; use ProxyDB instead."
)
self.__db__[name] = OrderedSet([optimizer])
self.__db__[name] = OrderedSet([rewriter])
self._names.add(name)
self.__db__[optimizer.__class__.__name__].add(optimizer)
self.__db__[rewriter.__class__.__name__].add(rewriter)
self.add_tags(name, *tags)
def add_tags(self, name, *tags):
......@@ -292,11 +291,11 @@ class OptimizationQuery:
class EquilibriumDB(OptimizationDatabase):
"""A database of rewrites that should be applied until equilibrium is reached.
Canonicalize, Stabilize, and Specialize are all equilibrium optimizations.
Canonicalize, Stabilize, and Specialize are all equilibrium rewriters.
Notes
-----
We can use `NodeRewriter` and `GraphRewriter` since `EquilibriumOptimizer`
We can use `NodeRewriter` and `GraphRewriter` since `EquilibriumGraphRewriter`
supports both.
It is probably not a good idea to have both ``ignore_newtrees == False``
......@@ -322,33 +321,47 @@ class EquilibriumDB(OptimizationDatabase):
super().__init__()
self.ignore_newtrees = ignore_newtrees
self.tracks_on_change_inputs = tracks_on_change_inputs
self.__final__: Dict[str, aesara_opt.Rewriter] = {}
self.__cleanup__: Dict[str, aesara_opt.Rewriter] = {}
self.__final__: Dict[str, bool] = {}
self.__cleanup__: Dict[str, bool] = {}
def register(self, name, obj, *tags, final_opt=False, cleanup=False, **kwargs):
if final_opt and cleanup:
raise ValueError("`final_opt` and `cleanup` cannot both be true.")
super().register(name, obj, *tags, **kwargs)
self.__final__[name] = final_opt
def register(
self,
name: str,
rewriter: Union["OptimizationDatabase", OptimizersType],
*tags: str,
final_rewriter: bool = False,
cleanup: bool = False,
**kwargs,
):
if final_rewriter and cleanup:
raise ValueError("`final_rewriter` and `cleanup` cannot both be true.")
super().register(name, rewriter, *tags, **kwargs)
self.__final__[name] = final_rewriter
self.__cleanup__[name] = cleanup
def query(self, *tags, **kwtags):
_opts = super().query(*tags, **kwtags)
final_opts = [o for o in _opts if self.__final__.get(o.name, False)]
cleanup_opts = [o for o in _opts if self.__cleanup__.get(o.name, False)]
opts = [o for o in _opts if o not in final_opts and o not in cleanup_opts]
if len(final_opts) == 0:
final_opts = None
if len(cleanup_opts) == 0:
cleanup_opts = None
return aesara_opt.EquilibriumOptimizer(
opts,
_rewriters = super().query(*tags, **kwtags)
final_rewriters = [o for o in _rewriters if self.__final__.get(o.name, False)]
cleanup_rewriters = [
o for o in _rewriters if self.__cleanup__.get(o.name, False)
]
rewriters = [
o
for o in _rewriters
if o not in final_rewriters and o not in cleanup_rewriters
]
if len(final_rewriters) == 0:
final_rewriters = None
if len(cleanup_rewriters) == 0:
cleanup_rewriters = None
return aesara_opt.EquilibriumGraphRewriter(
rewriters,
max_use_ratio=config.optdb__max_use_ratio,
ignore_newtrees=self.ignore_newtrees,
tracks_on_change_inputs=self.tracks_on_change_inputs,
failure_callback=aesara_opt.NodeProcessingGraphRewriter.warn_inplace,
final_optimizers=final_opts,
cleanup_optimizers=cleanup_opts,
final_rewriters=final_rewriters,
cleanup_rewriters=cleanup_rewriters,
)
......@@ -372,8 +385,10 @@ class SequenceDB(OptimizationDatabase):
self.failure_callback = failure_callback
def register(self, name, obj, *tags, **kwargs):
super().register(name, obj, *tags, **kwargs)
position = kwargs.pop("position", "last")
super().register(name, obj, *tags, **kwargs)
if position == "last":
if len(self.__position__) == 0:
self.__position__[name] = 0
......
......@@ -2373,7 +2373,7 @@ optdb.register(
position=75,
)
scan_eqopt1.register("all_pushout_opt", scan_seqopt1, "fast_run", "scan", position=1)
scan_eqopt1.register("all_pushout_opt", scan_seqopt1, "fast_run", "scan")
scan_seqopt1.register(
......@@ -2419,7 +2419,7 @@ scan_seqopt1.register(
scan_seqopt1.register(
"scan_pushout_add",
# TODO: Perhaps this should be an `EquilibriumOptimizer`?
# TODO: Perhaps this should be an `EquilibriumGraphRewriter`?
in2out(push_out_add_scan, ignore_newtrees=False),
"fast_run",
"more_mem",
......@@ -2434,7 +2434,6 @@ scan_eqopt2.register(
in2out(basic_opt.constant_folding, ignore_newtrees=True),
"fast_run",
"scan",
position=1,
)
......@@ -2444,14 +2443,13 @@ scan_eqopt2.register(
"remove_constants_and_unused_inputs_scan",
"fast_run",
"scan",
position=2,
)
# after const merge but before stabilize so that we can have identity
# for equivalent nodes but we still have the chance to hoist stuff out
# of the scan later.
scan_eqopt2.register("scan_merge", ScanMerge(), "fast_run", "scan", position=4)
scan_eqopt2.register("scan_merge", ScanMerge(), "fast_run", "scan")
# After Merge optimization
scan_eqopt2.register(
......@@ -2460,7 +2458,6 @@ scan_eqopt2.register(
"remove_constants_and_unused_inputs_scan",
"fast_run",
"scan",
position=5,
)
scan_eqopt2.register(
......@@ -2468,7 +2465,6 @@ scan_eqopt2.register(
in2out(scan_merge_inouts, ignore_newtrees=True),
"fast_run",
"scan",
position=6,
)
# After everything else
......@@ -2478,5 +2474,4 @@ scan_eqopt2.register(
"remove_constants_and_unused_inputs_scan",
"fast_run",
"scan",
position=8,
)
......@@ -2802,10 +2802,10 @@ def constant_folding(fgraph, node):
topo_constant_folding = in2out(
constant_folding, ignore_newtrees=True, name="topo_constant_folding"
)
register_canonicalize(topo_constant_folding, "fast_compile", final_opt=True)
register_uncanonicalize(topo_constant_folding, "fast_compile", final_opt=True)
register_stabilize(topo_constant_folding, "fast_compile", final_opt=True)
register_specialize(topo_constant_folding, "fast_compile", final_opt=True)
register_canonicalize(topo_constant_folding, "fast_compile", final_rewriter=True)
register_uncanonicalize(topo_constant_folding, "fast_compile", final_rewriter=True)
register_stabilize(topo_constant_folding, "fast_compile", final_rewriter=True)
register_specialize(topo_constant_folding, "fast_compile", final_rewriter=True)
def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None):
......@@ -3096,7 +3096,7 @@ local_elemwise_fusion = local_elemwise_fusion_op(Elemwise, elemwise_max_input_fc
class FusionOptimizer(GraphRewriter):
"""Graph rewriter that simply runs node fusion operations.
TODO: This is basically an `EquilibriumOptimizer`; we should just use that.
TODO: This is basically an `EquilibriumGraphRewriter`; we should just use that.
"""
......
......@@ -146,7 +146,7 @@ from aesara.graph.basic import Apply, view_roots
from aesara.graph.features import ReplacementDidNotRemoveError, ReplaceValidate
from aesara.graph.op import Op
from aesara.graph.opt import (
EquilibriumOptimizer,
EquilibriumGraphRewriter,
GraphRewriter,
copy_stack_trace,
in2out,
......@@ -1906,7 +1906,7 @@ blas_optdb.register(
blas_optdb.register("gemm_optimizer", GemmOptimizer(), "fast_run", position=10)
blas_optdb.register(
"local_gemm_to_gemv",
EquilibriumOptimizer(
EquilibriumGraphRewriter(
[
local_gemm_to_gemv,
local_gemm_to_ger,
......
......@@ -444,7 +444,7 @@ The following is an example that distributes dot products across additions.
import aesara
import aesara.tensor as at
from aesara.graph.kanren import KanrenRelationSub
from aesara.graph.opt import EquilibriumOptimizer
from aesara.graph.opt import EquilibriumGraphRewriter
from aesara.graph.opt_utils import optimize_graph
from aesara.tensor.math import _dot
from etuples import etuple
......@@ -484,7 +484,7 @@ The following is an example that distributes dot products across additions.
)
dot_distribute_opt = EquilibriumOptimizer([KanrenRelationSub(dot_distributeo)], max_use_ratio=10)
dot_distribute_opt = EquilibriumGraphRewriter([KanrenRelationSub(dot_distributeo)], max_use_ratio=10)
Below, we apply `dot_distribute_opt` to a few example graphs. First we create simple test graph:
......@@ -531,7 +531,7 @@ relational properties.
To do that, we will create another :class:`Rewriter` that simply reverses the arguments
to the relation :func:`dot_distributeo` and apply it to the distributed result in ``res``:
>>> dot_gather_opt = EquilibriumOptimizer([KanrenRelationSub(lambda x, y: dot_distributeo(y, x))], max_use_ratio=10)
>>> dot_gather_opt = EquilibriumGraphRewriter([KanrenRelationSub(lambda x, y: dot_distributeo(y, x))], max_use_ratio=10)
>>> rev_res = optimize_graph(res, include=[], custom_opt=dot_gather_opt, clone=False)
>>> print(aesara.pprint(rev_res))
(A @ (x + (y + (B @ (z + w)))))
......@@ -561,7 +561,7 @@ serve as a basis for filtering.
The point of :obj:`optdb` is that you might want to apply many optimizations
to a computation graph in many unique patterns. For example, you might
want to do optimization X, then optimization Y, then optimization Z. And then
maybe optimization Y is an :class:`EquilibriumOptimizer` containing :class:`NodeRewriter`\s A, B
maybe optimization Y is an :class:`EquilibriumGraphRewriter` containing :class:`NodeRewriter`\s A, B
and C which are applied on every node of the graph until they all fail to change
it. If some optimizations act up, we want an easy way to turn them off. Ditto if
some optimizations are very CPU-intensive and we don't want to take the time to
......@@ -599,7 +599,7 @@ optimizers they return will be put in their places.
An :class:`EquilibriumDB` contains :class:`NodeRewriter` or :class:`OptimizationDatabase` objects. Each of them
has a name and an arbitrary number of tags. When a :class:`OptimizationQuery` is applied to
an :class:`EquilibriumDB`, all :class:`NodeRewriter`\s that match the query are
inserted into an :class:`EquilibriumOptimizer`, which is returned. If the
inserted into an :class:`EquilibriumGraphRewriter`, which is returned. If the
:class:`SequenceDB` contains :class:`OptimizationDatabase` instances, the
:class:`OptimizationQuery` will be passed to them as well and the
:class:`NodeRewriter`\s they return will be put in their places
......@@ -859,8 +859,8 @@ This will output something like this:
0.028s for fgraph.validate()
0.131s for callback
time - (name, class, index) - validate time
0.751816s - ('canonicalize', 'EquilibriumOptimizer', 4) - 0.004s
EquilibriumOptimizer canonicalize
0.751816s - ('canonicalize', 'EquilibriumGraphRewriter', 4) - 0.004s
EquilibriumGraphRewriter canonicalize
time 0.751s for 14 passes
nb nodes (start, end, max) 108 81 117
time io_toposort 0.029s
......@@ -974,8 +974,8 @@ This will output something like this:
init io_toposort 0.00171804428101
loop time 0.000502109527588
callback_time 0.0
0.002257s - ('local_gemm_to_gemv', 'EquilibriumOptimizer', 3) - 0.000s
EquilibriumOptimizer local_gemm_to_gemv
0.002257s - ('local_gemm_to_gemv', 'EquilibriumGraphRewriter', 3) - 0.000s
EquilibriumGraphRewriter local_gemm_to_gemv
time 0.002s for 1 passes
nb nodes (start, end, max) 80 80 80
time io_toposort 0.001s
......@@ -994,8 +994,8 @@ This will output something like this:
init io_toposort 0.00138401985168
loop time 0.000202178955078
callback_time 0.0
0.031740s - ('specialize', 'EquilibriumOptimizer', 9) - 0.000s
EquilibriumOptimizer specialize
0.031740s - ('specialize', 'EquilibriumGraphRewriter', 9) - 0.000s
EquilibriumGraphRewriter specialize
time 0.031s for 2 passes
nb nodes (start, end, max) 80 78 80
time io_toposort 0.003s
......@@ -1080,8 +1080,8 @@ To understand this profile here is some explanation of how optimizations work:
.. code-block:: none
0.751816s - ('canonicalize', 'EquilibriumOptimizer', 4) - 0.004s
EquilibriumOptimizer canonicalize
0.751816s - ('canonicalize', 'EquilibriumGraphRewriter', 4) - 0.004s
EquilibriumGraphRewriter canonicalize
time 0.751s for 14 passes
nb nodes (start, end, max) 108 81 117
time io_toposort 0.029s
......@@ -1146,15 +1146,15 @@ To understand this profile here is some explanation of how optimizations work:
0.000s - local_subtensor_of_dot
0.000s - local_subtensor_merge
* ``0.751816s - ('canonicalize', 'EquilibriumOptimizer', 4) - 0.004s``
* ``0.751816s - ('canonicalize', 'EquilibriumGraphRewriter', 4) - 0.004s``
This line is from :class:`SequentialGraphRewriter`, and indicates information related
to a sub-optimizer. It means that this sub-optimizer took
a total of .7s. Its name is ``'canonicalize'``. It is an
:class:`EquilibriumOptimizer`. It was executed at index 4 by the
:class:`EquilibriumGraphRewriter`. It was executed at index 4 by the
:class:`SequentialGraphRewriter`. It spent 0.004s in the *validate* phase.
* All other lines are from the profiler of the :class:`EquilibriumOptimizer`.
* All other lines are from the profiler of the :class:`EquilibriumGraphRewriter`.
* An :class:`EquilibriumOptimizer` does multiple passes on the Apply nodes from
* An :class:`EquilibriumGraphRewriter` does multiple passes on the Apply nodes from
the graph, trying to apply local and global optimizations.
Conceptually, it tries to execute all global optimizations,
and to apply all local optimizations on all
......
......@@ -13,7 +13,7 @@ from aesara.graph.basic import Apply
from aesara.graph.fg import FunctionGraph
from aesara.graph.kanren import KanrenRelationSub
from aesara.graph.op import Op
from aesara.graph.opt import EquilibriumOptimizer
from aesara.graph.opt import EquilibriumGraphRewriter
from aesara.graph.opt_utils import optimize_graph
from aesara.graph.unify import eval_if_etuple
from aesara.tensor.math import Dot, _dot
......@@ -151,7 +151,7 @@ def test_KanrenRelationSub_dot():
),
)
distribute_opt = EquilibriumOptimizer(
distribute_opt = EquilibriumGraphRewriter(
[KanrenRelationSub(distributes)], max_use_ratio=10
)
......
......@@ -6,7 +6,7 @@ from aesara.graph.features import Feature
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.graph.opt import (
EquilibriumOptimizer,
EquilibriumGraphRewriter,
MergeOptimizer,
OpKeyGraphRewriter,
OpToRewriterTracker,
......@@ -446,7 +446,7 @@ class TestEquilibrium:
e = op3(op4(x, y))
g = FunctionGraph([x, y, z], [e])
# print g
opt = EquilibriumOptimizer(
opt = EquilibriumGraphRewriter(
[
PatternNodeRewriter((op1, "x", "y"), (op2, "x", "y")),
PatternNodeRewriter((op4, "x", "y"), (op1, "x", "y")),
......@@ -463,7 +463,7 @@ class TestEquilibrium:
e = op1(op1(op3(x, y)))
g = FunctionGraph([x, y, z], [e])
# print g
opt = EquilibriumOptimizer(
opt = EquilibriumGraphRewriter(
[
PatternNodeRewriter((op1, (op2, "x", "y")), (op4, "x", "y")),
PatternNodeRewriter((op3, "x", "y"), (op4, "x", "y")),
......@@ -488,7 +488,7 @@ class TestEquilibrium:
oldlevel = _logger.level
_logger.setLevel(logging.CRITICAL)
try:
opt = EquilibriumOptimizer(
opt = EquilibriumGraphRewriter(
[
PatternNodeRewriter((op1, "x", "y"), (op2, "x", "y")),
PatternNodeRewriter((op4, "x", "y"), (op1, "x", "y")),
......@@ -600,7 +600,7 @@ def test_patternsub_values_eq_approx(out_pattern, tracks):
e = op1(x)
fg = FunctionGraph([x], [e], clone=False)
opt = EquilibriumOptimizer(
opt = EquilibriumGraphRewriter(
[
PatternNodeRewriter(
(op1, "x"),
......@@ -633,7 +633,7 @@ def test_patternsub_invalid_dtype(out_pattern):
e = op_cast_type2(x)
fg = FunctionGraph([x], [e])
opt = EquilibriumOptimizer(
opt = EquilibriumGraphRewriter(
[
PatternNodeRewriter(
(op_cast_type2, "x"),
......
......@@ -45,8 +45,8 @@ class TestDB:
def test_EquilibriumDB(self):
eq_db = EquilibriumDB()
with pytest.raises(ValueError, match=r"`final_opt` and.*"):
eq_db.register("d", TestOpt(), final_opt=True, cleanup=True)
with pytest.raises(ValueError, match=r"`final_rewriter` and.*"):
eq_db.register("d", TestOpt(), final_rewriter=True, cleanup=True)
def test_SequenceDB(self):
seq_db = SequenceDB(failure_callback=None)
......
......@@ -7,7 +7,7 @@ from aesara.compile.function import function
from aesara.compile.mode import Mode
from aesara.graph.basic import Constant
from aesara.graph.fg import FunctionGraph
from aesara.graph.opt import EquilibriumOptimizer
from aesara.graph.opt import EquilibriumGraphRewriter
from aesara.graph.optdb import OptimizationQuery
from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.random.basic import (
......@@ -50,7 +50,7 @@ def apply_local_opt_to_rv(opt, op_fn, dist_op, dist_params, size, rng, name=None
p for p in dist_params_at + size_at if not isinstance(p, (slice, Constant))
]
mode = Mode("py", EquilibriumOptimizer([opt], max_use_ratio=100))
mode = Mode("py", EquilibriumGraphRewriter([opt], max_use_ratio=100))
f_opt = function(
f_inputs,
......@@ -519,7 +519,7 @@ def test_Subtensor_lift_restrictions():
z = x - y
fg = FunctionGraph([rng], [z], clone=False)
_ = EquilibriumOptimizer([local_subtensor_rv_lift], max_use_ratio=100).apply(fg)
_ = EquilibriumGraphRewriter([local_subtensor_rv_lift], max_use_ratio=100).apply(fg)
subtensor_node = fg.outputs[0].owner.inputs[1].owner.inputs[0].owner
assert subtensor_node == y.owner
......@@ -531,7 +531,7 @@ def test_Subtensor_lift_restrictions():
# We add `x` as an output to make sure that `is_rv_used_in_graph` handles
# `"output"` "nodes" correctly.
fg = FunctionGraph([rng], [z, x], clone=False)
EquilibriumOptimizer([local_subtensor_rv_lift], max_use_ratio=100).apply(fg)
EquilibriumGraphRewriter([local_subtensor_rv_lift], max_use_ratio=100).apply(fg)
assert fg.outputs[0] == z
assert fg.outputs[1] == x
......@@ -539,7 +539,7 @@ def test_Subtensor_lift_restrictions():
# The non-`Subtensor` client doesn't depend on the RNG state, so we can
# perform the lift
fg = FunctionGraph([rng], [z], clone=False)
EquilibriumOptimizer([local_subtensor_rv_lift], max_use_ratio=100).apply(fg)
EquilibriumGraphRewriter([local_subtensor_rv_lift], max_use_ratio=100).apply(fg)
rv_node = fg.outputs[0].owner.inputs[1].owner.inputs[0].owner
assert rv_node.op == normal
......@@ -557,7 +557,9 @@ def test_Dimshuffle_lift_restrictions():
z = x - y
fg = FunctionGraph([rng], [z, y], clone=False)
_ = EquilibriumOptimizer([local_dimshuffle_rv_lift], max_use_ratio=100).apply(fg)
_ = EquilibriumGraphRewriter([local_dimshuffle_rv_lift], max_use_ratio=100).apply(
fg
)
dimshuffle_node = fg.outputs[0].owner.inputs[1].owner
assert dimshuffle_node == y.owner
......@@ -569,7 +571,7 @@ def test_Dimshuffle_lift_restrictions():
# We add `x` as an output to make sure that `is_rv_used_in_graph` handles
# `"output"` "nodes" correctly.
fg = FunctionGraph([rng], [z, x], clone=False)
EquilibriumOptimizer([local_dimshuffle_rv_lift], max_use_ratio=100).apply(fg)
EquilibriumGraphRewriter([local_dimshuffle_rv_lift], max_use_ratio=100).apply(fg)
assert fg.outputs[0] == z
assert fg.outputs[1] == x
......@@ -577,7 +579,7 @@ def test_Dimshuffle_lift_restrictions():
# The non-`Dimshuffle` client doesn't depend on the RNG state, so we can
# perform the lift
fg = FunctionGraph([rng], [z], clone=False)
EquilibriumOptimizer([local_dimshuffle_rv_lift], max_use_ratio=100).apply(fg)
EquilibriumGraphRewriter([local_dimshuffle_rv_lift], max_use_ratio=100).apply(fg)
rv_node = fg.outputs[0].owner.inputs[1].owner
assert rv_node.op == normal
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论