提交 e88117e6 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Only require input_ndim and not input_broadcastable in DimShuffle

上级 d68f53f8
......@@ -19,7 +19,6 @@ from pytensor.graph.op import Op
from pytensor.tensor.math import dot
from pytensor.tensor.math import max as pt_max
from pytensor.tensor.shape import reshape
from pytensor.tensor.subtensor import DimShuffle
def register_specialize(lopt, *tags, **kwargs):
......@@ -375,7 +374,7 @@ def convolve(
[images.shape[0], pt.as_tensor(np.prod(outshp)), pt.as_tensor(nkern)]
)
tensout = reshape(output, newshp, ndim=3)
output = DimShuffle((False,) * tensout.ndim, (0, 2, 1))(tensout)
output = tensout.transpose(0, 2, 1)
if flatten:
output = pt.flatten(output, 2)
......@@ -443,6 +442,6 @@ def max_pool(images, imgshp, maxpoolshp):
)
out2 = reshape(out1, pshape, ndim=3)
out3 = DimShuffle(out2.broadcastable, (0, 2, 1))(out2)
out3 = out2.transpose(0, 2, 1)
return pt.flatten(out3, 2), outshp
......@@ -2042,7 +2042,7 @@ def transpose(x, axes=None):
# No-op
return _x
ret = DimShuffle(tuple(s == 1 for s in _x.type.shape), axes)(_x)
ret = _x.dimshuffle(axes)
if _x.name and axes == tuple(range((_x.type.ndim - 1), -1, -1)):
ret.name = _x.name + ".T"
......@@ -3518,7 +3518,7 @@ class PermuteRowElements(Op):
newdims.append(i)
i += 1
gx = DimShuffle(tuple(s == 1 for s in gx.type.shape), newdims)(gx)
gx = gx.dimshuffle(newdims)
assert gx.type.ndim == x.type.ndim
assert all(
s1 == s2
......
......@@ -41,7 +41,7 @@ from pytensor.tensor.math import (
)
from pytensor.tensor.math import max as pt_max
from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.shape import Shape_i, specify_broadcastable
from pytensor.tensor.shape import Shape_i
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector
from pytensor.tensor.variable import TensorVariable
......@@ -609,11 +609,6 @@ def squeeze(x, axis=None):
# Nothing could be squeezed
return _x
# `Dimshuffle` raises when we try to drop an axis that is not statically broadcastable.
# We add a `specify_broadcastable` instead of raising.
non_broadcastable_axis = [i for i in axis if not _x.broadcastable[i]]
_x = specify_broadcastable(_x, *non_broadcastable_axis)
return _x.dimshuffle([i for i in range(_x.ndim) if i not in axis])
......
from pytensor import printing
from pytensor.printing import pprint
from pytensor.tensor.elemwise import DimShuffle, scalar_elemwise
from pytensor.tensor.elemwise import scalar_elemwise
@scalar_elemwise
......@@ -429,4 +429,4 @@ pprint.assign(pow_inplace, printing.OperatorPrinter("**=", 1, "right"))
def transpose_inplace(x, **kwargs):
"Perform a transpose on a tensor without copying the underlying storage"
dims = list(range(x.ndim - 1, -1, -1))
return DimShuffle(x.broadcastable, dims)(x)
return x.dimshuffle(dims)
......@@ -33,7 +33,6 @@ from pytensor.tensor.basic import (
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
from pytensor.tensor.elemwise import (
CAReduce,
DimShuffle,
Elemwise,
get_normalized_batch_axes,
scalar_elemwise,
......@@ -2338,8 +2337,7 @@ class Sum(FixedOpCAReduce):
else:
new_dims.append(i)
i += 1
ds_op = DimShuffle(gz.type.broadcastable, new_dims)
gx = Elemwise(ps.second)(x, ds_op(gz))
gx = Elemwise(ps.second)(x, gz.dimshuffle(new_dims))
return [gx]
def R_op(self, inputs, eval_points):
......
......@@ -65,7 +65,7 @@ def size_parameter_as_tuple(fgraph, node):
if isinstance(size_node.op, MakeVector) or (
isinstance(size_node.op, DimShuffle)
and size_node.op.input_broadcastable == ()
and size_node.op.input_ndim == 0
and size_node.op.new_order == ("x",)
):
# Here PyTensor converted a tuple or list to a tensor
......
......@@ -494,7 +494,7 @@ def local_alloc_sink_dimshuffle(fgraph, node):
dimshuffle_new_order = ["x"] * num_dims_with_size_1_added_to_left + list(
range(len(new_output_shape))
)
return [DimShuffle(inner.type.broadcastable, dimshuffle_new_order)(inner)]
return [inner.dimshuffle(dimshuffle_new_order)]
@node_rewriter([AllocEmpty])
......
......@@ -422,8 +422,6 @@ def local_dimshuffle_lift(fgraph, node):
"""
op = node.op
if not isinstance(op, DimShuffle):
return False
inp = node.inputs[0]
inode = inp.owner
......@@ -437,7 +435,7 @@ def local_dimshuffle_lift(fgraph, node):
# Don't use make_node to have tag.test_value set.
new_inputs = []
for inp in inode.inputs:
new_inp = op.__class__(inp.type.broadcastable, op.new_order)(inp)
new_inp = inp.dimshuffle(op.new_order)
new_inputs.append(apply_local_dimshuffle_lift(fgraph, new_inp))
copy_stack_trace(node.outputs[0], new_inputs)
ret = inode.op(*new_inputs, return_list=True)
......@@ -449,7 +447,7 @@ def local_dimshuffle_lift(fgraph, node):
if is_dimshuffle_useless(new_order, inp):
return [inp]
elif inode and isinstance(inode.op, DimShuffle):
ret = op.__class__(inp.type.broadcastable, new_order)(inp)
ret = inp.dimshuffle(new_order)
ret = apply_local_dimshuffle_lift(fgraph, ret)
copy_stack_trace(node.outputs[0], ret)
return [ret]
......
......@@ -130,7 +130,7 @@ def shape_parameter_as_tuple(fgraph, node):
if isinstance(shape_node.op, MakeVector) or (
isinstance(shape_node.op, DimShuffle)
and shape_node.op.input_broadcastable == ()
and shape_node.op.input_ndim == 0
and shape_node.op.new_order == ("x",)
):
# Here PyTensor converted a tuple or list to a tensor
......
......@@ -65,7 +65,7 @@ def is_matrix_transpose(x: TensorVariable) -> bool:
if ndims < 2:
return False
transpose_order = (*range(ndims - 2), ndims - 1, ndims - 2)
return cast(bool, node.op.new_order == transpose_order)
return node.op.new_order == transpose_order
return False
......
......@@ -925,11 +925,7 @@ def local_reshape_to_dimshuffle(fgraph, node):
if index != output.type.ndim:
inner = op.__class__(len(new_output_shape))(inp, new_output_shape)
copy_stack_trace(output, inner)
new_node = [
DimShuffle(tuple(s == 1 for s in inner.type.shape), dimshuffle_new_order)(
inner
)
]
new_node = [inner.dimshuffle(dimshuffle_new_order)]
copy_stack_trace(output, new_node)
return new_node
......
......@@ -344,8 +344,8 @@ class _tensor_py_operators:
"""
if (len(pattern) == 1) and (isinstance(pattern[0], list | tuple)):
pattern = pattern[0]
op = pt.elemwise.DimShuffle(list(self.type.broadcastable), pattern)
return op(self)
ds_op = pt.elemwise.DimShuffle(input_ndim=self.type.ndim, new_order=pattern)
return ds_op(self)
def flatten(self, ndim=1):
return pt.basic.flatten(self, ndim)
......
......@@ -39,7 +39,7 @@ def test_jax_Dimshuffle():
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
a_pt = tensor(dtype=config.floatX, shape=(None, 1))
x = pt_elemwise.DimShuffle([False, True], (0,))(a_pt)
x = pt_elemwise.DimShuffle(input_ndim=2, new_order=(0,))(a_pt)
x_fg = FunctionGraph([a_pt], [x])
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
......
......@@ -15,7 +15,7 @@ from pytensor.compile.sharedvalue import SharedVariable
from pytensor.gradient import grad
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import elemwise as pt_elemwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import All, Any, Max, Mean, Min, Prod, ProdWithoutZeros, Sum
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
from tests.link.numba.test_basic import (
......@@ -205,7 +205,7 @@ def test_elemwise_speed(benchmark):
],
)
def test_Dimshuffle(v, new_order):
g = pt_elemwise.DimShuffle(v.broadcastable, new_order)(v)
g = v.dimshuffle(new_order)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
......@@ -219,7 +219,7 @@ def test_Dimshuffle(v, new_order):
def test_Dimshuffle_returns_array():
x = pt.vector("x", shape=(1,))
y = 2 * pt_elemwise.DimShuffle([True], [])(x)
y = 2 * x.dimshuffle([])
func = pytensor.function([x], y, mode="NUMBA")
out = func(np.zeros(1, dtype=config.floatX))
assert out.ndim == 0
......@@ -230,7 +230,7 @@ def test_Dimshuffle_non_contiguous():
non-contiguous arrays, make sure we work around thpt."""
x = pt.dvector()
idx = pt.vector(dtype="int64")
op = pytensor.tensor.elemwise.DimShuffle([True], [])
op = DimShuffle(input_ndim=1, new_order=[])
out = op(pt.specify_shape(x[idx][::2], (1,)))
func = pytensor.function([x, idx], out, mode="NUMBA")
assert func(np.zeros(3), np.array([1])).ndim == 0
......
......@@ -5,7 +5,6 @@ import pytensor.tensor as pt
import pytensor.tensor.math as ptm
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import elemwise as pt_elemwise
from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax
from pytensor.tensor.type import matrix, tensor, tensor3, vector
from tests.link.pytorch.test_basic import compare_pytorch_and_py
......@@ -27,11 +26,6 @@ def test_pytorch_Dimshuffle():
x_fg = FunctionGraph([a_pt], [x])
compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
a_pt = tensor(dtype=config.floatX, shape=(None, 1))
x = pt_elemwise.DimShuffle([False, True], (0,))(a_pt)
x_fg = FunctionGraph([a_pt], [x])
compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
def test_multiple_input_output():
x = vector("x")
......
......@@ -79,7 +79,7 @@ dimshuffle_lift = out2in(local_dimshuffle_lift)
def ds(x, y):
return DimShuffle(x.type.broadcastable, y)(x)
return x.dimshuffle(y)
def inputs(xbc=(0, 0), ybc=(0, 0), zbc=(0, 0)):
......
......@@ -160,7 +160,7 @@ _fast_run_rewrites = optdb.query(_fast_run_rewrites)
def ds(x, y):
return DimShuffle(x.type.broadcastable, y)(x)
return x.dimshuffle(y)
def rewrite(g, level="fast_run"):
......@@ -3749,7 +3749,7 @@ def test_local_log_sum_exp_maximum():
check_max_log_sum_exp(x, axis=(0, 1, 2), dimshuffle_op=None)
# If a transpose is applied to the sum
transpose_op = DimShuffle((False, False), (1, 0))
transpose_op = DimShuffle(input_ndim=2, new_order=(1, 0))
check_max_log_sum_exp(x, axis=2, dimshuffle_op=transpose_op)
# If the sum is performed with keepdims=True
......@@ -3770,7 +3770,7 @@ def test_local_log_sum_exp_near_one():
assert np.allclose(naive_ret, rewritten_ret)
# If a transpose is applied
transpose_op = DimShuffle((False, False), (1, 0))
transpose_op = DimShuffle(input_ndim=2, new_order=(1, 0))
f = compile_graph_log_sum_exp(x, axis=(1,), dimshuffle_op=transpose_op)
naive_ret = np.log(np.sum(np.exp(x_val), axis=1).T)
rewritten_ret = f(x_val)
......
......@@ -3418,7 +3418,7 @@ def test_unalign():
def test_dimshuffle_duplicate():
x = vector()
with pytest.raises(ValueError, match="may not appear twice"):
DimShuffle((False,), (0, 0))(x)
DimShuffle(input_ndim=1, new_order=(0, 0))(x)
class TestGetUnderlyingScalarConstantValue:
......
......@@ -593,9 +593,9 @@ class TestAsScalar:
b = pt.constant(np.asarray([[[0.5]]]))
b2 = b.dimshuffle()
assert b2.ndim == 0
d_a = DimShuffle([], [])(a)
d_b = DimShuffle([True, True, True], [0, 2, 1])(b)
d_a2 = DimShuffle([], ["x", "x", "x"])(a)
d_a = DimShuffle(input_ndim=0, new_order=[])(a)
d_b = DimShuffle(input_ndim=3, new_order=[0, 2, 1])(b)
d_a2 = DimShuffle(input_ndim=0, new_order=["x", "x", "x"])(a)
assert _as_scalar(a) == a
assert _as_scalar(b) != b
......@@ -607,13 +607,13 @@ class TestAsScalar:
# Test that it fails on nonscalar constants
a = pt.constant(np.ones(5))
assert _as_scalar(a) is None
assert _as_scalar(DimShuffle([False], [0, "x"])(a)) is None
assert _as_scalar(DimShuffle(input_ndim=1, new_order=[0, "x"])(a)) is None
def test_basic_2(self):
# Test that it works on scalar variables
a = dscalar()
d_a = DimShuffle([], [])(a)
d_a2 = DimShuffle([], ["x", "x"])(a)
d_a = DimShuffle(input_ndim=0, new_order=[])(a)
d_a2 = DimShuffle(input_ndim=0, new_order=["x", "x"])(a)
assert _as_scalar(a) is a
assert _as_scalar(d_a) is a
......@@ -623,13 +623,15 @@ class TestAsScalar:
# Test that it fails on nonscalar variables
a = matrix()
assert _as_scalar(a) is None
assert _as_scalar(DimShuffle([False, False], [0, "x", 1])(a)) is None
assert _as_scalar(DimShuffle(input_ndim=2, new_order=[0, "x", 1])(a)) is None
class TestRealMatrix:
def test_basic(self):
assert _is_real_matrix(DimShuffle([False, False], [1, 0])(matrix()))
assert not _is_real_matrix(DimShuffle([False], ["x", 0])(dvector()))
assert _is_real_matrix(DimShuffle(input_ndim=2, new_order=[1, 0])(matrix()))
assert not _is_real_matrix(
DimShuffle(input_ndim=1, new_order=["x", 0])(dvector())
)
"""
......
......@@ -60,46 +60,40 @@ class TestDimShuffle(unittest_tools.InferShapeTester):
((1,), ("x", "x"), (1, 1)),
]:
i_shape = [entry if entry == 1 else None for entry in xsh]
ib = [entry == 1 for entry in i_shape]
x = self.type(self.dtype, shape=i_shape)("x")
e = self.op(ib, shuffle)(x)
e = self.op(input_ndim=len(i_shape), new_order=shuffle)(x)
f = pytensor.function([x], e, mode=Mode(linker=linker))
assert f(np.ones(xsh, dtype=self.dtype)).shape == zsh
# test that DimShuffle.infer_shape work correctly
x = self.type(self.dtype, shape=i_shape)("x")
e = self.op(ib, shuffle)(x)
e = self.op(input_ndim=len(i_shape), new_order=shuffle)(x)
f = pytensor.function(
[x], e.shape, mode=Mode(linker=linker), on_unused_input="ignore"
)
assert all(f(np.ones(xsh, dtype=self.dtype))) == all(zsh)
# Test when we drop a axis that is not broadcastable
ib = [False, True, False]
x = self.type(self.dtype, shape=(None, 1, None))("x")
with pytest.raises(ValueError):
self.op(ib, shuffle)
x = self.type(self.dtype, shape=(2, 1, None))("x")
with pytest.raises(TypeError):
self.op(input_ndim=3, new_order=shuffle)(x)
# Test when we drop a axis that don't have shape 1
ib = [True, True, False]
x = self.type(self.dtype, shape=(1, 1, None))("x")
e = self.op(ib, (1, 2))(x)
f = pytensor.function([x], e.shape, mode=Mode(linker=linker))
with pytest.raises(TypeError):
f(np.ones((2, 1, 4)))
x = self.type(self.dtype, shape=(None, 1, None))("x")
e = self.op(input_ndim=3, new_order=(1, 2))(x)
f = pytensor.function([x], e, mode=Mode(linker=linker))
with pytest.raises(ValueError):
f(np.ones((2, 1, 4), dtype=self.dtype))
# Test that we can't take a dimensions multiple time
xsh, shuffle, zsh = ((1, 1, 4), (0, 1, 2, 0), (1, 4))
ib = [False, True, False]
x = self.type(self.dtype, shape=(None, 1, None))("x")
with pytest.raises(ValueError):
DimShuffle(ib, shuffle)
DimShuffle(input_ndim=3, new_order=shuffle)
def test_perform(self):
self.with_linker(PerformLinker())
def test_c_or_py(self):
# Shape op don't have C code.
# But This will test DimShuffle c code
self.with_linker(OpWiseCLinker())
def test_infer_shape(self):
......@@ -115,12 +109,11 @@ class TestDimShuffle(unittest_tools.InferShapeTester):
((1,), ("x", "x")),
]:
i_shape = [entry if entry == 1 else None for entry in xsh]
ib = [(entry == 1) for entry in xsh]
adtens = self.type(self.dtype, shape=i_shape)("x")
adtens_val = np.ones(xsh, dtype=self.dtype)
self._compile_and_check(
[adtens],
[self.op(ib, shuffle)(adtens)],
[self.op(input_ndim=len(xsh), new_order=shuffle)(adtens)],
[adtens_val],
self.op,
warn=False,
......@@ -191,11 +184,11 @@ class TestDimShuffle(unittest_tools.InferShapeTester):
y = x.dimshuffle([0, 1, "x"])
assert y.type.shape == (1, 2, 1)
def test_valid_input_broadcastable(self):
assert DimShuffle([True, False], (1, 0)).input_broadcastable == (True, False)
def test_valid_input_ndim(self):
assert DimShuffle(input_ndim=2, new_order=(1, 0)).input_ndim == 2
with pytest.raises(ValueError, match="input_broadcastable must be boolean"):
DimShuffle([None, None], (1, 0))
with pytest.raises(TypeError, match="input_ndim must be an integer"):
DimShuffle(input_ndim=(True, False), new_order=(1, 0))
class TestBroadcast:
......
......@@ -480,12 +480,9 @@ class TestSqueeze(utt.InferShapeTester):
assert f([0]) == 0
# Test that we cannot squeeze dimensions whose length is greater than 1
error_txt_1 = re.escape("SpecifyShape: Got shape (3,), expected (1,).")
error_txt_2 = re.escape("SpecifyShape: dim 0 of input has shape 3, expected 1")
match = error_txt_1 if pytensor.config.mode == "FAST_COMPILE" else error_txt_2
with pytest.raises(
AssertionError,
match=match,
ValueError,
match="cannot reshape array of size 3 into shape ()",
):
f([0, 1, 2])
......
......@@ -204,3 +204,12 @@ class TestFFT:
pytensor.config.floatX
)
utt.verify_grad(f_irfft, [inputs_val], eps=eps)
def test_rfft_expanded_dims_grad(self):
# Regression test for https://github.com/pymc-devs/pytensor/issues/969
def test_func(x):
return fft.rfft(x[None, :])
rng = np.random.default_rng(213)
inputs_val = rng.random((N,)).astype(pytensor.config.floatX)
utt.verify_grad(test_func, [inputs_val], rng=rng)
......@@ -4,7 +4,6 @@ import pytest
import pytensor
from pytensor import function
from pytensor.compile.mode import Mode
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import all as pt_all
from pytensor.tensor.math import any as pt_any
from pytensor.tensor.math import argmax, argmin, max_and_argmax, mean, prod, std, var
......@@ -40,7 +39,7 @@ class TestKeepDims:
new_dims.append(i)
i += 1
return DimShuffle(y.type.broadcastable, new_dims)(y)
return y.dimshuffle(new_dims)
@pytest.mark.parametrize(
"axis",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论