提交 e88117e6 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Only require input_ndim and not input_broadcastable in DimShuffle

上级 d68f53f8
......@@ -19,7 +19,6 @@ from pytensor.graph.op import Op
from pytensor.tensor.math import dot
from pytensor.tensor.math import max as pt_max
from pytensor.tensor.shape import reshape
from pytensor.tensor.subtensor import DimShuffle
def register_specialize(lopt, *tags, **kwargs):
......@@ -375,7 +374,7 @@ def convolve(
[images.shape[0], pt.as_tensor(np.prod(outshp)), pt.as_tensor(nkern)]
)
tensout = reshape(output, newshp, ndim=3)
output = DimShuffle((False,) * tensout.ndim, (0, 2, 1))(tensout)
output = tensout.transpose(0, 2, 1)
if flatten:
output = pt.flatten(output, 2)
......@@ -443,6 +442,6 @@ def max_pool(images, imgshp, maxpoolshp):
)
out2 = reshape(out1, pshape, ndim=3)
out3 = DimShuffle(out2.broadcastable, (0, 2, 1))(out2)
out3 = out2.transpose(0, 2, 1)
return pt.flatten(out3, 2), outshp
......@@ -2042,7 +2042,7 @@ def transpose(x, axes=None):
# No-op
return _x
ret = DimShuffle(tuple(s == 1 for s in _x.type.shape), axes)(_x)
ret = _x.dimshuffle(axes)
if _x.name and axes == tuple(range((_x.type.ndim - 1), -1, -1)):
ret.name = _x.name + ".T"
......@@ -3518,7 +3518,7 @@ class PermuteRowElements(Op):
newdims.append(i)
i += 1
gx = DimShuffle(tuple(s == 1 for s in gx.type.shape), newdims)(gx)
gx = gx.dimshuffle(newdims)
assert gx.type.ndim == x.type.ndim
assert all(
s1 == s2
......
from collections.abc import Sequence
from copy import copy
from textwrap import dedent
from typing import Literal
import numpy as np
from numpy.core.numeric import normalize_axis_tuple
......@@ -54,15 +56,14 @@ class DimShuffle(ExternalCOp):
Parameters
----------
input_broadcastable
The expected broadcastable pattern of the input
input_ndim
The expected number of dimension of the input
new_order
A list representing the relationship between the input's
dimensions and the output's dimensions. Each element of the
list can either be an index or 'x'. Indices must be encoded
as python integers, not pytensor symbolic integers.
inplace : bool, optional
If True (default), the output will be a view of the input.
Missing indexes correspond to drop dimensions.
Notes
-----
......@@ -77,10 +78,10 @@ class DimShuffle(ExternalCOp):
.. code-block:: python
DimShuffle((False, False, False), ["x", 2, "x", 0, 1])
DimShuffle(input_ndim=3, new_order=["x", 2, "x", 0, 1])
This `Op` will only work on 3d tensors with no broadcastable
dimensions. The first dimension will be broadcastable,
This `Op` will only work on 3d tensors.
The first dimension of the output will be broadcastable,
then we will have the third dimension of the input tensor as
the second of the resulting tensor, etc. If the tensor has
shape (20, 30, 40), the resulting tensor will have dimensions
......@@ -88,39 +89,36 @@ class DimShuffle(ExternalCOp):
.. code-block:: python
DimShuffle((True, False), [1])
DimShuffle(input_ndim=2, new_order=[1])
This `Op` will only work on 2d tensors with the first dimension
broadcastable.
The second dimension of the input tensor will be the first dimension of
the resulting tensor.
If the tensor has shape (1, 20), the resulting tensor will have shape
(20, ).
This `Op` will only work on 2d tensors with the first dimension broadcastable.
The second dimension of the input tensor will be the first dimension of the resulting tensor.
If the tensor has shape (1, 20), the resulting tensor will have shape (20, ).
Examples
--------
.. code-block:: python
DimShuffle((), ["x"]) # make a 0d (scalar) into a 1d vector
DimShuffle((False, False), [0, 1]) # identity
DimShuffle((False, False), [1, 0]) # inverts the 1st and 2nd dimensions
DimShuffle((False,), ["x", 0]) # make a row out of a 1d vector
# (N to 1xN)
DimShuffle((False,), [0, "x"]) # make a column out of a 1d vector
# (N to Nx1)
DimShuffle((False, False, False), [2, 0, 1]) # AxBxC to CxAxB
DimShuffle((False, False), [0, "x", 1]) # AxB to Ax1xB
DimShuffle((False, False), [1, "x", 0]) # AxB to Bx1xA
The reordering of the dimensions can be done with the numpy.transpose
function.
Adding, subtracting dimensions can be done with reshape.
DimShuffle(input_ndim=0, new_order=["x"]) # make a 0d (scalar) into a 1d vector
DimShuffle(input_ndim=2, new_order=[0, 1]) # identity
DimShuffle(input_ndim=2, new_order=[1, 0]) # transposition
# Make a row out of a 1d vector (N to 1xN)
DimShuffle(input_ndim=1, new_order=["x", 0])
# Make a colum out of a 1d vector (N to Nx1)
DimShuffle(input_ndim=1, new_order=[0, "x"])
DimShuffle(input_ndim=3, new_order=[2, 0, 1]) # AxBxC to CxAxB
DimShuffle(input_ndim=2, new_order=[0, "x", 1]) # AxB to Ax1xB
DimShuffle(input_ndim=2, new_order=[1, "x", 0]) # AxB to Bx1xA
Notes
-----
The python implementation of this Op combines numpy.transpose for reordering of the dimensions
and numpy.reshape for subtracting and adding broadcastable dimensions.
"""
_f16_ok = True
check_input = False
__props__ = ("input_broadcastable", "new_order", "inplace")
__props__ = ("input_ndim", "new_order", "inplace")
c_func_file = "c_code/dimshuffle.c"
c_func_name = "APPLY_SPECIFIC(cpu_dimshuffle)"
......@@ -133,16 +131,14 @@ class DimShuffle(ExternalCOp):
inplace=scalar_bool,
)
def __init__(self, input_broadcastable, new_order):
def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]):
super().__init__([self.c_func_file], self.c_func_name)
self.input_broadcastable = tuple(input_broadcastable)
if not all(isinstance(bs, bool | np.bool_) for bs in self.input_broadcastable):
raise ValueError(
f"input_broadcastable must be boolean, {self.input_broadcastable}"
)
self.new_order = tuple(new_order)
if not isinstance(input_ndim, int):
raise TypeError(f"input_ndim must be an integer, got {type(int)}")
self.input_ndim = input_ndim
self.new_order = tuple(new_order)
self.inplace = True
for i, j in enumerate(new_order):
......@@ -152,10 +148,10 @@ class DimShuffle(ExternalCOp):
"DimShuffle indices must be Python ints; got "
f"{j} of type {type(j)}."
)
if j >= len(input_broadcastable):
if j >= input_ndim:
raise ValueError(
f"new_order[{i}] is {j}, but the input only has "
f"{len(input_broadcastable)} axes."
f"{input_ndim} axes."
)
if j in new_order[(i + 1) :]:
raise ValueError(
......@@ -164,19 +160,7 @@ class DimShuffle(ExternalCOp):
)
# List of input dimensions to drop
drop = []
for i, b in enumerate(input_broadcastable):
if i not in new_order:
# We want to drop this dimension because it's not a value in
# `new_order`
if b == 1:
drop.append(i)
else:
# We cannot drop non-broadcastable dimensions
raise ValueError(
"Cannot drop a non-broadcastable dimension: "
f"{input_broadcastable}, {new_order}"
)
drop = [i for i in range(input_ndim) if i not in new_order]
# This is the list of the original dimensions that we keep
self.shuffle = [x for x in new_order if x != "x"]
......@@ -186,7 +170,6 @@ class DimShuffle(ExternalCOp):
self.augment = sorted(i for i, x in enumerate(new_order) if x == "x")
self.drop = drop
input_ndim = len(input_broadcastable)
self.is_left_expand_dims = self.augment and (
input_ndim == 0 or new_order[-input_ndim:] == list(range(input_ndim))
)
......@@ -204,30 +187,29 @@ class DimShuffle(ExternalCOp):
# Let's just build the ExternalCOp.
super().__init__([self.c_func_file], self.c_func_name)
def make_node(self, _input):
input = as_tensor_variable(_input)
ib = tuple(s == 1 for s in input.type.shape)
if ib != self.input_broadcastable:
if len(ib) != len(self.input_broadcastable):
def make_node(self, inp):
input = as_tensor_variable(inp)
if input.type.ndim != self.input_ndim:
raise TypeError(
"The number of dimensions of the input is incorrect for this op. "
f"Expected {self.input_ndim}, got {input.type.ndim}."
)
input_static_shape = input.type.shape
# Runtime check for invalid drop
for d in self.drop:
if input_static_shape[d] not in (1, None):
raise TypeError(
"The number of dimensions of the "
f"input is incorrect for this op. Expected {self.input_broadcastable}, got {ib}."
f"Input dropped dimension {d} must have length 1 but has {input_static_shape[d]}"
)
for expected, b in zip(self.input_broadcastable, ib):
if expected and not b:
raise TypeError(
"The broadcastable pattern of the "
f"input is incorrect for this op. Expected {self.input_broadcastable}, got {ib}."
)
# else, expected == b or not expected and b
# Both case are good.
out_static_shape = []
for dim_idx in self.new_order:
if dim_idx == "x":
out_static_shape.append(1)
else:
out_static_shape.append(input.type.shape[dim_idx])
out_static_shape.append(input_static_shape[dim_idx])
output = TensorType(dtype=input.type.dtype, shape=out_static_shape)()
......@@ -254,12 +236,14 @@ class DimShuffle(ExternalCOp):
if not isinstance(res, np.ndarray | np.memmap):
raise TypeError(res)
# Put dropped axis at end
res = res.transpose(self.transposition)
shape = list(res.shape[: len(self.shuffle)])
# Define new shape without dropped axis and including new ones
new_shape = list(res.shape[: len(self.shuffle)])
for augm in self.augment:
shape.insert(augm, 1)
res = res.reshape(shape)
new_shape.insert(augm, 1)
res = res.reshape(new_shape)
if not self.inplace:
res = np.copy(res)
......@@ -284,22 +268,15 @@ class DimShuffle(ExternalCOp):
def grad(self, inp, grads):
(x,) = inp
(gz,) = grads
gz = as_tensor_variable(gz)
grad_order = ["x"] * x.type.ndim
for i, v in enumerate(self.new_order):
if v != "x":
grad_order[v] = i
# Do not make the DimShuffle inplace as an optimization at the
# canonicalization optimization phase will remove the inplace.
# The inplace will be reintroduced automatically later in the graph.
if inp[0].dtype in discrete_dtypes:
return [inp[0].zeros_like(dtype=config.floatX)]
if x.type.dtype in discrete_dtypes:
return [x.zeros_like(dtype=config.floatX)]
else:
return [
DimShuffle(tuple(s == 1 for s in gz.type.shape), grad_order)(
Elemwise(scalar_identity)(gz)
)
]
return [gz.dimshuffle(grad_order)]
class DimShufflePrinter(Printer):
......@@ -409,7 +386,7 @@ class Elemwise(OpenMPOp):
self.nfunc = None
self.inplace_pattern = frozendict(self.inplace_pattern)
def get_output_info(self, dim_shuffle, *inputs):
def get_output_info(self, *inputs):
"""Return the outputs dtype and broadcastable pattern and the
dimshuffled inputs.
......@@ -427,12 +404,7 @@ class Elemwise(OpenMPOp):
if not difference:
args.append(input)
else:
args.append(
dim_shuffle(
input.type.broadcastable,
["x"] * difference + list(range(length)),
)(input)
)
args.append(input.dimshuffle(["x"] * difference + list(range(length))))
inputs = args
# HERE: all the broadcast dims have the same length now
......@@ -489,7 +461,7 @@ class Elemwise(OpenMPOp):
using DimShuffle.
"""
inputs = [as_tensor_variable(i) for i in inputs]
out_dtypes, out_shapes, inputs = self.get_output_info(DimShuffle, *inputs)
out_dtypes, out_shapes, inputs = self.get_output_info(*inputs)
outputs = [
TensorType(dtype=dtype, shape=shape)()
for dtype, shape in zip(out_dtypes, out_shapes)
......@@ -634,7 +606,7 @@ class Elemwise(OpenMPOp):
res = pytensor.tensor.basic.constant(
np.asarray(r.data), dtype=r.type.dtype
)
return DimShuffle((), ["x"] * nd)(res)
return res.dimshuffle(["x"] * nd)
new_r = Elemwise(node.op, {})(*[transform(ipt) for ipt in node.inputs])
if isinstance(new_r, list | tuple):
......@@ -1707,13 +1679,12 @@ def vectorize_dimshuffle(op: DimShuffle, node: Apply, x: TensorVariable) -> Appl
batched_ndims = x.type.ndim - node.inputs[0].type.ndim
if not batched_ndims:
return node.op.make_node(x)
input_broadcastable = x.type.broadcastable[:batched_ndims] + op.input_broadcastable
# e.g., ds(matrix, order=(1, "x", 0)) -> ds(tensor4, order=(0, 1, 3, "x", 2))
# e.g., ds(row, order=(1, "x")) -> ds(tensor4, order=(0, 1, 3, "x"))
# e.g., ds(input_ndim=2, order=(1, "x", 0)) -> ds(input_ndim=4, order=(0, 1, 3, "x", 2))
# e.g., ds(input_ndim=2, order=(1, "x")) -> ds(input_ndim=4, order=(0, 1, 3, "x"))
new_order = list(range(batched_ndims)) + [
"x" if (o == "x") else (o + batched_ndims) for o in op.new_order
]
return DimShuffle(input_broadcastable, new_order).make_node(x)
return x.dimshuffle(new_order).owner
def get_normalized_batch_axes(
......
......@@ -41,7 +41,7 @@ from pytensor.tensor.math import (
)
from pytensor.tensor.math import max as pt_max
from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.shape import Shape_i, specify_broadcastable
from pytensor.tensor.shape import Shape_i
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector
from pytensor.tensor.variable import TensorVariable
......@@ -609,11 +609,6 @@ def squeeze(x, axis=None):
# Nothing could be squeezed
return _x
# `Dimshuffle` raises when we try to drop an axis that is not statically broadcastable.
# We add a `specify_broadcastable` instead of raising.
non_broadcastable_axis = [i for i in axis if not _x.broadcastable[i]]
_x = specify_broadcastable(_x, *non_broadcastable_axis)
return _x.dimshuffle([i for i in range(_x.ndim) if i not in axis])
......
from pytensor import printing
from pytensor.printing import pprint
from pytensor.tensor.elemwise import DimShuffle, scalar_elemwise
from pytensor.tensor.elemwise import scalar_elemwise
@scalar_elemwise
......@@ -429,4 +429,4 @@ pprint.assign(pow_inplace, printing.OperatorPrinter("**=", 1, "right"))
def transpose_inplace(x, **kwargs):
"Perform a transpose on a tensor without copying the underlying storage"
dims = list(range(x.ndim - 1, -1, -1))
return DimShuffle(x.broadcastable, dims)(x)
return x.dimshuffle(dims)
......@@ -33,7 +33,6 @@ from pytensor.tensor.basic import (
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
from pytensor.tensor.elemwise import (
CAReduce,
DimShuffle,
Elemwise,
get_normalized_batch_axes,
scalar_elemwise,
......@@ -2338,8 +2337,7 @@ class Sum(FixedOpCAReduce):
else:
new_dims.append(i)
i += 1
ds_op = DimShuffle(gz.type.broadcastable, new_dims)
gx = Elemwise(ps.second)(x, ds_op(gz))
gx = Elemwise(ps.second)(x, gz.dimshuffle(new_dims))
return [gx]
def R_op(self, inputs, eval_points):
......
......@@ -65,7 +65,7 @@ def size_parameter_as_tuple(fgraph, node):
if isinstance(size_node.op, MakeVector) or (
isinstance(size_node.op, DimShuffle)
and size_node.op.input_broadcastable == ()
and size_node.op.input_ndim == 0
and size_node.op.new_order == ("x",)
):
# Here PyTensor converted a tuple or list to a tensor
......
......@@ -494,7 +494,7 @@ def local_alloc_sink_dimshuffle(fgraph, node):
dimshuffle_new_order = ["x"] * num_dims_with_size_1_added_to_left + list(
range(len(new_output_shape))
)
return [DimShuffle(inner.type.broadcastable, dimshuffle_new_order)(inner)]
return [inner.dimshuffle(dimshuffle_new_order)]
@node_rewriter([AllocEmpty])
......
......@@ -422,8 +422,6 @@ def local_dimshuffle_lift(fgraph, node):
"""
op = node.op
if not isinstance(op, DimShuffle):
return False
inp = node.inputs[0]
inode = inp.owner
......@@ -437,7 +435,7 @@ def local_dimshuffle_lift(fgraph, node):
# Don't use make_node to have tag.test_value set.
new_inputs = []
for inp in inode.inputs:
new_inp = op.__class__(inp.type.broadcastable, op.new_order)(inp)
new_inp = inp.dimshuffle(op.new_order)
new_inputs.append(apply_local_dimshuffle_lift(fgraph, new_inp))
copy_stack_trace(node.outputs[0], new_inputs)
ret = inode.op(*new_inputs, return_list=True)
......@@ -449,7 +447,7 @@ def local_dimshuffle_lift(fgraph, node):
if is_dimshuffle_useless(new_order, inp):
return [inp]
elif inode and isinstance(inode.op, DimShuffle):
ret = op.__class__(inp.type.broadcastable, new_order)(inp)
ret = inp.dimshuffle(new_order)
ret = apply_local_dimshuffle_lift(fgraph, ret)
copy_stack_trace(node.outputs[0], ret)
return [ret]
......
......@@ -130,7 +130,7 @@ def shape_parameter_as_tuple(fgraph, node):
if isinstance(shape_node.op, MakeVector) or (
isinstance(shape_node.op, DimShuffle)
and shape_node.op.input_broadcastable == ()
and shape_node.op.input_ndim == 0
and shape_node.op.new_order == ("x",)
):
# Here PyTensor converted a tuple or list to a tensor
......
......@@ -65,7 +65,7 @@ def is_matrix_transpose(x: TensorVariable) -> bool:
if ndims < 2:
return False
transpose_order = (*range(ndims - 2), ndims - 1, ndims - 2)
return cast(bool, node.op.new_order == transpose_order)
return node.op.new_order == transpose_order
return False
......
......@@ -925,11 +925,7 @@ def local_reshape_to_dimshuffle(fgraph, node):
if index != output.type.ndim:
inner = op.__class__(len(new_output_shape))(inp, new_output_shape)
copy_stack_trace(output, inner)
new_node = [
DimShuffle(tuple(s == 1 for s in inner.type.shape), dimshuffle_new_order)(
inner
)
]
new_node = [inner.dimshuffle(dimshuffle_new_order)]
copy_stack_trace(output, new_node)
return new_node
......
......@@ -344,8 +344,8 @@ class _tensor_py_operators:
"""
if (len(pattern) == 1) and (isinstance(pattern[0], list | tuple)):
pattern = pattern[0]
op = pt.elemwise.DimShuffle(list(self.type.broadcastable), pattern)
return op(self)
ds_op = pt.elemwise.DimShuffle(input_ndim=self.type.ndim, new_order=pattern)
return ds_op(self)
def flatten(self, ndim=1):
return pt.basic.flatten(self, ndim)
......
......@@ -39,7 +39,7 @@ def test_jax_Dimshuffle():
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
a_pt = tensor(dtype=config.floatX, shape=(None, 1))
x = pt_elemwise.DimShuffle([False, True], (0,))(a_pt)
x = pt_elemwise.DimShuffle(input_ndim=2, new_order=(0,))(a_pt)
x_fg = FunctionGraph([a_pt], [x])
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
......
......@@ -15,7 +15,7 @@ from pytensor.compile.sharedvalue import SharedVariable
from pytensor.gradient import grad
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import elemwise as pt_elemwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import All, Any, Max, Mean, Min, Prod, ProdWithoutZeros, Sum
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
from tests.link.numba.test_basic import (
......@@ -205,7 +205,7 @@ def test_elemwise_speed(benchmark):
],
)
def test_Dimshuffle(v, new_order):
g = pt_elemwise.DimShuffle(v.broadcastable, new_order)(v)
g = v.dimshuffle(new_order)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
......@@ -219,7 +219,7 @@ def test_Dimshuffle(v, new_order):
def test_Dimshuffle_returns_array():
x = pt.vector("x", shape=(1,))
y = 2 * pt_elemwise.DimShuffle([True], [])(x)
y = 2 * x.dimshuffle([])
func = pytensor.function([x], y, mode="NUMBA")
out = func(np.zeros(1, dtype=config.floatX))
assert out.ndim == 0
......@@ -230,7 +230,7 @@ def test_Dimshuffle_non_contiguous():
non-contiguous arrays, make sure we work around thpt."""
x = pt.dvector()
idx = pt.vector(dtype="int64")
op = pytensor.tensor.elemwise.DimShuffle([True], [])
op = DimShuffle(input_ndim=1, new_order=[])
out = op(pt.specify_shape(x[idx][::2], (1,)))
func = pytensor.function([x, idx], out, mode="NUMBA")
assert func(np.zeros(3), np.array([1])).ndim == 0
......
......@@ -5,7 +5,6 @@ import pytensor.tensor as pt
import pytensor.tensor.math as ptm
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import elemwise as pt_elemwise
from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax
from pytensor.tensor.type import matrix, tensor, tensor3, vector
from tests.link.pytorch.test_basic import compare_pytorch_and_py
......@@ -27,11 +26,6 @@ def test_pytorch_Dimshuffle():
x_fg = FunctionGraph([a_pt], [x])
compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
a_pt = tensor(dtype=config.floatX, shape=(None, 1))
x = pt_elemwise.DimShuffle([False, True], (0,))(a_pt)
x_fg = FunctionGraph([a_pt], [x])
compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
def test_multiple_input_output():
x = vector("x")
......
......@@ -79,7 +79,7 @@ dimshuffle_lift = out2in(local_dimshuffle_lift)
def ds(x, y):
return DimShuffle(x.type.broadcastable, y)(x)
return x.dimshuffle(y)
def inputs(xbc=(0, 0), ybc=(0, 0), zbc=(0, 0)):
......
......@@ -160,7 +160,7 @@ _fast_run_rewrites = optdb.query(_fast_run_rewrites)
def ds(x, y):
return DimShuffle(x.type.broadcastable, y)(x)
return x.dimshuffle(y)
def rewrite(g, level="fast_run"):
......@@ -3749,7 +3749,7 @@ def test_local_log_sum_exp_maximum():
check_max_log_sum_exp(x, axis=(0, 1, 2), dimshuffle_op=None)
# If a transpose is applied to the sum
transpose_op = DimShuffle((False, False), (1, 0))
transpose_op = DimShuffle(input_ndim=2, new_order=(1, 0))
check_max_log_sum_exp(x, axis=2, dimshuffle_op=transpose_op)
# If the sum is performed with keepdims=True
......@@ -3770,7 +3770,7 @@ def test_local_log_sum_exp_near_one():
assert np.allclose(naive_ret, rewritten_ret)
# If a transpose is applied
transpose_op = DimShuffle((False, False), (1, 0))
transpose_op = DimShuffle(input_ndim=2, new_order=(1, 0))
f = compile_graph_log_sum_exp(x, axis=(1,), dimshuffle_op=transpose_op)
naive_ret = np.log(np.sum(np.exp(x_val), axis=1).T)
rewritten_ret = f(x_val)
......
......@@ -3418,7 +3418,7 @@ def test_unalign():
def test_dimshuffle_duplicate():
x = vector()
with pytest.raises(ValueError, match="may not appear twice"):
DimShuffle((False,), (0, 0))(x)
DimShuffle(input_ndim=1, new_order=(0, 0))(x)
class TestGetUnderlyingScalarConstantValue:
......
......@@ -593,9 +593,9 @@ class TestAsScalar:
b = pt.constant(np.asarray([[[0.5]]]))
b2 = b.dimshuffle()
assert b2.ndim == 0
d_a = DimShuffle([], [])(a)
d_b = DimShuffle([True, True, True], [0, 2, 1])(b)
d_a2 = DimShuffle([], ["x", "x", "x"])(a)
d_a = DimShuffle(input_ndim=0, new_order=[])(a)
d_b = DimShuffle(input_ndim=3, new_order=[0, 2, 1])(b)
d_a2 = DimShuffle(input_ndim=0, new_order=["x", "x", "x"])(a)
assert _as_scalar(a) == a
assert _as_scalar(b) != b
......@@ -607,13 +607,13 @@ class TestAsScalar:
# Test that it fails on nonscalar constants
a = pt.constant(np.ones(5))
assert _as_scalar(a) is None
assert _as_scalar(DimShuffle([False], [0, "x"])(a)) is None
assert _as_scalar(DimShuffle(input_ndim=1, new_order=[0, "x"])(a)) is None
def test_basic_2(self):
# Test that it works on scalar variables
a = dscalar()
d_a = DimShuffle([], [])(a)
d_a2 = DimShuffle([], ["x", "x"])(a)
d_a = DimShuffle(input_ndim=0, new_order=[])(a)
d_a2 = DimShuffle(input_ndim=0, new_order=["x", "x"])(a)
assert _as_scalar(a) is a
assert _as_scalar(d_a) is a
......@@ -623,13 +623,15 @@ class TestAsScalar:
# Test that it fails on nonscalar variables
a = matrix()
assert _as_scalar(a) is None
assert _as_scalar(DimShuffle([False, False], [0, "x", 1])(a)) is None
assert _as_scalar(DimShuffle(input_ndim=2, new_order=[0, "x", 1])(a)) is None
class TestRealMatrix:
def test_basic(self):
assert _is_real_matrix(DimShuffle([False, False], [1, 0])(matrix()))
assert not _is_real_matrix(DimShuffle([False], ["x", 0])(dvector()))
assert _is_real_matrix(DimShuffle(input_ndim=2, new_order=[1, 0])(matrix()))
assert not _is_real_matrix(
DimShuffle(input_ndim=1, new_order=["x", 0])(dvector())
)
"""
......
......@@ -60,46 +60,40 @@ class TestDimShuffle(unittest_tools.InferShapeTester):
((1,), ("x", "x"), (1, 1)),
]:
i_shape = [entry if entry == 1 else None for entry in xsh]
ib = [entry == 1 for entry in i_shape]
x = self.type(self.dtype, shape=i_shape)("x")
e = self.op(ib, shuffle)(x)
e = self.op(input_ndim=len(i_shape), new_order=shuffle)(x)
f = pytensor.function([x], e, mode=Mode(linker=linker))
assert f(np.ones(xsh, dtype=self.dtype)).shape == zsh
# test that DimShuffle.infer_shape work correctly
x = self.type(self.dtype, shape=i_shape)("x")
e = self.op(ib, shuffle)(x)
e = self.op(input_ndim=len(i_shape), new_order=shuffle)(x)
f = pytensor.function(
[x], e.shape, mode=Mode(linker=linker), on_unused_input="ignore"
)
assert all(f(np.ones(xsh, dtype=self.dtype))) == all(zsh)
# Test when we drop a axis that is not broadcastable
ib = [False, True, False]
x = self.type(self.dtype, shape=(None, 1, None))("x")
with pytest.raises(ValueError):
self.op(ib, shuffle)
x = self.type(self.dtype, shape=(2, 1, None))("x")
with pytest.raises(TypeError):
self.op(input_ndim=3, new_order=shuffle)(x)
# Test when we drop a axis that don't have shape 1
ib = [True, True, False]
x = self.type(self.dtype, shape=(1, 1, None))("x")
e = self.op(ib, (1, 2))(x)
f = pytensor.function([x], e.shape, mode=Mode(linker=linker))
with pytest.raises(TypeError):
f(np.ones((2, 1, 4)))
x = self.type(self.dtype, shape=(None, 1, None))("x")
e = self.op(input_ndim=3, new_order=(1, 2))(x)
f = pytensor.function([x], e, mode=Mode(linker=linker))
with pytest.raises(ValueError):
f(np.ones((2, 1, 4), dtype=self.dtype))
# Test that we can't take a dimensions multiple time
xsh, shuffle, zsh = ((1, 1, 4), (0, 1, 2, 0), (1, 4))
ib = [False, True, False]
x = self.type(self.dtype, shape=(None, 1, None))("x")
with pytest.raises(ValueError):
DimShuffle(ib, shuffle)
DimShuffle(input_ndim=3, new_order=shuffle)
def test_perform(self):
self.with_linker(PerformLinker())
def test_c_or_py(self):
# Shape op don't have C code.
# But This will test DimShuffle c code
self.with_linker(OpWiseCLinker())
def test_infer_shape(self):
......@@ -115,12 +109,11 @@ class TestDimShuffle(unittest_tools.InferShapeTester):
((1,), ("x", "x")),
]:
i_shape = [entry if entry == 1 else None for entry in xsh]
ib = [(entry == 1) for entry in xsh]
adtens = self.type(self.dtype, shape=i_shape)("x")
adtens_val = np.ones(xsh, dtype=self.dtype)
self._compile_and_check(
[adtens],
[self.op(ib, shuffle)(adtens)],
[self.op(input_ndim=len(xsh), new_order=shuffle)(adtens)],
[adtens_val],
self.op,
warn=False,
......@@ -191,11 +184,11 @@ class TestDimShuffle(unittest_tools.InferShapeTester):
y = x.dimshuffle([0, 1, "x"])
assert y.type.shape == (1, 2, 1)
def test_valid_input_broadcastable(self):
assert DimShuffle([True, False], (1, 0)).input_broadcastable == (True, False)
def test_valid_input_ndim(self):
assert DimShuffle(input_ndim=2, new_order=(1, 0)).input_ndim == 2
with pytest.raises(ValueError, match="input_broadcastable must be boolean"):
DimShuffle([None, None], (1, 0))
with pytest.raises(TypeError, match="input_ndim must be an integer"):
DimShuffle(input_ndim=(True, False), new_order=(1, 0))
class TestBroadcast:
......
......@@ -480,12 +480,9 @@ class TestSqueeze(utt.InferShapeTester):
assert f([0]) == 0
# Test that we cannot squeeze dimensions whose length is greater than 1
error_txt_1 = re.escape("SpecifyShape: Got shape (3,), expected (1,).")
error_txt_2 = re.escape("SpecifyShape: dim 0 of input has shape 3, expected 1")
match = error_txt_1 if pytensor.config.mode == "FAST_COMPILE" else error_txt_2
with pytest.raises(
AssertionError,
match=match,
ValueError,
match="cannot reshape array of size 3 into shape ()",
):
f([0, 1, 2])
......
......@@ -204,3 +204,12 @@ class TestFFT:
pytensor.config.floatX
)
utt.verify_grad(f_irfft, [inputs_val], eps=eps)
def test_rfft_expanded_dims_grad(self):
# Regression test for https://github.com/pymc-devs/pytensor/issues/969
def test_func(x):
return fft.rfft(x[None, :])
rng = np.random.default_rng(213)
inputs_val = rng.random((N,)).astype(pytensor.config.floatX)
utt.verify_grad(test_func, [inputs_val], rng=rng)
......@@ -4,7 +4,6 @@ import pytest
import pytensor
from pytensor import function
from pytensor.compile.mode import Mode
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import all as pt_all
from pytensor.tensor.math import any as pt_any
from pytensor.tensor.math import argmax, argmin, max_and_argmax, mean, prod, std, var
......@@ -40,7 +39,7 @@ class TestKeepDims:
new_dims.append(i)
i += 1
return DimShuffle(y.type.broadcastable, new_dims)(y)
return y.dimshuffle(new_dims)
@pytest.mark.parametrize(
"axis",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论