提交 f6297134 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Ricardo Vieira

Set the global mode during compilation

上级 bdf98ca9
...@@ -32,6 +32,7 @@ from pytensor.link.utils import raise_with_op ...@@ -32,6 +32,7 @@ from pytensor.link.utils import raise_with_op
if TYPE_CHECKING: if TYPE_CHECKING:
from pytensor.compile.mode import Mode
from pytensor.link.vm import VM from pytensor.link.vm import VM
...@@ -1391,9 +1392,16 @@ class FunctionMaker: ...@@ -1391,9 +1392,16 @@ class FunctionMaker:
@staticmethod @staticmethod
def prepare_fgraph( def prepare_fgraph(
inputs, outputs, additional_outputs, fgraph, rewriter, linker, profile inputs,
outputs,
additional_outputs,
fgraph: FunctionGraph,
mode: "Mode",
profile,
): ):
rewriter = mode.optimizer
try: try:
start_rewriter = time.perf_counter() start_rewriter = time.perf_counter()
...@@ -1401,6 +1409,7 @@ class FunctionMaker: ...@@ -1401,6 +1409,7 @@ class FunctionMaker:
rewrite_time = None rewrite_time = None
with config.change_flags( with config.change_flags(
mode=mode,
compute_test_value=config.compute_test_value_opt, compute_test_value=config.compute_test_value_opt,
traceback__limit=config.traceback__compile_limit, traceback__limit=config.traceback__compile_limit,
): ):
...@@ -1440,7 +1449,7 @@ class FunctionMaker: ...@@ -1440,7 +1449,7 @@ class FunctionMaker:
stacklevel=3, stacklevel=3,
) )
if not hasattr(linker, "accept"): if not hasattr(mode.linker, "accept"):
raise ValueError( raise ValueError(
"'linker' parameter of FunctionMaker should be " "'linker' parameter of FunctionMaker should be "
f"a Linker with an accept method or one of {list(pytensor.compile.mode.predefined_linkers.keys())}" f"a Linker with an accept method or one of {list(pytensor.compile.mode.predefined_linkers.keys())}"
...@@ -1511,12 +1520,8 @@ class FunctionMaker: ...@@ -1511,12 +1520,8 @@ class FunctionMaker:
self.fgraph = fgraph self.fgraph = fgraph
rewriter, linker = mode.optimizer, copy.copy(mode.linker)
if not no_fgraph_prep: if not no_fgraph_prep:
self.prepare_fgraph( self.prepare_fgraph(inputs, outputs, found_updates, fgraph, mode, profile)
inputs, outputs, found_updates, fgraph, rewriter, linker, profile
)
assert len(fgraph.outputs) == len(outputs + found_updates) assert len(fgraph.outputs) == len(outputs + found_updates)
...@@ -1528,6 +1533,8 @@ class FunctionMaker: ...@@ -1528,6 +1533,8 @@ class FunctionMaker:
if not spec.borrow if not spec.borrow
] ]
linker = copy.copy(mode.linker)
if no_borrow: if no_borrow:
self.linker = linker.accept( self.linker = linker.accept(
fgraph, fgraph,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论