提交 2ce0ce13 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Note failing scan rewrite

上级 9a124cac
...@@ -658,10 +658,9 @@ def inner_sitsot_only_last_step_used( ...@@ -658,10 +658,9 @@ def inner_sitsot_only_last_step_used(
fgraph: FunctionGraph, var: Variable, scan_args: ScanArgs fgraph: FunctionGraph, var: Variable, scan_args: ScanArgs
) -> bool: ) -> bool:
""" """
Given a inner nit-sot output of `Scan`, return ``True`` iff the outer Given a inner sit-sot output of `Scan`, return ``True`` iff the outer
nit-sot output has only one client and that client is a `Subtensor` sit-sot output has only one client and that client is a `Subtensor`
instance that takes only the last step (last element along the first instance that takes only the last step (last element along the first axis).
axis).
""" """
idx = scan_args.inner_out_sit_sot.index(var) idx = scan_args.inner_out_sit_sot.index(var)
outer_var = scan_args.outer_out_sit_sot[idx] outer_var = scan_args.outer_out_sit_sot[idx]
...@@ -832,6 +831,14 @@ def scan_push_out_add(fgraph, node): ...@@ -832,6 +831,14 @@ def scan_push_out_add(fgraph, node):
Like `scan_push_out_seq`, this optimization aims to replace many operations Like `scan_push_out_seq`, this optimization aims to replace many operations
on small tensors by few operations on large tensors. It can also lead to on small tensors by few operations on large tensors. It can also lead to
increased memory usage. increased memory usage.
FIXME: This rewrite doesn't cover user defined graphs,
since it doesn't account for the intermediate slice
returned by the scan constructor for sit-sot (i.e., something like output[1:]).
It only looks for `outputs[-1]` but the user will only ever write `outputs[1:][-1]`
The relevant helper function is `inner_sitsot_only_last_step_used` which is only used by this rewrite
Note this rewrite is registered before subtensor_merge, but even if it were after subtensor_merge is a mess
and doesn't simplify to x[1:][-1] to x[-1] unless x length is statically known
""" """
# Don't perform the optimization on `as_while` `Scan`s. Because these # Don't perform the optimization on `as_while` `Scan`s. Because these
# `Scan`s don't run for a predetermined number of steps, handling them is # `Scan`s don't run for a predetermined number of steps, handling them is
...@@ -857,6 +864,7 @@ def scan_push_out_add(fgraph, node): ...@@ -857,6 +864,7 @@ def scan_push_out_add(fgraph, node):
isinstance(nd.op, Elemwise) isinstance(nd.op, Elemwise)
and isinstance(nd.op.scalar_op, ps.Add) and isinstance(nd.op.scalar_op, ps.Add)
and nd.out in args.inner_out_sit_sot and nd.out in args.inner_out_sit_sot
# FIXME: This function doesn't handle `sitsot_out[1:][-1]` pattern
and inner_sitsot_only_last_step_used(fgraph, nd.out, args) and inner_sitsot_only_last_step_used(fgraph, nd.out, args)
): ):
# Ensure that one of the input to the add is the output of # Ensure that one of the input to the add is the output of
...@@ -920,6 +928,7 @@ def scan_push_out_add(fgraph, node): ...@@ -920,6 +928,7 @@ def scan_push_out_add(fgraph, node):
# external Dot instead of the output of scan # external Dot instead of the output of scan
# Modify the outer graph to add the outer Dot # Modify the outer graph to add the outer Dot
outer_sitsot = new_scan_args.outer_out_sit_sot[sitsot_idx] outer_sitsot = new_scan_args.outer_out_sit_sot[sitsot_idx]
# TODO: If we fix the FIXME above, we have to make sure we replace the last subtensor, not the immediate one
subtensor_node = fgraph.clients[outer_sitsot][0][0] subtensor_node = fgraph.clients[outer_sitsot][0][0]
outer_sitsot_last_step = subtensor_node.outputs[0] outer_sitsot_last_step = subtensor_node.outputs[0]
......
...@@ -600,10 +600,12 @@ class TestPushOutAddScan: ...@@ -600,10 +600,12 @@ class TestPushOutAddScan:
is used to compute the sum over the dot products between the corresponding is used to compute the sum over the dot products between the corresponding
elements of two list of matrices. elements of two list of matrices.
TODO FIXME XXX: These aren't real tests; they simply confirm that a few FIXME: These aren't real tests; they simply confirm that a few
graph that could be relevant to the push-out optimizations can be compiled graph that could be relevant to the push-out optimizations can be compiled
and evaluated. None of them confirm that a push-out optimization has been and evaluated. None of them confirm that a push-out optimization has been
performed. performed.
FIXME: The rewrite is indeed broken, probably fro a long while, see FIXME details in the respective rewrite
""" """
def test_sum_dot(self): def test_sum_dot(self):
...@@ -614,7 +616,15 @@ class TestPushOutAddScan: ...@@ -614,7 +616,15 @@ class TestPushOutAddScan:
sequences=[A.dimshuffle(0, 1, "x"), B.dimshuffle(0, "x", 1)], sequences=[A.dimshuffle(0, 1, "x"), B.dimshuffle(0, "x", 1)],
outputs_info=[pt.zeros_like(A)], outputs_info=[pt.zeros_like(A)],
) )
# FIXME: This `s.owner.inputs[0][-1]` is a hack, users will never do that.
# They will do `s[-1]` which the rewrite fails to identify since it explicitly looks for a `scan_out[-1]`
# instead of `scan_out[1:][-1]` that the user would define by writing `s[-1]`
# It however, tests the only case the rewrite supports now
f = function([A, B], S.owner.inputs[0][-1]) f = function([A, B], S.owner.inputs[0][-1])
has_scan = any(isinstance(node.op, Scan) for node in f.maker.fgraph.apply_nodes)
# Rewrite is only triggered in fast_run mode
assert has_scan if (config.mode == "FAST_COMPILE") else (not has_scan)
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
vA = rng.uniform(size=(5, 5)).astype(config.floatX) vA = rng.uniform(size=(5, 5)).astype(config.floatX)
vB = rng.uniform(size=(5, 5)).astype(config.floatX) vB = rng.uniform(size=(5, 5)).astype(config.floatX)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论