提交 5db98be1 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Lift broadcast-only DimShuffles through RandomVariable Ops

上级 a95c3f85
...@@ -146,11 +146,11 @@ def local_dimshuffle_rv_lift(fgraph, node): ...@@ -146,11 +146,11 @@ def local_dimshuffle_rv_lift(fgraph, node):
ds_reps_new_dims = dim_orders[:reps_ind_split_idx] ds_reps_new_dims = dim_orders[:reps_ind_split_idx]
ds_ind_new_dims = dim_orders[reps_ind_split_idx:] ds_ind_new_dims = dim_orders[reps_ind_split_idx:]
ds_only_in_ind = ds_ind_new_dims and all( ds_in_ind_space = ds_ind_new_dims and all(
d >= reps_ind_split_idx for n, d in ds_ind_new_dims d >= reps_ind_split_idx for n, d in ds_ind_new_dims
) )
if ds_only_in_ind: if ds_in_ind_space or (not ds_ind_new_dims and not ds_reps_new_dims):
# Update the `size` array to reflect the `DimShuffle`d dimensions, # Update the `size` array to reflect the `DimShuffle`d dimensions,
# since the trailing dimensions in `size` represent the independent # since the trailing dimensions in `size` represent the independent
...@@ -163,10 +163,15 @@ def local_dimshuffle_rv_lift(fgraph, node): ...@@ -163,10 +163,15 @@ def local_dimshuffle_rv_lift(fgraph, node):
# Compute the new axes parameter(s) for the `DimShuffle` that will be # Compute the new axes parameter(s) for the `DimShuffle` that will be
# applied to the `RandomVariable` parameters (they need to be offset) # applied to the `RandomVariable` parameters (they need to be offset)
rv_params_new_order = [ if ds_ind_new_dims:
d - reps_ind_split_idx if isinstance(d, int) else d rv_params_new_order = [
for d in ds_new_order[ds_ind_new_dims[0][0] :] d - reps_ind_split_idx if isinstance(d, int) else d
] for d in ds_new_order[ds_ind_new_dims[0][0] :]
]
else:
# This case is reached when, for example, `ds_new_order` only
# consists of new broadcastable dimensions (i.e. `"x"`s)
rv_params_new_order = ds_new_order
# Lift the `DimShuffle`s into the parameters # Lift the `DimShuffle`s into the parameters
# NOTE: The parameters might not be broadcasted against each other, so # NOTE: The parameters might not be broadcasted against each other, so
...@@ -192,11 +197,11 @@ def local_dimshuffle_rv_lift(fgraph, node): ...@@ -192,11 +197,11 @@ def local_dimshuffle_rv_lift(fgraph, node):
return [new_node.outputs[1]] return [new_node.outputs[1]]
ds_only_in_reps = ds_reps_new_dims and all( ds_in_reps_space = ds_reps_new_dims and all(
d < reps_ind_split_idx for n, d in ds_reps_new_dims d < reps_ind_split_idx for n, d in ds_reps_new_dims
) )
if ds_only_in_reps: if ds_in_reps_space:
# Update the `size` array to reflect the `DimShuffle`d dimensions. # Update the `size` array to reflect the `DimShuffle`d dimensions.
# There should be no need to `DimShuffle` now. # There should be no need to `DimShuffle` now.
new_size = [ new_size = [
......
...@@ -134,6 +134,28 @@ def test_lift_rv_shapes(): ...@@ -134,6 +134,28 @@ def test_lift_rv_shapes():
@pytest.mark.parametrize( @pytest.mark.parametrize(
"ds_order, lifted, dist_op, dist_params, size, rtol", "ds_order, lifted, dist_op, dist_params, size, rtol",
[ [
(
("x",),
True,
normal,
(
np.array(-10.0, dtype=np.float64),
np.array(1e-6, dtype=np.float64),
),
(),
1e-7,
),
(
("x", "x", "x"),
True,
normal,
(
np.array(-10.0, dtype=np.float64),
np.array(1e-6, dtype=np.float64),
),
(),
1e-7,
),
( (
(1, 0, 2), (1, 0, 2),
True, True,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论