提交 4ac1e637 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Simplify implementation of tile

Deprecate obscure ndim kwarg
上级 c22e79e1
......@@ -10,7 +10,7 @@ import warnings
from collections.abc import Sequence
from functools import partial
from numbers import Number
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Union
from typing import cast as type_cast
import numpy as np
......@@ -33,7 +33,7 @@ from pytensor.graph.type import HasShape, Type
from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType
from pytensor.printing import Printer, min_informative_str, pprint, set_precedence
from pytensor.raise_op import CheckAndRaise, assert_op
from pytensor.raise_op import CheckAndRaise
from pytensor.scalar import int32
from pytensor.scalar.basic import ScalarConstant, ScalarType, ScalarVariable
from pytensor.tensor import (
......@@ -3084,87 +3084,132 @@ def flatten(x, ndim=1):
return x_reshaped
def tile(x, reps, ndim=None):
def tile(
A: "TensorLike", reps: Union[Sequence[Union[int, "TensorLike"]], "TensorLike"]
) -> TensorVariable:
"""
Tile input array `x` according to `reps`.
Tile input tensor `A` according to `reps`.
See the docstring of `numpy.tile` for details.
'reps' can be constant integer (e.g. 3), constant vector(e.g. [2 3]),
symbolic scalar (e.g. tensor.iscalar()), symbolic vector (e.g. tensor.ivector())
or a list of symbolic scalar (e.g. [tensor.iscalar(), tensor.iscalar()]).
If `reps` is a PyTensor vector, its length must be statically known.
You can use `specify_shape` to set the length.
Examples
--------
.. testcode::
import pytensor.tensor as pt
A = pt.matrix("A", dtype=int)
A_tiled = pt.tile(A, 2)
print(A_tiled.eval({A: [[1, 2], [3, 4]]}))
.. testoutput::
[[1 2 1 2]
[3 4 3 4]]
Reps can be a sequence of constants and/ or symbolic integer variables
.. testcode::
rep0 = pt.scalar("rep0", dtype=int)
A_tiled = pt.tile(A, (rep0, 1))
print(A_tiled.eval({A: [[1, 2], [3, 4]], rep0: 2}))
.. testoutput::
[[1 2]
[3 4]
[1 2]
[3 4]]
Reps can be a single integer vector, in which case its length must be statically known.
Either of the following is a valid way to specify the length:
.. testcode::
reps = pt.vector("reps", dtype=int, shape=(2,))
A_tiled = pt.tile(A, reps)
print(A_tiled.eval({A: [[1, 2], [3, 4]], reps: [1, 2]}))
.. testoutput::
[[1 2 1 2]
[3 4 3 4]]
.. testcode::
ndim is the number of the dimensions of the output, if it is provided, ndim
should be equal or larger than x.ndim and len(reps), otherwise, we will use
max(x.ndim, len(reps)) as ndim. If reps is symbolic vector, the ndim has to
be provided.
reps = pt.vector("reps", dtype=int)
reps = pt.specify_shape(reps, (2,))
A_tiled = pt.tile(A, reps)
print(A_tiled.eval({A: [[1, 2], [3, 4]], reps: [2, 2]}))
.. testoutput::
[[1 2 1 2]
[3 4 3 4]
[1 2 1 2]
[3 4 3 4]]
"""
from pytensor.tensor.math import ge
_x = as_tensor_variable(x)
if ndim is not None and ndim < _x.ndim:
raise ValueError("ndim should be equal or larger than _x.ndim")
A = as_tensor_variable(A)
# If reps is a scalar, integer or vector, we convert it to a list.
# Convert symbolic reps to a tuple
if not isinstance(reps, list | tuple):
reps_astensor = as_tensor_variable(reps)
ndim_check = reps_astensor.ndim
if reps_astensor.dtype not in discrete_dtypes:
raise ValueError("elements of reps must be integer dtype")
# The scalar/integer case
if ndim_check == 0:
reps = [reps]
# The vector case
elif ndim_check == 1:
if ndim is None:
reps = as_tensor_variable(reps)
if reps.type.ndim == 0:
reps = (reps,)
elif reps.type.ndim == 1:
try:
reps = tuple(reps)
except ValueError:
raise ValueError(
"if reps is tensor.vector, you should specify the ndim"
"Length of repetitions tensor cannot be determined. Use specify_shape to set the length."
)
else:
offset = ndim - reps.shape[0]
# assert that reps.shape[0] does not exceed ndim
offset = assert_op(offset, ge(offset, 0))
else:
raise ValueError(
f"Repetitions tensor must be a scalar or a vector, got ndim={reps.type.ndim}"
)
# if reps.ndim is less than _x.ndim, we pad the reps with
# "1" so that reps will have the same ndim as _x.
reps_ = [switch(i < offset, 1, reps[i - offset]) for i in range(ndim)]
reps = reps_
reps = [as_tensor_variable(rep) for rep in reps]
if not all(
rep.type.ndim == 0 and rep.type.dtype in discrete_dtypes for rep in reps
):
raise ValueError(
f"All reps entries shoud be scalar integers, got {reps} of type {[rep.type for rep in reps]}"
)
# For others, raise an error
else:
raise ValueError("the dimension of reps should not exceed 1")
else:
if ndim is not None and len(reps) > ndim:
raise ValueError("len(reps) should be equal or less than ndim")
if not all(
isinstance(r, int)
or (isinstance(r, TensorVariable) and r.dtype in discrete_dtypes)
for r in reps
):
raise ValueError("elements of reps must be scalars of integer dtype")
len_reps = len(reps)
out_ndim = builtins.max(len_reps, A.type.ndim)
# Pad reps on the left (if needed)
if len_reps < out_ndim:
reps = (*((1,) * (out_ndim - len_reps)), *reps)
# Pad A's shape on the left (if needed)
elif A.type.ndim < out_ndim:
A = shape_padleft(A, out_ndim - A.type.ndim)
# Expand every other dim of A and expand n-reps via Alloc
# A_replicated = alloc(A[None, :, ..., None, :], reps[0], A.shape[0], ..., reps[-1], A.shape[-1])
A_shape = A.shape
interleaved_reps_shape = [
d for pair in zip(reps, A_shape, strict=True) for d in pair
]
every_other_axis = tuple(range(0, out_ndim * 2, 2))
A_replicated = alloc(
expand_dims(A, every_other_axis),
*interleaved_reps_shape,
)
# If reps.ndim is less than _x.ndim, we pad the reps with
# "1" so that reps will have the same ndim as _x
reps = list(reps)
if ndim is None:
ndim = builtins.max(len(reps), _x.ndim)
if len(reps) < ndim:
reps = [1] * (ndim - len(reps)) + reps
_shape = [1] * (ndim - _x.ndim) + [_x.shape[i] for i in range(_x.ndim)]
alloc_shape = reps + _shape
y = alloc(_x, *alloc_shape)
shuffle_ind = np.arange(ndim * 2).reshape(2, ndim)
shuffle_ind = shuffle_ind.transpose().flatten()
y = y.dimshuffle(*shuffle_ind)
new_shapes = [sh * reps[i] for i, sh in enumerate(_shape)]
y = y.reshape(new_shapes)
return y
# Combine replicate and original dimensions via reshape
# A_tiled = A_replicated.reshape(reps[0] * A.shape[0], ..., reps[-1] * A.shape[-1])
tiled_shape = tuple(rep * A_dim for rep, A_dim in zip(reps, A_shape, strict=True))
return A_replicated.reshape(tiled_shape)
class ARange(Op):
......
......@@ -2386,194 +2386,120 @@ def test_is_flat():
assert not ptb.is_flat(X.reshape((iscalar(),) * 3))
def test_tile():
"""
TODO FIXME: Split this apart and parameterize. Also, find out why it's
unreasonably slow.
"""
class TestTile:
@pytest.mark.parametrize(
"A_shape, reps_test",
[
((), (2,)),
((5,), (2,)),
((2, 4), (2, 3)),
((2, 4), (2, 3, 4)),
((2, 4, 3), (2, 3)),
((2, 4, 3), (2, 3, 4)),
((2, 4, 3, 5), (2, 3, 4, 6)),
],
)
def test_tile_separate_reps_entries(self, A_shape, reps_test):
rng = np.random.default_rng(2400)
def run_tile(x, x_, reps, use_symbolic_reps):
if use_symbolic_reps:
rep_symbols = [iscalar() for _ in range(len(reps))]
f = function([x, *rep_symbols], tile(x, rep_symbols))
return f(*([x_, *reps]))
else:
f = function([x], tile(x, reps))
return f(x_)
A = tensor("A", shape=(None,) * len(A_shape))
reps = [iscalar(f"r{i}") for i in range(len(reps_test))]
tile_out = tile(A, reps)
rng = np.random.default_rng(utt.fetch_seed())
tile_fn = function([A, *reps], tile_out)
for use_symbolic_reps in [False, True]:
# Test the one-dimensional case.
x = vector()
x_ = rng.standard_normal(5).astype(config.floatX)
assert np.all(run_tile(x, x_, (2,), use_symbolic_reps) == np.tile(x_, (2,)))
A_test = rng.standard_normal(A_shape).astype(config.floatX)
np.testing.assert_array_equal(
tile_fn(A_test, *reps_test),
np.tile(A_test, reps_test),
strict=True,
)
# Test the two-dimensional case.
x = matrix()
x_ = rng.standard_normal((2, 4)).astype(config.floatX)
assert np.all(run_tile(x, x_, (2, 3), use_symbolic_reps) == np.tile(x_, (2, 3)))
# Test the three-dimensional case.
x = tensor3()
x_ = rng.standard_normal((2, 4, 3)).astype(config.floatX)
assert np.all(
run_tile(x, x_, (2, 3, 4), use_symbolic_reps) == np.tile(x_, (2, 3, 4))
@pytest.mark.parametrize("reps", (2, np.array([2, 3, 4])))
def test_combined_reps_entries(self, reps):
rng = np.random.default_rng(2422)
A_test = rng.standard_normal((2, 4, 3)).astype(config.floatX)
expected_eval = np.tile(A_test, reps)
A = tensor3("A")
np.testing.assert_array_equal(
tile(A, reps).eval({A: A_test}),
expected_eval,
strict=True,
)
# Test the four-dimensional case.
x = tensor4()
x_ = rng.standard_normal((2, 4, 3, 5)).astype(config.floatX)
assert np.all(
run_tile(x, x_, (2, 3, 4, 6), use_symbolic_reps)
== np.tile(x_, (2, 3, 4, 6))
sym_reps = as_tensor_variable(reps).type()
np.testing.assert_array_equal(
tile(A, sym_reps).eval({A: A_test, sym_reps: reps}),
expected_eval,
strict=True,
)
# Test passing a float
x = scalar()
x_val = 1.0
assert np.array_equal(
run_tile(x, x_val, (2,), use_symbolic_reps), np.tile(x_val, (2,))
def test_mixed_reps_type(self):
A = np.arange(9).reshape(3, 3)
reps = [2, iscalar("3"), 4]
np.testing.assert_array_equal(
tile(A, reps).eval({"3": 3}),
np.tile(A, [2, 3, 4]),
strict=True,
)
def test_tensorlike_A(self):
# Test when x is a list
x = matrix()
x_val = [[1.0, 2.0], [3.0, 4.0]]
assert np.array_equal(
run_tile(x, x_val, (2,), use_symbolic_reps), np.tile(x_val, (2,))
assert equal_computations(
[tile(x_val, (2,))],
[tile(as_tensor_variable(x_val), (2,))],
)
# Test when reps is integer, scalar or vector.
# Test 1,2,3,4-dimensional cases.
# Test input x has the shape [2], [2, 4], [2, 4, 3], [2, 4, 3, 5].
test_shape = [2, 4, 3, 5]
k = 0
for xtype in [vector(), matrix(), tensor3(), tensor4()]:
x = xtype
k = k + 1
x_ = rng.standard_normal(test_shape[0:k]).astype(config.floatX)
# integer:
reps_ = 2
f = function([x], tile(x, reps_))
assert np.all(f(x_) == np.tile(x_, reps_))
# scalar:
reps = iscalar()
reps_ = 2
f = function([x, reps], tile(x, reps))
assert np.all(f(x_, reps_) == np.tile(x_, reps_))
# vector:
reps = ivector()
reps_ = [2] if k == 1 or k == 2 else [2, 3]
ndim_ = k
f = function([x, reps], tile(x, reps, ndim_))
assert np.all(f(x_, reps_) == np.tile(x_, reps_))
# list of integers:
reps_ = [2, 3, 4]
f = function([x], tile(x, reps_))
assert np.all(f(x_) == np.tile(x_, reps_))
# list of integers and scalars:
d = iscalar()
reps = [2, d, 4]
f = function([x, d], tile(x, reps))
reps_ = [2, 3, 4]
assert np.all(f(x_, 3) == np.tile(x_, reps_))
# reps is list, len(reps) > x.ndim, 3 cases below:
r = [2, 3, 4, 5, 6]
reps_ = r[: k + 1] # len(reps_) = x.ndim+1
# (1) ndim = None.
f = function([x], tile(x, reps_))
assert np.all(f(x_) == np.tile(x_, reps_))
# (2) ndim = len(reps).
ndim_ = len(reps_)
f = function([x], tile(x, reps_, ndim_))
assert np.all(f(x_) == np.tile(x_, reps_))
# (3) ndim > len(reps)
ndim_ = len(reps_) + 1
f = function([x], tile(x, reps_, ndim_))
assert np.all(f(x_) == np.tile(x_, [1, *reps_]))
# reps is list, ndim > x.ndim > len(reps):
r = [2, 3, 4, 5]
if k > 1:
ndim_ = k + 1
reps_ = r[: k - 1]
f = function([x], tile(x, reps_, ndim_))
assert np.all(f(x_) == np.tile(x_, [1, 1, *reps_]))
def test_error_unknown_reps_length(self):
# error raising test: ndim not specified when reps is vector
reps = ivector()
with pytest.raises(ValueError):
tile(x, reps)
with pytest.raises(ValueError, match="Use specify_shape to set the length"):
tile(arange(3), reps)
# error raising test: not a integer
for reps in [2.5, fscalar(), fvector()]:
# fine with specify_shape
out = tile(arange(3), specify_shape(reps, 2))
np.testing.assert_array_equal(
out.eval({reps: [2, 3]}),
np.tile(np.arange(3), [2, 3]),
strict=True,
)
def test_error_non_integer_reps(self):
for reps in (
2.5,
fscalar(),
vector(shape=(3,), dtype="float64"),
[2, fscalar()],
):
with pytest.raises(ValueError):
tile(x, reps)
tile(arange(3), reps)
# error raising test: the dimension of reps exceeds 1
reps = imatrix()
with pytest.raises(ValueError):
tile(x, reps)
# error raising test: ndim is not None, ndim < x.ndim
# 3 cases below (reps is list/scalar/vector):
for reps in [[2, 3, 4], iscalar(), ivector()]:
if k > 1:
ndim = k - 1
with pytest.raises(ValueError):
tile(x, reps, ndim)
# error raising test: reps is list, len(reps) > ndim
r = [2, 3, 4, 5, 6]
reps = r[: k + 1]
ndim = k
with pytest.raises(ValueError):
tile(x, reps, ndim)
def test_error_reps_ndim(self):
for reps in (
matrix(shape=(3, 1), dtype=int),
[2, vector(shape=(2,), dtype=int)],
):
with pytest.raises(ValueError):
tile(arange(3), reps)
def test_tile_grad(self):
A = tensor3("A")
reps = vector("reps", shape=(3,), dtype=int)
A_tile = tile(A, reps)
grad_tile = grad(A_tile.sum(), A)
# error raising test:
# reps is vector and len(reps_value) > ndim,
# reps_value is the real value when executing the function.
reps = ivector()
r = [2, 3, 4, 5, 6, 7]
reps_ = r[: k + 2]
ndim_ = k + 1
f = function([x, reps], tile(x, reps, ndim_))
with pytest.raises(AssertionError):
f(x_, reps_)
def test_tile_grad():
def grad_tile(x, reps, np_x):
y = tile(x, reps)
z = y.sum()
g = pytensor.function([x], grad(z, x))
grad_res = g(np_x)
# The gradient should be the product of the tiling dimensions
# (since the gradients are additive through the tiling operation)
assert np.all(grad_res == np.prod(reps))
rng = np.random.default_rng(utt.fetch_seed())
# test vector
grad_tile(vector("x"), [3], rng.standard_normal(5).astype(config.floatX))
# test matrix
grad_tile(matrix("x"), [3, 4], rng.standard_normal((2, 3)).astype(config.floatX))
# test tensor3
grad_tile(
tensor3("x"), [3, 4, 5], rng.standard_normal((2, 4, 3)).astype(config.floatX)
)
# test tensor4
grad_tile(
tensor4("x"),
[3, 4, 5, 6],
rng.standard_normal((2, 4, 3, 5)).astype(config.floatX),
)
rng = np.random.default_rng(2489)
A_test = rng.normal(size=(2, 4, 3)).astype(config.floatX)
reps_test = [3, 4, 5]
np.testing.assert_array_equal(
grad_tile.eval({A: A_test, reps: reps_test}),
np.full(A_test.shape, np.prod(reps_test).astype(config.floatX)),
strict=True,
)
class TestARange:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论