提交 10c36d2a authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Apply useless blockwise rewrite when there are only dummy batch dims

Also extend eager rewrite to more Ops The Blockwise MatrixInverse grad test became more sensitive in float32, because desired stabilization rewrites (mainly `inv_as_solve`) that target Dot of Blockwise{MatrixInverse} are now triggered in the default blockwise grad but not in the non-default non-blockwise grad
上级 fe5865ef
...@@ -163,8 +163,8 @@ class Blockwise(Op): ...@@ -163,8 +163,8 @@ class Blockwise(Op):
return Apply(self, batched_inputs, batched_outputs) return Apply(self, batched_inputs, batched_outputs)
def _batch_ndim_from_outputs(self, outputs: Sequence[TensorVariable]) -> int: def batch_ndim(self, node: Apply) -> int:
return cast(int, outputs[0].type.ndim - len(self.outputs_sig[0])) return cast(int, node.outputs[0].type.ndim - len(self.outputs_sig[0]))
def infer_shape( def infer_shape(
self, fgraph, node, input_shapes self, fgraph, node, input_shapes
...@@ -172,7 +172,7 @@ class Blockwise(Op): ...@@ -172,7 +172,7 @@ class Blockwise(Op):
from pytensor.tensor import broadcast_shape from pytensor.tensor import broadcast_shape
from pytensor.tensor.shape import Shape_i from pytensor.tensor.shape import Shape_i
batch_ndims = self._batch_ndim_from_outputs(node.outputs) batch_ndims = self.batch_ndim(node)
core_dims: dict[str, Any] = {} core_dims: dict[str, Any] = {}
batch_shapes = [] batch_shapes = []
for input_shape, sig in zip(input_shapes, self.inputs_sig): for input_shape, sig in zip(input_shapes, self.inputs_sig):
...@@ -278,7 +278,7 @@ class Blockwise(Op): ...@@ -278,7 +278,7 @@ class Blockwise(Op):
return new_rval return new_rval
# Sum out the broadcasted dimensions # Sum out the broadcasted dimensions
batch_ndims = self._batch_ndim_from_outputs(outs) batch_ndims = self.batch_ndim(outs[0].owner)
batch_shape = outs[0].type.shape[:batch_ndims] batch_shape = outs[0].type.shape[:batch_ndims]
for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)): for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)):
if isinstance(rval[i].type, (NullType, DisconnectedType)): if isinstance(rval[i].type, (NullType, DisconnectedType)):
...@@ -320,7 +320,7 @@ class Blockwise(Op): ...@@ -320,7 +320,7 @@ class Blockwise(Op):
return self._gufunc return self._gufunc
def _check_runtime_broadcast(self, node, inputs): def _check_runtime_broadcast(self, node, inputs):
batch_ndim = self._batch_ndim_from_outputs(node.outputs) batch_ndim = self.batch_ndim(node)
for dims_and_bcast in zip( for dims_and_bcast in zip(
*[ *[
......
...@@ -2,9 +2,15 @@ from pytensor.compile.mode import optdb ...@@ -2,9 +2,15 @@ from pytensor.compile.mode import optdb
from pytensor.graph import node_rewriter from pytensor.graph import node_rewriter
from pytensor.graph.replace import vectorize_node from pytensor.graph.replace import vectorize_node
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
from pytensor.tensor.basic import Alloc, ARange, shape_padleft
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.math import _matrix_matrix_matmul from pytensor.tensor.math import Dot
from pytensor.tensor.rewriting.basic import register_canonicalize from pytensor.tensor.rewriting.basic import (
register_canonicalize,
register_specialize,
register_stabilize,
)
from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor, Subtensor
@node_rewriter([Blockwise]) @node_rewriter([Blockwise])
...@@ -29,8 +35,17 @@ def local_useless_unbatched_blockwise(fgraph, node): ...@@ -29,8 +35,17 @@ def local_useless_unbatched_blockwise(fgraph, node):
op = node.op op = node.op
inputs = node.inputs inputs = node.inputs
if max(inp.type.ndim - len(sig) for inp, sig in zip(inputs, op.inputs_sig)) == 0: batch_ndims = node.op.batch_ndim(node)
return copy_stack_trace(node.outputs, op.core_op.make_node(*inputs).outputs) if all(all(inp.type.broadcastable[:batch_ndims]) for inp in inputs):
if batch_ndims:
# Remove dummy batch dims
axis = tuple(range(batch_ndims))
inputs = [inp.squeeze(axis) for inp in inputs]
new_outs = op.core_op.make_node(*inputs).outputs
if batch_ndims:
# Reintroduce dummy batch dims
new_outs = [shape_padleft(out, batch_ndims) for out in new_outs]
return copy_stack_trace(node.outputs, new_outs)
# We register this rewrite late, so that other rewrites need only target Blockwise Ops # We register this rewrite late, so that other rewrites need only target Blockwise Ops
...@@ -46,6 +61,22 @@ optdb.register( ...@@ -46,6 +61,22 @@ optdb.register(
# Avoid redundant cases early on for Ops whose default form is not Blockwised # Avoid redundant cases early on for Ops whose default form is not Blockwised
@register_canonicalize @register_canonicalize
@node_rewriter(tracks=[_matrix_matrix_matmul]) @register_stabilize
@register_specialize
@node_rewriter(tracks=[Blockwise])
def local_eager_useless_unbatched_blockwise(fgraph, node): def local_eager_useless_unbatched_blockwise(fgraph, node):
if isinstance(
node.op.core_op,
(
# Many Dot-related rewrites (e.g., all of BlasOpt) happen before specialize
Dot,
# These Ops can't always be trivially vectorized at runtime,
# Since their inputs may imply non-rectangular shapes.
Alloc,
ARange,
Subtensor,
AdvancedSubtensor,
AdvancedIncSubtensor,
),
):
return local_useless_unbatched_blockwise.fn(fgraph, node) return local_useless_unbatched_blockwise.fn(fgraph, node)
...@@ -293,7 +293,7 @@ class BlockwiseOpTester: ...@@ -293,7 +293,7 @@ class BlockwiseOpTester:
pt_out, pt_out,
np_out, np_out,
rtol=1e-7 if config.floatX == "float64" else 1e-5, rtol=1e-7 if config.floatX == "float64" else 1e-5,
atol=1e-6 if config.floatX == "float64" else 1e-5, atol=1e-6 if config.floatX == "float64" else 1e-4,
) )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论