提交 e1ce1c35 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Refactor `lower_aligned` helper

上级 45a33ada
...@@ -9,6 +9,7 @@ from pytensor.tensor import ( ...@@ -9,6 +9,7 @@ from pytensor.tensor import (
) )
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.rewriting.basic import register_lower_xtensor from pytensor.xtensor.rewriting.basic import register_lower_xtensor
from pytensor.xtensor.rewriting.utils import lower_aligned
from pytensor.xtensor.shape import ( from pytensor.xtensor.shape import (
Concat, Concat,
ExpandDims, ExpandDims,
...@@ -70,15 +71,7 @@ def lower_concat(fgraph, node): ...@@ -70,15 +71,7 @@ def lower_concat(fgraph, node):
concat_axis = out_dims.index(concat_dim) concat_axis = out_dims.index(concat_dim)
# Convert input XTensors to Tensors and align batch dimensions # Convert input XTensors to Tensors and align batch dimensions
tensor_inputs = [] tensor_inputs = [lower_aligned(inp, out_dims) for inp in node.inputs]
for inp in node.inputs:
inp_dims = inp.type.dims
order = [
inp_dims.index(out_dim) if out_dim in inp_dims else "x"
for out_dim in out_dims
]
tensor_inp = tensor_from_xtensor(inp).dimshuffle(order)
tensor_inputs.append(tensor_inp)
# Broadcast non-concatenated dimensions of each input # Broadcast non-concatenated dimensions of each input
non_concat_shape = [None] * len(out_dims) non_concat_shape = [None] * len(out_dims)
......
import typing
from collections.abc import Sequence
from pytensor.compile import optdb from pytensor.compile import optdb
from pytensor.graph.rewriting.basic import NodeRewriter, in2out from pytensor.graph.rewriting.basic import NodeRewriter, in2out
from pytensor.graph.rewriting.db import EquilibriumDB, RewriteDatabase from pytensor.graph.rewriting.db import EquilibriumDB, RewriteDatabase
from pytensor.tensor.rewriting.ofg import inline_ofg_expansion from pytensor.tensor.rewriting.ofg import inline_ofg_expansion
from pytensor.tensor.variable import TensorVariable
from pytensor.xtensor.type import XTensorVariable
lower_xtensor_db = EquilibriumDB(ignore_newtrees=False) lower_xtensor_db = EquilibriumDB(ignore_newtrees=False)
...@@ -49,3 +54,10 @@ def register_lower_xtensor( ...@@ -49,3 +54,10 @@ def register_lower_xtensor(
**kwargs, **kwargs,
) )
return node_rewriter return node_rewriter
def lower_aligned(x: XTensorVariable, out_dims: Sequence[str]) -> TensorVariable:
"""Lower an XTensorVariable to a TensorVariable so that it's dimensions are aligned with "out_dims"."""
inp_dims = {d: i for i, d in enumerate(x.type.dims)}
ds_order = tuple(inp_dims.get(dim, "x") for dim in out_dims)
return typing.cast(TensorVariable, x.values.dimshuffle(ds_order))
...@@ -2,8 +2,8 @@ from pytensor.graph import node_rewriter ...@@ -2,8 +2,8 @@ from pytensor.graph import node_rewriter
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random.utils import compute_batch_shape from pytensor.tensor.random.utils import compute_batch_shape
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor from pytensor.xtensor.basic import xtensor_from_tensor
from pytensor.xtensor.rewriting.utils import register_lower_xtensor from pytensor.xtensor.rewriting.utils import lower_aligned, register_lower_xtensor
from pytensor.xtensor.vectorization import XRV, XBlockwise, XElemwise from pytensor.xtensor.vectorization import XRV, XBlockwise, XElemwise
...@@ -13,15 +13,7 @@ def lower_elemwise(fgraph, node): ...@@ -13,15 +13,7 @@ def lower_elemwise(fgraph, node):
out_dims = node.outputs[0].type.dims out_dims = node.outputs[0].type.dims
# Convert input XTensors to Tensors and align batch dimensions # Convert input XTensors to Tensors and align batch dimensions
tensor_inputs = [] tensor_inputs = [lower_aligned(inp, out_dims) for inp in node.inputs]
for inp in node.inputs:
inp_dims = inp.type.dims
order = [
inp_dims.index(out_dim) if out_dim in inp_dims else "x"
for out_dim in out_dims
]
tensor_inp = tensor_from_xtensor(inp).dimshuffle(order)
tensor_inputs.append(tensor_inp)
tensor_outs = Elemwise(scalar_op=node.op.scalar_op)( tensor_outs = Elemwise(scalar_op=node.op.scalar_op)(
*tensor_inputs, return_list=True *tensor_inputs, return_list=True
...@@ -42,17 +34,10 @@ def lower_blockwise(fgraph, node): ...@@ -42,17 +34,10 @@ def lower_blockwise(fgraph, node):
batch_dims = node.outputs[0].type.dims[:batch_ndim] batch_dims = node.outputs[0].type.dims[:batch_ndim]
# Convert input Tensors to XTensors, align batch dimensions and place core dimension at the end # Convert input Tensors to XTensors, align batch dimensions and place core dimension at the end
tensor_inputs = [] tensor_inputs = [
for inp, core_dims in zip(node.inputs, op.core_dims[0]): lower_aligned(inp, batch_dims + core_dims)
inp_dims = inp.type.dims for inp, core_dims in zip(node.inputs, op.core_dims[0], strict=True)
# Align the batch dims of the input, and place the core dims on the right ]
batch_order = [
inp_dims.index(batch_dim) if batch_dim in inp_dims else "x"
for batch_dim in batch_dims
]
core_order = [inp_dims.index(core_dim) for core_dim in core_dims]
tensor_inp = tensor_from_xtensor(inp).dimshuffle(batch_order + core_order)
tensor_inputs.append(tensor_inp)
signature = op.signature or getattr(op.core_op, "gufunc_signature", None) signature = op.signature or getattr(op.core_op, "gufunc_signature", None)
if signature is None: if signature is None:
...@@ -92,17 +77,10 @@ def lower_rv(fgraph, node): ...@@ -92,17 +77,10 @@ def lower_rv(fgraph, node):
param_batch_dims = old_out.type.dims[len(op.extra_dims) : batch_ndim] param_batch_dims = old_out.type.dims[len(op.extra_dims) : batch_ndim]
# Convert params Tensors to XTensors, align batch dimensions and place core dimension at the end # Convert params Tensors to XTensors, align batch dimensions and place core dimension at the end
tensor_params = [] tensor_params = [
for inp, core_dims in zip(params, op.core_dims[0]): lower_aligned(inp, param_batch_dims + core_dims)
inp_dims = inp.type.dims for inp, core_dims in zip(params, op.core_dims[0], strict=True)
# Align the batch dims of the input, and place the core dims on the right ]
batch_order = [
inp_dims.index(batch_dim) if batch_dim in inp_dims else "x"
for batch_dim in param_batch_dims
]
core_order = [inp_dims.index(core_dim) for core_dim in core_dims]
tensor_inp = tensor_from_xtensor(inp).dimshuffle(batch_order + core_order)
tensor_params.append(tensor_inp)
size = None size = None
if op.extra_dims: if op.extra_dims:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论