提交 4e107721 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix a RandomVariable DimShuffle lift case for empty sizes

上级 5a2fb70b
...@@ -177,9 +177,10 @@ def local_dimshuffle_rv_lift(fgraph, node): ...@@ -177,9 +177,10 @@ def local_dimshuffle_rv_lift(fgraph, node):
# Update the `size` array to reflect the `DimShuffle`d dimensions, # Update the `size` array to reflect the `DimShuffle`d dimensions,
# since the trailing dimensions in `size` represent the independent # since the trailing dimensions in `size` represent the independent
# variates dimensions (for univariate distributions, at least) # variates dimensions (for univariate distributions, at least)
has_size = get_vector_length(size) > 0
new_size = ( new_size = (
[constant(1, dtype="int64") if o == "x" else size[o] for o in ds_new_order] [constant(1, dtype="int64") if o == "x" else size[o] for o in ds_new_order]
if get_vector_length(size) > 0 if has_size
else size else size
) )
...@@ -190,6 +191,14 @@ def local_dimshuffle_rv_lift(fgraph, node): ...@@ -190,6 +191,14 @@ def local_dimshuffle_rv_lift(fgraph, node):
d - reps_ind_split_idx if isinstance(d, int) else d d - reps_ind_split_idx if isinstance(d, int) else d
for d in ds_new_order[ds_ind_new_dims[0][0] :] for d in ds_new_order[ds_ind_new_dims[0][0] :]
] ]
if not has_size and len(ds_new_order[: ds_ind_new_dims[0][0]]) > 0:
# Additional broadcast dimensions need to be added to the
# independent dimensions (i.e. parameters), since there's no
# `size` to which they can be added
rv_params_new_order = (
list(ds_new_order[: ds_ind_new_dims[0][0]]) + rv_params_new_order
)
else: else:
# This case is reached when, for example, `ds_new_order` only # This case is reached when, for example, `ds_new_order` only
# consists of new broadcastable dimensions (i.e. `"x"`s) # consists of new broadcastable dimensions (i.e. `"x"`s)
......
...@@ -147,6 +147,17 @@ def test_local_rv_size_lift(dist_op, dist_params, size): ...@@ -147,6 +147,17 @@ def test_local_rv_size_lift(dist_op, dist_params, size):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"ds_order, lifted, dist_op, dist_params, size, rtol", "ds_order, lifted, dist_op, dist_params, size, rtol",
[ [
(
("x", 0),
True,
normal,
(
np.array([0.0, -100.0], dtype=np.float64),
np.array(1e-6, dtype=np.float64),
),
(),
1e-7,
),
( (
("x",), ("x",),
True, True,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论