提交 ad1af2ea authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Delay creation of Scan inner mode

上级 de75bbad
...@@ -58,7 +58,7 @@ from pytensor import tensor as pt ...@@ -58,7 +58,7 @@ from pytensor import tensor as pt
from pytensor.compile.builders import construct_nominal_fgraph, infer_shape from pytensor.compile.builders import construct_nominal_fgraph, infer_shape
from pytensor.compile.function.pfunc import pfunc from pytensor.compile.function.pfunc import pfunc
from pytensor.compile.io import In, Out from pytensor.compile.io import In, Out
from pytensor.compile.mode import Mode, get_default_mode, get_mode from pytensor.compile.mode import Mode, get_mode
from pytensor.compile.profiling import register_profiler_printer from pytensor.compile.profiling import register_profiler_printer
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.gradient import DisconnectedType, NullType, Rop, grad, grad_undefined from pytensor.gradient import DisconnectedType, NullType, Rop, grad, grad_undefined
...@@ -761,18 +761,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -761,18 +761,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self.profile = profile self.profile = profile
self.allow_gc = allow_gc self.allow_gc = allow_gc
self.strict = strict self.strict = strict
self.mode = mode
# Clone mode_instance, altering "allow_gc" for the linker,
# and adding a message if we profile
if self.name:
message = f"{self.name} sub profile"
else:
message = "Scan sub profile"
self.mode = get_default_mode() if mode is None else mode
self.mode_instance = get_mode(self.mode).clone(
link_kwargs=dict(allow_gc=self.allow_gc), message=message
)
# build a list of output types for any Apply node using this op. # build a list of output types for any Apply node using this op.
self.output_types = [] self.output_types = []
...@@ -1445,10 +1434,17 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1445,10 +1434,17 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
elif self.profile: elif self.profile:
profile = self.profile profile = self.profile
# Clone mode_instance, altering "allow_gc" for the linker,
# and adding a message if we profile
mode_instance = get_mode(self.mode).clone(
link_kwargs=dict(allow_gc=self.allow_gc),
message=f"{self.name or 'Scan'} sub profile",
)
self._fn = pfunc( self._fn = pfunc(
wrapped_inputs, wrapped_inputs,
wrapped_outputs, wrapped_outputs,
mode=self.mode_instance, mode=mode_instance,
accept_inplace=False, accept_inplace=False,
profile=profile, profile=profile,
on_unused_input="ignore", on_unused_input="ignore",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论