提交 5248aa9a authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix bug in broadcast_shape_iter when all dims are broadcastable

上级 95deb922
...@@ -1560,11 +1560,11 @@ def broadcast_shape_iter( ...@@ -1560,11 +1560,11 @@ def broadcast_shape_iter(
non_bcast_vec = aet.as_tensor(maybe_non_bcast_shapes) non_bcast_vec = aet.as_tensor(maybe_non_bcast_shapes)
non_bcast_vec = aet.switch(eq(non_bcast_vec, 1), -one_at, non_bcast_vec) non_bcast_vec = aet.switch(eq(non_bcast_vec, 1), -one_at, non_bcast_vec)
dim_max = aet_max(non_bcast_vec) dim_max = aet_abs(aet_max(non_bcast_vec))
assert_dim = Assert("Could not broadcast dimensions") assert_dim = Assert("Could not broadcast dimensions")
assert_cond = aet_all( assert_cond = aet_all(
or_(eq(non_bcast_vec, -one_at), eq(non_bcast_vec, aet_abs(dim_max))) or_(eq(non_bcast_vec, -one_at), eq(non_bcast_vec, dim_max))
) )
bcast_dim = assert_dim(dim_max, assert_cond) bcast_dim = assert_dim(dim_max, assert_cond)
......
...@@ -1073,29 +1073,25 @@ def test_broadcast_shape_basic(): ...@@ -1073,29 +1073,25 @@ def test_broadcast_shape_basic():
[ [
((2, 2), (1, 2), (2, 2)), ((2, 2), (1, 2), (2, 2)),
((0, 2), (1, 2), (0, 2)), ((0, 2), (1, 2), (0, 2)),
((1, 2, 1), (2, 1, 2, 1), (2, 1, 2, 1)),
], ],
) )
@config.change_flags(compute_test_value="raise")
def test_broadcast_shape_symbolic(s1_vals, s2_vals, exp_res): def test_broadcast_shape_symbolic(s1_vals, s2_vals, exp_res):
s1_1, s1_2 = aet.lscalars("s1_1", "s1_2") s1s = aet.lscalars(len(s1_vals))
s2_1, s2_2 = aet.lscalars("s2_1", "s2_2") eval_point = {}
for s, s_val in zip(s1s, s1_vals):
s1_1.tag.test_value = s1_vals[0] eval_point[s] = s_val
s1_2.tag.test_value = s1_vals[1] s.tag.test_value = s_val
s2_1.tag.test_value = s2_vals[0]
s2_2.tag.test_value = s2_vals[1] s2s = aet.lscalars(len(s2_vals))
for s, s_val in zip(s2s, s2_vals):
res = broadcast_shape((s1_1, s1_2), (s2_1, s2_2), arrays_are_shapes=True) eval_point[s] = s_val
s.tag.test_value = s_val
res = broadcast_shape(s1s, s2s, arrays_are_shapes=True)
res = aet.as_tensor(res) res = aet.as_tensor(res)
assert ( assert tuple(res.eval(eval_point)) == exp_res
tuple(
res.eval(
{s1_1: s1_vals[0], s1_2: s1_vals[1], s2_1: s2_vals[0], s2_2: s2_vals[1]}
)
)
== exp_res
)
class TestBroadcastTo(utt.InferShapeTester): class TestBroadcastTo(utt.InferShapeTester):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论