提交 fa0ec65b authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make arrays_are_shapes keyword explicit in broadcast_shape_iter

上级 8706f75a
from collections.abc import Collection
from typing import Tuple
from typing import Iterable, Tuple, Union
import numpy as np
......@@ -1482,17 +1482,20 @@ def broadcast_shape(*arrays, **kwargs):
return broadcast_shape_iter(arrays, **kwargs)
def broadcast_shape_iter(arrays, **kwargs):
def broadcast_shape_iter(
arrays: Iterable[Union[TensorVariable, Tuple[TensorVariable, ...]]],
arrays_are_shapes: bool = False,
):
"""Compute the shape resulting from broadcasting arrays.
Parameters
----------
arrays: Iterable[TensorVariable] or Iterable[Tuple[Variable]]
arrays
An iterable of tensors, or a tuple of shapes (as tuples),
for which the broadcast shape is computed.
XXX: Do not call this with a generator/iterator; this function will not
make copies!
arrays_are_shapes: bool (Optional)
arrays_are_shapes
Indicates whether or not the `arrays` contains shape tuples.
If you use this approach, make sure that the broadcastable dimensions
are (scalar) constants with the value ``1`` or ``1`` exactly.
......@@ -1500,7 +1503,6 @@ def broadcast_shape_iter(arrays, **kwargs):
"""
one = aesara.scalar.ScalarConstant(aesara.scalar.int64, 1)
arrays_are_shapes = kwargs.pop("arrays_are_shapes", False)
if arrays_are_shapes:
max_dims = max(len(a) for a in arrays)
......@@ -1560,6 +1562,7 @@ def broadcast_shape_iter(arrays, **kwargs):
class BroadcastTo(Op):
"""An `Op` for `numpy.broadcast_to`."""
view_map = {0: [0]}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论