提交 8d5a8c8c authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make better use of constants in broadcast_shape_iter

上级 cfc931fa
from collections.abc import Collection from collections.abc import Collection
from functools import reduce from functools import reduce
from typing import Iterable, Tuple, Union from typing import Iterable, Set, Tuple, Union
import numpy as np import numpy as np
import numpy.core.numeric import numpy.core.numeric
...@@ -14,7 +14,7 @@ from aesara.gradient import ( ...@@ -14,7 +14,7 @@ from aesara.gradient import (
disconnected_type, disconnected_type,
grad_undefined, grad_undefined,
) )
from aesara.graph.basic import Apply, Variable, equal_computations from aesara.graph.basic import Apply, Constant, Variable, equal_computations
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.link.c.op import COp from aesara.link.c.op import COp
from aesara.link.c.params_type import ParamsType from aesara.link.c.params_type import ParamsType
...@@ -1491,7 +1491,12 @@ def broadcast_shape_iter( ...@@ -1491,7 +1491,12 @@ def broadcast_shape_iter(
array_shapes = [ array_shapes = [
(one_at,) * (max_dims - len(a)) (one_at,) * (max_dims - len(a))
+ tuple(one_at if getattr(sh, "value", sh) == 1 else sh for sh in a) + tuple(
one_at
if getattr(sh, "value", sh) == 1
else (aes.as_scalar(sh) if not isinstance(sh, Variable) else sh)
for sh in a
)
for a in arrays for a in arrays
] ]
else: else:
...@@ -1523,9 +1528,38 @@ def broadcast_shape_iter( ...@@ -1523,9 +1528,38 @@ def broadcast_shape_iter(
else: else:
# More than one shape might not be broadcastable in this dimension # More than one shape might not be broadcastable in this dimension
nonconst_nb_shapes: Set[int] = set()
const_nb_shapes: Set[Variable] = set()
for shape in maybe_non_bcast_shapes:
if isinstance(shape, Constant):
const_nb_shapes.add(shape.value.item())
else:
nonconst_nb_shapes.add(shape)
if len(const_nb_shapes) > 1:
raise ValueError("Could not broadcast dimensions")
elif len(const_nb_shapes) == 1:
(const_nb_shape,) = const_nb_shapes
assert const_nb_shape != 1
const_nt_shape_var = aesara.scalar.ScalarConstant(
aesara.scalar.int64, const_nb_shape
)
if len(nonconst_nb_shapes) > 0:
assert_dim = Assert("Could not broadcast dimensions")
assert_cond = reduce(
aes.and_,
(aes.eq(nbv, const_nt_shape_var) for nbv in nonconst_nb_shapes),
)
bcast_dim = assert_dim(const_nt_shape_var, assert_cond)
else:
bcast_dim = const_nt_shape_var
else:
all_dims_equal = all( all_dims_equal = all(
# TODO FIXME: This is a largely deficient means of comparing graphs # TODO FIXME: This is a largely deficient, and expensive, means
# (and especially shapes) # of comparing graphs (and especially shapes)
equal_computations([maybe_non_bcast_shapes[0]], [dim]) equal_computations([maybe_non_bcast_shapes[0]], [dim])
for dim in maybe_non_bcast_shapes[1:] for dim in maybe_non_bcast_shapes[1:]
) )
......
...@@ -8,7 +8,7 @@ from aesara import function ...@@ -8,7 +8,7 @@ from aesara import function
from aesara import tensor as at from aesara import tensor as at
from aesara.compile.mode import Mode from aesara.compile.mode import Mode
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import applys_between from aesara.graph.basic import Constant, applys_between
from aesara.graph.optdb import OptimizationQuery from aesara.graph.optdb import OptimizationQuery
from aesara.raise_op import Assert from aesara.raise_op import Assert
from aesara.tensor.elemwise import DimShuffle from aesara.tensor.elemwise import DimShuffle
...@@ -1143,6 +1143,35 @@ def test_broadcast_shape_basic(): ...@@ -1143,6 +1143,35 @@ def test_broadcast_shape_basic():
assert isinstance(b_at[-1].owner.op, Assert) assert isinstance(b_at[-1].owner.op, Assert)
def test_broadcast_shape_constants():
"""Make sure `broadcast_shape` uses constants when it can."""
x1_shp_at = iscalar("x1")
y2_shp_at = iscalar("y2")
b_at = broadcast_shape((x1_shp_at, 2), (3, y2_shp_at), arrays_are_shapes=True)
assert len(b_at) == 2
assert isinstance(b_at[0].owner.op, Assert)
assert b_at[0].owner.inputs[0].value.item() == 3
assert isinstance(b_at[1].owner.op, Assert)
assert b_at[1].owner.inputs[0].value.item() == 2
b_at = broadcast_shape((1, 2), (3, 2), arrays_are_shapes=True)
assert len(b_at) == 2
assert all(isinstance(x, Constant) for x in b_at)
assert b_at[0].value.item() == 3
assert b_at[1].value.item() == 2
b_at = broadcast_shape((1,), (1, 1), arrays_are_shapes=True)
assert len(b_at) == 2
assert all(isinstance(x, Constant) for x in b_at)
assert b_at[0].value.item() == 1
assert b_at[1].value.item() == 1
b_at = broadcast_shape((1,), (1,), arrays_are_shapes=True)
assert len(b_at) == 1
assert all(isinstance(x, Constant) for x in b_at)
assert b_at[0].value.item() == 1
@pytest.mark.parametrize( @pytest.mark.parametrize(
("s1_vals", "s2_vals", "exp_res"), ("s1_vals", "s2_vals", "exp_res"),
[ [
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论