提交 6898f749 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Remove BroadcastTo

上级 5f809cfe
...@@ -3,10 +3,8 @@ import warnings ...@@ -3,10 +3,8 @@ import warnings
import jax.numpy as jnp import jax.numpy as jnp
from pytensor.link.jax.dispatch.basic import jax_funcify from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.tensor.basic import infer_static_shape
from pytensor.tensor.extra_ops import ( from pytensor.tensor.extra_ops import (
Bartlett, Bartlett,
BroadcastTo,
CumOp, CumOp,
FillDiagonal, FillDiagonal,
FillDiagonalOffset, FillDiagonalOffset,
...@@ -102,18 +100,6 @@ def jax_funcify_RavelMultiIndex(op, **kwargs): ...@@ -102,18 +100,6 @@ def jax_funcify_RavelMultiIndex(op, **kwargs):
return ravelmultiindex return ravelmultiindex
@jax_funcify.register(BroadcastTo)
def jax_funcify_BroadcastTo(op, node, **kwargs):
shape = node.inputs[1:]
static_shape = infer_static_shape(shape)[1]
def broadcast_to(x, *shape):
shape = tuple(st if st is not None else s for s, st in zip(shape, static_shape))
return jnp.broadcast_to(x, shape)
return broadcast_to
@jax_funcify.register(FillDiagonal) @jax_funcify.register(FillDiagonal)
def jax_funcify_FillDiagonal(op, **kwargs): def jax_funcify_FillDiagonal(op, **kwargs):
def filldiagonal(value, diagonal): def filldiagonal(value, diagonal):
......
...@@ -2,7 +2,6 @@ import warnings ...@@ -2,7 +2,6 @@ import warnings
import numba import numba
import numpy as np import numpy as np
from numba.misc.special import literal_unroll
from pytensor import config from pytensor import config
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
...@@ -10,7 +9,6 @@ from pytensor.link.numba.dispatch.basic import get_numba_type, numba_funcify ...@@ -10,7 +9,6 @@ from pytensor.link.numba.dispatch.basic import get_numba_type, numba_funcify
from pytensor.raise_op import CheckAndRaise from pytensor.raise_op import CheckAndRaise
from pytensor.tensor.extra_ops import ( from pytensor.tensor.extra_ops import (
Bartlett, Bartlett,
BroadcastTo,
CumOp, CumOp,
FillDiagonal, FillDiagonal,
FillDiagonalOffset, FillDiagonalOffset,
...@@ -353,29 +351,6 @@ def numba_funcify_Searchsorted(op, node, **kwargs): ...@@ -353,29 +351,6 @@ def numba_funcify_Searchsorted(op, node, **kwargs):
return searchsorted return searchsorted
@numba_funcify.register(BroadcastTo)
def numba_funcify_BroadcastTo(op, node, **kwargs):
create_zeros_tuple = numba_basic.create_tuple_creator(
lambda _: 0, len(node.inputs) - 1
)
# TODO broadcastable checks
@numba_basic.numba_njit
def broadcast_to(x, *shape):
scalars_shape = create_zeros_tuple()
i = 0
for s_i in literal_unroll(shape):
scalars_shape = numba_basic.tuple_setitem(
scalars_shape, i, numba_basic.to_scalar(s_i)
)
i += 1
return np.broadcast_to(x, scalars_shape)
return broadcast_to
@numba_funcify.register(CheckAndRaise) @numba_funcify.register(CheckAndRaise)
def numba_funcify_CheckAndRaise(op, node, **kwargs): def numba_funcify_CheckAndRaise(op, node, **kwargs):
error = op.exc_type error = op.exc_type
......
...@@ -23,7 +23,7 @@ from pytensor.scalar import int32 as int_t ...@@ -23,7 +23,7 @@ from pytensor.scalar import int32 as int_t
from pytensor.scalar import upcast from pytensor.scalar import upcast
from pytensor.tensor import as_tensor_variable from pytensor.tensor import as_tensor_variable
from pytensor.tensor import basic as at from pytensor.tensor import basic as at
from pytensor.tensor.basic import get_vector_length, second from pytensor.tensor.basic import alloc, second
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import abs as pt_abs from pytensor.tensor.math import abs as pt_abs
from pytensor.tensor.math import all as pt_all from pytensor.tensor.math import all as pt_all
...@@ -1584,141 +1584,6 @@ def broadcast_shape_iter( ...@@ -1584,141 +1584,6 @@ def broadcast_shape_iter(
return tuple(result_dims) return tuple(result_dims)
class BroadcastTo(COp):
"""An `Op` for `numpy.broadcast_to`."""
_output_type_depends_on_input_value = True
__props__ = ()
view_map = {0: [0]}
def __call__(self, a, shape, **kwargs):
return super().__call__(a, *shape, **kwargs)
def make_node(self, a, *shape):
a = at.as_tensor_variable(a)
shape, static_shape = at.infer_static_shape(shape)
if len(shape) < a.ndim:
raise ValueError(
f"Broadcast target shape has {len(shape)} dims, which is shorter than input with {a.ndim} dims"
)
out = TensorType(dtype=a.type.dtype, shape=static_shape)()
# Attempt to prevent in-place operations on this view-based output
out.tag.indestructible = True
return Apply(self, [a] + shape, [out])
def perform(self, node, inputs, output_storage):
a, *shape = inputs
z = output_storage[0]
z[0] = np.broadcast_to(a, shape)
def grad(self, inputs, outputs_gradients):
a, *shape = inputs
(dout,) = outputs_gradients
# Determine the dimensions that were added by broadcasting
new_dims = list(range(dout.ndim - a.ndim))
d_wrt_a = broadcast_to(dout, shape).sum(axis=new_dims)
# Determine the dimensions that were broadcast
_, static_shape = at.infer_static_shape(shape)
# TODO: This needs to be performed at run-time when static shape
# information isn't available.
bcast_sums = [
i
for i, (a_s, s_s) in enumerate(zip(a.type.shape, static_shape[-a.ndim :]))
if a_s == 1 and s_s != 1
]
if bcast_sums:
d_wrt_a = d_wrt_a.sum(axis=bcast_sums, keepdims=True)
return [d_wrt_a] + [
grad_undefined(self, i, shp) for i, shp in enumerate(shape, 1)
]
def infer_shape(self, fgraph, node, ins_shapes):
return [node.inputs[1:]]
def c_code(self, node, name, inputs, outputs, sub):
inp_dims = node.inputs[0].ndim
out_dims = node.outputs[0].ndim
new_dims = out_dims - inp_dims
(x, *shape) = inputs
(out,) = outputs
fail = sub["fail"]
# TODO: Could just use `PyArray_Return`, no?
dims_array = ", ".join(
[
f"((dtype_{shape}*)(PyArray_DATA({shape})))[0]"
for i, shape in enumerate(shape)
]
)
src = (
"""
npy_intp itershape[%(out_dims)s] = {%(dims_array)s};
NpyIter *iter;
PyArrayObject *ops[1] = {%(x)s};
npy_uint32 flags = NPY_ITER_MULTI_INDEX | NPY_ITER_REFS_OK | NPY_ITER_ZEROSIZE_OK;
npy_uint32 op_flags[1] = {NPY_ITER_READONLY};
PyArray_Descr *op_dtypes[1] = {NULL};
int oa_ndim = %(out_dims)s;
int* op_axes[1] = {NULL};
npy_intp buffersize = 0;
for(int i = 0; i < %(inp_dims)s; i++)
{
if ((PyArray_DIMS(%(x)s)[i] != 1) && (PyArray_DIMS(%(x)s)[i] != itershape[i + %(new_dims)s]))
{
PyErr_Format(PyExc_ValueError,
"Shape mismatch in broadcast_to: target shape[%%i] = %%lld is incompatible with input shape = %%lld.",
i,
(long long int) itershape[i + %(new_dims)s],
(long long int) PyArray_DIMS(%(x)s)[i]
);
%(fail)s
}
}
iter = NpyIter_AdvancedNew(
1, ops, flags, NPY_CORDER, NPY_NO_CASTING, op_flags, op_dtypes, oa_ndim, op_axes, itershape, buffersize
);
%(out)s = NpyIter_GetIterView(iter, 0);
if(%(out)s == NULL){
NpyIter_Deallocate(iter);
%(fail)s;
}
if (NpyIter_Deallocate(iter) != NPY_SUCCEED) {
%(fail)s;
}
"""
% locals()
)
return src
def c_code_cache_version(self):
return (2,)
broadcast_to_ = BroadcastTo()
def geomspace(start, end, steps, base=10.0): def geomspace(start, end, steps, base=10.0):
from pytensor.tensor.math import log from pytensor.tensor.math import log
...@@ -1762,13 +1627,7 @@ def broadcast_to( ...@@ -1762,13 +1627,7 @@ def broadcast_to(
broadcasted array may refer to a single memory location. broadcasted array may refer to a single memory location.
""" """
x = at.as_tensor(x) return alloc(x, *shape)
shape_len = get_vector_length(shape)
if x.ndim == 0 and shape_len == 0:
return x
return broadcast_to_(x, shape)
def broadcast_arrays(*args: TensorVariable) -> Tuple[TensorVariable, ...]: def broadcast_arrays(*args: TensorVariable) -> Tuple[TensorVariable, ...]:
......
...@@ -2,7 +2,7 @@ import pytensor.scalar.basic as aes ...@@ -2,7 +2,7 @@ import pytensor.scalar.basic as aes
from pytensor.graph.rewriting.basic import node_rewriter from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.tensor.basic import Alloc, as_tensor_variable from pytensor.tensor.basic import Alloc, as_tensor_variable
from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.extra_ops import BroadcastTo, Repeat, Unique from pytensor.tensor.extra_ops import Repeat, Unique
from pytensor.tensor.rewriting.basic import register_canonicalize, register_useless from pytensor.tensor.rewriting.basic import register_canonicalize, register_useless
...@@ -60,39 +60,6 @@ def local_Unique_Alloc_lift(fgraph, node): ...@@ -60,39 +60,6 @@ def local_Unique_Alloc_lift(fgraph, node):
return [new_x] return [new_x]
@register_useless
@register_canonicalize
@node_rewriter([Unique])
def local_Unique_BroadcastTo_lift(fgraph, node):
"""Convert ``unique(broadcast_to(x, ...), axis=None)`` to ``unique(x, axis=None)``.
This isn't really so much a lift as a "reduction/consumption".
"""
if not isinstance(node.op, Unique):
return False
if (
node.op.return_index
or node.op.return_inverse
or node.op.return_counts
or node.op.axis is not None
):
return False
bcast_var = node.inputs[0]
if not (bcast_var.owner and isinstance(bcast_var.owner.op, BroadcastTo)):
return False
bcasted_var, *bcast_shape = bcast_var.owner.inputs
new_unique, *_ = node.op.make_node(bcasted_var).outputs
old_out = node.outputs[0]
new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype)
return [new_x]
@register_useless @register_useless
@register_canonicalize @register_canonicalize
@node_rewriter([Unique]) @node_rewriter([Unique])
...@@ -161,16 +128,3 @@ def local_Unique_second(fgraph, node): ...@@ -161,16 +128,3 @@ def local_Unique_second(fgraph, node):
old_out = node.outputs[0] old_out = node.outputs[0]
new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype) new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype)
return [new_x] return [new_x]
@register_useless
@register_canonicalize
@node_rewriter([BroadcastTo])
def local_remove_scalar_BroadcastTo(fgraph, node):
bcast_shape = node.inputs[1:]
if not bcast_shape:
bcasted_var = node.inputs[0]
# If this isn't true, the graph is invalid
assert bcasted_var.ndim == 0
return [bcasted_var]
...@@ -7,7 +7,7 @@ from pytensor.configdefaults import config ...@@ -7,7 +7,7 @@ from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value from pytensor.graph.op import get_test_value
from pytensor.tensor import extra_ops as at_extra_ops from pytensor.tensor import extra_ops as at_extra_ops
from pytensor.tensor.type import matrix, vector from pytensor.tensor.type import matrix
from tests.link.jax.test_basic import compare_jax_and_py from tests.link.jax.test_basic import compare_jax_and_py
...@@ -63,29 +63,6 @@ def test_extra_ops(): ...@@ -63,29 +63,6 @@ def test_extra_ops():
) )
@pytest.mark.parametrize(
"x, shape",
[
(
set_test_value(
vector("x"), np.random.random(size=(2,)).astype(config.floatX)
),
[at.as_tensor(3, dtype=np.int64), at.as_tensor(2, dtype=np.int64)],
),
(
set_test_value(
vector("x"), np.random.random(size=(2,)).astype(config.floatX)
),
[at.as_tensor(3, dtype=np.int8), at.as_tensor(2, dtype=np.int64)],
),
],
)
def test_BroadcastTo(x, shape):
out = at_extra_ops.broadcast_to(x, shape)
fgraph = FunctionGraph(outputs=[out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
@pytest.mark.xfail( @pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"), version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled", reason="Omnistaging cannot be disabled",
......
...@@ -36,41 +36,6 @@ def test_Bartlett(val): ...@@ -36,41 +36,6 @@ def test_Bartlett(val):
) )
@pytest.mark.parametrize(
"x, shape",
[
(
set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)),
[set_test_value(at.lscalar(), np.array(v)) for v in [3, 2]],
),
(
set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)),
[at.as_tensor(3, dtype=np.int64), at.as_tensor(2, dtype=np.int64)],
),
(
set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)),
at.as_tensor([set_test_value(at.lscalar(), np.array(v)) for v in [3, 2]]),
),
(
set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)),
[at.as_tensor(3, dtype=np.int8), at.as_tensor(2, dtype=np.int64)],
),
],
)
def test_BroadcastTo(x, shape):
g = extra_ops.BroadcastTo()(x, shape)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"val, axis, mode", "val, axis, mode",
[ [
......
...@@ -8,7 +8,7 @@ from pytensor.graph.fg import FunctionGraph ...@@ -8,7 +8,7 @@ from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.utils import rewrite_graph from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.tensor.basic import Alloc, alloc, as_tensor_variable, second from pytensor.tensor.basic import Alloc, alloc, as_tensor_variable, second
from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.extra_ops import BroadcastTo, Repeat, Unique, repeat, unique from pytensor.tensor.extra_ops import Repeat, Unique, repeat, unique
from pytensor.tensor.type import dscalar from pytensor.tensor.type import dscalar
...@@ -103,64 +103,6 @@ def test_local_Unique_Alloc_lift( ...@@ -103,64 +103,6 @@ def test_local_Unique_Alloc_lift(
assert np.array_equal(y_exp_val, y_val) assert np.array_equal(y_exp_val, y_val)
@pytest.mark.parametrize(
"x_val, axis, new_shape",
[
(np.array(-10, dtype=np.int64), None, (2, 3)),
(np.array([[-10, -3], [-10, 2], [-10, 2]], dtype=np.int64), None, (2, 3, 2)),
],
)
@pytest.mark.parametrize("return_index", [False])
@pytest.mark.parametrize("return_counts", [False])
@pytest.mark.parametrize("return_inverse", [False])
def test_local_Unique_BroadcastTo(
x_val, axis, new_shape, return_index, return_counts, return_inverse
):
x = as_tensor_variable(x_val).type()
y = unique(
BroadcastTo()(x, tuple(new_shape)),
return_index=return_index,
return_counts=return_counts,
return_inverse=return_inverse,
axis=axis,
)
if isinstance(y, list):
y, *_ = y
# This approach allows us to directly confirm that `x` is in the result.
y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_rewritten_fg = rewrite_graph(
y_fg,
clone=False,
include=["canonicalize", "local_Unique_BroadcastTo_lift"],
exclude=["local_Unique_scalar"],
)
y_rewritten = y_rewritten_fg.outputs[0]
y_rewritten_start = y_rewritten
assert isinstance(y_rewritten_start.owner.op, Unique)
assert y_rewritten_start.owner.inputs[0] == x
assert not any(
isinstance(node.op, BroadcastTo) for node in y_rewritten_fg.apply_nodes
)
default_mode = get_default_mode()
# The rewrite has already been applied to `y_rewritten`, so we can--and
# should--exclude it from the compilation of both our reference, `y`, and
# the rewritten result, `y_rewritten`.
rewrite_mode = default_mode.excluding("local_Unique_BroadcastTo_lift")
y_fn = function([x], [y, y_rewritten], mode=rewrite_mode)
# Make sure that the original `BroadcastTo` is used to compute the
# reference `y` result
assert any(
isinstance(node.op, BroadcastTo) for node in y_fn.maker.fgraph.apply_nodes
)
y_exp_val, y_val = y_fn(x_val)
assert np.array_equal(y_exp_val, y_val)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"x_val, unique_axis, repeats, repeat_axis", "x_val, unique_axis, repeats, repeat_axis",
[ [
...@@ -287,16 +229,3 @@ def test_local_Unique_second( ...@@ -287,16 +229,3 @@ def test_local_Unique_second(
y_exp_val, y_val = y_fn(x_val) y_exp_val, y_val = y_fn(x_val)
assert np.array_equal(y_exp_val, y_val) assert np.array_equal(y_exp_val, y_val)
def test_local_remove_scalar_BroadcastTo():
x = dscalar()
y = BroadcastTo()(x, ())
assert isinstance(y.owner.op, BroadcastTo)
res = rewrite_graph(
y, clone=False, include=["canonicalize", "local_remove_scalar_BroadcastTo"]
)
assert res is x
...@@ -8,14 +8,12 @@ from pytensor import function ...@@ -8,14 +8,12 @@ from pytensor import function
from pytensor import tensor as at from pytensor import tensor as at
from pytensor.compile.mode import Mode from pytensor.compile.mode import Mode
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.basic import Constant, applys_between from pytensor.graph.basic import Constant, applys_between, equal_computations
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.raise_op import Assert from pytensor.raise_op import Assert
from pytensor.tensor import alloc
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.extra_ops import ( from pytensor.tensor.extra_ops import (
Bartlett, Bartlett,
BroadcastTo,
CpuContiguous, CpuContiguous,
CumOp, CumOp,
FillDiagonal, FillDiagonal,
...@@ -47,7 +45,6 @@ from pytensor.tensor.extra_ops import ( ...@@ -47,7 +45,6 @@ from pytensor.tensor.extra_ops import (
to_one_hot, to_one_hot,
unravel_index, unravel_index,
) )
from pytensor.tensor.subtensor import AdvancedIncSubtensor
from pytensor.tensor.type import ( from pytensor.tensor.type import (
TensorType, TensorType,
dmatrix, dmatrix,
...@@ -61,7 +58,6 @@ from pytensor.tensor.type import ( ...@@ -61,7 +58,6 @@ from pytensor.tensor.type import (
lscalar, lscalar,
matrix, matrix,
scalar, scalar,
tensor,
tensor3, tensor3,
vector, vector,
) )
...@@ -1246,183 +1242,15 @@ def test_broadcast_shape_symbolic_one_symbolic(): ...@@ -1246,183 +1242,15 @@ def test_broadcast_shape_symbolic_one_symbolic():
assert res_shape[2].data == 3 assert res_shape[2].data == 3
class TestBroadcastTo(utt.InferShapeTester): def test_broadcast_to():
def setup_method(self): x = vector("x")
super().setup_method() y1 = scalar(dtype="int64")
self.op_class = BroadcastTo y2 = scalar(dtype="int64")
self.op = broadcast_to
def test_avoid_useless_scalars(self):
x = scalar()
y = broadcast_to(x, ())
assert y is x
def test_avoid_useless_subtensors(self):
x = scalar()
y = broadcast_to(x, (1, 2))
# There shouldn't be any unnecessary `Subtensor` operations
# (e.g. from `at.as_tensor((1, 2))[0]`)
assert y.owner.inputs[1].owner is None
assert y.owner.inputs[2].owner is None
@pytest.mark.parametrize("linker", ["cvm", "py"])
def test_perform(self, linker):
a = pytensor.shared(np.full((3, 1, 1), 5))
s_0 = iscalar("s_0")
s_1 = iscalar("s_1")
shape = (s_0, s_1, 1)
bcast_res = broadcast_to(a, shape)
assert bcast_res.broadcastable == (False, False, True)
bcast_fn = pytensor.function(
[s_0, s_1], bcast_res, mode=Mode(optimizer=None, linker=linker)
)
bcast_fn.vm.allow_gc = False
bcast_at = bcast_fn(3, 4)
bcast_np = np.broadcast_to(5, (3, 4, 1))
assert np.array_equal(bcast_at, bcast_np)
with pytest.raises(ValueError):
bcast_fn(5, 4)
if linker != "py":
bcast_var = bcast_fn.maker.fgraph.outputs[0].owner.inputs[0]
bcast_in = bcast_fn.vm.storage_map[a]
bcast_out = bcast_fn.vm.storage_map[bcast_var]
assert np.shares_memory(bcast_out[0], bcast_in[0])
def test_make_node_error_handling(self):
with pytest.raises(
ValueError,
match="Broadcast target shape has 1 dims, which is shorter than input with 2 dims",
):
broadcast_to(at.zeros((3, 4)), (5,))
@pytest.mark.skipif( assert equal_computations(
not config.cxx, reason="G++ not available, so we need to skip this test." [broadcast_to(x, (y1, y2))],
[alloc(x, y1, y2)],
) )
@pytest.mark.parametrize("valid", (True, False))
def test_memory_leak(self, valid):
import gc
import tracemalloc
from pytensor.link.c.cvm import CVM
n = 100_000
x = pytensor.shared(np.ones((1, n), dtype=np.float64))
y = broadcast_to(x, (5, n))
f = pytensor.function([], y, mode=Mode(optimizer=None, linker="cvm"))
assert isinstance(f.vm, CVM)
assert len(f.maker.fgraph.apply_nodes) == 2
assert any(
isinstance(node.op, BroadcastTo) for node in f.maker.fgraph.apply_nodes
)
tracemalloc.start()
blocks_last = None
block_diffs = []
for i in range(1, 50):
if valid:
x.set_value(np.ones((1, n)))
_ = f()
else:
x.set_value(np.ones((2, n)))
try:
_ = f()
except ValueError:
pass
else:
raise RuntimeError("Should have failed")
_ = gc.collect()
blocks_i, _ = tracemalloc.get_traced_memory()
if blocks_last is not None:
blocks_diff = (blocks_i - blocks_last) // 10**3
block_diffs.append(blocks_diff)
blocks_last = blocks_i
tracemalloc.stop()
assert np.all(np.array(block_diffs) <= (0 + 1e-8))
@pytest.mark.parametrize(
"fn,input_dims",
[
[lambda x: broadcast_to(x, (1,)), (1,)],
[lambda x: broadcast_to(x, (6, 2, 5, 3)), (1,)],
[lambda x: broadcast_to(x, (6, 2, 5, 3)), (5, 1)],
[lambda x: broadcast_to(x, (6, 2, 1, 3)), (2, 1, 3)],
],
)
def test_gradient(self, fn, input_dims):
rng = np.random.default_rng(43)
utt.verify_grad(
fn,
[rng.random(input_dims).astype(config.floatX)],
n_tests=1,
rng=rng,
)
def test_infer_shape(self):
rng = np.random.default_rng(43)
a = tensor(dtype=config.floatX, shape=(None, 1, None))
shape = list(a.shape)
out = self.op(a, shape)
self._compile_and_check(
[a] + shape,
[out],
[rng.random((2, 1, 3)).astype(config.floatX), 2, 1, 3],
self.op_class,
)
a = tensor(dtype=config.floatX, shape=(None, 1, None))
shape = [iscalar() for i in range(4)]
self._compile_and_check(
[a] + shape,
[self.op(a, shape)],
[rng.random((2, 1, 3)).astype(config.floatX), 6, 2, 5, 3],
self.op_class,
)
def test_inplace(self):
"""Make sure that in-place optimizations are *not* performed on the output of a ``BroadcastTo``."""
a = at.zeros((5,))
d = at.vector("d")
c = at.set_subtensor(a[np.r_[0, 1, 3]], d)
b = broadcast_to(c, (5,))
q = b[np.r_[0, 1, 3]]
e = at.set_subtensor(q, np.r_[0, 0, 0])
opts = RewriteDatabaseQuery(include=["inplace"])
py_mode = Mode("py", opts)
e_fn = function([d], e, mode=py_mode)
advincsub_node = e_fn.maker.fgraph.outputs[0].owner
assert isinstance(advincsub_node.op, AdvancedIncSubtensor)
assert isinstance(advincsub_node.inputs[0].owner.op, BroadcastTo)
assert advincsub_node.op.inplace is False
def test_rebuild(self):
x = vector(shape=(50,))
x_test = np.zeros((50,), dtype=config.floatX)
i = 0
y = broadcast_to(i, x.shape)
assert y.type.shape == (50,)
assert y.shape.eval({x: x_test}) == (50,)
assert y.eval({x: x_test}).shape == (50,)
x_new = vector(shape=(100,))
x_new_test = np.zeros((100,), dtype=config.floatX)
y_new = clone_replace(y, {x: x_new}, rebuild_strict=False)
assert y_new.type.shape == (100,)
assert y_new.shape.eval({x_new: x_new_test}) == (100,)
assert y_new.eval({x_new: x_new_test}).shape == (100,)
def test_broadcast_arrays(): def test_broadcast_arrays():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论