提交 9004e5f2 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

scan_save_mem_rewrite: short-circuit mit-mot

mit-mot outputs are never memory optimized. Skipping these outputs lets us get rid of two faulty logical branches that existed to mask each: 1. an `if i <= op.info.n_mit_mot:` inside an `else` branch. This was logically wrong in that it included the first non mit-mot output (should have been <, not <=). When this was the output of a while scan it created an artificial dependency on the scan output shape, and didn't allow the rewrite to happen. 2. because of this the outer `if(i <= op.info.n_mit_mot and ...)` had been added to sidestep this artificial dependency. The comment mentioned in was supposed to specifically handle sit-sot/mit-sot of while loops, but it was again looking at all mit-mots + first non mit-mot input. It was logically wrong but canceled the first logical mistake. If we remove both things just work.
上级 c3d877fe
...@@ -1469,8 +1469,11 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: ...@@ -1469,8 +1469,11 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
nw_steps = select_max(nw_steps, 1) nw_steps = select_max(nw_steps, 1)
# 2.4 Loop over the clients again now looking just to see how many # 2.4 Loop over the clients again now looking just to see how many
# intermediate steps to store # intermediate steps to store. Skip mit_mot outputs as their
for i, out in enumerate(node.outputs[:c_outs]): # store_steps is always 0 (all intermediate values are needed).
for i, out in enumerate(
node.outputs[op_info.n_mit_mot : c_outs], start=op_info.n_mit_mot
):
# look at all its clients # look at all its clients
for cl, _ in fgraph.clients[out]: for cl, _ in fgraph.clients[out]:
if isinstance(cl.op, Output): if isinstance(cl.op, Output):
...@@ -1495,26 +1498,7 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: ...@@ -1495,26 +1498,7 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
store_steps[i] = 0 store_steps[i] = 0
break break
# Special case for recurrent outputs where only the last result
# is requested. This is needed for this rewrite to apply to
# do-while Scans at all. Otherwise, `get_canonical_form_slice` in
# the `else` branch would reintroduce a shape dependency on the
# original Scan which would lead this rewrite to abort in the end.
if (
i <= op.info.n_mit_mot
and isinstance(this_slice[0], ScalarConstant)
and this_slice[0].value == -1
):
start = nw_steps - 1
else:
if i <= op.info.n_mit_mot:
try:
length = shape_of[out][0]
except KeyError:
length = out.shape[0]
else:
length = node.inputs[0] + init_l[i] length = node.inputs[0] + init_l[i]
cf_slice = get_canonical_form_slice(this_slice[0], length) cf_slice = get_canonical_form_slice(this_slice[0], length)
if isinstance(cf_slice[0], slice): if isinstance(cf_slice[0], slice):
......
...@@ -1692,13 +1692,16 @@ class TestSaveMem: ...@@ -1692,13 +1692,16 @@ class TestSaveMem:
# ys_trace is an Alloc that controls the size of the inner buffer, # ys_trace is an Alloc that controls the size of the inner buffer,
# it should have shape[0] == 3, with two entries for the taps and one # it should have shape[0] == 3, with two entries for the taps and one
# entry for the intermediate output # extra entry to prevent aliasing between the inputs and outputs
# of the pre-allocation mechanism. JIT linkers don't use pre-allocation
# so the buffer is one element smaller.
[scan_node] = (n for n in f.maker.fgraph.apply_nodes if isinstance(n.op, Scan)) [scan_node] = (n for n in f.maker.fgraph.apply_nodes if isinstance(n.op, Scan))
_, ys_trace = scan_node.inputs _, ys_trace = scan_node.inputs
debug_fn = pytensor.function( debug_fn = pytensor.function(
[n_steps, x0], ys_trace.shape[0], accept_inplace=True [n_steps, x0], ys_trace.shape[0], accept_inplace=True
) )
assert debug_fn(n_steps=1000, x0=[1, 1]) == 3 expected_size = 2 if isinstance(f.maker.linker, JITLinker) else 3
assert debug_fn(n_steps=1000, x0=[1, 1]) == expected_size
def test_while_scan_map(self): def test_while_scan_map(self):
xs = vector("xs") xs = vector("xs")
...@@ -1752,13 +1755,15 @@ class TestSaveMem: ...@@ -1752,13 +1755,15 @@ class TestSaveMem:
f(x0=0, seq=test_seq, n_steps=0) f(x0=0, seq=test_seq, n_steps=0)
# Evaluate the shape of ys_trace and len_zs to confirm the rewrite worked correctly. # Evaluate the shape of ys_trace and len_zs to confirm the rewrite worked correctly.
# JIT linkers don't use pre-allocation so the buffer is one element smaller.
[scan_node] = (n for n in f.maker.fgraph.apply_nodes if isinstance(n.op, Scan)) [scan_node] = (n for n in f.maker.fgraph.apply_nodes if isinstance(n.op, Scan))
_, _, ys_trace, len_zs = scan_node.inputs _, _, ys_trace, len_zs = scan_node.inputs
debug_fn = pytensor.function( debug_fn = pytensor.function(
[x0, n_steps], [ys_trace.shape[0], len_zs], accept_inplace=True [x0, n_steps], [ys_trace.shape[0], len_zs], accept_inplace=True
) )
stored_ys_steps, stored_zs_steps = debug_fn(x0=0, n_steps=200) stored_ys_steps, stored_zs_steps = debug_fn(x0=0, n_steps=200)
assert stored_ys_steps == 2 expected_y_steps = 1 if isinstance(f.maker.linker, JITLinker) else 2
assert stored_ys_steps == expected_y_steps
assert stored_zs_steps == 1 assert stored_zs_steps == 1
@pytest.mark.parametrize("val_ndim", (0, 1)) @pytest.mark.parametrize("val_ndim", (0, 1))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论