提交 b9af9523 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Optimize partial trace definition

上级 b3b68618
...@@ -209,18 +209,28 @@ def jax_funcify_Scan(op: Scan, **kwargs): ...@@ -209,18 +209,28 @@ def jax_funcify_Scan(op: Scan, **kwargs):
): ):
if init_state is not None: if init_state is not None:
# MIT-SOT and SIT-SOT: The final output should be as long as the input buffer # MIT-SOT and SIT-SOT: The final output should be as long as the input buffer
trace = jnp.atleast_1d(trace)
init_state = jnp.expand_dims(
init_state, range(trace.ndim - init_state.ndim)
)
full_trace = jnp.concatenate([init_state, trace], axis=0)
buffer_size = buffer.shape[0] buffer_size = buffer.shape[0]
if trace.shape[0] > buffer_size:
# Trace is longer than buffer, keep just the last `buffer.shape[0]` entries
partial_trace = trace[-buffer_size:]
else:
# Trace is shorter than buffer, this happens when we keep the initial_state
if init_state.ndim < buffer.ndim:
init_state = init_state[None]
if (
n_init_needed := buffer_size - trace.shape[0]
) < init_state.shape[0]:
# We may not need to keep all the initial states
init_state = init_state[-n_init_needed:]
partial_trace = jnp.concatenate([init_state, trace], axis=0)
else: else:
# NIT-SOT: Buffer is just the number of entries that should be returned # NIT-SOT: Buffer is just the number of entries that should be returned
full_trace = jnp.atleast_1d(trace)
buffer_size = buffer buffer_size = buffer
partial_trace = (
trace[-buffer_size:] if trace.shape[0] > buffer else trace
)
partial_trace = full_trace[-buffer_size:] assert partial_trace.shape[0] == buffer_size
partial_traces.append(partial_trace) partial_traces.append(partial_trace)
return partial_traces return partial_traces
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论