提交 3edbbc4e authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Prevent broadcast_to from creating useless Ops

上级 8af9aa23
...@@ -10,7 +10,7 @@ from aesara.gradient import ( ...@@ -10,7 +10,7 @@ from aesara.gradient import (
disconnected_type, disconnected_type,
grad_undefined, grad_undefined,
) )
from aesara.graph.basic import Apply, equal_computations from aesara.graph.basic import Apply, Variable, equal_computations
from aesara.graph.op import COp, Op from aesara.graph.op import COp, Op
from aesara.graph.params_type import ParamsType from aesara.graph.params_type import ParamsType
from aesara.graph.type import EnumList, Generic from aesara.graph.type import EnumList, Generic
...@@ -19,6 +19,7 @@ from aesara.raise_op import Assert ...@@ -19,6 +19,7 @@ from aesara.raise_op import Assert
from aesara.scalar import int32 as int_t from aesara.scalar import int32 as int_t
from aesara.scalar import upcast from aesara.scalar import upcast
from aesara.tensor import basic as at from aesara.tensor import basic as at
from aesara.tensor import get_vector_length
from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import abs as at_abs from aesara.tensor.math import abs as at_abs
from aesara.tensor.math import all as at_all from aesara.tensor.math import all as at_all
...@@ -1627,7 +1628,37 @@ class BroadcastTo(Op): ...@@ -1627,7 +1628,37 @@ class BroadcastTo(Op):
return [node.inputs[1:]] return [node.inputs[1:]]
broadcast_to = BroadcastTo() broadcast_to_ = BroadcastTo()
def broadcast_to(
x: TensorVariable, shape: Union[TensorVariable, Tuple[Variable]]
) -> TensorVariable:
"""Broadcast an array to a new shape.
Parameters
----------
array
The array to broadcast.
shape
The shape of the desired array.
Returns
-------
broadcast
A readonly view on the original array with the given shape. It is
typically not contiguous. Furthermore, more than one element of a
broadcasted array may refer to a single memory location.
"""
x = at.as_tensor(x)
shape = at.as_tensor(shape, ndim=1, dtype="int64")
shape_len = get_vector_length(shape)
if x.ndim == 0 and shape_len == 0:
return x
return broadcast_to_(x, shape)
def broadcast_arrays(*args: TensorVariable) -> Tuple[TensorVariable, ...]: def broadcast_arrays(*args: TensorVariable) -> Tuple[TensorVariable, ...]:
......
...@@ -51,14 +51,7 @@ from aesara.tensor.basic_opt import ( ...@@ -51,14 +51,7 @@ from aesara.tensor.basic_opt import (
register_specialize, register_specialize,
) )
from aesara.tensor.elemwise import DimShuffle, Elemwise from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.extra_ops import ( from aesara.tensor.extra_ops import BroadcastTo, Repeat, Unique, repeat, unique
BroadcastTo,
Repeat,
Unique,
broadcast_to,
repeat,
unique,
)
from aesara.tensor.math import ( from aesara.tensor.math import (
add, add,
bitwise_and, bitwise_and,
...@@ -3359,7 +3352,6 @@ def test_local_Unique_Alloc_lift( ...@@ -3359,7 +3352,6 @@ def test_local_Unique_Alloc_lift(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"x_val, axis, new_shape", "x_val, axis, new_shape",
[ [
(np.array(-10, dtype=np.int64), None, ()),
(np.array(-10, dtype=np.int64), None, (2, 3)), (np.array(-10, dtype=np.int64), None, (2, 3)),
(np.array([[-10, -3], [-10, 2], [-10, 2]], dtype=np.int64), None, (2, 3, 2)), (np.array([[-10, -3], [-10, 2], [-10, 2]], dtype=np.int64), None, (2, 3, 2)),
], ],
...@@ -3372,7 +3364,7 @@ def test_local_Unique_BroadcastTo( ...@@ -3372,7 +3364,7 @@ def test_local_Unique_BroadcastTo(
): ):
x = as_tensor_variable(x_val).type() x = as_tensor_variable(x_val).type()
y = unique( y = unique(
broadcast_to(x, tuple(new_shape)), BroadcastTo()(x, tuple(new_shape)),
return_index=return_index, return_index=return_index,
return_counts=return_counts, return_counts=return_counts,
return_inverse=return_inverse, return_inverse=return_inverse,
......
...@@ -1095,6 +1095,11 @@ class TestBroadcastTo(utt.InferShapeTester): ...@@ -1095,6 +1095,11 @@ class TestBroadcastTo(utt.InferShapeTester):
self.op_class = BroadcastTo self.op_class = BroadcastTo
self.op = broadcast_to self.op = broadcast_to
def test_avoid_useless_scalars(self):
x = scalar()
y = broadcast_to(x, ())
assert y is x
@config.change_flags(compute_test_value="raise") @config.change_flags(compute_test_value="raise")
def test_perform(self): def test_perform(self):
a = scalar() a = scalar()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论