提交 4ac1e637 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Simplify implementation of tile

Deprecate obscure ndim kwarg
上级 c22e79e1
...@@ -10,7 +10,7 @@ import warnings ...@@ -10,7 +10,7 @@ import warnings
from collections.abc import Sequence from collections.abc import Sequence
from functools import partial from functools import partial
from numbers import Number from numbers import Number
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Union
from typing import cast as type_cast from typing import cast as type_cast
import numpy as np import numpy as np
...@@ -33,7 +33,7 @@ from pytensor.graph.type import HasShape, Type ...@@ -33,7 +33,7 @@ from pytensor.graph.type import HasShape, Type
from pytensor.link.c.op import COp from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType from pytensor.link.c.params_type import ParamsType
from pytensor.printing import Printer, min_informative_str, pprint, set_precedence from pytensor.printing import Printer, min_informative_str, pprint, set_precedence
from pytensor.raise_op import CheckAndRaise, assert_op from pytensor.raise_op import CheckAndRaise
from pytensor.scalar import int32 from pytensor.scalar import int32
from pytensor.scalar.basic import ScalarConstant, ScalarType, ScalarVariable from pytensor.scalar.basic import ScalarConstant, ScalarType, ScalarVariable
from pytensor.tensor import ( from pytensor.tensor import (
...@@ -3084,87 +3084,132 @@ def flatten(x, ndim=1): ...@@ -3084,87 +3084,132 @@ def flatten(x, ndim=1):
return x_reshaped return x_reshaped
def tile(x, reps, ndim=None): def tile(
A: "TensorLike", reps: Union[Sequence[Union[int, "TensorLike"]], "TensorLike"]
) -> TensorVariable:
""" """
Tile input array `x` according to `reps`. Tile input tensor `A` according to `reps`.
See the docstring of `numpy.tile` for details. See the docstring of `numpy.tile` for details.
'reps' can be constant integer (e.g. 3), constant vector(e.g. [2 3]), If `reps` is a PyTensor vector, its length must be statically known.
symbolic scalar (e.g. tensor.iscalar()), symbolic vector (e.g. tensor.ivector()) You can use `specify_shape` to set the length.
or a list of symbolic scalar (e.g. [tensor.iscalar(), tensor.iscalar()]).
Examples
--------
.. testcode::
import pytensor.tensor as pt
A = pt.matrix("A", dtype=int)
A_tiled = pt.tile(A, 2)
print(A_tiled.eval({A: [[1, 2], [3, 4]]}))
.. testoutput::
[[1 2 1 2]
[3 4 3 4]]
Reps can be a sequence of constants and/ or symbolic integer variables
.. testcode::
rep0 = pt.scalar("rep0", dtype=int)
A_tiled = pt.tile(A, (rep0, 1))
print(A_tiled.eval({A: [[1, 2], [3, 4]], rep0: 2}))
.. testoutput::
[[1 2]
[3 4]
[1 2]
[3 4]]
Reps can be a single integer vector, in which case its length must be statically known.
Either of the following is a valid way to specify the length:
.. testcode::
reps = pt.vector("reps", dtype=int, shape=(2,))
A_tiled = pt.tile(A, reps)
print(A_tiled.eval({A: [[1, 2], [3, 4]], reps: [1, 2]}))
.. testoutput::
[[1 2 1 2]
[3 4 3 4]]
.. testcode::
ndim is the number of the dimensions of the output, if it is provided, ndim reps = pt.vector("reps", dtype=int)
should be equal or larger than x.ndim and len(reps), otherwise, we will use reps = pt.specify_shape(reps, (2,))
max(x.ndim, len(reps)) as ndim. If reps is symbolic vector, the ndim has to A_tiled = pt.tile(A, reps)
be provided. print(A_tiled.eval({A: [[1, 2], [3, 4]], reps: [2, 2]}))
.. testoutput::
[[1 2 1 2]
[3 4 3 4]
[1 2 1 2]
[3 4 3 4]]
""" """
from pytensor.tensor.math import ge
_x = as_tensor_variable(x) A = as_tensor_variable(A)
if ndim is not None and ndim < _x.ndim:
raise ValueError("ndim should be equal or larger than _x.ndim")
# If reps is a scalar, integer or vector, we convert it to a list. # Convert symbolic reps to a tuple
if not isinstance(reps, list | tuple): if not isinstance(reps, list | tuple):
reps_astensor = as_tensor_variable(reps) reps = as_tensor_variable(reps)
ndim_check = reps_astensor.ndim if reps.type.ndim == 0:
if reps_astensor.dtype not in discrete_dtypes: reps = (reps,)
raise ValueError("elements of reps must be integer dtype") elif reps.type.ndim == 1:
try:
# The scalar/integer case reps = tuple(reps)
if ndim_check == 0: except ValueError:
reps = [reps]
# The vector case
elif ndim_check == 1:
if ndim is None:
raise ValueError( raise ValueError(
"if reps is tensor.vector, you should specify the ndim" "Length of repetitions tensor cannot be determined. Use specify_shape to set the length."
) )
else: else:
offset = ndim - reps.shape[0] raise ValueError(
f"Repetitions tensor must be a scalar or a vector, got ndim={reps.type.ndim}"
# assert that reps.shape[0] does not exceed ndim )
offset = assert_op(offset, ge(offset, 0))
# if reps.ndim is less than _x.ndim, we pad the reps with reps = [as_tensor_variable(rep) for rep in reps]
# "1" so that reps will have the same ndim as _x. if not all(
reps_ = [switch(i < offset, 1, reps[i - offset]) for i in range(ndim)] rep.type.ndim == 0 and rep.type.dtype in discrete_dtypes for rep in reps
reps = reps_ ):
raise ValueError(
f"All reps entries shoud be scalar integers, got {reps} of type {[rep.type for rep in reps]}"
)
# For others, raise an error len_reps = len(reps)
else: out_ndim = builtins.max(len_reps, A.type.ndim)
raise ValueError("the dimension of reps should not exceed 1")
else: # Pad reps on the left (if needed)
if ndim is not None and len(reps) > ndim: if len_reps < out_ndim:
raise ValueError("len(reps) should be equal or less than ndim") reps = (*((1,) * (out_ndim - len_reps)), *reps)
if not all(
isinstance(r, int) # Pad A's shape on the left (if needed)
or (isinstance(r, TensorVariable) and r.dtype in discrete_dtypes) elif A.type.ndim < out_ndim:
for r in reps A = shape_padleft(A, out_ndim - A.type.ndim)
):
raise ValueError("elements of reps must be scalars of integer dtype") # Expand every other dim of A and expand n-reps via Alloc
# A_replicated = alloc(A[None, :, ..., None, :], reps[0], A.shape[0], ..., reps[-1], A.shape[-1])
A_shape = A.shape
interleaved_reps_shape = [
d for pair in zip(reps, A_shape, strict=True) for d in pair
]
every_other_axis = tuple(range(0, out_ndim * 2, 2))
A_replicated = alloc(
expand_dims(A, every_other_axis),
*interleaved_reps_shape,
)
# If reps.ndim is less than _x.ndim, we pad the reps with # Combine replicate and original dimensions via reshape
# "1" so that reps will have the same ndim as _x # A_tiled = A_replicated.reshape(reps[0] * A.shape[0], ..., reps[-1] * A.shape[-1])
reps = list(reps) tiled_shape = tuple(rep * A_dim for rep, A_dim in zip(reps, A_shape, strict=True))
if ndim is None: return A_replicated.reshape(tiled_shape)
ndim = builtins.max(len(reps), _x.ndim)
if len(reps) < ndim:
reps = [1] * (ndim - len(reps)) + reps
_shape = [1] * (ndim - _x.ndim) + [_x.shape[i] for i in range(_x.ndim)]
alloc_shape = reps + _shape
y = alloc(_x, *alloc_shape)
shuffle_ind = np.arange(ndim * 2).reshape(2, ndim)
shuffle_ind = shuffle_ind.transpose().flatten()
y = y.dimshuffle(*shuffle_ind)
new_shapes = [sh * reps[i] for i, sh in enumerate(_shape)]
y = y.reshape(new_shapes)
return y
class ARange(Op): class ARange(Op):
......
...@@ -2386,194 +2386,120 @@ def test_is_flat(): ...@@ -2386,194 +2386,120 @@ def test_is_flat():
assert not ptb.is_flat(X.reshape((iscalar(),) * 3)) assert not ptb.is_flat(X.reshape((iscalar(),) * 3))
def test_tile(): class TestTile:
""" @pytest.mark.parametrize(
TODO FIXME: Split this apart and parameterize. Also, find out why it's "A_shape, reps_test",
unreasonably slow. [
""" ((), (2,)),
((5,), (2,)),
((2, 4), (2, 3)),
((2, 4), (2, 3, 4)),
((2, 4, 3), (2, 3)),
((2, 4, 3), (2, 3, 4)),
((2, 4, 3, 5), (2, 3, 4, 6)),
],
)
def test_tile_separate_reps_entries(self, A_shape, reps_test):
rng = np.random.default_rng(2400)
def run_tile(x, x_, reps, use_symbolic_reps): A = tensor("A", shape=(None,) * len(A_shape))
if use_symbolic_reps: reps = [iscalar(f"r{i}") for i in range(len(reps_test))]
rep_symbols = [iscalar() for _ in range(len(reps))] tile_out = tile(A, reps)
f = function([x, *rep_symbols], tile(x, rep_symbols))
return f(*([x_, *reps]))
else:
f = function([x], tile(x, reps))
return f(x_)
rng = np.random.default_rng(utt.fetch_seed()) tile_fn = function([A, *reps], tile_out)
for use_symbolic_reps in [False, True]: A_test = rng.standard_normal(A_shape).astype(config.floatX)
# Test the one-dimensional case. np.testing.assert_array_equal(
x = vector() tile_fn(A_test, *reps_test),
x_ = rng.standard_normal(5).astype(config.floatX) np.tile(A_test, reps_test),
assert np.all(run_tile(x, x_, (2,), use_symbolic_reps) == np.tile(x_, (2,))) strict=True,
)
# Test the two-dimensional case. @pytest.mark.parametrize("reps", (2, np.array([2, 3, 4])))
x = matrix() def test_combined_reps_entries(self, reps):
x_ = rng.standard_normal((2, 4)).astype(config.floatX) rng = np.random.default_rng(2422)
assert np.all(run_tile(x, x_, (2, 3), use_symbolic_reps) == np.tile(x_, (2, 3))) A_test = rng.standard_normal((2, 4, 3)).astype(config.floatX)
expected_eval = np.tile(A_test, reps)
# Test the three-dimensional case.
x = tensor3() A = tensor3("A")
x_ = rng.standard_normal((2, 4, 3)).astype(config.floatX) np.testing.assert_array_equal(
assert np.all( tile(A, reps).eval({A: A_test}),
run_tile(x, x_, (2, 3, 4), use_symbolic_reps) == np.tile(x_, (2, 3, 4)) expected_eval,
strict=True,
) )
# Test the four-dimensional case. sym_reps = as_tensor_variable(reps).type()
x = tensor4() np.testing.assert_array_equal(
x_ = rng.standard_normal((2, 4, 3, 5)).astype(config.floatX) tile(A, sym_reps).eval({A: A_test, sym_reps: reps}),
assert np.all( expected_eval,
run_tile(x, x_, (2, 3, 4, 6), use_symbolic_reps) strict=True,
== np.tile(x_, (2, 3, 4, 6))
) )
# Test passing a float def test_mixed_reps_type(self):
x = scalar() A = np.arange(9).reshape(3, 3)
x_val = 1.0 reps = [2, iscalar("3"), 4]
assert np.array_equal( np.testing.assert_array_equal(
run_tile(x, x_val, (2,), use_symbolic_reps), np.tile(x_val, (2,)) tile(A, reps).eval({"3": 3}),
np.tile(A, [2, 3, 4]),
strict=True,
) )
def test_tensorlike_A(self):
# Test when x is a list # Test when x is a list
x = matrix()
x_val = [[1.0, 2.0], [3.0, 4.0]] x_val = [[1.0, 2.0], [3.0, 4.0]]
assert np.array_equal( assert equal_computations(
run_tile(x, x_val, (2,), use_symbolic_reps), np.tile(x_val, (2,)) [tile(x_val, (2,))],
[tile(as_tensor_variable(x_val), (2,))],
) )
# Test when reps is integer, scalar or vector. def test_error_unknown_reps_length(self):
# Test 1,2,3,4-dimensional cases.
# Test input x has the shape [2], [2, 4], [2, 4, 3], [2, 4, 3, 5].
test_shape = [2, 4, 3, 5]
k = 0
for xtype in [vector(), matrix(), tensor3(), tensor4()]:
x = xtype
k = k + 1
x_ = rng.standard_normal(test_shape[0:k]).astype(config.floatX)
# integer:
reps_ = 2
f = function([x], tile(x, reps_))
assert np.all(f(x_) == np.tile(x_, reps_))
# scalar:
reps = iscalar()
reps_ = 2
f = function([x, reps], tile(x, reps))
assert np.all(f(x_, reps_) == np.tile(x_, reps_))
# vector:
reps = ivector()
reps_ = [2] if k == 1 or k == 2 else [2, 3]
ndim_ = k
f = function([x, reps], tile(x, reps, ndim_))
assert np.all(f(x_, reps_) == np.tile(x_, reps_))
# list of integers:
reps_ = [2, 3, 4]
f = function([x], tile(x, reps_))
assert np.all(f(x_) == np.tile(x_, reps_))
# list of integers and scalars:
d = iscalar()
reps = [2, d, 4]
f = function([x, d], tile(x, reps))
reps_ = [2, 3, 4]
assert np.all(f(x_, 3) == np.tile(x_, reps_))
# reps is list, len(reps) > x.ndim, 3 cases below:
r = [2, 3, 4, 5, 6]
reps_ = r[: k + 1] # len(reps_) = x.ndim+1
# (1) ndim = None.
f = function([x], tile(x, reps_))
assert np.all(f(x_) == np.tile(x_, reps_))
# (2) ndim = len(reps).
ndim_ = len(reps_)
f = function([x], tile(x, reps_, ndim_))
assert np.all(f(x_) == np.tile(x_, reps_))
# (3) ndim > len(reps)
ndim_ = len(reps_) + 1
f = function([x], tile(x, reps_, ndim_))
assert np.all(f(x_) == np.tile(x_, [1, *reps_]))
# reps is list, ndim > x.ndim > len(reps):
r = [2, 3, 4, 5]
if k > 1:
ndim_ = k + 1
reps_ = r[: k - 1]
f = function([x], tile(x, reps_, ndim_))
assert np.all(f(x_) == np.tile(x_, [1, 1, *reps_]))
# error raising test: ndim not specified when reps is vector # error raising test: ndim not specified when reps is vector
reps = ivector() reps = ivector()
with pytest.raises(ValueError): with pytest.raises(ValueError, match="Use specify_shape to set the length"):
tile(x, reps) tile(arange(3), reps)
# error raising test: not a integer # fine with specify_shape
for reps in [2.5, fscalar(), fvector()]: out = tile(arange(3), specify_shape(reps, 2))
np.testing.assert_array_equal(
out.eval({reps: [2, 3]}),
np.tile(np.arange(3), [2, 3]),
strict=True,
)
def test_error_non_integer_reps(self):
for reps in (
2.5,
fscalar(),
vector(shape=(3,), dtype="float64"),
[2, fscalar()],
):
with pytest.raises(ValueError): with pytest.raises(ValueError):
tile(x, reps) tile(arange(3), reps)
# error raising test: the dimension of reps exceeds 1 def test_error_reps_ndim(self):
reps = imatrix() for reps in (
with pytest.raises(ValueError): matrix(shape=(3, 1), dtype=int),
tile(x, reps) [2, vector(shape=(2,), dtype=int)],
):
# error raising test: ndim is not None, ndim < x.ndim with pytest.raises(ValueError):
# 3 cases below (reps is list/scalar/vector): tile(arange(3), reps)
for reps in [[2, 3, 4], iscalar(), ivector()]:
if k > 1: def test_tile_grad(self):
ndim = k - 1 A = tensor3("A")
with pytest.raises(ValueError): reps = vector("reps", shape=(3,), dtype=int)
tile(x, reps, ndim) A_tile = tile(A, reps)
grad_tile = grad(A_tile.sum(), A)
# error raising test: reps is list, len(reps) > ndim
r = [2, 3, 4, 5, 6]
reps = r[: k + 1]
ndim = k
with pytest.raises(ValueError):
tile(x, reps, ndim)
# error raising test:
# reps is vector and len(reps_value) > ndim,
# reps_value is the real value when executing the function.
reps = ivector()
r = [2, 3, 4, 5, 6, 7]
reps_ = r[: k + 2]
ndim_ = k + 1
f = function([x, reps], tile(x, reps, ndim_))
with pytest.raises(AssertionError):
f(x_, reps_)
def test_tile_grad():
def grad_tile(x, reps, np_x):
y = tile(x, reps)
z = y.sum()
g = pytensor.function([x], grad(z, x))
grad_res = g(np_x)
# The gradient should be the product of the tiling dimensions # The gradient should be the product of the tiling dimensions
# (since the gradients are additive through the tiling operation) # (since the gradients are additive through the tiling operation)
assert np.all(grad_res == np.prod(reps)) rng = np.random.default_rng(2489)
A_test = rng.normal(size=(2, 4, 3)).astype(config.floatX)
rng = np.random.default_rng(utt.fetch_seed()) reps_test = [3, 4, 5]
np.testing.assert_array_equal(
# test vector grad_tile.eval({A: A_test, reps: reps_test}),
grad_tile(vector("x"), [3], rng.standard_normal(5).astype(config.floatX)) np.full(A_test.shape, np.prod(reps_test).astype(config.floatX)),
# test matrix strict=True,
grad_tile(matrix("x"), [3, 4], rng.standard_normal((2, 3)).astype(config.floatX)) )
# test tensor3
grad_tile(
tensor3("x"), [3, 4, 5], rng.standard_normal((2, 4, 3)).astype(config.floatX)
)
# test tensor4
grad_tile(
tensor4("x"),
[3, 4, 5, 6],
rng.standard_normal((2, 4, 3, 5)).astype(config.floatX),
)
class TestARange: class TestARange:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论