提交 b2946560 authored 作者: Ricardo's avatar Ricardo 提交者: Ricardo Vieira

Assert something about the output of function calls in `test_blas.py`

上级 be6a1fbb
......@@ -1901,27 +1901,25 @@ class TestGer(unittest_tools.OptimizationTestMixin):
rng = np.random.default_rng(unittest_tools.fetch_seed())
f = self.function([self.x, self.y], outer(self.x, self.y))
self.assertFunctionContains(f, self.ger_destructive)
# TODO FIXME: This is NOT a test.
f(
rng.random((5)).astype(self.dtype),
rng.random((4)).astype(self.dtype),
)
).shape == (5, 4)
def test_A_plus_outer(self):
rng = np.random.default_rng(unittest_tools.fetch_seed())
f = self.function([self.A, self.x, self.y], self.A + outer(self.x, self.y))
self.assertFunctionContains(f, self.ger)
# TODO FIXME: This is NOT a test.
f(
rng.random((5, 4)).astype(self.dtype),
rng.random((5)).astype(self.dtype),
rng.random((4)).astype(self.dtype),
)
).shape == (5, 4)
f(
rng.random((5, 4)).astype(self.dtype)[::-1, ::-1],
rng.random((5)).astype(self.dtype),
rng.random((4)).astype(self.dtype),
)
).shape == (5, 4)
def test_A_plus_scaled_outer(self):
rng = np.random.default_rng(unittest_tools.fetch_seed())
......@@ -1929,17 +1927,16 @@ class TestGer(unittest_tools.OptimizationTestMixin):
[self.A, self.x, self.y], self.A + 0.1 * outer(self.x, self.y)
)
self.assertFunctionContains(f, self.ger)
# TODO FIXME: This is NOT a test.
f(
rng.random((5, 4)).astype(self.dtype),
rng.random((5)).astype(self.dtype),
rng.random((4)).astype(self.dtype),
)
).shape == (5, 4)
f(
rng.random((5, 4)).astype(self.dtype)[::-1, ::-1],
rng.random((5)).astype(self.dtype),
rng.random((4)).astype(self.dtype),
)
).shape == (5, 4)
def test_scaled_A_plus_scaled_outer(self):
rng = np.random.default_rng(unittest_tools.fetch_seed())
......@@ -1951,17 +1948,16 @@ class TestGer(unittest_tools.OptimizationTestMixin):
# Why gemm? This make the graph simpler did we test that it
# make it faster?
self.assertFunctionContains(f, self.gemm)
# TODO FIXME: This is NOT a test.
f(
rng.random((5, 4)).astype(self.dtype),
rng.random((5)).astype(self.dtype),
rng.random((4)).astype(self.dtype),
)
).shape == (5, 4)
f(
rng.random((5, 4)).astype(self.dtype)[::-1, ::-1],
rng.random((5)).astype(self.dtype),
rng.random((4)).astype(self.dtype),
)
).shape == (5, 4)
def given_dtype(self, dtype, M, N):
# test corner case shape and dtype
......@@ -1971,17 +1967,16 @@ class TestGer(unittest_tools.OptimizationTestMixin):
[self.A, self.x, self.y], self.A + 0.1 * outer(self.x, self.y)
)
self.assertFunctionContains(f, self.ger)
# TODO FIXME: This is NOT a test.
f(
rng.random((M, N)).astype(self.dtype),
rng.random((M)).astype(self.dtype),
rng.random((N)).astype(self.dtype),
)
).shape == (5, 4)
f(
rng.random((M, N)).astype(self.dtype)[::-1, ::-1],
rng.random((M)).astype(self.dtype),
rng.random((N)).astype(self.dtype),
)
).shape == (5, 4)
def test_f32_0_0(self):
return self.given_dtype("float32", 0, 0)
......@@ -2024,7 +2019,7 @@ class TestGer(unittest_tools.OptimizationTestMixin):
],
)
self.assertFunctionContains(f, self.ger_destructive)
# TODO FIXME: This is NOT a test.
# TODO: Test something about the updated value of `A`
f(
rng.random((4)).astype(self.dtype),
rng.random((5)).astype(self.dtype),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论