提交 f2be9699 authored 作者: Mridul Seth's avatar Mridul Seth 提交者: Brandon T. Willard

Remove optimize_graph_with_cache, it's not used

Fixes https://github.com/aesara-devs/aesara/issues/733 Clean up tests and config variable cache_optimizations
上级 4355a58f
......@@ -6,8 +6,6 @@ Driver of graph construction, optimization, and linking.
import copy
import copyreg
import logging
import os
import pickle
import time
import warnings
from itertools import chain
......@@ -17,7 +15,6 @@ import numpy as np
import aesara
import aesara.compile.profiling
from aesara.compile.compilelock import lock_ctx
from aesara.compile.io import In, SymbolicInput, SymbolicOutput
from aesara.compile.ops import deep_copy_op, view_op
from aesara.configdefaults import config
......@@ -27,13 +24,11 @@ from aesara.graph.basic import (
ancestors,
clone_get_equiv,
graph_inputs,
vars_between,
)
from aesara.graph.destroyhandler import DestroyHandler
from aesara.graph.features import PreserveVariableAttributes
from aesara.graph.fg import FunctionGraph, InconsistencyError
from aesara.graph.op import HasInnerGraph
from aesara.graph.opt_utils import is_same_graph
from aesara.graph.utils import get_variable_trace_string
from aesara.link.basic import Container
from aesara.link.utils import raise_with_op
......@@ -1363,159 +1358,6 @@ class FunctionMaker:
else:
raise TypeError(f"Unknown output type: {type(output)} ({output})")
def optimize_graph_with_cache(self, optimizer, inputs, outputs):
# This function is not finished
graph_db_file = os.path.join(config.compiledir, "optimized_graphs.pkl")
# the inputs, outputs, and size of the graph to be optimized
inputs_new = [inp.variable for inp in inputs]
outputs_new = [out.variable for out in outputs]
size_new = len(self.fgraph.apply_nodes)
# Beginning of cache optimizations.
# Could be refactored in different functions.
def load_graph_db():
if os.path.isfile(graph_db_file):
print("graph_db already exists")
else:
# create graph_db
with open(graph_db_file, "wb") as f:
print(f"create new graph_db in {graph_db_file}")
# load the graph_db dictionary
try:
with open(graph_db_file, "rb") as f, config.change_flags(
unpickle_function=False
):
# Temporary hack to allow
# tests.scan.test_scan.T_Scan to
# finish. Should be changed in definitive version.
graph_db = pickle.load(f)
print("graph_db loaded and it is not empty")
except EOFError as e:
# the file has nothing in it
print(e)
print("graph_db loaded and it is empty")
graph_db = {}
return graph_db
def find_same_graph_in_db(graph_db):
# If found_graph_in_db is None, then need to optimize.
# Otherwise, return the graph found.
found_graph_in_db = None
# The sole purpose of this loop is to set 'need_optimize' by
# going through graph_db, looking for graph that has the same
# computation performed.
for graph_old, graph_optimized in graph_db.items():
inputs_old = graph_old.inputs
outputs_old = graph_old.outputs
size_old = len(graph_old.apply_nodes)
# Some heuristics to check is the same graphs have
# already been optimized before.
if len(inputs_new) != len(inputs_old):
# If the inputs are of different size,
# two graphs are for sure different
print("need to optimize, because input size is different")
continue
elif len(outputs_new) != len(outputs_old):
# If the inputs are of different size,
# two graphs are for sure different
print("need to optimize, because output size is different")
continue
elif not all(
input_new.type == input_old.type
for input_new, input_old in zip(inputs_new, inputs_old)
):
print("need to optimize, because inputs are of different " "types")
continue
elif not all(
output_new.type == output_old.type
for output_new, output_old in zip(outputs_new, outputs_old)
):
print("need to optimize, because outputs are of different " "types")
continue
elif not size_old == size_new:
print(
"need to optimize, because numbers of nodes in graph"
" are different"
)
continue
else:
flags = []
for i, (output_new, output_old) in enumerate(
zip(outputs_new, outputs_old)
):
print("loop through outputs node for both graphs")
graph_old.variables = set(
vars_between(graph_old.inputs, graph_old.outputs)
)
# using clone allowed to avoid a lot of errors
# deep copy seemed to had.
f2 = graph_old.clone(check_integrity=False)
t1 = output_new
t2 = f2.outputs[i]
givens = dict(
zip(
graph_inputs([t1]),
graph_inputs([t2]),
)
)
temp = dict(
zip(
graph_inputs([t1]),
graph_inputs([t2]),
)
)
# hack to remove inconsistent entry in givens
# seems to work that but source of inconsistency
# could be worth investigating.
for key, value in temp.items():
if key.type != value.type:
del givens[key]
flag = is_same_graph(t1, t2, givens=givens)
flags.append(flag)
is_same = all(flags)
if is_same:
# found the match
print("found a match, no need to optimize")
found_graph_in_db = graph_optimized
break
return found_graph_in_db
with lock_ctx():
graph_db = load_graph_db()
print(f"loaded graph_db from {graph_db_file}, size={len(graph_db)}")
found_graph = find_same_graph_in_db(graph_db)
if found_graph:
self.fgraph = found_graph
optimizer_profile = None
else:
# this is a brand new graph, optimize it, save it to graph_db
print("graph not found in graph_db, optimizing the graph")
self.fgraph.variables = set(
vars_between(self.fgraph.inputs, self.fgraph.outputs)
)
# check_integrity parameters was added to ignore
# "excess cached variables" errors. Works that way
# but once again the error couldbe worth
# investigating.
before_opt = self.fgraph.clone(check_integrity=False)
optimizer_profile = optimizer(self.fgraph)
graph_db.update({before_opt: self.fgraph})
with open(graph_db_file, "wb") as f:
pickle.dump(graph_db, f, -1)
print("new graph saved into graph_db")
return optimizer_profile
def __init__(
self,
inputs,
......@@ -1609,13 +1451,7 @@ class FunctionMaker:
compute_test_value=config.compute_test_value_opt,
traceback__limit=config.traceback__compile_limit,
):
# now optimize the graph
if config.cache_optimizations:
optimizer_profile = self.optimize_graph_with_cache(
optimizer, inputs, outputs
)
else:
optimizer_profile = optimizer(fgraph)
optimizer_profile = optimizer(fgraph)
end_optimizer = time.time()
opt_time = end_optimizer - start_optimizer
......
......@@ -1415,17 +1415,6 @@ def add_vm_configvars():
def add_deprecated_configvars():
# TODO: remove this?
config.add(
"cache_optimizations",
"WARNING: work in progress, does not work yet. "
"Specify if the optimization cache should be used. This cache will "
"any optimized graph and its optimization. Actually slow downs a lot "
"the first optimization, and could possibly still contains some bugs. "
"Use at your own risks.",
BoolParam(False),
in_c_key=False,
)
# TODO: remove this?
config.add(
......
import copy
import os
import pickle
import time
......@@ -28,7 +27,6 @@ from aesara.tensor.type import (
dscalar,
dscalars,
dvector,
fmatrix,
fscalar,
iscalar,
matrix,
......@@ -1221,40 +1219,3 @@ def test_sync_update():
d1 = t_1 - t_0
d2 = t_2 - t_1
assert d1 > d2, (d1, d2)
def test_FunctionMaker_cache_optimizations():
opt_db_file = os.path.join(config.compiledir, "optimized_graphs.pkl")
if os.path.exists(opt_db_file):
os.remove(opt_db_file)
floatX = "float32"
mode = config.mode
if mode in ["DEBUG_MODE", "DebugMode"]:
mode = "FAST_RUN"
graph_db_file = os.path.join(config.compiledir, "optimized_graphs.pkl")
assert not os.path.exists(graph_db_file)
with config.change_flags(cache_optimizations=True):
a = fmatrix("a")
b = fmatrix("b")
c = shared(np.ones((10, 10), dtype=floatX))
d = shared(np.ones((10, 10), dtype=floatX))
e = aet_sum(aet_sum(aet_sum(a ** 2 + b) + c) + d)
f1 = function([a, b], e, mode=mode)
# FIXME: We can do much better about testing this.
assert os.path.exists(graph_db_file)
m = fmatrix("x1")
n = fmatrix("x2")
p = shared(np.ones((10, 10), dtype=floatX))
q = shared(np.ones((10, 10), dtype=floatX))
j = aet_sum(aet_sum(aet_sum(m ** 2 + n) + p) + q)
f2 = function([m, n], j, mode=mode)
in1 = np.ones((10, 10), dtype=floatX)
in2 = np.ones((10, 10), dtype=floatX)
assert f1(in1, in2) == f2(in1, in2)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论