Unverified 提交 1a9b04bf authored 作者: Ravin Kumar's avatar Ravin Kumar 提交者: GitHub

Jax eye implementation (#153)

上级 99bb8083
...@@ -222,6 +222,14 @@ def test_jax_basic(): ...@@ -222,6 +222,14 @@ def test_jax_basic():
) )
def test_jax_eye():
"""Tests jaxification of the Eye operator"""
out = tt.eye(3)
out_fg = theano.gof.FunctionGraph([], [out])
compare_jax_and_py(out_fg, [])
def test_jax_basic_multiout(): def test_jax_basic_multiout():
np.random.seed(213234) np.random.seed(213234)
......
...@@ -27,6 +27,7 @@ from theano.tensor.basic import ( ...@@ -27,6 +27,7 @@ from theano.tensor.basic import (
AllocEmpty, AllocEmpty,
ARange, ARange,
Dot, Dot,
Eye,
Join, Join,
MaxAndArgmax, MaxAndArgmax,
Reshape, Reshape,
...@@ -984,3 +985,13 @@ def jax_funcify_RavelMultiIndex(op): ...@@ -984,3 +985,13 @@ def jax_funcify_RavelMultiIndex(op):
return jnp.ravel_multi_index(multi_index, dims, mode=mode, order=order) return jnp.ravel_multi_index(multi_index, dims, mode=mode, order=order)
return ravelmultiindex return ravelmultiindex
@jax_funcify.register(Eye)
def jax_funcify_Eye(op):
dtype = op.dtype
def eye(N, M, k):
return jnp.eye(N, M, k, dtype=dtype)
return eye
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论