提交 3edbbc4e authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Prevent broadcast_to from creating useless Ops

上级 8af9aa23
......@@ -10,7 +10,7 @@ from aesara.gradient import (
disconnected_type,
grad_undefined,
)
from aesara.graph.basic import Apply, equal_computations
from aesara.graph.basic import Apply, Variable, equal_computations
from aesara.graph.op import COp, Op
from aesara.graph.params_type import ParamsType
from aesara.graph.type import EnumList, Generic
......@@ -19,6 +19,7 @@ from aesara.raise_op import Assert
from aesara.scalar import int32 as int_t
from aesara.scalar import upcast
from aesara.tensor import basic as at
from aesara.tensor import get_vector_length
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import abs as at_abs
from aesara.tensor.math import all as at_all
......@@ -1627,7 +1628,37 @@ class BroadcastTo(Op):
return [node.inputs[1:]]
broadcast_to = BroadcastTo()
broadcast_to_ = BroadcastTo()
def broadcast_to(
x: TensorVariable, shape: Union[TensorVariable, Tuple[Variable]]
) -> TensorVariable:
"""Broadcast an array to a new shape.
Parameters
----------
array
The array to broadcast.
shape
The shape of the desired array.
Returns
-------
broadcast
A readonly view on the original array with the given shape. It is
typically not contiguous. Furthermore, more than one element of a
broadcasted array may refer to a single memory location.
"""
x = at.as_tensor(x)
shape = at.as_tensor(shape, ndim=1, dtype="int64")
shape_len = get_vector_length(shape)
if x.ndim == 0 and shape_len == 0:
return x
return broadcast_to_(x, shape)
def broadcast_arrays(*args: TensorVariable) -> Tuple[TensorVariable, ...]:
......
......@@ -51,14 +51,7 @@ from aesara.tensor.basic_opt import (
register_specialize,
)
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.extra_ops import (
BroadcastTo,
Repeat,
Unique,
broadcast_to,
repeat,
unique,
)
from aesara.tensor.extra_ops import BroadcastTo, Repeat, Unique, repeat, unique
from aesara.tensor.math import (
add,
bitwise_and,
......@@ -3359,7 +3352,6 @@ def test_local_Unique_Alloc_lift(
@pytest.mark.parametrize(
"x_val, axis, new_shape",
[
(np.array(-10, dtype=np.int64), None, ()),
(np.array(-10, dtype=np.int64), None, (2, 3)),
(np.array([[-10, -3], [-10, 2], [-10, 2]], dtype=np.int64), None, (2, 3, 2)),
],
......@@ -3372,7 +3364,7 @@ def test_local_Unique_BroadcastTo(
):
x = as_tensor_variable(x_val).type()
y = unique(
broadcast_to(x, tuple(new_shape)),
BroadcastTo()(x, tuple(new_shape)),
return_index=return_index,
return_counts=return_counts,
return_inverse=return_inverse,
......
......@@ -1095,6 +1095,11 @@ class TestBroadcastTo(utt.InferShapeTester):
self.op_class = BroadcastTo
self.op = broadcast_to
def test_avoid_useless_scalars(self):
x = scalar()
y = broadcast_to(x, ())
assert y is x
@config.change_flags(compute_test_value="raise")
def test_perform(self):
a = scalar()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论