提交 0c398e34 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Rewrite away blockwise Subtensor in gradient of Blockwise(Conv1d)

上级 a0a494ab
......@@ -14,7 +14,12 @@ from pytensor.tensor.rewriting.basic import (
register_stabilize,
)
from pytensor.tensor.shape import Reshape
from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor, Subtensor
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedSubtensor,
Subtensor,
indices_from_subtensor,
)
@node_rewriter([Blockwise])
......@@ -216,9 +221,9 @@ def local_blockwise_reshape(fgraph, node):
Reshape is tricky to vectorize eagerly, because a graph like
`x.reshape([x.shape[0] * x.shape[1], -1])` has many operations
that must be vectorized before we arrize at the reshape operation.
that must be vectorized before we arrive at the reshape operation.
For the square Reshape case, we must wait for all the intemediate
For the square Reshape case, we must wait for all the intermediate
operations to be lifted as Allocs
"""
if not isinstance(node.op.core_op, Reshape):
......@@ -234,6 +239,29 @@ def local_blockwise_reshape(fgraph, node):
return [new_out]
@register_stabilize
@register_specialize
@node_rewriter([Blockwise])
def local_blockwise_of_subtensor(fgraph, node):
"""Rewrite Blockwise of Subtensor, where the only batch input is the indexed tensor.
Blockwise(Subtensor{a: b})(x, a, b) -> x[:, a:b] when x has one batch dimension, and a/b none
"""
if not isinstance(node.op.core_op, Subtensor):
return
x, *idxs = node.inputs
if not all(all(idx.type.broadcastable) for idx in idxs):
return
core_idxs = indices_from_subtensor(
[idx.squeeze() for idx in idxs], node.op.core_op.idx_list
)
# Add empty slices for the batch dims
none_slices = (slice(None),) * node.op.batch_ndim(node)
return [x[(*none_slices, *core_idxs)]]
@node_rewriter(tracks=[Blockwise], inplace=True)
def blockwise_inplace(fgraph, node):
blockwise_op = node.op
......
......@@ -4,9 +4,11 @@ import numpy as np
import pytest
from scipy.signal import convolve as scipy_convolve
from pytensor import config, function
from pytensor import config, function, grad
from pytensor.graph import ancestors, rewrite_graph
from pytensor.tensor import matrix, vector
from pytensor.tensor.signal.conv import convolve1d
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.signal.conv import Conv1d, convolve1d
from tests import unittest_tools as utt
......@@ -60,3 +62,23 @@ def test_convolve1d_batch_same():
res = out.eval({x: x_test, y: y_test})
assert res.shape == (2, 8)
@pytest.mark.parametrize("mode", ("full", "valid", "same"))
def test_convolve1d_batch_graph(mode):
"""Test that we don't have slow Blockwise Subtensors in graph of a batched convolve1d"""
x = matrix("x")
y = matrix("y")
out = convolve1d(x, y, mode=mode)
grads = grad(out.sum(), wrt=[x, y])
final_grads = rewrite_graph(
grads, include=("ShapeOpt", "canonicalize", "stabilize", "specialize")
)
blockwise_nodes = [
var.owner
for var in ancestors(final_grads)
if var.owner is not None and isinstance(var.owner.op, Blockwise)
]
# Check any Blockwise are just Conv1d
assert all(isinstance(node.op.core_op, Conv1d) for node in blockwise_nodes)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论