提交 72e2da25 authored 作者: Anatoly's avatar Anatoly 提交者: Brandon T. Willard

add tests for out_shape input check

上级 cf7dc4a8
......@@ -12,6 +12,7 @@ from aesara.tensor.signal.pool import (
AveragePoolGrad,
DownsampleFactorMaxGradGrad,
MaxPoolGrad,
PoolGrad,
Pool,
max_pool_2d_same_size,
pool_2d,
......@@ -1292,23 +1293,32 @@ class TestDownsampleFactorMax(utt.InferShapeTester):
utt.assert_allclose(var_y, fix_y)
utt.assert_allclose(var_dx, fix_dx)
def test_pool_2d_checks(self):
x = fmatrix()
@staticmethod
def checks_helper(func, x, ws, stride, pad):
with pytest.raises(
ValueError,
match=r"You can't provide a tuple value to both 'ws' and 'ds'."):
pool_2d(input=x, ds=(1,1), ws=(1, 1))
func(x, ds=ws, ws=ws)
with pytest.raises(
ValueError,
match="You must provide a tuple value for the window size."):
pool_2d(input=x)
func(x)
with pytest.raises(
ValueError,
match=r"You can't provide a tuple value to both 'st and 'stride'."):
pool_2d(input=x, ws=(1, 1), st=(1,1), stride=(1, 1))
func(x, ws=ws, st=stride, stride=stride)
with pytest.raises(
ValueError,
match=r"You can't provide a tuple value to both 'padding' and pad."):
func(x, ws=ws, pad=pad, padding=pad)
def test_pool_2d_checks(self):
x = fmatrix()
self.checks_helper(pool_2d, x, ws=(1, 1), stride=(1, 1), pad=(1, 1))
with pytest.raises(
NotImplementedError,
......@@ -1322,20 +1332,7 @@ class TestDownsampleFactorMax(utt.InferShapeTester):
def test_pool_3d_checks(self):
x = ftensor3()
with pytest.raises(
ValueError,
match=r"You can't provide a tuple value to both 'ws' and 'ds'."):
pool_3d(input=x, ds=(1, 1, 1), ws=(1, 1, 1))
with pytest.raises(
ValueError,
match="You must provide a tuple value for the window size."):
pool_3d(input=x)
with pytest.raises(
ValueError,
match=r"You can't provide a tuple value to both 'st and 'stride'"):
pool_3d(input=x, ws=(1, 1, 1), st=(1, 1, 1), stride=(1, 1, 1))
self.checks_helper(pool_3d, x, ws=(1, 1, 1), stride=(1, 1, 1), pad=(1, 1, 1))
with pytest.raises(
NotImplementedError,
......@@ -1346,3 +1343,13 @@ class TestDownsampleFactorMax(utt.InferShapeTester):
out = pool_3d(input=x, ws=(1, 1, 1))
assert not out.owner.op.ignore_border
@pytest.mark.parametrize("func", [Pool.out_shape, PoolGrad.out_shape])
def test_Pool_out_shape_checks(self, func):
x = (10, 10)
self.checks_helper(func, x, ws=(1, 1), stride=(1, 1), pad=(1, 1))
with pytest.raises(
TypeError,
match="imgshape must have at least 3 dimensions"):
func(x, (2, 2), ndim=3)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论