提交 aa89e44b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Refactor helper to create safe gufunc signature

上级 b6874c66
...@@ -6,7 +6,7 @@ import numpy as np ...@@ -6,7 +6,7 @@ import numpy as np
from pytensor import config from pytensor import config
from pytensor.gradient import DisconnectedType from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.basic import Apply, Constant
from pytensor.graph.null_type import NullType from pytensor.graph.null_type import NullType
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.replace import ( from pytensor.graph.replace import (
...@@ -22,27 +22,11 @@ from pytensor.tensor.utils import ( ...@@ -22,27 +22,11 @@ from pytensor.tensor.utils import (
_parse_gufunc_signature, _parse_gufunc_signature,
broadcast_static_dim_lengths, broadcast_static_dim_lengths,
import_func_from_string, import_func_from_string,
safe_signature,
) )
from pytensor.tensor.variable import TensorVariable from pytensor.tensor.variable import TensorVariable
def safe_signature(
core_inputs: Sequence[Variable],
core_outputs: Sequence[Variable],
) -> str:
def operand_sig(operand: Variable, prefix: str) -> str:
operands = ",".join(f"{prefix}{i}" for i in range(operand.type.ndim))
return f"({operands})"
inputs_sig = ",".join(
operand_sig(i, prefix=f"i{n}") for n, i in enumerate(core_inputs)
)
outputs_sig = ",".join(
operand_sig(o, prefix=f"o{n}") for n, o in enumerate(core_outputs)
)
return f"{inputs_sig}->{outputs_sig}"
class Blockwise(Op): class Blockwise(Op):
"""Generalizes a core `Op` to work with batched dimensions. """Generalizes a core `Op` to work with batched dimensions.
...@@ -385,7 +369,10 @@ def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply: ...@@ -385,7 +369,10 @@ def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply:
else: else:
# TODO: This is pretty bad for shape inference and merge optimization! # TODO: This is pretty bad for shape inference and merge optimization!
# Should get better as we add signatures to our Ops # Should get better as we add signatures to our Ops
signature = safe_signature(node.inputs, node.outputs) signature = safe_signature(
[inp.type.ndim for inp in node.inputs],
[out.type.ndim for out in node.outputs],
)
return cast(Apply, Blockwise(op, signature=signature).make_node(*bached_inputs)) return cast(Apply, Blockwise(op, signature=signature).make_node(*bached_inputs))
......
...@@ -172,7 +172,11 @@ _ARGUMENT_LIST = f"{_ARGUMENT}(?:,{_ARGUMENT})*" ...@@ -172,7 +172,11 @@ _ARGUMENT_LIST = f"{_ARGUMENT}(?:,{_ARGUMENT})*"
_SIGNATURE = f"^{_ARGUMENT_LIST}->{_ARGUMENT_LIST}$" _SIGNATURE = f"^{_ARGUMENT_LIST}->{_ARGUMENT_LIST}$"
def _parse_gufunc_signature(signature): def _parse_gufunc_signature(
signature,
) -> tuple[
list[tuple[str, ...]], ...
]: # mypy doesn't know it's alwayl a length two tuple
""" """
Parse string signatures for a generalized universal function. Parse string signatures for a generalized universal function.
...@@ -198,3 +202,20 @@ def _parse_gufunc_signature(signature): ...@@ -198,3 +202,20 @@ def _parse_gufunc_signature(signature):
] ]
for arg_list in signature.split("->") for arg_list in signature.split("->")
) )
def safe_signature(
core_inputs_ndim: Sequence[int],
core_outputs_ndim: Sequence[int],
) -> str:
def operand_sig(operand_ndim: int, prefix: str) -> str:
operands = ",".join(f"{prefix}{i}" for i in range(operand_ndim))
return f"({operands})"
inputs_sig = ",".join(
operand_sig(ndim, prefix=f"i{n}") for n, ndim in enumerate(core_inputs_ndim)
)
outputs_sig = ",".join(
operand_sig(ndim, prefix=f"o{n}") for n, ndim in enumerate(core_outputs_ndim)
)
return f"{inputs_sig}->{outputs_sig}"
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论