提交 54f4b200 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Do not return Constants in shape Op

上级 8112576b
......@@ -146,14 +146,7 @@ def shape(x: Union[np.ndarray, Number, Variable]) -> Variable:
if not isinstance(x, Variable):
x = at.as_tensor_variable(x)
x_type = x.type
if isinstance(x_type, TensorType) and all(s is not None for s in x_type.shape):
res = at.as_tensor_variable(x_type.shape, ndim=1, dtype=np.int64)
else:
res = _shape(x)
return res
return _shape(x)
@_get_vector_length.register(Shape)
......
......@@ -12,7 +12,7 @@ from pytensor.configdefaults import config
from pytensor.graph.basic import Constant, OptionalApplyType, Variable
from pytensor.graph.utils import MetaType
from pytensor.scalar import ComplexError, IntegerDivisionError
from pytensor.tensor import _get_vector_length, as_tensor_variable
from pytensor.tensor import _get_vector_length
from pytensor.tensor.exceptions import AdvancedIndexingError
from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import NoneConst
......@@ -259,9 +259,6 @@ class _tensor_py_operators:
@property
def shape(self):
if not any(s is None for s in self.type.shape):
return as_tensor_variable(self.type.shape, ndim=1, dtype=np.int64)
return at.shape(self)
@property
......
......@@ -1477,13 +1477,12 @@ class TestSaveMem:
f(x0=0, seq=test_seq, n_steps=0)
# Evaluate the shape of ys_trace and len_zs to confirm the rewrite worked correctly.
# If a MissingInputError is raised, it means the rewrite failed
[scan_node] = (n for n in f.maker.fgraph.apply_nodes if isinstance(n.op, Scan))
_, _, ys_trace, len_zs = scan_node.inputs
debug_fn = pytensor.function(
[n_steps], [ys_trace.shape[0], len_zs], accept_inplace=True
[x0, n_steps], [ys_trace.shape[0], len_zs], accept_inplace=True
)
stored_ys_steps, stored_zs_steps = debug_fn(n_steps=200)
stored_ys_steps, stored_zs_steps = debug_fn(x0=0, n_steps=200)
assert stored_ys_steps == 2
assert stored_zs_steps == 1
......
......@@ -3506,7 +3506,7 @@ class TestSize:
def test_scalar(self):
x = scalar()
y = np.array(7, dtype=config.floatX)
assert y.size == function([], x.size)()
assert y.size == function([x], x.size)(y)
def test_shared(self):
# NB: we also test higher order tensors at the same time.
......
......@@ -7,7 +7,6 @@ from numpy.testing import assert_array_almost_equal
import pytensor
from pytensor import function
from pytensor.configdefaults import config
from pytensor.graph.basic import Constant
from pytensor.tensor.math import _allclose
from pytensor.tensor.nlinalg import (
SVD,
......@@ -274,9 +273,7 @@ def test_det_grad():
def test_det_shape():
x = matrix()
det_shape = det(x).shape
assert isinstance(det_shape, Constant)
assert tuple(det_shape.data) == ()
assert det(x).type.shape == ()
def test_slogdet():
......
......@@ -16,6 +16,7 @@ from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.rewriting.shape import ShapeFeature
from pytensor.tensor.shape import (
Reshape,
Shape,
Shape_i,
SpecifyShape,
Unbroadcast,
......@@ -397,7 +398,7 @@ class TestSpecifyShape(utt.InferShapeTester):
shape = as_tensor_variable([2])
y = specify_shape(x, shape)
assert y.type.shape == (2,)
assert y.shape.equals(shape)
assert isinstance(y.shape.owner.op, Shape)
def test_fixed_partial_shapes(self):
x = TensorType("floatX", (None, None))("x")
......
......@@ -6,12 +6,14 @@ from numpy.testing import assert_array_equal, assert_equal, assert_string_equal
import pytensor
import tests.unittest_tools as utt
from pytensor.compile import DeepCopyOp
from pytensor.compile.mode import get_default_mode
from pytensor.graph.basic import Constant, equal_computations
from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import constant
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import dot, eq
from pytensor.tensor.shape import Shape
from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor
from pytensor.tensor.type import (
TensorType,
......@@ -245,8 +247,14 @@ def test__getitem__newaxis(x, indices, new_order):
def test_fixed_shape_variable_basic():
x = TensorVariable(TensorType("int64", shape=(4,)), None)
assert isinstance(x.shape, Constant)
assert np.array_equal(x.shape.data, (4,))
assert x.type.shape == (4,)
assert isinstance(x.shape.owner.op, Shape)
shape_fn = pytensor.function([x], x.shape)
opt_shape = shape_fn.maker.fgraph.outputs[0]
assert isinstance(opt_shape.owner.op, DeepCopyOp)
assert isinstance(opt_shape.owner.inputs[0], Constant)
assert np.array_equal(opt_shape.owner.inputs[0].data, (4,))
x = TensorConstant(
TensorType("int64", shape=(None, None)), np.array([[1, 2], [2, 3]])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论