提交 f84d2b00 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Move TestTensorInstanceMethods to tests.tensor.test_var

上级 880015b4
...@@ -6,7 +6,6 @@ from tempfile import mkstemp ...@@ -6,7 +6,6 @@ from tempfile import mkstemp
import numpy as np import numpy as np
import pytest import pytest
from numpy.testing import assert_array_equal
import aesara import aesara
import aesara.scalar as aes import aesara.scalar as aes
...@@ -115,7 +114,6 @@ from aesara.tensor.type import ( ...@@ -115,7 +114,6 @@ from aesara.tensor.type import (
ivector, ivector,
lscalar, lscalar,
lvector, lvector,
matrices,
matrix, matrix,
row, row,
scalar, scalar,
...@@ -3960,75 +3958,6 @@ class TestInferShape(utt.InferShapeTester): ...@@ -3960,75 +3958,6 @@ class TestInferShape(utt.InferShapeTester):
) )
class TestTensorInstanceMethods:
def setup_method(self):
self.vars = matrices("X", "Y")
self.vals = [m.astype(config.floatX) for m in [random(2, 2), random(2, 2)]]
def test_repeat(self):
X, _ = self.vars
x, _ = self.vals
assert_array_equal(X.repeat(2).eval({X: x}), x.repeat(2))
def test_trace(self):
X, _ = self.vars
x, _ = self.vals
assert_array_equal(X.trace().eval({X: x}), x.trace())
def test_ravel(self):
X, _ = self.vars
x, _ = self.vals
assert_array_equal(X.ravel().eval({X: x}), x.ravel())
def test_diagonal(self):
X, _ = self.vars
x, _ = self.vals
assert_array_equal(X.diagonal().eval({X: x}), x.diagonal())
assert_array_equal(X.diagonal(1).eval({X: x}), x.diagonal(1))
assert_array_equal(X.diagonal(-1).eval({X: x}), x.diagonal(-1))
for offset, axis1, axis2 in [(1, 0, 1), (-1, 0, 1), (0, 1, 0), (-2, 1, 0)]:
assert_array_equal(
X.diagonal(offset, axis1, axis2).eval({X: x}),
x.diagonal(offset, axis1, axis2),
)
def test_take(self):
X, _ = self.vars
x, _ = self.vals
indices = [1, 0, 3]
assert_array_equal(X.take(indices).eval({X: x}), x.take(indices))
indices = [1, 0, 1]
assert_array_equal(X.take(indices, 1).eval({X: x}), x.take(indices, 1))
indices = np.array([-10, 5, 12], dtype="int32")
assert_array_equal(
X.take(indices, 1, mode="wrap").eval({X: x}),
x.take(indices, 1, mode="wrap"),
)
assert_array_equal(
X.take(indices, -1, mode="wrap").eval({X: x}),
x.take(indices, -1, mode="wrap"),
)
assert_array_equal(
X.take(indices, 1, mode="clip").eval({X: x}),
x.take(indices, 1, mode="clip"),
)
assert_array_equal(
X.take(indices, -1, mode="clip").eval({X: x}),
x.take(indices, -1, mode="clip"),
)
# Test error handling
with pytest.raises(IndexError):
X.take(indices).eval({X: x})
with pytest.raises(IndexError):
(2 * X.take(indices)).eval({X: x})
with pytest.raises(TypeError):
X.take([0.0])
indices = [[1, 0, 1], [0, 1, 1]]
assert_array_equal(X.take(indices, 1).eval({X: x}), x.take(indices, 1))
# Test equivalent advanced indexing
assert_array_equal(X[:, indices].eval({X: x}), x[:, indices])
class TestSwapaxes: class TestSwapaxes:
def test_no_dimensional_input(self): def test_no_dimensional_input(self):
with pytest.raises(IndexError): with pytest.raises(IndexError):
......
...@@ -2,7 +2,7 @@ from copy import copy ...@@ -2,7 +2,7 @@ from copy import copy
import numpy as np import numpy as np
import pytest import pytest
from numpy.testing import assert_equal, assert_string_equal from numpy.testing import assert_array_equal, assert_equal, assert_string_equal
import aesara import aesara
import tests.unittest_tools as utt import tests.unittest_tools as utt
...@@ -21,6 +21,7 @@ from aesara.tensor.type import ( ...@@ -21,6 +21,7 @@ from aesara.tensor.type import (
dvector, dvector,
iscalar, iscalar,
ivector, ivector,
matrices,
matrix, matrix,
scalar, scalar,
tensor3, tensor3,
...@@ -32,6 +33,7 @@ from aesara.tensor.var import ( ...@@ -32,6 +33,7 @@ from aesara.tensor.var import (
TensorConstant, TensorConstant,
TensorVariable, TensorVariable,
) )
from tests.tensor.utils import random
pytestmark = pytest.mark.filterwarnings("error") pytestmark = pytest.mark.filterwarnings("error")
...@@ -322,3 +324,74 @@ class TestTensorConstantSignature: ...@@ -322,3 +324,74 @@ class TestTensorConstantSignature:
y_sig = y.signature() y_sig = y.signature()
assert hash(x_sig) == hash(y_sig) assert hash(x_sig) == hash(y_sig)
class TestTensorInstanceMethods:
def setup_method(self):
self.vars = matrices("X", "Y")
self.vals = [
m.astype(aesara.config.floatX) for m in [random(2, 2), random(2, 2)]
]
def test_repeat(self):
X, _ = self.vars
x, _ = self.vals
assert_array_equal(X.repeat(2).eval({X: x}), x.repeat(2))
def test_trace(self):
X, _ = self.vars
x, _ = self.vals
assert_array_equal(X.trace().eval({X: x}), x.trace())
def test_ravel(self):
X, _ = self.vars
x, _ = self.vals
assert_array_equal(X.ravel().eval({X: x}), x.ravel())
def test_diagonal(self):
X, _ = self.vars
x, _ = self.vals
assert_array_equal(X.diagonal().eval({X: x}), x.diagonal())
assert_array_equal(X.diagonal(1).eval({X: x}), x.diagonal(1))
assert_array_equal(X.diagonal(-1).eval({X: x}), x.diagonal(-1))
for offset, axis1, axis2 in [(1, 0, 1), (-1, 0, 1), (0, 1, 0), (-2, 1, 0)]:
assert_array_equal(
X.diagonal(offset, axis1, axis2).eval({X: x}),
x.diagonal(offset, axis1, axis2),
)
def test_take(self):
X, _ = self.vars
x, _ = self.vals
indices = [1, 0, 3]
assert_array_equal(X.take(indices).eval({X: x}), x.take(indices))
indices = [1, 0, 1]
assert_array_equal(X.take(indices, 1).eval({X: x}), x.take(indices, 1))
indices = np.array([-10, 5, 12], dtype="int32")
assert_array_equal(
X.take(indices, 1, mode="wrap").eval({X: x}),
x.take(indices, 1, mode="wrap"),
)
assert_array_equal(
X.take(indices, -1, mode="wrap").eval({X: x}),
x.take(indices, -1, mode="wrap"),
)
assert_array_equal(
X.take(indices, 1, mode="clip").eval({X: x}),
x.take(indices, 1, mode="clip"),
)
assert_array_equal(
X.take(indices, -1, mode="clip").eval({X: x}),
x.take(indices, -1, mode="clip"),
)
# Test error handling
with pytest.raises(IndexError):
X.take(indices).eval({X: x})
with pytest.raises(IndexError):
(2 * X.take(indices)).eval({X: x})
with pytest.raises(TypeError):
X.take([0.0])
indices = [[1, 0, 1], [0, 1, 1]]
assert_array_equal(X.take(indices, 1).eval({X: x}), x.take(indices, 1))
# Test equivalent advanced indexing
assert_array_equal(X[:, indices].eval({X: x}), x[:, indices])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论