提交 1d19c375 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Allow non-shared untraced SIT-SOT

上级 d10c61ba
...@@ -60,23 +60,23 @@ def jax_funcify_Scan(op: Scan, **kwargs): ...@@ -60,23 +60,23 @@ def jax_funcify_Scan(op: Scan, **kwargs):
mit_mot_init, mit_mot_init,
mit_sot_init, mit_sot_init,
sit_sot_init, sit_sot_init,
op.outer_shared(outer_inputs), op.outer_untraced_sit_sot(outer_inputs),
op.outer_non_seqs(outer_inputs), op.outer_non_seqs(outer_inputs),
) # JAX `init` ) # JAX `init`
def jax_args_to_inner_func_args(carry, x): def jax_args_to_inner_func_args(carry, x):
"""Convert JAX scan arguments into format expected by scan_inner_func. """Convert JAX scan arguments into format expected by scan_inner_func.
scan(carry, x) -> scan_inner_func(seqs, MIT-SOT, SIT-SOT, shared, non_seqs) scan(carry, x) -> scan_inner_func(seqs, MIT-SOT, SIT-SOT, untraced SIT-SOT, non_seqs)
""" """
# `carry` contains all inner taps, shared terms, and non_seqs # `carry` contains all inner taps and non_seqs
( (
i, i,
inner_mit_mot, inner_mit_mot,
inner_mit_sot, inner_mit_sot,
inner_sit_sot, inner_sit_sot,
inner_shared, inner_untraced_sit_sot,
inner_non_seqs, inner_non_seqs,
) = carry ) = carry
...@@ -108,7 +108,7 @@ def jax_funcify_Scan(op: Scan, **kwargs): ...@@ -108,7 +108,7 @@ def jax_funcify_Scan(op: Scan, **kwargs):
*mit_mot_flatten, *mit_mot_flatten,
*mit_sot_flatten, *mit_sot_flatten,
*inner_sit_sot, *inner_sit_sot,
*inner_shared, *inner_untraced_sit_sot,
*inner_non_seqs, *inner_non_seqs,
) )
...@@ -118,14 +118,14 @@ def jax_funcify_Scan(op: Scan, **kwargs): ...@@ -118,14 +118,14 @@ def jax_funcify_Scan(op: Scan, **kwargs):
): ):
"""Convert inner_scan_func outputs into format expected by JAX scan. """Convert inner_scan_func outputs into format expected by JAX scan.
old_carry + (MIT-SOT_outs, SIT-SOT_outs, NIT-SOT_outs, shared_outs) -> (new_carry, ys) old_carry + (MIT-SOT_outs, SIT-SOT_outs, NIT-SOT_outs, untraced_SIT-SOT_outs) -> (new_carry, ys)
""" """
( (
i, i,
old_mit_mot, old_mit_mot,
old_mit_sot, old_mit_sot,
_old_sit_sot, _old_sit_sot,
_old_shared, _old_untraced_sit_sot,
inner_non_seqs, inner_non_seqs,
) = old_carry ) = old_carry
...@@ -133,7 +133,7 @@ def jax_funcify_Scan(op: Scan, **kwargs): ...@@ -133,7 +133,7 @@ def jax_funcify_Scan(op: Scan, **kwargs):
new_mit_sot_vals = op.inner_mitsot_outs(inner_scan_outs) new_mit_sot_vals = op.inner_mitsot_outs(inner_scan_outs)
new_sit_sot = op.inner_sitsot_outs(inner_scan_outs) new_sit_sot = op.inner_sitsot_outs(inner_scan_outs)
new_nit_sot = op.inner_nitsot_outs(inner_scan_outs) new_nit_sot = op.inner_nitsot_outs(inner_scan_outs)
new_shared = op.inner_shared_outs(inner_scan_outs) new_untraced_sit_sot = op.inner_untraced_sit_sot_outs(inner_scan_outs)
# New carry for next step # New carry for next step
# Update MIT-MOT buffer at positions indicated by output taps # Update MIT-MOT buffer at positions indicated by output taps
...@@ -150,14 +150,14 @@ def jax_funcify_Scan(op: Scan, **kwargs): ...@@ -150,14 +150,14 @@ def jax_funcify_Scan(op: Scan, **kwargs):
old_mit_sot, new_mit_sot_vals, strict=True old_mit_sot, new_mit_sot_vals, strict=True
) )
] ]
# For SIT-SOT, and shared just pass along the new value # For SIT-SOT just pass along the new value
# Non-sequences remain unchanged # Non-sequences remain unchanged
new_carry = ( new_carry = (
i + 1, i + 1,
new_mit_mot, new_mit_mot,
new_mit_sot, new_mit_sot,
new_sit_sot, new_sit_sot,
new_shared, new_untraced_sit_sot,
inner_non_seqs, inner_non_seqs,
) )
...@@ -183,7 +183,7 @@ def jax_funcify_Scan(op: Scan, **kwargs): ...@@ -183,7 +183,7 @@ def jax_funcify_Scan(op: Scan, **kwargs):
final_mit_mot, final_mit_mot,
_final_mit_sot, _final_mit_sot,
_final_sit_sot, _final_sit_sot,
final_shared, final_untraced_sit_sot,
_final_non_seqs, _final_non_seqs,
), ),
traces, traces,
...@@ -238,7 +238,7 @@ def jax_funcify_Scan(op: Scan, **kwargs): ...@@ -238,7 +238,7 @@ def jax_funcify_Scan(op: Scan, **kwargs):
scan_outs_final = [ scan_outs_final = [
*final_mit_mot, *final_mit_mot,
*get_partial_traces(traces), *get_partial_traces(traces),
*final_shared, *final_untraced_sit_sot,
] ]
if len(scan_outs_final) == 1: if len(scan_outs_final) == 1:
......
...@@ -108,19 +108,19 @@ def numba_funcify_Scan(op: Scan, node, **kwargs): ...@@ -108,19 +108,19 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
outer_in_mit_sot_names = op.outer_mitsot(outer_in_names) outer_in_mit_sot_names = op.outer_mitsot(outer_in_names)
outer_in_sit_sot_names = op.outer_sitsot(outer_in_names) outer_in_sit_sot_names = op.outer_sitsot(outer_in_names)
outer_in_nit_sot_names = op.outer_nitsot(outer_in_names) outer_in_nit_sot_names = op.outer_nitsot(outer_in_names)
outer_in_shared_names = op.outer_shared(outer_in_names) outer_in_untraced_sit_sot_names = op.outer_untraced_sit_sot(outer_in_names)
outer_in_non_seqs_names = op.outer_non_seqs(outer_in_names) outer_in_non_seqs_names = op.outer_non_seqs(outer_in_names)
# These are all the outer-input names that have produce outputs/have output # These are all the outer-input names that have produce outputs/have output
# taps (i.e. they have inner-outputs and corresponding outer-outputs). # taps (i.e. they have inner-outputs and corresponding outer-outputs).
# Outer-outputs are ordered as follows: # Outer-outputs are ordered as follows:
# mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots + shared-outputs # mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots + untraced-sit-sot-outputs
outer_in_outtap_names = ( outer_in_outtap_names = (
outer_in_mit_mot_names outer_in_mit_mot_names
+ outer_in_mit_sot_names + outer_in_mit_sot_names
+ outer_in_sit_sot_names + outer_in_sit_sot_names
+ outer_in_nit_sot_names + outer_in_nit_sot_names
+ outer_in_shared_names + outer_in_untraced_sit_sot_names
) )
# We create distinct variables for/references to the storage arrays for # We create distinct variables for/references to the storage arrays for
...@@ -138,8 +138,10 @@ def numba_funcify_Scan(op: Scan, node, **kwargs): ...@@ -138,8 +138,10 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
for outer_in_name in outer_in_nit_sot_names: for outer_in_name in outer_in_nit_sot_names:
outer_in_to_storage_name[outer_in_name] = f"{outer_in_name}_nitsot_storage" outer_in_to_storage_name[outer_in_name] = f"{outer_in_name}_nitsot_storage"
for outer_in_name in outer_in_shared_names: for outer_in_name in outer_in_untraced_sit_sot_names:
outer_in_to_storage_name[outer_in_name] = f"{outer_in_name}_shared_storage" outer_in_to_storage_name[outer_in_name] = (
f"{outer_in_name}_untraced_sit_sot_storage"
)
outer_output_names = list(outer_in_to_storage_name.values()) outer_output_names = list(outer_in_to_storage_name.values())
assert len(outer_output_names) == len(node.outputs) assert len(outer_output_names) == len(node.outputs)
...@@ -147,7 +149,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs): ...@@ -147,7 +149,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
# Construct the inner-input expressions (e.g. indexed storage expressions) # Construct the inner-input expressions (e.g. indexed storage expressions)
# Inner-inputs are ordered as follows: # Inner-inputs are ordered as follows:
# sequences + mit-mot-inputs + mit-sot-inputs + sit-sot-inputs + # sequences + mit-mot-inputs + mit-sot-inputs + sit-sot-inputs +
# shared-inputs + non-sequences. # untraced-sit-sot-inputs + non-sequences.
temp_scalar_storage_alloc_stmts: list[str] = [] temp_scalar_storage_alloc_stmts: list[str] = []
inner_in_exprs_scalar: list[str] = [] inner_in_exprs_scalar: list[str] = []
inner_in_exprs: list[str] = [] inner_in_exprs: list[str] = []
...@@ -204,11 +206,9 @@ def numba_funcify_Scan(op: Scan, node, **kwargs): ...@@ -204,11 +206,9 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
# Inner-outputs consist of: # Inner-outputs consist of:
# mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots + # mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots +
# shared-outputs [+ while-condition] # untraced-sit-sot-outputs [+ while-condition]
inner_output_names = [f"inner_out_{i}" for i in range(len(op.inner_outputs))] inner_output_names = [f"inner_out_{i}" for i in range(len(op.inner_outputs))]
# inner_out_shared_names = op.inner_shared_outs(inner_output_names)
# The assignment statements that copy inner-outputs into the outer-outputs # The assignment statements that copy inner-outputs into the outer-outputs
# storage # storage
inner_out_to_outer_in_stmts: list[str] = [] inner_out_to_outer_in_stmts: list[str] = []
......
import typing
import warnings import warnings
from itertools import chain from itertools import chain
...@@ -11,6 +12,7 @@ from pytensor.graph.basic import Constant, Variable ...@@ -11,6 +12,7 @@ from pytensor.graph.basic import Constant, Variable
from pytensor.graph.op import get_test_value from pytensor.graph.op import get_test_value
from pytensor.graph.replace import clone_replace from pytensor.graph.replace import clone_replace
from pytensor.graph.traversal import explicit_graph_inputs from pytensor.graph.traversal import explicit_graph_inputs
from pytensor.graph.type import HasShape
from pytensor.graph.utils import MissingInputError, TestValueError from pytensor.graph.utils import MissingInputError, TestValueError
from pytensor.scan.op import Scan, ScanInfo from pytensor.scan.op import Scan, ScanInfo
from pytensor.scan.utils import expand_empty, safe_new, until from pytensor.scan.utils import expand_empty, safe_new, until
...@@ -22,6 +24,10 @@ from pytensor.tensor.type import TensorType, integer_dtypes ...@@ -22,6 +24,10 @@ from pytensor.tensor.type import TensorType, integer_dtypes
from pytensor.updates import OrderedUpdates from pytensor.updates import OrderedUpdates
if typing.TYPE_CHECKING:
from pytensor.tensor.type import TensorVariable
def get_updates_and_outputs(ls): def get_updates_and_outputs(ls):
"""Recognize and order the updates, outputs, and stopping condition for a `Scan`. """Recognize and order the updates, outputs, and stopping condition for a `Scan`.
...@@ -469,7 +475,7 @@ def scan( ...@@ -469,7 +475,7 @@ def scan(
# Make sure we get rid of numpy arrays or ints or anything like that # Make sure we get rid of numpy arrays or ints or anything like that
# passed as inputs to scan # passed as inputs to scan
non_seqs = [] non_seqs: list[Variable] = []
for elem in wrap_into_list(non_sequences): for elem in wrap_into_list(non_sequences):
if not isinstance(elem, Variable): if not isinstance(elem, Variable):
non_seqs.append(pt.as_tensor_variable(elem)) non_seqs.append(pt.as_tensor_variable(elem))
...@@ -685,10 +691,10 @@ def scan( ...@@ -685,10 +691,10 @@ def scan(
# MIT_MOT -- not provided by the user only by the grad function # MIT_MOT -- not provided by the user only by the grad function
n_mit_mot = 0 n_mit_mot = 0
mit_mot_scan_inputs = [] mit_mot_scan_inputs: list[TensorVariable] = []
mit_mot_inner_inputs = [] mit_mot_inner_inputs: list[TensorVariable] = []
mit_mot_inner_outputs = [] mit_mot_inner_outputs: list[TensorVariable] = []
mit_mot_out_slices = [] mit_mot_out_slices: list[TensorVariable] = []
# SIT_SOT -- provided by the user # SIT_SOT -- provided by the user
n_mit_sot = 0 n_mit_sot = 0
...@@ -706,6 +712,12 @@ def scan( ...@@ -706,6 +712,12 @@ def scan(
sit_sot_inner_outputs = [] sit_sot_inner_outputs = []
sit_sot_rightOrder = [] sit_sot_rightOrder = []
n_untraced_sit_sot_outs = 0
untraced_sit_sot_scan_inputs = []
untraced_sit_sot_inner_inputs = []
untraced_sit_sot_inner_outputs = []
untraced_sit_sot_rightOrder = []
# go through outputs picking up time slices as needed # go through outputs picking up time slices as needed
for i, init_out in enumerate(outs_info): for i, init_out in enumerate(outs_info):
# Note that our convention dictates that if an output uses # Note that our convention dictates that if an output uses
...@@ -741,17 +753,35 @@ def scan( ...@@ -741,17 +753,35 @@ def scan(
# We need now to allocate space for storing the output and copy # We need now to allocate space for storing the output and copy
# the initial state over. We do this using the expand function # the initial state over. We do this using the expand function
# defined in scan utils # defined in scan utils
sit_sot_scan_inputs.append( if isinstance(actual_arg.type, HasShape):
expand_empty( sit_sot_scan_inputs.append(
shape_padleft(actual_arg), expand_empty(
actual_n_steps, shape_padleft(actual_arg),
actual_n_steps,
)
) )
) sit_sot_inner_slices.append(actual_arg)
sit_sot_inner_slices.append(actual_arg) sit_sot_inner_inputs.append(arg)
sit_sot_inner_inputs.append(arg) sit_sot_rightOrder.append(i)
sit_sot_rightOrder.append(i) n_sit_sot += 1
n_sit_sot += 1 else:
# Assume variables without shape cannot be stacked (e.g., RNG variables)
# Because this is new, issue a warning to inform the user, except for RNG, which were the main reason for this feature
from pytensor.tensor.random.type import RandomType
if not isinstance(arg.type, RandomType):
warnings.warn(
(
f"Output {actual_arg} (index {i}) with type {actual_arg.type} will be treated as untraced variable in scan. "
"Only the last value will be returned, not the entire sequence."
),
UserWarning,
)
untraced_sit_sot_scan_inputs.append(actual_arg)
untraced_sit_sot_inner_inputs.append(arg)
n_untraced_sit_sot_outs += 1
untraced_sit_sot_rightOrder.append(i)
elif init_out.get("taps", None): elif init_out.get("taps", None):
if np.any(np.array(init_out.get("taps", [])) > 0): if np.any(np.array(init_out.get("taps", [])) > 0):
...@@ -802,10 +832,11 @@ def scan( ...@@ -802,10 +832,11 @@ def scan(
# a map); in that case we do not have to do anything .. # a map); in that case we do not have to do anything ..
# Re-order args # Re-order args
max_mit_sot = np.max([-1, *mit_sot_rightOrder]) + 1 max_mit_sot = max(mit_sot_rightOrder, default=-1) + 1
max_sit_sot = np.max([-1, *sit_sot_rightOrder]) + 1 max_sit_sot = max(sit_sot_rightOrder, default=-1) + 1
n_elems = np.max([max_mit_sot, max_sit_sot]) max_untraced_sit_sot_outs = max(untraced_sit_sot_rightOrder, default=-1) + 1
_ordered_args = [[] for x in range(n_elems)] n_elems = np.max((max_mit_sot, max_sit_sot, max_untraced_sit_sot_outs))
_ordered_args: list[list[Variable]] = [[] for x in range(n_elems)]
offset = 0 offset = 0
for idx in range(n_mit_sot): for idx in range(n_mit_sot):
n_inputs = len(mit_sot_tap_array[idx]) n_inputs = len(mit_sot_tap_array[idx])
...@@ -825,6 +856,11 @@ def scan( ...@@ -825,6 +856,11 @@ def scan(
else: else:
_ordered_args[sit_sot_rightOrder[idx]] = [sit_sot_inner_inputs[idx]] _ordered_args[sit_sot_rightOrder[idx]] = [sit_sot_inner_inputs[idx]]
for idx in range(n_untraced_sit_sot_outs):
_ordered_args[untraced_sit_sot_rightOrder[idx]] = [
untraced_sit_sot_inner_inputs[idx]
]
ordered_args = list(chain.from_iterable(_ordered_args)) ordered_args = list(chain.from_iterable(_ordered_args))
if single_step_requested: if single_step_requested:
args = inner_slices + ordered_args + non_seqs args = inner_slices + ordered_args + non_seqs
...@@ -939,18 +975,19 @@ def scan( ...@@ -939,18 +975,19 @@ def scan(
if "taps" in out and out["taps"] != [-1]: if "taps" in out and out["taps"] != [-1]:
mit_sot_inner_outputs.append(outputs[i]) mit_sot_inner_outputs.append(outputs[i])
# Step 5.2 Outputs with tap equal to -1 # Step 5.2 Outputs with tap equal to -1 (traced and untraced)
for i, out in enumerate(outs_info): for i, out in enumerate(outs_info):
if "taps" in out and out["taps"] == [-1]: if "taps" in out and out["taps"] == [-1]:
sit_sot_inner_outputs.append(outputs[i]) output = outputs[i]
if isinstance(output.type, HasShape):
sit_sot_inner_outputs.append(output)
else:
untraced_sit_sot_inner_outputs.append(output)
# Step 5.3 Outputs that correspond to update rules of shared variables # Step 5.3 Outputs that correspond to update rules of shared variables
# This whole special logic for shared variables is deprecated
sit_sot_shared: list[Variable] = []
inner_replacements = {} inner_replacements = {}
n_shared_outs = 0
shared_scan_inputs = []
shared_inner_inputs = []
shared_inner_outputs = []
sit_sot_shared = []
no_update_shared_inputs = [] no_update_shared_inputs = []
for input in dummy_inputs: for input in dummy_inputs:
if not isinstance(input.variable, SharedVariable): if not isinstance(input.variable, SharedVariable):
...@@ -976,8 +1013,8 @@ def scan( ...@@ -976,8 +1013,8 @@ def scan(
new_var = safe_new(input.variable) new_var = safe_new(input.variable)
if getattr(input.variable, "name", None) is not None: if input.variable.name is not None:
new_var.name = input.variable.name + "_copy" new_var.name = f"{input.variable.name}_copy"
inner_replacements[input.variable] = new_var inner_replacements[input.variable] = new_var
...@@ -1003,10 +1040,10 @@ def scan( ...@@ -1003,10 +1040,10 @@ def scan(
sit_sot_shared.append(input.variable) sit_sot_shared.append(input.variable)
else: else:
shared_inner_inputs.append(new_var) untraced_sit_sot_inner_inputs.append(new_var)
shared_scan_inputs.append(input.variable) untraced_sit_sot_scan_inputs.append(input.variable)
shared_inner_outputs.append(input.update) untraced_sit_sot_inner_outputs.append(input.update)
n_shared_outs += 1 n_untraced_sit_sot_outs += 1
else: else:
no_update_shared_inputs.append(input) no_update_shared_inputs.append(input)
...@@ -1071,7 +1108,7 @@ def scan( ...@@ -1071,7 +1108,7 @@ def scan(
+ mit_mot_inner_inputs + mit_mot_inner_inputs
+ mit_sot_inner_inputs + mit_sot_inner_inputs
+ sit_sot_inner_inputs + sit_sot_inner_inputs
+ shared_inner_inputs + untraced_sit_sot_inner_inputs
+ other_shared_inner_args + other_shared_inner_args
+ other_inner_args + other_inner_args
) )
...@@ -1081,7 +1118,7 @@ def scan( ...@@ -1081,7 +1118,7 @@ def scan(
+ mit_sot_inner_outputs + mit_sot_inner_outputs
+ sit_sot_inner_outputs + sit_sot_inner_outputs
+ nit_sot_inner_outputs + nit_sot_inner_outputs
+ shared_inner_outputs + untraced_sit_sot_inner_outputs
) )
if condition is not None: if condition is not None:
inner_outs.append(condition) inner_outs.append(condition)
...@@ -1101,7 +1138,7 @@ def scan( ...@@ -1101,7 +1138,7 @@ def scan(
mit_mot_out_slices=tuple(tuple(v) for v in mit_mot_out_slices), mit_mot_out_slices=tuple(tuple(v) for v in mit_mot_out_slices),
mit_sot_in_slices=tuple(tuple(v) for v in mit_sot_tap_array), mit_sot_in_slices=tuple(tuple(v) for v in mit_sot_tap_array),
sit_sot_in_slices=tuple((-1,) for x in range(n_sit_sot)), sit_sot_in_slices=tuple((-1,) for x in range(n_sit_sot)),
n_shared_outs=n_shared_outs, n_untraced_sit_sot_outs=n_untraced_sit_sot_outs,
n_nit_sot=n_nit_sot, n_nit_sot=n_nit_sot,
n_non_seqs=len(other_shared_inner_args) + len(other_inner_args), n_non_seqs=len(other_shared_inner_args) + len(other_inner_args),
as_while=as_while, as_while=as_while,
...@@ -1127,7 +1164,7 @@ def scan( ...@@ -1127,7 +1164,7 @@ def scan(
+ mit_mot_scan_inputs + mit_mot_scan_inputs
+ mit_sot_scan_inputs + mit_sot_scan_inputs
+ sit_sot_scan_inputs + sit_sot_scan_inputs
+ shared_scan_inputs + untraced_sit_sot_scan_inputs
+ [actual_n_steps for x in range(n_nit_sot)] + [actual_n_steps for x in range(n_nit_sot)]
+ other_shared_scan_args + other_shared_scan_args
+ other_scan_args + other_scan_args
...@@ -1173,13 +1210,26 @@ def scan( ...@@ -1173,13 +1210,26 @@ def scan(
nit_sot_outs = remove_dimensions(scan_outs[offset : offset + n_nit_sot]) nit_sot_outs = remove_dimensions(scan_outs[offset : offset + n_nit_sot])
offset += n_nit_sot offset += n_nit_sot
for idx, update_rule in enumerate(scan_outs[offset : offset + n_shared_outs]):
update_map[shared_scan_inputs[idx]] = update_rule
_scan_out_list = mit_sot_outs + sit_sot_outs + nit_sot_outs # Support for explicit untraced sit_sot
n_explicit_untraced_sit_sot_outs = len(untraced_sit_sot_rightOrder)
untraced_sit_sot_outs = scan_outs[
offset : offset + n_explicit_untraced_sit_sot_outs
]
offset += n_explicit_untraced_sit_sot_outs
for idx, update_rule in enumerate(scan_outs[offset:]):
update_map[untraced_sit_sot_scan_inputs[idx]] = update_rule
_scan_out_list = mit_sot_outs + sit_sot_outs + nit_sot_outs + untraced_sit_sot_outs
# Step 10. I need to reorder the outputs to be in the order expected by # Step 10. I need to reorder the outputs to be in the order expected by
# the user # the user
rightOrder = mit_sot_rightOrder + sit_sot_rightOrder + nit_sot_rightOrder rightOrder = (
mit_sot_rightOrder
+ sit_sot_rightOrder
+ untraced_sit_sot_rightOrder
+ nit_sot_rightOrder
)
scan_out_list = [None] * len(rightOrder) scan_out_list = [None] * len(rightOrder)
for idx, pos in enumerate(rightOrder): for idx, pos in enumerate(rightOrder):
if pos >= 0: if pos >= 0:
......
...@@ -46,6 +46,7 @@ relies on the following elements to work properly : ...@@ -46,6 +46,7 @@ relies on the following elements to work properly :
import dataclasses import dataclasses
import logging import logging
import time import time
import warnings
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
from copy import copy from copy import copy
from itertools import chain, product from itertools import chain, product
...@@ -208,10 +209,19 @@ class ScanInfo: ...@@ -208,10 +209,19 @@ class ScanInfo:
mit_sot_in_slices: tuple mit_sot_in_slices: tuple
sit_sot_in_slices: tuple sit_sot_in_slices: tuple
n_nit_sot: int n_nit_sot: int
n_shared_outs: int n_untraced_sit_sot_outs: int
n_non_seqs: int n_non_seqs: int
as_while: bool as_while: bool
@property
def n_shared_outs(self):
warnings.warn(
"The 'n_shared_outs' property is deprecated. Use 'n_untraced_sit_sot_outs' instead.",
DeprecationWarning,
stacklevel=2,
)
return self.n_untraced_sit_sot_outs
@property @property
def n_mit_mot(self): def n_mit_mot(self):
return len(self.mit_mot_in_slices) return len(self.mit_mot_in_slices)
...@@ -239,7 +249,7 @@ class ScanInfo: ...@@ -239,7 +249,7 @@ class ScanInfo:
+ sum(len(x) for x in self.mit_mot_in_slices) + sum(len(x) for x in self.mit_mot_in_slices)
+ sum(len(x) for x in self.mit_sot_in_slices) + sum(len(x) for x in self.mit_sot_in_slices)
+ self.n_sit_sot + self.n_sit_sot
+ self.n_shared_outs + self.n_untraced_sit_sot_outs
+ self.n_non_seqs + self.n_non_seqs
) )
...@@ -250,7 +260,7 @@ class ScanInfo: ...@@ -250,7 +260,7 @@ class ScanInfo:
+ self.n_mit_sot + self.n_mit_sot
+ self.n_sit_sot + self.n_sit_sot
+ self.n_nit_sot + self.n_nit_sot
+ self.n_shared_outs + self.n_untraced_sit_sot_outs
+ int(self.as_while) + int(self.as_while)
) )
...@@ -263,7 +273,7 @@ class ScanInfo: ...@@ -263,7 +273,7 @@ class ScanInfo:
+ self.n_mit_sot + self.n_mit_sot
+ self.n_sit_sot + self.n_sit_sot
+ self.n_nit_sot + self.n_nit_sot
+ self.n_shared_outs + self.n_untraced_sit_sot_outs
+ self.n_non_seqs + self.n_non_seqs
) )
...@@ -274,7 +284,7 @@ class ScanInfo: ...@@ -274,7 +284,7 @@ class ScanInfo:
+ self.n_mit_sot + self.n_mit_sot
+ self.n_sit_sot + self.n_sit_sot
+ self.n_nit_sot + self.n_nit_sot
+ self.n_shared_outs + self.n_untraced_sit_sot_outs
) )
...@@ -381,7 +391,7 @@ class ScanMethodsMixin: ...@@ -381,7 +391,7 @@ class ScanMethodsMixin:
+ self.info.n_mit_mot + self.info.n_mit_mot
+ self.info.n_mit_sot + self.info.n_mit_sot
+ self.info.n_sit_sot + self.info.n_sit_sot
+ self.info.n_shared_outs + self.info.n_untraced_sit_sot_outs
) )
return list_inputs[offset : offset + self.info.n_nit_sot] return list_inputs[offset : offset + self.info.n_nit_sot]
...@@ -394,15 +404,23 @@ class ScanMethodsMixin: ...@@ -394,15 +404,23 @@ class ScanMethodsMixin:
offset = self.info.n_mit_mot + self.info.n_mit_sot + self.info.n_sit_sot offset = self.info.n_mit_mot + self.info.n_mit_sot + self.info.n_sit_sot
return list_outputs[offset : offset + self.info.n_nit_sot] return list_outputs[offset : offset + self.info.n_nit_sot]
def inner_shared(self, list_inputs): def inner_untraced_sit_sot(self, list_inputs):
n_taps_upto_sit_sot = sum( n_taps_upto_sit_sot = sum(
len(x) len(x)
for x in chain(self.info.mit_mot_in_slices, self.info.mit_sot_in_slices) for x in chain(self.info.mit_mot_in_slices, self.info.mit_sot_in_slices)
) )
offset = self.info.n_seqs + n_taps_upto_sit_sot + self.info.n_sit_sot offset = self.info.n_seqs + n_taps_upto_sit_sot + self.info.n_sit_sot
return list_inputs[offset : offset + self.info.n_shared_outs] return list_inputs[offset : offset + self.info.n_untraced_sit_sot_outs]
def outer_shared(self, list_inputs): def inner_shared(self, list_inputs):
warnings.warn(
"The 'inner_shared' method is deprecated. Use 'inner_untraced_sit_sot' instead.",
DeprecationWarning,
stacklevel=2,
)
return self.inner_untraced_sit_sot(list_inputs)
def outer_untraced_sit_sot(self, list_inputs):
offset = ( offset = (
1 1
+ self.info.n_seqs + self.info.n_seqs
...@@ -410,23 +428,47 @@ class ScanMethodsMixin: ...@@ -410,23 +428,47 @@ class ScanMethodsMixin:
+ self.info.n_mit_sot + self.info.n_mit_sot
+ self.info.n_sit_sot + self.info.n_sit_sot
) )
return list_inputs[offset : offset + self.info.n_shared_outs] return list_inputs[offset : offset + self.info.n_untraced_sit_sot_outs]
def inner_shared_outs(self, list_outputs): def outer_shared(self, list_inputs):
warnings.warn(
"The 'outer_shared' method is deprecated. Use 'outer_untraced_sit_sot' instead.",
DeprecationWarning,
stacklevel=2,
)
return self.outer_untraced_sit_sot(list_inputs)
def inner_untraced_sit_sot_outs(self, list_outputs):
n_taps = sum(len(x) for x in self.info.mit_mot_out_slices) n_taps = sum(len(x) for x in self.info.mit_mot_out_slices)
offset = ( offset = (
self.info.n_mit_sot + n_taps + self.info.n_sit_sot + self.info.n_nit_sot self.info.n_mit_sot + n_taps + self.info.n_sit_sot + self.info.n_nit_sot
) )
return list_outputs[offset : offset + self.info.n_shared_outs] return list_outputs[offset : offset + self.info.n_untraced_sit_sot_outs]
def outer_shared_outs(self, list_outputs): def inner_shared_outs(self, list_outputs):
warnings.warn(
"The 'inner_shared_outs' method is deprecated. Use 'inner_untraced_sit_sot_outs' instead.",
DeprecationWarning,
stacklevel=2,
)
return self.inner_untraced_sit_sot_outs(list_outputs)
def outer_untraced_sit_sot_outs(self, list_outputs):
offset = ( offset = (
self.info.n_mit_mot self.info.n_mit_mot
+ self.info.n_mit_sot + self.info.n_mit_sot
+ self.info.n_sit_sot + self.info.n_sit_sot
+ self.info.n_nit_sot + self.info.n_nit_sot
) )
return list_outputs[offset : offset + self.info.n_shared_outs] return list_outputs[offset : offset + self.info.n_untraced_sit_sot_outs]
def outer_shared_outs(self, list_outputs):
warnings.warn(
"The 'outer_shared_outs' method is deprecated. Use 'outer_untraced_sit_sot_outs' instead.",
DeprecationWarning,
stacklevel=2,
)
return self.outer_untraced_sit_sot_outs(list_outputs)
def inner_non_seqs(self, list_inputs): def inner_non_seqs(self, list_inputs):
n_taps_upto_sit_sot = sum( n_taps_upto_sit_sot = sum(
...@@ -437,7 +479,7 @@ class ScanMethodsMixin: ...@@ -437,7 +479,7 @@ class ScanMethodsMixin:
self.info.n_seqs self.info.n_seqs
+ n_taps_upto_sit_sot + n_taps_upto_sit_sot
+ self.info.n_sit_sot + self.info.n_sit_sot
+ self.info.n_shared_outs + self.info.n_untraced_sit_sot_outs
) )
return list_inputs[offset:] return list_inputs[offset:]
...@@ -449,7 +491,7 @@ class ScanMethodsMixin: ...@@ -449,7 +491,7 @@ class ScanMethodsMixin:
+ self.info.n_mit_sot + self.info.n_mit_sot
+ self.info.n_sit_sot + self.info.n_sit_sot
+ self.info.n_nit_sot + self.info.n_nit_sot
+ self.info.n_shared_outs + self.info.n_untraced_sit_sot_outs
) )
return list_inputs[offset:] return list_inputs[offset:]
...@@ -525,8 +567,8 @@ class ScanMethodsMixin: ...@@ -525,8 +567,8 @@ class ScanMethodsMixin:
outer_oidx += 1 outer_oidx += 1
# This is needed because, for outer inputs (and for outer inputs only) # This is needed because, for outer inputs (and for outer inputs only)
# nitsots come *after* shared variables. # nitsots come *after* untraced_sitsot variables.
outer_iidx += self.info.n_shared_outs outer_iidx += self.info.n_untraced_sit_sot_outs
# Handle nitsots variables # Handle nitsots variables
for i in range(self.info.n_nit_sot): for i in range(self.info.n_nit_sot):
...@@ -541,11 +583,11 @@ class ScanMethodsMixin: ...@@ -541,11 +583,11 @@ class ScanMethodsMixin:
outer_oidx += 1 outer_oidx += 1
# This is needed because, for outer inputs (and for outer inputs only) # This is needed because, for outer inputs (and for outer inputs only)
# nitsots come *after* shared variables. # nitsots come *after* untraced_sit_sot variables.
outer_iidx -= self.info.n_shared_outs + self.info.n_nit_sot outer_iidx -= self.info.n_untraced_sit_sot_outs + self.info.n_nit_sot
# Handle shared states # Handle untraced_sitsot states
for i in range(self.info.n_shared_outs): for i in range(self.info.n_untraced_sit_sot_outs):
outer_input_indices.append(outer_iidx) outer_input_indices.append(outer_iidx)
inner_input_indices.append([inner_iidx]) inner_input_indices.append([inner_iidx])
inner_output_indices.append([inner_oidx]) inner_output_indices.append([inner_oidx])
...@@ -557,7 +599,7 @@ class ScanMethodsMixin: ...@@ -557,7 +599,7 @@ class ScanMethodsMixin:
outer_oidx += 1 outer_oidx += 1
# This is needed because, for outer inputs (and for outer inputs only) # This is needed because, for outer inputs (and for outer inputs only)
# nitsots come *after* shared variables. # nitsots come *after* untraced_sitsot variables.
outer_iidx += self.info.n_nit_sot outer_iidx += self.info.n_nit_sot
# Handle non-sequence inputs # Handle non-sequence inputs
...@@ -708,7 +750,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -708,7 +750,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
Inputs of the inner function of `Scan`. Inputs of the inner function of `Scan`.
These take the following general form: These take the following general form:
sequences + mit-mot-inputs + mit-sot-inputs + sit-sot-inputs + shared-inputs + non-sequences sequences + mit-mot-inputs + mit-sot-inputs + sit-sot-inputs + untraced-sit-sot-inputs + shared-inputs + non-sequences
where each term is a list of `Variable`\s. where each term is a list of `Variable`\s.
...@@ -716,7 +758,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -716,7 +758,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
Outputs of the inner function of `Scan`. Outputs of the inner function of `Scan`.
These take the following general form: These take the following general form:
mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots + shared-outputs [+ while-condition] mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots + untraced-sit-sot-outputs [+ while-condition]
where each term is a list of `Variable`\s. where each term is a list of `Variable`\s.
...@@ -817,7 +859,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -817,7 +859,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
typeConstructor((None, *o.type.shape), o.type.dtype) typeConstructor((None, *o.type.shape), o.type.dtype)
) )
# shared outputs + possibly the ending condition # untraced_sit_sot outputs + possibly the ending condition
for o in self.fgraph.outputs[end:]: for o in self.fgraph.outputs[end:]:
self.output_types.append(o.type) self.output_types.append(o.type)
...@@ -836,10 +878,12 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -836,10 +878,12 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
] ]
self.mintaps += [0 for x in range(info.n_nit_sot)] self.mintaps += [0 for x in range(info.n_nit_sot)]
self.seqs_arg_offset = 1 + info.n_seqs self.seqs_arg_offset = 1 + info.n_seqs
self.shared_arg_offset = ( self.untraced_sit_sot_arg_offset = (
self.seqs_arg_offset + info.n_mit_mot + info.n_mit_sot + info.n_sit_sot self.seqs_arg_offset + info.n_mit_mot + info.n_mit_sot + info.n_sit_sot
) )
self.nit_sot_arg_offset = self.shared_arg_offset + info.n_shared_outs self.nit_sot_arg_offset = (
self.untraced_sit_sot_arg_offset + info.n_untraced_sit_sot_outs
)
# XXX: This doesn't include `info.n_nit_sot`s, so it's really a count # XXX: This doesn't include `info.n_nit_sot`s, so it's really a count
# of the number of outputs generated by taps with inputs # of the number of outputs generated by taps with inputs
self.n_outs = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot self.n_outs = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot
...@@ -908,7 +952,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -908,7 +952,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
sequences + sequences +
mit-mot-inputs + mit-sot-inputs + sit-sot-inputs + mit-mot-inputs + mit-sot-inputs + sit-sot-inputs +
shared-inputs + untraced-sit-sot-inputs + shared-inputs
nit-sots + nit-sots +
non-sequences non-sequences
...@@ -923,7 +967,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -923,7 +967,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
[n_steps] + [n_steps] +
sequences + sequences +
mit-mot-inputs + mit-sot-inputs + sit-sot-inputs + mit-mot-inputs + mit-sot-inputs + sit-sot-inputs +
shared-inputs + untraced-sit-sot-inputs + shared-inputs
nit-sots + nit-sots +
non-sequences non-sequences
...@@ -931,7 +975,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -931,7 +975,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + mit-mot-outputs + mit-sot-outputs + sit-sot-outputs +
nit-sots + nit-sots +
shared-outputs untraced-sit-sot-outputs
These outer-outputs essentially follow the same form as their These outer-outputs essentially follow the same form as their
corresponding inner-outputs, excluding the final "while" condition corresponding inner-outputs, excluding the final "while" condition
...@@ -949,7 +993,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -949,7 +993,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
+ len(self.info.mit_mot_in_slices) + len(self.info.mit_mot_in_slices)
+ len(self.info.mit_sot_in_slices) + len(self.info.mit_sot_in_slices)
+ len(self.inner_sitsot(self.inner_inputs)) + len(self.inner_sitsot(self.inner_inputs))
+ len(self.inner_shared(self.inner_inputs)) + len(self.inner_untraced_sit_sot(self.inner_inputs))
+ len(self.inner_non_seqs(self.inner_inputs)) + len(self.inner_non_seqs(self.inner_inputs))
) )
...@@ -1134,60 +1178,60 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1134,60 +1178,60 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
) )
argoffset += len(self.outer_sitsot(inputs)) argoffset += len(self.outer_sitsot(inputs))
# Check that the shared variable and their update rule have the same # Check that the untraced (u) sit-sot variable and their update rule have the same
# dtype. Maybe even same type ?! # dtype. Maybe even same type ?!
for idx, (inner_shared, inner_shared_out, _outer_shared) in enumerate( for idx, (inner_u_sitsot, inner_u_sitsot_out, _outer_u_sitsot) in enumerate(
zip( zip(
self.inner_shared(self.inner_inputs), self.inner_untraced_sit_sot(self.inner_inputs),
self.inner_shared_outs(self.inner_outputs), self.inner_untraced_sit_sot_outs(self.inner_outputs),
self.outer_shared(inputs), self.outer_untraced_sit_sot(inputs),
strict=True, strict=True,
) )
): ):
outer_shared = copy_var_format(_outer_shared, as_var=inner_shared) outer_u_sitsot = copy_var_format(_outer_u_sitsot, as_var=inner_u_sitsot)
new_inputs.append(outer_shared) new_inputs.append(outer_u_sitsot)
if ( if (
hasattr(outer_shared, "dtype") hasattr(outer_u_sitsot, "dtype")
and outer_shared.dtype != inner_shared_out.dtype and outer_u_sitsot.dtype != inner_u_sitsot_out.dtype
): ):
raise ValueError( raise ValueError(
err_msg2 err_msg2
% ( % (
str(outer_shared), str(outer_u_sitsot),
idx + argoffset, idx + argoffset,
outer_shared.dtype, outer_u_sitsot.dtype,
inner_shared_out.dtype, inner_u_sitsot_out.dtype,
) )
) )
if ( if (
hasattr(outer_shared, "dtype") hasattr(outer_u_sitsot, "dtype")
and outer_shared.ndim != inner_shared_out.ndim and outer_u_sitsot.ndim != inner_u_sitsot_out.ndim
): ):
raise ValueError( raise ValueError(
err_msg3 err_msg3
% ( % (
str(outer_shared), str(outer_u_sitsot),
idx + argoffset, idx + argoffset,
outer_shared.ndim, outer_u_sitsot.ndim,
inner_shared_out.ndim, inner_u_sitsot_out.ndim,
) )
) )
if hasattr(outer_shared, "dtype") and ( if hasattr(outer_u_sitsot, "dtype") and (
outer_shared.dtype != inner_shared.dtype outer_u_sitsot.dtype != inner_u_sitsot.dtype
or outer_shared.ndim != inner_shared.ndim or outer_u_sitsot.ndim != inner_u_sitsot.ndim
): ):
raise ValueError( raise ValueError(
err_msg1 err_msg1
% ( % (
"initial state (outputs_info in scan nomenclature) ", "initial state (outputs_info in scan nomenclature) ",
str(outer_shared), str(outer_u_sitsot),
argoffset + idx, argoffset + idx,
outer_shared.dtype, outer_u_sitsot.dtype,
outer_shared.ndim, outer_u_sitsot.ndim,
str(inner_shared), str(inner_u_sitsot),
inner_shared.dtype, inner_u_sitsot.dtype,
inner_shared.ndim, inner_u_sitsot.ndim,
) )
) )
# We do not need to call `copy_var_format` on outer_nisot arguments. # We do not need to call `copy_var_format` on outer_nisot arguments.
...@@ -1585,7 +1629,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1585,7 +1629,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
try: try:
t_fn, n_steps = scan_perform_ext.perform( t_fn, n_steps = scan_perform_ext.perform(
self.info.n_shared_outs, self.info.n_untraced_sit_sot_outs,
self.info.n_mit_mot_outs, self.info.n_mit_mot_outs,
self.info.n_seqs, self.info.n_seqs,
self.info.n_mit_mot, self.info.n_mit_mot,
...@@ -1719,7 +1763,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1719,7 +1763,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# The length of each output # The length of each output
store_steps = [ store_steps = [
arg.shape[0] arg.shape[0]
for arg in inputs[self.seqs_arg_offset : self.shared_arg_offset] for arg in inputs[self.seqs_arg_offset : self.untraced_sit_sot_arg_offset]
] ]
store_steps += list( store_steps += list(
inputs[self.nit_sot_arg_offset : self.nit_sot_arg_offset + info.n_nit_sot] inputs[self.nit_sot_arg_offset : self.nit_sot_arg_offset + info.n_nit_sot]
...@@ -1784,7 +1828,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1784,7 +1828,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
info.sit_sot_in_slices, info.sit_sot_in_slices,
) )
) )
+ info.n_shared_outs + info.n_untraced_sit_sot_outs
) )
for idx in range(len(other_args)): for idx in range(len(other_args)):
inner_input_storage[idx + offset].storage[0] = other_args[idx] inner_input_storage[idx + offset].storage[0] = other_args[idx]
...@@ -1827,14 +1871,14 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1827,14 +1871,14 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
] ]
offset += 1 offset += 1
a_offset = self.shared_arg_offset a_offset = self.untraced_sit_sot_arg_offset
o_offset = self.n_outs + info.n_nit_sot o_offset = self.n_outs + info.n_nit_sot
if i == 0: if i == 0:
for j in range(info.n_shared_outs): for j in range(info.n_untraced_sit_sot_outs):
inner_input_storage[offset].storage[0] = inputs[a_offset + j] inner_input_storage[offset].storage[0] = inputs[a_offset + j]
offset += 1 offset += 1
else: else:
for j in range(info.n_shared_outs): for j in range(info.n_untraced_sit_sot_outs):
inner_input_storage[offset].storage[0] = output_storage[ inner_input_storage[offset].storage[0] = output_storage[
o_offset + j o_offset + j
][0] ][0]
...@@ -1866,14 +1910,14 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1866,14 +1910,14 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
for idx in range(self.n_outs + info.n_nit_sot - info.n_mit_mot): for idx in range(self.n_outs + info.n_nit_sot - info.n_mit_mot):
inner_output_storage[idx + offset].storage[0] = None inner_output_storage[idx + offset].storage[0] = None
# 4.3. Collect slices for shared outputs # 4.3. Collect slices for untraced sitsot outputs
offset += self.n_outs + info.n_nit_sot - info.n_mit_mot offset += self.n_outs + info.n_nit_sot - info.n_mit_mot
for idx in range(info.n_shared_outs): for idx in range(info.n_untraced_sit_sot_outs):
inner_output_storage[idx + offset].storage[0] = None inner_output_storage[idx + offset].storage[0] = None
# 4.4. If there is a condition add it to the mix # 4.4. If there is a condition add it to the mix
if info.as_while: if info.as_while:
pdx = offset + info.n_shared_outs pdx = offset + info.n_untraced_sit_sot_outs
inner_output_storage[pdx].storage[0] = None inner_output_storage[pdx].storage[0] = None
# 4.5. Keep a reference to the variables (ndarrays, # 4.5. Keep a reference to the variables (ndarrays,
...@@ -1942,7 +1986,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1942,7 +1986,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
dt_fn = time.perf_counter() - t0_fn dt_fn = time.perf_counter() - t0_fn
if info.as_while: if info.as_while:
pdx = offset + info.n_shared_outs pdx = offset + info.n_untraced_sit_sot_outs
cond = inner_output_storage[pdx].storage[0] == 0 cond = inner_output_storage[pdx].storage[0] == 0
t_fn += dt_fn t_fn += dt_fn
...@@ -2089,10 +2133,10 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2089,10 +2133,10 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
j + offset_out j + offset_out
].storage[0] ].storage[0]
# 5.6 Copy over the values for outputs corresponding to shared # 5.6 Copy over the values for outputs corresponding to untraced sitsot
# variables # variables
begin = end begin = end
end += info.n_shared_outs end += info.n_untraced_sit_sot_outs
for j in range(begin, end): for j in range(begin, end):
jout = j + offset_out jout = j + offset_out
output_storage[j][0] = inner_output_storage[jout].storage[0] output_storage[j][0] = inner_output_storage[jout].storage[0]
...@@ -2240,13 +2284,13 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2240,13 +2284,13 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# out_equivalent[self.inner_inputs[inner_inp_idx]] = corresponding_tap # out_equivalent[self.inner_inputs[inner_inp_idx]] = corresponding_tap
outer_inp_idx += 1 outer_inp_idx += 1
# shared_outs # untraced sit_sot outputs
offset = 1 + info.n_seqs + n_outs offset = 1 + info.n_seqs + n_outs
for idx in range(info.n_shared_outs): for idx in range(info.n_untraced_sit_sot_outs):
outs_shape += [input_shapes[idx + offset]] outs_shape += [input_shapes[idx + offset]]
# non_sequences # non_sequences
offset += info.n_nit_sot + info.n_shared_outs offset += info.n_nit_sot + info.n_untraced_sit_sot_outs
inner_ins_shapes = seqs_shape + outs_shape + input_shapes[offset:] inner_ins_shapes = seqs_shape + outs_shape + input_shapes[offset:]
assert len(inner_ins_shapes) == len(self.inner_inputs) assert len(inner_ins_shapes) == len(self.inner_inputs)
...@@ -2288,7 +2332,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2288,7 +2332,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# in the inner function. # in the inner function.
r = node.outputs[n_outs + x] r = node.outputs[n_outs + x]
assert r.ndim == 1 + len(out_shape_x) assert r.ndim == 1 + len(out_shape_x)
shp = [node.inputs[offset + info.n_shared_outs + x]] shp = [node.inputs[offset + info.n_untraced_sit_sot_outs + x]]
for i, shp_i in zip(range(1, r.ndim), out_shape_x, strict=True): for i, shp_i in zip(range(1, r.ndim), out_shape_x, strict=True):
# Validate shp_i. v_shape_i is either None (if invalid), # Validate shp_i. v_shape_i is either None (if invalid),
# or a (variable, Boolean) tuple. The Boolean indicates # or a (variable, Boolean) tuple. The Boolean indicates
...@@ -2305,7 +2349,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2305,7 +2349,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
shp.append(v_shp_i[0]) shp.append(v_shp_i[0])
scan_outs.append(tuple(shp)) scan_outs.append(tuple(shp))
scan_outs += list(input_shapes[offset : offset + info.n_shared_outs]) scan_outs += list(input_shapes[offset : offset + info.n_untraced_sit_sot_outs])
# if we are dealing with a repeat-until, then we do not know the # if we are dealing with a repeat-until, then we do not know the
# leading dimension so we replace it for every entry with Shape_i # leading dimension so we replace it for every entry with Shape_i
if info.as_while: if info.as_while:
...@@ -2735,7 +2779,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2735,7 +2779,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
mitmot_inp_taps.append([]) mitmot_inp_taps.append([])
mitmot_out_taps.append([]) mitmot_out_taps.append([])
undefined_msg = None undefined_msg = None
through_shared = False through_untraced = False
disconnected = True disconnected = True
for mit_mot_out_slice in info.mit_mot_out_slices[idx]: for mit_mot_out_slice in info.mit_mot_out_slices[idx]:
...@@ -2779,9 +2823,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2779,9 +2823,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
disconnected &= disconnected_dC_dinps_t[ins_pos] disconnected &= disconnected_dC_dinps_t[ins_pos]
through_shared = any( through_untraced = any(
_sh in graph_inputs([dC_dinps_t[ins_pos]]) _sh in graph_inputs([dC_dinps_t[ins_pos]])
for _sh in self.inner_shared(self_inputs) for _sh in self.inner_untraced_sit_sot(self_inputs)
) )
ins_pos += 1 ins_pos += 1
...@@ -2795,8 +2839,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2795,8 +2839,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
if undefined_msg: if undefined_msg:
type_outs.append(undefined_msg) type_outs.append(undefined_msg)
elif through_shared: elif through_untraced:
type_outs.append("through_shared") type_outs.append("through_untraced")
elif disconnected: elif disconnected:
type_outs.append("disconnected") type_outs.append("disconnected")
else: else:
...@@ -2814,7 +2858,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2814,7 +2858,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
out_pos += 1 out_pos += 1
n_mitmot_inps += 1 n_mitmot_inps += 1
undefined_msg = None undefined_msg = None
through_shared = False through_untraced = False
disconnected = True disconnected = True
mitmot_inp_taps[idx + offset].append(0) mitmot_inp_taps[idx + offset].append(0)
for tap in taps: for tap in taps:
...@@ -2836,9 +2880,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2836,9 +2880,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
disconnected &= disconnected_dC_dinps_t[ins_pos] disconnected &= disconnected_dC_dinps_t[ins_pos]
through_shared = any( through_untraced = any(
_sh in graph_inputs([dC_dinps_t[ins_pos]]) _sh in graph_inputs([dC_dinps_t[ins_pos]])
for _sh in self.inner_shared(self_inputs) for _sh in self.inner_untraced_sit_sot(self_inputs)
) )
n_mitmot_inps += 1 n_mitmot_inps += 1
...@@ -2847,8 +2891,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2847,8 +2891,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
if undefined_msg: if undefined_msg:
type_outs.append(undefined_msg) type_outs.append(undefined_msg)
elif through_shared: elif through_untraced:
type_outs.append("through_shared") type_outs.append("through_untraced")
elif disconnected: elif disconnected:
type_outs.append("disconnected") type_outs.append("disconnected")
else: else:
...@@ -2884,15 +2928,15 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2884,15 +2928,15 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
else: else:
inner_out_mitmot.append(dC_dinps_t[ins_pos]) inner_out_mitmot.append(dC_dinps_t[ins_pos])
through_shared = any( through_untraced = any(
_sh in graph_inputs([dC_dinps_t[ins_pos]]) _sh in graph_inputs([dC_dinps_t[ins_pos]])
for _sh in self.inner_shared(self_inputs) for _sh in self.inner_untraced_sit_sot(self_inputs)
) )
if isinstance(dC_dinps_t[ins_pos].type, NullType): if isinstance(dC_dinps_t[ins_pos].type, NullType):
type_outs.append(dC_dinps_t[ins_pos].type.why_null) type_outs.append(dC_dinps_t[ins_pos].type.why_null)
elif through_shared: elif through_untraced:
type_outs.append("through_shared") type_outs.append("through_untraced")
elif disconnected_dC_dinps_t[ins_pos]: elif disconnected_dC_dinps_t[ins_pos]:
type_outs.append("disconnected") type_outs.append("disconnected")
else: else:
...@@ -2911,10 +2955,10 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2911,10 +2955,10 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
inner_out_nitsot = dC_dinps_t[: info.n_seqs] inner_out_nitsot = dC_dinps_t[: info.n_seqs]
inner_out_sitsot = dC_dinps_t[ins_pos:] inner_out_sitsot = dC_dinps_t[ins_pos:]
for _p, vl in enumerate(inner_out_sitsot): for _p, vl in enumerate(inner_out_sitsot):
through_shared = False through_untraced = False
for _sh in self.inner_shared(self_inputs): for _sh in self.inner_untraced_sit_sot(self_inputs):
if _sh in graph_inputs([vl]): if _sh in graph_inputs([vl]):
through_shared = True through_untraced = True
if isinstance(vl.type, NullType): if isinstance(vl.type, NullType):
type_outs.append(vl.type.why_null) type_outs.append(vl.type.why_null)
# Replace the inner output with a zero tensor of # Replace the inner output with a zero tensor of
...@@ -2922,18 +2966,18 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2922,18 +2966,18 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
inner_out_sitsot[_p] = pt.zeros( inner_out_sitsot[_p] = pt.zeros(
diff_inputs[ins_pos + _p].shape, dtype=config.floatX diff_inputs[ins_pos + _p].shape, dtype=config.floatX
) )
elif through_shared: elif through_untraced:
type_outs.append("through_shared") type_outs.append("through_untraced")
elif disconnected_dC_dinps_t[_p + ins_pos]: elif disconnected_dC_dinps_t[_p + ins_pos]:
type_outs.append("disconnected") type_outs.append("disconnected")
else: else:
type_outs.append("connected") type_outs.append("connected")
for _p, vl in enumerate(inner_out_nitsot): for _p, vl in enumerate(inner_out_nitsot):
through_shared = False through_untraced = False
for _sh in self.inner_shared(self_inputs): for _sh in self.inner_untraced_sit_sot(self_inputs):
if _sh in graph_inputs([vl]): if _sh in graph_inputs([vl]):
through_shared = True through_untraced = True
if isinstance(vl.type, NullType): if isinstance(vl.type, NullType):
type_outs.append(vl.type.why_null) type_outs.append(vl.type.why_null)
# Replace the inner output with a zero tensor of # Replace the inner output with a zero tensor of
...@@ -2942,8 +2986,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2942,8 +2986,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
diff_inputs[_p].shape, dtype=config.floatX diff_inputs[_p].shape, dtype=config.floatX
) )
if through_shared: if through_untraced:
type_outs.append("through_shared") type_outs.append("through_untraced")
elif disconnected_dC_dinps_t[_p]: elif disconnected_dC_dinps_t[_p]:
type_outs.append("disconnected") type_outs.append("disconnected")
else: else:
...@@ -2983,7 +3027,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2983,7 +3027,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
+ outer_inp_mitmot + outer_inp_mitmot
+ outer_inp_sitsot + outer_inp_sitsot
+ [n_steps if info.as_while else inputs[0] for _ in range(n_nit_sot)] + [n_steps if info.as_while else inputs[0] for _ in range(n_nit_sot)]
+ self.outer_shared(inputs) + self.outer_untraced_sit_sot(inputs)
+ self.outer_non_seqs(inputs) + self.outer_non_seqs(inputs)
) )
...@@ -2991,7 +3035,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2991,7 +3035,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
inner_inp_seqs inner_inp_seqs
+ inner_inp_mitmot + inner_inp_mitmot
+ inner_inp_sitsot + inner_inp_sitsot
+ self.inner_shared(self_inputs) + self.inner_untraced_sit_sot(self_inputs)
+ self.inner_non_seqs(self_inputs) + self.inner_non_seqs(self_inputs)
) )
inner_gfn_outs = inner_out_mitmot + inner_out_sitsot + inner_out_nitsot inner_gfn_outs = inner_out_mitmot + inner_out_sitsot + inner_out_nitsot
...@@ -3003,8 +3047,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3003,8 +3047,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
mit_sot_in_slices=(), mit_sot_in_slices=(),
sit_sot_in_slices=tuple((-1,) for k in range(n_sitsot_outs)), sit_sot_in_slices=tuple((-1,) for k in range(n_sitsot_outs)),
n_nit_sot=n_nit_sot, n_nit_sot=n_nit_sot,
n_shared_outs=0, n_untraced_sit_sot_outs=0,
n_non_seqs=len(self.outer_shared(inputs)) n_non_seqs=len(self.outer_untraced_sit_sot(inputs))
+ len(self.outer_non_seqs(inputs)), + len(self.outer_non_seqs(inputs)),
as_while=False, as_while=False,
) )
...@@ -3047,10 +3091,10 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3047,10 +3091,10 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
gradients.append(x[::-1]) gradients.append(x[::-1])
elif t == "disconnected": elif t == "disconnected":
gradients.append(DisconnectedType()()) gradients.append(DisconnectedType()())
elif t == "through_shared": elif t == "through_untraced":
gradients.append( gradients.append(
grad_undefined( grad_undefined(
self, p + 1, inputs[p + 1], "Depends on a shared variable" self, p + 1, inputs[p + 1], "Depends on a untraced variable"
) )
) )
else: else:
...@@ -3075,13 +3119,13 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3075,13 +3119,13 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
gradients.append(x[::-1]) gradients.append(x[::-1])
elif t == "disconnected": elif t == "disconnected":
gradients.append(DisconnectedType()()) gradients.append(DisconnectedType()())
elif t == "through_shared": elif t == "through_untraced":
gradients.append( gradients.append(
grad_undefined( grad_undefined(
self, self,
p + 1 + info.n_seqs, p + 1 + info.n_seqs,
inputs[p + 1 + info.n_seqs], inputs[p + 1 + info.n_seqs],
"Depends on a shared variable", "Depends on an untraced variable",
) )
) )
else: else:
...@@ -3090,7 +3134,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3090,7 +3134,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
start = len(gradients) start = len(gradients)
node = outs[0].owner node = outs[0].owner
for idx in range(info.n_shared_outs): for idx in range(info.n_untraced_sit_sot_outs):
disconnected = True disconnected = True
connected_flags = self.connection_pattern(node)[idx + start] connected_flags = self.connection_pattern(node)[idx + start]
for dC_dout, connected in zip(dC_douts, connected_flags, strict=True): for dC_dout, connected in zip(dC_douts, connected_flags, strict=True):
...@@ -3116,13 +3160,13 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3116,13 +3160,13 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
gradients.append(x[-1]) gradients.append(x[-1])
elif t == "disconnected": elif t == "disconnected":
gradients.append(DisconnectedType()()) gradients.append(DisconnectedType()())
elif t == "through_shared": elif t == "through_untraced":
gradients.append( gradients.append(
grad_undefined( grad_undefined(
self, self,
p + begin + 1, p + begin + 1,
inputs[p + begin + 1], inputs[p + begin + 1],
"Depends on a shared variable", "Depends on a untraced variable",
) )
) )
else: else:
...@@ -3152,7 +3196,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3152,7 +3196,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self_inputs = self.inner_inputs self_inputs = self.inner_inputs
rop_of_inputs = ( rop_of_inputs = (
self_inputs[: info.n_seqs + self.n_outs] self_inputs[: info.n_seqs + self.n_outs]
+ self_inputs[info.n_seqs + self.n_outs + info.n_shared_outs :] + self_inputs[info.n_seqs + self.n_outs + info.n_untraced_sit_sot_outs :]
) )
self_outputs = self.inner_outputs self_outputs = self.inner_outputs
...@@ -3162,8 +3206,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3162,8 +3206,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
rop_self_outputs = self_outputs[:-1] rop_self_outputs = self_outputs[:-1]
else: else:
rop_self_outputs = self_outputs rop_self_outputs = self_outputs
if info.n_shared_outs > 0: if info.n_untraced_sit_sot_outs > 0:
rop_self_outputs = rop_self_outputs[: -info.n_shared_outs] rop_self_outputs = rop_self_outputs[: -info.n_untraced_sit_sot_outs]
rop_outs = Rop( rop_outs = Rop(
rop_self_outputs, rop_self_outputs,
rop_of_inputs, rop_of_inputs,
...@@ -3247,13 +3291,13 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3247,13 +3291,13 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
scan_sit_sot = inputs[b:e] + clean_eval_points scan_sit_sot = inputs[b:e] + clean_eval_points
inner_sit_sot = self_inputs[ib:ie] + inner_eval_points[ib:ie] inner_sit_sot = self_inputs[ib:ie] + inner_eval_points[ib:ie]
# Shared outs ... # Untraced outs ...
b = e b = e
e = e + info.n_shared_outs e = e + info.n_untraced_sit_sot_outs
ib = ie ib = ie
ie = ie + info.n_shared_outs ie = ie + info.n_untraced_sit_sot_outs
scan_shared = inputs[b:e] scan_untraced = inputs[b:e]
inner_shared = self_inputs[ib:ie] inner_untraced = self_inputs[ib:ie]
# NIT_SOT sequences # NIT_SOT sequences
b = e b = e
...@@ -3268,7 +3312,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3268,7 +3312,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
else: else:
clean_eval_points.append(inp.zeros_like()) clean_eval_points.append(inp.zeros_like())
scan_other = inputs[e:] + clean_eval_points scan_other = inputs[e:] + clean_eval_points
# inner_eval_points do not have entries for shared variables # inner_eval_points do not have entries for untraced variables
inner_other = self_inputs[ie:] + inner_eval_points[ib:] inner_other = self_inputs[ie:] + inner_eval_points[ib:]
# Outputs # Outputs
...@@ -3287,15 +3331,15 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3287,15 +3331,15 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
e = e + info.n_nit_sot e = e + info.n_nit_sot
inner_out_nit_sot = self_outputs[b:e] + rop_outs[b:e] inner_out_nit_sot = self_outputs[b:e] + rop_outs[b:e]
b = e b = e
e = e + info.n_shared_outs e = e + info.n_untraced_sit_sot_outs
inner_out_shared = self_outputs[b:e] inner_out_untraced = self_outputs[b:e]
inner_ins = ( inner_ins = (
inner_seqs inner_seqs
+ inner_mit_mot + inner_mit_mot
+ inner_mit_sot + inner_mit_sot
+ inner_sit_sot + inner_sit_sot
+ inner_shared + inner_untraced
+ inner_other + inner_other
) )
inner_outs = ( inner_outs = (
...@@ -3303,7 +3347,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3303,7 +3347,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
+ inner_out_mit_sot + inner_out_mit_sot
+ inner_out_sit_sot + inner_out_sit_sot
+ inner_out_nit_sot + inner_out_nit_sot
+ inner_out_shared + inner_out_untraced
) )
if info.as_while: if info.as_while:
...@@ -3314,7 +3358,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3314,7 +3358,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
*scan_mit_mot, *scan_mit_mot,
*scan_mit_sot, *scan_mit_sot,
*scan_sit_sot, *scan_sit_sot,
*scan_shared, *scan_untraced,
*scan_nit_sot, *scan_nit_sot,
*scan_other, *scan_other,
] ]
...@@ -3326,7 +3370,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3326,7 +3370,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
mit_sot_in_slices=new_mit_sot_in_slices, mit_sot_in_slices=new_mit_sot_in_slices,
sit_sot_in_slices=new_sit_sot_in_slices, sit_sot_in_slices=new_sit_sot_in_slices,
n_nit_sot=info.n_nit_sot * 2, n_nit_sot=info.n_nit_sot * 2,
n_shared_outs=info.n_shared_outs, n_untraced_sit_sot_outs=info.n_untraced_sit_sot_outs,
n_non_seqs=len(inner_other), n_non_seqs=len(inner_other),
as_while=info.as_while, as_while=info.as_while,
) )
...@@ -3358,7 +3402,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3358,7 +3402,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
b = e + info.n_nit_sot b = e + info.n_nit_sot
e = e + info.n_nit_sot * 2 e = e + info.n_nit_sot * 2
final_outs += outputs[b:e] final_outs += outputs[b:e]
final_outs += [None] * info.n_shared_outs final_outs += [None] * info.n_untraced_sit_sot_outs
return final_outs return final_outs
......
...@@ -110,7 +110,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node): ...@@ -110,7 +110,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
sum(len(x) for x in chain(op_info.mit_mot_in_slices, op_info.mit_sot_in_slices)) sum(len(x) for x in chain(op_info.mit_mot_in_slices, op_info.mit_sot_in_slices))
) )
st += op_info.n_sit_sot st += op_info.n_sit_sot
st += op_info.n_shared_outs st += op_info.n_untraced_sit_sot_outs
op_ins = op.inner_inputs op_ins = op.inner_inputs
op_outs = op.inner_outputs op_outs = op.inner_outputs
...@@ -126,7 +126,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node): ...@@ -126,7 +126,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
+ op_info.n_mit_sot + op_info.n_mit_sot
+ op_info.n_sit_sot + op_info.n_sit_sot
+ op_info.n_nit_sot + op_info.n_nit_sot
+ op_info.n_shared_outs + op_info.n_untraced_sit_sot_outs
+ 1 + 1
) )
outer_non_seqs = node.inputs[st:] outer_non_seqs = node.inputs[st:]
...@@ -983,7 +983,7 @@ class ScanInplaceOptimizer(GraphRewriter): ...@@ -983,7 +983,7 @@ class ScanInplaceOptimizer(GraphRewriter):
ls = op.outer_mitmot(node.inputs) ls = op.outer_mitmot(node.inputs)
ls += op.outer_mitsot(node.inputs) ls += op.outer_mitsot(node.inputs)
ls += op.outer_sitsot(node.inputs) ls += op.outer_sitsot(node.inputs)
ls_end = op.outer_shared(node.inputs) ls_end = op.outer_untraced_sit_sot(node.inputs)
ls_end += op.outer_nitsot(node.inputs) ls_end += op.outer_nitsot(node.inputs)
ls_end += op.outer_non_seqs(node.inputs) ls_end += op.outer_non_seqs(node.inputs)
...@@ -1628,7 +1628,7 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: ...@@ -1628,7 +1628,7 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
+ idx + idx
+ op_info.n_seqs + op_info.n_seqs
+ 1 + 1
+ op_info.n_shared_outs + op_info.n_untraced_sit_sot_outs
) )
if nw_inputs[pos] == node.inputs[0]: if nw_inputs[pos] == node.inputs[0]:
nw_inputs[pos] = 1 if required_orphan else val nw_inputs[pos] = 1 if required_orphan else val
...@@ -1662,7 +1662,7 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: ...@@ -1662,7 +1662,7 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
elif ( elif (
idx < op_info.n_mit_sot + op_info.n_sit_sot + op_info.n_nit_sot idx < op_info.n_mit_sot + op_info.n_sit_sot + op_info.n_nit_sot
): ):
in_idx = offset + idx + op_info.n_shared_outs in_idx = offset + idx + op_info.n_untraced_sit_sot_outs
if nw_inputs[in_idx] == node.inputs[0]: if nw_inputs[in_idx] == node.inputs[0]:
nw_inputs[in_idx] = nw_steps nw_inputs[in_idx] = nw_steps
...@@ -1886,8 +1886,8 @@ class ScanMerge(GraphRewriter): ...@@ -1886,8 +1886,8 @@ class ScanMerge(GraphRewriter):
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# Shared # Shared
inner_ins[idx].append(nd.op.inner_shared(nd.op.inner_inputs)) inner_ins[idx].append(nd.op.inner_untraced_sit_sot(nd.op.inner_inputs))
outer_ins += nd.op.outer_shared(nd.inputs) outer_ins += nd.op.outer_untraced_sit_sot(nd.inputs)
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# NitSot # NitSot
...@@ -1897,8 +1897,10 @@ class ScanMerge(GraphRewriter): ...@@ -1897,8 +1897,10 @@ class ScanMerge(GraphRewriter):
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# Shared # Shared
outer_outs += nd.op.outer_shared_outs(nd.outputs) outer_outs += nd.op.outer_untraced_sit_sot_outs(nd.outputs)
inner_outs[idx].append(nd.op.inner_shared_outs(nd.op.inner_outputs)) inner_outs[idx].append(
nd.op.inner_untraced_sit_sot_outs(nd.op.inner_outputs)
)
n_non_seqs = 0 n_non_seqs = 0
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
...@@ -1978,7 +1980,9 @@ class ScanMerge(GraphRewriter): ...@@ -1978,7 +1980,9 @@ class ScanMerge(GraphRewriter):
mit_sot_in_slices=mit_sot_in_slices, mit_sot_in_slices=mit_sot_in_slices,
sit_sot_in_slices=sit_sot_in_slices, sit_sot_in_slices=sit_sot_in_slices,
n_nit_sot=sum(nd.op.info.n_nit_sot for nd in nodes), n_nit_sot=sum(nd.op.info.n_nit_sot for nd in nodes),
n_shared_outs=sum(nd.op.info.n_shared_outs for nd in nodes), n_untraced_sit_sot_outs=sum(
nd.op.info.n_untraced_sit_sot_outs for nd in nodes
),
n_non_seqs=n_non_seqs, n_non_seqs=n_non_seqs,
as_while=as_while, as_while=as_while,
) )
...@@ -2360,7 +2364,7 @@ def scan_push_out_dot1(fgraph, node): ...@@ -2360,7 +2364,7 @@ def scan_push_out_dot1(fgraph, node):
# When seq[t] is a vector/matrix and `value` is a matrix # When seq[t] is a vector/matrix and `value` is a matrix
# Note that this works when only you need X[-1] in the end # Note that this works when only you need X[-1] in the end
# and assumes dimshuffle are applied to vectors before calling dot # and assumes dimshuffle are applied to vectors before calling dot
op = node.op op: Scan = node.op
sitsot_ins = op.inner_sitsot(op.inner_inputs) sitsot_ins = op.inner_sitsot(op.inner_inputs)
sitsot_outs = op.inner_sitsot_outs(op.inner_outputs) sitsot_outs = op.inner_sitsot_outs(op.inner_outputs)
outer_sitsot = op.outer_sitsot_outs(node.outputs) outer_sitsot = op.outer_sitsot_outs(node.outputs)
...@@ -2416,9 +2420,13 @@ def scan_push_out_dot1(fgraph, node): ...@@ -2416,9 +2420,13 @@ def scan_push_out_dot1(fgraph, node):
inner_sitsot_outs = op.inner_sitsot_outs(op.inner_outputs) inner_sitsot_outs = op.inner_sitsot_outs(op.inner_outputs)
outer_nitsot = op.outer_nitsot(node.inputs) outer_nitsot = op.outer_nitsot(node.inputs)
inner_nitsot_outs = op.inner_nitsot_outs(op.inner_outputs) inner_nitsot_outs = op.inner_nitsot_outs(op.inner_outputs)
inner_shared = op.inner_shared(op.inner_inputs) inner_untraced_sitsot = op.inner_untraced_sitsot(op.inner_inputs)
outer_shared = op.outer_shared(node.inputs) outer_untraced_sitsot_outs = op.outer_untraced_sitsot_outs(
inner_shared_outs = op.inner_shared_outs(op.inner_outputs) node.inputs
)
inner_untraced_sitsot_outs = op.inner_untraced_sitsot_outs(
op.inner_outputs
)
inner_non_seqs = op.inner_non_seqs(op.inner_inputs) inner_non_seqs = op.inner_non_seqs(op.inner_inputs)
outer_non_seqs = op.outer_non_seqs(node.inputs) outer_non_seqs = op.outer_non_seqs(node.inputs)
...@@ -2441,7 +2449,7 @@ def scan_push_out_dot1(fgraph, node): ...@@ -2441,7 +2449,7 @@ def scan_push_out_dot1(fgraph, node):
+ inner_mitmot + inner_mitmot
+ inner_mitsot + inner_mitsot
+ inner_sitsot + inner_sitsot
+ inner_shared + inner_untraced_sitsot
+ inner_non_seqs + inner_non_seqs
) )
_new_inner_outs = ( _new_inner_outs = (
...@@ -2449,7 +2457,7 @@ def scan_push_out_dot1(fgraph, node): ...@@ -2449,7 +2457,7 @@ def scan_push_out_dot1(fgraph, node):
+ inner_mitsot_outs + inner_mitsot_outs
+ inner_sitsot_outs + inner_sitsot_outs
+ inner_nitsot_outs + inner_nitsot_outs
+ inner_shared_outs + inner_untraced_sitsot_outs
) )
new_inner_inps, new_inner_outs = reconstruct_graph( new_inner_inps, new_inner_outs = reconstruct_graph(
_new_inner_inps, _new_inner_outs _new_inner_inps, _new_inner_outs
...@@ -2471,7 +2479,7 @@ def scan_push_out_dot1(fgraph, node): ...@@ -2471,7 +2479,7 @@ def scan_push_out_dot1(fgraph, node):
*outer_mitmot, *outer_mitmot,
*outer_mitsot, *outer_mitsot,
*outer_sitsot, *outer_sitsot,
*outer_shared, *outer_untraced_sitsot_outs,
*outer_nitsot, *outer_nitsot,
node.inputs[0], node.inputs[0],
*outer_non_seqs, *outer_non_seqs,
......
...@@ -370,7 +370,9 @@ def scan_can_remove_outs(op, out_idxs): ...@@ -370,7 +370,9 @@ def scan_can_remove_outs(op, out_idxs):
out_ins += [op.inner_inputs[offset : offset + n_ins]] out_ins += [op.inner_inputs[offset : offset + n_ins]]
offset += n_ins offset += n_ins
out_ins += [[] for k in range(op.info.n_nit_sot)] out_ins += [[] for k in range(op.info.n_nit_sot)]
out_ins += [[op.inner_inputs[offset + k]] for k in range(op.info.n_shared_outs)] out_ins += [
[op.inner_inputs[offset + k]] for k in range(op.info.n_untraced_sit_sot_outs)
]
added = True added = True
out_idxs_mask = [1 for idx in out_idxs] out_idxs_mask = [1 for idx in out_idxs]
...@@ -409,7 +411,7 @@ def compress_outs(op, not_required, inputs): ...@@ -409,7 +411,7 @@ def compress_outs(op, not_required, inputs):
mit_sot_in_slices=(), mit_sot_in_slices=(),
sit_sot_in_slices=(), sit_sot_in_slices=(),
n_nit_sot=0, n_nit_sot=0,
n_shared_outs=0, n_untraced_sit_sot_outs=0,
n_non_seqs=0, n_non_seqs=0,
as_while=op_info.as_while, as_while=op_info.as_while,
) )
...@@ -515,17 +517,19 @@ def compress_outs(op, not_required, inputs): ...@@ -515,17 +517,19 @@ def compress_outs(op, not_required, inputs):
info = dataclasses.replace(info, n_nit_sot=info.n_nit_sot + 1) info = dataclasses.replace(info, n_nit_sot=info.n_nit_sot + 1)
op_outputs += [op.inner_outputs[o_offset]] op_outputs += [op.inner_outputs[o_offset]]
o_offset += 1 o_offset += 1
nit_sot_ins += [inputs[ni_offset + idx + op_info.n_shared_outs]] nit_sot_ins += [inputs[ni_offset + idx + op_info.n_untraced_sit_sot_outs]]
else: else:
o_offset += 1 o_offset += 1
offset += op_info.n_nit_sot offset += op_info.n_nit_sot
shared_ins = [] shared_ins = []
for idx in range(op_info.n_shared_outs): for idx in range(op_info.n_untraced_sit_sot_outs):
if offset + idx not in not_required: if offset + idx not in not_required:
map_old_new[offset + idx] = curr_pos map_old_new[offset + idx] = curr_pos
curr_pos += 1 curr_pos += 1
info = dataclasses.replace(info, n_shared_outs=info.n_shared_outs + 1) info = dataclasses.replace(
info, n_untraced_sit_sot_outs=info.n_untraced_sit_sot_outs + 1
)
op_outputs += [op.inner_outputs[o_offset]] op_outputs += [op.inner_outputs[o_offset]]
o_offset += 1 o_offset += 1
op_inputs += [op.inner_inputs[i_offset]] op_inputs += [op.inner_inputs[i_offset]]
...@@ -539,7 +543,9 @@ def compress_outs(op, not_required, inputs): ...@@ -539,7 +543,9 @@ def compress_outs(op, not_required, inputs):
# other stuff # other stuff
op_inputs += op.inner_inputs[i_offset:] op_inputs += op.inner_inputs[i_offset:]
info = dataclasses.replace(info, n_non_seqs=len(op.inner_inputs[i_offset:])) info = dataclasses.replace(info, n_non_seqs=len(op.inner_inputs[i_offset:]))
node_inputs += inputs[ni_offset + op_info.n_shared_outs + op_info.n_nit_sot :] node_inputs += inputs[
ni_offset + op_info.n_untraced_sit_sot_outs + op_info.n_nit_sot :
]
if op_info.as_while: if op_info.as_while:
op_outputs += [op.inner_outputs[o_offset]] op_outputs += [op.inner_outputs[o_offset]]
map_old_new[o_offset] = len(op_outputs) - 1 map_old_new[o_offset] = len(op_outputs) - 1
...@@ -658,11 +664,11 @@ class ScanArgs: ...@@ -658,11 +664,11 @@ class ScanArgs:
p += n_sit_sot p += n_sit_sot
q += n_sit_sot q += n_sit_sot
n_shared_outs = info.n_shared_outs n_untraced_sit_sot_outs = info.n_untraced_sit_sot_outs
self.outer_in_shared = list(outer_inputs[p : p + n_shared_outs]) self.outer_in_shared = list(outer_inputs[p : p + n_untraced_sit_sot_outs])
self.inner_in_shared = list(inner_inputs[q : q + n_shared_outs]) self.inner_in_shared = list(inner_inputs[q : q + n_untraced_sit_sot_outs])
p += n_shared_outs p += n_untraced_sit_sot_outs
q += n_shared_outs q += n_untraced_sit_sot_outs
n_nit_sot = info.n_nit_sot n_nit_sot = info.n_nit_sot
self.outer_in_nit_sot = list(outer_inputs[p : p + n_nit_sot]) self.outer_in_nit_sot = list(outer_inputs[p : p + n_nit_sot])
...@@ -702,10 +708,10 @@ class ScanArgs: ...@@ -702,10 +708,10 @@ class ScanArgs:
p += n_nit_sot p += n_nit_sot
q += n_nit_sot q += n_nit_sot
self.outer_out_shared = list(outer_outputs[p : p + n_shared_outs]) self.outer_out_shared = list(outer_outputs[p : p + n_untraced_sit_sot_outs])
self.inner_out_shared = list(inner_outputs[q : q + n_shared_outs]) self.inner_out_shared = list(inner_outputs[q : q + n_untraced_sit_sot_outs])
p += n_shared_outs p += n_untraced_sit_sot_outs
q += n_shared_outs q += n_untraced_sit_sot_outs
assert p == len(outer_outputs) assert p == len(outer_outputs)
assert q == len(inner_outputs) assert q == len(inner_outputs)
...@@ -816,7 +822,7 @@ class ScanArgs: ...@@ -816,7 +822,7 @@ class ScanArgs:
mit_sot_in_slices=tuple(tuple(v) for v in self.mit_sot_in_slices), mit_sot_in_slices=tuple(tuple(v) for v in self.mit_sot_in_slices),
sit_sot_in_slices=((-1,),) * len(self.inner_in_sit_sot), sit_sot_in_slices=((-1,),) * len(self.inner_in_sit_sot),
n_nit_sot=len(self.outer_in_nit_sot), n_nit_sot=len(self.outer_in_nit_sot),
n_shared_outs=len(self.outer_in_shared), n_untraced_sit_sot_outs=len(self.outer_in_shared),
n_non_seqs=len(self.inner_in_non_seqs), n_non_seqs=len(self.inner_in_non_seqs),
as_while=self.as_while, as_while=self.as_while,
) )
......
...@@ -85,7 +85,7 @@ from tests.link.numba.test_basic import compare_numba_and_py ...@@ -85,7 +85,7 @@ from tests.link.numba.test_basic import compare_numba_and_py
3, 3,
[], [],
[np.array([0.50100236, 2.16822932, 1.36326596])], [np.array([0.50100236, 2.16822932, 1.36326596])],
lambda op: op.info.n_shared_outs > 0, lambda op: op.info.n_untraced_sit_sot_outs > 0,
), ),
# mit-sot (that's also a type of sit-sot) # mit-sot (that's also a type of sit-sot)
( (
......
...@@ -42,11 +42,13 @@ from pytensor.tensor.math import all as pt_all ...@@ -42,11 +42,13 @@ from pytensor.tensor.math import all as pt_all
from pytensor.tensor.math import dot, exp, mean, sigmoid, tanh from pytensor.tensor.math import dot, exp, mean, sigmoid, tanh
from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.random import normal from pytensor.tensor.random import normal
from pytensor.tensor.random.type import RandomGeneratorType, random_generator_type
from pytensor.tensor.random.utils import RandomStream from pytensor.tensor.random.utils import RandomStream
from pytensor.tensor.shape import Shape_i, reshape, specify_shape from pytensor.tensor.shape import Shape_i, reshape, specify_shape
from pytensor.tensor.sharedvar import SharedVariable from pytensor.tensor.sharedvar import SharedVariable
from pytensor.tensor.subtensor import Subtensor from pytensor.tensor.subtensor import Subtensor
from pytensor.tensor.type import ( from pytensor.tensor.type import (
TensorType,
dcol, dcol,
dmatrix, dmatrix,
dscalar, dscalar,
...@@ -4007,7 +4009,7 @@ class TestExamples: ...@@ -4007,7 +4009,7 @@ class TestExamples:
[{}], [{}],
[], [],
3, 3,
lambda op: op.info.n_shared_outs > 0, lambda op: op.info.n_untraced_sit_sot_outs > 0,
), ),
# mit-sot (that's also a type of sit-sot) # mit-sot (that's also a type of sit-sot)
( (
...@@ -4106,3 +4108,34 @@ def test_output_storage_reuse(linker_mode): ...@@ -4106,3 +4108,34 @@ def test_output_storage_reuse(linker_mode):
res = f_cvm() res = f_cvm()
assert np.array_equal(res, np.array([3, 1, 0])) assert np.array_equal(res, np.array([3, 1, 0]))
def test_rng_outputs_info():
rng_init = random_generator_type("rng")
rng_x0, x0 = pt.random.normal(0, rng=rng_init, dtype="float64").owner.outputs
def step(prev_x, prev_rng):
next_rng, next_x = pt.random.normal(
prev_x, rng=prev_rng, dtype="float64"
).owner.outputs
return next_x, next_rng
[xs, rng_final], updates = scan(
fn=step,
outputs_info=[x0, rng_x0],
n_steps=10,
)
assert isinstance(xs.type, TensorType)
assert isinstance(rng_final.type, RandomGeneratorType)
assert not updates
fn = function([rng_init], [xs, rng_final])
xs_eval, rng_final_eval = fn(np.random.default_rng(0))
rng_ref = np.random.default_rng(0)
assert not random_generator_type.values_eq(rng_ref, rng_final_eval)
xs_ref = [rng_ref.normal(0)]
for i in range(10):
xs_ref.append(rng_ref.normal(xs_ref[-1]))
assert random_generator_type.values_eq(rng_ref, rng_final_eval)
np.testing.assert_allclose(xs_eval, xs_ref[1:])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论