提交 ad1af2ea authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Delay creation of Scan inner mode

上级 de75bbad
......@@ -58,7 +58,7 @@ from pytensor import tensor as pt
from pytensor.compile.builders import construct_nominal_fgraph, infer_shape
from pytensor.compile.function.pfunc import pfunc
from pytensor.compile.io import In, Out
from pytensor.compile.mode import Mode, get_default_mode, get_mode
from pytensor.compile.mode import Mode, get_mode
from pytensor.compile.profiling import register_profiler_printer
from pytensor.configdefaults import config
from pytensor.gradient import DisconnectedType, NullType, Rop, grad, grad_undefined
......@@ -761,18 +761,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self.profile = profile
self.allow_gc = allow_gc
self.strict = strict
# Clone mode_instance, altering "allow_gc" for the linker,
# and adding a message if we profile
if self.name:
message = f"{self.name} sub profile"
else:
message = "Scan sub profile"
self.mode = get_default_mode() if mode is None else mode
self.mode_instance = get_mode(self.mode).clone(
link_kwargs=dict(allow_gc=self.allow_gc), message=message
)
self.mode = mode
# build a list of output types for any Apply node using this op.
self.output_types = []
......@@ -1445,10 +1434,17 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
elif self.profile:
profile = self.profile
# Clone mode_instance, altering "allow_gc" for the linker,
# and adding a message if we profile
mode_instance = get_mode(self.mode).clone(
link_kwargs=dict(allow_gc=self.allow_gc),
message=f"{self.name or 'Scan'} sub profile",
)
self._fn = pfunc(
wrapped_inputs,
wrapped_outputs,
mode=self.mode_instance,
mode=mode_instance,
accept_inplace=False,
profile=profile,
on_unused_input="ignore",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论