提交 e1ce1c35 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Refactor `lower_aligned` helper

上级 45a33ada
......@@ -9,6 +9,7 @@ from pytensor.tensor import (
)
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.rewriting.basic import register_lower_xtensor
from pytensor.xtensor.rewriting.utils import lower_aligned
from pytensor.xtensor.shape import (
Concat,
ExpandDims,
......@@ -70,15 +71,7 @@ def lower_concat(fgraph, node):
concat_axis = out_dims.index(concat_dim)
# Convert input XTensors to Tensors and align batch dimensions
tensor_inputs = []
for inp in node.inputs:
inp_dims = inp.type.dims
order = [
inp_dims.index(out_dim) if out_dim in inp_dims else "x"
for out_dim in out_dims
]
tensor_inp = tensor_from_xtensor(inp).dimshuffle(order)
tensor_inputs.append(tensor_inp)
tensor_inputs = [lower_aligned(inp, out_dims) for inp in node.inputs]
# Broadcast non-concatenated dimensions of each input
non_concat_shape = [None] * len(out_dims)
......
import typing
from collections.abc import Sequence
from pytensor.compile import optdb
from pytensor.graph.rewriting.basic import NodeRewriter, in2out
from pytensor.graph.rewriting.db import EquilibriumDB, RewriteDatabase
from pytensor.tensor.rewriting.ofg import inline_ofg_expansion
from pytensor.tensor.variable import TensorVariable
from pytensor.xtensor.type import XTensorVariable
lower_xtensor_db = EquilibriumDB(ignore_newtrees=False)
......@@ -49,3 +54,10 @@ def register_lower_xtensor(
**kwargs,
)
return node_rewriter
def lower_aligned(x: XTensorVariable, out_dims: Sequence[str]) -> TensorVariable:
"""Lower an XTensorVariable to a TensorVariable so that it's dimensions are aligned with "out_dims"."""
inp_dims = {d: i for i, d in enumerate(x.type.dims)}
ds_order = tuple(inp_dims.get(dim, "x") for dim in out_dims)
return typing.cast(TensorVariable, x.values.dimshuffle(ds_order))
......@@ -2,8 +2,8 @@ from pytensor.graph import node_rewriter
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random.utils import compute_batch_shape
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.rewriting.utils import register_lower_xtensor
from pytensor.xtensor.basic import xtensor_from_tensor
from pytensor.xtensor.rewriting.utils import lower_aligned, register_lower_xtensor
from pytensor.xtensor.vectorization import XRV, XBlockwise, XElemwise
......@@ -13,15 +13,7 @@ def lower_elemwise(fgraph, node):
out_dims = node.outputs[0].type.dims
# Convert input XTensors to Tensors and align batch dimensions
tensor_inputs = []
for inp in node.inputs:
inp_dims = inp.type.dims
order = [
inp_dims.index(out_dim) if out_dim in inp_dims else "x"
for out_dim in out_dims
]
tensor_inp = tensor_from_xtensor(inp).dimshuffle(order)
tensor_inputs.append(tensor_inp)
tensor_inputs = [lower_aligned(inp, out_dims) for inp in node.inputs]
tensor_outs = Elemwise(scalar_op=node.op.scalar_op)(
*tensor_inputs, return_list=True
......@@ -42,17 +34,10 @@ def lower_blockwise(fgraph, node):
batch_dims = node.outputs[0].type.dims[:batch_ndim]
# Convert input Tensors to XTensors, align batch dimensions and place core dimension at the end
tensor_inputs = []
for inp, core_dims in zip(node.inputs, op.core_dims[0]):
inp_dims = inp.type.dims
# Align the batch dims of the input, and place the core dims on the right
batch_order = [
inp_dims.index(batch_dim) if batch_dim in inp_dims else "x"
for batch_dim in batch_dims
]
core_order = [inp_dims.index(core_dim) for core_dim in core_dims]
tensor_inp = tensor_from_xtensor(inp).dimshuffle(batch_order + core_order)
tensor_inputs.append(tensor_inp)
tensor_inputs = [
lower_aligned(inp, batch_dims + core_dims)
for inp, core_dims in zip(node.inputs, op.core_dims[0], strict=True)
]
signature = op.signature or getattr(op.core_op, "gufunc_signature", None)
if signature is None:
......@@ -92,17 +77,10 @@ def lower_rv(fgraph, node):
param_batch_dims = old_out.type.dims[len(op.extra_dims) : batch_ndim]
# Convert params Tensors to XTensors, align batch dimensions and place core dimension at the end
tensor_params = []
for inp, core_dims in zip(params, op.core_dims[0]):
inp_dims = inp.type.dims
# Align the batch dims of the input, and place the core dims on the right
batch_order = [
inp_dims.index(batch_dim) if batch_dim in inp_dims else "x"
for batch_dim in param_batch_dims
]
core_order = [inp_dims.index(core_dim) for core_dim in core_dims]
tensor_inp = tensor_from_xtensor(inp).dimshuffle(batch_order + core_order)
tensor_params.append(tensor_inp)
tensor_params = [
lower_aligned(inp, param_batch_dims + core_dims)
for inp, core_dims in zip(params, op.core_dims[0], strict=True)
]
size = None
if op.extra_dims:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论