提交 85235623 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Do more agressive scan memory saves in JIT backends

上级 92420c8f
...@@ -454,6 +454,19 @@ else: ...@@ -454,6 +454,19 @@ else:
RewriteDatabaseQuery(include=["fast_run", "py_only"]), RewriteDatabaseQuery(include=["fast_run", "py_only"]),
) )
NUMBA = Mode(
NumbaLinker(),
RewriteDatabaseQuery(
include=["fast_run", "numba"],
exclude=[
"cxx_only",
"BlasOpt",
"local_careduce_fusion",
"scan_save_mem_prealloc",
],
),
)
JAX = Mode( JAX = Mode(
JAXLinker(), JAXLinker(),
RewriteDatabaseQuery( RewriteDatabaseQuery(
...@@ -463,6 +476,7 @@ JAX = Mode( ...@@ -463,6 +476,7 @@ JAX = Mode(
"BlasOpt", "BlasOpt",
"fusion", "fusion",
"inplace", "inplace",
"scan_save_mem_prealloc",
], ],
), ),
) )
...@@ -476,16 +490,10 @@ PYTORCH = Mode( ...@@ -476,16 +490,10 @@ PYTORCH = Mode(
"fusion", "fusion",
"inplace", "inplace",
"local_uint_constant_indices", "local_uint_constant_indices",
"scan_save_mem_prealloc",
], ],
), ),
) )
NUMBA = Mode(
NumbaLinker(),
RewriteDatabaseQuery(
include=["fast_run", "numba"],
exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"],
),
)
predefined_modes = { predefined_modes = {
......
...@@ -1085,7 +1085,9 @@ def add_scan_configvars(): ...@@ -1085,7 +1085,9 @@ def add_scan_configvars():
"scan__allow_output_prealloc", "scan__allow_output_prealloc",
"Allow/disallow memory preallocation for outputs inside of scan " "Allow/disallow memory preallocation for outputs inside of scan "
"(default: True)", "(default: True)",
BoolParam(True), # Non-mutable because ScanSaveMem rewrite checks it,
# and we can't have the rewrite and the implementation mismatch
BoolParam(True, mutable=False),
in_c_key=False, in_c_key=False,
) )
......
...@@ -70,7 +70,7 @@ from pytensor.tensor.subtensor import ( ...@@ -70,7 +70,7 @@ from pytensor.tensor.subtensor import (
get_slice_elements, get_slice_elements,
set_subtensor, set_subtensor,
) )
from pytensor.tensor.variable import TensorConstant from pytensor.tensor.variable import TensorConstant, TensorVariable
list_opt_slice = [ list_opt_slice = [
...@@ -1182,8 +1182,7 @@ def while_scan_merge_subtensor_last_element(fgraph, scan_node): ...@@ -1182,8 +1182,7 @@ def while_scan_merge_subtensor_last_element(fgraph, scan_node):
return subtensor_merge_replacements return subtensor_merge_replacements
@node_rewriter([Scan]) def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: bool):
def scan_save_mem(fgraph, node):
r"""Graph optimizer that reduces scan memory consumption. r"""Graph optimizer that reduces scan memory consumption.
This optimizations attempts to determine if a `Scan` node, during its execution, This optimizations attempts to determine if a `Scan` node, during its execution,
...@@ -1214,10 +1213,16 @@ def scan_save_mem(fgraph, node): ...@@ -1214,10 +1213,16 @@ def scan_save_mem(fgraph, node):
The scan perform implementation takes the output sizes into consideration, The scan perform implementation takes the output sizes into consideration,
saving the newest results over the oldest ones whenever the buffer is filled. saving the newest results over the oldest ones whenever the buffer is filled.
"""
if not isinstance(node.op, Scan):
return False
Paramaters
----------
backend_supports_output_pre_allocation: bool
When the backend supports output pre-allocation Scan must keep buffers
with a length of required_states + 1, because the inner function will
attempt to write the inner function outputs directly into the provided
position in the outer circular buffer. This would invalidate results,
if the input is still needed for some other output computation.
"""
if hasattr(fgraph, "shape_feature"): if hasattr(fgraph, "shape_feature"):
shape_of = fgraph.shape_feature.shape_of shape_of = fgraph.shape_feature.shape_of
else: else:
...@@ -1270,6 +1275,7 @@ def scan_save_mem(fgraph, node): ...@@ -1270,6 +1275,7 @@ def scan_save_mem(fgraph, node):
# Note: For simplicity while Scans also have global_nsteps set to None. # Note: For simplicity while Scans also have global_nsteps set to None.
# All step optimizations require knowing the shape of the output, which # All step optimizations require knowing the shape of the output, which
# cannot be determined from the inputs alone. # cannot be determined from the inputs alone.
global_nsteps: None | dict
assert len(node.outputs) >= c_outs assert len(node.outputs) >= c_outs
if len(node.outputs) == c_outs and not op.info.as_while: if len(node.outputs) == c_outs and not op.info.as_while:
global_nsteps = {"real": -1, "sym": []} global_nsteps = {"real": -1, "sym": []}
...@@ -1277,7 +1283,7 @@ def scan_save_mem(fgraph, node): ...@@ -1277,7 +1283,7 @@ def scan_save_mem(fgraph, node):
global_nsteps = None global_nsteps = None
# Keeps track of the original slices that each client represent # Keeps track of the original slices that each client represent
slices = [None for o in node.outputs] slices: list[None | list] = [None for o in node.outputs]
# A list for each output indicating how many intermediate values # A list for each output indicating how many intermediate values
# should be stored. If negative it means none of the intermediate # should be stored. If negative it means none of the intermediate
...@@ -1294,7 +1300,7 @@ def scan_save_mem(fgraph, node): ...@@ -1294,7 +1300,7 @@ def scan_save_mem(fgraph, node):
# or not # or not
flag_store = False flag_store = False
# 2.2 Loop over the clients # 2.2 Loop over the clients to figure out how many steps we actually need to do in the Scan
for i, out in enumerate(node.outputs[:c_outs]): for i, out in enumerate(node.outputs[:c_outs]):
# look at all its clients # look at all its clients
slices[i] = [] slices[i] = []
...@@ -1337,7 +1343,7 @@ def scan_save_mem(fgraph, node): ...@@ -1337,7 +1343,7 @@ def scan_save_mem(fgraph, node):
except KeyError: except KeyError:
length = out.shape[0] length = out.shape[0]
cf_slice = get_canonical_form_slice(this_slice[0], length) cf_slice = get_canonical_form_slice(this_slice[0], length)
slices[i] += [(cf_slice, this_slice)] slices[i] += [(cf_slice, this_slice)] # type: ignore
if isinstance(this_slice[0], slice) and this_slice[0].stop is None: if isinstance(this_slice[0], slice) and this_slice[0].stop is None:
global_nsteps = None global_nsteps = None
...@@ -1477,7 +1483,10 @@ def scan_save_mem(fgraph, node): ...@@ -1477,7 +1483,10 @@ def scan_save_mem(fgraph, node):
# for mitsots and sitsots (because mitmots are not # for mitsots and sitsots (because mitmots are not
# currently supported by the mechanism) and only if # currently supported by the mechanism) and only if
# the pre-allocation mechanism is activated. # the pre-allocation mechanism is activated.
prealloc_outs = config.scan__allow_output_prealloc prealloc_outs = (
backend_supports_output_pre_allocation
and config.scan__allow_output_prealloc
)
first_mitsot_idx = op_info.n_mit_mot first_mitsot_idx = op_info.n_mit_mot
last_sitsot_idx = ( last_sitsot_idx = (
...@@ -1486,6 +1495,8 @@ def scan_save_mem(fgraph, node): ...@@ -1486,6 +1495,8 @@ def scan_save_mem(fgraph, node):
preallocable_output = first_mitsot_idx <= i <= last_sitsot_idx preallocable_output = first_mitsot_idx <= i <= last_sitsot_idx
if prealloc_outs and preallocable_output: if prealloc_outs and preallocable_output:
# TODO: If there's only one output or other outputs do not depend
# on the same input, we could reduce the buffer size to the minimum
pval = select_max(nw_steps - start + init_l[i], init_l[i] + 1) pval = select_max(nw_steps - start + init_l[i], init_l[i] + 1)
else: else:
pval = select_max(nw_steps - start + init_l[i], init_l[i]) pval = select_max(nw_steps - start + init_l[i], init_l[i])
...@@ -1652,7 +1663,7 @@ def scan_save_mem(fgraph, node): ...@@ -1652,7 +1663,7 @@ def scan_save_mem(fgraph, node):
name=op.name, name=op.name,
allow_gc=op.allow_gc, allow_gc=op.allow_gc,
) )
new_outs = new_op(*node_ins, return_list=True) new_outs = cast(list[TensorVariable], new_op(*node_ins, return_list=True))
old_new = [] old_new = []
# 3.7 Get replace pairs for those outputs that do not change # 3.7 Get replace pairs for those outputs that do not change
...@@ -1682,7 +1693,7 @@ def scan_save_mem(fgraph, node): ...@@ -1682,7 +1693,7 @@ def scan_save_mem(fgraph, node):
sl_ins = get_slice_elements( sl_ins = get_slice_elements(
nw_slice, lambda entry: isinstance(entry, Variable) nw_slice, lambda entry: isinstance(entry, Variable)
) )
new_o = subtens(new_outs[nw_pos], *sl_ins) new_o = cast(TensorVariable, subtens(new_outs[nw_pos], *sl_ins))
if new_o.ndim > 0: if new_o.ndim > 0:
new_o = new_o[:: cnf_slice[1]] new_o = new_o[:: cnf_slice[1]]
replaced_outs.append(idx) replaced_outs.append(idx)
...@@ -1737,7 +1748,7 @@ def scan_save_mem(fgraph, node): ...@@ -1737,7 +1748,7 @@ def scan_save_mem(fgraph, node):
sl_ins = get_slice_elements( sl_ins = get_slice_elements(
nw_slice, lambda entry: isinstance(entry, Variable) nw_slice, lambda entry: isinstance(entry, Variable)
) )
new_o = subtens(new_outs[nw_pos], *sl_ins) new_o = cast(TensorVariable, subtens(new_outs[nw_pos], *sl_ins))
if new_o.ndim > 0: if new_o.ndim > 0:
new_o = new_o[:: cnf_slice[1]] new_o = new_o[:: cnf_slice[1]]
old_new += [(old, new_o)] old_new += [(old, new_o)]
...@@ -1768,6 +1779,20 @@ def scan_save_mem(fgraph, node): ...@@ -1768,6 +1779,20 @@ def scan_save_mem(fgraph, node):
return False return False
@node_rewriter([Scan])
def scan_save_mem_prealloc(fgraph, node):
return scan_save_mem_rewrite(
fgraph, node, backend_supports_output_pre_allocation=True
)
@node_rewriter([Scan])
def scan_save_mem_no_prealloc(fgraph, node):
return scan_save_mem_rewrite(
fgraph, node, backend_supports_output_pre_allocation=False
)
class ScanMerge(GraphRewriter): class ScanMerge(GraphRewriter):
r"""Graph optimizer that merges different scan ops. r"""Graph optimizer that merges different scan ops.
...@@ -2495,10 +2520,20 @@ optdb.register("scan_eqopt1", scan_eqopt1, "fast_run", "scan", position=0.05) ...@@ -2495,10 +2520,20 @@ optdb.register("scan_eqopt1", scan_eqopt1, "fast_run", "scan", position=0.05)
optdb.register("scan_eqopt2", scan_eqopt2, "fast_run", "scan", position=1.6) optdb.register("scan_eqopt2", scan_eqopt2, "fast_run", "scan", position=1.6)
# ScanSaveMem should execute only once per node. # ScanSaveMem should execute only once per node.
optdb.register( optdb.register(
"scan_save_mem", "scan_save_mem_prealloc",
in2out(scan_save_mem, ignore_newtrees=True), in2out(scan_save_mem_prealloc, ignore_newtrees=True),
"fast_run", "fast_run",
"scan", "scan",
"scan_save_mem",
position=1.61,
)
optdb.register(
"scan_save_mem_no_prealloc",
in2out(scan_save_mem_no_prealloc, ignore_newtrees=True),
"numba",
"jax",
"pytorch",
use_db_name_as_tag=False,
position=1.61, position=1.61,
) )
optdb.register( optdb.register(
......
import logging import logging
import sys import sys
import warnings import warnings
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable, Sequence
from itertools import chain, groupby from itertools import chain, groupby
from textwrap import dedent from textwrap import dedent
from typing import cast, overload from typing import cast, overload
...@@ -645,7 +645,7 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False): ...@@ -645,7 +645,7 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False):
def get_slice_elements( def get_slice_elements(
idxs: list, idxs: Sequence,
cond: Callable = lambda x: isinstance(x, Variable), cond: Callable = lambda x: isinstance(x, Variable),
) -> list: ) -> list:
"""Extract slice elements conditional on a given predicate function. """Extract slice elements conditional on a given predicate function.
......
...@@ -465,7 +465,7 @@ class TestScanSITSOTBuffer: ...@@ -465,7 +465,7 @@ class TestScanSITSOTBuffer:
) )
if buffer_size == "unit": if buffer_size == "unit":
xs_kept = xs[-1] # Only last state is used xs_kept = xs[-1] # Only last state is used
expected_buffer_size = 2 expected_buffer_size = 1
elif buffer_size == "aligned": elif buffer_size == "aligned":
xs_kept = xs[-2:] # The buffer will be aligned at the end of the 9 steps xs_kept = xs[-2:] # The buffer will be aligned at the end of the 9 steps
expected_buffer_size = 2 expected_buffer_size = 2
...@@ -555,8 +555,7 @@ class TestScanMITSOTBuffer: ...@@ -555,8 +555,7 @@ class TestScanMITSOTBuffer:
accept_inplace=True, accept_inplace=True,
on_unused_input="ignore", on_unused_input="ignore",
) )
assert tuple(mitsot_buffer_shape) == (3,) assert tuple(mitsot_buffer_shape) == (2,)
if benchmark is not None: if benchmark is not None:
numba_fn.trust_input = True numba_fn.trust_input = True
benchmark(numba_fn, *test_vals) benchmark(numba_fn, *test_vals)
......
...@@ -742,7 +742,7 @@ class TestPushOutAddScan: ...@@ -742,7 +742,7 @@ class TestPushOutAddScan:
utt.assert_allclose(f_opt_output, f_no_opt_output) utt.assert_allclose(f_opt_output, f_no_opt_output)
def test_non_zero_init(self): def test_non_zero_init(self):
"""Test the case where the initial value for the nitsot output is non-zero.""" """Test the case where the initial value for the sitsot output is non-zero."""
input1 = tensor3() input1 = tensor3()
input2 = tensor3() input2 = tensor3()
...@@ -759,8 +759,7 @@ class TestPushOutAddScan: ...@@ -759,8 +759,7 @@ class TestPushOutAddScan:
init = pt.as_tensor_variable(np.random.normal(size=(3, 7))) init = pt.as_tensor_variable(np.random.normal(size=(3, 7)))
# Compile the function twice, once with the optimization and once # Compile the function twice, once with the optimization and once without
# without
opt_mode = mode.including("scan") opt_mode = mode.including("scan")
h, _ = pytensor.scan( h, _ = pytensor.scan(
inner_fct, inner_fct,
...@@ -792,7 +791,7 @@ class TestPushOutAddScan: ...@@ -792,7 +791,7 @@ class TestPushOutAddScan:
output_opt = f_opt(input1_value, input2_value, input3_value) output_opt = f_opt(input1_value, input2_value, input3_value)
output_no_opt = f_no_opt(input1_value, input2_value, input3_value) output_no_opt = f_no_opt(input1_value, input2_value, input3_value)
utt.assert_allclose(output_opt, output_no_opt) np.testing.assert_allclose(output_opt, output_no_opt)
class TestScanMerge: class TestScanMerge:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论