提交 b2946560 authored 作者: Ricardo's avatar Ricardo 提交者: Ricardo Vieira

Assert something about the output of function calls in `test_blas.py`

上级 be6a1fbb
...@@ -1901,27 +1901,25 @@ class TestGer(unittest_tools.OptimizationTestMixin): ...@@ -1901,27 +1901,25 @@ class TestGer(unittest_tools.OptimizationTestMixin):
rng = np.random.default_rng(unittest_tools.fetch_seed()) rng = np.random.default_rng(unittest_tools.fetch_seed())
f = self.function([self.x, self.y], outer(self.x, self.y)) f = self.function([self.x, self.y], outer(self.x, self.y))
self.assertFunctionContains(f, self.ger_destructive) self.assertFunctionContains(f, self.ger_destructive)
# TODO FIXME: This is NOT a test.
f( f(
rng.random((5)).astype(self.dtype), rng.random((5)).astype(self.dtype),
rng.random((4)).astype(self.dtype), rng.random((4)).astype(self.dtype),
) ).shape == (5, 4)
def test_A_plus_outer(self): def test_A_plus_outer(self):
rng = np.random.default_rng(unittest_tools.fetch_seed()) rng = np.random.default_rng(unittest_tools.fetch_seed())
f = self.function([self.A, self.x, self.y], self.A + outer(self.x, self.y)) f = self.function([self.A, self.x, self.y], self.A + outer(self.x, self.y))
self.assertFunctionContains(f, self.ger) self.assertFunctionContains(f, self.ger)
# TODO FIXME: This is NOT a test.
f( f(
rng.random((5, 4)).astype(self.dtype), rng.random((5, 4)).astype(self.dtype),
rng.random((5)).astype(self.dtype), rng.random((5)).astype(self.dtype),
rng.random((4)).astype(self.dtype), rng.random((4)).astype(self.dtype),
) ).shape == (5, 4)
f( f(
rng.random((5, 4)).astype(self.dtype)[::-1, ::-1], rng.random((5, 4)).astype(self.dtype)[::-1, ::-1],
rng.random((5)).astype(self.dtype), rng.random((5)).astype(self.dtype),
rng.random((4)).astype(self.dtype), rng.random((4)).astype(self.dtype),
) ).shape == (5, 4)
def test_A_plus_scaled_outer(self): def test_A_plus_scaled_outer(self):
rng = np.random.default_rng(unittest_tools.fetch_seed()) rng = np.random.default_rng(unittest_tools.fetch_seed())
...@@ -1929,17 +1927,16 @@ class TestGer(unittest_tools.OptimizationTestMixin): ...@@ -1929,17 +1927,16 @@ class TestGer(unittest_tools.OptimizationTestMixin):
[self.A, self.x, self.y], self.A + 0.1 * outer(self.x, self.y) [self.A, self.x, self.y], self.A + 0.1 * outer(self.x, self.y)
) )
self.assertFunctionContains(f, self.ger) self.assertFunctionContains(f, self.ger)
# TODO FIXME: This is NOT a test.
f( f(
rng.random((5, 4)).astype(self.dtype), rng.random((5, 4)).astype(self.dtype),
rng.random((5)).astype(self.dtype), rng.random((5)).astype(self.dtype),
rng.random((4)).astype(self.dtype), rng.random((4)).astype(self.dtype),
) ).shape == (5, 4)
f( f(
rng.random((5, 4)).astype(self.dtype)[::-1, ::-1], rng.random((5, 4)).astype(self.dtype)[::-1, ::-1],
rng.random((5)).astype(self.dtype), rng.random((5)).astype(self.dtype),
rng.random((4)).astype(self.dtype), rng.random((4)).astype(self.dtype),
) ).shape == (5, 4)
def test_scaled_A_plus_scaled_outer(self): def test_scaled_A_plus_scaled_outer(self):
rng = np.random.default_rng(unittest_tools.fetch_seed()) rng = np.random.default_rng(unittest_tools.fetch_seed())
...@@ -1951,17 +1948,16 @@ class TestGer(unittest_tools.OptimizationTestMixin): ...@@ -1951,17 +1948,16 @@ class TestGer(unittest_tools.OptimizationTestMixin):
# Why gemm? This make the graph simpler did we test that it # Why gemm? This make the graph simpler did we test that it
# make it faster? # make it faster?
self.assertFunctionContains(f, self.gemm) self.assertFunctionContains(f, self.gemm)
# TODO FIXME: This is NOT a test.
f( f(
rng.random((5, 4)).astype(self.dtype), rng.random((5, 4)).astype(self.dtype),
rng.random((5)).astype(self.dtype), rng.random((5)).astype(self.dtype),
rng.random((4)).astype(self.dtype), rng.random((4)).astype(self.dtype),
) ).shape == (5, 4)
f( f(
rng.random((5, 4)).astype(self.dtype)[::-1, ::-1], rng.random((5, 4)).astype(self.dtype)[::-1, ::-1],
rng.random((5)).astype(self.dtype), rng.random((5)).astype(self.dtype),
rng.random((4)).astype(self.dtype), rng.random((4)).astype(self.dtype),
) ).shape == (5, 4)
def given_dtype(self, dtype, M, N): def given_dtype(self, dtype, M, N):
# test corner case shape and dtype # test corner case shape and dtype
...@@ -1971,17 +1967,16 @@ class TestGer(unittest_tools.OptimizationTestMixin): ...@@ -1971,17 +1967,16 @@ class TestGer(unittest_tools.OptimizationTestMixin):
[self.A, self.x, self.y], self.A + 0.1 * outer(self.x, self.y) [self.A, self.x, self.y], self.A + 0.1 * outer(self.x, self.y)
) )
self.assertFunctionContains(f, self.ger) self.assertFunctionContains(f, self.ger)
# TODO FIXME: This is NOT a test.
f( f(
rng.random((M, N)).astype(self.dtype), rng.random((M, N)).astype(self.dtype),
rng.random((M)).astype(self.dtype), rng.random((M)).astype(self.dtype),
rng.random((N)).astype(self.dtype), rng.random((N)).astype(self.dtype),
) ).shape == (5, 4)
f( f(
rng.random((M, N)).astype(self.dtype)[::-1, ::-1], rng.random((M, N)).astype(self.dtype)[::-1, ::-1],
rng.random((M)).astype(self.dtype), rng.random((M)).astype(self.dtype),
rng.random((N)).astype(self.dtype), rng.random((N)).astype(self.dtype),
) ).shape == (5, 4)
def test_f32_0_0(self): def test_f32_0_0(self):
return self.given_dtype("float32", 0, 0) return self.given_dtype("float32", 0, 0)
...@@ -2024,7 +2019,7 @@ class TestGer(unittest_tools.OptimizationTestMixin): ...@@ -2024,7 +2019,7 @@ class TestGer(unittest_tools.OptimizationTestMixin):
], ],
) )
self.assertFunctionContains(f, self.ger_destructive) self.assertFunctionContains(f, self.ger_destructive)
# TODO FIXME: This is NOT a test. # TODO: Test something about the updated value of `A`
f( f(
rng.random((4)).astype(self.dtype), rng.random((4)).astype(self.dtype),
rng.random((5)).astype(self.dtype), rng.random((5)).astype(self.dtype),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论