提交 ebc0de09 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Validate compatible linker in Scan make_thunk

上级 7523caa4
...@@ -76,6 +76,7 @@ from pytensor.graph.traversal import graph_inputs ...@@ -76,6 +76,7 @@ from pytensor.graph.traversal import graph_inputs
from pytensor.graph.type import HasShape from pytensor.graph.type import HasShape
from pytensor.graph.utils import InconsistencyError, MissingInputError from pytensor.graph.utils import InconsistencyError, MissingInputError
from pytensor.link.c.basic import CLinker from pytensor.link.c.basic import CLinker
from pytensor.link.vm import VMLinker
from pytensor.printing import op_debug_information from pytensor.printing import op_debug_information
from pytensor.scan.utils import ScanProfileStats, Validator, forced_replace, safe_new from pytensor.scan.utils import ScanProfileStats, Validator, forced_replace, safe_new
from pytensor.tensor.basic import as_tensor_variable from pytensor.tensor.basic import as_tensor_variable
...@@ -884,16 +885,24 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -884,16 +885,24 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self.nit_sot_arg_offset = ( self.nit_sot_arg_offset = (
self.untraced_sit_sot_arg_offset + info.n_untraced_sit_sot_outs self.untraced_sit_sot_arg_offset + info.n_untraced_sit_sot_outs
) )
# XXX: This doesn't include `info.n_nit_sot`s, so it's really a count # Note: This doesn't include `info.n_nit_sot`s, so it's really a count
# of the number of outputs generated by taps with inputs # of the number of outputs generated by taps with inputs
self.n_outs = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot self.n_outs = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot
self.n_tap_outs = info.n_mit_mot + info.n_mit_sot self.n_tap_outs = info.n_mit_mot + info.n_mit_sot
# TODO: These can be moved to thunk/function compilation # Python and Cython perform methods provide the array location where a mitmot output should be
( # stored to the VM as a symbolic update. This helper variable is used in the perform method for validation
_, mitmots_preallocated = [False] * info.n_mit_mot_outs
self.mitmots_preallocated, if config.scan__allow_output_prealloc:
) = self._mitmot_preallocations() for mitmot_idx in range(info.n_mit_mot):
for inp_tap in info.mit_mot_in_slices[mitmot_idx]:
if inp_tap in info.mit_mot_out_slices[mitmot_idx]:
# Figure out the index of the corresponding output
output_idx = sum(
len(m) for m in info.mit_mot_out_slices[:mitmot_idx]
) + info.mit_mot_out_slices[mitmot_idx].index(inp_tap)
mitmots_preallocated[output_idx] = True
self.mitmots_preallocated = tuple(mitmots_preallocated)
self.n_outer_inputs = info.n_outer_inputs self.n_outer_inputs = info.n_outer_inputs
self.n_outer_outputs = info.n_outer_outputs self.n_outer_outputs = info.n_outer_outputs
...@@ -908,39 +917,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -908,39 +917,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
) )
self._hash_inner_graph = hash(self._cmodule_key) self._hash_inner_graph = hash(self._cmodule_key)
def _mitmot_preallocations(self):
if config.scan__allow_output_prealloc:
preallocated_mitmot_outs = []
info = self.info
input_idx = info.n_seqs
for mitmot_idx in range(info.n_mit_mot):
for inp_tap in info.mit_mot_in_slices[mitmot_idx]:
if inp_tap in info.mit_mot_out_slices[mitmot_idx]:
# Figure out the index of the corresponding output
output_idx = sum(
len(m) for m in info.mit_mot_out_slices[:mitmot_idx]
)
output_idx += info.mit_mot_out_slices[mitmot_idx].index(inp_tap)
preallocated_mitmot_outs.append(output_idx)
input_idx += 1
preallocated_mitmot_outs.sort()
else:
# Output preallocation is not activated. Mark every mitmot output
# tap as not being preallocated
preallocated_mitmot_outs = []
# Store the list of mitmot output taps that have been altered so they
# can be preallocated
mitmots_preallocated = [
i in preallocated_mitmot_outs for i in range(info.n_mit_mot_outs)
]
return preallocated_mitmot_outs, mitmots_preallocated
def __setstate__(self, d): def __setstate__(self, d):
self.__dict__.update(d) self.__dict__.update(d)
# Ensure that the graph associated with the inner function is valid. # Ensure that the graph associated with the inner function is valid.
...@@ -1483,11 +1459,26 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1483,11 +1459,26 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# Clone mode_instance, altering "allow_gc" for the linker, # Clone mode_instance, altering "allow_gc" for the linker,
# and adding a message if we profile # and adding a message if we profile
mode_instance = get_mode(self.mode).clone( mode = self.mode
link_kwargs=dict(allow_gc=self.allow_gc), if mode in (None, "FAST_RUN"):
message=f"{self.name or 'Scan'} sub profile", mode_instance = Mode("cvm", "fast_run")
) elif mode == "FAST_COMPILE":
mode_instance = Mode(
VMLinker(use_cloop=False, c_thunks=False), "fast_compile"
)
else:
mode_instance = get_mode(mode).clone(
link_kwargs=dict(allow_gc=self.allow_gc),
message=f"{self.name or 'Scan'} sub profile",
)
# Scan python and cython perform relies on the VM being able to set updates for preallocated MIT-MOT,
# which only the VMs produced by VMLinker do
if any(self.mitmots_preallocated) and not isinstance(
mode_instance.linker, VMLinker
):
raise NotImplementedError(
f"Python/Cython implementation of Scan with preallocated MIT-MOT outputs requires a VMLinker, got {mode_instance.linker}"
)
self._fn = pfunc( self._fn = pfunc(
wrapped_inputs, wrapped_inputs,
wrapped_outputs, wrapped_outputs,
...@@ -2007,6 +1998,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2007,6 +1998,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
new_var = inner_input_storage[inner_inp_idx].storage[0] new_var = inner_input_storage[inner_inp_idx].storage[0]
if old_var is new_var: if old_var is new_var:
old_data = old_mitmot_input_data[mitmot_inp_idx] old_data = old_mitmot_input_data[mitmot_inp_idx]
# This check is only valid if the VM performs updates
# Otherwise the output value may remain the same as the input,
# but doesn't mean that it has been setup correctly
same_data = new_var.data == old_data same_data = new_var.data == old_data
else: else:
same_data = False same_data = False
......
...@@ -34,10 +34,12 @@ from pytensor.graph.replace import vectorize_graph ...@@ -34,10 +34,12 @@ from pytensor.graph.replace import vectorize_graph
from pytensor.graph.rewriting.basic import MergeOptimizer from pytensor.graph.rewriting.basic import MergeOptimizer
from pytensor.graph.traversal import ancestors from pytensor.graph.traversal import ancestors
from pytensor.graph.utils import MissingInputError from pytensor.graph.utils import MissingInputError
from pytensor.link.vm import VMLinker
from pytensor.raise_op import assert_op from pytensor.raise_op import assert_op
from pytensor.scan.basic import scan from pytensor.scan.basic import scan
from pytensor.scan.op import Scan from pytensor.scan.op import Scan, ScanInfo
from pytensor.scan.utils import until from pytensor.scan.utils import until
from pytensor.tensor import as_tensor
from pytensor.tensor.math import all as pt_all from pytensor.tensor.math import all as pt_all
from pytensor.tensor.math import dot, exp, mean, sigmoid, tanh from pytensor.tensor.math import dot, exp, mean, sigmoid, tanh
from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.math import sum as pt_sum
...@@ -4308,3 +4310,91 @@ def test_return_updates_api_change(): ...@@ -4308,3 +4310,91 @@ def test_return_updates_api_change():
with pytest.raises(ValueError, match=err_msg): with pytest.raises(ValueError, match=err_msg):
scan(lambda: {x: x + 1}, outputs_info=[], n_steps=5, return_updates=False) scan(lambda: {x: x + 1}, outputs_info=[], n_steps=5, return_updates=False)
@pytest.mark.parametrize(
"scan_mode",
[
None,
"FAST_RUN",
"FAST_COMPILE",
Mode("cvm", optimizer=None),
Mode("vm", optimizer=None),
Mode("c", optimizer=None),
Mode("py", optimizer=None),
],
)
def test_scan_mode_compatibility(scan_mode):
# Regression test for case where using Scan with a non-updating VM failed
# Build a scan with one sequence and two MIT-MOTs
info = ScanInfo(
n_seqs=1,
mit_mot_in_slices=((0, 1), (0, 1)),
mit_mot_out_slices=((1,), (1,)),
mit_sot_in_slices=(),
sit_sot_in_slices=(),
n_nit_sot=0,
n_untraced_sit_sot_outs=0,
n_non_seqs=0,
as_while=False,
)
bool_seq = pt.scalar(dtype="bool")
mitmot_A0, mitmot_A1, mitmot_B0, mitmot_B1 = [
pt.matrix(shape=(2, 2)) for i in range(4)
]
inputs = [
bool_seq,
mitmot_A0,
mitmot_A1,
mitmot_B0,
mitmot_B1,
]
outputs = [
pt.add(bool_seq + mitmot_A0, mitmot_A1),
pt.add(bool_seq * mitmot_B0, mitmot_B1),
]
scan_op = Scan(
inputs,
outputs,
info=info,
mode=scan_mode,
)
n_steps = 5
numerical_inputs = [
np.array(n_steps, dtype="int64"),
np.array([1, 1, 0, 1, 0], dtype="bool"),
np.zeros(n_steps + 1)[:, None, None] * np.eye(2),
np.arange(n_steps + 1)[:, None, None] * np.eye(2),
]
tensor_inputs = [as_tensor(inp, dtype=inp.dtype).type() for inp in numerical_inputs]
tensor_outputs = [o.sum() for o in scan_op(*tensor_inputs)]
no_opt_mode = Mode(linker="py", optimizer=None)
# NotImplementedError should only be triggered when we try to compile the function
if (
# Abstract modes should never fail
scan_mode not in (None, "FAST_RUN", "FAST_COMPILE")
# Only if the user tries something specific and incompatible
and not isinstance(get_mode(scan_mode).linker, VMLinker)
):
with pytest.raises(
NotImplementedError,
match="Python/Cython implementation of Scan with preallocated MIT-MOT outputs requires a VMLinker",
):
function(tensor_inputs, tensor_outputs, mode=no_opt_mode)
return
fn = function(tensor_inputs, tensor_outputs, mode=no_opt_mode)
# Check we have the expected Scan in the compiled function
[fn_scan_op] = [
node.op for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Scan)
]
assert fn_scan_op.info == info
assert fn_scan_op.mitmots_preallocated == (True, True)
# Expected value computed by running correct Scan once
np.testing.assert_allclose(fn(*numerical_inputs), [44, 38])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论