提交 e4b15e48 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use Composite graphs in aesara.tensor.extra_ops.broadcast_shape_iter

上级 9d5ab765
......@@ -23,6 +23,7 @@ from aesara.misc.safe_asarray import _asarray
from aesara.raise_op import Assert
from aesara.scalar import int32 as int_t
from aesara.scalar import upcast
from aesara.scalar.basic import Composite
from aesara.tensor import basic as at
from aesara.tensor import get_vector_length
from aesara.tensor.exceptions import NotScalarConstantError
......@@ -1552,16 +1553,32 @@ def broadcast_shape_iter(
# be broadcastable or equal to the one non-broadcastable
# constant `const_nt_shape_var`.
assert_dim = Assert("Could not broadcast dimensions")
scalar_nonconst_nb_shapes = [
at.scalar_from_tensor(s)
if isinstance(s.type, TensorType)
else s
for s in nonconst_nb_shapes
]
dummy_nonconst_nb_shapes = [
aes.get_scalar_type(dtype=v.dtype)()
for v in scalar_nonconst_nb_shapes
]
assert_cond = reduce(
aes.and_,
(
aes.or_(
aes.eq(nbv, one_at), aes.eq(nbv, const_nt_shape_var)
)
for nbv in nonconst_nb_shapes
for nbv in dummy_nonconst_nb_shapes
),
)
bcast_dim = assert_dim(const_nt_shape_var, assert_cond)
assert_cond_op = Composite(dummy_nonconst_nb_shapes, [assert_cond])
bcast_dim = assert_dim(
const_nt_shape_var, assert_cond_op(*scalar_nonconst_nb_shapes)
)
else:
bcast_dim = const_nt_shape_var
else:
......@@ -1579,21 +1596,37 @@ def broadcast_shape_iter(
result_dims.append(maybe_non_bcast_shapes[0])
continue
scalar_maybe_non_bcast_shapes = [
at.scalar_from_tensor(s) if isinstance(s.type, TensorType) else s
for s in maybe_non_bcast_shapes
]
dummy_maybe_non_bcast_shapes = [
aes.get_scalar_type(dtype=v.dtype)()
for v in scalar_maybe_non_bcast_shapes
]
non_bcast_vec = [
aes.switch(aes.eq(nbv, 1), -one_at, nbv)
for nbv in maybe_non_bcast_shapes
for nbv in dummy_maybe_non_bcast_shapes
]
dim_max = aes.abs(reduce(aes.scalar_maximum, non_bcast_vec))
dim_max_op = Composite(dummy_maybe_non_bcast_shapes, [dim_max])
dummy_dim_max = dim_max_op(*dummy_maybe_non_bcast_shapes)
assert_dim = Assert("Could not broadcast dimensions")
assert_cond = reduce(
aes.and_,
(
aes.or_(aes.eq(nbv, -one_at), aes.eq(nbv, dim_max))
aes.or_(aes.eq(nbv, -one_at), aes.eq(nbv, dummy_dim_max))
for nbv in non_bcast_vec
),
)
bcast_dim = assert_dim(dim_max, assert_cond)
assert_cond_op = Composite(dummy_maybe_non_bcast_shapes, [assert_cond])
bcast_dim = assert_dim(
dim_max_op(*scalar_maybe_non_bcast_shapes),
assert_cond_op(*scalar_maybe_non_bcast_shapes),
)
result_dims.append(bcast_dim)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论