提交 ebc0de09 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Validate compatible linker in Scan make_thunk

上级 7523caa4
......@@ -76,6 +76,7 @@ from pytensor.graph.traversal import graph_inputs
from pytensor.graph.type import HasShape
from pytensor.graph.utils import InconsistencyError, MissingInputError
from pytensor.link.c.basic import CLinker
from pytensor.link.vm import VMLinker
from pytensor.printing import op_debug_information
from pytensor.scan.utils import ScanProfileStats, Validator, forced_replace, safe_new
from pytensor.tensor.basic import as_tensor_variable
......@@ -884,16 +885,24 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self.nit_sot_arg_offset = (
self.untraced_sit_sot_arg_offset + info.n_untraced_sit_sot_outs
)
# XXX: This doesn't include `info.n_nit_sot`s, so it's really a count
# Note: This doesn't include `info.n_nit_sot`s, so it's really a count
# of the number of outputs generated by taps with inputs
self.n_outs = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot
self.n_tap_outs = info.n_mit_mot + info.n_mit_sot
# TODO: These can be moved to thunk/function compilation
(
_,
self.mitmots_preallocated,
) = self._mitmot_preallocations()
# Python and Cython perform methods provide the array location where a mitmot output should be
# stored to the VM as a symbolic update. This helper variable is used in the perform method for validation
mitmots_preallocated = [False] * info.n_mit_mot_outs
if config.scan__allow_output_prealloc:
for mitmot_idx in range(info.n_mit_mot):
for inp_tap in info.mit_mot_in_slices[mitmot_idx]:
if inp_tap in info.mit_mot_out_slices[mitmot_idx]:
# Figure out the index of the corresponding output
output_idx = sum(
len(m) for m in info.mit_mot_out_slices[:mitmot_idx]
) + info.mit_mot_out_slices[mitmot_idx].index(inp_tap)
mitmots_preallocated[output_idx] = True
self.mitmots_preallocated = tuple(mitmots_preallocated)
self.n_outer_inputs = info.n_outer_inputs
self.n_outer_outputs = info.n_outer_outputs
......@@ -908,39 +917,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
)
self._hash_inner_graph = hash(self._cmodule_key)
def _mitmot_preallocations(self):
if config.scan__allow_output_prealloc:
preallocated_mitmot_outs = []
info = self.info
input_idx = info.n_seqs
for mitmot_idx in range(info.n_mit_mot):
for inp_tap in info.mit_mot_in_slices[mitmot_idx]:
if inp_tap in info.mit_mot_out_slices[mitmot_idx]:
# Figure out the index of the corresponding output
output_idx = sum(
len(m) for m in info.mit_mot_out_slices[:mitmot_idx]
)
output_idx += info.mit_mot_out_slices[mitmot_idx].index(inp_tap)
preallocated_mitmot_outs.append(output_idx)
input_idx += 1
preallocated_mitmot_outs.sort()
else:
# Output preallocation is not activated. Mark every mitmot output
# tap as not being preallocated
preallocated_mitmot_outs = []
# Store the list of mitmot output taps that have been altered so they
# can be preallocated
mitmots_preallocated = [
i in preallocated_mitmot_outs for i in range(info.n_mit_mot_outs)
]
return preallocated_mitmot_outs, mitmots_preallocated
def __setstate__(self, d):
self.__dict__.update(d)
# Ensure that the graph associated with the inner function is valid.
......@@ -1483,11 +1459,26 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# Clone mode_instance, altering "allow_gc" for the linker,
# and adding a message if we profile
mode_instance = get_mode(self.mode).clone(
link_kwargs=dict(allow_gc=self.allow_gc),
message=f"{self.name or 'Scan'} sub profile",
)
mode = self.mode
if mode in (None, "FAST_RUN"):
mode_instance = Mode("cvm", "fast_run")
elif mode == "FAST_COMPILE":
mode_instance = Mode(
VMLinker(use_cloop=False, c_thunks=False), "fast_compile"
)
else:
mode_instance = get_mode(mode).clone(
link_kwargs=dict(allow_gc=self.allow_gc),
message=f"{self.name or 'Scan'} sub profile",
)
# Scan python and cython perform relies on the VM being able to set updates for preallocated MIT-MOT,
# which only the VMs produced by VMLinker do
if any(self.mitmots_preallocated) and not isinstance(
mode_instance.linker, VMLinker
):
raise NotImplementedError(
f"Python/Cython implementation of Scan with preallocated MIT-MOT outputs requires a VMLinker, got {mode_instance.linker}"
)
self._fn = pfunc(
wrapped_inputs,
wrapped_outputs,
......@@ -2007,6 +1998,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
new_var = inner_input_storage[inner_inp_idx].storage[0]
if old_var is new_var:
old_data = old_mitmot_input_data[mitmot_inp_idx]
# This check is only valid if the VM performs updates
# Otherwise the output value may remain the same as the input,
# but doesn't mean that it has been setup correctly
same_data = new_var.data == old_data
else:
same_data = False
......
......@@ -34,10 +34,12 @@ from pytensor.graph.replace import vectorize_graph
from pytensor.graph.rewriting.basic import MergeOptimizer
from pytensor.graph.traversal import ancestors
from pytensor.graph.utils import MissingInputError
from pytensor.link.vm import VMLinker
from pytensor.raise_op import assert_op
from pytensor.scan.basic import scan
from pytensor.scan.op import Scan
from pytensor.scan.op import Scan, ScanInfo
from pytensor.scan.utils import until
from pytensor.tensor import as_tensor
from pytensor.tensor.math import all as pt_all
from pytensor.tensor.math import dot, exp, mean, sigmoid, tanh
from pytensor.tensor.math import sum as pt_sum
......@@ -4308,3 +4310,91 @@ def test_return_updates_api_change():
with pytest.raises(ValueError, match=err_msg):
scan(lambda: {x: x + 1}, outputs_info=[], n_steps=5, return_updates=False)
@pytest.mark.parametrize(
"scan_mode",
[
None,
"FAST_RUN",
"FAST_COMPILE",
Mode("cvm", optimizer=None),
Mode("vm", optimizer=None),
Mode("c", optimizer=None),
Mode("py", optimizer=None),
],
)
def test_scan_mode_compatibility(scan_mode):
# Regression test for case where using Scan with a non-updating VM failed
# Build a scan with one sequence and two MIT-MOTs
info = ScanInfo(
n_seqs=1,
mit_mot_in_slices=((0, 1), (0, 1)),
mit_mot_out_slices=((1,), (1,)),
mit_sot_in_slices=(),
sit_sot_in_slices=(),
n_nit_sot=0,
n_untraced_sit_sot_outs=0,
n_non_seqs=0,
as_while=False,
)
bool_seq = pt.scalar(dtype="bool")
mitmot_A0, mitmot_A1, mitmot_B0, mitmot_B1 = [
pt.matrix(shape=(2, 2)) for i in range(4)
]
inputs = [
bool_seq,
mitmot_A0,
mitmot_A1,
mitmot_B0,
mitmot_B1,
]
outputs = [
pt.add(bool_seq + mitmot_A0, mitmot_A1),
pt.add(bool_seq * mitmot_B0, mitmot_B1),
]
scan_op = Scan(
inputs,
outputs,
info=info,
mode=scan_mode,
)
n_steps = 5
numerical_inputs = [
np.array(n_steps, dtype="int64"),
np.array([1, 1, 0, 1, 0], dtype="bool"),
np.zeros(n_steps + 1)[:, None, None] * np.eye(2),
np.arange(n_steps + 1)[:, None, None] * np.eye(2),
]
tensor_inputs = [as_tensor(inp, dtype=inp.dtype).type() for inp in numerical_inputs]
tensor_outputs = [o.sum() for o in scan_op(*tensor_inputs)]
no_opt_mode = Mode(linker="py", optimizer=None)
# NotImplementedError should only be triggered when we try to compile the function
if (
# Abstract modes should never fail
scan_mode not in (None, "FAST_RUN", "FAST_COMPILE")
# Only if the user tries something specific and incompatible
and not isinstance(get_mode(scan_mode).linker, VMLinker)
):
with pytest.raises(
NotImplementedError,
match="Python/Cython implementation of Scan with preallocated MIT-MOT outputs requires a VMLinker",
):
function(tensor_inputs, tensor_outputs, mode=no_opt_mode)
return
fn = function(tensor_inputs, tensor_outputs, mode=no_opt_mode)
# Check we have the expected Scan in the compiled function
[fn_scan_op] = [
node.op for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Scan)
]
assert fn_scan_op.info == info
assert fn_scan_op.mitmots_preallocated == (True, True)
# Expected value computed by running correct Scan once
np.testing.assert_allclose(fn(*numerical_inputs), [44, 38])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论