提交 4b812709 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Improve static shape of shaped xtensor Ops

Also avoid useless cast to XTensorType when 0d-tensor suffices
上级 16487418
...@@ -29,13 +29,13 @@ from pytensor.graph.fg import FunctionGraph, Output ...@@ -29,13 +29,13 @@ from pytensor.graph.fg import FunctionGraph, Output
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node from pytensor.graph.replace import _vectorize_node
from pytensor.graph.rewriting.db import EquilibriumDB from pytensor.graph.rewriting.db import EquilibriumDB
from pytensor.graph.type import HasShape from pytensor.graph.type import HasDataType, HasShape
from pytensor.link.c.op import COp from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType from pytensor.link.c.params_type import ParamsType
from pytensor.printing import Printer, min_informative_str, pprint, set_precedence from pytensor.printing import Printer, min_informative_str, pprint, set_precedence
from pytensor.raise_op import CheckAndRaise from pytensor.raise_op import CheckAndRaise
from pytensor.scalar import int32 from pytensor.scalar import int32
from pytensor.scalar.basic import ScalarConstant, ScalarType, ScalarVariable from pytensor.scalar.basic import ScalarConstant, ScalarVariable
from pytensor.tensor import ( from pytensor.tensor import (
_as_tensor_variable, _as_tensor_variable,
_get_vector_length, _get_vector_length,
...@@ -292,13 +292,8 @@ def _get_underlying_scalar_constant_value( ...@@ -292,13 +292,8 @@ def _get_underlying_scalar_constant_value(
max_recur : int max_recur : int
The maximum number of recursion. The maximum number of recursion.
Notes
-----
There may be another function similar to this one in the code,
but I'm not sure where it is.
""" """
from pytensor.compile.ops import DeepCopyOp, OutputGuard from pytensor.compile.ops import DeepCopyOp, OutputGuard, TypeCastingOp
from pytensor.sparse import CSM from pytensor.sparse import CSM
from pytensor.tensor.subtensor import Subtensor from pytensor.tensor.subtensor import Subtensor
...@@ -319,13 +314,20 @@ def _get_underlying_scalar_constant_value( ...@@ -319,13 +314,20 @@ def _get_underlying_scalar_constant_value(
raise NotScalarConstantError() raise NotScalarConstantError()
if isinstance(v, Constant): if isinstance(v, Constant):
if isinstance(v.type, TensorType) and v.unique_value is not None: v_type = v.type
return v.unique_value if isinstance(v_type, HasShape) and isinstance(v_type, HasDataType):
if v_type.ndim == 0:
return np.array(v.data, dtype=v.type.dtype)
elif isinstance(v.type, ScalarType): elif (not any(s is None for s in v_type.shape)) and (
return v.data np.prod(v_type.shape) == 1
):
return np.array(v.data, dtype=v_type.dtype).squeeze()
elif isinstance(v.type, NoneTypeT): elif isinstance(v_type, TensorType) and v.unique_value is not None:
return np.array(v.unique_value)
elif isinstance(v_type, NoneTypeT):
return None return None
raise NotScalarConstantError() raise NotScalarConstantError()
...@@ -333,9 +335,9 @@ def _get_underlying_scalar_constant_value( ...@@ -333,9 +335,9 @@ def _get_underlying_scalar_constant_value(
if not only_process_constants and getattr(v, "owner", None) and max_recur > 0: if not only_process_constants and getattr(v, "owner", None) and max_recur > 0:
op = v.owner.op op = v.owner.op
max_recur -= 1 max_recur -= 1
if isinstance(op, Alloc | DimShuffle | OutputGuard | DeepCopyOp): if isinstance(
# OutputGuard is only used in debugmode but we op, Alloc | DimShuffle | TypeCastingOp | DeepCopyOp | OutputGuard
# keep it here to avoid problems with old pickles ):
v = v.owner.inputs[0] v = v.owner.inputs[0]
continue continue
elif isinstance(op, Shape_i): elif isinstance(op, Shape_i):
...@@ -343,7 +345,6 @@ def _get_underlying_scalar_constant_value( ...@@ -343,7 +345,6 @@ def _get_underlying_scalar_constant_value(
inp = v.owner.inputs[0] inp = v.owner.inputs[0]
if isinstance(inp, Constant): if isinstance(inp, Constant):
return np.asarray(np.shape(inp.data)[i]) return np.asarray(np.shape(inp.data)[i])
# The shape of a broadcastable dimension is 1
if isinstance(inp.type, HasShape) and inp.type.shape[i] is not None: if isinstance(inp.type, HasShape) and inp.type.shape[i] is not None:
return np.asarray(inp.type.shape[i]) return np.asarray(inp.type.shape[i])
...@@ -600,7 +601,10 @@ def get_scalar_constant_value( ...@@ -600,7 +601,10 @@ def get_scalar_constant_value(
If 'v' is not a scalar, it raises a NotScalarConstantError. If 'v' is not a scalar, it raises a NotScalarConstantError.
""" """
if isinstance(v, TensorVariable | np.ndarray): if isinstance(v, Variable) and isinstance(v.type, HasShape):
if v.type.ndim != 0:
raise NotScalarConstantError("Input ndim != 0")
elif isinstance(v, np.ndarray):
if v.ndim != 0: if v.ndim != 0:
raise NotScalarConstantError("Input ndim != 0") raise NotScalarConstantError("Input ndim != 0")
return get_underlying_scalar_constant_value( return get_underlying_scalar_constant_value(
......
...@@ -676,7 +676,7 @@ def get_constant_idx( ...@@ -676,7 +676,7 @@ def get_constant_idx(
>>> b.owner.op.idx_list >>> b.owner.op.idx_list
(0, slice(1, 2, None)) (0, slice(1, 2, None))
>>> get_constant_idx(b.owner.op.idx_list, b.owner.inputs, allow_partial=True) >>> get_constant_idx(b.owner.op.idx_list, b.owner.inputs, allow_partial=True)
[v, slice(np.int64(1), np.int64(3), None)] [v, slice(1, 3, None)]
>>> get_constant_idx(b.owner.op.idx_list, b.owner.inputs) >>> get_constant_idx(b.owner.op.idx_list, b.owner.inputs)
Traceback (most recent call last): Traceback (most recent call last):
pytensor.tensor.exceptions.NotScalarConstantError pytensor.tensor.exceptions.NotScalarConstantError
...@@ -696,7 +696,7 @@ def get_constant_idx( ...@@ -696,7 +696,7 @@ def get_constant_idx(
val, val,
only_process_constants=only_process_constants, only_process_constants=only_process_constants,
elemwise=elemwise, elemwise=elemwise,
) ).item()
except NotScalarConstantError: except NotScalarConstantError:
if allow_partial: if allow_partial:
return val return val
......
...@@ -119,7 +119,6 @@ def lower_expand_dims(fgraph, node): ...@@ -119,7 +119,6 @@ def lower_expand_dims(fgraph, node):
# Convert inputs to tensors # Convert inputs to tensors
x_tensor = tensor_from_xtensor(x) x_tensor = tensor_from_xtensor(x)
size_tensor = tensor_from_xtensor(size)
# Get the new dimension name and position # Get the new dimension name and position
new_axis = 0 # Always insert at front new_axis = 0 # Always insert at front
...@@ -130,7 +129,7 @@ def lower_expand_dims(fgraph, node): ...@@ -130,7 +129,7 @@ def lower_expand_dims(fgraph, node):
result_tensor = expand_dims(x_tensor, new_axis) result_tensor = expand_dims(x_tensor, new_axis)
else: else:
# Otherwise broadcast to the requested size # Otherwise broadcast to the requested size
result_tensor = broadcast_to(x_tensor, (size_tensor, *x_tensor.shape)) result_tensor = broadcast_to(x_tensor, (size, *x_tensor.shape))
# Preserve static shape information # Preserve static shape information
result_tensor = specify_shape(result_tensor, out.type.shape) result_tensor = specify_shape(result_tensor, out.type.shape)
......
...@@ -123,7 +123,10 @@ class UnStack(XOp): ...@@ -123,7 +123,10 @@ class UnStack(XOp):
raise ValueError( raise ValueError(
f"Number of unstacked lengths {len(unstacked_length)} must match number of unstacked dims {len(self.unstacked_dims)}" f"Number of unstacked lengths {len(unstacked_length)} must match number of unstacked dims {len(self.unstacked_dims)}"
) )
unstacked_lengths = [as_tensor(length, ndim=0) for length in unstacked_length] unstacked_lengths = [
as_tensor(length, allow_xtensor_conversion=True)
for length in unstacked_length
]
if not all(length.dtype in discrete_dtypes for length in unstacked_lengths): if not all(length.dtype in discrete_dtypes for length in unstacked_lengths):
raise TypeError("Unstacked lengths must be discrete dtypes.") raise TypeError("Unstacked lengths must be discrete dtypes.")
...@@ -441,7 +444,7 @@ class ExpandDims(XOp): ...@@ -441,7 +444,7 @@ class ExpandDims(XOp):
if self.dim in x.type.dims: if self.dim in x.type.dims:
raise ValueError(f"Dimension {self.dim} already exists in {x.type.dims}") raise ValueError(f"Dimension {self.dim} already exists in {x.type.dims}")
size = as_xtensor(size, dims=()) size = as_tensor(size, allow_xtensor_conversion=True)
if not (size.dtype in integer_dtypes and size.ndim == 0): if not (size.dtype in integer_dtypes and size.ndim == 0):
raise ValueError(f"size should be an integer scalar, got {size.type}") raise ValueError(f"size should be an integer scalar, got {size.type}")
try: try:
......
...@@ -16,6 +16,7 @@ from pytensor.graph.type import HasShape ...@@ -16,6 +16,7 @@ from pytensor.graph.type import HasShape
from pytensor.scalar import discrete_dtypes from pytensor.scalar import discrete_dtypes
from pytensor.tensor import ( from pytensor.tensor import (
TensorVariable, TensorVariable,
as_tensor,
broadcast_shape, broadcast_shape,
broadcast_to, broadcast_to,
tensor, tensor,
...@@ -232,7 +233,7 @@ class XRV(XOp, RNGConsumerOp): ...@@ -232,7 +233,7 @@ class XRV(XOp, RNGConsumerOp):
) )
extra_dim_lengths = [ extra_dim_lengths = [
as_xtensor(dim_length).values as_tensor(dim_length, allow_xtensor_conversion=True)
for dim_length in extra_dim_lengths_and_params[: len(self.extra_dims)] for dim_length in extra_dim_lengths_and_params[: len(self.extra_dims)]
] ]
if not all( if not all(
......
...@@ -3504,12 +3504,11 @@ class TestGetUnderlyingScalarConstantValue: ...@@ -3504,12 +3504,11 @@ class TestGetUnderlyingScalarConstantValue:
assert get_underlying_scalar_constant_value(s) == c.data assert get_underlying_scalar_constant_value(s) == c.data
def test_copy(self): def test_copy(self):
# Make sure we do not return a writeable internal storage of a constant, # Make sure we do not return the internal storage of a constant,
# so we cannot change the value of a constant by mistake. # so we cannot change the value of a constant by mistake.
c = constant(3) c = constant(3)
d = get_scalar_constant_value(c) d = get_scalar_constant_value(c)
with pytest.raises(ValueError, match="output array is read-only"): d += 1
d += 1
e = get_scalar_constant_value(c) e = get_scalar_constant_value(c)
assert e == 3, (c, d, e) assert e == 3, (c, d, e)
......
...@@ -132,6 +132,14 @@ def test_dtype(): ...@@ -132,6 +132,14 @@ def test_dtype():
assert x.type.dtype == "float32" assert x.type.dtype == "float32"
def test_static_shape():
x = xtensor("x", dims=("a", "b"), shape=(1, None))
y = xtensor("y", dims=("c", "d"), shape=(2, None))
out = normal(x, 1, extra_dims=y.sizes)
assert out.type.dims == ("c", "d", "a", "b")
assert out.type.shape == (2, None, 1, None)
def test_normal(): def test_normal():
rng = random_generator_type("rng") rng = random_generator_type("rng")
c_size = tensor("c_size", shape=(), dtype=int) c_size = tensor("c_size", shape=(), dtype=int)
......
...@@ -25,7 +25,7 @@ from pytensor.xtensor.shape import ( ...@@ -25,7 +25,7 @@ from pytensor.xtensor.shape import (
unstack, unstack,
zeros_like, zeros_like,
) )
from pytensor.xtensor.type import as_xtensor, xtensor from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor
from pytensor.xtensor.vectorization import vectorize_graph from pytensor.xtensor.vectorization import vectorize_graph
from tests.xtensor.util import ( from tests.xtensor.util import (
check_vectorization, check_vectorization,
...@@ -369,16 +369,22 @@ def test_expand_dims(): ...@@ -369,16 +369,22 @@ def test_expand_dims():
# Implicit size 1 # Implicit size 1
y = x.expand_dims("country") y = x.expand_dims("country")
assert y.type.dims == ("country", "city", "year")
assert y.type.shape == (1, 2, 2)
fn = xr_function([x], y) fn = xr_function([x], y)
xr_assert_allclose(fn(x_test), x_test.expand_dims("country")) xr_assert_allclose(fn(x_test), x_test.expand_dims("country"))
# Test with multiple dimensions # Test with multiple dimensions
y = x.expand_dims(["country", "state"]) y = x.expand_dims(["country", "state"])
assert y.type.dims == ("country", "state", "city", "year")
assert y.type.shape == (1, 1, 2, 2)
fn = xr_function([x], y) fn = xr_function([x], y)
xr_assert_allclose(fn(x_test), x_test.expand_dims(["country", "state"])) xr_assert_allclose(fn(x_test), x_test.expand_dims(["country", "state"]))
# Test with a dict of name-size pairs # Test with a dict of name-size pairs
y = x.expand_dims({"country": 2, "state": 3}) y = x.expand_dims({"country": 2, "state": 3})
assert y.type.dims == ("country", "state", "city", "year")
assert y.type.shape == (2, 3, 2, 2)
fn = xr_function([x], y) fn = xr_function([x], y)
xr_assert_allclose(fn(x_test), x_test.expand_dims({"country": 2, "state": 3})) xr_assert_allclose(fn(x_test), x_test.expand_dims({"country": 2, "state": 3}))
...@@ -390,6 +396,8 @@ def test_expand_dims(): ...@@ -390,6 +396,8 @@ def test_expand_dims():
# Test with a dict of name-coord array pairs # Test with a dict of name-coord array pairs
with pytest.warns(UserWarning, match="only its length is used"): with pytest.warns(UserWarning, match="only its length is used"):
y = x.expand_dims({"country": np.array([1, 2]), "state": np.array([3, 4, 5])}) y = x.expand_dims({"country": np.array([1, 2]), "state": np.array([3, 4, 5])})
assert y.type.dims == ("country", "state", "city", "year")
assert y.type.shape == (2, 3, 2, 2)
fn = xr_function([x], y) fn = xr_function([x], y)
xr_assert_allclose( xr_assert_allclose(
fn(x_test), fn(x_test),
...@@ -399,12 +407,16 @@ def test_expand_dims(): ...@@ -399,12 +407,16 @@ def test_expand_dims():
# Symbolic size 1 # Symbolic size 1
size_sym_1 = scalar("size_sym_1", dtype="int64") size_sym_1 = scalar("size_sym_1", dtype="int64")
y = x.expand_dims({"country": size_sym_1}) y = x.expand_dims({"country": size_sym_1})
assert y.type.dims == ("country", "city", "year")
assert y.type.shape == (None, 2, 2)
fn = xr_function([x, size_sym_1], y) fn = xr_function([x, size_sym_1], y)
xr_assert_allclose(fn(x_test, 1), x_test.expand_dims({"country": 1})) xr_assert_allclose(fn(x_test, 1), x_test.expand_dims({"country": 1}))
# Test with symbolic sizes in dict # Test with symbolic sizes in dict
size_sym_2 = scalar("size_sym_2", dtype="int64") size_sym_2 = scalar("size_sym_2", dtype="int64")
y = x.expand_dims({"country": size_sym_1, "state": size_sym_2}) y = x.expand_dims({"country": size_sym_1, "state": size_sym_2})
assert y.type.dims == ("country", "state", "city", "year")
assert y.type.shape == (None, None, 2, 2)
fn = xr_function([x, size_sym_1, size_sym_2], y) fn = xr_function([x, size_sym_1, size_sym_2], y)
xr_assert_allclose(fn(x_test, 2, 3), x_test.expand_dims({"country": 2, "state": 3})) xr_assert_allclose(fn(x_test, 2, 3), x_test.expand_dims({"country": 2, "state": 3}))
...@@ -415,16 +427,24 @@ def test_expand_dims(): ...@@ -415,16 +427,24 @@ def test_expand_dims():
# Test with axis parameter # Test with axis parameter
y = x.expand_dims("country", axis=1) y = x.expand_dims("country", axis=1)
assert y.type == XTensorType(
dtype=x.dtype, dims=("city", "country", "year"), shape=(2, 1, 2)
)
fn = xr_function([x], y) fn = xr_function([x], y)
xr_assert_allclose(fn(x_test), x_test.expand_dims("country", axis=1)) xr_assert_allclose(fn(x_test), x_test.expand_dims("country", axis=1))
# Test with negative axis parameter # Test with negative axis parameter
y = x.expand_dims("country", axis=-1) y = x.expand_dims("country", axis=-1)
assert y.type == XTensorType(
dtype=x.dtype, dims=("city", "year", "country"), shape=(2, 2, 1)
)
fn = xr_function([x], y) fn = xr_function([x], y)
xr_assert_allclose(fn(x_test), x_test.expand_dims("country", axis=-1)) xr_assert_allclose(fn(x_test), x_test.expand_dims("country", axis=-1))
# Add two new dims with axis parameters # Add two new dims with axis parameters
y = x.expand_dims(["country", "state"], axis=[1, 2]) y = x.expand_dims(["country", "state"], axis=[1, 2])
assert y.type.dims == ("city", "country", "state", "year")
assert y.type.shape == (2, 1, 1, 2)
fn = xr_function([x], y) fn = xr_function([x], y)
xr_assert_allclose( xr_assert_allclose(
fn(x_test), x_test.expand_dims(["country", "state"], axis=[1, 2]) fn(x_test), x_test.expand_dims(["country", "state"], axis=[1, 2])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论