提交 220fef2d authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Generalize `local_subtensor_of_elemwise` to Blockwise

上级 49f83bc0
...@@ -20,6 +20,7 @@ from pytensor.tensor.basic import ( ...@@ -20,6 +20,7 @@ from pytensor.tensor.basic import (
join, join,
register_infer_shape, register_infer_shape,
) )
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import squeeze from pytensor.tensor.extra_ops import squeeze
...@@ -169,8 +170,8 @@ def local_subtensor_of_dot(fgraph, node): ...@@ -169,8 +170,8 @@ def local_subtensor_of_dot(fgraph, node):
@register_canonicalize("shape_unsafe") @register_canonicalize("shape_unsafe")
@register_specialize("shape_unsafe") @register_specialize("shape_unsafe")
@node_rewriter([Subtensor]) @node_rewriter([Subtensor])
def local_subtensor_of_elemwise(fgraph, node): def local_subtensor_of_batch_dims(fgraph, node):
"""Lift a Subtensor through an Elemwise and its implicit broadcasting behavior. """Lift a Subtensor through the batch dims of an (Elemwise or Blockwise) operation and its implicit broadcasting behavior.
exp(x)[:, 0] -> exp(x[:, 0]) exp(x)[:, 0] -> exp(x[:, 0])
add(x, y)[0] -> add(x[0], y[0]) add(x, y)[0] -> add(x[0], y[0])
...@@ -178,7 +179,7 @@ def local_subtensor_of_elemwise(fgraph, node): ...@@ -178,7 +179,7 @@ def local_subtensor_of_elemwise(fgraph, node):
""" """
elem, *idx = node.inputs elem, *idx = node.inputs
if not (elem.owner and isinstance(elem.owner.op, Elemwise)): if not (elem.owner and isinstance(elem.owner.op, Elemwise | Blockwise)):
return None return None
if len(fgraph.clients[elem]) > 1: if len(fgraph.clients[elem]) > 1:
...@@ -188,9 +189,34 @@ def local_subtensor_of_elemwise(fgraph, node): ...@@ -188,9 +189,34 @@ def local_subtensor_of_elemwise(fgraph, node):
idx_tuple = indices_from_subtensor(idx, node.op.idx_list) idx_tuple = indices_from_subtensor(idx, node.op.idx_list)
batch_ndim = (
elem.owner.op.batch_ndim(elem.owner)
if isinstance(elem.owner.op, Blockwise)
else elem.ndim
)
if len(idx_tuple) > batch_ndim:
# Indexing on core dimensions of Blockwise. We split the indices and lift the batch ones only
batch_indices, core_indices = idx_tuple[:batch_ndim], idx_tuple[batch_ndim:]
if all(is_full_slice(idx) for idx in batch_indices):
# No batch indices, nothing to do
return None
elem_with_batch_indices = elem[batch_indices]
[elem_with_batch_indices_lifted] = local_subtensor_of_batch_dims.transform(
fgraph, elem_with_batch_indices.owner
)
# Reapply the core_indices
core_ndim = elem.type.ndim - batch_ndim
# Number of batch dims may have changed with the lifting of indices, so we recompute
new_batch_ndim = elem_with_batch_indices_lifted.type.ndim - core_ndim
new_indices = (*(slice(None),) * new_batch_ndim, *core_indices)
new_elem = elem_with_batch_indices_lifted[new_indices]
copy_stack_trace(node.outputs[0], new_elem)
return [new_elem]
elem_inputs = elem.owner.inputs elem_inputs = elem.owner.inputs
elem_bcast = elem.type.broadcastable elem_bcast = elem.type.broadcastable[:batch_ndim]
if all(inp.type.broadcastable == elem_bcast for inp in elem_inputs): if all(inp.type.broadcastable[:batch_ndim] == elem_bcast for inp in elem_inputs):
# No need to worry about implicit broadcasting. # No need to worry about implicit broadcasting.
indexed_inputs = [inp[idx_tuple] for inp in elem_inputs] indexed_inputs = [inp[idx_tuple] for inp in elem_inputs]
...@@ -201,7 +227,7 @@ def local_subtensor_of_elemwise(fgraph, node): ...@@ -201,7 +227,7 @@ def local_subtensor_of_elemwise(fgraph, node):
zip( zip(
idx_tuple, idx_tuple,
elem_bcast, elem_bcast,
*(inp.type.broadcastable for inp in elem_inputs), *(inp.type.broadcastable[:batch_ndim] for inp in elem_inputs),
# Indices can be shorter than input ndims # Indices can be shorter than input ndims
strict=False, strict=False,
) )
......
...@@ -14,6 +14,7 @@ from pytensor.compile import DeepCopyOp, get_default_mode, get_mode ...@@ -14,6 +14,7 @@ from pytensor.compile import DeepCopyOp, get_default_mode, get_mode
from pytensor.graph import ( from pytensor.graph import (
Constant, Constant,
FunctionGraph, FunctionGraph,
Op,
RewriteDatabaseQuery, RewriteDatabaseQuery,
Type, Type,
rewrite_graph, rewrite_graph,
...@@ -23,6 +24,7 @@ from pytensor.graph.rewriting.basic import check_stack_trace ...@@ -23,6 +24,7 @@ from pytensor.graph.rewriting.basic import check_stack_trace
from pytensor.printing import debugprint from pytensor.printing import debugprint
from pytensor.tensor import ( from pytensor.tensor import (
add, add,
dvector,
exp, exp,
iscalar, iscalar,
iscalars, iscalars,
...@@ -39,11 +41,12 @@ from pytensor.tensor import ( ...@@ -39,11 +41,12 @@ from pytensor.tensor import (
from pytensor.tensor.basic import MakeVector, concatenate, expand_dims, make_vector from pytensor.tensor.basic import MakeVector, concatenate, expand_dims, make_vector
from pytensor.tensor.blas import Dot22, Gemv from pytensor.tensor.blas import Dot22, Gemv
from pytensor.tensor.blas_c import CGemv from pytensor.tensor.blas_c import CGemv
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.rewriting.subtensor_lift import ( from pytensor.tensor.rewriting.subtensor_lift import (
local_subtensor_make_vector, local_subtensor_make_vector,
local_subtensor_of_elemwise, local_subtensor_of_batch_dims,
local_subtensor_shape_constant, local_subtensor_shape_constant,
) )
from pytensor.tensor.shape import SpecifyShape, _shape from pytensor.tensor.shape import SpecifyShape, _shape
...@@ -60,7 +63,7 @@ mode_opt = get_mode(mode_opt) ...@@ -60,7 +63,7 @@ mode_opt = get_mode(mode_opt)
NO_OPTIMIZATION_MODE = Mode(linker="py", optimizer=None) NO_OPTIMIZATION_MODE = Mode(linker="py", optimizer=None)
class TestLocalSubtensorOfElemwise: class TestLocalSubtensorOfBatchDims:
def test_unary_multiple_clients(self): def test_unary_multiple_clients(self):
# as test0, but we reuse the output of the elemwise # as test0, but we reuse the output of the elemwise
# So we should not lift the subtensor # So we should not lift the subtensor
...@@ -146,7 +149,7 @@ class TestLocalSubtensorOfElemwise: ...@@ -146,7 +149,7 @@ class TestLocalSubtensorOfElemwise:
), ),
], ],
) )
def test_local_subtensor_of_elemwise(self, original_fn, expected_fn): def test_elemwise(self, original_fn, expected_fn):
rng = np.random.default_rng(257) rng = np.random.default_rng(257)
x = pt.matrix("x", shape=(5, 3)) x = pt.matrix("x", shape=(5, 3))
y = pt.matrix("y", shape=(5, 3)) y = pt.matrix("y", shape=(5, 3))
...@@ -165,7 +168,7 @@ class TestLocalSubtensorOfElemwise: ...@@ -165,7 +168,7 @@ class TestLocalSubtensorOfElemwise:
out.eval({x: x_test, y: y_test}, **eval_kwargs), out.eval({x: x_test, y: y_test}, **eval_kwargs),
) )
def test_local_subtensor_of_elemwise_multiple_clients(self): def test_elemwise_multiple_clients(self):
x = pt.matrix("x", shape=(5, 3)) x = pt.matrix("x", shape=(5, 3))
y = pt.matrix("y", shape=(5, 3)) y = pt.matrix("y", shape=(5, 3))
out1 = add(x, y) out1 = add(x, y)
...@@ -173,11 +176,48 @@ class TestLocalSubtensorOfElemwise: ...@@ -173,11 +176,48 @@ class TestLocalSubtensorOfElemwise:
# Rewrite should fail when another node uses out1 directly (in this case it's an extra output) # Rewrite should fail when another node uses out1 directly (in this case it's an extra output)
fgraph = FunctionGraph([x, y], [out1, out2], clone=False) fgraph = FunctionGraph([x, y], [out1, out2], clone=False)
assert local_subtensor_of_elemwise.transform(fgraph, out2.owner) is None assert local_subtensor_of_batch_dims.transform(fgraph, out2.owner) is None
# Otherwise it should work # Otherwise it should work
fgraph.remove_output(0) fgraph.remove_output(0)
assert local_subtensor_of_elemwise.transform(fgraph, out2.owner) is not None assert local_subtensor_of_batch_dims.transform(fgraph, out2.owner) is not None
def test_blockwise(self):
class CoreTestOp(Op):
itypes = [dvector, dvector]
otypes = [dvector]
def perform(self, node, inputs, output_storage):
output_storage[0][0] = np.convolve(*inputs, mode="valid")
core_test_op = CoreTestOp()
block_test_op = Blockwise(core_test_op, signature="(a),(b)->(c)")
x = tensor3("x", shape=(7, 5, 11), dtype="float64")
y = tensor("y", shape=(7, 33), dtype="float64")
out = block_test_op(x, y[:, None, :])
assert isinstance(out.owner.op, Blockwise)
out_sliced = out[2:][:, 3:]
rewritten_out_sliced = rewrite_graph(out_sliced)
expected_out_sliced = block_test_op(x[2:, 3:], y[2:][:, None, :])
assert equal_computations([rewritten_out_sliced], [expected_out_sliced])
rng = np.random.default_rng(191)
x_test = rng.normal(size=x.type.shape).astype(x.type.dtype)
y_test = rng.normal(size=y.type.shape).astype(y.type.dtype)
np.testing.assert_allclose(
rewritten_out_sliced.eval(
{x: x_test, y: y_test}, mode=NO_OPTIMIZATION_MODE
),
out_sliced.eval({x: x_test, y: y_test}, mode=NO_OPTIMIZATION_MODE),
)
# Check slice on core dims
out_sliced = out[2:][:, 0][:, 4:]
rewritten_out_sliced = rewrite_graph(out_sliced)
expected_out_sliced = block_test_op(x[2:, 0], y[2:])[:, 4:]
assert equal_computations([rewritten_out_sliced], [expected_out_sliced])
def test_local_subtensor_of_dot(): def test_local_subtensor_of_dot():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论