Unverified 提交 efa845a3 authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Add jax dispatch for `KroneckerProduct` `Op` (#822)

上级 0e29d76b
...@@ -8,6 +8,7 @@ from pytensor.tensor.nlinalg import ( ...@@ -8,6 +8,7 @@ from pytensor.tensor.nlinalg import (
Det, Det,
Eig, Eig,
Eigh, Eigh,
KroneckerProduct,
MatrixInverse, MatrixInverse,
MatrixPinv, MatrixPinv,
QRFull, QRFull,
...@@ -104,6 +105,14 @@ def jax_funcify_BatchedDot(op, **kwargs): ...@@ -104,6 +105,14 @@ def jax_funcify_BatchedDot(op, **kwargs):
return batched_dot return batched_dot
@jax_funcify.register(KroneckerProduct)
def jax_funcify_KroneckerProduct(op, **kwargs):
def _kron(x, y):
return jnp.kron(x, y)
return _kron
@jax_funcify.register(Max) @jax_funcify.register(Max)
def jax_funcify_Max(op, **kwargs): def jax_funcify_Max(op, **kwargs):
axis = op.axis axis = op.axis
......
...@@ -165,3 +165,15 @@ def test_pinv_hermitian(): ...@@ -165,3 +165,15 @@ def test_pinv_hermitian():
assert not np.allclose( assert not np.allclose(
jax_fn(A_not_h_test), np.linalg.pinv(A_not_h_test, hermitian=True) jax_fn(A_not_h_test), np.linalg.pinv(A_not_h_test, hermitian=True)
) )
def test_kron():
x = matrix("x")
y = matrix("y")
z = pt_nlinalg.kron(x, y)
fgraph = FunctionGraph([x, y], [z])
x_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
y_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
compare_jax_and_py(fgraph, [x_np, y_np])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论