提交 f6297134 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Ricardo Vieira

Set the global mode during compilation

上级 bdf98ca9
......@@ -32,6 +32,7 @@ from pytensor.link.utils import raise_with_op
if TYPE_CHECKING:
from pytensor.compile.mode import Mode
from pytensor.link.vm import VM
......@@ -1391,9 +1392,16 @@ class FunctionMaker:
@staticmethod
def prepare_fgraph(
inputs, outputs, additional_outputs, fgraph, rewriter, linker, profile
inputs,
outputs,
additional_outputs,
fgraph: FunctionGraph,
mode: "Mode",
profile,
):
rewriter = mode.optimizer
try:
start_rewriter = time.perf_counter()
......@@ -1401,6 +1409,7 @@ class FunctionMaker:
rewrite_time = None
with config.change_flags(
mode=mode,
compute_test_value=config.compute_test_value_opt,
traceback__limit=config.traceback__compile_limit,
):
......@@ -1440,7 +1449,7 @@ class FunctionMaker:
stacklevel=3,
)
if not hasattr(linker, "accept"):
if not hasattr(mode.linker, "accept"):
raise ValueError(
"'linker' parameter of FunctionMaker should be "
f"a Linker with an accept method or one of {list(pytensor.compile.mode.predefined_linkers.keys())}"
......@@ -1511,12 +1520,8 @@ class FunctionMaker:
self.fgraph = fgraph
rewriter, linker = mode.optimizer, copy.copy(mode.linker)
if not no_fgraph_prep:
self.prepare_fgraph(
inputs, outputs, found_updates, fgraph, rewriter, linker, profile
)
self.prepare_fgraph(inputs, outputs, found_updates, fgraph, mode, profile)
assert len(fgraph.outputs) == len(outputs + found_updates)
......@@ -1528,6 +1533,8 @@ class FunctionMaker:
if not spec.borrow
]
linker = copy.copy(mode.linker)
if no_borrow:
self.linker = linker.accept(
fgraph,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论