提交 6834740a authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Michael Osthege

Fix static type shape bug

上级 081a0b48
......@@ -191,7 +191,7 @@ class RandomVariable(Op):
return shape
batch_shape = [
s if b is False else constant(1, "int64")
s if not b else constant(1, "int64")
for s, b in zip(shape[:-n], p.type.broadcastable[:-n])
]
return batch_shape
......
......@@ -109,8 +109,13 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
def parse_bcast_and_shape(s):
if isinstance(s, (bool, np.bool_)):
return 1 if s else None
else:
elif isinstance(s, (int, np.int_)):
return int(s)
elif s is None:
return s
raise ValueError(
f"TensorType broadcastable/shape must be a boolean, integer or None, got {type(s)} {s}"
)
self.shape = tuple(parse_bcast_and_shape(s) for s in shape)
self.dtype_specs() # error checking is done there
......
......@@ -16,6 +16,7 @@ from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.tensor import ones, stack
from pytensor.tensor.random.basic import (
_gamma,
bernoulli,
......@@ -1465,3 +1466,12 @@ def test_rebuild():
assert y_new.type.shape == (100,)
assert y_new.shape.eval({x_new: x_new_test}) == (100,)
assert y_new.eval({x_new: x_new_test}).shape == (100,)
def test_categorical_join_p_static_shape():
"""Regression test against a bug caused by misreading a numpy.bool_"""
p = ones(3) / 3
prob = stack([p, 1 - p], axis=-1)
assert prob.type.shape == (3, 2)
x = categorical(p=prob)
assert x.type.shape == (3,)
......@@ -2046,8 +2046,14 @@ class TestJoinAndSplit:
def test_static_shape_inference(self):
a = at.tensor(dtype="int8", shape=(2, 3))
b = at.tensor(dtype="int8", shape=(2, 5))
assert at.join(1, a, b).type.shape == (2, 8)
assert at.join(-1, a, b).type.shape == (2, 8)
res = at.join(1, a, b).type.shape
assert res == (2, 8)
assert all(isinstance(s, int) for s in res)
res = at.join(-1, a, b).type.shape
assert res == (2, 8)
assert all(isinstance(s, int) for s in res)
# Check early informative errors from static shape info
with pytest.raises(ValueError, match="must match exactly"):
......@@ -2055,8 +2061,9 @@ class TestJoinAndSplit:
# Check partial inference
d = at.tensor(dtype="int8", shape=(2, None))
assert at.join(1, a, b, d).type.shape == (2, None)
return
res = at.join(1, a, b, d).type.shape
assert res == (2, None)
assert isinstance(res[0], int)
def test_split_0elem(self):
rng = np.random.default_rng(seed=utt.fetch_seed())
......
......@@ -267,6 +267,27 @@ def test_fixed_shape_basic():
assert t2.shape == (2, 4)
def test_shape_type_conversion():
t1 = TensorType("float64", shape=np.array([3], dtype=int))
assert t1.shape == (3,)
assert isinstance(t1.shape[0], int)
assert t1.broadcastable == (False,)
assert isinstance(t1.broadcastable[0], bool)
t2 = TensorType("float64", broadcastable=np.array([True, False], dtype="bool"))
assert t2.shape == (1, None)
assert isinstance(t2.shape[0], int)
assert t2.broadcastable == (True, False)
assert isinstance(t2.broadcastable[0], bool)
assert isinstance(t2.broadcastable[1], bool)
with pytest.raises(
ValueError,
match="TensorType broadcastable/shape must be a boolean, integer or None",
):
TensorType("float64", shape=("1", "2"))
def test_fixed_shape_clone():
t1 = TensorType("float64", (1,))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论