提交 f3a7d94f authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Scan dispatches: correct handling of signed mitmot taps

Unlike MIT-SOT and SIT-SOT these can be positive or negative, depending on the order of differentiation
上级 ebc0de09
...@@ -90,7 +90,7 @@ def jax_funcify_Scan(op: Scan, **kwargs): ...@@ -90,7 +90,7 @@ def jax_funcify_Scan(op: Scan, **kwargs):
chain.from_iterable( chain.from_iterable(
buffer[(i + np.array(taps))] buffer[(i + np.array(taps))]
for buffer, taps in zip( for buffer, taps in zip(
inner_mit_mot, info.mit_mot_in_slices, strict=True inner_mit_mot, info.normalized_mit_mot_in_slices, strict=True
) )
) )
) )
...@@ -140,7 +140,10 @@ def jax_funcify_Scan(op: Scan, **kwargs): ...@@ -140,7 +140,10 @@ def jax_funcify_Scan(op: Scan, **kwargs):
new_mit_mot = [ new_mit_mot = [
buffer.at[i + np.array(taps)].set(new_vals) buffer.at[i + np.array(taps)].set(new_vals)
for buffer, new_vals, taps in zip( for buffer, new_vals, taps in zip(
old_mit_mot, new_mit_mot_vals, info.mit_mot_out_slices, strict=True old_mit_mot,
new_mit_mot_vals,
info.normalized_mit_mot_out_slices,
strict=True,
) )
] ]
# Discard oldest MIT-SOT and append newest value # Discard oldest MIT-SOT and append newest value
......
...@@ -27,9 +27,8 @@ def idx_to_str( ...@@ -27,9 +27,8 @@ def idx_to_str(
idx_symbol: str = "i", idx_symbol: str = "i",
allow_scalar=False, allow_scalar=False,
) -> str: ) -> str:
if offset < 0: assert offset >= 0
indices = f"{idx_symbol} + {array_name}.shape[0] - {offset}" if offset > 0:
elif offset > 0:
indices = f"{idx_symbol} + {offset}" indices = f"{idx_symbol} + {offset}"
else: else:
indices = idx_symbol indices = idx_symbol
...@@ -226,33 +225,16 @@ def numba_funcify_Scan(op: Scan, node, **kwargs): ...@@ -226,33 +225,16 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
# storage array like a circular buffer, and that's why we need to track the # storage array like a circular buffer, and that's why we need to track the
# storage size along with the taps length/indexing offset. # storage size along with the taps length/indexing offset.
def add_output_storage_post_proc_stmt( def add_output_storage_post_proc_stmt(
outer_in_name: str, tap_sizes: tuple[int, ...], storage_size: str outer_in_name: str, max_offset: int, storage_size: str
): ):
tap_size = max(tap_sizes) # Rotate the storage so that the last computed value is at the end of the storage array.
if op.info.as_while:
# While loops need to truncate the output storage to a length given
# by the number of iterations performed.
output_storage_post_proc_stmts.append(
dedent(
f"""
if i + {tap_size} < {storage_size}:
{storage_size} = i + {tap_size}
{outer_in_name} = {outer_in_name}[:{storage_size}]
"""
).strip()
)
# Rotate the storage so that the last computed value is at the end of
# the storage array.
# This is needed when the output storage array does not have a length # This is needed when the output storage array does not have a length
# equal to the number of taps plus `n_steps`. # equal to the number of taps plus `n_steps`.
# If the storage size only allows one entry, there's nothing to rotate
output_storage_post_proc_stmts.append( output_storage_post_proc_stmts.append(
dedent( dedent(
f""" f"""
if 1 < {storage_size} < (i + {tap_size}): if 1 < {storage_size} < (i + {max_offset}):
{outer_in_name}_shift = (i + {tap_size}) % ({storage_size}) {outer_in_name}_shift = (i + {max_offset}) % ({storage_size})
if {outer_in_name}_shift > 0: if {outer_in_name}_shift > 0:
{outer_in_name}_left = {outer_in_name}[:{outer_in_name}_shift] {outer_in_name}_left = {outer_in_name}[:{outer_in_name}_shift]
{outer_in_name}_right = {outer_in_name}[{outer_in_name}_shift:] {outer_in_name}_right = {outer_in_name}[{outer_in_name}_shift:]
...@@ -261,6 +243,18 @@ def numba_funcify_Scan(op: Scan, node, **kwargs): ...@@ -261,6 +243,18 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
).strip() ).strip()
) )
if op.info.as_while:
# While loops need to truncate the output storage to a length given
# by the number of iterations performed.
output_storage_post_proc_stmts.append(
dedent(
f"""
elif {storage_size} > (i + {max_offset}):
{outer_in_name} = {outer_in_name}[:i + {max_offset}]
"""
).strip()
)
# Special in-loop statements that create (nit-sot) storage arrays after a # Special in-loop statements that create (nit-sot) storage arrays after a
# single iteration is performed. This is necessary because we don't know # single iteration is performed. This is necessary because we don't know
# the exact shapes of the storage arrays that need to be allocated until # the exact shapes of the storage arrays that need to be allocated until
...@@ -288,12 +282,11 @@ def numba_funcify_Scan(op: Scan, node, **kwargs): ...@@ -288,12 +282,11 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
storage_size_name = f"{outer_in_name}_len" storage_size_name = f"{outer_in_name}_len"
storage_size_stmt = f"{storage_size_name} = {outer_in_name}.shape[0]" storage_size_stmt = f"{storage_size_name} = {outer_in_name}.shape[0]"
input_taps = inner_in_names_to_input_taps[outer_in_name] input_taps = inner_in_names_to_input_taps[outer_in_name]
tap_storage_size = -min(input_taps) max_lookback_inp_tap = -min(0, min(input_taps))
assert tap_storage_size >= 0 assert max_lookback_inp_tap >= 0
for in_tap in input_taps: for in_tap in input_taps:
tap_offset = in_tap + tap_storage_size tap_offset = max_lookback_inp_tap + in_tap
assert tap_offset >= 0
is_vector = outer_in_var.ndim == 1 is_vector = outer_in_var.ndim == 1
add_inner_in_expr( add_inner_in_expr(
outer_in_name, outer_in_name,
...@@ -302,22 +295,25 @@ def numba_funcify_Scan(op: Scan, node, **kwargs): ...@@ -302,22 +295,25 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
vector_slice_opt=is_vector, vector_slice_opt=is_vector,
) )
output_taps = inner_in_names_to_output_taps.get( output_taps = inner_in_names_to_output_taps.get(outer_in_name, [0])
outer_in_name, [tap_storage_size] for out_tap in output_taps:
) tap_offset = max_lookback_inp_tap + out_tap
inner_out_to_outer_in_stmts.extend( assert tap_offset >= 0
idx_to_str( inner_out_to_outer_in_stmts.append(
storage_name, idx_to_str(
out_tap, storage_name,
size=storage_size_name, tap_offset,
allow_scalar=True, size=storage_size_name,
allow_scalar=True,
)
) )
for out_tap in output_taps
)
add_output_storage_post_proc_stmt( if outer_in_name not in outer_in_mit_mot_names:
storage_name, output_taps, storage_size_name # MIT-SOT and SIT-SOT may require buffer rolling/truncation after the main loop
) max_offset_out_tap = max(output_taps) + max_lookback_inp_tap
add_output_storage_post_proc_stmt(
storage_name, max_offset_out_tap, storage_size_name
)
else: else:
storage_size_stmt = "" storage_size_stmt = ""
...@@ -351,7 +347,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs): ...@@ -351,7 +347,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
inner_out_to_outer_in_stmts.append( inner_out_to_outer_in_stmts.append(
idx_to_str(storage_name, 0, size=storage_size_name, allow_scalar=True) idx_to_str(storage_name, 0, size=storage_size_name, allow_scalar=True)
) )
add_output_storage_post_proc_stmt(storage_name, (0,), storage_size_name) add_output_storage_post_proc_stmt(storage_name, 0, storage_size_name)
# In case of nit-sots we are provided the length of the array in # In case of nit-sots we are provided the length of the array in
# the iteration dimension instead of actual arrays, hence we # the iteration dimension instead of actual arrays, hence we
......
...@@ -288,6 +288,26 @@ class ScanInfo: ...@@ -288,6 +288,26 @@ class ScanInfo:
+ self.n_untraced_sit_sot_outs + self.n_untraced_sit_sot_outs
) )
@property
def normalized_mit_mot_in_slices(self) -> tuple[tuple[int, ...], ...]:
"""Return mit_mot_in slices normalized as an offset from the oldest tap"""
# TODO: Make this the canonical representation
res = []
for in_slice in self.mit_mot_in_slices:
min_tap = -(min(0, min(in_slice)))
res.append(tuple(tap + min_tap for tap in in_slice))
return tuple(res)
@property
def normalized_mit_mot_out_slices(self) -> tuple[tuple[int, ...], ...]:
"""Return mit_mot_out slices normalized as an offset from the oldest tap"""
# TODO: Make this the canonical representation
res = []
for out_slice in self.mit_mot_out_slices:
min_tap = -(min(0, min(out_slice)))
res.append(tuple(tap + min_tap for tap in out_slice))
return tuple(res)
TensorConstructorType = Callable[ TensorConstructorType = Callable[
[Iterable[bool | int | None], str | np.generic], TensorType [Iterable[bool | int | None], str | np.generic], TensorType
......
...@@ -15,6 +15,7 @@ from pytensor.tensor import random ...@@ -15,6 +15,7 @@ from pytensor.tensor import random
from pytensor.tensor.math import gammaln, log from pytensor.tensor.math import gammaln, log
from pytensor.tensor.type import dmatrix, dvector, matrix, scalar, vector from pytensor.tensor.type import dmatrix, dvector, matrix, scalar, vector
from tests.link.jax.test_basic import compare_jax_and_py from tests.link.jax.test_basic import compare_jax_and_py
from tests.scan.test_basic import ScanCompatibilityTests
jax = pytest.importorskip("jax") jax = pytest.importorskip("jax")
...@@ -626,3 +627,7 @@ def test_scan_benchmark(model, mode, gradient_backend, benchmark): ...@@ -626,3 +627,7 @@ def test_scan_benchmark(model, mode, gradient_backend, benchmark):
block_until_ready(*test_input_vals) # Warmup block_until_ready(*test_input_vals) # Warmup
benchmark.pedantic(block_until_ready, test_input_vals, rounds=200, iterations=1) benchmark.pedantic(block_until_ready, test_input_vals, rounds=200, iterations=1)
def test_higher_order_derivatives():
ScanCompatibilityTests.check_higher_order_derivative(mode="JAX")
...@@ -16,6 +16,7 @@ from pytensor.tensor.elemwise import Elemwise ...@@ -16,6 +16,7 @@ from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random.utils import RandomStream from pytensor.tensor.random.utils import RandomStream
from tests import unittest_tools as utt from tests import unittest_tools as utt
from tests.link.numba.test_basic import compare_numba_and_py from tests.link.numba.test_basic import compare_numba_and_py
from tests.scan.test_basic import ScanCompatibilityTests
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -652,3 +653,7 @@ class TestScanMITSOTBuffer: ...@@ -652,3 +653,7 @@ class TestScanMITSOTBuffer:
def test_mit_sot_buffer_benchmark(self, constant_n_steps, n_steps_val, benchmark): def test_mit_sot_buffer_benchmark(self, constant_n_steps, n_steps_val, benchmark):
self.buffer_tester(constant_n_steps, n_steps_val, benchmark=benchmark) self.buffer_tester(constant_n_steps, n_steps_val, benchmark=benchmark)
def test_higher_order_derivatives():
ScanCompatibilityTests.check_higher_order_derivative(mode="NUMBA")
...@@ -4082,6 +4082,9 @@ class TestExamples: ...@@ -4082,6 +4082,9 @@ class TestExamples:
# Also, the purpose of this test is not clear. # Also, the purpose of this test is not clear.
self._grad_mout_helper(1, None) self._grad_mout_helper(1, None)
def test_higher_order_derivatives(self):
ScanCompatibilityTests.check_higher_order_derivative(mode=None)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"fn, sequences, outputs_info, non_sequences, n_steps, op_check", "fn, sequences, outputs_info, non_sequences, n_steps, op_check",
...@@ -4398,3 +4401,33 @@ def test_scan_mode_compatibility(scan_mode): ...@@ -4398,3 +4401,33 @@ def test_scan_mode_compatibility(scan_mode):
# Expected value computed by running correct Scan once # Expected value computed by running correct Scan once
np.testing.assert_allclose(fn(*numerical_inputs), [44, 38]) np.testing.assert_allclose(fn(*numerical_inputs), [44, 38])
class ScanCompatibilityTests:
"""Collection of test of subtle required behaviors of Scan, that can be reused by different backends."""
@staticmethod
def check_higher_order_derivative(mode):
"""This tests different mit-mot taps signs"""
x = pt.dscalar("x")
# xs[-1] is equivalent to x ** 16
xs = scan(
fn=lambda xtm1: xtm1**2,
outputs_info=[x],
n_steps=4,
return_updates=False,
)
r = xs[-1]
g = grad(r, x)
gg = grad(g, x)
ggg = grad(gg, x)
fn = function([x], [r, g, gg, ggg], mode=mode)
x_test = np.array(0.95, dtype=x.type.dtype)
r_res, g_res, gg_res, _ggg_res = fn(x_test)
np.testing.assert_allclose(r_res, x_test**16)
np.testing.assert_allclose(g_res, 16 * x_test**15)
np.testing.assert_allclose(gg_res, (16 * 15) * x_test**14)
# FIXME: All implementations of Scan seem to get this one wrong!
# np.testing.assert_allclose(ggg_res, (16 * 15 * 14) * x_test**13)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论