提交 cf7dc4a8 authored 作者: Anatoly's avatar Anatoly 提交者: Brandon T. Willard

add test for checks in pool_3d

上级 e0a8c79e
...@@ -19,10 +19,12 @@ from aesara.tensor.signal.pool import ( ...@@ -19,10 +19,12 @@ from aesara.tensor.signal.pool import (
) )
from aesara.tensor.type import ( from aesara.tensor.type import (
TensorType, TensorType,
fmatrix,
dmatrix, dmatrix,
dtensor3, dtensor3,
dtensor4, dtensor4,
ftensor4, ftensor4,
ftensor3,
ivector, ivector,
tensor, tensor,
tensor4, tensor4,
...@@ -1291,7 +1293,7 @@ class TestDownsampleFactorMax(utt.InferShapeTester): ...@@ -1291,7 +1293,7 @@ class TestDownsampleFactorMax(utt.InferShapeTester):
utt.assert_allclose(var_dx, fix_dx) utt.assert_allclose(var_dx, fix_dx)
def test_pool_2d_checks(self): def test_pool_2d_checks(self):
x = ftensor4() x = fmatrix()
with pytest.raises( with pytest.raises(
ValueError, ValueError,
...@@ -1305,7 +1307,7 @@ class TestDownsampleFactorMax(utt.InferShapeTester): ...@@ -1305,7 +1307,7 @@ class TestDownsampleFactorMax(utt.InferShapeTester):
with pytest.raises( with pytest.raises(
ValueError, ValueError,
match="You can't provide a tuple value to both 'st and 'stride'. Please provide a value only to 'stride'."): match=r"You can't provide a tuple value to both 'st and 'stride'."):
pool_2d(input=x, ws=(1, 1), st=(1,1), stride=(1, 1)) pool_2d(input=x, ws=(1, 1), st=(1,1), stride=(1, 1))
with pytest.raises( with pytest.raises(
...@@ -1317,3 +1319,30 @@ class TestDownsampleFactorMax(utt.InferShapeTester): ...@@ -1317,3 +1319,30 @@ class TestDownsampleFactorMax(utt.InferShapeTester):
out = pool_2d(input=x, ws=(1, 1)) out = pool_2d(input=x, ws=(1, 1))
assert not out.owner.op.ignore_border assert not out.owner.op.ignore_border
def test_pool_3d_checks(self):
x = ftensor3()
with pytest.raises(
ValueError,
match=r"You can't provide a tuple value to both 'ws' and 'ds'."):
pool_3d(input=x, ds=(1, 1, 1), ws=(1, 1, 1))
with pytest.raises(
ValueError,
match="You must provide a tuple value for the window size."):
pool_3d(input=x)
with pytest.raises(
ValueError,
match=r"You can't provide a tuple value to both 'st and 'stride'"):
pool_3d(input=x, ws=(1, 1, 1), st=(1, 1, 1), stride=(1, 1, 1))
with pytest.raises(
NotImplementedError,
match="pool_3d requires a dimension >= 3"):
pool_3d(input=fmatrix(), ws=(1, 1, 1))
with pytest.deprecated_call():
out = pool_3d(input=x, ws=(1, 1, 1))
assert not out.owner.op.ignore_border
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论