提交 f3a7d94f authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Scan dispatches: correct handling of signed mitmot taps

Unlike MIT-SOT and SIT-SOT these can be positive or negative, depending on the order of differentiation
上级 ebc0de09
......@@ -90,7 +90,7 @@ def jax_funcify_Scan(op: Scan, **kwargs):
chain.from_iterable(
buffer[(i + np.array(taps))]
for buffer, taps in zip(
inner_mit_mot, info.mit_mot_in_slices, strict=True
inner_mit_mot, info.normalized_mit_mot_in_slices, strict=True
)
)
)
......@@ -140,7 +140,10 @@ def jax_funcify_Scan(op: Scan, **kwargs):
new_mit_mot = [
buffer.at[i + np.array(taps)].set(new_vals)
for buffer, new_vals, taps in zip(
old_mit_mot, new_mit_mot_vals, info.mit_mot_out_slices, strict=True
old_mit_mot,
new_mit_mot_vals,
info.normalized_mit_mot_out_slices,
strict=True,
)
]
# Discard oldest MIT-SOT and append newest value
......
......@@ -27,9 +27,8 @@ def idx_to_str(
idx_symbol: str = "i",
allow_scalar=False,
) -> str:
if offset < 0:
indices = f"{idx_symbol} + {array_name}.shape[0] - {offset}"
elif offset > 0:
assert offset >= 0
if offset > 0:
indices = f"{idx_symbol} + {offset}"
else:
indices = idx_symbol
......@@ -226,33 +225,16 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
# storage array like a circular buffer, and that's why we need to track the
# storage size along with the taps length/indexing offset.
def add_output_storage_post_proc_stmt(
outer_in_name: str, tap_sizes: tuple[int, ...], storage_size: str
outer_in_name: str, max_offset: int, storage_size: str
):
tap_size = max(tap_sizes)
if op.info.as_while:
# While loops need to truncate the output storage to a length given
# by the number of iterations performed.
output_storage_post_proc_stmts.append(
dedent(
f"""
if i + {tap_size} < {storage_size}:
{storage_size} = i + {tap_size}
{outer_in_name} = {outer_in_name}[:{storage_size}]
"""
).strip()
)
# Rotate the storage so that the last computed value is at the end of
# the storage array.
# Rotate the storage so that the last computed value is at the end of the storage array.
# This is needed when the output storage array does not have a length
# equal to the number of taps plus `n_steps`.
# If the storage size only allows one entry, there's nothing to rotate
output_storage_post_proc_stmts.append(
dedent(
f"""
if 1 < {storage_size} < (i + {tap_size}):
{outer_in_name}_shift = (i + {tap_size}) % ({storage_size})
if 1 < {storage_size} < (i + {max_offset}):
{outer_in_name}_shift = (i + {max_offset}) % ({storage_size})
if {outer_in_name}_shift > 0:
{outer_in_name}_left = {outer_in_name}[:{outer_in_name}_shift]
{outer_in_name}_right = {outer_in_name}[{outer_in_name}_shift:]
......@@ -261,6 +243,18 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
).strip()
)
if op.info.as_while:
# While loops need to truncate the output storage to a length given
# by the number of iterations performed.
output_storage_post_proc_stmts.append(
dedent(
f"""
elif {storage_size} > (i + {max_offset}):
{outer_in_name} = {outer_in_name}[:i + {max_offset}]
"""
).strip()
)
# Special in-loop statements that create (nit-sot) storage arrays after a
# single iteration is performed. This is necessary because we don't know
# the exact shapes of the storage arrays that need to be allocated until
......@@ -288,12 +282,11 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
storage_size_name = f"{outer_in_name}_len"
storage_size_stmt = f"{storage_size_name} = {outer_in_name}.shape[0]"
input_taps = inner_in_names_to_input_taps[outer_in_name]
tap_storage_size = -min(input_taps)
assert tap_storage_size >= 0
max_lookback_inp_tap = -min(0, min(input_taps))
assert max_lookback_inp_tap >= 0
for in_tap in input_taps:
tap_offset = in_tap + tap_storage_size
assert tap_offset >= 0
tap_offset = max_lookback_inp_tap + in_tap
is_vector = outer_in_var.ndim == 1
add_inner_in_expr(
outer_in_name,
......@@ -302,22 +295,25 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
vector_slice_opt=is_vector,
)
output_taps = inner_in_names_to_output_taps.get(
outer_in_name, [tap_storage_size]
)
inner_out_to_outer_in_stmts.extend(
idx_to_str(
storage_name,
out_tap,
size=storage_size_name,
allow_scalar=True,
output_taps = inner_in_names_to_output_taps.get(outer_in_name, [0])
for out_tap in output_taps:
tap_offset = max_lookback_inp_tap + out_tap
assert tap_offset >= 0
inner_out_to_outer_in_stmts.append(
idx_to_str(
storage_name,
tap_offset,
size=storage_size_name,
allow_scalar=True,
)
)
for out_tap in output_taps
)
add_output_storage_post_proc_stmt(
storage_name, output_taps, storage_size_name
)
if outer_in_name not in outer_in_mit_mot_names:
# MIT-SOT and SIT-SOT may require buffer rolling/truncation after the main loop
max_offset_out_tap = max(output_taps) + max_lookback_inp_tap
add_output_storage_post_proc_stmt(
storage_name, max_offset_out_tap, storage_size_name
)
else:
storage_size_stmt = ""
......@@ -351,7 +347,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
inner_out_to_outer_in_stmts.append(
idx_to_str(storage_name, 0, size=storage_size_name, allow_scalar=True)
)
add_output_storage_post_proc_stmt(storage_name, (0,), storage_size_name)
add_output_storage_post_proc_stmt(storage_name, 0, storage_size_name)
# In case of nit-sots we are provided the length of the array in
# the iteration dimension instead of actual arrays, hence we
......
......@@ -288,6 +288,26 @@ class ScanInfo:
+ self.n_untraced_sit_sot_outs
)
@property
def normalized_mit_mot_in_slices(self) -> tuple[tuple[int, ...], ...]:
"""Return mit_mot_in slices normalized as an offset from the oldest tap"""
# TODO: Make this the canonical representation
res = []
for in_slice in self.mit_mot_in_slices:
min_tap = -(min(0, min(in_slice)))
res.append(tuple(tap + min_tap for tap in in_slice))
return tuple(res)
@property
def normalized_mit_mot_out_slices(self) -> tuple[tuple[int, ...], ...]:
"""Return mit_mot_out slices normalized as an offset from the oldest tap"""
# TODO: Make this the canonical representation
res = []
for out_slice in self.mit_mot_out_slices:
min_tap = -(min(0, min(out_slice)))
res.append(tuple(tap + min_tap for tap in out_slice))
return tuple(res)
TensorConstructorType = Callable[
[Iterable[bool | int | None], str | np.generic], TensorType
......
......@@ -15,6 +15,7 @@ from pytensor.tensor import random
from pytensor.tensor.math import gammaln, log
from pytensor.tensor.type import dmatrix, dvector, matrix, scalar, vector
from tests.link.jax.test_basic import compare_jax_and_py
from tests.scan.test_basic import ScanCompatibilityTests
jax = pytest.importorskip("jax")
......@@ -626,3 +627,7 @@ def test_scan_benchmark(model, mode, gradient_backend, benchmark):
block_until_ready(*test_input_vals) # Warmup
benchmark.pedantic(block_until_ready, test_input_vals, rounds=200, iterations=1)
def test_higher_order_derivatives():
ScanCompatibilityTests.check_higher_order_derivative(mode="JAX")
......@@ -16,6 +16,7 @@ from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random.utils import RandomStream
from tests import unittest_tools as utt
from tests.link.numba.test_basic import compare_numba_and_py
from tests.scan.test_basic import ScanCompatibilityTests
@pytest.mark.parametrize(
......@@ -652,3 +653,7 @@ class TestScanMITSOTBuffer:
def test_mit_sot_buffer_benchmark(self, constant_n_steps, n_steps_val, benchmark):
self.buffer_tester(constant_n_steps, n_steps_val, benchmark=benchmark)
def test_higher_order_derivatives():
ScanCompatibilityTests.check_higher_order_derivative(mode="NUMBA")
......@@ -4082,6 +4082,9 @@ class TestExamples:
# Also, the purpose of this test is not clear.
self._grad_mout_helper(1, None)
def test_higher_order_derivatives(self):
ScanCompatibilityTests.check_higher_order_derivative(mode=None)
@pytest.mark.parametrize(
"fn, sequences, outputs_info, non_sequences, n_steps, op_check",
......@@ -4398,3 +4401,33 @@ def test_scan_mode_compatibility(scan_mode):
# Expected value computed by running correct Scan once
np.testing.assert_allclose(fn(*numerical_inputs), [44, 38])
class ScanCompatibilityTests:
"""Collection of test of subtle required behaviors of Scan, that can be reused by different backends."""
@staticmethod
def check_higher_order_derivative(mode):
"""This tests different mit-mot taps signs"""
x = pt.dscalar("x")
# xs[-1] is equivalent to x ** 16
xs = scan(
fn=lambda xtm1: xtm1**2,
outputs_info=[x],
n_steps=4,
return_updates=False,
)
r = xs[-1]
g = grad(r, x)
gg = grad(g, x)
ggg = grad(gg, x)
fn = function([x], [r, g, gg, ggg], mode=mode)
x_test = np.array(0.95, dtype=x.type.dtype)
r_res, g_res, gg_res, _ggg_res = fn(x_test)
np.testing.assert_allclose(r_res, x_test**16)
np.testing.assert_allclose(g_res, 16 * x_test**15)
np.testing.assert_allclose(gg_res, (16 * 15) * x_test**14)
# FIXME: All implementations of Scan seem to get this one wrong!
# np.testing.assert_allclose(ggg_res, (16 * 15 * 14) * x_test**13)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论