提交 b3b68618 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Only MIT-MOT require working on buffers directly

Using JAX Scan machinery to create MIT-SOT, SIT-SOT, and NIT-SOT buffers for us seems to be more performant than working directly on the pre-allocated buffers and reading/writing at every iteration. There is no machinery to work with MIT-MOT directly (just like in PyTensor user-facing Scan).
上级 14e6c781
......@@ -2,85 +2,26 @@ from itertools import chain
import jax.numpy as jnp
import numpy as np
from jax._src.lax.control_flow import fori_loop
from jax._src.lax.control_flow import scan as jax_scan
from pytensor.compile.mode import JAX, get_mode
from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.scan.op import Scan
def call_inner_func_with_indexed_buffers(
info,
scan_inner_func,
i,
sequences,
mit_mot_buffers,
mit_sot_buffers,
sit_sot_buffers,
shareds,
non_sequences,
):
sequence_vals = [seq[i] for seq in sequences]
# chain.from_iterable is used flatten the first dimension of each indexed buffer
# [buf1[[idx0, idx1]], buf2[[idx0, idx1]]] -> [buf1[idx0], buf1[idx1], buf2[idx0], buf2[idx1]]
# Benchmarking suggests unpacking advanced indexing on all taps is faster than basic index one tap at a time
mit_mot_vals = list(
chain.from_iterable(
buffer[(i + np.array(in_taps))]
for buffer, in_taps in zip(
mit_mot_buffers, info.mit_mot_in_slices, strict=True
)
)
)
mit_sot_vals = list(
chain.from_iterable(
# Convert negative taps (-2, -1) to positive indices (0, 1)
buffer[((i + (np.array(in_taps) - min(in_taps))) % buffer.shape[0])]
for buffer, in_taps in zip(
mit_sot_buffers, info.mit_sot_in_slices, strict=True
)
)
)
sit_sot_vals = [buffer[i % buffer.shape[0]] for buffer in sit_sot_buffers]
return scan_inner_func(
*sequence_vals,
*mit_mot_vals,
*mit_sot_vals,
*sit_sot_vals,
*shareds,
*non_sequences,
)
def update_buffers(buffers, update_vals, indices, may_roll: bool = True):
return tuple(
buffer.at[(index % buffer.shape[0]) if may_roll else index].set(update_val)
for buffer, update_val, index in zip(buffers, update_vals, indices, strict=True)
)
def align_buffers(buffers, n_steps, max_taps):
return [
jnp.roll(
buffer,
shift=jnp.where(
# Only needs rolling if last write position is beyond the buffer length
(n_steps + max_tap) > buffer.shape[0],
# Roll left by the amount of overflow
-((n_steps + max_tap + 1) % buffer.shape[0]),
0,
),
axis=0,
)
for buffer, max_tap in zip(buffers, max_taps, strict=True)
]
@jax_funcify.register(Scan)
def jax_funcify_Scan(op: Scan, node, **kwargs):
op = op # Need to bind to a local variable
def jax_funcify_Scan(op: Scan, **kwargs):
# Note: This implementation is different from the internal PyTensor Scan op.
# In particular, we don't make use of the provided buffers for recurring outputs (MIT-SOT, SIT-SOT)
# These buffers include the initial state and enough space to store as many intermediate results as needed.
# Instead, we let JAX scan recreate the concatenated buffer itself from the values computed in each iteration,
# and then prepend the initial_state and/or truncate results we don't need at the end.
# Likewise, we allow JAX to stack NIT-SOT outputs itself, instead of writing to an empty buffer with the final size.
# In contrast, MIT-MOT behave like PyTensor Scan. We read from and write to the original buffer as we iterate.
# Hopefully, JAX can do the same sort of memory optimizations as PyTensor does.
# Performance-wise, the benchmarks show this approach is better, specially when auto-diffing through JAX.
# For an implementation that is closer to the internal PyTensor Scan, check intermediate commit in
# https://github.com/pymc-devs/pytensor/pull/1651
info = op.info
if info.as_while:
......@@ -91,199 +32,207 @@ def jax_funcify_Scan(op: Scan, node, **kwargs):
get_mode(op.mode).including("jax").excluding(*JAX._optimizer.exclude).optimizer
)
rewriter(op.fgraph)
# TODO: Use scan name from Op when available
scan_inner_func = jax_funcify(op.fgraph, fgraph_name="scan_inner_func", **kwargs)
def scan(*outer_inputs, op=op, node=node):
n_steps = outer_inputs[0]
sequences = op.outer_seqs(outer_inputs)
has_empty_sequences = any(seq.shape[0] == 0 for seq in sequences)
init_mit_mot_buffers = op.outer_mitmot(outer_inputs)
init_mit_sot_buffers = op.outer_mitsot(outer_inputs)
init_sit_sot_buffers = op.outer_sitsot(outer_inputs)
nit_sot_buffer_lens = op.outer_nitsot(outer_inputs)
# Shareds are special-cased SIT-SOTs that are not traced, but updated at each step.
# Only last value is returned. It's a hack for special types (like RNG) that can't be "concatenated" over time.
init_shareds = op.outer_shared(outer_inputs)
non_sequences = op.outer_non_seqs(outer_inputs)
assert (
1
+ len(sequences)
+ len(init_mit_mot_buffers)
+ len(init_mit_sot_buffers)
+ len(init_sit_sot_buffers)
+ len(nit_sot_buffer_lens)
+ len(init_shareds)
+ len(non_sequences)
) == len(outer_inputs)
# Initialize NIT-SOT buffers
if nit_sot_buffer_lens:
if has_empty_sequences:
# In this case we cannot call the inner function to infer the shapes of the nit_sot outputs
# So we must rely on static shapes of the outputs (if available)
nit_sot_core_shapes = [
n.type.shape for n in op.inner_nitsot_outs(op.fgraph.outputs)
]
if any(d is None for shape in nit_sot_core_shapes for d in shape):
raise ValueError(
"Scan with NIT-SOT outputs (None in outputs_info) cannot have 0 steps unless the output shapes are statically known)\n"
f"The static shapes of the NIT-SOT outputs for this Scan {node.op} are: {nit_sot_core_shapes}."
)
scan_inner_func = jax_funcify(op.fgraph, **kwargs)
def scan(*outer_inputs):
# Extract JAX scan inputs
# JAX doesn't want some inputs to be tuple, but later lists (e.g., from list-comprehensions).
# We convert everything to list, so that it remains a list after slicing.
outer_inputs = list(outer_inputs)
n_steps = outer_inputs[0] # JAX `length`
seqs = [seq[:n_steps] for seq in op.outer_seqs(outer_inputs)] # JAX `xs`
# MIT-MOT don't have a concept of "initial state"
# The whole buffer is meaningful at the start of the Scan
mit_mot_init = op.outer_mitmot(outer_inputs)
# For MIT-SOT and SIT-SOT, extract the initial states from the outer input buffers
mit_sot_init = [
buff[: -min(tap)]
for buff, tap in zip(
op.outer_mitsot(outer_inputs), op.info.mit_sot_in_slices, strict=True
)
]
sit_sot_init = [buff[0] for buff in op.outer_sitsot(outer_inputs)]
else:
# Otherwise, call the function once to get the shapes and dtypes of the nit_sot outputs
buffer_vals = call_inner_func_with_indexed_buffers(
info,
scan_inner_func,
0,
sequences,
init_mit_mot_buffers,
init_mit_sot_buffers,
init_sit_sot_buffers,
init_shareds,
non_sequences,
)
nit_sot_core_shapes = [
n.shape for n in op.inner_nitsot_outs(buffer_vals)
]
nit_sot_dtypes = [
n.type.dtype for n in op.inner_nitsot_outs(op.fgraph.outputs)
]
init_nit_sot_buffers = tuple(
jnp.empty(
(nit_sot_buffer_len, *nit_sot_core_shape),
dtype=nit_sot_dtype,
init_carry = (
0, # loop counter, needed for indexing MIT-MOT
mit_mot_init,
mit_sot_init,
sit_sot_init,
op.outer_shared(outer_inputs),
op.outer_non_seqs(outer_inputs),
) # JAX `init`
def jax_args_to_inner_func_args(carry, x):
"""Convert JAX scan arguments into format expected by scan_inner_func.
scan(carry, x) -> scan_inner_func(seqs, MIT-SOT, SIT-SOT, shared, non_seqs)
"""
# `carry` contains all inner taps, shared terms, and non_seqs
(
i,
inner_mit_mot,
inner_mit_sot,
inner_sit_sot,
inner_shared,
inner_non_seqs,
) = carry
# `x` contains the inner sequences
inner_seqs = x
# chain.from_iterable is used to flatten the first dimension of each indexed buffer
# [buf1[[idx0, idx1]], buf2[[idx0, idx1]]] -> [buf1[idx0], buf1[idx1], buf2[idx0], buf2[idx1]]
# Benchmarking suggests unpacking advanced indexing on all taps is faster than basic index one tap at a time
mit_mot_flatten = list(
chain.from_iterable(
buffer[(i + np.array(taps))]
for buffer, taps in zip(
inner_mit_mot, info.mit_mot_in_slices, strict=True
)
)
for nit_sot_buffer_len, nit_sot_core_shape, nit_sot_dtype in zip(
nit_sot_buffer_lens,
nit_sot_core_shapes,
nit_sot_dtypes,
strict=True,
)
mit_sot_flatten = list(
chain.from_iterable(
buffer[np.array(taps)]
for buffer, taps in zip(
inner_mit_sot, info.mit_sot_in_slices, strict=True
)
)
)
else:
init_nit_sot_buffers = ()
if has_empty_sequences:
# fori_loop still gets called with n_steps=0, which would raise an IndexError, we return early here
init_vals = (
*init_mit_mot_buffers,
*init_mit_sot_buffers,
*init_sit_sot_buffers,
*init_nit_sot_buffers,
*init_shareds,
return (
*inner_seqs,
*mit_mot_flatten,
*mit_sot_flatten,
*inner_sit_sot,
*inner_shared,
*inner_non_seqs,
)
return init_vals[0] if len(init_vals) == 1 else init_vals
def body_fun(i, prev_vals):
def inner_func_outs_to_jax_outs(
old_carry,
inner_scan_outs,
):
"""Convert inner_scan_func outputs into format expected by JAX scan.
old_carry + (MIT-SOT_outs, SIT-SOT_outs, NIT-SOT_outs, shared_outs) -> (new_carry, ys)
"""
(
mit_mot_buffers,
mit_sot_buffers,
sit_sot_buffers,
nit_sot_buffers,
shareds,
) = prev_vals
next_vals = call_inner_func_with_indexed_buffers(
info,
scan_inner_func,
i,
sequences,
mit_mot_buffers,
mit_sot_buffers,
sit_sot_buffers,
shareds,
non_sequences,
)
# For MIT-MOT buffers, we want to store at the positions indicated by the output taps
mit_mot_updated_buffers = update_buffers(
mit_mot_buffers,
op.inner_mitmot_outs_grouped(next_vals),
# Taps are positive, we stack them to obtain advanced indices
indices=[i + jnp.stack(taps) for taps in info.mit_mot_out_slices],
# MIT-MOT buffers never roll, as they are never truncated
may_roll=False,
)
# For regular buffers, we want to store at the position after the last reading
mit_sot_updated_buffers = update_buffers(
mit_sot_buffers,
op.inner_mitsot_outs(next_vals),
indices=[i - min(taps) for taps in info.mit_sot_in_slices],
)
sit_sot_updated_buffers = update_buffers(
sit_sot_buffers,
op.inner_sitsot_outs(next_vals),
# Taps are always -1 for SIT-SOT, so we just use i + 1
indices=[i + 1 for _ in sit_sot_buffers],
)
nit_sot_updated_buffers = update_buffers(
nit_sot_buffers,
op.inner_nitsot_outs(next_vals),
# Taps are always 0 for NIT-SOT, so we just use i
indices=[i for _ in nit_sot_buffers],
old_mit_mot,
old_mit_sot,
_old_sit_sot,
_old_shared,
inner_non_seqs,
) = old_carry
new_mit_mot_vals = op.inner_mitmot_outs_grouped(inner_scan_outs)
new_mit_sot_vals = op.inner_mitsot_outs(inner_scan_outs)
new_sit_sot = op.inner_sitsot_outs(inner_scan_outs)
new_nit_sot = op.inner_nitsot_outs(inner_scan_outs)
new_shared = op.inner_shared_outs(inner_scan_outs)
# New carry for next step
# Update MIT-MOT buffer at positions indicated by output taps
new_mit_mot = [
buffer.at[i + np.array(taps)].set(new_vals)
for buffer, new_vals, taps in zip(
old_mit_mot, new_mit_mot_vals, info.mit_mot_out_slices, strict=True
)
]
# Discard oldest MIT-SOT and append newest value
new_mit_sot = [
jnp.concatenate([old_buffer[1:], new_val[None, ...]], axis=0)
for old_buffer, new_val in zip(
old_mit_sot, new_mit_sot_vals, strict=True
)
]
# For SIT-SOT, and shared just pass along the new value
# Non-sequences remain unchanged
new_carry = (
i + 1,
new_mit_mot,
new_mit_sot,
new_sit_sot,
new_shared,
inner_non_seqs,
)
shareds_update_vals = op.inner_shared_outs(next_vals)
return (
mit_mot_updated_buffers,
mit_sot_updated_buffers,
sit_sot_updated_buffers,
nit_sot_updated_buffers,
shareds_update_vals,
)
# Select new MIT-SOT, SIT-SOT, and NIT-SOT for tracing
traced_outs = [
*new_mit_sot_vals,
*new_sit_sot,
*new_nit_sot,
]
return new_carry, traced_outs
def jax_inner_func(carry, x):
inner_args = jax_args_to_inner_func_args(carry, x)
inner_scan_outs = list(scan_inner_func(*inner_args))
new_carry, traced_outs = inner_func_outs_to_jax_outs(carry, inner_scan_outs)
return new_carry, traced_outs
# Extract PyTensor scan outputs
(
updated_mit_mot_buffers,
updated_mit_sot_buffers,
updated_sit_sot_buffers,
updated_nit_sot_buffers,
updated_shareds,
) = fori_loop(
0,
n_steps,
body_fun,
init_val=(
init_mit_mot_buffers,
init_mit_sot_buffers,
init_sit_sot_buffers,
init_nit_sot_buffers,
init_shareds,
(
_final_i,
final_mit_mot,
_final_mit_sot,
_final_sit_sot,
final_shared,
_final_non_seqs,
),
)
# Roll the output buffers to match PyTensor Scan semantics
# MIT-MOT buffers are never truncated, so no rolling is needed
aligned_mit_mot_buffers = updated_mit_mot_buffers
aligned_mit_sot_buffers = align_buffers(
updated_mit_sot_buffers,
n_steps,
# (-3, -1) -> max is 2
max_taps=[-min(taps) - 1 for taps in info.mit_sot_in_slices],
)
aligned_sit_sot_buffers = align_buffers(
updated_sit_sot_buffers,
n_steps,
max_taps=[0 for _ in updated_sit_sot_buffers],
)
aligned_nit_sot_buffers = align_buffers(
updated_nit_sot_buffers,
n_steps,
max_taps=[0 for _ in updated_nit_sot_buffers],
)
all_outputs = tuple(
chain.from_iterable(
(
aligned_mit_mot_buffers,
aligned_mit_sot_buffers,
aligned_sit_sot_buffers,
aligned_nit_sot_buffers,
updated_shareds,
)
traces,
) = jax_scan(jax_inner_func, init_carry, seqs, length=n_steps)
def get_partial_traces(traces):
"""Convert JAX scan traces to PyTensor traces.
We need to:
1. Prepend initial states to JAX output traces
2. Slice final traces if Scan was instructed to only keep a portion
"""
init_states = mit_sot_init + sit_sot_init + [None] * op.info.n_nit_sot
buffers = (
op.outer_mitsot(outer_inputs)
+ op.outer_sitsot(outer_inputs)
+ op.outer_nitsot(outer_inputs)
)
)
return all_outputs[0] if len(all_outputs) == 1 else all_outputs
partial_traces = []
for init_state, trace, buffer in zip(
init_states, traces, buffers, strict=True
):
if init_state is not None:
# MIT-SOT and SIT-SOT: The final output should be as long as the input buffer
trace = jnp.atleast_1d(trace)
init_state = jnp.expand_dims(
init_state, range(trace.ndim - init_state.ndim)
)
full_trace = jnp.concatenate([init_state, trace], axis=0)
buffer_size = buffer.shape[0]
else:
# NIT-SOT: Buffer is just the number of entries that should be returned
full_trace = jnp.atleast_1d(trace)
buffer_size = buffer
partial_trace = full_trace[-buffer_size:]
partial_traces.append(partial_trace)
return partial_traces
scan_outs_final = [
*final_mit_mot,
*get_partial_traces(traces),
*final_shared,
]
if len(scan_outs_final) == 1:
scan_outs_final = scan_outs_final[0]
return scan_outs_final
return scan
......@@ -8,7 +8,6 @@ from pytensor import function, ifelse, shared
from pytensor.compile import get_mode
from pytensor.configdefaults import config
from pytensor.graph import Apply, Op
from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.scan import until
from pytensor.scan.basic import scan
from pytensor.scan.op import Scan
......@@ -335,6 +334,9 @@ def test_default_mode_excludes_incompatible_rewrites():
def test_dynamic_sequence_length():
# Imported here to not trigger import of JAX in non-JAX CI jobs
from pytensor.link.jax.dispatch.basic import jax_funcify
class IncWithoutStaticShape(Op):
def make_node(self, x):
x = pt.as_tensor_variable(x)
......@@ -358,10 +360,10 @@ def test_dynamic_sequence_length():
assert sum(isinstance(node.op, Scan) for node in f.maker.fgraph.apply_nodes) == 1
np.testing.assert_allclose(f([[1, 2, 3]]), np.array([[2, 3, 4]]))
with pytest.raises(ValueError):
f(np.zeros((0, 3)))
# This works if we use JAX scan internally, but not if we use a fori_loop with a buffer allocated by us
np.testing.assert_allclose(f(np.zeros((0, 3))), np.empty((0, 3)))
# But should be fine with static shape
# With known static shape we should always manage, regardless of the internal implementation
out2, _ = scan(
lambda x: pt.specify_shape(inc_without_static_shape(x), x.shape),
outputs_info=[None],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论