提交 05d376f3 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix bug in vectorize_random_variable when size is empty but not None

上级 85506229
...@@ -238,7 +238,7 @@ class RandomVariable(Op): ...@@ -238,7 +238,7 @@ class RandomVariable(Op):
raise ValueError( raise ValueError(
f"Size length is incompatible with batched dimensions of parameter {i} {param}:\n" f"Size length is incompatible with batched dimensions of parameter {i} {param}:\n"
f"len(size) = {size_len}, len(batched dims {param}) = {param_batched_dims}. " f"len(size) = {size_len}, len(batched dims {param}) = {param_batched_dims}. "
f"Size length must be 0 or >= {param_batched_dims}" f"Size must be None or have length >= {param_batched_dims}"
) )
return tuple(size) + supp_shape return tuple(size) + supp_shape
...@@ -454,11 +454,10 @@ def vectorize_random_variable( ...@@ -454,11 +454,10 @@ def vectorize_random_variable(
original_dist_params = op.dist_params(node) original_dist_params = op.dist_params(node)
old_size = op.size_param(node) old_size = op.size_param(node)
len_old_size = (
None if isinstance(old_size.type, NoneTypeT) else get_vector_length(old_size)
)
if len_old_size and equal_computations([old_size], [size]): if not isinstance(old_size.type, NoneTypeT) and equal_computations(
[old_size], [size]
):
# If the original RV had a size variable and a new one has not been provided, # If the original RV had a size variable and a new one has not been provided,
# we need to define a new size as the concatenation of the original size dimensions # we need to define a new size as the concatenation of the original size dimensions
# and the novel ones implied by new broadcasted batched parameters dimensions. # and the novel ones implied by new broadcasted batched parameters dimensions.
......
...@@ -296,6 +296,16 @@ def test_vectorize(): ...@@ -296,6 +296,16 @@ def test_vectorize():
assert vect_node.default_output().type.shape == (10, 2, 5) assert vect_node.default_output().type.shape == (10, 2, 5)
def test_vectorize_empty_size():
scalar_mu = pt.scalar("scalar_mu")
scalar_x = pt.random.normal(loc=scalar_mu, size=())
assert scalar_x.type.shape == ()
vector_mu = pt.vector("vector_mu", shape=(5,))
vector_x = vectorize_graph(scalar_x, {scalar_mu: vector_mu})
assert vector_x.type.shape == (5,)
def test_size_none_vs_empty(): def test_size_none_vs_empty():
rv = RandomVariable( rv = RandomVariable(
"normal", "normal",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论