提交 1d19c375 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Allow non-shared untraced SIT-SOT

上级 d10c61ba
......@@ -60,23 +60,23 @@ def jax_funcify_Scan(op: Scan, **kwargs):
mit_mot_init,
mit_sot_init,
sit_sot_init,
op.outer_shared(outer_inputs),
op.outer_untraced_sit_sot(outer_inputs),
op.outer_non_seqs(outer_inputs),
) # JAX `init`
def jax_args_to_inner_func_args(carry, x):
"""Convert JAX scan arguments into format expected by scan_inner_func.
scan(carry, x) -> scan_inner_func(seqs, MIT-SOT, SIT-SOT, shared, non_seqs)
scan(carry, x) -> scan_inner_func(seqs, MIT-SOT, SIT-SOT, untraced SIT-SOT, non_seqs)
"""
# `carry` contains all inner taps, shared terms, and non_seqs
# `carry` contains all inner taps and non_seqs
(
i,
inner_mit_mot,
inner_mit_sot,
inner_sit_sot,
inner_shared,
inner_untraced_sit_sot,
inner_non_seqs,
) = carry
......@@ -108,7 +108,7 @@ def jax_funcify_Scan(op: Scan, **kwargs):
*mit_mot_flatten,
*mit_sot_flatten,
*inner_sit_sot,
*inner_shared,
*inner_untraced_sit_sot,
*inner_non_seqs,
)
......@@ -118,14 +118,14 @@ def jax_funcify_Scan(op: Scan, **kwargs):
):
"""Convert inner_scan_func outputs into format expected by JAX scan.
old_carry + (MIT-SOT_outs, SIT-SOT_outs, NIT-SOT_outs, shared_outs) -> (new_carry, ys)
old_carry + (MIT-SOT_outs, SIT-SOT_outs, NIT-SOT_outs, untraced_SIT-SOT_outs) -> (new_carry, ys)
"""
(
i,
old_mit_mot,
old_mit_sot,
_old_sit_sot,
_old_shared,
_old_untraced_sit_sot,
inner_non_seqs,
) = old_carry
......@@ -133,7 +133,7 @@ def jax_funcify_Scan(op: Scan, **kwargs):
new_mit_sot_vals = op.inner_mitsot_outs(inner_scan_outs)
new_sit_sot = op.inner_sitsot_outs(inner_scan_outs)
new_nit_sot = op.inner_nitsot_outs(inner_scan_outs)
new_shared = op.inner_shared_outs(inner_scan_outs)
new_untraced_sit_sot = op.inner_untraced_sit_sot_outs(inner_scan_outs)
# New carry for next step
# Update MIT-MOT buffer at positions indicated by output taps
......@@ -150,14 +150,14 @@ def jax_funcify_Scan(op: Scan, **kwargs):
old_mit_sot, new_mit_sot_vals, strict=True
)
]
# For SIT-SOT, and shared just pass along the new value
# For SIT-SOT just pass along the new value
# Non-sequences remain unchanged
new_carry = (
i + 1,
new_mit_mot,
new_mit_sot,
new_sit_sot,
new_shared,
new_untraced_sit_sot,
inner_non_seqs,
)
......@@ -183,7 +183,7 @@ def jax_funcify_Scan(op: Scan, **kwargs):
final_mit_mot,
_final_mit_sot,
_final_sit_sot,
final_shared,
final_untraced_sit_sot,
_final_non_seqs,
),
traces,
......@@ -238,7 +238,7 @@ def jax_funcify_Scan(op: Scan, **kwargs):
scan_outs_final = [
*final_mit_mot,
*get_partial_traces(traces),
*final_shared,
*final_untraced_sit_sot,
]
if len(scan_outs_final) == 1:
......
......@@ -108,19 +108,19 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
outer_in_mit_sot_names = op.outer_mitsot(outer_in_names)
outer_in_sit_sot_names = op.outer_sitsot(outer_in_names)
outer_in_nit_sot_names = op.outer_nitsot(outer_in_names)
outer_in_shared_names = op.outer_shared(outer_in_names)
outer_in_untraced_sit_sot_names = op.outer_untraced_sit_sot(outer_in_names)
outer_in_non_seqs_names = op.outer_non_seqs(outer_in_names)
# These are all the outer-input names that have produce outputs/have output
# taps (i.e. they have inner-outputs and corresponding outer-outputs).
# Outer-outputs are ordered as follows:
# mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots + shared-outputs
# mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots + untraced-sit-sot-outputs
outer_in_outtap_names = (
outer_in_mit_mot_names
+ outer_in_mit_sot_names
+ outer_in_sit_sot_names
+ outer_in_nit_sot_names
+ outer_in_shared_names
+ outer_in_untraced_sit_sot_names
)
# We create distinct variables for/references to the storage arrays for
......@@ -138,8 +138,10 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
for outer_in_name in outer_in_nit_sot_names:
outer_in_to_storage_name[outer_in_name] = f"{outer_in_name}_nitsot_storage"
for outer_in_name in outer_in_shared_names:
outer_in_to_storage_name[outer_in_name] = f"{outer_in_name}_shared_storage"
for outer_in_name in outer_in_untraced_sit_sot_names:
outer_in_to_storage_name[outer_in_name] = (
f"{outer_in_name}_untraced_sit_sot_storage"
)
outer_output_names = list(outer_in_to_storage_name.values())
assert len(outer_output_names) == len(node.outputs)
......@@ -147,7 +149,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
# Construct the inner-input expressions (e.g. indexed storage expressions)
# Inner-inputs are ordered as follows:
# sequences + mit-mot-inputs + mit-sot-inputs + sit-sot-inputs +
# shared-inputs + non-sequences.
# untraced-sit-sot-inputs + non-sequences.
temp_scalar_storage_alloc_stmts: list[str] = []
inner_in_exprs_scalar: list[str] = []
inner_in_exprs: list[str] = []
......@@ -204,11 +206,9 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
# Inner-outputs consist of:
# mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots +
# shared-outputs [+ while-condition]
# untraced-sit-sot-outputs [+ while-condition]
inner_output_names = [f"inner_out_{i}" for i in range(len(op.inner_outputs))]
# inner_out_shared_names = op.inner_shared_outs(inner_output_names)
# The assignment statements that copy inner-outputs into the outer-outputs
# storage
inner_out_to_outer_in_stmts: list[str] = []
......
import typing
import warnings
from itertools import chain
......@@ -11,6 +12,7 @@ from pytensor.graph.basic import Constant, Variable
from pytensor.graph.op import get_test_value
from pytensor.graph.replace import clone_replace
from pytensor.graph.traversal import explicit_graph_inputs
from pytensor.graph.type import HasShape
from pytensor.graph.utils import MissingInputError, TestValueError
from pytensor.scan.op import Scan, ScanInfo
from pytensor.scan.utils import expand_empty, safe_new, until
......@@ -22,6 +24,10 @@ from pytensor.tensor.type import TensorType, integer_dtypes
from pytensor.updates import OrderedUpdates
if typing.TYPE_CHECKING:
from pytensor.tensor.type import TensorVariable
def get_updates_and_outputs(ls):
"""Recognize and order the updates, outputs, and stopping condition for a `Scan`.
......@@ -469,7 +475,7 @@ def scan(
# Make sure we get rid of numpy arrays or ints or anything like that
# passed as inputs to scan
non_seqs = []
non_seqs: list[Variable] = []
for elem in wrap_into_list(non_sequences):
if not isinstance(elem, Variable):
non_seqs.append(pt.as_tensor_variable(elem))
......@@ -685,10 +691,10 @@ def scan(
# MIT_MOT -- not provided by the user only by the grad function
n_mit_mot = 0
mit_mot_scan_inputs = []
mit_mot_inner_inputs = []
mit_mot_inner_outputs = []
mit_mot_out_slices = []
mit_mot_scan_inputs: list[TensorVariable] = []
mit_mot_inner_inputs: list[TensorVariable] = []
mit_mot_inner_outputs: list[TensorVariable] = []
mit_mot_out_slices: list[TensorVariable] = []
# SIT_SOT -- provided by the user
n_mit_sot = 0
......@@ -706,6 +712,12 @@ def scan(
sit_sot_inner_outputs = []
sit_sot_rightOrder = []
n_untraced_sit_sot_outs = 0
untraced_sit_sot_scan_inputs = []
untraced_sit_sot_inner_inputs = []
untraced_sit_sot_inner_outputs = []
untraced_sit_sot_rightOrder = []
# go through outputs picking up time slices as needed
for i, init_out in enumerate(outs_info):
# Note that our convention dictates that if an output uses
......@@ -741,17 +753,35 @@ def scan(
# We need now to allocate space for storing the output and copy
# the initial state over. We do this using the expand function
# defined in scan utils
sit_sot_scan_inputs.append(
expand_empty(
shape_padleft(actual_arg),
actual_n_steps,
if isinstance(actual_arg.type, HasShape):
sit_sot_scan_inputs.append(
expand_empty(
shape_padleft(actual_arg),
actual_n_steps,
)
)
)
sit_sot_inner_slices.append(actual_arg)
sit_sot_inner_slices.append(actual_arg)
sit_sot_inner_inputs.append(arg)
sit_sot_rightOrder.append(i)
n_sit_sot += 1
sit_sot_inner_inputs.append(arg)
sit_sot_rightOrder.append(i)
n_sit_sot += 1
else:
# Assume variables without shape cannot be stacked (e.g., RNG variables)
# Because this is new, issue a warning to inform the user, except for RNG, which were the main reason for this feature
from pytensor.tensor.random.type import RandomType
if not isinstance(arg.type, RandomType):
warnings.warn(
(
f"Output {actual_arg} (index {i}) with type {actual_arg.type} will be treated as untraced variable in scan. "
"Only the last value will be returned, not the entire sequence."
),
UserWarning,
)
untraced_sit_sot_scan_inputs.append(actual_arg)
untraced_sit_sot_inner_inputs.append(arg)
n_untraced_sit_sot_outs += 1
untraced_sit_sot_rightOrder.append(i)
elif init_out.get("taps", None):
if np.any(np.array(init_out.get("taps", [])) > 0):
......@@ -802,10 +832,11 @@ def scan(
# a map); in that case we do not have to do anything ..
# Re-order args
max_mit_sot = np.max([-1, *mit_sot_rightOrder]) + 1
max_sit_sot = np.max([-1, *sit_sot_rightOrder]) + 1
n_elems = np.max([max_mit_sot, max_sit_sot])
_ordered_args = [[] for x in range(n_elems)]
max_mit_sot = max(mit_sot_rightOrder, default=-1) + 1
max_sit_sot = max(sit_sot_rightOrder, default=-1) + 1
max_untraced_sit_sot_outs = max(untraced_sit_sot_rightOrder, default=-1) + 1
n_elems = np.max((max_mit_sot, max_sit_sot, max_untraced_sit_sot_outs))
_ordered_args: list[list[Variable]] = [[] for x in range(n_elems)]
offset = 0
for idx in range(n_mit_sot):
n_inputs = len(mit_sot_tap_array[idx])
......@@ -825,6 +856,11 @@ def scan(
else:
_ordered_args[sit_sot_rightOrder[idx]] = [sit_sot_inner_inputs[idx]]
for idx in range(n_untraced_sit_sot_outs):
_ordered_args[untraced_sit_sot_rightOrder[idx]] = [
untraced_sit_sot_inner_inputs[idx]
]
ordered_args = list(chain.from_iterable(_ordered_args))
if single_step_requested:
args = inner_slices + ordered_args + non_seqs
......@@ -939,18 +975,19 @@ def scan(
if "taps" in out and out["taps"] != [-1]:
mit_sot_inner_outputs.append(outputs[i])
# Step 5.2 Outputs with tap equal to -1
# Step 5.2 Outputs with tap equal to -1 (traced and untraced)
for i, out in enumerate(outs_info):
if "taps" in out and out["taps"] == [-1]:
sit_sot_inner_outputs.append(outputs[i])
output = outputs[i]
if isinstance(output.type, HasShape):
sit_sot_inner_outputs.append(output)
else:
untraced_sit_sot_inner_outputs.append(output)
# Step 5.3 Outputs that correspond to update rules of shared variables
# This whole special logic for shared variables is deprecated
sit_sot_shared: list[Variable] = []
inner_replacements = {}
n_shared_outs = 0
shared_scan_inputs = []
shared_inner_inputs = []
shared_inner_outputs = []
sit_sot_shared = []
no_update_shared_inputs = []
for input in dummy_inputs:
if not isinstance(input.variable, SharedVariable):
......@@ -976,8 +1013,8 @@ def scan(
new_var = safe_new(input.variable)
if getattr(input.variable, "name", None) is not None:
new_var.name = input.variable.name + "_copy"
if input.variable.name is not None:
new_var.name = f"{input.variable.name}_copy"
inner_replacements[input.variable] = new_var
......@@ -1003,10 +1040,10 @@ def scan(
sit_sot_shared.append(input.variable)
else:
shared_inner_inputs.append(new_var)
shared_scan_inputs.append(input.variable)
shared_inner_outputs.append(input.update)
n_shared_outs += 1
untraced_sit_sot_inner_inputs.append(new_var)
untraced_sit_sot_scan_inputs.append(input.variable)
untraced_sit_sot_inner_outputs.append(input.update)
n_untraced_sit_sot_outs += 1
else:
no_update_shared_inputs.append(input)
......@@ -1071,7 +1108,7 @@ def scan(
+ mit_mot_inner_inputs
+ mit_sot_inner_inputs
+ sit_sot_inner_inputs
+ shared_inner_inputs
+ untraced_sit_sot_inner_inputs
+ other_shared_inner_args
+ other_inner_args
)
......@@ -1081,7 +1118,7 @@ def scan(
+ mit_sot_inner_outputs
+ sit_sot_inner_outputs
+ nit_sot_inner_outputs
+ shared_inner_outputs
+ untraced_sit_sot_inner_outputs
)
if condition is not None:
inner_outs.append(condition)
......@@ -1101,7 +1138,7 @@ def scan(
mit_mot_out_slices=tuple(tuple(v) for v in mit_mot_out_slices),
mit_sot_in_slices=tuple(tuple(v) for v in mit_sot_tap_array),
sit_sot_in_slices=tuple((-1,) for x in range(n_sit_sot)),
n_shared_outs=n_shared_outs,
n_untraced_sit_sot_outs=n_untraced_sit_sot_outs,
n_nit_sot=n_nit_sot,
n_non_seqs=len(other_shared_inner_args) + len(other_inner_args),
as_while=as_while,
......@@ -1127,7 +1164,7 @@ def scan(
+ mit_mot_scan_inputs
+ mit_sot_scan_inputs
+ sit_sot_scan_inputs
+ shared_scan_inputs
+ untraced_sit_sot_scan_inputs
+ [actual_n_steps for x in range(n_nit_sot)]
+ other_shared_scan_args
+ other_scan_args
......@@ -1173,13 +1210,26 @@ def scan(
nit_sot_outs = remove_dimensions(scan_outs[offset : offset + n_nit_sot])
offset += n_nit_sot
for idx, update_rule in enumerate(scan_outs[offset : offset + n_shared_outs]):
update_map[shared_scan_inputs[idx]] = update_rule
_scan_out_list = mit_sot_outs + sit_sot_outs + nit_sot_outs
# Support for explicit untraced sit_sot
n_explicit_untraced_sit_sot_outs = len(untraced_sit_sot_rightOrder)
untraced_sit_sot_outs = scan_outs[
offset : offset + n_explicit_untraced_sit_sot_outs
]
offset += n_explicit_untraced_sit_sot_outs
for idx, update_rule in enumerate(scan_outs[offset:]):
update_map[untraced_sit_sot_scan_inputs[idx]] = update_rule
_scan_out_list = mit_sot_outs + sit_sot_outs + nit_sot_outs + untraced_sit_sot_outs
# Step 10. I need to reorder the outputs to be in the order expected by
# the user
rightOrder = mit_sot_rightOrder + sit_sot_rightOrder + nit_sot_rightOrder
rightOrder = (
mit_sot_rightOrder
+ sit_sot_rightOrder
+ untraced_sit_sot_rightOrder
+ nit_sot_rightOrder
)
scan_out_list = [None] * len(rightOrder)
for idx, pos in enumerate(rightOrder):
if pos >= 0:
......
......@@ -46,6 +46,7 @@ relies on the following elements to work properly :
import dataclasses
import logging
import time
import warnings
from collections.abc import Callable, Iterable
from copy import copy
from itertools import chain, product
......@@ -208,10 +209,19 @@ class ScanInfo:
mit_sot_in_slices: tuple
sit_sot_in_slices: tuple
n_nit_sot: int
n_shared_outs: int
n_untraced_sit_sot_outs: int
n_non_seqs: int
as_while: bool
@property
def n_shared_outs(self):
warnings.warn(
"The 'n_shared_outs' property is deprecated. Use 'n_untraced_sit_sot_outs' instead.",
DeprecationWarning,
stacklevel=2,
)
return self.n_untraced_sit_sot_outs
@property
def n_mit_mot(self):
return len(self.mit_mot_in_slices)
......@@ -239,7 +249,7 @@ class ScanInfo:
+ sum(len(x) for x in self.mit_mot_in_slices)
+ sum(len(x) for x in self.mit_sot_in_slices)
+ self.n_sit_sot
+ self.n_shared_outs
+ self.n_untraced_sit_sot_outs
+ self.n_non_seqs
)
......@@ -250,7 +260,7 @@ class ScanInfo:
+ self.n_mit_sot
+ self.n_sit_sot
+ self.n_nit_sot
+ self.n_shared_outs
+ self.n_untraced_sit_sot_outs
+ int(self.as_while)
)
......@@ -263,7 +273,7 @@ class ScanInfo:
+ self.n_mit_sot
+ self.n_sit_sot
+ self.n_nit_sot
+ self.n_shared_outs
+ self.n_untraced_sit_sot_outs
+ self.n_non_seqs
)
......@@ -274,7 +284,7 @@ class ScanInfo:
+ self.n_mit_sot
+ self.n_sit_sot
+ self.n_nit_sot
+ self.n_shared_outs
+ self.n_untraced_sit_sot_outs
)
......@@ -381,7 +391,7 @@ class ScanMethodsMixin:
+ self.info.n_mit_mot
+ self.info.n_mit_sot
+ self.info.n_sit_sot
+ self.info.n_shared_outs
+ self.info.n_untraced_sit_sot_outs
)
return list_inputs[offset : offset + self.info.n_nit_sot]
......@@ -394,15 +404,23 @@ class ScanMethodsMixin:
offset = self.info.n_mit_mot + self.info.n_mit_sot + self.info.n_sit_sot
return list_outputs[offset : offset + self.info.n_nit_sot]
def inner_shared(self, list_inputs):
def inner_untraced_sit_sot(self, list_inputs):
n_taps_upto_sit_sot = sum(
len(x)
for x in chain(self.info.mit_mot_in_slices, self.info.mit_sot_in_slices)
)
offset = self.info.n_seqs + n_taps_upto_sit_sot + self.info.n_sit_sot
return list_inputs[offset : offset + self.info.n_shared_outs]
return list_inputs[offset : offset + self.info.n_untraced_sit_sot_outs]
def outer_shared(self, list_inputs):
def inner_shared(self, list_inputs):
warnings.warn(
"The 'inner_shared' method is deprecated. Use 'inner_untraced_sit_sot' instead.",
DeprecationWarning,
stacklevel=2,
)
return self.inner_untraced_sit_sot(list_inputs)
def outer_untraced_sit_sot(self, list_inputs):
offset = (
1
+ self.info.n_seqs
......@@ -410,23 +428,47 @@ class ScanMethodsMixin:
+ self.info.n_mit_sot
+ self.info.n_sit_sot
)
return list_inputs[offset : offset + self.info.n_shared_outs]
return list_inputs[offset : offset + self.info.n_untraced_sit_sot_outs]
def inner_shared_outs(self, list_outputs):
def outer_shared(self, list_inputs):
warnings.warn(
"The 'outer_shared' method is deprecated. Use 'outer_untraced_sit_sot' instead.",
DeprecationWarning,
stacklevel=2,
)
return self.outer_untraced_sit_sot(list_inputs)
def inner_untraced_sit_sot_outs(self, list_outputs):
n_taps = sum(len(x) for x in self.info.mit_mot_out_slices)
offset = (
self.info.n_mit_sot + n_taps + self.info.n_sit_sot + self.info.n_nit_sot
)
return list_outputs[offset : offset + self.info.n_shared_outs]
return list_outputs[offset : offset + self.info.n_untraced_sit_sot_outs]
def outer_shared_outs(self, list_outputs):
def inner_shared_outs(self, list_outputs):
warnings.warn(
"The 'inner_shared_outs' method is deprecated. Use 'inner_untraced_sit_sot_outs' instead.",
DeprecationWarning,
stacklevel=2,
)
return self.inner_untraced_sit_sot_outs(list_outputs)
def outer_untraced_sit_sot_outs(self, list_outputs):
offset = (
self.info.n_mit_mot
+ self.info.n_mit_sot
+ self.info.n_sit_sot
+ self.info.n_nit_sot
)
return list_outputs[offset : offset + self.info.n_shared_outs]
return list_outputs[offset : offset + self.info.n_untraced_sit_sot_outs]
def outer_shared_outs(self, list_outputs):
warnings.warn(
"The 'outer_shared_outs' method is deprecated. Use 'outer_untraced_sit_sot_outs' instead.",
DeprecationWarning,
stacklevel=2,
)
return self.outer_untraced_sit_sot_outs(list_outputs)
def inner_non_seqs(self, list_inputs):
n_taps_upto_sit_sot = sum(
......@@ -437,7 +479,7 @@ class ScanMethodsMixin:
self.info.n_seqs
+ n_taps_upto_sit_sot
+ self.info.n_sit_sot
+ self.info.n_shared_outs
+ self.info.n_untraced_sit_sot_outs
)
return list_inputs[offset:]
......@@ -449,7 +491,7 @@ class ScanMethodsMixin:
+ self.info.n_mit_sot
+ self.info.n_sit_sot
+ self.info.n_nit_sot
+ self.info.n_shared_outs
+ self.info.n_untraced_sit_sot_outs
)
return list_inputs[offset:]
......@@ -525,8 +567,8 @@ class ScanMethodsMixin:
outer_oidx += 1
# This is needed because, for outer inputs (and for outer inputs only)
# nitsots come *after* shared variables.
outer_iidx += self.info.n_shared_outs
# nitsots come *after* untraced_sitsot variables.
outer_iidx += self.info.n_untraced_sit_sot_outs
# Handle nitsots variables
for i in range(self.info.n_nit_sot):
......@@ -541,11 +583,11 @@ class ScanMethodsMixin:
outer_oidx += 1
# This is needed because, for outer inputs (and for outer inputs only)
# nitsots come *after* shared variables.
outer_iidx -= self.info.n_shared_outs + self.info.n_nit_sot
# nitsots come *after* untraced_sit_sot variables.
outer_iidx -= self.info.n_untraced_sit_sot_outs + self.info.n_nit_sot
# Handle shared states
for i in range(self.info.n_shared_outs):
# Handle untraced_sitsot states
for i in range(self.info.n_untraced_sit_sot_outs):
outer_input_indices.append(outer_iidx)
inner_input_indices.append([inner_iidx])
inner_output_indices.append([inner_oidx])
......@@ -557,7 +599,7 @@ class ScanMethodsMixin:
outer_oidx += 1
# This is needed because, for outer inputs (and for outer inputs only)
# nitsots come *after* shared variables.
# nitsots come *after* untraced_sitsot variables.
outer_iidx += self.info.n_nit_sot
# Handle non-sequence inputs
......@@ -708,7 +750,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
Inputs of the inner function of `Scan`.
These take the following general form:
sequences + mit-mot-inputs + mit-sot-inputs + sit-sot-inputs + shared-inputs + non-sequences
sequences + mit-mot-inputs + mit-sot-inputs + sit-sot-inputs + untraced-sit-sot-inputs + shared-inputs + non-sequences
where each term is a list of `Variable`\s.
......@@ -716,7 +758,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
Outputs of the inner function of `Scan`.
These take the following general form:
mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots + shared-outputs [+ while-condition]
mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots + untraced-sit-sot-outputs [+ while-condition]
where each term is a list of `Variable`\s.
......@@ -817,7 +859,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
typeConstructor((None, *o.type.shape), o.type.dtype)
)
# shared outputs + possibly the ending condition
# untraced_sit_sot outputs + possibly the ending condition
for o in self.fgraph.outputs[end:]:
self.output_types.append(o.type)
......@@ -836,10 +878,12 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
]
self.mintaps += [0 for x in range(info.n_nit_sot)]
self.seqs_arg_offset = 1 + info.n_seqs
self.shared_arg_offset = (
self.untraced_sit_sot_arg_offset = (
self.seqs_arg_offset + info.n_mit_mot + info.n_mit_sot + info.n_sit_sot
)
self.nit_sot_arg_offset = self.shared_arg_offset + info.n_shared_outs
self.nit_sot_arg_offset = (
self.untraced_sit_sot_arg_offset + info.n_untraced_sit_sot_outs
)
# XXX: This doesn't include `info.n_nit_sot`s, so it's really a count
# of the number of outputs generated by taps with inputs
self.n_outs = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot
......@@ -908,7 +952,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
sequences +
mit-mot-inputs + mit-sot-inputs + sit-sot-inputs +
shared-inputs +
untraced-sit-sot-inputs + shared-inputs
nit-sots +
non-sequences
......@@ -923,7 +967,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
[n_steps] +
sequences +
mit-mot-inputs + mit-sot-inputs + sit-sot-inputs +
shared-inputs +
untraced-sit-sot-inputs + shared-inputs
nit-sots +
non-sequences
......@@ -931,7 +975,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
mit-mot-outputs + mit-sot-outputs + sit-sot-outputs +
nit-sots +
shared-outputs
untraced-sit-sot-outputs
These outer-outputs essentially follow the same form as their
corresponding inner-outputs, excluding the final "while" condition
......@@ -949,7 +993,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
+ len(self.info.mit_mot_in_slices)
+ len(self.info.mit_sot_in_slices)
+ len(self.inner_sitsot(self.inner_inputs))
+ len(self.inner_shared(self.inner_inputs))
+ len(self.inner_untraced_sit_sot(self.inner_inputs))
+ len(self.inner_non_seqs(self.inner_inputs))
)
......@@ -1134,60 +1178,60 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
)
argoffset += len(self.outer_sitsot(inputs))
# Check that the shared variable and their update rule have the same
# Check that the untraced (u) sit-sot variable and their update rule have the same
# dtype. Maybe even same type ?!
for idx, (inner_shared, inner_shared_out, _outer_shared) in enumerate(
for idx, (inner_u_sitsot, inner_u_sitsot_out, _outer_u_sitsot) in enumerate(
zip(
self.inner_shared(self.inner_inputs),
self.inner_shared_outs(self.inner_outputs),
self.outer_shared(inputs),
self.inner_untraced_sit_sot(self.inner_inputs),
self.inner_untraced_sit_sot_outs(self.inner_outputs),
self.outer_untraced_sit_sot(inputs),
strict=True,
)
):
outer_shared = copy_var_format(_outer_shared, as_var=inner_shared)
new_inputs.append(outer_shared)
outer_u_sitsot = copy_var_format(_outer_u_sitsot, as_var=inner_u_sitsot)
new_inputs.append(outer_u_sitsot)
if (
hasattr(outer_shared, "dtype")
and outer_shared.dtype != inner_shared_out.dtype
hasattr(outer_u_sitsot, "dtype")
and outer_u_sitsot.dtype != inner_u_sitsot_out.dtype
):
raise ValueError(
err_msg2
% (
str(outer_shared),
str(outer_u_sitsot),
idx + argoffset,
outer_shared.dtype,
inner_shared_out.dtype,
outer_u_sitsot.dtype,
inner_u_sitsot_out.dtype,
)
)
if (
hasattr(outer_shared, "dtype")
and outer_shared.ndim != inner_shared_out.ndim
hasattr(outer_u_sitsot, "dtype")
and outer_u_sitsot.ndim != inner_u_sitsot_out.ndim
):
raise ValueError(
err_msg3
% (
str(outer_shared),
str(outer_u_sitsot),
idx + argoffset,
outer_shared.ndim,
inner_shared_out.ndim,
outer_u_sitsot.ndim,
inner_u_sitsot_out.ndim,
)
)
if hasattr(outer_shared, "dtype") and (
outer_shared.dtype != inner_shared.dtype
or outer_shared.ndim != inner_shared.ndim
if hasattr(outer_u_sitsot, "dtype") and (
outer_u_sitsot.dtype != inner_u_sitsot.dtype
or outer_u_sitsot.ndim != inner_u_sitsot.ndim
):
raise ValueError(
err_msg1
% (
"initial state (outputs_info in scan nomenclature) ",
str(outer_shared),
str(outer_u_sitsot),
argoffset + idx,
outer_shared.dtype,
outer_shared.ndim,
str(inner_shared),
inner_shared.dtype,
inner_shared.ndim,
outer_u_sitsot.dtype,
outer_u_sitsot.ndim,
str(inner_u_sitsot),
inner_u_sitsot.dtype,
inner_u_sitsot.ndim,
)
)
# We do not need to call `copy_var_format` on outer_nisot arguments.
......@@ -1585,7 +1629,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
try:
t_fn, n_steps = scan_perform_ext.perform(
self.info.n_shared_outs,
self.info.n_untraced_sit_sot_outs,
self.info.n_mit_mot_outs,
self.info.n_seqs,
self.info.n_mit_mot,
......@@ -1719,7 +1763,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# The length of each output
store_steps = [
arg.shape[0]
for arg in inputs[self.seqs_arg_offset : self.shared_arg_offset]
for arg in inputs[self.seqs_arg_offset : self.untraced_sit_sot_arg_offset]
]
store_steps += list(
inputs[self.nit_sot_arg_offset : self.nit_sot_arg_offset + info.n_nit_sot]
......@@ -1784,7 +1828,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
info.sit_sot_in_slices,
)
)
+ info.n_shared_outs
+ info.n_untraced_sit_sot_outs
)
for idx in range(len(other_args)):
inner_input_storage[idx + offset].storage[0] = other_args[idx]
......@@ -1827,14 +1871,14 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
]
offset += 1
a_offset = self.shared_arg_offset
a_offset = self.untraced_sit_sot_arg_offset
o_offset = self.n_outs + info.n_nit_sot
if i == 0:
for j in range(info.n_shared_outs):
for j in range(info.n_untraced_sit_sot_outs):
inner_input_storage[offset].storage[0] = inputs[a_offset + j]
offset += 1
else:
for j in range(info.n_shared_outs):
for j in range(info.n_untraced_sit_sot_outs):
inner_input_storage[offset].storage[0] = output_storage[
o_offset + j
][0]
......@@ -1866,14 +1910,14 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
for idx in range(self.n_outs + info.n_nit_sot - info.n_mit_mot):
inner_output_storage[idx + offset].storage[0] = None
# 4.3. Collect slices for shared outputs
# 4.3. Collect slices for untraced sitsot outputs
offset += self.n_outs + info.n_nit_sot - info.n_mit_mot
for idx in range(info.n_shared_outs):
for idx in range(info.n_untraced_sit_sot_outs):
inner_output_storage[idx + offset].storage[0] = None
# 4.4. If there is a condition add it to the mix
if info.as_while:
pdx = offset + info.n_shared_outs
pdx = offset + info.n_untraced_sit_sot_outs
inner_output_storage[pdx].storage[0] = None
# 4.5. Keep a reference to the variables (ndarrays,
......@@ -1942,7 +1986,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
dt_fn = time.perf_counter() - t0_fn
if info.as_while:
pdx = offset + info.n_shared_outs
pdx = offset + info.n_untraced_sit_sot_outs
cond = inner_output_storage[pdx].storage[0] == 0
t_fn += dt_fn
......@@ -2089,10 +2133,10 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
j + offset_out
].storage[0]
# 5.6 Copy over the values for outputs corresponding to shared
# 5.6 Copy over the values for outputs corresponding to untraced sitsot
# variables
begin = end
end += info.n_shared_outs
end += info.n_untraced_sit_sot_outs
for j in range(begin, end):
jout = j + offset_out
output_storage[j][0] = inner_output_storage[jout].storage[0]
......@@ -2240,13 +2284,13 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# out_equivalent[self.inner_inputs[inner_inp_idx]] = corresponding_tap
outer_inp_idx += 1
# shared_outs
# untraced sit_sot outputs
offset = 1 + info.n_seqs + n_outs
for idx in range(info.n_shared_outs):
for idx in range(info.n_untraced_sit_sot_outs):
outs_shape += [input_shapes[idx + offset]]
# non_sequences
offset += info.n_nit_sot + info.n_shared_outs
offset += info.n_nit_sot + info.n_untraced_sit_sot_outs
inner_ins_shapes = seqs_shape + outs_shape + input_shapes[offset:]
assert len(inner_ins_shapes) == len(self.inner_inputs)
......@@ -2288,7 +2332,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# in the inner function.
r = node.outputs[n_outs + x]
assert r.ndim == 1 + len(out_shape_x)
shp = [node.inputs[offset + info.n_shared_outs + x]]
shp = [node.inputs[offset + info.n_untraced_sit_sot_outs + x]]
for i, shp_i in zip(range(1, r.ndim), out_shape_x, strict=True):
# Validate shp_i. v_shape_i is either None (if invalid),
# or a (variable, Boolean) tuple. The Boolean indicates
......@@ -2305,7 +2349,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
shp.append(v_shp_i[0])
scan_outs.append(tuple(shp))
scan_outs += list(input_shapes[offset : offset + info.n_shared_outs])
scan_outs += list(input_shapes[offset : offset + info.n_untraced_sit_sot_outs])
# if we are dealing with a repeat-until, then we do not know the
# leading dimension so we replace it for every entry with Shape_i
if info.as_while:
......@@ -2735,7 +2779,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
mitmot_inp_taps.append([])
mitmot_out_taps.append([])
undefined_msg = None
through_shared = False
through_untraced = False
disconnected = True
for mit_mot_out_slice in info.mit_mot_out_slices[idx]:
......@@ -2779,9 +2823,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
disconnected &= disconnected_dC_dinps_t[ins_pos]
through_shared = any(
through_untraced = any(
_sh in graph_inputs([dC_dinps_t[ins_pos]])
for _sh in self.inner_shared(self_inputs)
for _sh in self.inner_untraced_sit_sot(self_inputs)
)
ins_pos += 1
......@@ -2795,8 +2839,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
if undefined_msg:
type_outs.append(undefined_msg)
elif through_shared:
type_outs.append("through_shared")
elif through_untraced:
type_outs.append("through_untraced")
elif disconnected:
type_outs.append("disconnected")
else:
......@@ -2814,7 +2858,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
out_pos += 1
n_mitmot_inps += 1
undefined_msg = None
through_shared = False
through_untraced = False
disconnected = True
mitmot_inp_taps[idx + offset].append(0)
for tap in taps:
......@@ -2836,9 +2880,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
disconnected &= disconnected_dC_dinps_t[ins_pos]
through_shared = any(
through_untraced = any(
_sh in graph_inputs([dC_dinps_t[ins_pos]])
for _sh in self.inner_shared(self_inputs)
for _sh in self.inner_untraced_sit_sot(self_inputs)
)
n_mitmot_inps += 1
......@@ -2847,8 +2891,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
if undefined_msg:
type_outs.append(undefined_msg)
elif through_shared:
type_outs.append("through_shared")
elif through_untraced:
type_outs.append("through_untraced")
elif disconnected:
type_outs.append("disconnected")
else:
......@@ -2884,15 +2928,15 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
else:
inner_out_mitmot.append(dC_dinps_t[ins_pos])
through_shared = any(
through_untraced = any(
_sh in graph_inputs([dC_dinps_t[ins_pos]])
for _sh in self.inner_shared(self_inputs)
for _sh in self.inner_untraced_sit_sot(self_inputs)
)
if isinstance(dC_dinps_t[ins_pos].type, NullType):
type_outs.append(dC_dinps_t[ins_pos].type.why_null)
elif through_shared:
type_outs.append("through_shared")
elif through_untraced:
type_outs.append("through_untraced")
elif disconnected_dC_dinps_t[ins_pos]:
type_outs.append("disconnected")
else:
......@@ -2911,10 +2955,10 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
inner_out_nitsot = dC_dinps_t[: info.n_seqs]
inner_out_sitsot = dC_dinps_t[ins_pos:]
for _p, vl in enumerate(inner_out_sitsot):
through_shared = False
for _sh in self.inner_shared(self_inputs):
through_untraced = False
for _sh in self.inner_untraced_sit_sot(self_inputs):
if _sh in graph_inputs([vl]):
through_shared = True
through_untraced = True
if isinstance(vl.type, NullType):
type_outs.append(vl.type.why_null)
# Replace the inner output with a zero tensor of
......@@ -2922,18 +2966,18 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
inner_out_sitsot[_p] = pt.zeros(
diff_inputs[ins_pos + _p].shape, dtype=config.floatX
)
elif through_shared:
type_outs.append("through_shared")
elif through_untraced:
type_outs.append("through_untraced")
elif disconnected_dC_dinps_t[_p + ins_pos]:
type_outs.append("disconnected")
else:
type_outs.append("connected")
for _p, vl in enumerate(inner_out_nitsot):
through_shared = False
for _sh in self.inner_shared(self_inputs):
through_untraced = False
for _sh in self.inner_untraced_sit_sot(self_inputs):
if _sh in graph_inputs([vl]):
through_shared = True
through_untraced = True
if isinstance(vl.type, NullType):
type_outs.append(vl.type.why_null)
# Replace the inner output with a zero tensor of
......@@ -2942,8 +2986,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
diff_inputs[_p].shape, dtype=config.floatX
)
if through_shared:
type_outs.append("through_shared")
if through_untraced:
type_outs.append("through_untraced")
elif disconnected_dC_dinps_t[_p]:
type_outs.append("disconnected")
else:
......@@ -2983,7 +3027,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
+ outer_inp_mitmot
+ outer_inp_sitsot
+ [n_steps if info.as_while else inputs[0] for _ in range(n_nit_sot)]
+ self.outer_shared(inputs)
+ self.outer_untraced_sit_sot(inputs)
+ self.outer_non_seqs(inputs)
)
......@@ -2991,7 +3035,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
inner_inp_seqs
+ inner_inp_mitmot
+ inner_inp_sitsot
+ self.inner_shared(self_inputs)
+ self.inner_untraced_sit_sot(self_inputs)
+ self.inner_non_seqs(self_inputs)
)
inner_gfn_outs = inner_out_mitmot + inner_out_sitsot + inner_out_nitsot
......@@ -3003,8 +3047,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
mit_sot_in_slices=(),
sit_sot_in_slices=tuple((-1,) for k in range(n_sitsot_outs)),
n_nit_sot=n_nit_sot,
n_shared_outs=0,
n_non_seqs=len(self.outer_shared(inputs))
n_untraced_sit_sot_outs=0,
n_non_seqs=len(self.outer_untraced_sit_sot(inputs))
+ len(self.outer_non_seqs(inputs)),
as_while=False,
)
......@@ -3047,10 +3091,10 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
gradients.append(x[::-1])
elif t == "disconnected":
gradients.append(DisconnectedType()())
elif t == "through_shared":
elif t == "through_untraced":
gradients.append(
grad_undefined(
self, p + 1, inputs[p + 1], "Depends on a shared variable"
self, p + 1, inputs[p + 1], "Depends on a untraced variable"
)
)
else:
......@@ -3075,13 +3119,13 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
gradients.append(x[::-1])
elif t == "disconnected":
gradients.append(DisconnectedType()())
elif t == "through_shared":
elif t == "through_untraced":
gradients.append(
grad_undefined(
self,
p + 1 + info.n_seqs,
inputs[p + 1 + info.n_seqs],
"Depends on a shared variable",
"Depends on an untraced variable",
)
)
else:
......@@ -3090,7 +3134,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
start = len(gradients)
node = outs[0].owner
for idx in range(info.n_shared_outs):
for idx in range(info.n_untraced_sit_sot_outs):
disconnected = True
connected_flags = self.connection_pattern(node)[idx + start]
for dC_dout, connected in zip(dC_douts, connected_flags, strict=True):
......@@ -3116,13 +3160,13 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
gradients.append(x[-1])
elif t == "disconnected":
gradients.append(DisconnectedType()())
elif t == "through_shared":
elif t == "through_untraced":
gradients.append(
grad_undefined(
self,
p + begin + 1,
inputs[p + begin + 1],
"Depends on a shared variable",
"Depends on a untraced variable",
)
)
else:
......@@ -3152,7 +3196,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self_inputs = self.inner_inputs
rop_of_inputs = (
self_inputs[: info.n_seqs + self.n_outs]
+ self_inputs[info.n_seqs + self.n_outs + info.n_shared_outs :]
+ self_inputs[info.n_seqs + self.n_outs + info.n_untraced_sit_sot_outs :]
)
self_outputs = self.inner_outputs
......@@ -3162,8 +3206,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
rop_self_outputs = self_outputs[:-1]
else:
rop_self_outputs = self_outputs
if info.n_shared_outs > 0:
rop_self_outputs = rop_self_outputs[: -info.n_shared_outs]
if info.n_untraced_sit_sot_outs > 0:
rop_self_outputs = rop_self_outputs[: -info.n_untraced_sit_sot_outs]
rop_outs = Rop(
rop_self_outputs,
rop_of_inputs,
......@@ -3247,13 +3291,13 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
scan_sit_sot = inputs[b:e] + clean_eval_points
inner_sit_sot = self_inputs[ib:ie] + inner_eval_points[ib:ie]
# Shared outs ...
# Untraced outs ...
b = e
e = e + info.n_shared_outs
e = e + info.n_untraced_sit_sot_outs
ib = ie
ie = ie + info.n_shared_outs
scan_shared = inputs[b:e]
inner_shared = self_inputs[ib:ie]
ie = ie + info.n_untraced_sit_sot_outs
scan_untraced = inputs[b:e]
inner_untraced = self_inputs[ib:ie]
# NIT_SOT sequences
b = e
......@@ -3268,7 +3312,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
else:
clean_eval_points.append(inp.zeros_like())
scan_other = inputs[e:] + clean_eval_points
# inner_eval_points do not have entries for shared variables
# inner_eval_points do not have entries for untraced variables
inner_other = self_inputs[ie:] + inner_eval_points[ib:]
# Outputs
......@@ -3287,15 +3331,15 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
e = e + info.n_nit_sot
inner_out_nit_sot = self_outputs[b:e] + rop_outs[b:e]
b = e
e = e + info.n_shared_outs
inner_out_shared = self_outputs[b:e]
e = e + info.n_untraced_sit_sot_outs
inner_out_untraced = self_outputs[b:e]
inner_ins = (
inner_seqs
+ inner_mit_mot
+ inner_mit_sot
+ inner_sit_sot
+ inner_shared
+ inner_untraced
+ inner_other
)
inner_outs = (
......@@ -3303,7 +3347,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
+ inner_out_mit_sot
+ inner_out_sit_sot
+ inner_out_nit_sot
+ inner_out_shared
+ inner_out_untraced
)
if info.as_while:
......@@ -3314,7 +3358,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
*scan_mit_mot,
*scan_mit_sot,
*scan_sit_sot,
*scan_shared,
*scan_untraced,
*scan_nit_sot,
*scan_other,
]
......@@ -3326,7 +3370,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
mit_sot_in_slices=new_mit_sot_in_slices,
sit_sot_in_slices=new_sit_sot_in_slices,
n_nit_sot=info.n_nit_sot * 2,
n_shared_outs=info.n_shared_outs,
n_untraced_sit_sot_outs=info.n_untraced_sit_sot_outs,
n_non_seqs=len(inner_other),
as_while=info.as_while,
)
......@@ -3358,7 +3402,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
b = e + info.n_nit_sot
e = e + info.n_nit_sot * 2
final_outs += outputs[b:e]
final_outs += [None] * info.n_shared_outs
final_outs += [None] * info.n_untraced_sit_sot_outs
return final_outs
......
......@@ -110,7 +110,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
sum(len(x) for x in chain(op_info.mit_mot_in_slices, op_info.mit_sot_in_slices))
)
st += op_info.n_sit_sot
st += op_info.n_shared_outs
st += op_info.n_untraced_sit_sot_outs
op_ins = op.inner_inputs
op_outs = op.inner_outputs
......@@ -126,7 +126,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
+ op_info.n_mit_sot
+ op_info.n_sit_sot
+ op_info.n_nit_sot
+ op_info.n_shared_outs
+ op_info.n_untraced_sit_sot_outs
+ 1
)
outer_non_seqs = node.inputs[st:]
......@@ -983,7 +983,7 @@ class ScanInplaceOptimizer(GraphRewriter):
ls = op.outer_mitmot(node.inputs)
ls += op.outer_mitsot(node.inputs)
ls += op.outer_sitsot(node.inputs)
ls_end = op.outer_shared(node.inputs)
ls_end = op.outer_untraced_sit_sot(node.inputs)
ls_end += op.outer_nitsot(node.inputs)
ls_end += op.outer_non_seqs(node.inputs)
......@@ -1628,7 +1628,7 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
+ idx
+ op_info.n_seqs
+ 1
+ op_info.n_shared_outs
+ op_info.n_untraced_sit_sot_outs
)
if nw_inputs[pos] == node.inputs[0]:
nw_inputs[pos] = 1 if required_orphan else val
......@@ -1662,7 +1662,7 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
elif (
idx < op_info.n_mit_sot + op_info.n_sit_sot + op_info.n_nit_sot
):
in_idx = offset + idx + op_info.n_shared_outs
in_idx = offset + idx + op_info.n_untraced_sit_sot_outs
if nw_inputs[in_idx] == node.inputs[0]:
nw_inputs[in_idx] = nw_steps
......@@ -1886,8 +1886,8 @@ class ScanMerge(GraphRewriter):
for idx, nd in enumerate(nodes):
# Shared
inner_ins[idx].append(nd.op.inner_shared(nd.op.inner_inputs))
outer_ins += nd.op.outer_shared(nd.inputs)
inner_ins[idx].append(nd.op.inner_untraced_sit_sot(nd.op.inner_inputs))
outer_ins += nd.op.outer_untraced_sit_sot(nd.inputs)
for idx, nd in enumerate(nodes):
# NitSot
......@@ -1897,8 +1897,10 @@ class ScanMerge(GraphRewriter):
for idx, nd in enumerate(nodes):
# Shared
outer_outs += nd.op.outer_shared_outs(nd.outputs)
inner_outs[idx].append(nd.op.inner_shared_outs(nd.op.inner_outputs))
outer_outs += nd.op.outer_untraced_sit_sot_outs(nd.outputs)
inner_outs[idx].append(
nd.op.inner_untraced_sit_sot_outs(nd.op.inner_outputs)
)
n_non_seqs = 0
for idx, nd in enumerate(nodes):
......@@ -1978,7 +1980,9 @@ class ScanMerge(GraphRewriter):
mit_sot_in_slices=mit_sot_in_slices,
sit_sot_in_slices=sit_sot_in_slices,
n_nit_sot=sum(nd.op.info.n_nit_sot for nd in nodes),
n_shared_outs=sum(nd.op.info.n_shared_outs for nd in nodes),
n_untraced_sit_sot_outs=sum(
nd.op.info.n_untraced_sit_sot_outs for nd in nodes
),
n_non_seqs=n_non_seqs,
as_while=as_while,
)
......@@ -2360,7 +2364,7 @@ def scan_push_out_dot1(fgraph, node):
# When seq[t] is a vector/matrix and `value` is a matrix
# Note that this works when only you need X[-1] in the end
# and assumes dimshuffle are applied to vectors before calling dot
op = node.op
op: Scan = node.op
sitsot_ins = op.inner_sitsot(op.inner_inputs)
sitsot_outs = op.inner_sitsot_outs(op.inner_outputs)
outer_sitsot = op.outer_sitsot_outs(node.outputs)
......@@ -2416,9 +2420,13 @@ def scan_push_out_dot1(fgraph, node):
inner_sitsot_outs = op.inner_sitsot_outs(op.inner_outputs)
outer_nitsot = op.outer_nitsot(node.inputs)
inner_nitsot_outs = op.inner_nitsot_outs(op.inner_outputs)
inner_shared = op.inner_shared(op.inner_inputs)
outer_shared = op.outer_shared(node.inputs)
inner_shared_outs = op.inner_shared_outs(op.inner_outputs)
inner_untraced_sitsot = op.inner_untraced_sitsot(op.inner_inputs)
outer_untraced_sitsot_outs = op.outer_untraced_sitsot_outs(
node.inputs
)
inner_untraced_sitsot_outs = op.inner_untraced_sitsot_outs(
op.inner_outputs
)
inner_non_seqs = op.inner_non_seqs(op.inner_inputs)
outer_non_seqs = op.outer_non_seqs(node.inputs)
......@@ -2441,7 +2449,7 @@ def scan_push_out_dot1(fgraph, node):
+ inner_mitmot
+ inner_mitsot
+ inner_sitsot
+ inner_shared
+ inner_untraced_sitsot
+ inner_non_seqs
)
_new_inner_outs = (
......@@ -2449,7 +2457,7 @@ def scan_push_out_dot1(fgraph, node):
+ inner_mitsot_outs
+ inner_sitsot_outs
+ inner_nitsot_outs
+ inner_shared_outs
+ inner_untraced_sitsot_outs
)
new_inner_inps, new_inner_outs = reconstruct_graph(
_new_inner_inps, _new_inner_outs
......@@ -2471,7 +2479,7 @@ def scan_push_out_dot1(fgraph, node):
*outer_mitmot,
*outer_mitsot,
*outer_sitsot,
*outer_shared,
*outer_untraced_sitsot_outs,
*outer_nitsot,
node.inputs[0],
*outer_non_seqs,
......
......@@ -370,7 +370,9 @@ def scan_can_remove_outs(op, out_idxs):
out_ins += [op.inner_inputs[offset : offset + n_ins]]
offset += n_ins
out_ins += [[] for k in range(op.info.n_nit_sot)]
out_ins += [[op.inner_inputs[offset + k]] for k in range(op.info.n_shared_outs)]
out_ins += [
[op.inner_inputs[offset + k]] for k in range(op.info.n_untraced_sit_sot_outs)
]
added = True
out_idxs_mask = [1 for idx in out_idxs]
......@@ -409,7 +411,7 @@ def compress_outs(op, not_required, inputs):
mit_sot_in_slices=(),
sit_sot_in_slices=(),
n_nit_sot=0,
n_shared_outs=0,
n_untraced_sit_sot_outs=0,
n_non_seqs=0,
as_while=op_info.as_while,
)
......@@ -515,17 +517,19 @@ def compress_outs(op, not_required, inputs):
info = dataclasses.replace(info, n_nit_sot=info.n_nit_sot + 1)
op_outputs += [op.inner_outputs[o_offset]]
o_offset += 1
nit_sot_ins += [inputs[ni_offset + idx + op_info.n_shared_outs]]
nit_sot_ins += [inputs[ni_offset + idx + op_info.n_untraced_sit_sot_outs]]
else:
o_offset += 1
offset += op_info.n_nit_sot
shared_ins = []
for idx in range(op_info.n_shared_outs):
for idx in range(op_info.n_untraced_sit_sot_outs):
if offset + idx not in not_required:
map_old_new[offset + idx] = curr_pos
curr_pos += 1
info = dataclasses.replace(info, n_shared_outs=info.n_shared_outs + 1)
info = dataclasses.replace(
info, n_untraced_sit_sot_outs=info.n_untraced_sit_sot_outs + 1
)
op_outputs += [op.inner_outputs[o_offset]]
o_offset += 1
op_inputs += [op.inner_inputs[i_offset]]
......@@ -539,7 +543,9 @@ def compress_outs(op, not_required, inputs):
# other stuff
op_inputs += op.inner_inputs[i_offset:]
info = dataclasses.replace(info, n_non_seqs=len(op.inner_inputs[i_offset:]))
node_inputs += inputs[ni_offset + op_info.n_shared_outs + op_info.n_nit_sot :]
node_inputs += inputs[
ni_offset + op_info.n_untraced_sit_sot_outs + op_info.n_nit_sot :
]
if op_info.as_while:
op_outputs += [op.inner_outputs[o_offset]]
map_old_new[o_offset] = len(op_outputs) - 1
......@@ -658,11 +664,11 @@ class ScanArgs:
p += n_sit_sot
q += n_sit_sot
n_shared_outs = info.n_shared_outs
self.outer_in_shared = list(outer_inputs[p : p + n_shared_outs])
self.inner_in_shared = list(inner_inputs[q : q + n_shared_outs])
p += n_shared_outs
q += n_shared_outs
n_untraced_sit_sot_outs = info.n_untraced_sit_sot_outs
self.outer_in_shared = list(outer_inputs[p : p + n_untraced_sit_sot_outs])
self.inner_in_shared = list(inner_inputs[q : q + n_untraced_sit_sot_outs])
p += n_untraced_sit_sot_outs
q += n_untraced_sit_sot_outs
n_nit_sot = info.n_nit_sot
self.outer_in_nit_sot = list(outer_inputs[p : p + n_nit_sot])
......@@ -702,10 +708,10 @@ class ScanArgs:
p += n_nit_sot
q += n_nit_sot
self.outer_out_shared = list(outer_outputs[p : p + n_shared_outs])
self.inner_out_shared = list(inner_outputs[q : q + n_shared_outs])
p += n_shared_outs
q += n_shared_outs
self.outer_out_shared = list(outer_outputs[p : p + n_untraced_sit_sot_outs])
self.inner_out_shared = list(inner_outputs[q : q + n_untraced_sit_sot_outs])
p += n_untraced_sit_sot_outs
q += n_untraced_sit_sot_outs
assert p == len(outer_outputs)
assert q == len(inner_outputs)
......@@ -816,7 +822,7 @@ class ScanArgs:
mit_sot_in_slices=tuple(tuple(v) for v in self.mit_sot_in_slices),
sit_sot_in_slices=((-1,),) * len(self.inner_in_sit_sot),
n_nit_sot=len(self.outer_in_nit_sot),
n_shared_outs=len(self.outer_in_shared),
n_untraced_sit_sot_outs=len(self.outer_in_shared),
n_non_seqs=len(self.inner_in_non_seqs),
as_while=self.as_while,
)
......
......@@ -85,7 +85,7 @@ from tests.link.numba.test_basic import compare_numba_and_py
3,
[],
[np.array([0.50100236, 2.16822932, 1.36326596])],
lambda op: op.info.n_shared_outs > 0,
lambda op: op.info.n_untraced_sit_sot_outs > 0,
),
# mit-sot (that's also a type of sit-sot)
(
......
......@@ -42,11 +42,13 @@ from pytensor.tensor.math import all as pt_all
from pytensor.tensor.math import dot, exp, mean, sigmoid, tanh
from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.random import normal
from pytensor.tensor.random.type import RandomGeneratorType, random_generator_type
from pytensor.tensor.random.utils import RandomStream
from pytensor.tensor.shape import Shape_i, reshape, specify_shape
from pytensor.tensor.sharedvar import SharedVariable
from pytensor.tensor.subtensor import Subtensor
from pytensor.tensor.type import (
TensorType,
dcol,
dmatrix,
dscalar,
......@@ -4007,7 +4009,7 @@ class TestExamples:
[{}],
[],
3,
lambda op: op.info.n_shared_outs > 0,
lambda op: op.info.n_untraced_sit_sot_outs > 0,
),
# mit-sot (that's also a type of sit-sot)
(
......@@ -4106,3 +4108,34 @@ def test_output_storage_reuse(linker_mode):
res = f_cvm()
assert np.array_equal(res, np.array([3, 1, 0]))
def test_rng_outputs_info():
rng_init = random_generator_type("rng")
rng_x0, x0 = pt.random.normal(0, rng=rng_init, dtype="float64").owner.outputs
def step(prev_x, prev_rng):
next_rng, next_x = pt.random.normal(
prev_x, rng=prev_rng, dtype="float64"
).owner.outputs
return next_x, next_rng
[xs, rng_final], updates = scan(
fn=step,
outputs_info=[x0, rng_x0],
n_steps=10,
)
assert isinstance(xs.type, TensorType)
assert isinstance(rng_final.type, RandomGeneratorType)
assert not updates
fn = function([rng_init], [xs, rng_final])
xs_eval, rng_final_eval = fn(np.random.default_rng(0))
rng_ref = np.random.default_rng(0)
assert not random_generator_type.values_eq(rng_ref, rng_final_eval)
xs_ref = [rng_ref.normal(0)]
for i in range(10):
xs_ref.append(rng_ref.normal(xs_ref[-1]))
assert random_generator_type.values_eq(rng_ref, rng_final_eval)
np.testing.assert_allclose(xs_eval, xs_ref[1:])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论