提交 8e8c5e8e authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix SharedVariable.value issue in broadcast_shape_iter

Closes #1255
上级 967d1538
...@@ -1493,7 +1493,7 @@ def broadcast_shape_iter( ...@@ -1493,7 +1493,7 @@ def broadcast_shape_iter(
(one_at,) * (max_dims - len(a)) (one_at,) * (max_dims - len(a))
+ tuple( + tuple(
one_at one_at
if getattr(sh, "value", sh) == 1 if sh == 1 or isinstance(sh, Constant) and sh.value == 1
else (aes.as_scalar(sh) if not isinstance(sh, Variable) else sh) else (aes.as_scalar(sh) if not isinstance(sh, Variable) else sh)
for sh in a for sh in a
) )
......
...@@ -1142,6 +1142,12 @@ def test_broadcast_shape_basic(): ...@@ -1142,6 +1142,12 @@ def test_broadcast_shape_basic():
b_at = broadcast_shape(x_at, y_at) b_at = broadcast_shape(x_at, y_at)
assert isinstance(b_at[-1].owner.op, Assert) assert isinstance(b_at[-1].owner.op, Assert)
# N.B. Shared variable shape values shouldn't be treated as constants,
# because they can change.
s = aesara.shared(1)
b_at = broadcast_shape((s, 2), (2, 1), arrays_are_shapes=True)
assert isinstance(b_at[0].owner.op, Assert)
def test_broadcast_shape_constants(): def test_broadcast_shape_constants():
"""Make sure `broadcast_shape` uses constants when it can.""" """Make sure `broadcast_shape` uses constants when it can."""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论