Unverified 提交 9ae07ab0 authored 作者: jessegrabowski's avatar jessegrabowski 提交者: GitHub

Fix JAX Scan for output ndim > 1 (#288)

上级 cb417fe5
...@@ -154,10 +154,11 @@ def jax_funcify_Scan(op: Scan, **kwargs): ...@@ -154,10 +154,11 @@ def jax_funcify_Scan(op: Scan, **kwargs):
for init_state, trace, buffer in zip(init_states, traces, buffers): for init_state, trace, buffer in zip(init_states, traces, buffers):
if init_state is not None: if init_state is not None:
# MIT-SOT and SIT-SOT: The final output should be as long as the input buffer # MIT-SOT and SIT-SOT: The final output should be as long as the input buffer
full_trace = jnp.concatenate( trace = jnp.atleast_1d(trace)
[jnp.atleast_1d(init_state), jnp.atleast_1d(trace)], init_state = jnp.expand_dims(
axis=0, init_state, range(trace.ndim - init_state.ndim)
) )
full_trace = jnp.concatenate([init_state, trace], axis=0)
buffer_size = buffer.shape[0] buffer_size = buffer.shape[0]
else: else:
# NIT-SOT: Buffer is just the number of entries that should be returned # NIT-SOT: Buffer is just the number of entries that should be returned
......
...@@ -13,7 +13,7 @@ from pytensor.scan.basic import scan ...@@ -13,7 +13,7 @@ from pytensor.scan.basic import scan
from pytensor.scan.op import Scan from pytensor.scan.op import Scan
from pytensor.tensor import random from pytensor.tensor import random
from pytensor.tensor.math import gammaln, log from pytensor.tensor.math import gammaln, log
from pytensor.tensor.type import lscalar, scalar, vector from pytensor.tensor.type import dmatrix, dvector, lscalar, scalar, vector
from tests.link.jax.test_basic import compare_jax_and_py from tests.link.jax.test_basic import compare_jax_and_py
...@@ -317,3 +317,104 @@ def test_scan_mitsot_with_nonseq(): ...@@ -317,3 +317,104 @@ def test_scan_mitsot_with_nonseq():
test_input_vals = [np.array(10.0).astype(config.floatX)] test_input_vals = [np.array(10.0).astype(config.floatX)]
compare_jax_and_py(out_fg, test_input_vals) compare_jax_and_py(out_fg, test_input_vals)
@pytest.mark.parametrize("x0_func", [dvector, dmatrix])
@pytest.mark.parametrize("A_func", [dmatrix, dmatrix])
def test_nd_scan_sit_sot(x0_func, A_func):
x0 = x0_func("x0")
A = A_func("A")
n_steps = 3
k = 3
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
xs, _ = scan(
lambda X, A: A @ X,
non_sequences=[A],
outputs_info=[x0],
n_steps=n_steps,
mode=get_mode("JAX"),
)
x0_val = (
np.arange(k, dtype=config.floatX)
if x0.ndim == 1
else np.diag(np.arange(k, dtype=config.floatX))
)
A_val = np.eye(k, dtype=config.floatX)
fg = FunctionGraph([x0, A], [xs])
test_input_vals = [x0_val, A_val]
compare_jax_and_py(fg, test_input_vals)
def test_nd_scan_sit_sot_with_seq():
n_steps = 3
k = 3
x = at.matrix("x0", shape=(n_steps, k))
A = at.matrix("A", shape=(k, k))
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
xs, _ = scan(
lambda X, A: A @ X,
non_sequences=[A],
sequences=[x],
n_steps=n_steps,
mode=get_mode("JAX"),
)
x_val = np.arange(n_steps * k, dtype=config.floatX).reshape(n_steps, k)
A_val = np.eye(k, dtype=config.floatX)
fg = FunctionGraph([x, A], [xs])
test_input_vals = [x_val, A_val]
compare_jax_and_py(fg, test_input_vals)
def test_nd_scan_mit_sot():
x0 = at.matrix("x0", shape=(3, 3))
A = at.matrix("A", shape=(3, 3))
B = at.matrix("B", shape=(3, 3))
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
xs, _ = scan(
lambda xtm3, xtm1, A, B: A @ xtm3 + B @ xtm1,
outputs_info=[{"initial": x0, "taps": [-3, -1]}],
non_sequences=[A, B],
n_steps=10,
mode=get_mode("JAX"),
)
fg = FunctionGraph([x0, A, B], [xs])
x0_val = np.arange(9, dtype=config.floatX).reshape(3, 3)
A_val = np.eye(3, dtype=config.floatX)
B_val = np.eye(3, dtype=config.floatX)
test_input_vals = [x0_val, A_val, B_val]
compare_jax_and_py(fg, test_input_vals)
def test_nd_scan_sit_sot_with_carry():
x0 = at.vector("x0", shape=(3,))
A = at.matrix("A", shape=(3, 3))
def step(x, A):
return A @ x, x.sum()
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
xs, _ = scan(
step,
outputs_info=[x0, None],
non_sequences=[A],
n_steps=10,
mode=get_mode("JAX"),
)
fg = FunctionGraph([x0, A], xs)
x0_val = np.arange(3, dtype=config.floatX)
A_val = np.eye(3, dtype=config.floatX)
test_input_vals = [x0_val, A_val]
compare_jax_and_py(fg, test_input_vals)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论