提交 4b812709 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Improve static shape of shaped xtensor Ops

Also avoid useless cast to XTensorType when 0d-tensor suffices
上级 16487418
......@@ -29,13 +29,13 @@ from pytensor.graph.fg import FunctionGraph, Output
from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node
from pytensor.graph.rewriting.db import EquilibriumDB
from pytensor.graph.type import HasShape
from pytensor.graph.type import HasDataType, HasShape
from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType
from pytensor.printing import Printer, min_informative_str, pprint, set_precedence
from pytensor.raise_op import CheckAndRaise
from pytensor.scalar import int32
from pytensor.scalar.basic import ScalarConstant, ScalarType, ScalarVariable
from pytensor.scalar.basic import ScalarConstant, ScalarVariable
from pytensor.tensor import (
_as_tensor_variable,
_get_vector_length,
......@@ -292,13 +292,8 @@ def _get_underlying_scalar_constant_value(
max_recur : int
The maximum number of recursion.
Notes
-----
There may be another function similar to this one in the code,
but I'm not sure where it is.
"""
from pytensor.compile.ops import DeepCopyOp, OutputGuard
from pytensor.compile.ops import DeepCopyOp, OutputGuard, TypeCastingOp
from pytensor.sparse import CSM
from pytensor.tensor.subtensor import Subtensor
......@@ -319,13 +314,20 @@ def _get_underlying_scalar_constant_value(
raise NotScalarConstantError()
if isinstance(v, Constant):
if isinstance(v.type, TensorType) and v.unique_value is not None:
return v.unique_value
v_type = v.type
if isinstance(v_type, HasShape) and isinstance(v_type, HasDataType):
if v_type.ndim == 0:
return np.array(v.data, dtype=v.type.dtype)
elif isinstance(v.type, ScalarType):
return v.data
elif (not any(s is None for s in v_type.shape)) and (
np.prod(v_type.shape) == 1
):
return np.array(v.data, dtype=v_type.dtype).squeeze()
elif isinstance(v.type, NoneTypeT):
elif isinstance(v_type, TensorType) and v.unique_value is not None:
return np.array(v.unique_value)
elif isinstance(v_type, NoneTypeT):
return None
raise NotScalarConstantError()
......@@ -333,9 +335,9 @@ def _get_underlying_scalar_constant_value(
if not only_process_constants and getattr(v, "owner", None) and max_recur > 0:
op = v.owner.op
max_recur -= 1
if isinstance(op, Alloc | DimShuffle | OutputGuard | DeepCopyOp):
# OutputGuard is only used in debugmode but we
# keep it here to avoid problems with old pickles
if isinstance(
op, Alloc | DimShuffle | TypeCastingOp | DeepCopyOp | OutputGuard
):
v = v.owner.inputs[0]
continue
elif isinstance(op, Shape_i):
......@@ -343,7 +345,6 @@ def _get_underlying_scalar_constant_value(
inp = v.owner.inputs[0]
if isinstance(inp, Constant):
return np.asarray(np.shape(inp.data)[i])
# The shape of a broadcastable dimension is 1
if isinstance(inp.type, HasShape) and inp.type.shape[i] is not None:
return np.asarray(inp.type.shape[i])
......@@ -600,7 +601,10 @@ def get_scalar_constant_value(
If 'v' is not a scalar, it raises a NotScalarConstantError.
"""
if isinstance(v, TensorVariable | np.ndarray):
if isinstance(v, Variable) and isinstance(v.type, HasShape):
if v.type.ndim != 0:
raise NotScalarConstantError("Input ndim != 0")
elif isinstance(v, np.ndarray):
if v.ndim != 0:
raise NotScalarConstantError("Input ndim != 0")
return get_underlying_scalar_constant_value(
......
......@@ -676,7 +676,7 @@ def get_constant_idx(
>>> b.owner.op.idx_list
(0, slice(1, 2, None))
>>> get_constant_idx(b.owner.op.idx_list, b.owner.inputs, allow_partial=True)
[v, slice(np.int64(1), np.int64(3), None)]
[v, slice(1, 3, None)]
>>> get_constant_idx(b.owner.op.idx_list, b.owner.inputs)
Traceback (most recent call last):
pytensor.tensor.exceptions.NotScalarConstantError
......@@ -696,7 +696,7 @@ def get_constant_idx(
val,
only_process_constants=only_process_constants,
elemwise=elemwise,
)
).item()
except NotScalarConstantError:
if allow_partial:
return val
......
......@@ -119,7 +119,6 @@ def lower_expand_dims(fgraph, node):
# Convert inputs to tensors
x_tensor = tensor_from_xtensor(x)
size_tensor = tensor_from_xtensor(size)
# Get the new dimension name and position
new_axis = 0 # Always insert at front
......@@ -130,7 +129,7 @@ def lower_expand_dims(fgraph, node):
result_tensor = expand_dims(x_tensor, new_axis)
else:
# Otherwise broadcast to the requested size
result_tensor = broadcast_to(x_tensor, (size_tensor, *x_tensor.shape))
result_tensor = broadcast_to(x_tensor, (size, *x_tensor.shape))
# Preserve static shape information
result_tensor = specify_shape(result_tensor, out.type.shape)
......
......@@ -123,7 +123,10 @@ class UnStack(XOp):
raise ValueError(
f"Number of unstacked lengths {len(unstacked_length)} must match number of unstacked dims {len(self.unstacked_dims)}"
)
unstacked_lengths = [as_tensor(length, ndim=0) for length in unstacked_length]
unstacked_lengths = [
as_tensor(length, allow_xtensor_conversion=True)
for length in unstacked_length
]
if not all(length.dtype in discrete_dtypes for length in unstacked_lengths):
raise TypeError("Unstacked lengths must be discrete dtypes.")
......@@ -441,7 +444,7 @@ class ExpandDims(XOp):
if self.dim in x.type.dims:
raise ValueError(f"Dimension {self.dim} already exists in {x.type.dims}")
size = as_xtensor(size, dims=())
size = as_tensor(size, allow_xtensor_conversion=True)
if not (size.dtype in integer_dtypes and size.ndim == 0):
raise ValueError(f"size should be an integer scalar, got {size.type}")
try:
......
......@@ -16,6 +16,7 @@ from pytensor.graph.type import HasShape
from pytensor.scalar import discrete_dtypes
from pytensor.tensor import (
TensorVariable,
as_tensor,
broadcast_shape,
broadcast_to,
tensor,
......@@ -232,7 +233,7 @@ class XRV(XOp, RNGConsumerOp):
)
extra_dim_lengths = [
as_xtensor(dim_length).values
as_tensor(dim_length, allow_xtensor_conversion=True)
for dim_length in extra_dim_lengths_and_params[: len(self.extra_dims)]
]
if not all(
......
......@@ -3504,12 +3504,11 @@ class TestGetUnderlyingScalarConstantValue:
assert get_underlying_scalar_constant_value(s) == c.data
def test_copy(self):
# Make sure we do not return a writeable internal storage of a constant,
# Make sure we do not return the internal storage of a constant,
# so we cannot change the value of a constant by mistake.
c = constant(3)
d = get_scalar_constant_value(c)
with pytest.raises(ValueError, match="output array is read-only"):
d += 1
d += 1
e = get_scalar_constant_value(c)
assert e == 3, (c, d, e)
......
......@@ -132,6 +132,14 @@ def test_dtype():
assert x.type.dtype == "float32"
def test_static_shape():
x = xtensor("x", dims=("a", "b"), shape=(1, None))
y = xtensor("y", dims=("c", "d"), shape=(2, None))
out = normal(x, 1, extra_dims=y.sizes)
assert out.type.dims == ("c", "d", "a", "b")
assert out.type.shape == (2, None, 1, None)
def test_normal():
rng = random_generator_type("rng")
c_size = tensor("c_size", shape=(), dtype=int)
......
......@@ -25,7 +25,7 @@ from pytensor.xtensor.shape import (
unstack,
zeros_like,
)
from pytensor.xtensor.type import as_xtensor, xtensor
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor
from pytensor.xtensor.vectorization import vectorize_graph
from tests.xtensor.util import (
check_vectorization,
......@@ -369,16 +369,22 @@ def test_expand_dims():
# Implicit size 1
y = x.expand_dims("country")
assert y.type.dims == ("country", "city", "year")
assert y.type.shape == (1, 2, 2)
fn = xr_function([x], y)
xr_assert_allclose(fn(x_test), x_test.expand_dims("country"))
# Test with multiple dimensions
y = x.expand_dims(["country", "state"])
assert y.type.dims == ("country", "state", "city", "year")
assert y.type.shape == (1, 1, 2, 2)
fn = xr_function([x], y)
xr_assert_allclose(fn(x_test), x_test.expand_dims(["country", "state"]))
# Test with a dict of name-size pairs
y = x.expand_dims({"country": 2, "state": 3})
assert y.type.dims == ("country", "state", "city", "year")
assert y.type.shape == (2, 3, 2, 2)
fn = xr_function([x], y)
xr_assert_allclose(fn(x_test), x_test.expand_dims({"country": 2, "state": 3}))
......@@ -390,6 +396,8 @@ def test_expand_dims():
# Test with a dict of name-coord array pairs
with pytest.warns(UserWarning, match="only its length is used"):
y = x.expand_dims({"country": np.array([1, 2]), "state": np.array([3, 4, 5])})
assert y.type.dims == ("country", "state", "city", "year")
assert y.type.shape == (2, 3, 2, 2)
fn = xr_function([x], y)
xr_assert_allclose(
fn(x_test),
......@@ -399,12 +407,16 @@ def test_expand_dims():
# Symbolic size 1
size_sym_1 = scalar("size_sym_1", dtype="int64")
y = x.expand_dims({"country": size_sym_1})
assert y.type.dims == ("country", "city", "year")
assert y.type.shape == (None, 2, 2)
fn = xr_function([x, size_sym_1], y)
xr_assert_allclose(fn(x_test, 1), x_test.expand_dims({"country": 1}))
# Test with symbolic sizes in dict
size_sym_2 = scalar("size_sym_2", dtype="int64")
y = x.expand_dims({"country": size_sym_1, "state": size_sym_2})
assert y.type.dims == ("country", "state", "city", "year")
assert y.type.shape == (None, None, 2, 2)
fn = xr_function([x, size_sym_1, size_sym_2], y)
xr_assert_allclose(fn(x_test, 2, 3), x_test.expand_dims({"country": 2, "state": 3}))
......@@ -415,16 +427,24 @@ def test_expand_dims():
# Test with axis parameter
y = x.expand_dims("country", axis=1)
assert y.type == XTensorType(
dtype=x.dtype, dims=("city", "country", "year"), shape=(2, 1, 2)
)
fn = xr_function([x], y)
xr_assert_allclose(fn(x_test), x_test.expand_dims("country", axis=1))
# Test with negative axis parameter
y = x.expand_dims("country", axis=-1)
assert y.type == XTensorType(
dtype=x.dtype, dims=("city", "year", "country"), shape=(2, 2, 1)
)
fn = xr_function([x], y)
xr_assert_allclose(fn(x_test), x_test.expand_dims("country", axis=-1))
# Add two new dims with axis parameters
y = x.expand_dims(["country", "state"], axis=[1, 2])
assert y.type.dims == ("city", "country", "state", "year")
assert y.type.shape == (2, 1, 1, 2)
fn = xr_function([x], y)
xr_assert_allclose(
fn(x_test), x_test.expand_dims(["country", "state"], axis=[1, 2])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论