提交 0cf56c68 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix broadcasting bug in broadcast_shape_iter

上级 7ed6a03b
......@@ -22,7 +22,9 @@ from aesara.tensor import basic as aet
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import abs as aet_abs
from aesara.tensor.math import all as aet_all
from aesara.tensor.math import eq, ge, lt, maximum, minimum, or_, prod
from aesara.tensor.math import eq, ge, lt
from aesara.tensor.math import max as aet_max
from aesara.tensor.math import maximum, minimum, or_, prod
from aesara.tensor.math import sum as aet_sum
from aesara.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
from aesara.tensor.type import (
......@@ -1476,7 +1478,8 @@ def broadcast_shape(*arrays, **kwargs):
arrays_are_shapes: bool (Optional)
Indicates whether or not the `arrays` contains shape tuples.
If you use this approach, make sure that the broadcastable dimensions
are (scalar) constants with the value ``1`` or ``1`` exactly.
are (scalar) constants with the value ``1``--or simply the integer
``1``.
"""
return broadcast_shape_iter(arrays, **kwargs)
......@@ -1486,77 +1489,86 @@ def broadcast_shape_iter(
arrays: Iterable[Union[TensorVariable, Tuple[TensorVariable, ...]]],
arrays_are_shapes: bool = False,
):
"""Compute the shape resulting from broadcasting arrays.
r"""Compute the shape resulting from broadcasting arrays.
.. warning::
This function will not make copies, so be careful when calling it with
a generator/iterator!
Parameters
----------
arrays
An iterable of tensors, or a tuple of shapes (as tuples),
for which the broadcast shape is computed.
XXX: Do not call this with a generator/iterator; this function will not
make copies!
arrays_are_shapes
Indicates whether or not the `arrays` contains shape tuples.
If you use this approach, make sure that the broadcastable dimensions
are (scalar) constants with the value ``1`` or ``1`` exactly.
are (scalar) constants with the value ``1``--or simply the integer
``1``.
"""
one = aesara.scalar.ScalarConstant(aesara.scalar.int64, 1)
one_at = aesara.scalar.ScalarConstant(aesara.scalar.int64, 1)
if arrays_are_shapes:
max_dims = max(len(a) for a in arrays)
array_shapes = [
(one,) * (max_dims - len(a))
+ tuple(one if getattr(sh, "value", sh) == 1 else sh for sh in a)
(one_at,) * (max_dims - len(a))
+ tuple(one_at if getattr(sh, "value", sh) == 1 else sh for sh in a)
for a in arrays
]
else:
max_dims = max(a.ndim for a in arrays)
array_shapes = [
(one,) * (max_dims - a.ndim)
+ tuple(one if bcast else sh for sh, bcast in zip(a.shape, a.broadcastable))
(one_at,) * (max_dims - a.ndim)
+ tuple(
one_at if bcast else sh for sh, bcast in zip(a.shape, a.broadcastable)
)
for a in arrays
]
result_dims = []
for dim_shapes in zip(*array_shapes):
non_bcast_shapes = [shape for shape in dim_shapes if shape != one]
if len(non_bcast_shapes) > 0:
# Either there's only one non-broadcastable dimensions--and that's
# what determines the dimension size, or there are multiple
# non-broadcastable dimensions that must be equal
i_dim = non_bcast_shapes.pop()
# Get the shapes in this dimension that are not definitively
# broadcastable (i.e. not symbolically known to be broadcastable)
maybe_non_bcast_shapes = [shape for shape in dim_shapes if shape != one_at]
if len(maybe_non_bcast_shapes) == 0:
# Every shape was broadcastable in this dimension
result_dims.append(one_at)
elif len(maybe_non_bcast_shapes) == 1:
# Only one shape might not be broadcastable in this dimension
result_dims.extend(maybe_non_bcast_shapes)
else:
# More than one shape might not be broadcastable in this dimension
potentially_unequal_dims = [
dim
for dim in non_bcast_shapes
all_dims_equal = all(
# TODO FIXME: This is a largely deficient means of comparing graphs
# (and especially shapes)
if not equal_computations([i_dim], [dim])
]
equal_computations([maybe_non_bcast_shapes[0]], [dim])
for dim in maybe_non_bcast_shapes[1:]
)
if potentially_unequal_dims:
# In this case, we can't tell whether or not the dimensions are
# equal, so we'll need to assert their equality and move the error
# handling to evaluation time.
assert_dim = Assert("Could not broadcast dimensions")
eq_condition = aet_all(
[
or_(eq(dim, one), eq(i_dim, dim))
for dim in potentially_unequal_dims
]
)
eq_condition = or_(eq(i_dim, one), eq_condition)
result_dims.append(assert_dim(i_dim, eq_condition))
else:
result_dims.append(i_dim)
else:
# Every array was broadcastable in this dimension
result_dims.append(one)
if all_dims_equal:
result_dims.append(maybe_non_bcast_shapes[0])
continue
non_bcast_vec = aet.as_tensor(maybe_non_bcast_shapes)
non_bcast_vec = aet.switch(eq(non_bcast_vec, 1), -one_at, non_bcast_vec)
dim_max = aet_max(non_bcast_vec)
assert_dim = Assert("Could not broadcast dimensions")
assert_cond = aet_all(
or_(eq(non_bcast_vec, -one_at), eq(non_bcast_vec, aet_abs(dim_max)))
)
bcast_dim = assert_dim(dim_max, assert_cond)
result_dims.append(bcast_dim)
return tuple(result_dims)
......
......@@ -971,7 +971,7 @@ class TestRavelMultiIndex(utt.InferShapeTester):
ravel_multi_index(((3, 4),), ((3, 4),))
def test_broadcast_shape():
def test_broadcast_shape_basic():
def shape_tuple(x, use_bcast=True):
if use_bcast:
return tuple(
......@@ -1006,11 +1006,6 @@ def test_broadcast_shape():
shape_tuple(x_aet), shape_tuple(y_aet), arrays_are_shapes=True
)
assert np.array_equal([z.eval() for z in b_aet], b.shape)
# These are all constants, so there shouldn't be any asserts in the
# resulting graph.
assert not any(
isinstance(node.op, Assert) for node in applys_between([x_aet, y_aet], b_aet)
)
x = np.array([1, 2, 3])
y = np.array([4, 5, 6])
......@@ -1023,12 +1018,6 @@ def test_broadcast_shape():
shape_tuple(x_aet), shape_tuple(y_aet), arrays_are_shapes=True
)
assert np.array_equal([z.eval() for z in b_aet], b.shape)
# TODO: This will work when/if we use a more sophisticated `is_same_graph`
# implementation.
# assert not any(
# isinstance(node.op, Assert)
# for node in graph_ops([x_aet, y_aet], b_aet)
# )
x = np.empty((1, 2, 3))
y = np.array(1)
......@@ -1038,9 +1027,6 @@ def test_broadcast_shape():
b_aet = broadcast_shape(x_aet, y_aet)
assert b_aet[0].value == 1
assert np.array_equal([z.eval() for z in b_aet], b.shape)
assert not any(
isinstance(node.op, Assert) for node in applys_between([x_aet, y_aet], b_aet)
)
b_aet = broadcast_shape(
shape_tuple(x_aet), shape_tuple(y_aet), arrays_are_shapes=True
)
......@@ -1054,12 +1040,6 @@ def test_broadcast_shape():
b_aet = broadcast_shape(x_aet, y_aet)
assert b_aet[1].value == 1
assert np.array_equal([z.eval() for z in b_aet], b.shape)
# TODO: This will work when/if we use a more sophisticated `is_same_graph`
# implementation.
# assert not any(
# isinstance(node.op, Assert)
# for node in graph_ops([x_aet, y_aet], b_aet)
# )
b_aet = broadcast_shape(
shape_tuple(x_aet), shape_tuple(y_aet), arrays_are_shapes=True
)
......@@ -1073,12 +1053,6 @@ def test_broadcast_shape():
y_shapes = (y1_shp_aet, 1, x2_shp_aet)
y_aet = aet.ones(y_shapes)
b_aet = broadcast_shape(x_aet, y_aet)
# TODO: This will work when/if we use a more sophisticated `is_same_graph`
# implementation.
# assert not any(
# isinstance(node.op, Assert)
# for node in graph_ops([x_aet, y_aet], b_aet)
# )
res = aet.as_tensor(b_aet).eval(
{
x1_shp_aet: 10,
......@@ -1094,6 +1068,36 @@ def test_broadcast_shape():
assert isinstance(b_aet[-1].owner.op, Assert)
@pytest.mark.parametrize(
("s1_vals", "s2_vals", "exp_res"),
[
((2, 2), (1, 2), (2, 2)),
((0, 2), (1, 2), (0, 2)),
],
)
@config.change_flags(compute_test_value="raise")
def test_broadcast_shape_symbolic(s1_vals, s2_vals, exp_res):
s1_1, s1_2 = aet.lscalars("s1_1", "s1_2")
s2_1, s2_2 = aet.lscalars("s2_1", "s2_2")
s1_1.tag.test_value = s1_vals[0]
s1_2.tag.test_value = s1_vals[1]
s2_1.tag.test_value = s2_vals[0]
s2_2.tag.test_value = s2_vals[1]
res = broadcast_shape((s1_1, s1_2), (s2_1, s2_2), arrays_are_shapes=True)
res = aet.as_tensor(res)
assert (
tuple(
res.eval(
{s1_1: s1_vals[0], s1_2: s1_vals[1], s2_1: s2_vals[0], s2_2: s2_vals[1]}
)
)
== exp_res
)
class TestBroadcastTo(utt.InferShapeTester):
rng = np.random.default_rng(43)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论