Unverified 提交 cd98c61a authored 作者: Kyle Beauchamp's avatar Kyle Beauchamp 提交者: GitHub

Add JAX conversion for Softmax Op (#239)

上级 b666bdbb
...@@ -743,6 +743,10 @@ def test_nnet(): ...@@ -743,6 +743,10 @@ def test_nnet():
fgraph = theano.gof.FunctionGraph([x], [out]) fgraph = theano.gof.FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = tt.nnet.softmax(x)
fgraph = theano.gof.FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_tensor_basics(): def test_tensor_basics():
y = tt.vector("y") y = tt.vector("y")
......
...@@ -58,6 +58,7 @@ from theano.tensor.nlinalg import ( ...@@ -58,6 +58,7 @@ from theano.tensor.nlinalg import (
QRFull, QRFull,
QRIncomplete, QRIncomplete,
) )
from theano.tensor.nnet import Softmax
from theano.tensor.nnet.sigm import ScalarSoftplus from theano.tensor.nnet.sigm import ScalarSoftplus
from theano.tensor.opt import MakeVector from theano.tensor.opt import MakeVector
from theano.tensor.slinalg import Cholesky, Solve from theano.tensor.slinalg import Cholesky, Solve
...@@ -250,6 +251,14 @@ def jax_funcify_Identity(op): ...@@ -250,6 +251,14 @@ def jax_funcify_Identity(op):
return identity return identity
@jax_funcify.register(Softmax)
def jax_funcify_Softmax(op):
def softmax(x):
return jax.nn.softmax(x)
return softmax
@jax_funcify.register(ScalarSoftplus) @jax_funcify.register(ScalarSoftplus)
def jax_funcify_ScalarSoftplus(op): def jax_funcify_ScalarSoftplus(op):
def scalarsoftplus(x): def scalarsoftplus(x):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论