提交 e265debc authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Generalize dot rewrites to work with Blockwise

上级 b437c8d6
......@@ -44,6 +44,7 @@ from pytensor.tensor.math import (
Prod,
Sum,
_conj,
_dot,
_inner_prod,
_matrix_matrix_matmul,
_matrix_vec_prod,
......@@ -98,6 +99,7 @@ from pytensor.tensor.rewriting.basic import (
register_useless,
)
from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift
from pytensor.tensor.rewriting.linalg import is_matrix_transpose
from pytensor.tensor.shape import Shape, Shape_i
from pytensor.tensor.subtensor import Subtensor
from pytensor.tensor.type import (
......@@ -175,21 +177,20 @@ def local_lift_transpose_through_dot(fgraph, node):
These rewrites "lift" (propagate towards the inputs) `DimShuffle`
through dot product. It allows to put the graph in a more standard shape,
and to later merge consecutive `DimShuffle`\s.
The transformation should be apply whether or not the transpose is
inplace. The newly-introduced transpositions are not inplace, this will
be taken care of in a later rewrite phase.
"""
if not (isinstance(node.op, DimShuffle) and node.op.new_order == (1, 0)):
return False
if not (node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Dot)):
if not (
is_matrix_transpose(node.outputs[0])
and node.inputs[0].owner
and ((dot_op := node.inputs[0].owner.op) in (_dot, _matrix_matrix_matmul))
):
return False
x, y = node.inputs[0].owner.inputs
if x.ndim == y.ndim == 2:
if x.ndim >= y.ndim >= 2:
# Output is dot product of transposed inputs in reverse order
ret = [dot(y.T, x.T)]
ret = [dot_op(y.mT, x.mT)]
# Copy over stack trace to output from result of dot-product
copy_stack_trace(node.inputs[0], ret)
......
......@@ -5,7 +5,7 @@ import numpy as np
from pytensor import Variable
from pytensor.compile import optdb
from pytensor.graph import Constant, FunctionGraph, node_rewriter
from pytensor.graph import Constant, FunctionGraph, node_rewriter, vectorize_graph
from pytensor.graph.rewriting.basic import NodeRewriter, copy_stack_trace
from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple
from pytensor.scalar import basic as ps
......@@ -119,21 +119,48 @@ def local_subtensor_of_dot(fgraph, node):
the remaining entries of ``idxs`` (if any), modified to skip the
second-to-last dimension of ``B`` (because dot sums over this dimension).
"""
if not isinstance(node.op, Subtensor):
return
if not (node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Dot)):
x, *idx_vars = node.inputs
if not (
x.owner is not None
and (
isinstance(x.owner.op, Dot)
or (
isinstance(x.owner.op, Blockwise)
and isinstance(x.owner.op.core_op, Dot)
)
)
):
return
# If there is other node that use the outputs of the dot
# We don't want to compute twice the sub part.
if len(fgraph.clients[node.inputs[0]]) > 1:
if len(fgraph.clients[x]) > 1:
return
a = node.inputs[0].owner.inputs[0]
b = node.inputs[0].owner.inputs[1]
a = x.owner.inputs[0]
b = x.owner.inputs[1]
idx_list = indices_from_subtensor(idx_vars, node.op.idx_list)
idx_list = get_idx_list(node.inputs, node.op.idx_list)
if not idx_list:
# Nothing to do, `local_useless_slice` will handle this case
return None
num_a_indices = min(a.ndim - 1, len(idx_list))
batch_ndim = (
x.owner.op.batch_ndim(x.owner) if isinstance(x.owner.op, Blockwise) else 0
)
if batch_ndim:
batch_idx_list, idx_list = idx_list[:batch_ndim], idx_list[batch_ndim:]
if not idx_list:
# Indexing only over batch dimensions of Blockwise, nothing to do here
# This will be handled by `local_subtensor_of_batch_dims`
return None
# We perform the rest of the rewrite on dummy a, b that correspond to the core case
a = a.type.clone(shape=a.type.shape[batch_ndim:])()
b = b.type.clone(shape=b.type.shape[batch_ndim:])()
a_ndim = a.ndim
b_ndim = b.ndim
num_a_indices = min(a_ndim - 1, len(idx_list))
a_indices = idx_list[:num_a_indices]
b_indices = idx_list[num_a_indices:]
......@@ -142,26 +169,22 @@ def local_subtensor_of_dot(fgraph, node):
# This wasn't necessary for a, because we just omitted the last index.
# We skip this if b.ndim = 1, since then we just want b_sub = b, not b_sub = b[:]
# (dot also handles b.ndim < 2 as a special case)
if b.ndim > 1 and len(b_indices) >= b.ndim - 1:
if b_ndim > 1 and len(b_indices) >= b_ndim - 1:
b_indices = (
b_indices[: b.ndim - 2]
b_indices[: b_ndim - 2]
+ (slice(None, None, None),)
+ b_indices[b.ndim - 2 :]
+ b_indices[b_ndim - 2 :]
)
a_sub = a.__getitem__(tuple(a_indices))
b_sub = b.__getitem__(tuple(b_indices)) if b_indices else b
a_sub = a[tuple(a_indices)]
b_sub = b[tuple(b_indices)] if b_indices else b
r = dot(a_sub, b_sub)
# Copy over previous output stacktrace to a_sub and b_sub,
# because an error in the subtensor operation (e.g. an index error)
# on either a or b must correspond to an error in the
# subtensor operation on their dot product.
copy_stack_trace(node.outputs[0], [a_sub, b_sub])
if batch_ndim:
# Replace dummy inputs by the original batch ones
r = vectorize_graph(r, replace={a: x.owner.inputs[0], b: x.owner.inputs[1]})
r = r[tuple(batch_idx_list)]
# Copy over previous output stacktrace and previous dot product stacktrace,
# because an error here may correspond to an either in either the original
# dot product, or in the dot product after the subtensor operation.
r = dot(a_sub, b_sub)
copy_stack_trace([node.outputs[0], node.inputs[0]], r)
return [r]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论