提交 e53b9b32 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make the optimization database configurable in Mode

上级 66006d27
......@@ -2669,6 +2669,7 @@ class DebugMode(Mode):
check_preallocated_output=None,
require_matching_strides=None,
linker=None,
db=None,
):
"""
If any of these arguments (except optimizer) is not None, it overrides
......@@ -2685,7 +2686,7 @@ class DebugMode(Mode):
linker,
)
super().__init__(optimizer=optimizer, linker=linker)
super().__init__(optimizer=optimizer, linker=linker, db=db)
if stability_patience is not None:
self.stability_patience = stability_patience
......
......@@ -5,7 +5,7 @@ WRITEME
import logging
import warnings
from typing import Tuple, Union
from typing import Optional, Tuple, Union
import aesara
from aesara.compile.function.types import Supervisor
......@@ -20,11 +20,12 @@ from aesara.graph.opt import (
from aesara.graph.optdb import (
EquilibriumDB,
LocalGroupDB,
OptimizationDatabase,
OptimizationQuery,
SequenceDB,
TopoDB,
)
from aesara.link.basic import PerformLinker
from aesara.link.basic import Linker, PerformLinker
from aesara.link.c.basic import CLinker, OpWiseCLinker
from aesara.link.jax.linker import JAXLinker
from aesara.link.numba.linker import NumbaLinker
......@@ -281,12 +282,15 @@ class Mode:
Parameters
----------
optimizer : a structure of type Optimizer
optimizer: a structure of type Optimizer
An Optimizer may simplify the math, put similar computations together,
improve numerical stability and various other improvements.
linker : a structure of type Linker
linker: a structure of type Linker
A Linker decides which implementations to use (C or Python, for example)
and how to string them together to perform the computation.
db:
The ``OptimizationDatabase`` used by this ``Mode``. Note: This value
is *not* part of a ``Mode`` instance's pickled state.
See Also
--------
......@@ -296,12 +300,24 @@ class Mode:
"""
def __init__(self, linker=None, optimizer="default"):
def __init__(
self,
linker: Optional[Union[str, Linker]] = None,
optimizer: Union[str, OptimizationQuery] = "default",
db: OptimizationDatabase = None,
):
if linker is None:
linker = config.linker
if type(optimizer) == str and optimizer == "default":
optimizer = config.optimizer
Mode.__setstate__(self, (linker, optimizer))
self.__setstate__((linker, optimizer))
if db is None:
global optdb
self.optdb = optdb
else:
self.optdb = db
# self.provided_optimizer - typically the `optimizer` arg.
# But if the `optimizer` arg is keyword corresponding to a predefined
......@@ -316,7 +332,10 @@ class Mode:
return (self.provided_linker, self.provided_optimizer)
def __setstate__(self, state):
global optdb
linker, optimizer = state
self.optdb = optdb
self.provided_linker = linker
self.provided_optimizer = optimizer
if isinstance(linker, str) or linker is None:
......@@ -331,15 +350,16 @@ class Mode:
self.fn_time = 0
def __str__(self):
return "{}(linker = {}, optimizer = {})".format(
self.__class__.__name__,
self.provided_linker,
self.provided_optimizer,
return (
f"{self.__class__.__name__}("
f"linker={self.provided_linker}, "
f"optimizer={self.provided_optimizer}, "
f"optdb={self.optdb})"
)
def __get_optimizer(self):
if isinstance(self._optimizer, OptimizationQuery):
return optdb.query(self._optimizer)
return self.optdb.query(self._optimizer)
else:
return self._optimizer
......
......@@ -37,9 +37,9 @@ class MonitorMode(Mode):
"""
def __init__(self, pre_func=None, post_func=None, optimizer="default", linker=None):
self.pre_func = pre_func
self.post_func = post_func
def __init__(
self, pre_func=None, post_func=None, optimizer="default", linker=None, db=None
):
wrap_linker = WrapLinkerMany([OpWiseCLinker()], [self.eval])
if optimizer == "default":
optimizer = config.optimizer
......@@ -50,14 +50,22 @@ class MonitorMode(Mode):
linker,
)
super().__init__(wrap_linker, optimizer=optimizer)
super().__init__(linker=wrap_linker, optimizer=optimizer, db=db)
self.pre_func = pre_func
self.post_func = post_func
def __getstate__(self):
lnk, opt = super().__getstate__()
return (lnk, opt, self.pre_func, self.post_func)
def __setstate__(self, state):
lnk, opt, pre_func, post_func = state
lnk, opt, *funcs = state
if funcs:
pre_func, post_func = funcs
else:
pre_func, post_func = None, None
self.pre_func = pre_func
self.post_func = post_func
super().__setstate__((lnk, opt))
......
......@@ -207,6 +207,7 @@ class NanGuardMode(Mode):
big_is_error=None,
optimizer="default",
linker=None,
db=None,
):
self.provided_optimizer = optimizer
if nan_is_error is None:
......@@ -298,4 +299,4 @@ class NanGuardMode(Mode):
wrap_linker = aesara.link.vm.VMLinker(
callback=nan_check, callback_input=nan_check_input
)
super().__init__(wrap_linker, optimizer=self.provided_optimizer)
super().__init__(linker=wrap_linker, optimizer=self.provided_optimizer, db=db)
import pytest
from aesara.compile.function import function
from aesara.compile.mode import AddFeatureOptimizer, Mode
from aesara.configdefaults import config
from aesara.graph.features import NoOutputFromInplace
from aesara.graph.optdb import OptimizationQuery, SequenceDB
from aesara.tensor.math import dot, tanh
from aesara.tensor.type import matrix
@pytest.mark.skipif(
not config.cxx, reason="G++ not available, so we need to skip this test."
)
def test_Mode_basic():
db = SequenceDB()
mode = Mode(linker="py", optimizer=OptimizationQuery(include=None), db=db)
assert mode.optdb is db
assert str(mode).startswith("Mode(linker=py, optimizer=OptimizationQuery")
def test_no_output_from_implace():
x = matrix()
y = matrix()
......@@ -26,7 +30,7 @@ def test_no_output_from_implace():
# Ensure that the elemwise op that produces the output is not inplace when
# using a mode that includes the optimization
opt = AddFeatureOptimizer(NoOutputFromInplace())
mode_opt = Mode(linker="cvm", optimizer="fast_run").register((opt, 49.9))
mode_opt = Mode(linker="py", optimizer="fast_run").register((opt, 49.9))
fct_opt = function([x, y], b, mode=mode_opt)
op = fct_opt.maker.fgraph.outputs[0].owner.op
......@@ -35,4 +39,7 @@ def test_no_output_from_implace():
def test_including():
mode = Mode(optimizer="merge")
mode.including("fast_compile")
assert set(mode._optimizer.include) == {"merge"}
new_mode = mode.including("fast_compile")
assert set(new_mode._optimizer.include) == {"merge", "fast_compile"}
......@@ -157,7 +157,7 @@ class RecordMode(Mode):
self.record = record
self.known_fgraphs = set()
def __init__(self, record=None, **kwargs):
def __init__(self, record=None, db=None, **kwargs):
"""
Takes either a Record object or the keyword arguments to make one.
......@@ -252,4 +252,4 @@ class RecordMode(Mode):
linker = VMLinker(use_cloop=bool(config.cxx))
wrap_linker = WrapLinkerMany([linker], [callback])
super().__init__(wrap_linker, optimizer="fast_run")
super().__init__(linker=wrap_linker, optimizer="fast_run", db=db)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论