提交 1c11dd44 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add broadcast_arrays function

上级 1fa8df43
...@@ -66,6 +66,7 @@ from aesara.tensor.blas import batched_dot, batched_tensordot ...@@ -66,6 +66,7 @@ from aesara.tensor.blas import batched_dot, batched_tensordot
from aesara.tensor.extra_ops import ( from aesara.tensor.extra_ops import (
bartlett, bartlett,
bincount, bincount,
broadcast_arrays,
broadcast_shape, broadcast_shape,
broadcast_shape_iter, broadcast_shape_iter,
broadcast_to, broadcast_to,
......
from collections.abc import Collection from collections.abc import Collection
from typing import Tuple
import numpy as np import numpy as np
...@@ -32,6 +33,7 @@ from aesara.tensor.type import ( ...@@ -32,6 +33,7 @@ from aesara.tensor.type import (
integer_dtypes, integer_dtypes,
vector, vector,
) )
from aesara.tensor.var import TensorVariable
from aesara.utils import LOCAL_BITWIDTH, PYTHON_INT_BITWIDTH from aesara.utils import LOCAL_BITWIDTH, PYTHON_INT_BITWIDTH
...@@ -1614,3 +1616,15 @@ class BroadcastTo(Op): ...@@ -1614,3 +1616,15 @@ class BroadcastTo(Op):
broadcast_to = BroadcastTo() broadcast_to = BroadcastTo()
def broadcast_arrays(*args: TensorVariable) -> Tuple[TensorVariable, ...]:
"""Broadcast any number of arrays against each other.
Parameters
----------
`*args` : array_likes
The arrays to broadcast.
"""
return tuple(broadcast_to(a, broadcast_shape(*args)) for a in args)
...@@ -26,6 +26,7 @@ from aesara.tensor.extra_ops import ( ...@@ -26,6 +26,7 @@ from aesara.tensor.extra_ops import (
UnravelIndex, UnravelIndex,
bartlett, bartlett,
bincount, bincount,
broadcast_arrays,
broadcast_shape, broadcast_shape,
broadcast_to, broadcast_to,
compress, compress,
...@@ -1177,3 +1178,19 @@ class TestBroadcastTo(utt.InferShapeTester): ...@@ -1177,3 +1178,19 @@ class TestBroadcastTo(utt.InferShapeTester):
assert isinstance(advincsub_node.inputs[0].owner.op, BroadcastTo) assert isinstance(advincsub_node.inputs[0].owner.op, BroadcastTo)
assert advincsub_node.op.inplace is False assert advincsub_node.op.inplace is False
def test_broadcast_arrays():
x, y = aet.dvector(), aet.dmatrix()
x_bcast, y_bcast = broadcast_arrays(x, y)
py_mode = Mode("py", None)
bcast_fn = function([x, y], [x_bcast, y_bcast], mode=py_mode)
x_val = np.array([1.0], dtype=np.float64)
y_val = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float64)
x_bcast_val, y_bcast_val = bcast_fn(x_val, y_val)
x_bcast_exp, y_bcast_exp = np.broadcast_arrays(x_val, y_val)
assert np.array_equal(x_bcast_val, x_bcast_exp)
assert np.array_equal(y_bcast_val, y_bcast_exp)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论