提交 0cf56c68 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix broadcasting bug in broadcast_shape_iter

上级 7ed6a03b
...@@ -22,7 +22,9 @@ from aesara.tensor import basic as aet ...@@ -22,7 +22,9 @@ from aesara.tensor import basic as aet
from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import abs as aet_abs from aesara.tensor.math import abs as aet_abs
from aesara.tensor.math import all as aet_all from aesara.tensor.math import all as aet_all
from aesara.tensor.math import eq, ge, lt, maximum, minimum, or_, prod from aesara.tensor.math import eq, ge, lt
from aesara.tensor.math import max as aet_max
from aesara.tensor.math import maximum, minimum, or_, prod
from aesara.tensor.math import sum as aet_sum from aesara.tensor.math import sum as aet_sum
from aesara.tensor.subtensor import advanced_inc_subtensor1, set_subtensor from aesara.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
from aesara.tensor.type import ( from aesara.tensor.type import (
...@@ -1476,7 +1478,8 @@ def broadcast_shape(*arrays, **kwargs): ...@@ -1476,7 +1478,8 @@ def broadcast_shape(*arrays, **kwargs):
arrays_are_shapes: bool (Optional) arrays_are_shapes: bool (Optional)
Indicates whether or not the `arrays` contains shape tuples. Indicates whether or not the `arrays` contains shape tuples.
If you use this approach, make sure that the broadcastable dimensions If you use this approach, make sure that the broadcastable dimensions
are (scalar) constants with the value ``1`` or ``1`` exactly. are (scalar) constants with the value ``1``--or simply the integer
``1``.
""" """
return broadcast_shape_iter(arrays, **kwargs) return broadcast_shape_iter(arrays, **kwargs)
...@@ -1486,77 +1489,86 @@ def broadcast_shape_iter( ...@@ -1486,77 +1489,86 @@ def broadcast_shape_iter(
arrays: Iterable[Union[TensorVariable, Tuple[TensorVariable, ...]]], arrays: Iterable[Union[TensorVariable, Tuple[TensorVariable, ...]]],
arrays_are_shapes: bool = False, arrays_are_shapes: bool = False,
): ):
"""Compute the shape resulting from broadcasting arrays. r"""Compute the shape resulting from broadcasting arrays.
.. warning::
This function will not make copies, so be careful when calling it with
a generator/iterator!
Parameters Parameters
---------- ----------
arrays arrays
An iterable of tensors, or a tuple of shapes (as tuples), An iterable of tensors, or a tuple of shapes (as tuples),
for which the broadcast shape is computed. for which the broadcast shape is computed.
XXX: Do not call this with a generator/iterator; this function will not
make copies!
arrays_are_shapes arrays_are_shapes
Indicates whether or not the `arrays` contains shape tuples. Indicates whether or not the `arrays` contains shape tuples.
If you use this approach, make sure that the broadcastable dimensions If you use this approach, make sure that the broadcastable dimensions
are (scalar) constants with the value ``1`` or ``1`` exactly. are (scalar) constants with the value ``1``--or simply the integer
``1``.
""" """
one = aesara.scalar.ScalarConstant(aesara.scalar.int64, 1) one_at = aesara.scalar.ScalarConstant(aesara.scalar.int64, 1)
if arrays_are_shapes: if arrays_are_shapes:
max_dims = max(len(a) for a in arrays) max_dims = max(len(a) for a in arrays)
array_shapes = [ array_shapes = [
(one,) * (max_dims - len(a)) (one_at,) * (max_dims - len(a))
+ tuple(one if getattr(sh, "value", sh) == 1 else sh for sh in a) + tuple(one_at if getattr(sh, "value", sh) == 1 else sh for sh in a)
for a in arrays for a in arrays
] ]
else: else:
max_dims = max(a.ndim for a in arrays) max_dims = max(a.ndim for a in arrays)
array_shapes = [ array_shapes = [
(one,) * (max_dims - a.ndim) (one_at,) * (max_dims - a.ndim)
+ tuple(one if bcast else sh for sh, bcast in zip(a.shape, a.broadcastable)) + tuple(
one_at if bcast else sh for sh, bcast in zip(a.shape, a.broadcastable)
)
for a in arrays for a in arrays
] ]
result_dims = [] result_dims = []
for dim_shapes in zip(*array_shapes): for dim_shapes in zip(*array_shapes):
non_bcast_shapes = [shape for shape in dim_shapes if shape != one] # Get the shapes in this dimension that are not definitively
# broadcastable (i.e. not symbolically known to be broadcastable)
if len(non_bcast_shapes) > 0: maybe_non_bcast_shapes = [shape for shape in dim_shapes if shape != one_at]
# Either there's only one non-broadcastable dimensions--and that's
# what determines the dimension size, or there are multiple if len(maybe_non_bcast_shapes) == 0:
# non-broadcastable dimensions that must be equal # Every shape was broadcastable in this dimension
i_dim = non_bcast_shapes.pop() result_dims.append(one_at)
elif len(maybe_non_bcast_shapes) == 1:
# Only one shape might not be broadcastable in this dimension
result_dims.extend(maybe_non_bcast_shapes)
else:
# More than one shape might not be broadcastable in this dimension
potentially_unequal_dims = [ all_dims_equal = all(
dim
for dim in non_bcast_shapes
# TODO FIXME: This is a largely deficient means of comparing graphs # TODO FIXME: This is a largely deficient means of comparing graphs
# (and especially shapes) # (and especially shapes)
if not equal_computations([i_dim], [dim]) equal_computations([maybe_non_bcast_shapes[0]], [dim])
] for dim in maybe_non_bcast_shapes[1:]
)
if potentially_unequal_dims: if all_dims_equal:
# In this case, we can't tell whether or not the dimensions are result_dims.append(maybe_non_bcast_shapes[0])
# equal, so we'll need to assert their equality and move the error continue
# handling to evaluation time.
assert_dim = Assert("Could not broadcast dimensions") non_bcast_vec = aet.as_tensor(maybe_non_bcast_shapes)
eq_condition = aet_all( non_bcast_vec = aet.switch(eq(non_bcast_vec, 1), -one_at, non_bcast_vec)
[ dim_max = aet_max(non_bcast_vec)
or_(eq(dim, one), eq(i_dim, dim))
for dim in potentially_unequal_dims assert_dim = Assert("Could not broadcast dimensions")
] assert_cond = aet_all(
) or_(eq(non_bcast_vec, -one_at), eq(non_bcast_vec, aet_abs(dim_max)))
eq_condition = or_(eq(i_dim, one), eq_condition) )
result_dims.append(assert_dim(i_dim, eq_condition)) bcast_dim = assert_dim(dim_max, assert_cond)
else:
result_dims.append(i_dim) result_dims.append(bcast_dim)
else:
# Every array was broadcastable in this dimension
result_dims.append(one)
return tuple(result_dims) return tuple(result_dims)
......
...@@ -971,7 +971,7 @@ class TestRavelMultiIndex(utt.InferShapeTester): ...@@ -971,7 +971,7 @@ class TestRavelMultiIndex(utt.InferShapeTester):
ravel_multi_index(((3, 4),), ((3, 4),)) ravel_multi_index(((3, 4),), ((3, 4),))
def test_broadcast_shape(): def test_broadcast_shape_basic():
def shape_tuple(x, use_bcast=True): def shape_tuple(x, use_bcast=True):
if use_bcast: if use_bcast:
return tuple( return tuple(
...@@ -1006,11 +1006,6 @@ def test_broadcast_shape(): ...@@ -1006,11 +1006,6 @@ def test_broadcast_shape():
shape_tuple(x_aet), shape_tuple(y_aet), arrays_are_shapes=True shape_tuple(x_aet), shape_tuple(y_aet), arrays_are_shapes=True
) )
assert np.array_equal([z.eval() for z in b_aet], b.shape) assert np.array_equal([z.eval() for z in b_aet], b.shape)
# These are all constants, so there shouldn't be any asserts in the
# resulting graph.
assert not any(
isinstance(node.op, Assert) for node in applys_between([x_aet, y_aet], b_aet)
)
x = np.array([1, 2, 3]) x = np.array([1, 2, 3])
y = np.array([4, 5, 6]) y = np.array([4, 5, 6])
...@@ -1023,12 +1018,6 @@ def test_broadcast_shape(): ...@@ -1023,12 +1018,6 @@ def test_broadcast_shape():
shape_tuple(x_aet), shape_tuple(y_aet), arrays_are_shapes=True shape_tuple(x_aet), shape_tuple(y_aet), arrays_are_shapes=True
) )
assert np.array_equal([z.eval() for z in b_aet], b.shape) assert np.array_equal([z.eval() for z in b_aet], b.shape)
# TODO: This will work when/if we use a more sophisticated `is_same_graph`
# implementation.
# assert not any(
# isinstance(node.op, Assert)
# for node in graph_ops([x_aet, y_aet], b_aet)
# )
x = np.empty((1, 2, 3)) x = np.empty((1, 2, 3))
y = np.array(1) y = np.array(1)
...@@ -1038,9 +1027,6 @@ def test_broadcast_shape(): ...@@ -1038,9 +1027,6 @@ def test_broadcast_shape():
b_aet = broadcast_shape(x_aet, y_aet) b_aet = broadcast_shape(x_aet, y_aet)
assert b_aet[0].value == 1 assert b_aet[0].value == 1
assert np.array_equal([z.eval() for z in b_aet], b.shape) assert np.array_equal([z.eval() for z in b_aet], b.shape)
assert not any(
isinstance(node.op, Assert) for node in applys_between([x_aet, y_aet], b_aet)
)
b_aet = broadcast_shape( b_aet = broadcast_shape(
shape_tuple(x_aet), shape_tuple(y_aet), arrays_are_shapes=True shape_tuple(x_aet), shape_tuple(y_aet), arrays_are_shapes=True
) )
...@@ -1054,12 +1040,6 @@ def test_broadcast_shape(): ...@@ -1054,12 +1040,6 @@ def test_broadcast_shape():
b_aet = broadcast_shape(x_aet, y_aet) b_aet = broadcast_shape(x_aet, y_aet)
assert b_aet[1].value == 1 assert b_aet[1].value == 1
assert np.array_equal([z.eval() for z in b_aet], b.shape) assert np.array_equal([z.eval() for z in b_aet], b.shape)
# TODO: This will work when/if we use a more sophisticated `is_same_graph`
# implementation.
# assert not any(
# isinstance(node.op, Assert)
# for node in graph_ops([x_aet, y_aet], b_aet)
# )
b_aet = broadcast_shape( b_aet = broadcast_shape(
shape_tuple(x_aet), shape_tuple(y_aet), arrays_are_shapes=True shape_tuple(x_aet), shape_tuple(y_aet), arrays_are_shapes=True
) )
...@@ -1073,12 +1053,6 @@ def test_broadcast_shape(): ...@@ -1073,12 +1053,6 @@ def test_broadcast_shape():
y_shapes = (y1_shp_aet, 1, x2_shp_aet) y_shapes = (y1_shp_aet, 1, x2_shp_aet)
y_aet = aet.ones(y_shapes) y_aet = aet.ones(y_shapes)
b_aet = broadcast_shape(x_aet, y_aet) b_aet = broadcast_shape(x_aet, y_aet)
# TODO: This will work when/if we use a more sophisticated `is_same_graph`
# implementation.
# assert not any(
# isinstance(node.op, Assert)
# for node in graph_ops([x_aet, y_aet], b_aet)
# )
res = aet.as_tensor(b_aet).eval( res = aet.as_tensor(b_aet).eval(
{ {
x1_shp_aet: 10, x1_shp_aet: 10,
...@@ -1094,6 +1068,36 @@ def test_broadcast_shape(): ...@@ -1094,6 +1068,36 @@ def test_broadcast_shape():
assert isinstance(b_aet[-1].owner.op, Assert) assert isinstance(b_aet[-1].owner.op, Assert)
@pytest.mark.parametrize(
("s1_vals", "s2_vals", "exp_res"),
[
((2, 2), (1, 2), (2, 2)),
((0, 2), (1, 2), (0, 2)),
],
)
@config.change_flags(compute_test_value="raise")
def test_broadcast_shape_symbolic(s1_vals, s2_vals, exp_res):
s1_1, s1_2 = aet.lscalars("s1_1", "s1_2")
s2_1, s2_2 = aet.lscalars("s2_1", "s2_2")
s1_1.tag.test_value = s1_vals[0]
s1_2.tag.test_value = s1_vals[1]
s2_1.tag.test_value = s2_vals[0]
s2_2.tag.test_value = s2_vals[1]
res = broadcast_shape((s1_1, s1_2), (s2_1, s2_2), arrays_are_shapes=True)
res = aet.as_tensor(res)
assert (
tuple(
res.eval(
{s1_1: s1_vals[0], s1_2: s1_vals[1], s2_1: s2_vals[0], s2_2: s2_vals[1]}
)
)
== exp_res
)
class TestBroadcastTo(utt.InferShapeTester): class TestBroadcastTo(utt.InferShapeTester):
rng = np.random.default_rng(43) rng = np.random.default_rng(43)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论