提交 ccda3ccb authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix lift_rv_shapes when size is empty and parameters are broadcasted

上级 7a987064
...@@ -58,14 +58,20 @@ def lift_rv_shapes(node): ...@@ -58,14 +58,20 @@ def lift_rv_shapes(node):
dist_params = broadcast_params(dist_params, node.op.ndims_params) dist_params = broadcast_params(dist_params, node.op.ndims_params)
dist_params = [ if get_vector_length(size) > 0:
broadcast_to( dist_params = [
p, (tuple(size) + tuple(p.shape)) if node.op.ndim_supp > 0 else size broadcast_to(
) p, (tuple(size) + tuple(p.shape)) if node.op.ndim_supp > 0 else size
for p in dist_params )
] for p in dist_params
]
new_node = node.op.make_node(rng, None, dtype, *dist_params)
if config.compute_test_value != "off":
compute_test_value(new_node)
return node.op.make_node(rng, None, dtype, *dist_params) return new_node
@local_optimizer([DimShuffle]) @local_optimizer([DimShuffle])
......
...@@ -101,6 +101,13 @@ def test_lift_rv_shapes(): ...@@ -101,6 +101,13 @@ def test_lift_rv_shapes():
test_size = [] test_size = []
check_shape_lifted_rv(normal, test_params, test_size, rng) check_shape_lifted_rv(normal, test_params, test_size, rng)
test_params = [
np.array([0.0, 1.0], dtype=config.floatX),
np.array(5.0, dtype=config.floatX),
]
test_size = []
check_shape_lifted_rv(normal, test_params, test_size, rng)
test_params = [ test_params = [
np.array([0.0, 1.0], dtype=config.floatX), np.array([0.0, 1.0], dtype=config.floatX),
np.array(5.0, dtype=config.floatX), np.array(5.0, dtype=config.floatX),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论