提交 68b41a48 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Better error for fallback of vectorize_node with non-tensor types

上级 301f10dc
......@@ -14,9 +14,10 @@ from pytensor.graph.replace import (
_vectorize_not_needed,
vectorize_graph,
)
from pytensor.scalar import ScalarType
from pytensor.tensor import as_tensor_variable
from pytensor.tensor.shape import shape_padleft
from pytensor.tensor.type import continuous_dtypes, discrete_dtypes, tensor
from pytensor.tensor.type import TensorType, continuous_dtypes, discrete_dtypes, tensor
from pytensor.tensor.utils import (
_parse_gufunc_signature,
broadcast_static_dim_lengths,
......@@ -373,6 +374,12 @@ class Blockwise(Op):
@_vectorize_node.register(Op)
def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply:
for inp in node.inputs:
if not isinstance(inp.type, (TensorType, ScalarType)):
raise NotImplementedError(
f"Cannot vectorize node {node} with input {inp} of type {inp.type}"
)
if hasattr(op, "gufunc_signature"):
signature = op.gufunc_signature
else:
......
import re
from itertools import product
from typing import Optional, Union
......@@ -12,7 +13,7 @@ from pytensor.graph import Apply, Op
from pytensor.graph.replace import vectorize_node
from pytensor.raise_op import assert_op
from pytensor.tensor import diagonal, log, tensor
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
from pytensor.tensor.nlinalg import MatrixInverse
from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot
from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve_triangular
......@@ -42,6 +43,19 @@ def test_vectorize_blockwise():
assert new_vect_node.inputs[0] is tns4
def test_vectorize_node_fallback_unsupported_type():
x = tensor("x", shape=(2, 6))
node = x[:, [0, 2, 4]].owner
with pytest.raises(
NotImplementedError,
match=re.escape(
"Cannot vectorize node AdvancedSubtensor(x, MakeSlice.0, [0 2 4]) with input MakeSlice.0 of type slice"
),
):
vectorize_node_fallback(node.op, node, node.inputs)
def check_blockwise_runtime_broadcasting(mode):
a = tensor("a", shape=(None, 3, 5))
b = tensor("b", shape=(None, 5, 3))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论