提交 f2a7fb99 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix broadcasting conditions in broadcast_shape_iter

上级 d3a435a9
...@@ -1548,15 +1548,26 @@ def broadcast_shape_iter( ...@@ -1548,15 +1548,26 @@ def broadcast_shape_iter(
) )
if len(nonconst_nb_shapes) > 0: if len(nonconst_nb_shapes) > 0:
# All the potential non-broadcast shapes need to either
# be broadcastable or equal to the one non-broadcastable
# constant `const_nt_shape_var`.
assert_dim = Assert("Could not broadcast dimensions") assert_dim = Assert("Could not broadcast dimensions")
assert_cond = reduce( assert_cond = reduce(
aes.and_, aes.and_,
(aes.eq(nbv, const_nt_shape_var) for nbv in nonconst_nb_shapes), (
aes.or_(
aes.eq(nbv, one_at), aes.eq(nbv, const_nt_shape_var)
)
for nbv in nonconst_nb_shapes
),
) )
bcast_dim = assert_dim(const_nt_shape_var, assert_cond) bcast_dim = assert_dim(const_nt_shape_var, assert_cond)
else: else:
bcast_dim = const_nt_shape_var bcast_dim = const_nt_shape_var
else: else:
# There are no constant, non-broadcastable shapes in this
# dimension.
all_dims_equal = all( all_dims_equal = all(
# TODO FIXME: This is a largely deficient, and expensive, means # TODO FIXME: This is a largely deficient, and expensive, means
# of comparing graphs (and especially shapes) # of comparing graphs (and especially shapes)
......
...@@ -1198,6 +1198,31 @@ def test_broadcast_shape_symbolic(s1_vals, s2_vals, exp_res): ...@@ -1198,6 +1198,31 @@ def test_broadcast_shape_symbolic(s1_vals, s2_vals, exp_res):
assert tuple(res.eval(eval_point)) == exp_res assert tuple(res.eval(eval_point)) == exp_res
def test_broadcast_shape_symbolic_one_symbolic():
"""Test case for a constant non-broadcast shape and a symbolic shape."""
one_at = at.as_tensor(1, dtype=np.int64)
three_at = at.as_tensor(3, dtype=np.int64)
int_div = one_at / one_at
assert int_div.owner.op == at.true_div
index_shapes = [
(one_at, one_at, three_at),
(one_at, int_div, one_at),
(one_at, one_at, int_div),
]
res_shape = broadcast_shape(*index_shapes, arrays_are_shapes=True)
from aesara.graph.rewriting.utils import rewrite_graph
res_shape = rewrite_graph(res_shape)
assert res_shape[0].data == 1
assert res_shape[1].data == 1
assert res_shape[2].data == 3
class TestBroadcastTo(utt.InferShapeTester): class TestBroadcastTo(utt.InferShapeTester):
def setup_method(self): def setup_method(self):
super().setup_method() super().setup_method()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论