提交 5046519a authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Move subtensor blockwise rewrite

上级 4b18f908
......@@ -17,7 +17,6 @@ from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedSubtensor,
Subtensor,
indices_from_subtensor,
)
......@@ -229,29 +228,6 @@ def local_blockwise_reshape(fgraph, node):
return [new_out]
@register_stabilize
@register_specialize
@node_rewriter([Blockwise])
def local_blockwise_of_subtensor(fgraph, node):
"""Rewrite Blockwise of Subtensor, where the only batch input is the indexed tensor.
Blockwise(Subtensor{a: b})(x, a, b) -> x[:, a:b] when x has one batch dimension, and a/b none
"""
if not isinstance(node.op.core_op, Subtensor):
return
x, *idxs = node.inputs
if not all(all(idx.type.broadcastable) for idx in idxs):
return
core_idxs = indices_from_subtensor(
[idx.squeeze() for idx in idxs], node.op.core_op.idx_list
)
# Add empty slices for the batch dims
none_slices = (slice(None),) * node.op.batch_ndim(node)
return [x[(*none_slices, *core_idxs)]]
class InplaceBlockwiseOptimizer(InplaceGraphOptimizer):
op = Blockwise
......
......@@ -1571,6 +1571,29 @@ compile.optdb.register(
)
@register_stabilize
@register_specialize
@node_rewriter([Blockwise])
def local_blockwise_of_subtensor(fgraph, node):
"""Rewrite Blockwise of Subtensor, where the only batch input is the indexed tensor.
Blockwise(Subtensor{a: b})(x, a, b) -> x[:, a:b] when x has one batch dimension, and a/b none
"""
if not isinstance(node.op.core_op, Subtensor):
return
x, *idxs = node.inputs
if not all(all(idx.type.broadcastable) for idx in idxs):
return
core_idxs = indices_from_subtensor(
[idx.squeeze() for idx in idxs], node.op.core_op.idx_list
)
# Add empty slices for the batch dims
none_slices = (slice(None),) * node.op.batch_ndim(node)
return [x[(*none_slices, *core_idxs)]]
@register_canonicalize("shape_unsafe")
@register_stabilize("shape_unsafe")
@register_specialize("shape_unsafe")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论