提交 9578bd3b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Allow specializing shape of predefined tensors types

上级 3c66aa65
...@@ -123,6 +123,18 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape): ...@@ -123,6 +123,18 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
self.name = name self.name = name
self.numpy_dtype = np.dtype(self.dtype) self.numpy_dtype = np.dtype(self.dtype)
def __call__(self, *args, shape=None, **kwargs):
if shape is not None:
# Check if shape is compatible with the original type
new_type = self.clone(shape=shape)
if self.is_super(new_type):
return new_type(*args, **kwargs)
else:
raise ValueError(
f"{shape=} is incompatible with original type shape {self.shape=}"
)
return super().__call__(*args, **kwargs)
def clone( def clone(
self, dtype=None, shape=None, broadcastable=None, **kwargs self, dtype=None, shape=None, broadcastable=None, **kwargs
) -> "TensorType": ) -> "TensorType":
......
...@@ -10,6 +10,10 @@ from pytensor.tensor.shape import SpecifyShape ...@@ -10,6 +10,10 @@ from pytensor.tensor.shape import SpecifyShape
from pytensor.tensor.type import ( from pytensor.tensor.type import (
TensorType, TensorType,
col, col,
dmatrix,
drow,
fmatrix,
frow,
matrix, matrix,
row, row,
scalar, scalar,
...@@ -477,3 +481,21 @@ def test_row_matrix_creator_helpers(helper): ...@@ -477,3 +481,21 @@ def test_row_matrix_creator_helpers(helper):
match = "The second dimension of a `col` must have shape 1, got 5" match = "The second dimension of a `col` must have shape 1, got 5"
with pytest.raises(ValueError, match=match): with pytest.raises(ValueError, match=match):
helper(shape=(2, 5)) helper(shape=(2, 5))
def test_shape_of_predefined_dtype_tensor():
# Valid: None dimensions can be specialized
assert fmatrix(shape=(1, None)).type == frow
assert drow(shape=(1, 5)).type == dmatrix(shape=(1, 5)).type
# Invalid: Number of dimensions must match
with pytest.raises(ValueError):
fmatrix(shape=(None, None, None))
# Invalid: Fixed shapes must match
with pytest.raises(ValueError):
fmatrix(shape=(3, 5)).type(shape=(4, 5))
# Invalid: Known shapes can't be lost
with pytest.raises(ValueError):
drow(shape=(None, None))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论