提交 a8f952dc authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix bug in supp_shape_from_ref_param_shape

上级 af88b456
...@@ -323,7 +323,7 @@ def supp_shape_from_ref_param_shape( ...@@ -323,7 +323,7 @@ def supp_shape_from_ref_param_shape(
raise ValueError("ndim_supp must be greater than 0") raise ValueError("ndim_supp must be greater than 0")
if param_shapes is not None: if param_shapes is not None:
ref_param = param_shapes[ref_param_idx] ref_param = param_shapes[ref_param_idx]
return (ref_param[-ndim_supp],) return tuple(ref_param[i] for i in range(-ndim_supp, 0))
else: else:
ref_param = dist_params[ref_param_idx] ref_param = dist_params[ref_param_idx]
if ref_param.ndim < ndim_supp: if ref_param.ndim < ndim_supp:
......
...@@ -313,3 +313,11 @@ def test_supp_shape_from_ref_param_shape(): ...@@ -313,3 +313,11 @@ def test_supp_shape_from_ref_param_shape():
ref_param_idx=1, ref_param_idx=1,
) )
assert res == (3, 4) assert res == (3, 4)
res = supp_shape_from_ref_param_shape(
ndim_supp=2,
dist_params=(np.array([1, 2]), np.ones((2, 3, 4))),
param_shapes=((2,), (2, 3, 4)),
ref_param_idx=1,
)
assert res == (3, 4)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论