提交 10e5c92f authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

JAX Scan: Fix bug when recurring trace matches final requested size

上级 25e41c39
......@@ -10,7 +10,7 @@ from pytensor.scan.op import Scan
@jax_funcify.register(Scan)
def jax_funcify_Scan(op: Scan, **kwargs):
def jax_funcify_Scan(op: Scan, node, **kwargs):
# Note: This implementation is different from the internal PyTensor Scan op.
# In particular, we don't make use of the provided buffers for recurring outputs (MIT-SOT, SIT-SOT)
# These buffers include the initial state and enough space to store as many intermediate results as needed.
......@@ -219,6 +219,8 @@ def jax_funcify_Scan(op: Scan, **kwargs):
if trace.shape[0] > buffer_size:
# Trace is longer than buffer, keep just the last `buffer.shape[0]` entries
partial_trace = trace[-buffer_size:]
elif trace.shape[0] == buffer_size:
partial_trace = trace
else:
# Trace is shorter than buffer, this happens when we keep the initial_state
if init_state.ndim < buffer.ndim:
......
......@@ -10,8 +10,8 @@ from pytensor.configdefaults import config
from pytensor.graph import Apply, Op
from pytensor.scan import until
from pytensor.scan.basic import scan
from pytensor.scan.op import Scan
from pytensor.tensor import random
from pytensor.scan.op import Scan, ScanInfo
from pytensor.tensor import as_tensor, empty, random
from pytensor.tensor.math import gammaln, log
from pytensor.tensor.type import dmatrix, dvector, matrix, scalar, vector
from tests.link.jax.test_basic import compare_jax_and_py
......@@ -631,3 +631,37 @@ def test_scan_benchmark(model, mode, gradient_backend, benchmark):
def test_higher_order_derivatives():
ScanCompatibilityTests.check_higher_order_derivative(mode="JAX")
def test_trace_truncation_regression_bug():
# Regression bug for a case where the final recurring trace size matched exactly with the number of steps
n_steps = as_tensor(7, dtype=int)
x0 = scalar("x0")
x0_buffer = empty((n_steps,))[0].set(x0)
# I don't know how to create such a Scan naturally, so we use the internal API
xtm1 = x0.type()
scan_op = Scan(
inputs=[xtm1],
outputs=[xtm1 + 1],
info=ScanInfo(
n_seqs=0,
mit_mot_in_slices=(),
mit_mot_out_slices=(),
mit_sot_in_slices=(),
sit_sot_in_slices=((-1,),),
n_nit_sot=0,
n_untraced_sit_sot_outs=0,
n_non_seqs=0,
as_while=False,
),
)
xs_with_x0 = scan_op(n_steps, x0_buffer)
compare_jax_and_py(
[x0],
[xs_with_x0],
[np.array(0)],
jax_mode="JAX",
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论