提交 4d539fa5 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Canonicalize subtensor negative integer indices

上级 617964ff
......@@ -724,6 +724,44 @@ def local_useless_subtensor(fgraph, node):
return [node.inputs[0]]
@register_canonicalize
@node_rewriter([Subtensor])
def local_convert_negative_indices(fgraph, node):
"""Convert negative indices in `Subtensor` with static length to positive indices."""
x, *raw_idxs = node.inputs
idxs = indices_from_subtensor(raw_idxs, node.op.idx_list)
new_idxs = None
for i, (dim_length, idx) in enumerate(zip(x.type.shape, idxs)):
if (
dim_length is None
or isinstance(idx, slice)
or not isinstance(idx, Constant)
):
continue
val = idx.data
if val >= 0:
continue
new_val = val + dim_length
if new_val < 0:
# This is an invalid index, keep original to not confuse the user
return None
if new_idxs is None:
new_idxs = list(idxs)
new_idxs[i] = new_val
if new_idxs is None:
# No negative indices to convert
return None
new_subtensor = x[tuple(new_idxs)]
copy_stack_trace(node.outputs, new_subtensor)
return [new_subtensor]
@register_canonicalize
@register_specialize
@node_rewriter([AdvancedSubtensor1])
......
......@@ -1992,3 +1992,20 @@ def test_extract_diag_of_diagonal_set_subtensor():
expected_outs.append(outs[-1])
assert equal_computations(rewritten_outs, expected_outs)
def test_local_convert_negative_indices():
x = pt.tensor("x", shape=(None, 3, 1))
# Dim length is unknown rewrite can't be applied
rewritten_out = rewrite_graph(x[-2])
assert equal_computations([rewritten_out], [x[-2]])
# Rewrite applies
rewritten_out = rewrite_graph(x[:, -2])
assert equal_computations([rewritten_out], [x[:, 1]])
# Rewrite doesn't apply because index is invalid
# TODO: If Subtensor decides to raise on make_node, this test can be removed
rewritten_out = rewrite_graph(x[:, :, -2])
assert equal_computations([rewritten_out], [x[:, :, -2]])
......@@ -202,7 +202,7 @@ def test_local_subtensor_of_reduce(original_fn, expected_fn):
out = original_fn(x)
expected_opt_out = expected_fn(x)
opt_out = rewrite_graph(out)
opt_out = rewrite_graph(out, exclude=("local_convert_negative_indices",))
assert equal_computations([opt_out], [expected_opt_out]), debugprint(
[expected_opt_out, opt_out], print_type=True
)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论