提交 aa89e44b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Refactor helper to create safe gufunc signature

上级 b6874c66
......@@ -6,7 +6,7 @@ import numpy as np
from pytensor import config
from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.basic import Apply, Constant
from pytensor.graph.null_type import NullType
from pytensor.graph.op import Op
from pytensor.graph.replace import (
......@@ -22,27 +22,11 @@ from pytensor.tensor.utils import (
_parse_gufunc_signature,
broadcast_static_dim_lengths,
import_func_from_string,
safe_signature,
)
from pytensor.tensor.variable import TensorVariable
def safe_signature(
core_inputs: Sequence[Variable],
core_outputs: Sequence[Variable],
) -> str:
def operand_sig(operand: Variable, prefix: str) -> str:
operands = ",".join(f"{prefix}{i}" for i in range(operand.type.ndim))
return f"({operands})"
inputs_sig = ",".join(
operand_sig(i, prefix=f"i{n}") for n, i in enumerate(core_inputs)
)
outputs_sig = ",".join(
operand_sig(o, prefix=f"o{n}") for n, o in enumerate(core_outputs)
)
return f"{inputs_sig}->{outputs_sig}"
class Blockwise(Op):
"""Generalizes a core `Op` to work with batched dimensions.
......@@ -385,7 +369,10 @@ def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply:
else:
# TODO: This is pretty bad for shape inference and merge optimization!
# Should get better as we add signatures to our Ops
signature = safe_signature(node.inputs, node.outputs)
signature = safe_signature(
[inp.type.ndim for inp in node.inputs],
[out.type.ndim for out in node.outputs],
)
return cast(Apply, Blockwise(op, signature=signature).make_node(*bached_inputs))
......
......@@ -172,7 +172,11 @@ _ARGUMENT_LIST = f"{_ARGUMENT}(?:,{_ARGUMENT})*"
_SIGNATURE = f"^{_ARGUMENT_LIST}->{_ARGUMENT_LIST}$"
def _parse_gufunc_signature(signature):
def _parse_gufunc_signature(
signature,
) -> tuple[
list[tuple[str, ...]], ...
]: # mypy doesn't know it's alwayl a length two tuple
"""
Parse string signatures for a generalized universal function.
......@@ -198,3 +202,20 @@ def _parse_gufunc_signature(signature):
]
for arg_list in signature.split("->")
)
def safe_signature(
core_inputs_ndim: Sequence[int],
core_outputs_ndim: Sequence[int],
) -> str:
def operand_sig(operand_ndim: int, prefix: str) -> str:
operands = ",".join(f"{prefix}{i}" for i in range(operand_ndim))
return f"({operands})"
inputs_sig = ",".join(
operand_sig(ndim, prefix=f"i{n}") for n, ndim in enumerate(core_inputs_ndim)
)
outputs_sig = ",".join(
operand_sig(ndim, prefix=f"o{n}") for n, ndim in enumerate(core_outputs_ndim)
)
return f"{inputs_sig}->{outputs_sig}"
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论