Unverified 提交 2e9d502f authored 作者: Abhinav Khot's avatar Abhinav Khot 提交者: GitHub

Fix `get_vector_length` incorrectly returning for shared variable without static shape (#1295)

上级 9e603cf4
......@@ -3,7 +3,6 @@ import warnings
import numpy as np
from pytensor.compile import SharedVariable, shared_constructor
from pytensor.tensor import _get_vector_length
from pytensor.tensor.type import TensorType
from pytensor.tensor.variable import TensorVariable
......@@ -51,11 +50,6 @@ class TensorSharedVariable(SharedVariable, TensorVariable):
self.container.value = 0 * self.container.value
@_get_vector_length.register(TensorSharedVariable)
def _get_vector_length_TensorSharedVariable(var_inst, var):
return len(var.get_value(borrow=True))
@shared_constructor.register(np.ndarray)
def tensor_constructor(
value,
......
......@@ -965,8 +965,10 @@ class TestUnravelIndex(utt.InferShapeTester):
f_array_array = fn(indices, shape_array)
np.testing.assert_equal(ref, f_array_array())
# shape given as an PyTensor variable
shape_symb = pytensor.shared(shape_array)
# shape given as a shared PyTensor variable with static shape
shape_symb = pytensor.shared(
shape_array, shape=shape_array.shape, strict=True
)
f_array_symb = fn(indices, shape_symb)
np.testing.assert_equal(ref, f_array_symb())
......
......@@ -605,6 +605,7 @@ def makeSharedTester(
def test_values_eq(self):
# Test the type.values_eq[_approx] function
dtype = self.dtype
if dtype is None:
dtype = pytensor.config.floatX
......@@ -691,9 +692,13 @@ def test_scalar_shared_deprecated():
def test_get_vector_length():
x = pytensor.shared(np.array((2, 3, 4, 5)))
arr = np.array((2, 3, 4, 5))
x = pytensor.shared(arr, shape=arr.shape, strict=True)
assert get_vector_length(x) == 4
with pytest.raises(ValueError):
get_vector_length(pytensor.shared(arr))
def test_shared_masked_array_not_implemented():
x = np.ma.masked_greater(np.array([1, 2, 3, 4]), 3)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论