Unverified 提交 efa845a3 authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Add jax dispatch for `KroneckerProduct` `Op` (#822)

上级 0e29d76b
......@@ -8,6 +8,7 @@ from pytensor.tensor.nlinalg import (
Det,
Eig,
Eigh,
KroneckerProduct,
MatrixInverse,
MatrixPinv,
QRFull,
......@@ -104,6 +105,14 @@ def jax_funcify_BatchedDot(op, **kwargs):
return batched_dot
@jax_funcify.register(KroneckerProduct)
def jax_funcify_KroneckerProduct(op, **kwargs):
def _kron(x, y):
return jnp.kron(x, y)
return _kron
@jax_funcify.register(Max)
def jax_funcify_Max(op, **kwargs):
axis = op.axis
......
......@@ -165,3 +165,15 @@ def test_pinv_hermitian():
assert not np.allclose(
jax_fn(A_not_h_test), np.linalg.pinv(A_not_h_test, hermitian=True)
)
def test_kron():
x = matrix("x")
y = matrix("y")
z = pt_nlinalg.kron(x, y)
fgraph = FunctionGraph([x, y], [z])
x_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
y_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
compare_jax_and_py(fgraph, [x_np, y_np])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论