提交 e00abf32 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Inverse need not be a symbolic input in `PermuteRowElements`

上级 3cdcfde4
......@@ -3481,20 +3481,18 @@ class PermuteRowElements(Op):
permutation instead.
"""
__props__ = ()
__props__ = ("inverse",)
def __init__(self, inverse: bool):
super().__init__()
self.inverse = inverse
def make_node(self, x, y, inverse):
def make_node(self, x, y):
x = as_tensor_variable(x)
y = as_tensor_variable(y)
if inverse: # as_tensor_variable does not accept booleans
inverse = as_tensor_variable(1)
else:
inverse = as_tensor_variable(0)
# y should contain integers
assert y.type.dtype in integer_dtypes
# Inverse should be an integer scalar
assert inverse.type.ndim == 0 and inverse.type.dtype in integer_dtypes
# Match shapes of x and y
x_dim = x.type.ndim
......@@ -3511,7 +3509,7 @@ class PermuteRowElements(Op):
]
out_type = tensor(dtype=x.type.dtype, shape=out_shape)
inputlist = [x, y, inverse]
inputlist = [x, y]
outputlist = [out_type]
return Apply(self, inputlist, outputlist)
......@@ -3564,7 +3562,7 @@ class PermuteRowElements(Op):
raise ValueError(f"Dimension mismatch: {xs0}, {ys0}")
def perform(self, node, inp, out):
x, y, inverse = inp
x, y = inp
(outs,) = out
x_s = x.shape
y_s = y.shape
......@@ -3587,7 +3585,7 @@ class PermuteRowElements(Op):
if outs[0] is None or outs[0].shape != out_s:
outs[0] = np.empty(out_s, dtype=x.dtype)
self._rec_perform(node, x, y, inverse, outs[0], curdim=0)
self._rec_perform(node, x, y, self.inverse, outs[0], curdim=0)
def infer_shape(self, fgraph, node, in_shapes):
from pytensor.tensor.math import maximum
......@@ -3599,14 +3597,14 @@ class PermuteRowElements(Op):
return [out_shape]
def grad(self, inp, grads):
from pytensor.tensor.math import Sum, eq
from pytensor.tensor.math import Sum
x, y, inverse = inp
x, y = inp
(gz,) = grads
# First, compute the gradient wrt the broadcasted x.
# If 'inverse' is False (0), apply the inverse of y on gz.
# Else, apply y on gz.
gx = permute_row_elements(gz, y, eq(inverse, 0))
gx = permute_row_elements(gz, y, not self.inverse)
# If x has been broadcasted along some axes, we need to sum
# the gradient over these axes, but keep the dimension (as
......@@ -3643,20 +3641,17 @@ class PermuteRowElements(Op):
if x.type.dtype in discrete_dtypes:
gx = x.zeros_like()
# The elements of y and of inverse both affect the output,
# The elements of y affect the output,
# so they are connected to the output,
# and the transformation isn't defined if their values
# are non-integer, so the gradient with respect to them is
# undefined
return [gx, grad_undefined(self, 1, y), grad_undefined(self, 1, inverse)]
_permute_row_elements = PermuteRowElements()
return [gx, grad_undefined(self, 1, y)]
def permute_row_elements(x, y, inverse=0):
return _permute_row_elements(x, y, inverse)
def permute_row_elements(x, y, inverse=False):
return PermuteRowElements(inverse=inverse)(x, y)
def inverse_permutation(perm):
......
......@@ -1147,7 +1147,7 @@ def merge_two_slices(fgraph, slice1, len1, slice2, len2):
val = switch(le(len2, 0), len1 + 1, val)
val = switch(ge(sl2, len2), len1 + 1, val)
val = switch(lt(sl2, 0), -len1 - 1, val)
if sl1.step:
if sl1.step is not None:
val = switch(eq(sl1.step, 0), len1 + 1, val)
return val
else:
......
......@@ -3972,13 +3972,12 @@ class TestInferShape(utt.InferShapeTester):
advec = dvector()
aivec = ivector()
abool = True
rng = np.random.default_rng(utt.fetch_seed())
advec_val = random(5)
aivec_val = rng.permutation(5).astype("int32")
self._compile_and_check(
[advec, aivec],
[PermuteRowElements()(advec, aivec, abool)],
[PermuteRowElements(inverse=True)(advec, aivec)],
[advec_val, aivec_val],
PermuteRowElements,
)
......@@ -3986,7 +3985,7 @@ class TestInferShape(utt.InferShapeTester):
admat_val = random(3, 5)
self._compile_and_check(
[admat, aivec],
[PermuteRowElements()(admat, aivec, abool)],
[PermuteRowElements(inverse=False)(admat, aivec)],
[admat_val, aivec_val],
PermuteRowElements,
)
......@@ -3995,7 +3994,7 @@ class TestInferShape(utt.InferShapeTester):
adtens3_val = random(3, 2, 5)
self._compile_and_check(
[adtens3, aivec],
[PermuteRowElements()(adtens3, aivec, abool)],
[PermuteRowElements(inverse=True)(adtens3, aivec)],
[adtens3_val, aivec_val],
PermuteRowElements,
)
......@@ -4008,7 +4007,7 @@ class TestInferShape(utt.InferShapeTester):
admat_val = random(3, 5)
self._compile_and_check(
[admat, aimat],
[PermuteRowElements()(admat, aimat, abool)],
[PermuteRowElements(inverse=False)(admat, aimat)],
[admat_val, aimat_val],
PermuteRowElements,
)
......@@ -4023,7 +4022,7 @@ class TestInferShape(utt.InferShapeTester):
aitens3_val[1, ::, ::] = bimat_val
self._compile_and_check(
[admat, aitens3],
[PermuteRowElements()(admat, aitens3, abool)],
[PermuteRowElements(inverse=True)(admat, aitens3)],
[admat_val, aitens3_val],
PermuteRowElements,
)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论