提交 6898f749 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Remove BroadcastTo

上级 5f809cfe
......@@ -3,10 +3,8 @@ import warnings
import jax.numpy as jnp
from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.tensor.basic import infer_static_shape
from pytensor.tensor.extra_ops import (
Bartlett,
BroadcastTo,
CumOp,
FillDiagonal,
FillDiagonalOffset,
......@@ -102,18 +100,6 @@ def jax_funcify_RavelMultiIndex(op, **kwargs):
return ravelmultiindex
@jax_funcify.register(BroadcastTo)
def jax_funcify_BroadcastTo(op, node, **kwargs):
shape = node.inputs[1:]
static_shape = infer_static_shape(shape)[1]
def broadcast_to(x, *shape):
shape = tuple(st if st is not None else s for s, st in zip(shape, static_shape))
return jnp.broadcast_to(x, shape)
return broadcast_to
@jax_funcify.register(FillDiagonal)
def jax_funcify_FillDiagonal(op, **kwargs):
def filldiagonal(value, diagonal):
......
......@@ -2,7 +2,6 @@ import warnings
import numba
import numpy as np
from numba.misc.special import literal_unroll
from pytensor import config
from pytensor.link.numba.dispatch import basic as numba_basic
......@@ -10,7 +9,6 @@ from pytensor.link.numba.dispatch.basic import get_numba_type, numba_funcify
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor.extra_ops import (
Bartlett,
BroadcastTo,
CumOp,
FillDiagonal,
FillDiagonalOffset,
......@@ -353,29 +351,6 @@ def numba_funcify_Searchsorted(op, node, **kwargs):
return searchsorted
@numba_funcify.register(BroadcastTo)
def numba_funcify_BroadcastTo(op, node, **kwargs):
create_zeros_tuple = numba_basic.create_tuple_creator(
lambda _: 0, len(node.inputs) - 1
)
# TODO broadcastable checks
@numba_basic.numba_njit
def broadcast_to(x, *shape):
scalars_shape = create_zeros_tuple()
i = 0
for s_i in literal_unroll(shape):
scalars_shape = numba_basic.tuple_setitem(
scalars_shape, i, numba_basic.to_scalar(s_i)
)
i += 1
return np.broadcast_to(x, scalars_shape)
return broadcast_to
@numba_funcify.register(CheckAndRaise)
def numba_funcify_CheckAndRaise(op, node, **kwargs):
error = op.exc_type
......
......@@ -23,7 +23,7 @@ from pytensor.scalar import int32 as int_t
from pytensor.scalar import upcast
from pytensor.tensor import as_tensor_variable
from pytensor.tensor import basic as at
from pytensor.tensor.basic import get_vector_length, second
from pytensor.tensor.basic import alloc, second
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import abs as pt_abs
from pytensor.tensor.math import all as pt_all
......@@ -1584,141 +1584,6 @@ def broadcast_shape_iter(
return tuple(result_dims)
class BroadcastTo(COp):
"""An `Op` for `numpy.broadcast_to`."""
_output_type_depends_on_input_value = True
__props__ = ()
view_map = {0: [0]}
def __call__(self, a, shape, **kwargs):
return super().__call__(a, *shape, **kwargs)
def make_node(self, a, *shape):
a = at.as_tensor_variable(a)
shape, static_shape = at.infer_static_shape(shape)
if len(shape) < a.ndim:
raise ValueError(
f"Broadcast target shape has {len(shape)} dims, which is shorter than input with {a.ndim} dims"
)
out = TensorType(dtype=a.type.dtype, shape=static_shape)()
# Attempt to prevent in-place operations on this view-based output
out.tag.indestructible = True
return Apply(self, [a] + shape, [out])
def perform(self, node, inputs, output_storage):
a, *shape = inputs
z = output_storage[0]
z[0] = np.broadcast_to(a, shape)
def grad(self, inputs, outputs_gradients):
a, *shape = inputs
(dout,) = outputs_gradients
# Determine the dimensions that were added by broadcasting
new_dims = list(range(dout.ndim - a.ndim))
d_wrt_a = broadcast_to(dout, shape).sum(axis=new_dims)
# Determine the dimensions that were broadcast
_, static_shape = at.infer_static_shape(shape)
# TODO: This needs to be performed at run-time when static shape
# information isn't available.
bcast_sums = [
i
for i, (a_s, s_s) in enumerate(zip(a.type.shape, static_shape[-a.ndim :]))
if a_s == 1 and s_s != 1
]
if bcast_sums:
d_wrt_a = d_wrt_a.sum(axis=bcast_sums, keepdims=True)
return [d_wrt_a] + [
grad_undefined(self, i, shp) for i, shp in enumerate(shape, 1)
]
def infer_shape(self, fgraph, node, ins_shapes):
return [node.inputs[1:]]
def c_code(self, node, name, inputs, outputs, sub):
inp_dims = node.inputs[0].ndim
out_dims = node.outputs[0].ndim
new_dims = out_dims - inp_dims
(x, *shape) = inputs
(out,) = outputs
fail = sub["fail"]
# TODO: Could just use `PyArray_Return`, no?
dims_array = ", ".join(
[
f"((dtype_{shape}*)(PyArray_DATA({shape})))[0]"
for i, shape in enumerate(shape)
]
)
src = (
"""
npy_intp itershape[%(out_dims)s] = {%(dims_array)s};
NpyIter *iter;
PyArrayObject *ops[1] = {%(x)s};
npy_uint32 flags = NPY_ITER_MULTI_INDEX | NPY_ITER_REFS_OK | NPY_ITER_ZEROSIZE_OK;
npy_uint32 op_flags[1] = {NPY_ITER_READONLY};
PyArray_Descr *op_dtypes[1] = {NULL};
int oa_ndim = %(out_dims)s;
int* op_axes[1] = {NULL};
npy_intp buffersize = 0;
for(int i = 0; i < %(inp_dims)s; i++)
{
if ((PyArray_DIMS(%(x)s)[i] != 1) && (PyArray_DIMS(%(x)s)[i] != itershape[i + %(new_dims)s]))
{
PyErr_Format(PyExc_ValueError,
"Shape mismatch in broadcast_to: target shape[%%i] = %%lld is incompatible with input shape = %%lld.",
i,
(long long int) itershape[i + %(new_dims)s],
(long long int) PyArray_DIMS(%(x)s)[i]
);
%(fail)s
}
}
iter = NpyIter_AdvancedNew(
1, ops, flags, NPY_CORDER, NPY_NO_CASTING, op_flags, op_dtypes, oa_ndim, op_axes, itershape, buffersize
);
%(out)s = NpyIter_GetIterView(iter, 0);
if(%(out)s == NULL){
NpyIter_Deallocate(iter);
%(fail)s;
}
if (NpyIter_Deallocate(iter) != NPY_SUCCEED) {
%(fail)s;
}
"""
% locals()
)
return src
def c_code_cache_version(self):
return (2,)
broadcast_to_ = BroadcastTo()
def geomspace(start, end, steps, base=10.0):
from pytensor.tensor.math import log
......@@ -1762,13 +1627,7 @@ def broadcast_to(
broadcasted array may refer to a single memory location.
"""
x = at.as_tensor(x)
shape_len = get_vector_length(shape)
if x.ndim == 0 and shape_len == 0:
return x
return broadcast_to_(x, shape)
return alloc(x, *shape)
def broadcast_arrays(*args: TensorVariable) -> Tuple[TensorVariable, ...]:
......
......@@ -2,7 +2,7 @@ import pytensor.scalar.basic as aes
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.tensor.basic import Alloc, as_tensor_variable
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.extra_ops import BroadcastTo, Repeat, Unique
from pytensor.tensor.extra_ops import Repeat, Unique
from pytensor.tensor.rewriting.basic import register_canonicalize, register_useless
......@@ -60,39 +60,6 @@ def local_Unique_Alloc_lift(fgraph, node):
return [new_x]
@register_useless
@register_canonicalize
@node_rewriter([Unique])
def local_Unique_BroadcastTo_lift(fgraph, node):
"""Convert ``unique(broadcast_to(x, ...), axis=None)`` to ``unique(x, axis=None)``.
This isn't really so much a lift as a "reduction/consumption".
"""
if not isinstance(node.op, Unique):
return False
if (
node.op.return_index
or node.op.return_inverse
or node.op.return_counts
or node.op.axis is not None
):
return False
bcast_var = node.inputs[0]
if not (bcast_var.owner and isinstance(bcast_var.owner.op, BroadcastTo)):
return False
bcasted_var, *bcast_shape = bcast_var.owner.inputs
new_unique, *_ = node.op.make_node(bcasted_var).outputs
old_out = node.outputs[0]
new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype)
return [new_x]
@register_useless
@register_canonicalize
@node_rewriter([Unique])
......@@ -161,16 +128,3 @@ def local_Unique_second(fgraph, node):
old_out = node.outputs[0]
new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype)
return [new_x]
@register_useless
@register_canonicalize
@node_rewriter([BroadcastTo])
def local_remove_scalar_BroadcastTo(fgraph, node):
bcast_shape = node.inputs[1:]
if not bcast_shape:
bcasted_var = node.inputs[0]
# If this isn't true, the graph is invalid
assert bcasted_var.ndim == 0
return [bcasted_var]
......@@ -7,7 +7,7 @@ from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.tensor import extra_ops as at_extra_ops
from pytensor.tensor.type import matrix, vector
from pytensor.tensor.type import matrix
from tests.link.jax.test_basic import compare_jax_and_py
......@@ -63,29 +63,6 @@ def test_extra_ops():
)
@pytest.mark.parametrize(
"x, shape",
[
(
set_test_value(
vector("x"), np.random.random(size=(2,)).astype(config.floatX)
),
[at.as_tensor(3, dtype=np.int64), at.as_tensor(2, dtype=np.int64)],
),
(
set_test_value(
vector("x"), np.random.random(size=(2,)).astype(config.floatX)
),
[at.as_tensor(3, dtype=np.int8), at.as_tensor(2, dtype=np.int64)],
),
],
)
def test_BroadcastTo(x, shape):
out = at_extra_ops.broadcast_to(x, shape)
fgraph = FunctionGraph(outputs=[out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
......
......@@ -36,41 +36,6 @@ def test_Bartlett(val):
)
@pytest.mark.parametrize(
"x, shape",
[
(
set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)),
[set_test_value(at.lscalar(), np.array(v)) for v in [3, 2]],
),
(
set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)),
[at.as_tensor(3, dtype=np.int64), at.as_tensor(2, dtype=np.int64)],
),
(
set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)),
at.as_tensor([set_test_value(at.lscalar(), np.array(v)) for v in [3, 2]]),
),
(
set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)),
[at.as_tensor(3, dtype=np.int8), at.as_tensor(2, dtype=np.int64)],
),
],
)
def test_BroadcastTo(x, shape):
g = extra_ops.BroadcastTo()(x, shape)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"val, axis, mode",
[
......
......@@ -8,7 +8,7 @@ from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.tensor.basic import Alloc, alloc, as_tensor_variable, second
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.extra_ops import BroadcastTo, Repeat, Unique, repeat, unique
from pytensor.tensor.extra_ops import Repeat, Unique, repeat, unique
from pytensor.tensor.type import dscalar
......@@ -103,64 +103,6 @@ def test_local_Unique_Alloc_lift(
assert np.array_equal(y_exp_val, y_val)
@pytest.mark.parametrize(
"x_val, axis, new_shape",
[
(np.array(-10, dtype=np.int64), None, (2, 3)),
(np.array([[-10, -3], [-10, 2], [-10, 2]], dtype=np.int64), None, (2, 3, 2)),
],
)
@pytest.mark.parametrize("return_index", [False])
@pytest.mark.parametrize("return_counts", [False])
@pytest.mark.parametrize("return_inverse", [False])
def test_local_Unique_BroadcastTo(
x_val, axis, new_shape, return_index, return_counts, return_inverse
):
x = as_tensor_variable(x_val).type()
y = unique(
BroadcastTo()(x, tuple(new_shape)),
return_index=return_index,
return_counts=return_counts,
return_inverse=return_inverse,
axis=axis,
)
if isinstance(y, list):
y, *_ = y
# This approach allows us to directly confirm that `x` is in the result.
y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_rewritten_fg = rewrite_graph(
y_fg,
clone=False,
include=["canonicalize", "local_Unique_BroadcastTo_lift"],
exclude=["local_Unique_scalar"],
)
y_rewritten = y_rewritten_fg.outputs[0]
y_rewritten_start = y_rewritten
assert isinstance(y_rewritten_start.owner.op, Unique)
assert y_rewritten_start.owner.inputs[0] == x
assert not any(
isinstance(node.op, BroadcastTo) for node in y_rewritten_fg.apply_nodes
)
default_mode = get_default_mode()
# The rewrite has already been applied to `y_rewritten`, so we can--and
# should--exclude it from the compilation of both our reference, `y`, and
# the rewritten result, `y_rewritten`.
rewrite_mode = default_mode.excluding("local_Unique_BroadcastTo_lift")
y_fn = function([x], [y, y_rewritten], mode=rewrite_mode)
# Make sure that the original `BroadcastTo` is used to compute the
# reference `y` result
assert any(
isinstance(node.op, BroadcastTo) for node in y_fn.maker.fgraph.apply_nodes
)
y_exp_val, y_val = y_fn(x_val)
assert np.array_equal(y_exp_val, y_val)
@pytest.mark.parametrize(
"x_val, unique_axis, repeats, repeat_axis",
[
......@@ -287,16 +229,3 @@ def test_local_Unique_second(
y_exp_val, y_val = y_fn(x_val)
assert np.array_equal(y_exp_val, y_val)
def test_local_remove_scalar_BroadcastTo():
x = dscalar()
y = BroadcastTo()(x, ())
assert isinstance(y.owner.op, BroadcastTo)
res = rewrite_graph(
y, clone=False, include=["canonicalize", "local_remove_scalar_BroadcastTo"]
)
assert res is x
......@@ -8,14 +8,12 @@ from pytensor import function
from pytensor import tensor as at
from pytensor.compile.mode import Mode
from pytensor.configdefaults import config
from pytensor.graph.basic import Constant, applys_between
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.graph.basic import Constant, applys_between, equal_computations
from pytensor.raise_op import Assert
from pytensor.tensor import alloc
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.extra_ops import (
Bartlett,
BroadcastTo,
CpuContiguous,
CumOp,
FillDiagonal,
......@@ -47,7 +45,6 @@ from pytensor.tensor.extra_ops import (
to_one_hot,
unravel_index,
)
from pytensor.tensor.subtensor import AdvancedIncSubtensor
from pytensor.tensor.type import (
TensorType,
dmatrix,
......@@ -61,7 +58,6 @@ from pytensor.tensor.type import (
lscalar,
matrix,
scalar,
tensor,
tensor3,
vector,
)
......@@ -1246,183 +1242,15 @@ def test_broadcast_shape_symbolic_one_symbolic():
assert res_shape[2].data == 3
class TestBroadcastTo(utt.InferShapeTester):
def setup_method(self):
super().setup_method()
self.op_class = BroadcastTo
self.op = broadcast_to
def test_avoid_useless_scalars(self):
x = scalar()
y = broadcast_to(x, ())
assert y is x
def test_avoid_useless_subtensors(self):
x = scalar()
y = broadcast_to(x, (1, 2))
# There shouldn't be any unnecessary `Subtensor` operations
# (e.g. from `at.as_tensor((1, 2))[0]`)
assert y.owner.inputs[1].owner is None
assert y.owner.inputs[2].owner is None
@pytest.mark.parametrize("linker", ["cvm", "py"])
def test_perform(self, linker):
a = pytensor.shared(np.full((3, 1, 1), 5))
s_0 = iscalar("s_0")
s_1 = iscalar("s_1")
shape = (s_0, s_1, 1)
bcast_res = broadcast_to(a, shape)
assert bcast_res.broadcastable == (False, False, True)
bcast_fn = pytensor.function(
[s_0, s_1], bcast_res, mode=Mode(optimizer=None, linker=linker)
)
bcast_fn.vm.allow_gc = False
bcast_at = bcast_fn(3, 4)
bcast_np = np.broadcast_to(5, (3, 4, 1))
assert np.array_equal(bcast_at, bcast_np)
with pytest.raises(ValueError):
bcast_fn(5, 4)
if linker != "py":
bcast_var = bcast_fn.maker.fgraph.outputs[0].owner.inputs[0]
bcast_in = bcast_fn.vm.storage_map[a]
bcast_out = bcast_fn.vm.storage_map[bcast_var]
assert np.shares_memory(bcast_out[0], bcast_in[0])
def test_make_node_error_handling(self):
with pytest.raises(
ValueError,
match="Broadcast target shape has 1 dims, which is shorter than input with 2 dims",
):
broadcast_to(at.zeros((3, 4)), (5,))
def test_broadcast_to():
x = vector("x")
y1 = scalar(dtype="int64")
y2 = scalar(dtype="int64")
@pytest.mark.skipif(
not config.cxx, reason="G++ not available, so we need to skip this test."
assert equal_computations(
[broadcast_to(x, (y1, y2))],
[alloc(x, y1, y2)],
)
@pytest.mark.parametrize("valid", (True, False))
def test_memory_leak(self, valid):
import gc
import tracemalloc
from pytensor.link.c.cvm import CVM
n = 100_000
x = pytensor.shared(np.ones((1, n), dtype=np.float64))
y = broadcast_to(x, (5, n))
f = pytensor.function([], y, mode=Mode(optimizer=None, linker="cvm"))
assert isinstance(f.vm, CVM)
assert len(f.maker.fgraph.apply_nodes) == 2
assert any(
isinstance(node.op, BroadcastTo) for node in f.maker.fgraph.apply_nodes
)
tracemalloc.start()
blocks_last = None
block_diffs = []
for i in range(1, 50):
if valid:
x.set_value(np.ones((1, n)))
_ = f()
else:
x.set_value(np.ones((2, n)))
try:
_ = f()
except ValueError:
pass
else:
raise RuntimeError("Should have failed")
_ = gc.collect()
blocks_i, _ = tracemalloc.get_traced_memory()
if blocks_last is not None:
blocks_diff = (blocks_i - blocks_last) // 10**3
block_diffs.append(blocks_diff)
blocks_last = blocks_i
tracemalloc.stop()
assert np.all(np.array(block_diffs) <= (0 + 1e-8))
@pytest.mark.parametrize(
"fn,input_dims",
[
[lambda x: broadcast_to(x, (1,)), (1,)],
[lambda x: broadcast_to(x, (6, 2, 5, 3)), (1,)],
[lambda x: broadcast_to(x, (6, 2, 5, 3)), (5, 1)],
[lambda x: broadcast_to(x, (6, 2, 1, 3)), (2, 1, 3)],
],
)
def test_gradient(self, fn, input_dims):
rng = np.random.default_rng(43)
utt.verify_grad(
fn,
[rng.random(input_dims).astype(config.floatX)],
n_tests=1,
rng=rng,
)
def test_infer_shape(self):
rng = np.random.default_rng(43)
a = tensor(dtype=config.floatX, shape=(None, 1, None))
shape = list(a.shape)
out = self.op(a, shape)
self._compile_and_check(
[a] + shape,
[out],
[rng.random((2, 1, 3)).astype(config.floatX), 2, 1, 3],
self.op_class,
)
a = tensor(dtype=config.floatX, shape=(None, 1, None))
shape = [iscalar() for i in range(4)]
self._compile_and_check(
[a] + shape,
[self.op(a, shape)],
[rng.random((2, 1, 3)).astype(config.floatX), 6, 2, 5, 3],
self.op_class,
)
def test_inplace(self):
"""Make sure that in-place optimizations are *not* performed on the output of a ``BroadcastTo``."""
a = at.zeros((5,))
d = at.vector("d")
c = at.set_subtensor(a[np.r_[0, 1, 3]], d)
b = broadcast_to(c, (5,))
q = b[np.r_[0, 1, 3]]
e = at.set_subtensor(q, np.r_[0, 0, 0])
opts = RewriteDatabaseQuery(include=["inplace"])
py_mode = Mode("py", opts)
e_fn = function([d], e, mode=py_mode)
advincsub_node = e_fn.maker.fgraph.outputs[0].owner
assert isinstance(advincsub_node.op, AdvancedIncSubtensor)
assert isinstance(advincsub_node.inputs[0].owner.op, BroadcastTo)
assert advincsub_node.op.inplace is False
def test_rebuild(self):
x = vector(shape=(50,))
x_test = np.zeros((50,), dtype=config.floatX)
i = 0
y = broadcast_to(i, x.shape)
assert y.type.shape == (50,)
assert y.shape.eval({x: x_test}) == (50,)
assert y.eval({x: x_test}).shape == (50,)
x_new = vector(shape=(100,))
x_new_test = np.zeros((100,), dtype=config.floatX)
y_new = clone_replace(y, {x: x_new}, rebuild_strict=False)
assert y_new.type.shape == (100,)
assert y_new.shape.eval({x_new: x_new_test}) == (100,)
assert y_new.eval({x_new: x_new_test}).shape == (100,)
def test_broadcast_arrays():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论