提交 8e8c5e8e authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix SharedVariable.value issue in broadcast_shape_iter

Closes #1255
上级 967d1538
......@@ -1493,7 +1493,7 @@ def broadcast_shape_iter(
(one_at,) * (max_dims - len(a))
+ tuple(
one_at
if getattr(sh, "value", sh) == 1
if sh == 1 or isinstance(sh, Constant) and sh.value == 1
else (aes.as_scalar(sh) if not isinstance(sh, Variable) else sh)
for sh in a
)
......
......@@ -1142,6 +1142,12 @@ def test_broadcast_shape_basic():
b_at = broadcast_shape(x_at, y_at)
assert isinstance(b_at[-1].owner.op, Assert)
# N.B. Shared variable shape values shouldn't be treated as constants,
# because they can change.
s = aesara.shared(1)
b_at = broadcast_shape((s, 2), (2, 1), arrays_are_shapes=True)
assert isinstance(b_at[0].owner.op, Assert)
def test_broadcast_shape_constants():
"""Make sure `broadcast_shape` uses constants when it can."""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论