提交 54f4b200 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Do not return Constants in shape Op

上级 8112576b
...@@ -146,14 +146,7 @@ def shape(x: Union[np.ndarray, Number, Variable]) -> Variable: ...@@ -146,14 +146,7 @@ def shape(x: Union[np.ndarray, Number, Variable]) -> Variable:
if not isinstance(x, Variable): if not isinstance(x, Variable):
x = at.as_tensor_variable(x) x = at.as_tensor_variable(x)
x_type = x.type return _shape(x)
if isinstance(x_type, TensorType) and all(s is not None for s in x_type.shape):
res = at.as_tensor_variable(x_type.shape, ndim=1, dtype=np.int64)
else:
res = _shape(x)
return res
@_get_vector_length.register(Shape) @_get_vector_length.register(Shape)
......
...@@ -12,7 +12,7 @@ from pytensor.configdefaults import config ...@@ -12,7 +12,7 @@ from pytensor.configdefaults import config
from pytensor.graph.basic import Constant, OptionalApplyType, Variable from pytensor.graph.basic import Constant, OptionalApplyType, Variable
from pytensor.graph.utils import MetaType from pytensor.graph.utils import MetaType
from pytensor.scalar import ComplexError, IntegerDivisionError from pytensor.scalar import ComplexError, IntegerDivisionError
from pytensor.tensor import _get_vector_length, as_tensor_variable from pytensor.tensor import _get_vector_length
from pytensor.tensor.exceptions import AdvancedIndexingError from pytensor.tensor.exceptions import AdvancedIndexingError
from pytensor.tensor.type import TensorType from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import NoneConst from pytensor.tensor.type_other import NoneConst
...@@ -259,9 +259,6 @@ class _tensor_py_operators: ...@@ -259,9 +259,6 @@ class _tensor_py_operators:
@property @property
def shape(self): def shape(self):
if not any(s is None for s in self.type.shape):
return as_tensor_variable(self.type.shape, ndim=1, dtype=np.int64)
return at.shape(self) return at.shape(self)
@property @property
......
...@@ -1477,13 +1477,12 @@ class TestSaveMem: ...@@ -1477,13 +1477,12 @@ class TestSaveMem:
f(x0=0, seq=test_seq, n_steps=0) f(x0=0, seq=test_seq, n_steps=0)
# Evaluate the shape of ys_trace and len_zs to confirm the rewrite worked correctly. # Evaluate the shape of ys_trace and len_zs to confirm the rewrite worked correctly.
# If a MissingInputError is raised, it means the rewrite failed
[scan_node] = (n for n in f.maker.fgraph.apply_nodes if isinstance(n.op, Scan)) [scan_node] = (n for n in f.maker.fgraph.apply_nodes if isinstance(n.op, Scan))
_, _, ys_trace, len_zs = scan_node.inputs _, _, ys_trace, len_zs = scan_node.inputs
debug_fn = pytensor.function( debug_fn = pytensor.function(
[n_steps], [ys_trace.shape[0], len_zs], accept_inplace=True [x0, n_steps], [ys_trace.shape[0], len_zs], accept_inplace=True
) )
stored_ys_steps, stored_zs_steps = debug_fn(n_steps=200) stored_ys_steps, stored_zs_steps = debug_fn(x0=0, n_steps=200)
assert stored_ys_steps == 2 assert stored_ys_steps == 2
assert stored_zs_steps == 1 assert stored_zs_steps == 1
......
...@@ -3506,7 +3506,7 @@ class TestSize: ...@@ -3506,7 +3506,7 @@ class TestSize:
def test_scalar(self): def test_scalar(self):
x = scalar() x = scalar()
y = np.array(7, dtype=config.floatX) y = np.array(7, dtype=config.floatX)
assert y.size == function([], x.size)() assert y.size == function([x], x.size)(y)
def test_shared(self): def test_shared(self):
# NB: we also test higher order tensors at the same time. # NB: we also test higher order tensors at the same time.
......
...@@ -7,7 +7,6 @@ from numpy.testing import assert_array_almost_equal ...@@ -7,7 +7,6 @@ from numpy.testing import assert_array_almost_equal
import pytensor import pytensor
from pytensor import function from pytensor import function
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.basic import Constant
from pytensor.tensor.math import _allclose from pytensor.tensor.math import _allclose
from pytensor.tensor.nlinalg import ( from pytensor.tensor.nlinalg import (
SVD, SVD,
...@@ -274,9 +273,7 @@ def test_det_grad(): ...@@ -274,9 +273,7 @@ def test_det_grad():
def test_det_shape(): def test_det_shape():
x = matrix() x = matrix()
det_shape = det(x).shape assert det(x).type.shape == ()
assert isinstance(det_shape, Constant)
assert tuple(det_shape.data) == ()
def test_slogdet(): def test_slogdet():
......
...@@ -16,6 +16,7 @@ from pytensor.tensor.elemwise import DimShuffle, Elemwise ...@@ -16,6 +16,7 @@ from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.rewriting.shape import ShapeFeature from pytensor.tensor.rewriting.shape import ShapeFeature
from pytensor.tensor.shape import ( from pytensor.tensor.shape import (
Reshape, Reshape,
Shape,
Shape_i, Shape_i,
SpecifyShape, SpecifyShape,
Unbroadcast, Unbroadcast,
...@@ -397,7 +398,7 @@ class TestSpecifyShape(utt.InferShapeTester): ...@@ -397,7 +398,7 @@ class TestSpecifyShape(utt.InferShapeTester):
shape = as_tensor_variable([2]) shape = as_tensor_variable([2])
y = specify_shape(x, shape) y = specify_shape(x, shape)
assert y.type.shape == (2,) assert y.type.shape == (2,)
assert y.shape.equals(shape) assert isinstance(y.shape.owner.op, Shape)
def test_fixed_partial_shapes(self): def test_fixed_partial_shapes(self):
x = TensorType("floatX", (None, None))("x") x = TensorType("floatX", (None, None))("x")
......
...@@ -6,12 +6,14 @@ from numpy.testing import assert_array_equal, assert_equal, assert_string_equal ...@@ -6,12 +6,14 @@ from numpy.testing import assert_array_equal, assert_equal, assert_string_equal
import pytensor import pytensor
import tests.unittest_tools as utt import tests.unittest_tools as utt
from pytensor.compile import DeepCopyOp
from pytensor.compile.mode import get_default_mode from pytensor.compile.mode import get_default_mode
from pytensor.graph.basic import Constant, equal_computations from pytensor.graph.basic import Constant, equal_computations
from pytensor.tensor import get_vector_length from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import constant from pytensor.tensor.basic import constant
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import dot, eq from pytensor.tensor.math import dot, eq
from pytensor.tensor.shape import Shape
from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor
from pytensor.tensor.type import ( from pytensor.tensor.type import (
TensorType, TensorType,
...@@ -245,8 +247,14 @@ def test__getitem__newaxis(x, indices, new_order): ...@@ -245,8 +247,14 @@ def test__getitem__newaxis(x, indices, new_order):
def test_fixed_shape_variable_basic(): def test_fixed_shape_variable_basic():
x = TensorVariable(TensorType("int64", shape=(4,)), None) x = TensorVariable(TensorType("int64", shape=(4,)), None)
assert isinstance(x.shape, Constant) assert x.type.shape == (4,)
assert np.array_equal(x.shape.data, (4,)) assert isinstance(x.shape.owner.op, Shape)
shape_fn = pytensor.function([x], x.shape)
opt_shape = shape_fn.maker.fgraph.outputs[0]
assert isinstance(opt_shape.owner.op, DeepCopyOp)
assert isinstance(opt_shape.owner.inputs[0], Constant)
assert np.array_equal(opt_shape.owner.inputs[0].data, (4,))
x = TensorConstant( x = TensorConstant(
TensorType("int64", shape=(None, None)), np.array([[1, 2], [2, 3]]) TensorType("int64", shape=(None, None)), np.array([[1, 2], [2, 3]])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论