提交 89353bd2 authored 作者: Ricardo's avatar Ricardo 提交者: Ricardo Vieira

Actually test dtype in `TestGer.given_dtype`

上级 b2946560
......@@ -1959,52 +1959,64 @@ class TestGer(unittest_tools.OptimizationTestMixin):
rng.random((4)).astype(self.dtype),
).shape == (5, 4)
def given_dtype(self, dtype, M, N):
def given_dtype(self, dtype, M, N, *, destructive=True):
# test corner case shape and dtype
rng = np.random.default_rng(unittest_tools.fetch_seed())
f = self.function(
[self.A, self.x, self.y], self.A + 0.1 * outer(self.x, self.y)
A = tensor(dtype=dtype, shape=(False, False))
x = tensor(dtype=dtype, shape=(False,))
y = tensor(dtype=dtype, shape=(False,))
f = self.function([A, x, y], A + 0.1 * outer(x, y))
self.assertFunctionContains(
f, self.ger_destructive if destructive else self.ger
)
self.assertFunctionContains(f, self.ger)
f(
rng.random((M, N)).astype(self.dtype),
rng.random((M)).astype(self.dtype),
rng.random((N)).astype(self.dtype),
rng.random((M, N)).astype(dtype),
rng.random((M)).astype(dtype),
rng.random((N)).astype(dtype),
).shape == (5, 4)
f(
rng.random((M, N)).astype(self.dtype)[::-1, ::-1],
rng.random((M)).astype(self.dtype),
rng.random((N)).astype(self.dtype),
rng.random((M, N)).astype(dtype)[::-1, ::-1],
rng.random((M)).astype(dtype),
rng.random((N)).astype(dtype),
).shape == (5, 4)
def test_f32_0_0(self):
return self.given_dtype("float32", 0, 0)
return self.given_dtype("float32", 0, 0, destructive=config.floatX != "float32")
def test_f32_1_0(self):
return self.given_dtype("float32", 1, 0)
return self.given_dtype("float32", 1, 0, destructive=config.floatX != "float32")
def test_f32_0_1(self):
return self.given_dtype("float32", 0, 1)
return self.given_dtype("float32", 0, 1, destructive=config.floatX != "float32")
def test_f32_1_1(self):
return self.given_dtype("float32", 1, 1)
return self.given_dtype("float32", 1, 1, destructive=config.floatX != "float32")
def test_f32_4_4(self):
return self.given_dtype("float32", 4, 4)
return self.given_dtype("float32", 4, 4, destructive=config.floatX != "float32")
def test_f32_7_1(self):
return self.given_dtype("float32", 7, 1)
return self.given_dtype("float32", 7, 1, destructive=config.floatX != "float32")
def test_f32_1_2(self):
return self.given_dtype("float32", 1, 2)
return self.given_dtype("float32", 1, 2, destructive=config.floatX != "float32")
def test_f64_4_5(self):
return self.given_dtype("float64", 4, 5)
return self.given_dtype("float64", 4, 5, destructive=False)
@pytest.mark.xfail(
condition=config.floatX == "float32",
reason="GER from complex64 is not introduced in float32 mode",
)
def test_c64_7_1(self):
return self.given_dtype("complex64", 7, 1)
@pytest.mark.xfail(
raises=AssertionError,
reason="Unclear how this test was supposed to work with complex128",
)
def test_c128_1_9(self):
return self.given_dtype("complex128", 1, 9)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论