提交 6cadc763 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Consider hermitian in jax dispatch of pinv

上级 5a8f3137
......@@ -89,7 +89,7 @@ def jax_funcify_Dot(op, **kwargs):
@jax_funcify.register(MatrixPinv)
def jax_funcify_Pinv(op, **kwargs):
def pinv(x):
return jnp.linalg.pinv(x)
return jnp.linalg.pinv(x, hermitian=op.hermitian)
return pinv
......
......@@ -143,3 +143,34 @@ def test_pinv():
fgraph = FunctionGraph([x], [x_inv])
x_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
compare_jax_and_py(fgraph, [x_np])
def test_pinv_hermitian():
A = matrix("A", dtype="complex128")
A_h_test = np.c_[[3, 3 + 2j], [3 - 2j, 2]]
A_not_h_test = A_h_test + 0 + 1j
A_inv = at_nlinalg.pinv(A, hermitian=False)
jax_fn = function([A], A_inv, mode="JAX")
assert np.allclose(jax_fn(A_h_test), np.linalg.pinv(A_h_test, hermitian=False))
assert np.allclose(jax_fn(A_h_test), np.linalg.pinv(A_h_test, hermitian=True))
assert np.allclose(
jax_fn(A_not_h_test), np.linalg.pinv(A_not_h_test, hermitian=False)
)
assert not np.allclose(
jax_fn(A_not_h_test), np.linalg.pinv(A_not_h_test, hermitian=True)
)
A_inv = at_nlinalg.pinv(A, hermitian=True)
jax_fn = function([A], A_inv, mode="JAX")
assert np.allclose(jax_fn(A_h_test), np.linalg.pinv(A_h_test, hermitian=False))
assert np.allclose(jax_fn(A_h_test), np.linalg.pinv(A_h_test, hermitian=True))
assert not np.allclose(
jax_fn(A_not_h_test), np.linalg.pinv(A_not_h_test, hermitian=False)
)
# Numpy fails differently than JAX when hermitian assumption is violated
assert not np.allclose(
jax_fn(A_not_h_test), np.linalg.pinv(A_not_h_test, hermitian=True)
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论