提交 87002260 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Move get_vector_length Subtensor tests to test_subtensor

上级 19ad27c9
...@@ -960,27 +960,6 @@ def test_get_vector_length(): ...@@ -960,27 +960,6 @@ def test_get_vector_length():
res = get_vector_length(x) res = get_vector_length(x)
assert res == 4 assert res == 4
# Test `Subtensor`s
x = as_tensor_variable(np.arange(4))
assert get_vector_length(x[2:4]) == 2
assert get_vector_length(x[2:]) == 2
assert get_vector_length(x[1:4]) == 3
assert get_vector_length(x[2:2]) == 0
assert get_vector_length(x[1:10]) == 3
# Test step
assert get_vector_length(x[1:10:2]) == 2
# Test neg start
assert get_vector_length(x[-1:4]) == 1
assert get_vector_length(x[-6:4]) == 4
# test neg stop
assert get_vector_length(x[1:-2]) == 1
assert get_vector_length(x[1:-1]) == 2
assert get_vector_length(lvector()[1:1]) == 0
assert get_vector_length(lvector()[-1:-1:3]) == 0
with pytest.raises(ValueError, match="^Length of .*"):
get_vector_length(x[lscalar() :])
# Test `Join`s # Test `Join`s
z = join(0, as_tensor_variable(1, ndim=1), as_tensor_variable(x.shape[0], ndim=1)) z = join(0, as_tensor_variable(1, ndim=1), as_tensor_variable(x.shape[0], ndim=1))
assert isinstance(z.owner.op, Join) assert isinstance(z.owner.op, Join)
......
...@@ -15,6 +15,7 @@ from aesara.configdefaults import config ...@@ -15,6 +15,7 @@ from aesara.configdefaults import config
from aesara.graph.op import get_test_value from aesara.graph.op import get_test_value
from aesara.graph.opt_utils import is_same_graph from aesara.graph.opt_utils import is_same_graph
from aesara.scalar.basic import as_scalar from aesara.scalar.basic import as_scalar
from aesara.tensor import get_vector_length
from aesara.tensor.elemwise import DimShuffle from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.math import exp, isinf from aesara.tensor.math import exp, isinf
from aesara.tensor.math import sum as aet_sum from aesara.tensor.math import sum as aet_sum
...@@ -2503,3 +2504,25 @@ def test_symbolic_slice(): ...@@ -2503,3 +2504,25 @@ def test_symbolic_slice():
a, b = x.shape[:2] a, b = x.shape[:2]
output = a.eval({x: np.zeros((5, 4, 3, 2), dtype=config.floatX)}) output = a.eval({x: np.zeros((5, 4, 3, 2), dtype=config.floatX)})
assert output == np.array(5) assert output == np.array(5)
def test_get_vector_length():
x = aet.as_tensor_variable(np.arange(4))
assert get_vector_length(x[2:4]) == 2
assert get_vector_length(x[2:]) == 2
assert get_vector_length(x[1:4]) == 3
assert get_vector_length(x[2:2]) == 0
assert get_vector_length(x[1:10]) == 3
# Test step
assert get_vector_length(x[1:10:2]) == 2
# Test neg start
assert get_vector_length(x[-1:4]) == 1
assert get_vector_length(x[-6:4]) == 4
# test neg stop
assert get_vector_length(x[1:-2]) == 1
assert get_vector_length(x[1:-1]) == 2
assert get_vector_length(lvector()[1:1]) == 0
assert get_vector_length(lvector()[-1:-1:3]) == 0
with pytest.raises(ValueError, match="^Length of .*"):
get_vector_length(x[lscalar() :])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论