提交 fa0ec65b authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make arrays_are_shapes keyword explicit in broadcast_shape_iter

上级 8706f75a
from collections.abc import Collection from collections.abc import Collection
from typing import Tuple from typing import Iterable, Tuple, Union
import numpy as np import numpy as np
...@@ -1482,17 +1482,20 @@ def broadcast_shape(*arrays, **kwargs): ...@@ -1482,17 +1482,20 @@ def broadcast_shape(*arrays, **kwargs):
return broadcast_shape_iter(arrays, **kwargs) return broadcast_shape_iter(arrays, **kwargs)
def broadcast_shape_iter(arrays, **kwargs): def broadcast_shape_iter(
arrays: Iterable[Union[TensorVariable, Tuple[TensorVariable, ...]]],
arrays_are_shapes: bool = False,
):
"""Compute the shape resulting from broadcasting arrays. """Compute the shape resulting from broadcasting arrays.
Parameters Parameters
---------- ----------
arrays: Iterable[TensorVariable] or Iterable[Tuple[Variable]] arrays
An iterable of tensors, or a tuple of shapes (as tuples), An iterable of tensors, or a tuple of shapes (as tuples),
for which the broadcast shape is computed. for which the broadcast shape is computed.
XXX: Do not call this with a generator/iterator; this function will not XXX: Do not call this with a generator/iterator; this function will not
make copies! make copies!
arrays_are_shapes: bool (Optional) arrays_are_shapes
Indicates whether or not the `arrays` contains shape tuples. Indicates whether or not the `arrays` contains shape tuples.
If you use this approach, make sure that the broadcastable dimensions If you use this approach, make sure that the broadcastable dimensions
are (scalar) constants with the value ``1`` or ``1`` exactly. are (scalar) constants with the value ``1`` or ``1`` exactly.
...@@ -1500,7 +1503,6 @@ def broadcast_shape_iter(arrays, **kwargs): ...@@ -1500,7 +1503,6 @@ def broadcast_shape_iter(arrays, **kwargs):
""" """
one = aesara.scalar.ScalarConstant(aesara.scalar.int64, 1) one = aesara.scalar.ScalarConstant(aesara.scalar.int64, 1)
arrays_are_shapes = kwargs.pop("arrays_are_shapes", False)
if arrays_are_shapes: if arrays_are_shapes:
max_dims = max(len(a) for a in arrays) max_dims = max(len(a) for a in arrays)
...@@ -1560,6 +1562,7 @@ def broadcast_shape_iter(arrays, **kwargs): ...@@ -1560,6 +1562,7 @@ def broadcast_shape_iter(arrays, **kwargs):
class BroadcastTo(Op): class BroadcastTo(Op):
"""An `Op` for `numpy.broadcast_to`."""
view_map = {0: [0]} view_map = {0: [0]}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论