提交 2ce0ce13 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Note failing scan rewrite

上级 9a124cac
......@@ -658,10 +658,9 @@ def inner_sitsot_only_last_step_used(
fgraph: FunctionGraph, var: Variable, scan_args: ScanArgs
) -> bool:
"""
Given a inner nit-sot output of `Scan`, return ``True`` iff the outer
nit-sot output has only one client and that client is a `Subtensor`
instance that takes only the last step (last element along the first
axis).
Given a inner sit-sot output of `Scan`, return ``True`` iff the outer
sit-sot output has only one client and that client is a `Subtensor`
instance that takes only the last step (last element along the first axis).
"""
idx = scan_args.inner_out_sit_sot.index(var)
outer_var = scan_args.outer_out_sit_sot[idx]
......@@ -832,6 +831,14 @@ def scan_push_out_add(fgraph, node):
Like `scan_push_out_seq`, this optimization aims to replace many operations
on small tensors by few operations on large tensors. It can also lead to
increased memory usage.
FIXME: This rewrite doesn't cover user defined graphs,
since it doesn't account for the intermediate slice
returned by the scan constructor for sit-sot (i.e., something like output[1:]).
It only looks for `outputs[-1]` but the user will only ever write `outputs[1:][-1]`
The relevant helper function is `inner_sitsot_only_last_step_used` which is only used by this rewrite
Note this rewrite is registered before subtensor_merge, but even if it were after subtensor_merge is a mess
and doesn't simplify to x[1:][-1] to x[-1] unless x length is statically known
"""
# Don't perform the optimization on `as_while` `Scan`s. Because these
# `Scan`s don't run for a predetermined number of steps, handling them is
......@@ -857,6 +864,7 @@ def scan_push_out_add(fgraph, node):
isinstance(nd.op, Elemwise)
and isinstance(nd.op.scalar_op, ps.Add)
and nd.out in args.inner_out_sit_sot
# FIXME: This function doesn't handle `sitsot_out[1:][-1]` pattern
and inner_sitsot_only_last_step_used(fgraph, nd.out, args)
):
# Ensure that one of the input to the add is the output of
......@@ -920,6 +928,7 @@ def scan_push_out_add(fgraph, node):
# external Dot instead of the output of scan
# Modify the outer graph to add the outer Dot
outer_sitsot = new_scan_args.outer_out_sit_sot[sitsot_idx]
# TODO: If we fix the FIXME above, we have to make sure we replace the last subtensor, not the immediate one
subtensor_node = fgraph.clients[outer_sitsot][0][0]
outer_sitsot_last_step = subtensor_node.outputs[0]
......
......@@ -600,10 +600,12 @@ class TestPushOutAddScan:
is used to compute the sum over the dot products between the corresponding
elements of two list of matrices.
TODO FIXME XXX: These aren't real tests; they simply confirm that a few
FIXME: These aren't real tests; they simply confirm that a few
graph that could be relevant to the push-out optimizations can be compiled
and evaluated. None of them confirm that a push-out optimization has been
performed.
FIXME: The rewrite is indeed broken, probably fro a long while, see FIXME details in the respective rewrite
"""
def test_sum_dot(self):
......@@ -614,7 +616,15 @@ class TestPushOutAddScan:
sequences=[A.dimshuffle(0, 1, "x"), B.dimshuffle(0, "x", 1)],
outputs_info=[pt.zeros_like(A)],
)
# FIXME: This `s.owner.inputs[0][-1]` is a hack, users will never do that.
# They will do `s[-1]` which the rewrite fails to identify since it explicitly looks for a `scan_out[-1]`
# instead of `scan_out[1:][-1]` that the user would define by writing `s[-1]`
# It however, tests the only case the rewrite supports now
f = function([A, B], S.owner.inputs[0][-1])
has_scan = any(isinstance(node.op, Scan) for node in f.maker.fgraph.apply_nodes)
# Rewrite is only triggered in fast_run mode
assert has_scan if (config.mode == "FAST_COMPILE") else (not has_scan)
rng = np.random.default_rng(utt.fetch_seed())
vA = rng.uniform(size=(5, 5)).astype(config.floatX)
vB = rng.uniform(size=(5, 5)).astype(config.floatX)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论