提交 c0860f86 authored 作者: Margus Niitsoo's avatar Margus Niitsoo 提交者: Ricardo Vieira

Fix transpose numpy compatibility #1142

上级 3bd1bcf6
......@@ -346,7 +346,7 @@ class _tensor_py_operators:
DimShuffle
"""
if (len(pattern) == 1) and (isinstance(pattern[0], list | tuple)):
if (len(pattern) == 1) and (isinstance(pattern[0], list | tuple | np.ndarray)):
pattern = pattern[0]
ds_op = pt.elemwise.DimShuffle(input_ndim=self.type.ndim, new_order=pattern)
return ds_op(self)
......
......@@ -451,6 +451,21 @@ class TestTensorInstanceMethods:
with pytest.raises(TypeError, match=msg):
x[0] += 5
def test_transpose(self):
X, _ = self.vars
x, _ = self.vals
# Turn (2,2) -> (1,2)
X, x = X[1:, :], x[1:, :]
assert_array_equal(X.transpose(0, 1).eval({X: x}), x.transpose(0, 1))
assert_array_equal(X.transpose(1, 0).eval({X: x}), x.transpose(1, 0))
# Test handing in tuples, lists and np.arrays
equal_computations([X.transpose((1, 0))], [X.transpose(1, 0)])
equal_computations([X.transpose([1, 0])], [X.transpose(1, 0)])
equal_computations([X.transpose(np.array([1, 0]))], [X.transpose(1, 0)])
def test_deprecated_import():
with pytest.warns(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论