提交 5248aa9a authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix bug in broadcast_shape_iter when all dims are broadcastable

上级 95deb922
......@@ -1560,11 +1560,11 @@ def broadcast_shape_iter(
non_bcast_vec = aet.as_tensor(maybe_non_bcast_shapes)
non_bcast_vec = aet.switch(eq(non_bcast_vec, 1), -one_at, non_bcast_vec)
dim_max = aet_max(non_bcast_vec)
dim_max = aet_abs(aet_max(non_bcast_vec))
assert_dim = Assert("Could not broadcast dimensions")
assert_cond = aet_all(
or_(eq(non_bcast_vec, -one_at), eq(non_bcast_vec, aet_abs(dim_max)))
or_(eq(non_bcast_vec, -one_at), eq(non_bcast_vec, dim_max))
)
bcast_dim = assert_dim(dim_max, assert_cond)
......
......@@ -1073,29 +1073,25 @@ def test_broadcast_shape_basic():
[
((2, 2), (1, 2), (2, 2)),
((0, 2), (1, 2), (0, 2)),
((1, 2, 1), (2, 1, 2, 1), (2, 1, 2, 1)),
],
)
@config.change_flags(compute_test_value="raise")
def test_broadcast_shape_symbolic(s1_vals, s2_vals, exp_res):
s1_1, s1_2 = aet.lscalars("s1_1", "s1_2")
s2_1, s2_2 = aet.lscalars("s2_1", "s2_2")
s1_1.tag.test_value = s1_vals[0]
s1_2.tag.test_value = s1_vals[1]
s2_1.tag.test_value = s2_vals[0]
s2_2.tag.test_value = s2_vals[1]
res = broadcast_shape((s1_1, s1_2), (s2_1, s2_2), arrays_are_shapes=True)
s1s = aet.lscalars(len(s1_vals))
eval_point = {}
for s, s_val in zip(s1s, s1_vals):
eval_point[s] = s_val
s.tag.test_value = s_val
s2s = aet.lscalars(len(s2_vals))
for s, s_val in zip(s2s, s2_vals):
eval_point[s] = s_val
s.tag.test_value = s_val
res = broadcast_shape(s1s, s2s, arrays_are_shapes=True)
res = aet.as_tensor(res)
assert (
tuple(
res.eval(
{s1_1: s1_vals[0], s1_2: s1_vals[1], s2_1: s2_vals[0], s2_2: s2_vals[1]}
)
)
== exp_res
)
assert tuple(res.eval(eval_point)) == exp_res
class TestBroadcastTo(utt.InferShapeTester):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论