提交 14e6c781 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Reimplement JAX Scan dispatcher with MIT-MOT support

上级 97797975
import jax
from itertools import chain
import jax.numpy as jnp
import numpy as np
from jax._src.lax.control_flow import fori_loop
from pytensor.compile.mode import JAX, get_mode
from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.scan.op import Scan
def call_inner_func_with_indexed_buffers(
info,
scan_inner_func,
i,
sequences,
mit_mot_buffers,
mit_sot_buffers,
sit_sot_buffers,
shareds,
non_sequences,
):
sequence_vals = [seq[i] for seq in sequences]
# chain.from_iterable is used flatten the first dimension of each indexed buffer
# [buf1[[idx0, idx1]], buf2[[idx0, idx1]]] -> [buf1[idx0], buf1[idx1], buf2[idx0], buf2[idx1]]
# Benchmarking suggests unpacking advanced indexing on all taps is faster than basic index one tap at a time
mit_mot_vals = list(
chain.from_iterable(
buffer[(i + np.array(in_taps))]
for buffer, in_taps in zip(
mit_mot_buffers, info.mit_mot_in_slices, strict=True
)
)
)
mit_sot_vals = list(
chain.from_iterable(
# Convert negative taps (-2, -1) to positive indices (0, 1)
buffer[((i + (np.array(in_taps) - min(in_taps))) % buffer.shape[0])]
for buffer, in_taps in zip(
mit_sot_buffers, info.mit_sot_in_slices, strict=True
)
)
)
sit_sot_vals = [buffer[i % buffer.shape[0]] for buffer in sit_sot_buffers]
return scan_inner_func(
*sequence_vals,
*mit_mot_vals,
*mit_sot_vals,
*sit_sot_vals,
*shareds,
*non_sequences,
)
def update_buffers(buffers, update_vals, indices, may_roll: bool = True):
return tuple(
buffer.at[(index % buffer.shape[0]) if may_roll else index].set(update_val)
for buffer, update_val, index in zip(buffers, update_vals, indices, strict=True)
)
def align_buffers(buffers, n_steps, max_taps):
return [
jnp.roll(
buffer,
shift=jnp.where(
# Only needs rolling if last write position is beyond the buffer length
(n_steps + max_tap) > buffer.shape[0],
# Roll left by the amount of overflow
-((n_steps + max_tap + 1) % buffer.shape[0]),
0,
),
axis=0,
)
for buffer, max_tap in zip(buffers, max_taps, strict=True)
]
@jax_funcify.register(Scan)
def jax_funcify_Scan(op: Scan, **kwargs):
def jax_funcify_Scan(op: Scan, node, **kwargs):
op = op # Need to bind to a local variable
info = op.info
if info.as_while:
raise NotImplementedError("While Scan cannot yet be converted to JAX")
if info.n_mit_mot:
raise NotImplementedError(
"Scan with MIT-MOT (gradients of scan) cannot yet be converted to JAX"
)
# Optimize inner graph (exclude any defalut rewrites that are incompatible with JAX mode)
rewriter = (
get_mode(op.mode).including("jax").excluding(*JAX._optimizer.exclude).optimizer
)
rewriter(op.fgraph)
scan_inner_func = jax_funcify(op.fgraph, **kwargs)
def scan(*outer_inputs):
# Extract JAX scan inputs
outer_inputs = list(outer_inputs)
n_steps = outer_inputs[0] # JAX `length`
seqs = [seq[:n_steps] for seq in op.outer_seqs(outer_inputs)] # JAX `xs`
mit_sot_init = []
for tap, seq in zip(
op.info.mit_sot_in_slices, op.outer_mitsot(outer_inputs), strict=True
):
init_slice = seq[: abs(min(tap))]
mit_sot_init.append(init_slice)
sit_sot_init = [seq[0] for seq in op.outer_sitsot(outer_inputs)]
init_carry = (
mit_sot_init,
sit_sot_init,
op.outer_shared(outer_inputs),
op.outer_non_seqs(outer_inputs),
) # JAX `init`
def jax_args_to_inner_func_args(carry, x):
"""Convert JAX scan arguments into format expected by scan_inner_func.
scan(carry, x) -> scan_inner_func(seqs, mit_sot, sit_sot, shared, non_seqs)
"""
# `carry` contains all inner taps, shared terms, and non_seqs
(
inner_mit_sot,
inner_sit_sot,
inner_shared,
inner_non_seqs,
) = carry
# `x` contains the inner sequences
inner_seqs = x
mit_sot_flatten = []
for array, index in zip(
inner_mit_sot, op.info.mit_sot_in_slices, strict=True
):
mit_sot_flatten.extend(array[jnp.array(index)])
inner_scan_inputs = [
*inner_seqs,
*mit_sot_flatten,
*inner_sit_sot,
*inner_shared,
*inner_non_seqs,
]
return inner_scan_inputs
def inner_func_outs_to_jax_outs(
old_carry,
inner_scan_outs,
):
"""Convert inner_scan_func outputs into format expected by JAX scan.
# TODO: Use scan name from Op when available
scan_inner_func = jax_funcify(op.fgraph, fgraph_name="scan_inner_func", **kwargs)
def scan(*outer_inputs, op=op, node=node):
n_steps = outer_inputs[0]
sequences = op.outer_seqs(outer_inputs)
has_empty_sequences = any(seq.shape[0] == 0 for seq in sequences)
init_mit_mot_buffers = op.outer_mitmot(outer_inputs)
init_mit_sot_buffers = op.outer_mitsot(outer_inputs)
init_sit_sot_buffers = op.outer_sitsot(outer_inputs)
nit_sot_buffer_lens = op.outer_nitsot(outer_inputs)
# Shareds are special-cased SIT-SOTs that are not traced, but updated at each step.
# Only last value is returned. It's a hack for special types (like RNG) that can't be "concatenated" over time.
init_shareds = op.outer_shared(outer_inputs)
non_sequences = op.outer_non_seqs(outer_inputs)
assert (
1
+ len(sequences)
+ len(init_mit_mot_buffers)
+ len(init_mit_sot_buffers)
+ len(init_sit_sot_buffers)
+ len(nit_sot_buffer_lens)
+ len(init_shareds)
+ len(non_sequences)
) == len(outer_inputs)
# Initialize NIT-SOT buffers
if nit_sot_buffer_lens:
if has_empty_sequences:
# In this case we cannot call the inner function to infer the shapes of the nit_sot outputs
# So we must rely on static shapes of the outputs (if available)
nit_sot_core_shapes = [
n.type.shape for n in op.inner_nitsot_outs(op.fgraph.outputs)
]
if any(d is None for shape in nit_sot_core_shapes for d in shape):
raise ValueError(
"Scan with NIT-SOT outputs (None in outputs_info) cannot have 0 steps unless the output shapes are statically known)\n"
f"The static shapes of the NIT-SOT outputs for this Scan {node.op} are: {nit_sot_core_shapes}."
)
old_carry + (mit_sot_outs, sit_sot_outs, nit_sot_outs, shared_outs) -> (new_carry, ys)
"""
(
inner_mit_sot,
_inner_sit_sot,
inner_shared,
inner_non_seqs,
) = old_carry
inner_mit_sot_outs = op.inner_mitsot_outs(inner_scan_outs)
inner_sit_sot_outs = op.inner_sitsot_outs(inner_scan_outs)
inner_nit_sot_outs = op.inner_nitsot_outs(inner_scan_outs)
inner_shared_outs = op.inner_shared_outs(inner_scan_outs)
# Replace the oldest mit_sot tap by the newest value
inner_mit_sot_new = [
jnp.concatenate([old_mit_sot[1:], new_val[None, ...]], axis=0)
for old_mit_sot, new_val in zip(
inner_mit_sot, inner_mit_sot_outs, strict=True
else:
# Otherwise, call the function once to get the shapes and dtypes of the nit_sot outputs
buffer_vals = call_inner_func_with_indexed_buffers(
info,
scan_inner_func,
0,
sequences,
init_mit_mot_buffers,
init_mit_sot_buffers,
init_sit_sot_buffers,
init_shareds,
non_sequences,
)
nit_sot_core_shapes = [
n.shape for n in op.inner_nitsot_outs(buffer_vals)
]
nit_sot_dtypes = [
n.type.dtype for n in op.inner_nitsot_outs(op.fgraph.outputs)
]
# Nothing needs to be done with sit_sot
inner_sit_sot_new = inner_sit_sot_outs
inner_shared_new = inner_shared
# Replace old shared inputs by new shared outputs
inner_shared_new[: len(inner_shared_outs)] = inner_shared_outs
new_carry = (
inner_mit_sot_new,
inner_sit_sot_new,
inner_shared_new,
inner_non_seqs,
init_nit_sot_buffers = tuple(
jnp.empty(
(nit_sot_buffer_len, *nit_sot_core_shape),
dtype=nit_sot_dtype,
)
for nit_sot_buffer_len, nit_sot_core_shape, nit_sot_dtype in zip(
nit_sot_buffer_lens,
nit_sot_core_shapes,
nit_sot_dtypes,
strict=True,
)
)
else:
init_nit_sot_buffers = ()
if has_empty_sequences:
# fori_loop still gets called with n_steps=0, which would raise an IndexError, we return early here
init_vals = (
*init_mit_mot_buffers,
*init_mit_sot_buffers,
*init_sit_sot_buffers,
*init_nit_sot_buffers,
*init_shareds,
)
return init_vals[0] if len(init_vals) == 1 else init_vals
# Shared variables and non_seqs are not traced
traced_outs = [
*inner_mit_sot_outs,
*inner_sit_sot_outs,
*inner_nit_sot_outs,
]
return new_carry, traced_outs
def jax_inner_func(carry, x):
inner_args = jax_args_to_inner_func_args(carry, x)
inner_scan_outs = list(scan_inner_func(*inner_args))
new_carry, traced_outs = inner_func_outs_to_jax_outs(carry, inner_scan_outs)
return new_carry, traced_outs
def body_fun(i, prev_vals):
(
mit_mot_buffers,
mit_sot_buffers,
sit_sot_buffers,
nit_sot_buffers,
shareds,
) = prev_vals
next_vals = call_inner_func_with_indexed_buffers(
info,
scan_inner_func,
i,
sequences,
mit_mot_buffers,
mit_sot_buffers,
sit_sot_buffers,
shareds,
non_sequences,
)
# For MIT-MOT buffers, we want to store at the positions indicated by the output taps
mit_mot_updated_buffers = update_buffers(
mit_mot_buffers,
op.inner_mitmot_outs_grouped(next_vals),
# Taps are positive, we stack them to obtain advanced indices
indices=[i + jnp.stack(taps) for taps in info.mit_mot_out_slices],
# MIT-MOT buffers never roll, as they are never truncated
may_roll=False,
)
# For regular buffers, we want to store at the position after the last reading
mit_sot_updated_buffers = update_buffers(
mit_sot_buffers,
op.inner_mitsot_outs(next_vals),
indices=[i - min(taps) for taps in info.mit_sot_in_slices],
)
sit_sot_updated_buffers = update_buffers(
sit_sot_buffers,
op.inner_sitsot_outs(next_vals),
# Taps are always -1 for SIT-SOT, so we just use i + 1
indices=[i + 1 for _ in sit_sot_buffers],
)
nit_sot_updated_buffers = update_buffers(
nit_sot_buffers,
op.inner_nitsot_outs(next_vals),
# Taps are always 0 for NIT-SOT, so we just use i
indices=[i for _ in nit_sot_buffers],
)
shareds_update_vals = op.inner_shared_outs(next_vals)
return (
mit_mot_updated_buffers,
mit_sot_updated_buffers,
sit_sot_updated_buffers,
nit_sot_updated_buffers,
shareds_update_vals,
)
# Extract PyTensor scan outputs
final_carry, traces = jax.lax.scan(
jax_inner_func, init_carry, seqs, length=n_steps
(
updated_mit_mot_buffers,
updated_mit_sot_buffers,
updated_sit_sot_buffers,
updated_nit_sot_buffers,
updated_shareds,
) = fori_loop(
0,
n_steps,
body_fun,
init_val=(
init_mit_mot_buffers,
init_mit_sot_buffers,
init_sit_sot_buffers,
init_nit_sot_buffers,
init_shareds,
),
)
def get_partial_traces(traces):
"""Convert JAX scan traces to PyTensor traces.
# Roll the output buffers to match PyTensor Scan semantics
# MIT-MOT buffers are never truncated, so no rolling is needed
aligned_mit_mot_buffers = updated_mit_mot_buffers
aligned_mit_sot_buffers = align_buffers(
updated_mit_sot_buffers,
n_steps,
# (-3, -1) -> max is 2
max_taps=[-min(taps) - 1 for taps in info.mit_sot_in_slices],
)
We need to:
1. Prepend initial states to JAX output traces
2. Slice final traces if Scan was instructed to only keep a portion
"""
aligned_sit_sot_buffers = align_buffers(
updated_sit_sot_buffers,
n_steps,
max_taps=[0 for _ in updated_sit_sot_buffers],
)
aligned_nit_sot_buffers = align_buffers(
updated_nit_sot_buffers,
n_steps,
max_taps=[0 for _ in updated_nit_sot_buffers],
)
init_states = mit_sot_init + sit_sot_init + [None] * op.info.n_nit_sot
buffers = (
op.outer_mitsot(outer_inputs)
+ op.outer_sitsot(outer_inputs)
+ op.outer_nitsot(outer_inputs)
all_outputs = tuple(
chain.from_iterable(
(
aligned_mit_mot_buffers,
aligned_mit_sot_buffers,
aligned_sit_sot_buffers,
aligned_nit_sot_buffers,
updated_shareds,
)
)
partial_traces = []
for init_state, trace, buffer in zip(
init_states, traces, buffers, strict=True
):
if init_state is not None:
# MIT-SOT and SIT-SOT: The final output should be as long as the input buffer
trace = jnp.atleast_1d(trace)
init_state = jnp.expand_dims(
init_state, range(trace.ndim - init_state.ndim)
)
full_trace = jnp.concatenate([init_state, trace], axis=0)
buffer_size = buffer.shape[0]
else:
# NIT-SOT: Buffer is just the number of entries that should be returned
full_trace = jnp.atleast_1d(trace)
buffer_size = buffer
partial_trace = full_trace[-buffer_size:]
partial_traces.append(partial_trace)
return partial_traces
def get_shared_outs(final_carry):
"""Retrive last state of shared_outs from final_carry.
These outputs cannot be traced in PyTensor Scan
"""
(
_inner_out_mit_sot,
_inner_out_sit_sot,
inner_out_shared,
_inner_in_non_seqs,
) = final_carry
shared_outs = inner_out_shared[: info.n_shared_outs]
return list(shared_outs)
scan_outs_final = get_partial_traces(traces) + get_shared_outs(final_carry)
if len(scan_outs_final) == 1:
scan_outs_final = scan_outs_final[0]
return scan_outs_final
)
return all_outputs[0] if len(all_outputs) == 1 else all_outputs
return scan
......@@ -307,6 +307,17 @@ class ScanMethodsMixin:
n_taps = sum(len(x) for x in self.info.mit_mot_out_slices)
return list_outputs[:n_taps]
def inner_mitmot_outs_grouped(self, list_outputs):
# Like inner_mitmot_outs but returns a list of lists, one per mitmot
# Instead of a flat list
n_taps = [len(x) for x in self.info.mit_mot_out_slices]
grouped_outs = []
offset = 0
for nt in n_taps:
grouped_outs.append(list_outputs[offset : offset + nt])
offset += nt
return grouped_outs
def outer_mitmot_outs(self, list_outputs):
return list_outputs[: self.info.n_mit_mot]
......
......@@ -7,6 +7,8 @@ import pytensor.tensor as pt
from pytensor import function, ifelse, shared
from pytensor.compile import get_mode
from pytensor.configdefaults import config
from pytensor.graph import Apply, Op
from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.scan import until
from pytensor.scan.basic import scan
from pytensor.scan.op import Scan
......@@ -98,16 +100,26 @@ def test_scan_nit_sot(view):
assert len(scan_nodes) == 1
@pytest.mark.xfail(raises=NotImplementedError)
def test_scan_mit_mot():
xs = pt.vector("xs", shape=(10,))
ys, _ = scan(
lambda xtm2, xtm1: (xtm2 + xtm1),
outputs_info=[{"initial": xs, "taps": [-2, -1]}],
def step(xtm1, ytm3, ytm1, rho):
return (xtm1 + ytm1) * rho, ytm3 * (1 - rho) + ytm1 * rho
rho = pt.scalar("rho", dtype="float64")
x0 = pt.vector("xs", shape=(2,))
y0 = pt.vector("ys", shape=(3,))
[outs, _], _ = scan(
step,
outputs_info=[x0, {"initial": y0, "taps": [-3, -1]}],
non_sequences=[rho],
n_steps=10,
)
grads_wrt_xs = pt.grad(ys.sum(), wrt=xs)
compare_jax_and_py([xs], [grads_wrt_xs], [np.arange(10)])
grads = pt.grad(outs.sum(), wrt=[x0, y0, rho])
compare_jax_and_py(
[x0, y0, rho],
grads,
[np.arange(2), np.array([0.5, 0.5, 0.5]), np.array(0.95)],
jax_mode=get_mode("JAX"),
)
def test_scan_update():
......@@ -323,13 +335,41 @@ def test_default_mode_excludes_incompatible_rewrites():
def test_dynamic_sequence_length():
x = pt.tensor("x", shape=(None,))
out, _ = scan(lambda x: x + 1, sequences=[x])
class IncWithoutStaticShape(Op):
def make_node(self, x):
x = pt.as_tensor_variable(x)
return Apply(self, [x], [pt.tensor(shape=(None,) * x.type.ndim)])
def perform(self, node, inputs, outputs):
outputs[0][0] = inputs[0] + 1
@jax_funcify.register(IncWithoutStaticShape)
def _(op, **kwargs):
return lambda x: x + 1
inc_without_static_shape = IncWithoutStaticShape()
x = pt.tensor("x", shape=(None, 3))
out, _ = scan(
lambda x: inc_without_static_shape(x), outputs_info=[None], sequences=[x]
)
f = function([x], out, mode=get_mode("JAX").excluding("scan"))
assert sum(isinstance(node.op, Scan) for node in f.maker.fgraph.apply_nodes) == 1
np.testing.assert_allclose(f([]), [])
np.testing.assert_allclose(f([1, 2, 3]), np.array([2, 3, 4]))
np.testing.assert_allclose(f([[1, 2, 3]]), np.array([[2, 3, 4]]))
with pytest.raises(ValueError):
f(np.zeros((0, 3)))
# But should be fine with static shape
out2, _ = scan(
lambda x: pt.specify_shape(inc_without_static_shape(x), x.shape),
outputs_info=[None],
sequences=[x],
)
f2 = function([x], out2, mode=get_mode("JAX").excluding("scan"))
np.testing.assert_allclose(f2([[1, 2, 3]]), np.array([[2, 3, 4]]))
np.testing.assert_allclose(f2(np.zeros((0, 3))), np.empty((0, 3)))
def SEIR_model_logp():
......@@ -499,9 +539,6 @@ def cyclical_reduction():
@pytest.mark.parametrize("mode", ("0forward", "1backward", "2both"))
@pytest.mark.parametrize("model", [cyclical_reduction, SEIR_model_logp])
def test_scan_benchmark(model, mode, gradient_backend, benchmark):
if gradient_backend == "PYTENSOR" and mode in ("1backward", "2both"):
pytest.skip("PYTENSOR backend does not support backward mode yet")
model_dict = model()
graph_inputs = model_dict["graph_inputs"]
differentiable_vars = model_dict["differentiable_vars"]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论