提交 68b41a48 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Better error for fallback of vectorize_node with non-tensor types

上级 301f10dc
...@@ -14,9 +14,10 @@ from pytensor.graph.replace import ( ...@@ -14,9 +14,10 @@ from pytensor.graph.replace import (
_vectorize_not_needed, _vectorize_not_needed,
vectorize_graph, vectorize_graph,
) )
from pytensor.scalar import ScalarType
from pytensor.tensor import as_tensor_variable from pytensor.tensor import as_tensor_variable
from pytensor.tensor.shape import shape_padleft from pytensor.tensor.shape import shape_padleft
from pytensor.tensor.type import continuous_dtypes, discrete_dtypes, tensor from pytensor.tensor.type import TensorType, continuous_dtypes, discrete_dtypes, tensor
from pytensor.tensor.utils import ( from pytensor.tensor.utils import (
_parse_gufunc_signature, _parse_gufunc_signature,
broadcast_static_dim_lengths, broadcast_static_dim_lengths,
...@@ -373,6 +374,12 @@ class Blockwise(Op): ...@@ -373,6 +374,12 @@ class Blockwise(Op):
@_vectorize_node.register(Op) @_vectorize_node.register(Op)
def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply: def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply:
for inp in node.inputs:
if not isinstance(inp.type, (TensorType, ScalarType)):
raise NotImplementedError(
f"Cannot vectorize node {node} with input {inp} of type {inp.type}"
)
if hasattr(op, "gufunc_signature"): if hasattr(op, "gufunc_signature"):
signature = op.gufunc_signature signature = op.gufunc_signature
else: else:
......
import re
from itertools import product from itertools import product
from typing import Optional, Union from typing import Optional, Union
...@@ -12,7 +13,7 @@ from pytensor.graph import Apply, Op ...@@ -12,7 +13,7 @@ from pytensor.graph import Apply, Op
from pytensor.graph.replace import vectorize_node from pytensor.graph.replace import vectorize_node
from pytensor.raise_op import assert_op from pytensor.raise_op import assert_op
from pytensor.tensor import diagonal, log, tensor from pytensor.tensor import diagonal, log, tensor
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
from pytensor.tensor.nlinalg import MatrixInverse from pytensor.tensor.nlinalg import MatrixInverse
from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot
from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve_triangular from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve_triangular
...@@ -42,6 +43,19 @@ def test_vectorize_blockwise(): ...@@ -42,6 +43,19 @@ def test_vectorize_blockwise():
assert new_vect_node.inputs[0] is tns4 assert new_vect_node.inputs[0] is tns4
def test_vectorize_node_fallback_unsupported_type():
x = tensor("x", shape=(2, 6))
node = x[:, [0, 2, 4]].owner
with pytest.raises(
NotImplementedError,
match=re.escape(
"Cannot vectorize node AdvancedSubtensor(x, MakeSlice.0, [0 2 4]) with input MakeSlice.0 of type slice"
),
):
vectorize_node_fallback(node.op, node, node.inputs)
def check_blockwise_runtime_broadcasting(mode): def check_blockwise_runtime_broadcasting(mode):
a = tensor("a", shape=(None, 3, 5)) a = tensor("a", shape=(None, 3, 5))
b = tensor("b", shape=(None, 5, 3)) b = tensor("b", shape=(None, 5, 3))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论