提交 f2be9699 authored 作者: Mridul Seth's avatar Mridul Seth 提交者: Brandon T. Willard

Remove optimize_graph_with_cache, it's not used

Fixes https://github.com/aesara-devs/aesara/issues/733 Clean up tests and config variable cache_optimizations
上级 4355a58f
...@@ -6,8 +6,6 @@ Driver of graph construction, optimization, and linking. ...@@ -6,8 +6,6 @@ Driver of graph construction, optimization, and linking.
import copy import copy
import copyreg import copyreg
import logging import logging
import os
import pickle
import time import time
import warnings import warnings
from itertools import chain from itertools import chain
...@@ -17,7 +15,6 @@ import numpy as np ...@@ -17,7 +15,6 @@ import numpy as np
import aesara import aesara
import aesara.compile.profiling import aesara.compile.profiling
from aesara.compile.compilelock import lock_ctx
from aesara.compile.io import In, SymbolicInput, SymbolicOutput from aesara.compile.io import In, SymbolicInput, SymbolicOutput
from aesara.compile.ops import deep_copy_op, view_op from aesara.compile.ops import deep_copy_op, view_op
from aesara.configdefaults import config from aesara.configdefaults import config
...@@ -27,13 +24,11 @@ from aesara.graph.basic import ( ...@@ -27,13 +24,11 @@ from aesara.graph.basic import (
ancestors, ancestors,
clone_get_equiv, clone_get_equiv,
graph_inputs, graph_inputs,
vars_between,
) )
from aesara.graph.destroyhandler import DestroyHandler from aesara.graph.destroyhandler import DestroyHandler
from aesara.graph.features import PreserveVariableAttributes from aesara.graph.features import PreserveVariableAttributes
from aesara.graph.fg import FunctionGraph, InconsistencyError from aesara.graph.fg import FunctionGraph, InconsistencyError
from aesara.graph.op import HasInnerGraph from aesara.graph.op import HasInnerGraph
from aesara.graph.opt_utils import is_same_graph
from aesara.graph.utils import get_variable_trace_string from aesara.graph.utils import get_variable_trace_string
from aesara.link.basic import Container from aesara.link.basic import Container
from aesara.link.utils import raise_with_op from aesara.link.utils import raise_with_op
...@@ -1363,159 +1358,6 @@ class FunctionMaker: ...@@ -1363,159 +1358,6 @@ class FunctionMaker:
else: else:
raise TypeError(f"Unknown output type: {type(output)} ({output})") raise TypeError(f"Unknown output type: {type(output)} ({output})")
def optimize_graph_with_cache(self, optimizer, inputs, outputs):
# This function is not finished
graph_db_file = os.path.join(config.compiledir, "optimized_graphs.pkl")
# the inputs, outputs, and size of the graph to be optimized
inputs_new = [inp.variable for inp in inputs]
outputs_new = [out.variable for out in outputs]
size_new = len(self.fgraph.apply_nodes)
# Beginning of cache optimizations.
# Could be refactored in different functions.
def load_graph_db():
if os.path.isfile(graph_db_file):
print("graph_db already exists")
else:
# create graph_db
with open(graph_db_file, "wb") as f:
print(f"create new graph_db in {graph_db_file}")
# load the graph_db dictionary
try:
with open(graph_db_file, "rb") as f, config.change_flags(
unpickle_function=False
):
# Temporary hack to allow
# tests.scan.test_scan.T_Scan to
# finish. Should be changed in definitive version.
graph_db = pickle.load(f)
print("graph_db loaded and it is not empty")
except EOFError as e:
# the file has nothing in it
print(e)
print("graph_db loaded and it is empty")
graph_db = {}
return graph_db
def find_same_graph_in_db(graph_db):
# If found_graph_in_db is None, then need to optimize.
# Otherwise, return the graph found.
found_graph_in_db = None
# The sole purpose of this loop is to set 'need_optimize' by
# going through graph_db, looking for graph that has the same
# computation performed.
for graph_old, graph_optimized in graph_db.items():
inputs_old = graph_old.inputs
outputs_old = graph_old.outputs
size_old = len(graph_old.apply_nodes)
# Some heuristics to check is the same graphs have
# already been optimized before.
if len(inputs_new) != len(inputs_old):
# If the inputs are of different size,
# two graphs are for sure different
print("need to optimize, because input size is different")
continue
elif len(outputs_new) != len(outputs_old):
# If the inputs are of different size,
# two graphs are for sure different
print("need to optimize, because output size is different")
continue
elif not all(
input_new.type == input_old.type
for input_new, input_old in zip(inputs_new, inputs_old)
):
print("need to optimize, because inputs are of different " "types")
continue
elif not all(
output_new.type == output_old.type
for output_new, output_old in zip(outputs_new, outputs_old)
):
print("need to optimize, because outputs are of different " "types")
continue
elif not size_old == size_new:
print(
"need to optimize, because numbers of nodes in graph"
" are different"
)
continue
else:
flags = []
for i, (output_new, output_old) in enumerate(
zip(outputs_new, outputs_old)
):
print("loop through outputs node for both graphs")
graph_old.variables = set(
vars_between(graph_old.inputs, graph_old.outputs)
)
# using clone allowed to avoid a lot of errors
# deep copy seemed to had.
f2 = graph_old.clone(check_integrity=False)
t1 = output_new
t2 = f2.outputs[i]
givens = dict(
zip(
graph_inputs([t1]),
graph_inputs([t2]),
)
)
temp = dict(
zip(
graph_inputs([t1]),
graph_inputs([t2]),
)
)
# hack to remove inconsistent entry in givens
# seems to work that but source of inconsistency
# could be worth investigating.
for key, value in temp.items():
if key.type != value.type:
del givens[key]
flag = is_same_graph(t1, t2, givens=givens)
flags.append(flag)
is_same = all(flags)
if is_same:
# found the match
print("found a match, no need to optimize")
found_graph_in_db = graph_optimized
break
return found_graph_in_db
with lock_ctx():
graph_db = load_graph_db()
print(f"loaded graph_db from {graph_db_file}, size={len(graph_db)}")
found_graph = find_same_graph_in_db(graph_db)
if found_graph:
self.fgraph = found_graph
optimizer_profile = None
else:
# this is a brand new graph, optimize it, save it to graph_db
print("graph not found in graph_db, optimizing the graph")
self.fgraph.variables = set(
vars_between(self.fgraph.inputs, self.fgraph.outputs)
)
# check_integrity parameters was added to ignore
# "excess cached variables" errors. Works that way
# but once again the error couldbe worth
# investigating.
before_opt = self.fgraph.clone(check_integrity=False)
optimizer_profile = optimizer(self.fgraph)
graph_db.update({before_opt: self.fgraph})
with open(graph_db_file, "wb") as f:
pickle.dump(graph_db, f, -1)
print("new graph saved into graph_db")
return optimizer_profile
def __init__( def __init__(
self, self,
inputs, inputs,
...@@ -1609,12 +1451,6 @@ class FunctionMaker: ...@@ -1609,12 +1451,6 @@ class FunctionMaker:
compute_test_value=config.compute_test_value_opt, compute_test_value=config.compute_test_value_opt,
traceback__limit=config.traceback__compile_limit, traceback__limit=config.traceback__compile_limit,
): ):
# now optimize the graph
if config.cache_optimizations:
optimizer_profile = self.optimize_graph_with_cache(
optimizer, inputs, outputs
)
else:
optimizer_profile = optimizer(fgraph) optimizer_profile = optimizer(fgraph)
end_optimizer = time.time() end_optimizer = time.time()
......
...@@ -1415,17 +1415,6 @@ def add_vm_configvars(): ...@@ -1415,17 +1415,6 @@ def add_vm_configvars():
def add_deprecated_configvars(): def add_deprecated_configvars():
# TODO: remove this?
config.add(
"cache_optimizations",
"WARNING: work in progress, does not work yet. "
"Specify if the optimization cache should be used. This cache will "
"any optimized graph and its optimization. Actually slow downs a lot "
"the first optimization, and could possibly still contains some bugs. "
"Use at your own risks.",
BoolParam(False),
in_c_key=False,
)
# TODO: remove this? # TODO: remove this?
config.add( config.add(
......
import copy import copy
import os
import pickle import pickle
import time import time
...@@ -28,7 +27,6 @@ from aesara.tensor.type import ( ...@@ -28,7 +27,6 @@ from aesara.tensor.type import (
dscalar, dscalar,
dscalars, dscalars,
dvector, dvector,
fmatrix,
fscalar, fscalar,
iscalar, iscalar,
matrix, matrix,
...@@ -1221,40 +1219,3 @@ def test_sync_update(): ...@@ -1221,40 +1219,3 @@ def test_sync_update():
d1 = t_1 - t_0 d1 = t_1 - t_0
d2 = t_2 - t_1 d2 = t_2 - t_1
assert d1 > d2, (d1, d2) assert d1 > d2, (d1, d2)
def test_FunctionMaker_cache_optimizations():
opt_db_file = os.path.join(config.compiledir, "optimized_graphs.pkl")
if os.path.exists(opt_db_file):
os.remove(opt_db_file)
floatX = "float32"
mode = config.mode
if mode in ["DEBUG_MODE", "DebugMode"]:
mode = "FAST_RUN"
graph_db_file = os.path.join(config.compiledir, "optimized_graphs.pkl")
assert not os.path.exists(graph_db_file)
with config.change_flags(cache_optimizations=True):
a = fmatrix("a")
b = fmatrix("b")
c = shared(np.ones((10, 10), dtype=floatX))
d = shared(np.ones((10, 10), dtype=floatX))
e = aet_sum(aet_sum(aet_sum(a ** 2 + b) + c) + d)
f1 = function([a, b], e, mode=mode)
# FIXME: We can do much better about testing this.
assert os.path.exists(graph_db_file)
m = fmatrix("x1")
n = fmatrix("x2")
p = shared(np.ones((10, 10), dtype=floatX))
q = shared(np.ones((10, 10), dtype=floatX))
j = aet_sum(aet_sum(aet_sum(m ** 2 + n) + p) + q)
f2 = function([m, n], j, mode=mode)
in1 = np.ones((10, 10), dtype=floatX)
in2 = np.ones((10, 10), dtype=floatX)
assert f1(in1, in2) == f2(in1, in2)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论