Unverified 提交 1a9b04bf authored 作者: Ravin Kumar's avatar Ravin Kumar 提交者: GitHub

Jax eye implementation (#153)

上级 99bb8083
......@@ -222,6 +222,14 @@ def test_jax_basic():
)
def test_jax_eye():
"""Tests jaxification of the Eye operator"""
out = tt.eye(3)
out_fg = theano.gof.FunctionGraph([], [out])
compare_jax_and_py(out_fg, [])
def test_jax_basic_multiout():
np.random.seed(213234)
......
......@@ -27,6 +27,7 @@ from theano.tensor.basic import (
AllocEmpty,
ARange,
Dot,
Eye,
Join,
MaxAndArgmax,
Reshape,
......@@ -984,3 +985,13 @@ def jax_funcify_RavelMultiIndex(op):
return jnp.ravel_multi_index(multi_index, dims, mode=mode, order=order)
return ravelmultiindex
@jax_funcify.register(Eye)
def jax_funcify_Eye(op):
dtype = op.dtype
def eye(N, M, k):
return jnp.eye(N, M, k, dtype=dtype)
return eye
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论