提交 8a81a53d authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Simplify cholesky infer_shape test and remove slow mark

上级 94522570
......@@ -122,18 +122,20 @@ def test_cholesky_grad_indef():
assert np.all(np.isnan(chol_f(mat)))
@pytest.mark.slow
def test_cholesky_shape():
rng = np.random.default_rng(utt.fetch_seed())
def test_cholesky_infer_shape():
x = matrix()
for l in (cholesky(x), Cholesky(lower=True)(x), Cholesky(lower=False)(x)):
f_chol = pytensor.function([x], l.shape)
f_chol = pytensor.function([x], [cholesky(x).shape, cholesky(x, lower=False).shape])
if config.mode != "FAST_COMPILE":
topo_chol = f_chol.maker.fgraph.toposort()
if config.mode != "FAST_COMPILE":
assert sum(node.op.__class__ == Cholesky for node in topo_chol) == 0
for shp in [2, 3, 5]:
m = np.cov(rng.standard_normal((shp, shp + 10))).astype(config.floatX)
np.testing.assert_equal(f_chol(m), (shp, shp))
f_chol.dprint()
assert not any(
isinstance(getattr(node.op, "core_op", node.op), Cholesky)
for node in topo_chol
)
for shp in [2, 3, 5]:
res1, res2 = f_chol(np.eye(shp).astype(x.dtype))
assert tuple(res1) == (shp, shp)
assert tuple(res2) == (shp, shp)
def test_eigvalsh():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论