提交 e88117e6 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Only require input_ndim and not input_broadcastable in DimShuffle

上级 d68f53f8
...@@ -19,7 +19,6 @@ from pytensor.graph.op import Op ...@@ -19,7 +19,6 @@ from pytensor.graph.op import Op
from pytensor.tensor.math import dot from pytensor.tensor.math import dot
from pytensor.tensor.math import max as pt_max from pytensor.tensor.math import max as pt_max
from pytensor.tensor.shape import reshape from pytensor.tensor.shape import reshape
from pytensor.tensor.subtensor import DimShuffle
def register_specialize(lopt, *tags, **kwargs): def register_specialize(lopt, *tags, **kwargs):
...@@ -375,7 +374,7 @@ def convolve( ...@@ -375,7 +374,7 @@ def convolve(
[images.shape[0], pt.as_tensor(np.prod(outshp)), pt.as_tensor(nkern)] [images.shape[0], pt.as_tensor(np.prod(outshp)), pt.as_tensor(nkern)]
) )
tensout = reshape(output, newshp, ndim=3) tensout = reshape(output, newshp, ndim=3)
output = DimShuffle((False,) * tensout.ndim, (0, 2, 1))(tensout) output = tensout.transpose(0, 2, 1)
if flatten: if flatten:
output = pt.flatten(output, 2) output = pt.flatten(output, 2)
...@@ -443,6 +442,6 @@ def max_pool(images, imgshp, maxpoolshp): ...@@ -443,6 +442,6 @@ def max_pool(images, imgshp, maxpoolshp):
) )
out2 = reshape(out1, pshape, ndim=3) out2 = reshape(out1, pshape, ndim=3)
out3 = DimShuffle(out2.broadcastable, (0, 2, 1))(out2) out3 = out2.transpose(0, 2, 1)
return pt.flatten(out3, 2), outshp return pt.flatten(out3, 2), outshp
...@@ -2042,7 +2042,7 @@ def transpose(x, axes=None): ...@@ -2042,7 +2042,7 @@ def transpose(x, axes=None):
# No-op # No-op
return _x return _x
ret = DimShuffle(tuple(s == 1 for s in _x.type.shape), axes)(_x) ret = _x.dimshuffle(axes)
if _x.name and axes == tuple(range((_x.type.ndim - 1), -1, -1)): if _x.name and axes == tuple(range((_x.type.ndim - 1), -1, -1)):
ret.name = _x.name + ".T" ret.name = _x.name + ".T"
...@@ -3518,7 +3518,7 @@ class PermuteRowElements(Op): ...@@ -3518,7 +3518,7 @@ class PermuteRowElements(Op):
newdims.append(i) newdims.append(i)
i += 1 i += 1
gx = DimShuffle(tuple(s == 1 for s in gx.type.shape), newdims)(gx) gx = gx.dimshuffle(newdims)
assert gx.type.ndim == x.type.ndim assert gx.type.ndim == x.type.ndim
assert all( assert all(
s1 == s2 s1 == s2
......
...@@ -41,7 +41,7 @@ from pytensor.tensor.math import ( ...@@ -41,7 +41,7 @@ from pytensor.tensor.math import (
) )
from pytensor.tensor.math import max as pt_max from pytensor.tensor.math import max as pt_max
from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.shape import Shape_i, specify_broadcastable from pytensor.tensor.shape import Shape_i
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector
from pytensor.tensor.variable import TensorVariable from pytensor.tensor.variable import TensorVariable
...@@ -609,11 +609,6 @@ def squeeze(x, axis=None): ...@@ -609,11 +609,6 @@ def squeeze(x, axis=None):
# Nothing could be squeezed # Nothing could be squeezed
return _x return _x
# `Dimshuffle` raises when we try to drop an axis that is not statically broadcastable.
# We add a `specify_broadcastable` instead of raising.
non_broadcastable_axis = [i for i in axis if not _x.broadcastable[i]]
_x = specify_broadcastable(_x, *non_broadcastable_axis)
return _x.dimshuffle([i for i in range(_x.ndim) if i not in axis]) return _x.dimshuffle([i for i in range(_x.ndim) if i not in axis])
......
from pytensor import printing from pytensor import printing
from pytensor.printing import pprint from pytensor.printing import pprint
from pytensor.tensor.elemwise import DimShuffle, scalar_elemwise from pytensor.tensor.elemwise import scalar_elemwise
@scalar_elemwise @scalar_elemwise
...@@ -429,4 +429,4 @@ pprint.assign(pow_inplace, printing.OperatorPrinter("**=", 1, "right")) ...@@ -429,4 +429,4 @@ pprint.assign(pow_inplace, printing.OperatorPrinter("**=", 1, "right"))
def transpose_inplace(x, **kwargs): def transpose_inplace(x, **kwargs):
"Perform a transpose on a tensor without copying the underlying storage" "Perform a transpose on a tensor without copying the underlying storage"
dims = list(range(x.ndim - 1, -1, -1)) dims = list(range(x.ndim - 1, -1, -1))
return DimShuffle(x.broadcastable, dims)(x) return x.dimshuffle(dims)
...@@ -33,7 +33,6 @@ from pytensor.tensor.basic import ( ...@@ -33,7 +33,6 @@ from pytensor.tensor.basic import (
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
from pytensor.tensor.elemwise import ( from pytensor.tensor.elemwise import (
CAReduce, CAReduce,
DimShuffle,
Elemwise, Elemwise,
get_normalized_batch_axes, get_normalized_batch_axes,
scalar_elemwise, scalar_elemwise,
...@@ -2338,8 +2337,7 @@ class Sum(FixedOpCAReduce): ...@@ -2338,8 +2337,7 @@ class Sum(FixedOpCAReduce):
else: else:
new_dims.append(i) new_dims.append(i)
i += 1 i += 1
ds_op = DimShuffle(gz.type.broadcastable, new_dims) gx = Elemwise(ps.second)(x, gz.dimshuffle(new_dims))
gx = Elemwise(ps.second)(x, ds_op(gz))
return [gx] return [gx]
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
......
...@@ -65,7 +65,7 @@ def size_parameter_as_tuple(fgraph, node): ...@@ -65,7 +65,7 @@ def size_parameter_as_tuple(fgraph, node):
if isinstance(size_node.op, MakeVector) or ( if isinstance(size_node.op, MakeVector) or (
isinstance(size_node.op, DimShuffle) isinstance(size_node.op, DimShuffle)
and size_node.op.input_broadcastable == () and size_node.op.input_ndim == 0
and size_node.op.new_order == ("x",) and size_node.op.new_order == ("x",)
): ):
# Here PyTensor converted a tuple or list to a tensor # Here PyTensor converted a tuple or list to a tensor
......
...@@ -494,7 +494,7 @@ def local_alloc_sink_dimshuffle(fgraph, node): ...@@ -494,7 +494,7 @@ def local_alloc_sink_dimshuffle(fgraph, node):
dimshuffle_new_order = ["x"] * num_dims_with_size_1_added_to_left + list( dimshuffle_new_order = ["x"] * num_dims_with_size_1_added_to_left + list(
range(len(new_output_shape)) range(len(new_output_shape))
) )
return [DimShuffle(inner.type.broadcastable, dimshuffle_new_order)(inner)] return [inner.dimshuffle(dimshuffle_new_order)]
@node_rewriter([AllocEmpty]) @node_rewriter([AllocEmpty])
......
...@@ -422,8 +422,6 @@ def local_dimshuffle_lift(fgraph, node): ...@@ -422,8 +422,6 @@ def local_dimshuffle_lift(fgraph, node):
""" """
op = node.op op = node.op
if not isinstance(op, DimShuffle):
return False
inp = node.inputs[0] inp = node.inputs[0]
inode = inp.owner inode = inp.owner
...@@ -437,7 +435,7 @@ def local_dimshuffle_lift(fgraph, node): ...@@ -437,7 +435,7 @@ def local_dimshuffle_lift(fgraph, node):
# Don't use make_node to have tag.test_value set. # Don't use make_node to have tag.test_value set.
new_inputs = [] new_inputs = []
for inp in inode.inputs: for inp in inode.inputs:
new_inp = op.__class__(inp.type.broadcastable, op.new_order)(inp) new_inp = inp.dimshuffle(op.new_order)
new_inputs.append(apply_local_dimshuffle_lift(fgraph, new_inp)) new_inputs.append(apply_local_dimshuffle_lift(fgraph, new_inp))
copy_stack_trace(node.outputs[0], new_inputs) copy_stack_trace(node.outputs[0], new_inputs)
ret = inode.op(*new_inputs, return_list=True) ret = inode.op(*new_inputs, return_list=True)
...@@ -449,7 +447,7 @@ def local_dimshuffle_lift(fgraph, node): ...@@ -449,7 +447,7 @@ def local_dimshuffle_lift(fgraph, node):
if is_dimshuffle_useless(new_order, inp): if is_dimshuffle_useless(new_order, inp):
return [inp] return [inp]
elif inode and isinstance(inode.op, DimShuffle): elif inode and isinstance(inode.op, DimShuffle):
ret = op.__class__(inp.type.broadcastable, new_order)(inp) ret = inp.dimshuffle(new_order)
ret = apply_local_dimshuffle_lift(fgraph, ret) ret = apply_local_dimshuffle_lift(fgraph, ret)
copy_stack_trace(node.outputs[0], ret) copy_stack_trace(node.outputs[0], ret)
return [ret] return [ret]
......
...@@ -130,7 +130,7 @@ def shape_parameter_as_tuple(fgraph, node): ...@@ -130,7 +130,7 @@ def shape_parameter_as_tuple(fgraph, node):
if isinstance(shape_node.op, MakeVector) or ( if isinstance(shape_node.op, MakeVector) or (
isinstance(shape_node.op, DimShuffle) isinstance(shape_node.op, DimShuffle)
and shape_node.op.input_broadcastable == () and shape_node.op.input_ndim == 0
and shape_node.op.new_order == ("x",) and shape_node.op.new_order == ("x",)
): ):
# Here PyTensor converted a tuple or list to a tensor # Here PyTensor converted a tuple or list to a tensor
......
...@@ -65,7 +65,7 @@ def is_matrix_transpose(x: TensorVariable) -> bool: ...@@ -65,7 +65,7 @@ def is_matrix_transpose(x: TensorVariable) -> bool:
if ndims < 2: if ndims < 2:
return False return False
transpose_order = (*range(ndims - 2), ndims - 1, ndims - 2) transpose_order = (*range(ndims - 2), ndims - 1, ndims - 2)
return cast(bool, node.op.new_order == transpose_order) return node.op.new_order == transpose_order
return False return False
......
...@@ -925,11 +925,7 @@ def local_reshape_to_dimshuffle(fgraph, node): ...@@ -925,11 +925,7 @@ def local_reshape_to_dimshuffle(fgraph, node):
if index != output.type.ndim: if index != output.type.ndim:
inner = op.__class__(len(new_output_shape))(inp, new_output_shape) inner = op.__class__(len(new_output_shape))(inp, new_output_shape)
copy_stack_trace(output, inner) copy_stack_trace(output, inner)
new_node = [ new_node = [inner.dimshuffle(dimshuffle_new_order)]
DimShuffle(tuple(s == 1 for s in inner.type.shape), dimshuffle_new_order)(
inner
)
]
copy_stack_trace(output, new_node) copy_stack_trace(output, new_node)
return new_node return new_node
......
...@@ -344,8 +344,8 @@ class _tensor_py_operators: ...@@ -344,8 +344,8 @@ class _tensor_py_operators:
""" """
if (len(pattern) == 1) and (isinstance(pattern[0], list | tuple)): if (len(pattern) == 1) and (isinstance(pattern[0], list | tuple)):
pattern = pattern[0] pattern = pattern[0]
op = pt.elemwise.DimShuffle(list(self.type.broadcastable), pattern) ds_op = pt.elemwise.DimShuffle(input_ndim=self.type.ndim, new_order=pattern)
return op(self) return ds_op(self)
def flatten(self, ndim=1): def flatten(self, ndim=1):
return pt.basic.flatten(self, ndim) return pt.basic.flatten(self, ndim)
......
...@@ -39,7 +39,7 @@ def test_jax_Dimshuffle(): ...@@ -39,7 +39,7 @@ def test_jax_Dimshuffle():
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)]) compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
a_pt = tensor(dtype=config.floatX, shape=(None, 1)) a_pt = tensor(dtype=config.floatX, shape=(None, 1))
x = pt_elemwise.DimShuffle([False, True], (0,))(a_pt) x = pt_elemwise.DimShuffle(input_ndim=2, new_order=(0,))(a_pt)
x_fg = FunctionGraph([a_pt], [x]) x_fg = FunctionGraph([a_pt], [x])
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)]) compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
......
...@@ -15,7 +15,7 @@ from pytensor.compile.sharedvalue import SharedVariable ...@@ -15,7 +15,7 @@ from pytensor.compile.sharedvalue import SharedVariable
from pytensor.gradient import grad from pytensor.gradient import grad
from pytensor.graph.basic import Constant from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import elemwise as pt_elemwise from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import All, Any, Max, Mean, Min, Prod, ProdWithoutZeros, Sum from pytensor.tensor.math import All, Any, Max, Mean, Min, Prod, ProdWithoutZeros, Sum
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
from tests.link.numba.test_basic import ( from tests.link.numba.test_basic import (
...@@ -205,7 +205,7 @@ def test_elemwise_speed(benchmark): ...@@ -205,7 +205,7 @@ def test_elemwise_speed(benchmark):
], ],
) )
def test_Dimshuffle(v, new_order): def test_Dimshuffle(v, new_order):
g = pt_elemwise.DimShuffle(v.broadcastable, new_order)(v) g = v.dimshuffle(new_order)
g_fg = FunctionGraph(outputs=[g]) g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py( compare_numba_and_py(
g_fg, g_fg,
...@@ -219,7 +219,7 @@ def test_Dimshuffle(v, new_order): ...@@ -219,7 +219,7 @@ def test_Dimshuffle(v, new_order):
def test_Dimshuffle_returns_array(): def test_Dimshuffle_returns_array():
x = pt.vector("x", shape=(1,)) x = pt.vector("x", shape=(1,))
y = 2 * pt_elemwise.DimShuffle([True], [])(x) y = 2 * x.dimshuffle([])
func = pytensor.function([x], y, mode="NUMBA") func = pytensor.function([x], y, mode="NUMBA")
out = func(np.zeros(1, dtype=config.floatX)) out = func(np.zeros(1, dtype=config.floatX))
assert out.ndim == 0 assert out.ndim == 0
...@@ -230,7 +230,7 @@ def test_Dimshuffle_non_contiguous(): ...@@ -230,7 +230,7 @@ def test_Dimshuffle_non_contiguous():
non-contiguous arrays, make sure we work around thpt.""" non-contiguous arrays, make sure we work around thpt."""
x = pt.dvector() x = pt.dvector()
idx = pt.vector(dtype="int64") idx = pt.vector(dtype="int64")
op = pytensor.tensor.elemwise.DimShuffle([True], []) op = DimShuffle(input_ndim=1, new_order=[])
out = op(pt.specify_shape(x[idx][::2], (1,))) out = op(pt.specify_shape(x[idx][::2], (1,)))
func = pytensor.function([x, idx], out, mode="NUMBA") func = pytensor.function([x, idx], out, mode="NUMBA")
assert func(np.zeros(3), np.array([1])).ndim == 0 assert func(np.zeros(3), np.array([1])).ndim == 0
......
...@@ -5,7 +5,6 @@ import pytensor.tensor as pt ...@@ -5,7 +5,6 @@ import pytensor.tensor as pt
import pytensor.tensor.math as ptm import pytensor.tensor.math as ptm
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import elemwise as pt_elemwise
from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax
from pytensor.tensor.type import matrix, tensor, tensor3, vector from pytensor.tensor.type import matrix, tensor, tensor3, vector
from tests.link.pytorch.test_basic import compare_pytorch_and_py from tests.link.pytorch.test_basic import compare_pytorch_and_py
...@@ -27,11 +26,6 @@ def test_pytorch_Dimshuffle(): ...@@ -27,11 +26,6 @@ def test_pytorch_Dimshuffle():
x_fg = FunctionGraph([a_pt], [x]) x_fg = FunctionGraph([a_pt], [x])
compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)]) compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
a_pt = tensor(dtype=config.floatX, shape=(None, 1))
x = pt_elemwise.DimShuffle([False, True], (0,))(a_pt)
x_fg = FunctionGraph([a_pt], [x])
compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
def test_multiple_input_output(): def test_multiple_input_output():
x = vector("x") x = vector("x")
......
...@@ -79,7 +79,7 @@ dimshuffle_lift = out2in(local_dimshuffle_lift) ...@@ -79,7 +79,7 @@ dimshuffle_lift = out2in(local_dimshuffle_lift)
def ds(x, y): def ds(x, y):
return DimShuffle(x.type.broadcastable, y)(x) return x.dimshuffle(y)
def inputs(xbc=(0, 0), ybc=(0, 0), zbc=(0, 0)): def inputs(xbc=(0, 0), ybc=(0, 0), zbc=(0, 0)):
......
...@@ -160,7 +160,7 @@ _fast_run_rewrites = optdb.query(_fast_run_rewrites) ...@@ -160,7 +160,7 @@ _fast_run_rewrites = optdb.query(_fast_run_rewrites)
def ds(x, y): def ds(x, y):
return DimShuffle(x.type.broadcastable, y)(x) return x.dimshuffle(y)
def rewrite(g, level="fast_run"): def rewrite(g, level="fast_run"):
...@@ -3749,7 +3749,7 @@ def test_local_log_sum_exp_maximum(): ...@@ -3749,7 +3749,7 @@ def test_local_log_sum_exp_maximum():
check_max_log_sum_exp(x, axis=(0, 1, 2), dimshuffle_op=None) check_max_log_sum_exp(x, axis=(0, 1, 2), dimshuffle_op=None)
# If a transpose is applied to the sum # If a transpose is applied to the sum
transpose_op = DimShuffle((False, False), (1, 0)) transpose_op = DimShuffle(input_ndim=2, new_order=(1, 0))
check_max_log_sum_exp(x, axis=2, dimshuffle_op=transpose_op) check_max_log_sum_exp(x, axis=2, dimshuffle_op=transpose_op)
# If the sum is performed with keepdims=True # If the sum is performed with keepdims=True
...@@ -3770,7 +3770,7 @@ def test_local_log_sum_exp_near_one(): ...@@ -3770,7 +3770,7 @@ def test_local_log_sum_exp_near_one():
assert np.allclose(naive_ret, rewritten_ret) assert np.allclose(naive_ret, rewritten_ret)
# If a transpose is applied # If a transpose is applied
transpose_op = DimShuffle((False, False), (1, 0)) transpose_op = DimShuffle(input_ndim=2, new_order=(1, 0))
f = compile_graph_log_sum_exp(x, axis=(1,), dimshuffle_op=transpose_op) f = compile_graph_log_sum_exp(x, axis=(1,), dimshuffle_op=transpose_op)
naive_ret = np.log(np.sum(np.exp(x_val), axis=1).T) naive_ret = np.log(np.sum(np.exp(x_val), axis=1).T)
rewritten_ret = f(x_val) rewritten_ret = f(x_val)
......
...@@ -3418,7 +3418,7 @@ def test_unalign(): ...@@ -3418,7 +3418,7 @@ def test_unalign():
def test_dimshuffle_duplicate(): def test_dimshuffle_duplicate():
x = vector() x = vector()
with pytest.raises(ValueError, match="may not appear twice"): with pytest.raises(ValueError, match="may not appear twice"):
DimShuffle((False,), (0, 0))(x) DimShuffle(input_ndim=1, new_order=(0, 0))(x)
class TestGetUnderlyingScalarConstantValue: class TestGetUnderlyingScalarConstantValue:
......
...@@ -593,9 +593,9 @@ class TestAsScalar: ...@@ -593,9 +593,9 @@ class TestAsScalar:
b = pt.constant(np.asarray([[[0.5]]])) b = pt.constant(np.asarray([[[0.5]]]))
b2 = b.dimshuffle() b2 = b.dimshuffle()
assert b2.ndim == 0 assert b2.ndim == 0
d_a = DimShuffle([], [])(a) d_a = DimShuffle(input_ndim=0, new_order=[])(a)
d_b = DimShuffle([True, True, True], [0, 2, 1])(b) d_b = DimShuffle(input_ndim=3, new_order=[0, 2, 1])(b)
d_a2 = DimShuffle([], ["x", "x", "x"])(a) d_a2 = DimShuffle(input_ndim=0, new_order=["x", "x", "x"])(a)
assert _as_scalar(a) == a assert _as_scalar(a) == a
assert _as_scalar(b) != b assert _as_scalar(b) != b
...@@ -607,13 +607,13 @@ class TestAsScalar: ...@@ -607,13 +607,13 @@ class TestAsScalar:
# Test that it fails on nonscalar constants # Test that it fails on nonscalar constants
a = pt.constant(np.ones(5)) a = pt.constant(np.ones(5))
assert _as_scalar(a) is None assert _as_scalar(a) is None
assert _as_scalar(DimShuffle([False], [0, "x"])(a)) is None assert _as_scalar(DimShuffle(input_ndim=1, new_order=[0, "x"])(a)) is None
def test_basic_2(self): def test_basic_2(self):
# Test that it works on scalar variables # Test that it works on scalar variables
a = dscalar() a = dscalar()
d_a = DimShuffle([], [])(a) d_a = DimShuffle(input_ndim=0, new_order=[])(a)
d_a2 = DimShuffle([], ["x", "x"])(a) d_a2 = DimShuffle(input_ndim=0, new_order=["x", "x"])(a)
assert _as_scalar(a) is a assert _as_scalar(a) is a
assert _as_scalar(d_a) is a assert _as_scalar(d_a) is a
...@@ -623,13 +623,15 @@ class TestAsScalar: ...@@ -623,13 +623,15 @@ class TestAsScalar:
# Test that it fails on nonscalar variables # Test that it fails on nonscalar variables
a = matrix() a = matrix()
assert _as_scalar(a) is None assert _as_scalar(a) is None
assert _as_scalar(DimShuffle([False, False], [0, "x", 1])(a)) is None assert _as_scalar(DimShuffle(input_ndim=2, new_order=[0, "x", 1])(a)) is None
class TestRealMatrix: class TestRealMatrix:
def test_basic(self): def test_basic(self):
assert _is_real_matrix(DimShuffle([False, False], [1, 0])(matrix())) assert _is_real_matrix(DimShuffle(input_ndim=2, new_order=[1, 0])(matrix()))
assert not _is_real_matrix(DimShuffle([False], ["x", 0])(dvector())) assert not _is_real_matrix(
DimShuffle(input_ndim=1, new_order=["x", 0])(dvector())
)
""" """
......
...@@ -60,46 +60,40 @@ class TestDimShuffle(unittest_tools.InferShapeTester): ...@@ -60,46 +60,40 @@ class TestDimShuffle(unittest_tools.InferShapeTester):
((1,), ("x", "x"), (1, 1)), ((1,), ("x", "x"), (1, 1)),
]: ]:
i_shape = [entry if entry == 1 else None for entry in xsh] i_shape = [entry if entry == 1 else None for entry in xsh]
ib = [entry == 1 for entry in i_shape]
x = self.type(self.dtype, shape=i_shape)("x") x = self.type(self.dtype, shape=i_shape)("x")
e = self.op(ib, shuffle)(x) e = self.op(input_ndim=len(i_shape), new_order=shuffle)(x)
f = pytensor.function([x], e, mode=Mode(linker=linker)) f = pytensor.function([x], e, mode=Mode(linker=linker))
assert f(np.ones(xsh, dtype=self.dtype)).shape == zsh assert f(np.ones(xsh, dtype=self.dtype)).shape == zsh
# test that DimShuffle.infer_shape work correctly # test that DimShuffle.infer_shape work correctly
x = self.type(self.dtype, shape=i_shape)("x") x = self.type(self.dtype, shape=i_shape)("x")
e = self.op(ib, shuffle)(x) e = self.op(input_ndim=len(i_shape), new_order=shuffle)(x)
f = pytensor.function( f = pytensor.function(
[x], e.shape, mode=Mode(linker=linker), on_unused_input="ignore" [x], e.shape, mode=Mode(linker=linker), on_unused_input="ignore"
) )
assert all(f(np.ones(xsh, dtype=self.dtype))) == all(zsh) assert all(f(np.ones(xsh, dtype=self.dtype))) == all(zsh)
# Test when we drop a axis that is not broadcastable # Test when we drop a axis that is not broadcastable
ib = [False, True, False] x = self.type(self.dtype, shape=(2, 1, None))("x")
x = self.type(self.dtype, shape=(None, 1, None))("x") with pytest.raises(TypeError):
with pytest.raises(ValueError): self.op(input_ndim=3, new_order=shuffle)(x)
self.op(ib, shuffle)
# Test when we drop a axis that don't have shape 1 # Test when we drop a axis that don't have shape 1
ib = [True, True, False] x = self.type(self.dtype, shape=(None, 1, None))("x")
x = self.type(self.dtype, shape=(1, 1, None))("x") e = self.op(input_ndim=3, new_order=(1, 2))(x)
e = self.op(ib, (1, 2))(x) f = pytensor.function([x], e, mode=Mode(linker=linker))
f = pytensor.function([x], e.shape, mode=Mode(linker=linker)) with pytest.raises(ValueError):
with pytest.raises(TypeError): f(np.ones((2, 1, 4), dtype=self.dtype))
f(np.ones((2, 1, 4)))
# Test that we can't take a dimensions multiple time # Test that we can't take a dimensions multiple time
xsh, shuffle, zsh = ((1, 1, 4), (0, 1, 2, 0), (1, 4)) xsh, shuffle, zsh = ((1, 1, 4), (0, 1, 2, 0), (1, 4))
ib = [False, True, False]
x = self.type(self.dtype, shape=(None, 1, None))("x") x = self.type(self.dtype, shape=(None, 1, None))("x")
with pytest.raises(ValueError): with pytest.raises(ValueError):
DimShuffle(ib, shuffle) DimShuffle(input_ndim=3, new_order=shuffle)
def test_perform(self): def test_perform(self):
self.with_linker(PerformLinker()) self.with_linker(PerformLinker())
def test_c_or_py(self): def test_c_or_py(self):
# Shape op don't have C code.
# But This will test DimShuffle c code
self.with_linker(OpWiseCLinker()) self.with_linker(OpWiseCLinker())
def test_infer_shape(self): def test_infer_shape(self):
...@@ -115,12 +109,11 @@ class TestDimShuffle(unittest_tools.InferShapeTester): ...@@ -115,12 +109,11 @@ class TestDimShuffle(unittest_tools.InferShapeTester):
((1,), ("x", "x")), ((1,), ("x", "x")),
]: ]:
i_shape = [entry if entry == 1 else None for entry in xsh] i_shape = [entry if entry == 1 else None for entry in xsh]
ib = [(entry == 1) for entry in xsh]
adtens = self.type(self.dtype, shape=i_shape)("x") adtens = self.type(self.dtype, shape=i_shape)("x")
adtens_val = np.ones(xsh, dtype=self.dtype) adtens_val = np.ones(xsh, dtype=self.dtype)
self._compile_and_check( self._compile_and_check(
[adtens], [adtens],
[self.op(ib, shuffle)(adtens)], [self.op(input_ndim=len(xsh), new_order=shuffle)(adtens)],
[adtens_val], [adtens_val],
self.op, self.op,
warn=False, warn=False,
...@@ -191,11 +184,11 @@ class TestDimShuffle(unittest_tools.InferShapeTester): ...@@ -191,11 +184,11 @@ class TestDimShuffle(unittest_tools.InferShapeTester):
y = x.dimshuffle([0, 1, "x"]) y = x.dimshuffle([0, 1, "x"])
assert y.type.shape == (1, 2, 1) assert y.type.shape == (1, 2, 1)
def test_valid_input_broadcastable(self): def test_valid_input_ndim(self):
assert DimShuffle([True, False], (1, 0)).input_broadcastable == (True, False) assert DimShuffle(input_ndim=2, new_order=(1, 0)).input_ndim == 2
with pytest.raises(ValueError, match="input_broadcastable must be boolean"): with pytest.raises(TypeError, match="input_ndim must be an integer"):
DimShuffle([None, None], (1, 0)) DimShuffle(input_ndim=(True, False), new_order=(1, 0))
class TestBroadcast: class TestBroadcast:
......
...@@ -480,12 +480,9 @@ class TestSqueeze(utt.InferShapeTester): ...@@ -480,12 +480,9 @@ class TestSqueeze(utt.InferShapeTester):
assert f([0]) == 0 assert f([0]) == 0
# Test that we cannot squeeze dimensions whose length is greater than 1 # Test that we cannot squeeze dimensions whose length is greater than 1
error_txt_1 = re.escape("SpecifyShape: Got shape (3,), expected (1,).")
error_txt_2 = re.escape("SpecifyShape: dim 0 of input has shape 3, expected 1")
match = error_txt_1 if pytensor.config.mode == "FAST_COMPILE" else error_txt_2
with pytest.raises( with pytest.raises(
AssertionError, ValueError,
match=match, match="cannot reshape array of size 3 into shape ()",
): ):
f([0, 1, 2]) f([0, 1, 2])
......
...@@ -204,3 +204,12 @@ class TestFFT: ...@@ -204,3 +204,12 @@ class TestFFT:
pytensor.config.floatX pytensor.config.floatX
) )
utt.verify_grad(f_irfft, [inputs_val], eps=eps) utt.verify_grad(f_irfft, [inputs_val], eps=eps)
def test_rfft_expanded_dims_grad(self):
# Regression test for https://github.com/pymc-devs/pytensor/issues/969
def test_func(x):
return fft.rfft(x[None, :])
rng = np.random.default_rng(213)
inputs_val = rng.random((N,)).astype(pytensor.config.floatX)
utt.verify_grad(test_func, [inputs_val], rng=rng)
...@@ -4,7 +4,6 @@ import pytest ...@@ -4,7 +4,6 @@ import pytest
import pytensor import pytensor
from pytensor import function from pytensor import function
from pytensor.compile.mode import Mode from pytensor.compile.mode import Mode
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import all as pt_all from pytensor.tensor.math import all as pt_all
from pytensor.tensor.math import any as pt_any from pytensor.tensor.math import any as pt_any
from pytensor.tensor.math import argmax, argmin, max_and_argmax, mean, prod, std, var from pytensor.tensor.math import argmax, argmin, max_and_argmax, mean, prod, std, var
...@@ -40,7 +39,7 @@ class TestKeepDims: ...@@ -40,7 +39,7 @@ class TestKeepDims:
new_dims.append(i) new_dims.append(i)
i += 1 i += 1
return DimShuffle(y.type.broadcastable, new_dims)(y) return y.dimshuffle(new_dims)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"axis", "axis",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论