提交 e53b9b32 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make the optimization database configurable in Mode

上级 66006d27
...@@ -2669,6 +2669,7 @@ class DebugMode(Mode): ...@@ -2669,6 +2669,7 @@ class DebugMode(Mode):
check_preallocated_output=None, check_preallocated_output=None,
require_matching_strides=None, require_matching_strides=None,
linker=None, linker=None,
db=None,
): ):
""" """
If any of these arguments (except optimizer) is not None, it overrides If any of these arguments (except optimizer) is not None, it overrides
...@@ -2685,7 +2686,7 @@ class DebugMode(Mode): ...@@ -2685,7 +2686,7 @@ class DebugMode(Mode):
linker, linker,
) )
super().__init__(optimizer=optimizer, linker=linker) super().__init__(optimizer=optimizer, linker=linker, db=db)
if stability_patience is not None: if stability_patience is not None:
self.stability_patience = stability_patience self.stability_patience = stability_patience
......
...@@ -5,7 +5,7 @@ WRITEME ...@@ -5,7 +5,7 @@ WRITEME
import logging import logging
import warnings import warnings
from typing import Tuple, Union from typing import Optional, Tuple, Union
import aesara import aesara
from aesara.compile.function.types import Supervisor from aesara.compile.function.types import Supervisor
...@@ -20,11 +20,12 @@ from aesara.graph.opt import ( ...@@ -20,11 +20,12 @@ from aesara.graph.opt import (
from aesara.graph.optdb import ( from aesara.graph.optdb import (
EquilibriumDB, EquilibriumDB,
LocalGroupDB, LocalGroupDB,
OptimizationDatabase,
OptimizationQuery, OptimizationQuery,
SequenceDB, SequenceDB,
TopoDB, TopoDB,
) )
from aesara.link.basic import PerformLinker from aesara.link.basic import Linker, PerformLinker
from aesara.link.c.basic import CLinker, OpWiseCLinker from aesara.link.c.basic import CLinker, OpWiseCLinker
from aesara.link.jax.linker import JAXLinker from aesara.link.jax.linker import JAXLinker
from aesara.link.numba.linker import NumbaLinker from aesara.link.numba.linker import NumbaLinker
...@@ -281,12 +282,15 @@ class Mode: ...@@ -281,12 +282,15 @@ class Mode:
Parameters Parameters
---------- ----------
optimizer : a structure of type Optimizer optimizer: a structure of type Optimizer
An Optimizer may simplify the math, put similar computations together, An Optimizer may simplify the math, put similar computations together,
improve numerical stability and various other improvements. improve numerical stability and various other improvements.
linker : a structure of type Linker linker: a structure of type Linker
A Linker decides which implementations to use (C or Python, for example) A Linker decides which implementations to use (C or Python, for example)
and how to string them together to perform the computation. and how to string them together to perform the computation.
db:
The ``OptimizationDatabase`` used by this ``Mode``. Note: This value
is *not* part of a ``Mode`` instance's pickled state.
See Also See Also
-------- --------
...@@ -296,12 +300,24 @@ class Mode: ...@@ -296,12 +300,24 @@ class Mode:
""" """
def __init__(self, linker=None, optimizer="default"): def __init__(
self,
linker: Optional[Union[str, Linker]] = None,
optimizer: Union[str, OptimizationQuery] = "default",
db: OptimizationDatabase = None,
):
if linker is None: if linker is None:
linker = config.linker linker = config.linker
if type(optimizer) == str and optimizer == "default": if type(optimizer) == str and optimizer == "default":
optimizer = config.optimizer optimizer = config.optimizer
Mode.__setstate__(self, (linker, optimizer))
self.__setstate__((linker, optimizer))
if db is None:
global optdb
self.optdb = optdb
else:
self.optdb = db
# self.provided_optimizer - typically the `optimizer` arg. # self.provided_optimizer - typically the `optimizer` arg.
# But if the `optimizer` arg is keyword corresponding to a predefined # But if the `optimizer` arg is keyword corresponding to a predefined
...@@ -316,7 +332,10 @@ class Mode: ...@@ -316,7 +332,10 @@ class Mode:
return (self.provided_linker, self.provided_optimizer) return (self.provided_linker, self.provided_optimizer)
def __setstate__(self, state): def __setstate__(self, state):
global optdb
linker, optimizer = state linker, optimizer = state
self.optdb = optdb
self.provided_linker = linker self.provided_linker = linker
self.provided_optimizer = optimizer self.provided_optimizer = optimizer
if isinstance(linker, str) or linker is None: if isinstance(linker, str) or linker is None:
...@@ -331,15 +350,16 @@ class Mode: ...@@ -331,15 +350,16 @@ class Mode:
self.fn_time = 0 self.fn_time = 0
def __str__(self): def __str__(self):
return "{}(linker = {}, optimizer = {})".format( return (
self.__class__.__name__, f"{self.__class__.__name__}("
self.provided_linker, f"linker={self.provided_linker}, "
self.provided_optimizer, f"optimizer={self.provided_optimizer}, "
f"optdb={self.optdb})"
) )
def __get_optimizer(self): def __get_optimizer(self):
if isinstance(self._optimizer, OptimizationQuery): if isinstance(self._optimizer, OptimizationQuery):
return optdb.query(self._optimizer) return self.optdb.query(self._optimizer)
else: else:
return self._optimizer return self._optimizer
......
...@@ -37,9 +37,9 @@ class MonitorMode(Mode): ...@@ -37,9 +37,9 @@ class MonitorMode(Mode):
""" """
def __init__(self, pre_func=None, post_func=None, optimizer="default", linker=None): def __init__(
self.pre_func = pre_func self, pre_func=None, post_func=None, optimizer="default", linker=None, db=None
self.post_func = post_func ):
wrap_linker = WrapLinkerMany([OpWiseCLinker()], [self.eval]) wrap_linker = WrapLinkerMany([OpWiseCLinker()], [self.eval])
if optimizer == "default": if optimizer == "default":
optimizer = config.optimizer optimizer = config.optimizer
...@@ -50,14 +50,22 @@ class MonitorMode(Mode): ...@@ -50,14 +50,22 @@ class MonitorMode(Mode):
linker, linker,
) )
super().__init__(wrap_linker, optimizer=optimizer) super().__init__(linker=wrap_linker, optimizer=optimizer, db=db)
self.pre_func = pre_func
self.post_func = post_func
def __getstate__(self): def __getstate__(self):
lnk, opt = super().__getstate__() lnk, opt = super().__getstate__()
return (lnk, opt, self.pre_func, self.post_func) return (lnk, opt, self.pre_func, self.post_func)
def __setstate__(self, state): def __setstate__(self, state):
lnk, opt, pre_func, post_func = state lnk, opt, *funcs = state
if funcs:
pre_func, post_func = funcs
else:
pre_func, post_func = None, None
self.pre_func = pre_func self.pre_func = pre_func
self.post_func = post_func self.post_func = post_func
super().__setstate__((lnk, opt)) super().__setstate__((lnk, opt))
......
...@@ -207,6 +207,7 @@ class NanGuardMode(Mode): ...@@ -207,6 +207,7 @@ class NanGuardMode(Mode):
big_is_error=None, big_is_error=None,
optimizer="default", optimizer="default",
linker=None, linker=None,
db=None,
): ):
self.provided_optimizer = optimizer self.provided_optimizer = optimizer
if nan_is_error is None: if nan_is_error is None:
...@@ -298,4 +299,4 @@ class NanGuardMode(Mode): ...@@ -298,4 +299,4 @@ class NanGuardMode(Mode):
wrap_linker = aesara.link.vm.VMLinker( wrap_linker = aesara.link.vm.VMLinker(
callback=nan_check, callback_input=nan_check_input callback=nan_check, callback_input=nan_check_input
) )
super().__init__(wrap_linker, optimizer=self.provided_optimizer) super().__init__(linker=wrap_linker, optimizer=self.provided_optimizer, db=db)
import pytest
from aesara.compile.function import function from aesara.compile.function import function
from aesara.compile.mode import AddFeatureOptimizer, Mode from aesara.compile.mode import AddFeatureOptimizer, Mode
from aesara.configdefaults import config
from aesara.graph.features import NoOutputFromInplace from aesara.graph.features import NoOutputFromInplace
from aesara.graph.optdb import OptimizationQuery, SequenceDB
from aesara.tensor.math import dot, tanh from aesara.tensor.math import dot, tanh
from aesara.tensor.type import matrix from aesara.tensor.type import matrix
@pytest.mark.skipif( def test_Mode_basic():
not config.cxx, reason="G++ not available, so we need to skip this test." db = SequenceDB()
) mode = Mode(linker="py", optimizer=OptimizationQuery(include=None), db=db)
assert mode.optdb is db
assert str(mode).startswith("Mode(linker=py, optimizer=OptimizationQuery")
def test_no_output_from_implace(): def test_no_output_from_implace():
x = matrix() x = matrix()
y = matrix() y = matrix()
...@@ -26,7 +30,7 @@ def test_no_output_from_implace(): ...@@ -26,7 +30,7 @@ def test_no_output_from_implace():
# Ensure that the elemwise op that produces the output is not inplace when # Ensure that the elemwise op that produces the output is not inplace when
# using a mode that includes the optimization # using a mode that includes the optimization
opt = AddFeatureOptimizer(NoOutputFromInplace()) opt = AddFeatureOptimizer(NoOutputFromInplace())
mode_opt = Mode(linker="cvm", optimizer="fast_run").register((opt, 49.9)) mode_opt = Mode(linker="py", optimizer="fast_run").register((opt, 49.9))
fct_opt = function([x, y], b, mode=mode_opt) fct_opt = function([x, y], b, mode=mode_opt)
op = fct_opt.maker.fgraph.outputs[0].owner.op op = fct_opt.maker.fgraph.outputs[0].owner.op
...@@ -35,4 +39,7 @@ def test_no_output_from_implace(): ...@@ -35,4 +39,7 @@ def test_no_output_from_implace():
def test_including(): def test_including():
mode = Mode(optimizer="merge") mode = Mode(optimizer="merge")
mode.including("fast_compile") assert set(mode._optimizer.include) == {"merge"}
new_mode = mode.including("fast_compile")
assert set(new_mode._optimizer.include) == {"merge", "fast_compile"}
...@@ -157,7 +157,7 @@ class RecordMode(Mode): ...@@ -157,7 +157,7 @@ class RecordMode(Mode):
self.record = record self.record = record
self.known_fgraphs = set() self.known_fgraphs = set()
def __init__(self, record=None, **kwargs): def __init__(self, record=None, db=None, **kwargs):
""" """
Takes either a Record object or the keyword arguments to make one. Takes either a Record object or the keyword arguments to make one.
...@@ -252,4 +252,4 @@ class RecordMode(Mode): ...@@ -252,4 +252,4 @@ class RecordMode(Mode):
linker = VMLinker(use_cloop=bool(config.cxx)) linker = VMLinker(use_cloop=bool(config.cxx))
wrap_linker = WrapLinkerMany([linker], [callback]) wrap_linker = WrapLinkerMany([linker], [callback])
super().__init__(wrap_linker, optimizer="fast_run") super().__init__(linker=wrap_linker, optimizer="fast_run", db=db)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论