Unverified 提交 ce01b113 authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: GitHub

Add JAX conversion for LogSoftmax and a fix for jax_funcify_join (#343)

上级 219a9516
...@@ -52,7 +52,7 @@ from aesara.tensor.nlinalg import ( ...@@ -52,7 +52,7 @@ from aesara.tensor.nlinalg import (
QRFull, QRFull,
QRIncomplete, QRIncomplete,
) )
from aesara.tensor.nnet.basic import Softmax from aesara.tensor.nnet.basic import LogSoftmax, Softmax
from aesara.tensor.nnet.sigm import ScalarSoftplus from aesara.tensor.nnet.sigm import ScalarSoftplus
from aesara.tensor.random.op import RandomVariable from aesara.tensor.random.op import RandomVariable
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
...@@ -275,6 +275,14 @@ def jax_funcify_Softmax(op): ...@@ -275,6 +275,14 @@ def jax_funcify_Softmax(op):
return softmax return softmax
@jax_funcify.register(LogSoftmax)
def jax_funcify_LogSoftmax(op):
def log_softmax(x):
return jax.nn.log_softmax(x)
return log_softmax
@jax_funcify.register(ScalarSoftplus) @jax_funcify.register(ScalarSoftplus)
def jax_funcify_ScalarSoftplus(op): def jax_funcify_ScalarSoftplus(op):
def scalarsoftplus(x): def scalarsoftplus(x):
...@@ -786,6 +794,8 @@ def jax_funcify_DimShuffle(op): ...@@ -786,6 +794,8 @@ def jax_funcify_DimShuffle(op):
@jax_funcify.register(Join) @jax_funcify.register(Join)
def jax_funcify_Join(op): def jax_funcify_Join(op):
def join(axis, *tensors): def join(axis, *tensors):
# tensors could also be tuples, and in this case they don't have a ndim
tensors = [jnp.asarray(tensor) for tensor in tensors]
view = op.view view = op.view
if (view != -1) and all( if (view != -1) and all(
[ [
......
...@@ -703,6 +703,45 @@ def test_jax_Dimshuffle(): ...@@ -703,6 +703,45 @@ def test_jax_Dimshuffle():
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)]) compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
def test_jax_Join():
a = matrix("a")
b = matrix("b")
x = aet.join(0, a, b)
x_fg = FunctionGraph([a, b], [x])
compare_jax_and_py(
x_fg,
[
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
np.c_[[4.0, 5.0, 6.0]].astype(config.floatX),
],
)
compare_jax_and_py(
x_fg,
[
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
np.c_[[4.0, 5.0]].astype(config.floatX),
],
)
x = aet.join(1, a, b)
x_fg = FunctionGraph([a, b], [x])
compare_jax_and_py(
x_fg,
[
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
np.c_[[4.0, 5.0, 6.0]].astype(config.floatX),
],
)
compare_jax_and_py(
x_fg,
[
np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX),
np.c_[[5.0, 6.0]].astype(config.floatX),
],
)
def test_jax_variadic_Scalar(): def test_jax_variadic_Scalar():
mu = vector("mu", dtype=config.floatX) mu = vector("mu", dtype=config.floatX)
mu.tag.test_value = np.r_[0.1, 1.1].astype(config.floatX) mu.tag.test_value = np.r_[0.1, 1.1].astype(config.floatX)
...@@ -777,6 +816,10 @@ def test_nnet(): ...@@ -777,6 +816,10 @@ def test_nnet():
fgraph = FunctionGraph([x], [out]) fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = aet_nnet.logsoftmax(x)
fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_tensor_basics(): def test_tensor_basics():
y = vector("y") y = vector("y")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论