Unverified 提交 ce01b113 authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: GitHub

Add JAX conversion for LogSoftmax and a fix for jax_funcify_join (#343)

上级 219a9516
......@@ -52,7 +52,7 @@ from aesara.tensor.nlinalg import (
QRFull,
QRIncomplete,
)
from aesara.tensor.nnet.basic import Softmax
from aesara.tensor.nnet.basic import LogSoftmax, Softmax
from aesara.tensor.nnet.sigm import ScalarSoftplus
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
......@@ -275,6 +275,14 @@ def jax_funcify_Softmax(op):
return softmax
@jax_funcify.register(LogSoftmax)
def jax_funcify_LogSoftmax(op):
def log_softmax(x):
return jax.nn.log_softmax(x)
return log_softmax
@jax_funcify.register(ScalarSoftplus)
def jax_funcify_ScalarSoftplus(op):
def scalarsoftplus(x):
......@@ -786,6 +794,8 @@ def jax_funcify_DimShuffle(op):
@jax_funcify.register(Join)
def jax_funcify_Join(op):
def join(axis, *tensors):
# tensors could also be tuples, and in this case they don't have a ndim
tensors = [jnp.asarray(tensor) for tensor in tensors]
view = op.view
if (view != -1) and all(
[
......
......@@ -703,6 +703,45 @@ def test_jax_Dimshuffle():
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
def test_jax_Join():
a = matrix("a")
b = matrix("b")
x = aet.join(0, a, b)
x_fg = FunctionGraph([a, b], [x])
compare_jax_and_py(
x_fg,
[
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
np.c_[[4.0, 5.0, 6.0]].astype(config.floatX),
],
)
compare_jax_and_py(
x_fg,
[
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
np.c_[[4.0, 5.0]].astype(config.floatX),
],
)
x = aet.join(1, a, b)
x_fg = FunctionGraph([a, b], [x])
compare_jax_and_py(
x_fg,
[
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
np.c_[[4.0, 5.0, 6.0]].astype(config.floatX),
],
)
compare_jax_and_py(
x_fg,
[
np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX),
np.c_[[5.0, 6.0]].astype(config.floatX),
],
)
def test_jax_variadic_Scalar():
mu = vector("mu", dtype=config.floatX)
mu.tag.test_value = np.r_[0.1, 1.1].astype(config.floatX)
......@@ -777,6 +816,10 @@ def test_nnet():
fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = aet_nnet.logsoftmax(x)
fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_tensor_basics():
y = vector("y")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论