提交 ccf53d19 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

SplitDims: Fix scalar tensor shape

上级 2ecf8523
......@@ -3,7 +3,7 @@ from itertools import pairwise
from typing import TypeAlias
import numpy as np
from numpy.lib._array_utils_impl import normalize_axis_tuple
from numpy.lib._array_utils_impl import normalize_axis_index, normalize_axis_tuple
from pytensor import Variable
from pytensor.gradient import DisconnectedType
......@@ -13,7 +13,6 @@ from pytensor.graph.replace import _vectorize_node
from pytensor.scalar import ScalarVariable
from pytensor.tensor import TensorLike, as_tensor_variable
from pytensor.tensor.basic import expand_dims, infer_static_shape, join, split
from pytensor.tensor.extra_ops import squeeze
from pytensor.tensor.math import prod
from pytensor.tensor.type import tensor
from pytensor.tensor.variable import TensorVariable
......@@ -166,11 +165,16 @@ class SplitDims(Op):
def make_node(self, x, shape):
x = as_tensor_variable(x)
shape = as_tensor_variable(shape, dtype=int, ndim=1)
shape = as_tensor_variable(shape, dtype=int)
if shape.type.numpy_dtype.kind not in "iu":
raise TypeError("shape must be an integer tensor")
if shape.type.ndim != 1:
raise TypeError(
f"shape must be a 1-D tensor, got {shape} with {shape.type.ndim} dimensions"
)
axis = self.axis
_, constant_shape = infer_static_shape(shape)
......@@ -262,25 +266,21 @@ def split_dims(
x = as_tensor_variable(x)
if axis is None:
if x.ndim != 1:
if x.type.ndim != 1:
raise ValueError(
"split_dims can only be called with axis=None for 1d inputs"
)
axis = 0
if isinstance(shape, int):
shape = [shape]
else:
shape = list(shape) # type: ignore[arg-type]
if not shape:
# If we get an empty shape, there is potentially a dummy dimension at the requested axis. This happens for
# example when splitting a packed tensor that had its dims expanded before packing (e.g. when packing shapes
# (3, ) and (3, 3) to (3, 4)
return squeeze(x, axis=axis) # type: ignore[no-any-return]
[axis] = normalize_axis_tuple(axis, x.ndim) # type: ignore[misc]
shape = as_tensor_variable(shape, dtype="int64", ndim=1) # type: ignore[arg-type]
axis = normalize_axis_index(axis, x.ndim)
# Convert scalar shape to 1d tuple (shape,)
if not isinstance(shape, Sequence):
if isinstance(shape, TensorVariable | np.ndarray):
if shape.ndim == 0:
shape = (shape,)
elif isinstance(shape, int | np.integer | ScalarVariable):
shape = (shape,)
return SplitDims(axis=axis)(x, shape) # type: ignore[return-value]
......
......@@ -61,9 +61,10 @@ def test_join_dims():
[
(0, pt.as_tensor([2, 3]), (2, 3, 4, 6)),
(2, [2, 3], (6, 4, 2, 3)),
(-1, pt.as_tensor(6), (6, 4, 6)),
(-1, 6, (6, 4, 6)),
],
ids=["tensor", "list", "integer"],
ids=["tensor list", "integer list", "tensor", "integer"],
)
def test_split_dims(axis, shape, expected_shape):
rng = np.random.default_rng()
......@@ -95,7 +96,7 @@ def test_split_dims(axis, shape, expected_shape):
def test_split_size_zero_shape():
x = pt.tensor("x", shape=(1, 4, 6))
x_split = split_dims(x, axis=0, shape=pt.as_tensor(np.zeros((0,))))
x_split = split_dims(x, axis=0, shape=pt.as_tensor(np.zeros((0,), dtype="int32")))
assert x_split.type.shape == (4, 6)
x_value = np.empty((1, 4, 6), dtype=config.floatX)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论