提交 ffdde1cd authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Implement gradient for vector repetitions

Also cleans up implementation and documentation
上级 da4960b8
...@@ -646,12 +646,17 @@ class Repeat(Op): ...@@ -646,12 +646,17 @@ class Repeat(Op):
__props__ = ("axis",) __props__ = ("axis",)
def __init__(self, axis=None): def __init__(self, axis: int | None = None):
if axis is not None:
if not isinstance(axis, int) or axis < 0:
raise ValueError(
f"Repeat only accepts positive integer axis or None, got {axis}"
)
self.axis = axis self.axis = axis
def make_node(self, x, repeats): def make_node(self, x, repeats):
x = ptb.as_tensor_variable(x) x = ptb.as_tensor_variable(x)
repeats = ptb.as_tensor_variable(repeats) repeats = ptb.as_tensor_variable(repeats, dtype="int64")
if repeats.dtype not in integer_dtypes: if repeats.dtype not in integer_dtypes:
raise TypeError("repeats.dtype must be an integer.") raise TypeError("repeats.dtype must be an integer.")
...@@ -687,17 +692,12 @@ class Repeat(Op): ...@@ -687,17 +692,12 @@ class Repeat(Op):
out_shape = list(x.type.shape) out_shape = list(x.type.shape)
out_shape[self.axis] = None out_shape[self.axis] = None
out_type = TensorType( out_type = TensorType(x.dtype, shape=out_shape)
x.dtype, shape=tuple(1 if s == 1 else None for s in out_shape)
)
return Apply(self, [x, repeats], [out_type()]) return Apply(self, [x, repeats], [out_type()])
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
x = inputs[0] [x, repeats] = inputs
repeats = inputs[1] output_storage[0][0] = np.repeat(x, repeats=repeats, axis=self.axis)
z = output_storage[0]
z[0] = np.repeat(x, repeats=repeats, axis=self.axis)
def connection_pattern(self, node): def connection_pattern(self, node):
return [[True], [False]] return [[True], [False]]
...@@ -705,40 +705,51 @@ class Repeat(Op): ...@@ -705,40 +705,51 @@ class Repeat(Op):
def grad(self, inputs, gout): def grad(self, inputs, gout):
(x, repeats) = inputs (x, repeats) = inputs
(gz,) = gout (gz,) = gout
axis = self.axis
if repeats.ndim == 0: if repeats.ndim == 0:
if self.axis is None: # When axis is a scalar (same number of reps for all elements),
axis = x.ndim # We can split the repetitions into their own axis with reshape and sum them back
else: # to the original element location
if self.axis >= 0: sum_axis = x.ndim if axis is None else axis + 1
axis = self.axis + 1 shape = list(x.shape)
else: shape.insert(sum_axis, repeats)
axis = self.axis + x.ndim + 1 gx = gz.reshape(shape).sum(axis=sum_axis)
shape = [x.shape[k] for k in range(x.ndim)]
shape.insert(axis, repeats)
return [
gz.reshape(shape, ndim=x.ndim + 1).sum(axis=axis),
DisconnectedType()(),
]
elif repeats.ndim == 1: elif repeats.ndim == 1:
# For this implementation, we would need to specify the length # To sum the gradients that belong to the same repeated x,
# of repeats in order to split gz in the right way to sum # We create a repeated eye and dot product it with the gradient.
# the good part. axis_size = x.size if axis is None else x.shape[axis]
raise NotImplementedError() repeated_eye = repeat(
ptb.eye(axis_size), repeats, axis=0
) # A sparse repeat would be neat
if axis is None:
gx = gz @ repeated_eye
# Undo the ravelling when axis=None
gx = gx.reshape(x.shape)
else:
# Place gradient axis at end for dot product
gx = ptb.moveaxis(gz, axis, -1)
gx = gx @ repeated_eye
# Place gradient back into the correct axis
gx = ptb.moveaxis(gx, -1, axis)
else: else:
raise ValueError() raise ValueError()
return [gx, DisconnectedType()()]
def infer_shape(self, fgraph, node, ins_shapes): def infer_shape(self, fgraph, node, ins_shapes):
i0_shapes = ins_shapes[0] i0_shapes = ins_shapes[0]
repeats = node.inputs[1] repeats = node.inputs[1]
out_shape = list(i0_shapes) out_shape = list(i0_shapes)
axis = self.axis
# uint64 shape are not supported. # uint64 shape are not supported.
dtype = None dtype = None
if repeats.dtype in ("uint8", "uint16", "uint32"): if repeats.dtype in ("uint8", "uint16", "uint32"):
dtype = "int64" dtype = "int64"
if self.axis is None: if axis is None:
if repeats.ndim == 0: if repeats.ndim == 0:
if len(i0_shapes) == 0: if len(i0_shapes) == 0:
out_shape = [repeats] out_shape = [repeats]
...@@ -751,82 +762,115 @@ class Repeat(Op): ...@@ -751,82 +762,115 @@ class Repeat(Op):
out_shape = [pt_sum(repeats, dtype=dtype)] out_shape = [pt_sum(repeats, dtype=dtype)]
else: else:
if repeats.ndim == 0: if repeats.ndim == 0:
out_shape[self.axis] = out_shape[self.axis] * repeats out_shape[axis] = out_shape[axis] * repeats
else: else:
out_shape[self.axis] = pt_sum(repeats, dtype=dtype) out_shape[axis] = pt_sum(repeats, dtype=dtype)
return [out_shape] return [out_shape]
def repeat(x, repeats, axis=None): def repeat(
"""Repeat elements of an array. a: TensorLike, repeats: TensorLike, axis: int or None = None
) -> TensorVariable:
"""Repeat elements of a tensor.
It returns an array which has the same shape as `x`, except along the given See :func:`numpy.repeat` for more information.
`axis`. The `axis` parameter is used to specify the axis along which values
are repeated. By default, a flattened version of `x` is used.
The number of repetitions for each element is `repeats`. `repeats` is
broadcasted to fit the length of the given `axis`.
Parameters Parameters
---------- ----------
x a: tensor_like
Input data, tensor variable. Input tensor
repeats repeats: tensor_like
int, scalar or tensor variable The number of repetitions for each element. repeats is broadcasted to fit the shape of the given axis.
axis : int, optional axis : int, optional
The axis along which to repeat values. By default, use the flattened input array, and return a flat output array.
See Also Returns
-------
repeated_tensor: TensorVariable
Output tensor which as the same shape as a, except along the given axis
Examples
-------- --------
tensor.tile
.. testcode::
import pytensor.tensor as pt
a = pt.arange(4).reshape((2, 2))
out = pt.repeat(a, repeats=[2, 3], axis=0)
print(out.eval())
.. testoutput::
[[0 1]
[0 1]
[2 3]
[2 3]
[2 3]]
When axis is None, the array is first flattened and then repeated
.. testcode::
import pytensor.tensor as pt
a = pt.arange(4).reshape((2, 2))
out = pt.repeat(a, repeats=[2, 3, 0, 1], axis=None)
print(out.eval())
.. testoutput::
[0 0 1 1 1 3]
.. versionadded:: 0.6 .. versionadded:: 0.6
""" """
a = ptb.as_tensor_variable(a)
if axis is not None:
axis = normalize_axis_index(axis, a.ndim)
repeats = ptb.as_tensor_variable(repeats, dtype=np.int64) repeats = ptb.as_tensor_variable(repeats, dtype=np.int64)
if repeats.ndim > 1: if repeats.ndim > 1:
raise ValueError("The dimension of repeats should not exceed 1.") raise ValueError("The dimension of repeats should not exceed 1.")
if repeats.ndim == 1 and not repeats.broadcastable[0]: if repeats.ndim == 1 and not repeats.broadcastable[0]:
return Repeat(axis=axis)(x, repeats) # We only use the Repeat Op for vector repeats
return Repeat(axis=axis)(a, repeats)
else: else:
if repeats.ndim == 1: if repeats.ndim == 1:
repeats = repeats[0] repeats = repeats[0]
if x.dtype == "uint64": if a.dtype == "uint64":
# Multiplying int64 (shape) by uint64 (repeats) yields a float64
# Which is not valid for the `reshape` operation at the end
raise TypeError("repeat doesn't support dtype uint64") raise TypeError("repeat doesn't support dtype uint64")
if axis is None: if axis is None:
axis = 0 axis = 0
x = x.flatten() a = a.flatten()
else:
if axis >= x.ndim:
raise ValueError("Axis should not exceed x.ndim-1.")
if axis < 0:
axis = x.ndim + axis
shape = [x.shape[i] for i in range(x.ndim)] repeat_shape = list(a.shape)
# shape_ is the shape of the intermediate tensor which has # alloc_shape is the shape of the intermediate tensor which has
# an additional dimension comparing to x. We use alloc to # an additional dimension comparing to x. We use alloc to
# allocate space for this intermediate tensor to replicate x # allocate space for this intermediate tensor to replicate x
# along that additional dimension. # along that additional dimension.
shape_ = shape[:] alloc_shape = repeat_shape[:]
shape_.insert(axis + 1, repeats) alloc_shape.insert(axis + 1, repeats)
# shape is now the shape of output, where shape[axis] becomes # repeat_shape is now the shape of output, where shape[axis] becomes
# shape[axis]*repeats. # shape[axis]*repeats.
shape[axis] = shape[axis] * repeats repeat_shape[axis] = repeat_shape[axis] * repeats
# dims_ is the dimension of that intermediate tensor.
dims_ = list(np.arange(x.ndim))
dims_.insert(axis + 1, "x")
# After the original tensor is duplicated along the additional # After the original tensor is duplicated along the additional
# dimension, we reshape it to the expected output shape, and # dimension, we reshape it to the expected output shape
# return the output z. return ptb.alloc(ptb.expand_dims(a, axis + 1), *alloc_shape).reshape(
z = ptb.alloc(x.dimshuffle(*dims_), *shape_).reshape(shape) repeat_shape
return z )
class Bartlett(Op): class Bartlett(Op):
......
...@@ -595,7 +595,6 @@ class TestRepeat(utt.InferShapeTester): ...@@ -595,7 +595,6 @@ class TestRepeat(utt.InferShapeTester):
isinstance(n.op, Repeat) for n in f.maker.fgraph.toposort() isinstance(n.op, Repeat) for n in f.maker.fgraph.toposort()
) )
@pytest.mark.slow
@pytest.mark.parametrize("ndim", [1, 3]) @pytest.mark.parametrize("ndim", [1, 3])
@pytest.mark.parametrize("dtype", ["int8", "uint8", "uint64"]) @pytest.mark.parametrize("dtype", ["int8", "uint8", "uint64"])
def test_infer_shape(self, ndim, dtype): def test_infer_shape(self, ndim, dtype):
...@@ -606,6 +605,10 @@ class TestRepeat(utt.InferShapeTester): ...@@ -606,6 +605,10 @@ class TestRepeat(utt.InferShapeTester):
a = rng.random(shp).astype(config.floatX) a = rng.random(shp).astype(config.floatX)
for axis in self._possible_axis(ndim): for axis in self._possible_axis(ndim):
if axis is not None and axis < 0:
# Operator does not support negative axis
continue
r_var = scalar(dtype=dtype) r_var = scalar(dtype=dtype)
r = np.asarray(3, dtype=dtype) r = np.asarray(3, dtype=dtype)
if dtype in self.numpy_unsupported_dtypes: if dtype in self.numpy_unsupported_dtypes:
...@@ -635,12 +638,23 @@ class TestRepeat(utt.InferShapeTester): ...@@ -635,12 +638,23 @@ class TestRepeat(utt.InferShapeTester):
self.op_class, self.op_class,
) )
@pytest.mark.parametrize("ndim", range(3)) @pytest.mark.parametrize("x_ndim", [2, 3], ids=lambda x: f"x_ndim={x}")
def test_grad(self, ndim): @pytest.mark.parametrize("repeats_ndim", [0, 1], ids=lambda r: f"repeats_ndim={r}")
a = np.random.random((10,) * ndim).astype(config.floatX) @pytest.mark.parametrize("axis", [None, 0, 1], ids=lambda a: f"axis={a}")
def test_grad(self, x_ndim, repeats_ndim, axis):
for axis in self._possible_axis(ndim): rng = np.random.default_rng(
utt.verify_grad(lambda x: Repeat(axis=axis)(x, 3), [a]) [653, x_ndim, 2 if axis is None else axis, repeats_ndim]
)
x_test = rng.normal(size=np.arange(3, 3 + x_ndim))
if repeats_ndim == 0:
repeats_size = ()
else:
repeats_size = (x_test.shape[axis] if axis is not None else x_test.size,)
repeats = rng.integers(1, 6, size=repeats_size)
utt.verify_grad(
lambda x: Repeat(axis=axis)(x, repeats),
[x_test],
)
def test_broadcastable(self): def test_broadcastable(self):
x = TensorType(config.floatX, shape=(None, 1, None))() x = TensorType(config.floatX, shape=(None, 1, None))()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论