提交 63c513e1 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Simplify Repeat Op to only work with specific axis and vector repeats

上级 3dcf1fb6
......@@ -220,7 +220,7 @@ def numba_typify(data, dtype=None, **kwargs):
return data
def generate_fallback_impl(op, node=None, storage_map=None, **kwargs):
def generate_fallback_impl(op, node, storage_map=None, **kwargs):
"""Create a Numba compatible function from a Pytensor `Op`."""
warnings.warn(
......
......@@ -6,7 +6,11 @@ import numpy as np
from pytensor.graph import Apply
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import get_numba_type, numba_funcify
from pytensor.link.numba.dispatch.basic import (
generate_fallback_impl,
get_numba_type,
numba_funcify,
)
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor import TensorVariable
from pytensor.tensor.extra_ops import (
......@@ -200,39 +204,10 @@ def numba_funcify_RavelMultiIndex(op, node, **kwargs):
@numba_funcify.register(Repeat)
def numba_funcify_Repeat(op, node, **kwargs):
axis = op.axis
a, _ = node.inputs
use_python = False
if axis is not None:
use_python = True
if use_python:
warnings.warn(
(
"Numba will use object mode to allow the "
"`axis` argument to `numpy.repeat`."
),
UserWarning,
)
ret_sig = get_numba_type(node.outputs[0].type)
@numba_basic.numba_njit
def repeatop(x, repeats):
with numba.objmode(ret=ret_sig):
ret = np.repeat(x, repeats, axis)
return ret
else:
repeats_ndim = node.inputs[1].ndim
if repeats_ndim == 0:
@numba_basic.numba_njit(inline="always")
def repeatop(x, repeats):
return np.repeat(x, repeats.item())
else:
# Numba only supports axis=None, which in our case is when axis is 0 and the input is a vector
if axis == 0 and a.type.ndim == 1:
@numba_basic.numba_njit(inline="always")
def repeatop(x, repeats):
......@@ -240,6 +215,9 @@ def numba_funcify_Repeat(op, node, **kwargs):
return repeatop
else:
return generate_fallback_impl(op, node)
@numba_funcify.register(Unique)
def numba_funcify_Unique(op, node, **kwargs):
......
......@@ -660,53 +660,72 @@ class Repeat(Op):
__props__ = ("axis",)
def __init__(self, axis: int | None = None):
if axis is not None:
if not isinstance(axis, int) or axis < 0:
def __init__(self, axis: int):
if isinstance(axis, int):
if axis < 0:
raise ValueError(
f"Repeat Op only accepts positive integer axis, got {axis}. "
"Use the helper `pt.repeat` to handle negative axis."
)
elif axis is None:
raise ValueError(
f"Repeat only accepts positive integer axis or None, got {axis}"
"Repeat Op only accepts positive integer axis. "
"Use the helper `pt.repeat` to handle axis=None."
)
else:
raise TypeError(
f"Invalid type for axis {axis}, expected int got {type(axis)}"
)
self.axis = axis
def make_node(self, x, repeats):
x = ptb.as_tensor_variable(x)
repeats = ptb.as_tensor_variable(repeats, dtype="int64")
if repeats.dtype not in integer_dtypes:
raise TypeError("repeats.dtype must be an integer.")
if repeats.type.ndim != 1:
if repeats.type.ndim == 0:
raise ValueError(
f"repeats {repeats} must have 1 dimension, got 0. Use the helper `pt.repeat` to handle scalar repeats."
)
else:
raise ValueError(
f"repeats {repeats} must have 1 dimension, got {repeats.type.ndim}"
)
if repeats.type.dtype not in integer_dtypes:
raise TypeError(
f"repeats {repeats} dtype must be an integer, got {repeats.type.dtype}."
)
# Some dtypes are not supported by numpy's implementation of repeat.
# Until another one is available, we should fail at graph construction
# time, not wait for execution.
ptr_bitwidth = LOCAL_BITWIDTH
if ptr_bitwidth == 64:
numpy_unsupported_dtypes = ("uint64",)
if ptr_bitwidth == 32:
numpy_unsupported_dtypes = ("uint32", "int64", "uint64")
if repeats.dtype in numpy_unsupported_dtypes:
numpy_unsupported_dtypes = (
("uint64",) if LOCAL_BITWIDTH == 64 else ("uint64", "uint32", "int64")
)
if repeats.type.dtype in numpy_unsupported_dtypes:
raise TypeError(
(
f"dtypes {numpy_unsupported_dtypes!s} are not supported by numpy.repeat "
"for the 'repeats' parameter, "
),
repeats.dtype,
f"repeats {repeats} dtype {repeats.type.dtype} are not supported by numpy.repeat"
)
if self.axis is None:
out_shape = [None]
else:
shape = list(x.type.shape)
axis_input_dim_length = shape[self.axis]
axis_output_dim_length = None
if axis_input_dim_length is not None:
# If we have a static dim and constant repeats we can infer the length of the output dim
# Right now we only support homogenous constant repeats
try:
const_reps = ptb.get_scalar_constant_value(repeats)
const_reps = ptb.get_underlying_scalar_constant_value(repeats)
except NotScalarConstantError:
const_reps = None
if const_reps == 1:
out_shape = x.type.shape
pass
else:
out_shape = list(x.type.shape)
out_shape[self.axis] = None
axis_output_dim_length = int(const_reps * axis_input_dim_length)
shape[self.axis] = axis_output_dim_length
out_type = TensorType(x.dtype, shape=out_shape)
out_type = TensorType(x.dtype, shape=shape)
return Apply(self, [x, repeats], [out_type()])
def perform(self, node, inputs, output_storage):
......@@ -720,37 +739,20 @@ class Repeat(Op):
(x, repeats) = inputs
(gz,) = gout
axis = self.axis
if repeats.ndim == 0:
# When axis is a scalar (same number of reps for all elements),
# We can split the repetitions into their own axis with reshape and sum them back
# to the original element location
sum_axis = x.ndim if axis is None else axis + 1
shape = list(x.shape)
shape.insert(sum_axis, repeats)
gx = gz.reshape(shape).sum(axis=sum_axis)
elif repeats.ndim == 1:
# To sum the gradients that belong to the same repeated x,
# We create a repeated eye and dot product it with the gradient.
axis_size = x.size if axis is None else x.shape[axis]
axis_size = x.shape[axis]
repeated_eye = repeat(
ptb.eye(axis_size), repeats, axis=0
) # A sparse repeat would be neat
if axis is None:
gx = gz @ repeated_eye
# Undo the ravelling when axis=None
gx = gx.reshape(x.shape)
else:
# Place gradient axis at end for dot product
gx = ptb.moveaxis(gz, axis, -1)
gx = gx @ repeated_eye
# Place gradient back into the correct axis
gx = ptb.moveaxis(gx, -1, axis)
else:
raise ValueError()
return [gx, DisconnectedType()()]
def infer_shape(self, fgraph, node, ins_shapes):
......@@ -763,21 +765,7 @@ class Repeat(Op):
dtype = None
if repeats.dtype in ("uint8", "uint16", "uint32"):
dtype = "int64"
if axis is None:
if repeats.ndim == 0:
if len(i0_shapes) == 0:
out_shape = [repeats]
else:
res = 1
for d in i0_shapes:
res = res * d
out_shape = (res * repeats,)
else:
out_shape = [pt_sum(repeats, dtype=dtype)]
else:
if repeats.ndim == 0:
out_shape[axis] = out_shape[axis] * repeats
else:
out_shape[axis] = pt_sum(repeats, dtype=dtype)
return [out_shape]
......@@ -843,7 +831,10 @@ def repeat(
"""
a = ptb.as_tensor_variable(a)
if axis is not None:
if axis is None:
axis = 0
a = a.flatten()
else:
axis = normalize_axis_index(axis, a.ndim)
repeats = ptb.as_tensor_variable(repeats, dtype=np.int64)
......@@ -851,40 +842,31 @@ def repeat(
if repeats.ndim > 1:
raise ValueError("The dimension of repeats should not exceed 1.")
if repeats.ndim == 1 and not repeats.broadcastable[0]:
if repeats.type.broadcastable == (True,):
# This behaves the same as scalar repeat
repeats = repeats.squeeze()
if repeats.ndim == 1:
# We only use the Repeat Op for vector repeats
return Repeat(axis=axis)(a, repeats)
else:
if repeats.ndim == 1:
repeats = repeats[0]
if a.dtype == "uint64":
# Multiplying int64 (shape) by uint64 (repeats) yields a float64
# Which is not valid for the `reshape` operation at the end
raise TypeError("repeat doesn't support dtype uint64")
if axis is None:
axis = 0
a = a.flatten()
repeat_shape = list(a.shape)
# Scalar repeat, we implement this with canonical Ops broadcast + reshape
a_shape = a.shape
# alloc_shape is the shape of the intermediate tensor which has
# an additional dimension comparing to x. We use alloc to
# allocate space for this intermediate tensor to replicate x
# along that additional dimension.
alloc_shape = repeat_shape[:]
alloc_shape.insert(axis + 1, repeats)
# Replicate a along a new axis (axis+1) repeats times
broadcast_shape = list(a_shape)
broadcast_shape.insert(axis + 1, repeats)
broadcast_a = broadcast_to(ptb.expand_dims(a, axis + 1), broadcast_shape)
# repeat_shape is now the shape of output, where shape[axis] becomes
# shape[axis]*repeats.
# Reshape broadcast_a to the final shape, merging axis and axis+1
repeat_shape = list(a_shape)
repeat_shape[axis] = repeat_shape[axis] * repeats
# After the original tensor is duplicated along the additional
# dimension, we reshape it to the expected output shape
return ptb.alloc(ptb.expand_dims(a, axis + 1), *alloc_shape).reshape(
repeat_shape
)
return broadcast_a.reshape(repeat_shape)
class Bartlett(Op):
......
......@@ -212,27 +212,15 @@ def test_RavelMultiIndex(arr, shape, mode, order, exc):
@pytest.mark.parametrize(
"x, repeats, axis, exc",
[
(
(pt.lscalar(), np.array(1, dtype="int64")),
(pt.lscalar(), np.array(0, dtype="int64")),
None,
None,
),
(
(pt.lmatrix(), np.zeros((2, 2), dtype="int64")),
(pt.lscalar(), np.array(1, dtype="int64")),
None,
None,
),
(
(pt.lvector(), np.arange(2, dtype="int64")),
(pt.lvector(), np.array([1, 1], dtype="int64")),
None,
(pt.lvector(), np.array([1, 3], dtype="int64")),
0,
None,
),
(
(pt.lmatrix(), np.zeros((2, 2), dtype="int64")),
(pt.lscalar(), np.array(1, dtype="int64")),
(pt.lvector(), np.array([1, 3], dtype="int64")),
0,
UserWarning,
),
......
......@@ -530,7 +530,7 @@ class TestRepeat(utt.InferShapeTester):
def setup_method(self):
super().setup_method()
self.op_class = Repeat
self.op = Repeat()
self.op = Repeat(axis=0)
# uint64 always fails
# int64 and uint32 also fail if python int are 32-bit
if LOCAL_BITWIDTH == 64:
......@@ -595,40 +595,27 @@ class TestRepeat(utt.InferShapeTester):
def test_infer_shape(self, ndim, dtype):
rng = np.random.default_rng(4282)
x = TensorType(config.floatX, shape=(None,) * ndim)()
a_var = TensorType(config.floatX, shape=(None,) * ndim)("a")
r_var = vector("r", dtype=dtype)
shp = (np.arange(ndim) + 1) * 3
a = rng.random(shp).astype(config.floatX)
for axis in self._possible_axis(ndim):
if axis is not None and axis < 0:
# Operator does not support negative axis
continue
r_var = scalar(dtype=dtype)
r = np.asarray(3, dtype=dtype)
if dtype in self.numpy_unsupported_dtypes:
r_var = vector(dtype=dtype)
with pytest.raises(TypeError):
repeat(x, r_var)
else:
self._compile_and_check(
[x, r_var],
[Repeat(axis=axis)(x, r_var)],
[a, r],
self.op_class,
)
repeat(a_var, r_var, axis=axis)
continue
if axis is None or axis < 0:
# Operator Repeat does not support None or negative axis
continue
r_var = vector(dtype=dtype)
if axis is None:
r = rng.integers(1, 6, size=a.size).astype(dtype)
elif a.size > 0:
r = rng.integers(1, 6, size=a.shape[axis]).astype(dtype)
else:
r = rng.integers(1, 6, size=(10,)).astype(dtype)
self._compile_and_check(
[x, r_var],
[Repeat(axis=axis)(x, r_var)],
[a_var, r_var],
[Repeat(axis=axis)(a_var, r_var)],
[a, r],
self.op_class,
)
......@@ -647,18 +634,38 @@ class TestRepeat(utt.InferShapeTester):
repeats_size = (x_test.shape[axis] if axis is not None else x_test.size,)
repeats = rng.integers(1, 6, size=repeats_size)
utt.verify_grad(
lambda x: Repeat(axis=axis)(x, repeats),
lambda x: repeat(x, repeats, axis=axis),
[x_test],
)
def test_broadcastable(self):
x = TensorType(config.floatX, shape=(None, 1, None))()
r = Repeat(axis=1)(x, 2)
assert r.broadcastable == (False, False, False)
r = Repeat(axis=1)(x, 1)
assert r.broadcastable == (False, True, False)
r = Repeat(axis=0)(x, 2)
assert r.broadcastable == (False, True, False)
def test_static_shape(self):
x = TensorType(config.floatX, shape=(None, 1, 3))()
symbolic_r = scalar(dtype="int32")
r = repeat(x, 2, axis=0)
assert r.type.shape == (None, 1, 3)
r = repeat(x, 2, axis=1)
assert r.type.shape == (None, 2, 3)
r = repeat(x, [2], axis=1)
assert r.type.shape == (None, 2, 3)
r = repeat(x, symbolic_r, axis=1)
assert r.type.shape == (None, None, 3)
r = repeat(x, 1, axis=1)
assert r.type.shape == (None, 1, 3)
r = repeat(x, 2, axis=2)
assert r.type.shape == (None, 1, 6)
r = repeat(x, [2, 2, 2], axis=2)
assert r.type.shape == (None, 1, 6)
# This case could be implemented in the future
r = repeat(x, [1, 2, 4], axis=2)
assert r.type.shape == (None, 1, None)
class TestBartlett(utt.InferShapeTester):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论