提交 a64055dd authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fail early in RandomVariable.make_node when size is incompatible with parameters dimensionality

上级 237f54f9
...@@ -192,6 +192,19 @@ class RandomVariable(Op): ...@@ -192,6 +192,19 @@ class RandomVariable(Op):
size_len = get_vector_length(size) size_len = get_vector_length(size)
if size_len > 0: if size_len > 0:
# Fail early when size is incompatible with parameters
for i, (param, param_ndim_supp) in enumerate(
zip(dist_params, self.ndims_params)
):
param_batched_dims = getattr(param, "ndim", 0) - param_ndim_supp
if param_batched_dims > size_len:
raise ValueError(
f"Size length is incompatible with batched dimensions of parameter {i} {param}:\n"
f"len(size) = {size_len}, len(batched dims {param}) = {param_batched_dims}. "
f"Size length must be 0 or >= {param_batched_dims}"
)
if self.ndim_supp == 0: if self.ndim_supp == 0:
return size return size
else: else:
......
...@@ -217,3 +217,17 @@ def test_random_maker_ops_no_seed(): ...@@ -217,3 +217,17 @@ def test_random_maker_ops_no_seed():
z = function(inputs=[], outputs=[default_rng()])() z = function(inputs=[], outputs=[default_rng()])()
aes_res = z[0] aes_res = z[0]
assert isinstance(aes_res, np.random.Generator) assert isinstance(aes_res, np.random.Generator)
def test_RandomVariable_incompatible_size():
rv_op = RandomVariable("normal", 0, [0, 0], config.floatX, inplace=True)
with pytest.raises(
ValueError, match="Size length is incompatible with batched dimensions"
):
rv_op(np.zeros((1, 3)), 1, size=(3,))
rv_op = RandomVariable("dirichlet", 0, [1], config.floatX, inplace=True)
with pytest.raises(
ValueError, match="Size length is incompatible with batched dimensions"
):
rv_op(np.zeros((2, 4, 3)), 1, size=(4,))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论