提交 220fef2d authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Generalize `local_subtensor_of_elemwise` to Blockwise

上级 49f83bc0
......@@ -20,6 +20,7 @@ from pytensor.tensor.basic import (
join,
register_infer_shape,
)
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import squeeze
......@@ -169,8 +170,8 @@ def local_subtensor_of_dot(fgraph, node):
@register_canonicalize("shape_unsafe")
@register_specialize("shape_unsafe")
@node_rewriter([Subtensor])
def local_subtensor_of_elemwise(fgraph, node):
"""Lift a Subtensor through an Elemwise and its implicit broadcasting behavior.
def local_subtensor_of_batch_dims(fgraph, node):
"""Lift a Subtensor through the batch dims of an (Elemwise or Blockwise) operation and its implicit broadcasting behavior.
exp(x)[:, 0] -> exp(x[:, 0])
add(x, y)[0] -> add(x[0], y[0])
......@@ -178,7 +179,7 @@ def local_subtensor_of_elemwise(fgraph, node):
"""
elem, *idx = node.inputs
if not (elem.owner and isinstance(elem.owner.op, Elemwise)):
if not (elem.owner and isinstance(elem.owner.op, Elemwise | Blockwise)):
return None
if len(fgraph.clients[elem]) > 1:
......@@ -188,9 +189,34 @@ def local_subtensor_of_elemwise(fgraph, node):
idx_tuple = indices_from_subtensor(idx, node.op.idx_list)
batch_ndim = (
elem.owner.op.batch_ndim(elem.owner)
if isinstance(elem.owner.op, Blockwise)
else elem.ndim
)
if len(idx_tuple) > batch_ndim:
# Indexing on core dimensions of Blockwise. We split the indices and lift the batch ones only
batch_indices, core_indices = idx_tuple[:batch_ndim], idx_tuple[batch_ndim:]
if all(is_full_slice(idx) for idx in batch_indices):
# No batch indices, nothing to do
return None
elem_with_batch_indices = elem[batch_indices]
[elem_with_batch_indices_lifted] = local_subtensor_of_batch_dims.transform(
fgraph, elem_with_batch_indices.owner
)
# Reapply the core_indices
core_ndim = elem.type.ndim - batch_ndim
# Number of batch dims may have changed with the lifting of indices, so we recompute
new_batch_ndim = elem_with_batch_indices_lifted.type.ndim - core_ndim
new_indices = (*(slice(None),) * new_batch_ndim, *core_indices)
new_elem = elem_with_batch_indices_lifted[new_indices]
copy_stack_trace(node.outputs[0], new_elem)
return [new_elem]
elem_inputs = elem.owner.inputs
elem_bcast = elem.type.broadcastable
if all(inp.type.broadcastable == elem_bcast for inp in elem_inputs):
elem_bcast = elem.type.broadcastable[:batch_ndim]
if all(inp.type.broadcastable[:batch_ndim] == elem_bcast for inp in elem_inputs):
# No need to worry about implicit broadcasting.
indexed_inputs = [inp[idx_tuple] for inp in elem_inputs]
......@@ -201,7 +227,7 @@ def local_subtensor_of_elemwise(fgraph, node):
zip(
idx_tuple,
elem_bcast,
*(inp.type.broadcastable for inp in elem_inputs),
*(inp.type.broadcastable[:batch_ndim] for inp in elem_inputs),
# Indices can be shorter than input ndims
strict=False,
)
......
......@@ -14,6 +14,7 @@ from pytensor.compile import DeepCopyOp, get_default_mode, get_mode
from pytensor.graph import (
Constant,
FunctionGraph,
Op,
RewriteDatabaseQuery,
Type,
rewrite_graph,
......@@ -23,6 +24,7 @@ from pytensor.graph.rewriting.basic import check_stack_trace
from pytensor.printing import debugprint
from pytensor.tensor import (
add,
dvector,
exp,
iscalar,
iscalars,
......@@ -39,11 +41,12 @@ from pytensor.tensor import (
from pytensor.tensor.basic import MakeVector, concatenate, expand_dims, make_vector
from pytensor.tensor.blas import Dot22, Gemv
from pytensor.tensor.blas_c import CGemv
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.rewriting.subtensor_lift import (
local_subtensor_make_vector,
local_subtensor_of_elemwise,
local_subtensor_of_batch_dims,
local_subtensor_shape_constant,
)
from pytensor.tensor.shape import SpecifyShape, _shape
......@@ -60,7 +63,7 @@ mode_opt = get_mode(mode_opt)
NO_OPTIMIZATION_MODE = Mode(linker="py", optimizer=None)
class TestLocalSubtensorOfElemwise:
class TestLocalSubtensorOfBatchDims:
def test_unary_multiple_clients(self):
# as test0, but we reuse the output of the elemwise
# So we should not lift the subtensor
......@@ -146,7 +149,7 @@ class TestLocalSubtensorOfElemwise:
),
],
)
def test_local_subtensor_of_elemwise(self, original_fn, expected_fn):
def test_elemwise(self, original_fn, expected_fn):
rng = np.random.default_rng(257)
x = pt.matrix("x", shape=(5, 3))
y = pt.matrix("y", shape=(5, 3))
......@@ -165,7 +168,7 @@ class TestLocalSubtensorOfElemwise:
out.eval({x: x_test, y: y_test}, **eval_kwargs),
)
def test_local_subtensor_of_elemwise_multiple_clients(self):
def test_elemwise_multiple_clients(self):
x = pt.matrix("x", shape=(5, 3))
y = pt.matrix("y", shape=(5, 3))
out1 = add(x, y)
......@@ -173,11 +176,48 @@ class TestLocalSubtensorOfElemwise:
# Rewrite should fail when another node uses out1 directly (in this case it's an extra output)
fgraph = FunctionGraph([x, y], [out1, out2], clone=False)
assert local_subtensor_of_elemwise.transform(fgraph, out2.owner) is None
assert local_subtensor_of_batch_dims.transform(fgraph, out2.owner) is None
# Otherwise it should work
fgraph.remove_output(0)
assert local_subtensor_of_elemwise.transform(fgraph, out2.owner) is not None
assert local_subtensor_of_batch_dims.transform(fgraph, out2.owner) is not None
def test_blockwise(self):
class CoreTestOp(Op):
itypes = [dvector, dvector]
otypes = [dvector]
def perform(self, node, inputs, output_storage):
output_storage[0][0] = np.convolve(*inputs, mode="valid")
core_test_op = CoreTestOp()
block_test_op = Blockwise(core_test_op, signature="(a),(b)->(c)")
x = tensor3("x", shape=(7, 5, 11), dtype="float64")
y = tensor("y", shape=(7, 33), dtype="float64")
out = block_test_op(x, y[:, None, :])
assert isinstance(out.owner.op, Blockwise)
out_sliced = out[2:][:, 3:]
rewritten_out_sliced = rewrite_graph(out_sliced)
expected_out_sliced = block_test_op(x[2:, 3:], y[2:][:, None, :])
assert equal_computations([rewritten_out_sliced], [expected_out_sliced])
rng = np.random.default_rng(191)
x_test = rng.normal(size=x.type.shape).astype(x.type.dtype)
y_test = rng.normal(size=y.type.shape).astype(y.type.dtype)
np.testing.assert_allclose(
rewritten_out_sliced.eval(
{x: x_test, y: y_test}, mode=NO_OPTIMIZATION_MODE
),
out_sliced.eval({x: x_test, y: y_test}, mode=NO_OPTIMIZATION_MODE),
)
# Check slice on core dims
out_sliced = out[2:][:, 0][:, 4:]
rewritten_out_sliced = rewrite_graph(out_sliced)
expected_out_sliced = block_test_op(x[2:, 0], y[2:])[:, 4:]
assert equal_computations([rewritten_out_sliced], [expected_out_sliced])
def test_local_subtensor_of_dot():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论