提交 237f54f9 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix shape_inference of `ChoiceRV` when param_shapes are provided

上级 cc054868
......@@ -1990,11 +1990,12 @@ class ChoiceRV(RandomVariable):
raise NotImplementedError()
def _infer_shape(self, size, dist_params, param_shapes=None):
(a, p, _) = dist_params
a, p, _ = dist_params
if isinstance(p.type, pytensor.tensor.type_other.NoneTypeT):
param_shapes = param_shapes[:1] if param_shapes is not None else None
shape = super()._infer_shape(size, (a,), param_shapes)
else:
param_shapes = param_shapes[:2] if param_shapes is not None else None
shape = super()._infer_shape(size, (a, p), param_shapes)
return shape
......
......@@ -1390,6 +1390,19 @@ def test_choice_samples():
compare_sample_values(choice, at.as_tensor_variable([1, 2, 3]), 2, replace=True)
def test_choice_infer_shape():
node = choice([0, 1]).owner
res = node.op._infer_shape((), node.inputs[3:], None)
assert tuple(res.eval()) == ()
node = choice([0, 1]).owner
# The param_shape of a NoneConst is None, during shape_inference
res = node.op._infer_shape(
(), node.inputs[3:], (node.inputs[3].shape, None, node.inputs[5].shape)
)
assert tuple(res.eval()) == ()
def test_permutation_samples():
compare_sample_values(
permutation,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论