提交 85235623 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Do more agressive scan memory saves in JIT backends

上级 92420c8f
......@@ -454,6 +454,19 @@ else:
RewriteDatabaseQuery(include=["fast_run", "py_only"]),
)
NUMBA = Mode(
NumbaLinker(),
RewriteDatabaseQuery(
include=["fast_run", "numba"],
exclude=[
"cxx_only",
"BlasOpt",
"local_careduce_fusion",
"scan_save_mem_prealloc",
],
),
)
JAX = Mode(
JAXLinker(),
RewriteDatabaseQuery(
......@@ -463,6 +476,7 @@ JAX = Mode(
"BlasOpt",
"fusion",
"inplace",
"scan_save_mem_prealloc",
],
),
)
......@@ -476,16 +490,10 @@ PYTORCH = Mode(
"fusion",
"inplace",
"local_uint_constant_indices",
"scan_save_mem_prealloc",
],
),
)
NUMBA = Mode(
NumbaLinker(),
RewriteDatabaseQuery(
include=["fast_run", "numba"],
exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"],
),
)
predefined_modes = {
......
......@@ -1085,7 +1085,9 @@ def add_scan_configvars():
"scan__allow_output_prealloc",
"Allow/disallow memory preallocation for outputs inside of scan "
"(default: True)",
BoolParam(True),
# Non-mutable because ScanSaveMem rewrite checks it,
# and we can't have the rewrite and the implementation mismatch
BoolParam(True, mutable=False),
in_c_key=False,
)
......
......@@ -70,7 +70,7 @@ from pytensor.tensor.subtensor import (
get_slice_elements,
set_subtensor,
)
from pytensor.tensor.variable import TensorConstant
from pytensor.tensor.variable import TensorConstant, TensorVariable
list_opt_slice = [
......@@ -1182,8 +1182,7 @@ def while_scan_merge_subtensor_last_element(fgraph, scan_node):
return subtensor_merge_replacements
@node_rewriter([Scan])
def scan_save_mem(fgraph, node):
def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: bool):
r"""Graph optimizer that reduces scan memory consumption.
This optimizations attempts to determine if a `Scan` node, during its execution,
......@@ -1214,10 +1213,16 @@ def scan_save_mem(fgraph, node):
The scan perform implementation takes the output sizes into consideration,
saving the newest results over the oldest ones whenever the buffer is filled.
"""
if not isinstance(node.op, Scan):
return False
Paramaters
----------
backend_supports_output_pre_allocation: bool
When the backend supports output pre-allocation Scan must keep buffers
with a length of required_states + 1, because the inner function will
attempt to write the inner function outputs directly into the provided
position in the outer circular buffer. This would invalidate results,
if the input is still needed for some other output computation.
"""
if hasattr(fgraph, "shape_feature"):
shape_of = fgraph.shape_feature.shape_of
else:
......@@ -1270,6 +1275,7 @@ def scan_save_mem(fgraph, node):
# Note: For simplicity while Scans also have global_nsteps set to None.
# All step optimizations require knowing the shape of the output, which
# cannot be determined from the inputs alone.
global_nsteps: None | dict
assert len(node.outputs) >= c_outs
if len(node.outputs) == c_outs and not op.info.as_while:
global_nsteps = {"real": -1, "sym": []}
......@@ -1277,7 +1283,7 @@ def scan_save_mem(fgraph, node):
global_nsteps = None
# Keeps track of the original slices that each client represent
slices = [None for o in node.outputs]
slices: list[None | list] = [None for o in node.outputs]
# A list for each output indicating how many intermediate values
# should be stored. If negative it means none of the intermediate
......@@ -1294,7 +1300,7 @@ def scan_save_mem(fgraph, node):
# or not
flag_store = False
# 2.2 Loop over the clients
# 2.2 Loop over the clients to figure out how many steps we actually need to do in the Scan
for i, out in enumerate(node.outputs[:c_outs]):
# look at all its clients
slices[i] = []
......@@ -1337,7 +1343,7 @@ def scan_save_mem(fgraph, node):
except KeyError:
length = out.shape[0]
cf_slice = get_canonical_form_slice(this_slice[0], length)
slices[i] += [(cf_slice, this_slice)]
slices[i] += [(cf_slice, this_slice)] # type: ignore
if isinstance(this_slice[0], slice) and this_slice[0].stop is None:
global_nsteps = None
......@@ -1477,7 +1483,10 @@ def scan_save_mem(fgraph, node):
# for mitsots and sitsots (because mitmots are not
# currently supported by the mechanism) and only if
# the pre-allocation mechanism is activated.
prealloc_outs = config.scan__allow_output_prealloc
prealloc_outs = (
backend_supports_output_pre_allocation
and config.scan__allow_output_prealloc
)
first_mitsot_idx = op_info.n_mit_mot
last_sitsot_idx = (
......@@ -1486,6 +1495,8 @@ def scan_save_mem(fgraph, node):
preallocable_output = first_mitsot_idx <= i <= last_sitsot_idx
if prealloc_outs and preallocable_output:
# TODO: If there's only one output or other outputs do not depend
# on the same input, we could reduce the buffer size to the minimum
pval = select_max(nw_steps - start + init_l[i], init_l[i] + 1)
else:
pval = select_max(nw_steps - start + init_l[i], init_l[i])
......@@ -1652,7 +1663,7 @@ def scan_save_mem(fgraph, node):
name=op.name,
allow_gc=op.allow_gc,
)
new_outs = new_op(*node_ins, return_list=True)
new_outs = cast(list[TensorVariable], new_op(*node_ins, return_list=True))
old_new = []
# 3.7 Get replace pairs for those outputs that do not change
......@@ -1682,7 +1693,7 @@ def scan_save_mem(fgraph, node):
sl_ins = get_slice_elements(
nw_slice, lambda entry: isinstance(entry, Variable)
)
new_o = subtens(new_outs[nw_pos], *sl_ins)
new_o = cast(TensorVariable, subtens(new_outs[nw_pos], *sl_ins))
if new_o.ndim > 0:
new_o = new_o[:: cnf_slice[1]]
replaced_outs.append(idx)
......@@ -1737,7 +1748,7 @@ def scan_save_mem(fgraph, node):
sl_ins = get_slice_elements(
nw_slice, lambda entry: isinstance(entry, Variable)
)
new_o = subtens(new_outs[nw_pos], *sl_ins)
new_o = cast(TensorVariable, subtens(new_outs[nw_pos], *sl_ins))
if new_o.ndim > 0:
new_o = new_o[:: cnf_slice[1]]
old_new += [(old, new_o)]
......@@ -1768,6 +1779,20 @@ def scan_save_mem(fgraph, node):
return False
@node_rewriter([Scan])
def scan_save_mem_prealloc(fgraph, node):
return scan_save_mem_rewrite(
fgraph, node, backend_supports_output_pre_allocation=True
)
@node_rewriter([Scan])
def scan_save_mem_no_prealloc(fgraph, node):
return scan_save_mem_rewrite(
fgraph, node, backend_supports_output_pre_allocation=False
)
class ScanMerge(GraphRewriter):
r"""Graph optimizer that merges different scan ops.
......@@ -2495,10 +2520,20 @@ optdb.register("scan_eqopt1", scan_eqopt1, "fast_run", "scan", position=0.05)
optdb.register("scan_eqopt2", scan_eqopt2, "fast_run", "scan", position=1.6)
# ScanSaveMem should execute only once per node.
optdb.register(
"scan_save_mem",
in2out(scan_save_mem, ignore_newtrees=True),
"scan_save_mem_prealloc",
in2out(scan_save_mem_prealloc, ignore_newtrees=True),
"fast_run",
"scan",
"scan_save_mem",
position=1.61,
)
optdb.register(
"scan_save_mem_no_prealloc",
in2out(scan_save_mem_no_prealloc, ignore_newtrees=True),
"numba",
"jax",
"pytorch",
use_db_name_as_tag=False,
position=1.61,
)
optdb.register(
......
import logging
import sys
import warnings
from collections.abc import Callable, Iterable
from collections.abc import Callable, Iterable, Sequence
from itertools import chain, groupby
from textwrap import dedent
from typing import cast, overload
......@@ -645,7 +645,7 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False):
def get_slice_elements(
idxs: list,
idxs: Sequence,
cond: Callable = lambda x: isinstance(x, Variable),
) -> list:
"""Extract slice elements conditional on a given predicate function.
......
......@@ -465,7 +465,7 @@ class TestScanSITSOTBuffer:
)
if buffer_size == "unit":
xs_kept = xs[-1] # Only last state is used
expected_buffer_size = 2
expected_buffer_size = 1
elif buffer_size == "aligned":
xs_kept = xs[-2:] # The buffer will be aligned at the end of the 9 steps
expected_buffer_size = 2
......@@ -555,8 +555,7 @@ class TestScanMITSOTBuffer:
accept_inplace=True,
on_unused_input="ignore",
)
assert tuple(mitsot_buffer_shape) == (3,)
assert tuple(mitsot_buffer_shape) == (2,)
if benchmark is not None:
numba_fn.trust_input = True
benchmark(numba_fn, *test_vals)
......
......@@ -742,7 +742,7 @@ class TestPushOutAddScan:
utt.assert_allclose(f_opt_output, f_no_opt_output)
def test_non_zero_init(self):
"""Test the case where the initial value for the nitsot output is non-zero."""
"""Test the case where the initial value for the sitsot output is non-zero."""
input1 = tensor3()
input2 = tensor3()
......@@ -759,8 +759,7 @@ class TestPushOutAddScan:
init = pt.as_tensor_variable(np.random.normal(size=(3, 7)))
# Compile the function twice, once with the optimization and once
# without
# Compile the function twice, once with the optimization and once without
opt_mode = mode.including("scan")
h, _ = pytensor.scan(
inner_fct,
......@@ -792,7 +791,7 @@ class TestPushOutAddScan:
output_opt = f_opt(input1_value, input2_value, input3_value)
output_no_opt = f_no_opt(input1_value, input2_value, input3_value)
utt.assert_allclose(output_opt, output_no_opt)
np.testing.assert_allclose(output_opt, output_no_opt)
class TestScanMerge:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论