提交 e00abf32 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Inverse need not be a symbolic input in `PermuteRowElements`

上级 3cdcfde4
...@@ -3481,20 +3481,18 @@ class PermuteRowElements(Op): ...@@ -3481,20 +3481,18 @@ class PermuteRowElements(Op):
permutation instead. permutation instead.
""" """
__props__ = () __props__ = ("inverse",)
def __init__(self, inverse: bool):
super().__init__()
self.inverse = inverse
def make_node(self, x, y, inverse): def make_node(self, x, y):
x = as_tensor_variable(x) x = as_tensor_variable(x)
y = as_tensor_variable(y) y = as_tensor_variable(y)
if inverse: # as_tensor_variable does not accept booleans
inverse = as_tensor_variable(1)
else:
inverse = as_tensor_variable(0)
# y should contain integers # y should contain integers
assert y.type.dtype in integer_dtypes assert y.type.dtype in integer_dtypes
# Inverse should be an integer scalar
assert inverse.type.ndim == 0 and inverse.type.dtype in integer_dtypes
# Match shapes of x and y # Match shapes of x and y
x_dim = x.type.ndim x_dim = x.type.ndim
...@@ -3511,7 +3509,7 @@ class PermuteRowElements(Op): ...@@ -3511,7 +3509,7 @@ class PermuteRowElements(Op):
] ]
out_type = tensor(dtype=x.type.dtype, shape=out_shape) out_type = tensor(dtype=x.type.dtype, shape=out_shape)
inputlist = [x, y, inverse] inputlist = [x, y]
outputlist = [out_type] outputlist = [out_type]
return Apply(self, inputlist, outputlist) return Apply(self, inputlist, outputlist)
...@@ -3564,7 +3562,7 @@ class PermuteRowElements(Op): ...@@ -3564,7 +3562,7 @@ class PermuteRowElements(Op):
raise ValueError(f"Dimension mismatch: {xs0}, {ys0}") raise ValueError(f"Dimension mismatch: {xs0}, {ys0}")
def perform(self, node, inp, out): def perform(self, node, inp, out):
x, y, inverse = inp x, y = inp
(outs,) = out (outs,) = out
x_s = x.shape x_s = x.shape
y_s = y.shape y_s = y.shape
...@@ -3587,7 +3585,7 @@ class PermuteRowElements(Op): ...@@ -3587,7 +3585,7 @@ class PermuteRowElements(Op):
if outs[0] is None or outs[0].shape != out_s: if outs[0] is None or outs[0].shape != out_s:
outs[0] = np.empty(out_s, dtype=x.dtype) outs[0] = np.empty(out_s, dtype=x.dtype)
self._rec_perform(node, x, y, inverse, outs[0], curdim=0) self._rec_perform(node, x, y, self.inverse, outs[0], curdim=0)
def infer_shape(self, fgraph, node, in_shapes): def infer_shape(self, fgraph, node, in_shapes):
from pytensor.tensor.math import maximum from pytensor.tensor.math import maximum
...@@ -3599,14 +3597,14 @@ class PermuteRowElements(Op): ...@@ -3599,14 +3597,14 @@ class PermuteRowElements(Op):
return [out_shape] return [out_shape]
def grad(self, inp, grads): def grad(self, inp, grads):
from pytensor.tensor.math import Sum, eq from pytensor.tensor.math import Sum
x, y, inverse = inp x, y = inp
(gz,) = grads (gz,) = grads
# First, compute the gradient wrt the broadcasted x. # First, compute the gradient wrt the broadcasted x.
# If 'inverse' is False (0), apply the inverse of y on gz. # If 'inverse' is False (0), apply the inverse of y on gz.
# Else, apply y on gz. # Else, apply y on gz.
gx = permute_row_elements(gz, y, eq(inverse, 0)) gx = permute_row_elements(gz, y, not self.inverse)
# If x has been broadcasted along some axes, we need to sum # If x has been broadcasted along some axes, we need to sum
# the gradient over these axes, but keep the dimension (as # the gradient over these axes, but keep the dimension (as
...@@ -3643,20 +3641,17 @@ class PermuteRowElements(Op): ...@@ -3643,20 +3641,17 @@ class PermuteRowElements(Op):
if x.type.dtype in discrete_dtypes: if x.type.dtype in discrete_dtypes:
gx = x.zeros_like() gx = x.zeros_like()
# The elements of y and of inverse both affect the output, # The elements of y affect the output,
# so they are connected to the output, # so they are connected to the output,
# and the transformation isn't defined if their values # and the transformation isn't defined if their values
# are non-integer, so the gradient with respect to them is # are non-integer, so the gradient with respect to them is
# undefined # undefined
return [gx, grad_undefined(self, 1, y), grad_undefined(self, 1, inverse)] return [gx, grad_undefined(self, 1, y)]
_permute_row_elements = PermuteRowElements()
def permute_row_elements(x, y, inverse=0): def permute_row_elements(x, y, inverse=False):
return _permute_row_elements(x, y, inverse) return PermuteRowElements(inverse=inverse)(x, y)
def inverse_permutation(perm): def inverse_permutation(perm):
......
...@@ -1147,7 +1147,7 @@ def merge_two_slices(fgraph, slice1, len1, slice2, len2): ...@@ -1147,7 +1147,7 @@ def merge_two_slices(fgraph, slice1, len1, slice2, len2):
val = switch(le(len2, 0), len1 + 1, val) val = switch(le(len2, 0), len1 + 1, val)
val = switch(ge(sl2, len2), len1 + 1, val) val = switch(ge(sl2, len2), len1 + 1, val)
val = switch(lt(sl2, 0), -len1 - 1, val) val = switch(lt(sl2, 0), -len1 - 1, val)
if sl1.step: if sl1.step is not None:
val = switch(eq(sl1.step, 0), len1 + 1, val) val = switch(eq(sl1.step, 0), len1 + 1, val)
return val return val
else: else:
......
...@@ -3972,13 +3972,12 @@ class TestInferShape(utt.InferShapeTester): ...@@ -3972,13 +3972,12 @@ class TestInferShape(utt.InferShapeTester):
advec = dvector() advec = dvector()
aivec = ivector() aivec = ivector()
abool = True
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
advec_val = random(5) advec_val = random(5)
aivec_val = rng.permutation(5).astype("int32") aivec_val = rng.permutation(5).astype("int32")
self._compile_and_check( self._compile_and_check(
[advec, aivec], [advec, aivec],
[PermuteRowElements()(advec, aivec, abool)], [PermuteRowElements(inverse=True)(advec, aivec)],
[advec_val, aivec_val], [advec_val, aivec_val],
PermuteRowElements, PermuteRowElements,
) )
...@@ -3986,7 +3985,7 @@ class TestInferShape(utt.InferShapeTester): ...@@ -3986,7 +3985,7 @@ class TestInferShape(utt.InferShapeTester):
admat_val = random(3, 5) admat_val = random(3, 5)
self._compile_and_check( self._compile_and_check(
[admat, aivec], [admat, aivec],
[PermuteRowElements()(admat, aivec, abool)], [PermuteRowElements(inverse=False)(admat, aivec)],
[admat_val, aivec_val], [admat_val, aivec_val],
PermuteRowElements, PermuteRowElements,
) )
...@@ -3995,7 +3994,7 @@ class TestInferShape(utt.InferShapeTester): ...@@ -3995,7 +3994,7 @@ class TestInferShape(utt.InferShapeTester):
adtens3_val = random(3, 2, 5) adtens3_val = random(3, 2, 5)
self._compile_and_check( self._compile_and_check(
[adtens3, aivec], [adtens3, aivec],
[PermuteRowElements()(adtens3, aivec, abool)], [PermuteRowElements(inverse=True)(adtens3, aivec)],
[adtens3_val, aivec_val], [adtens3_val, aivec_val],
PermuteRowElements, PermuteRowElements,
) )
...@@ -4008,7 +4007,7 @@ class TestInferShape(utt.InferShapeTester): ...@@ -4008,7 +4007,7 @@ class TestInferShape(utt.InferShapeTester):
admat_val = random(3, 5) admat_val = random(3, 5)
self._compile_and_check( self._compile_and_check(
[admat, aimat], [admat, aimat],
[PermuteRowElements()(admat, aimat, abool)], [PermuteRowElements(inverse=False)(admat, aimat)],
[admat_val, aimat_val], [admat_val, aimat_val],
PermuteRowElements, PermuteRowElements,
) )
...@@ -4023,7 +4022,7 @@ class TestInferShape(utt.InferShapeTester): ...@@ -4023,7 +4022,7 @@ class TestInferShape(utt.InferShapeTester):
aitens3_val[1, ::, ::] = bimat_val aitens3_val[1, ::, ::] = bimat_val
self._compile_and_check( self._compile_and_check(
[admat, aitens3], [admat, aitens3],
[PermuteRowElements()(admat, aitens3, abool)], [PermuteRowElements(inverse=True)(admat, aitens3)],
[admat_val, aitens3_val], [admat_val, aitens3_val],
PermuteRowElements, PermuteRowElements,
) )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论