提交 b437c8d6 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Lift subtensor through squeeze

上级 220fef2d
...@@ -461,6 +461,41 @@ def local_subtensor_of_expand_dims(fgraph, node): ...@@ -461,6 +461,41 @@ def local_subtensor_of_expand_dims(fgraph, node):
return [out] return [out]
@register_canonicalize
@register_specialize
@node_rewriter([Subtensor])
def local_subtensor_of_squeeze(fgraph, node):
"""Lift subtensor through a squeeze operation"""
x, *idxs_vars = node.inputs
if not (
x.owner is not None
and isinstance(x.owner.op, DimShuffle)
and x.owner.op.is_squeeze
):
return None
[x_before_squeeze] = x.owner.inputs
idxs = indices_from_subtensor(idxs_vars, node.op.idx_list)
dropped_dims = x.owner.op.drop
# Apply indices directly on x
# Add empty slices on the axis that squeeze would have removed
new_idxs = np.insert(np.array(idxs, dtype=object), dropped_dims, slice(None))
x_indexed = x_before_squeeze[tuple(new_idxs)]
# Reapply squeeze
# Indexing may have squeezed some dimensions, so we need to recalculate dropped_dims
new_dropped_dims = np.array(dropped_dims)
for i, new_idx in reversed(tuple(enumerate(new_idxs))):
if not isinstance(new_idx, slice):
# If it's not a slice, it's an integer which drops the dimension
new_dropped_dims[new_dropped_dims > i] -= 1
new_x = x_indexed.squeeze(tuple(new_dropped_dims))
copy_stack_trace(x, new_x)
return [new_x]
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@node_rewriter([Subtensor]) @node_rewriter([Subtensor])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论