Unverified 提交 cd98c61a authored 作者: Kyle Beauchamp's avatar Kyle Beauchamp 提交者: GitHub

Add JAX conversion for Softmax Op (#239)

上级 b666bdbb
......@@ -743,6 +743,10 @@ def test_nnet():
fgraph = theano.gof.FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = tt.nnet.softmax(x)
fgraph = theano.gof.FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_tensor_basics():
y = tt.vector("y")
......
......@@ -58,6 +58,7 @@ from theano.tensor.nlinalg import (
QRFull,
QRIncomplete,
)
from theano.tensor.nnet import Softmax
from theano.tensor.nnet.sigm import ScalarSoftplus
from theano.tensor.opt import MakeVector
from theano.tensor.slinalg import Cholesky, Solve
......@@ -250,6 +251,14 @@ def jax_funcify_Identity(op):
return identity
@jax_funcify.register(Softmax)
def jax_funcify_Softmax(op):
def softmax(x):
return jax.nn.softmax(x)
return softmax
@jax_funcify.register(ScalarSoftplus)
def jax_funcify_ScalarSoftplus(op):
def scalarsoftplus(x):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论