提交 54678486 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Add JAX support for sigmoid scalar Ops

上级 8a8c7e7d
...@@ -150,6 +150,10 @@ def test_jax_basic(): ...@@ -150,6 +150,10 @@ def test_jax_basic():
assert jax_res[0, 0] == -10.0 assert jax_res[0, 0] == -10.0
assert jax_res[0, 1] == -8.0 assert jax_res[0, 1] == -8.0
out = tt.clip(x, y, 5)
out_fg = theano.gof.FunctionGraph([x, y], [out])
(jax_res,) = compare_jax_and_py(out_fg, test_input_vals)
@pytest.mark.skip(reason="Not fully implemented, yet.") @pytest.mark.skip(reason="Not fully implemented, yet.")
def test_jax_scan(): def test_jax_scan():
...@@ -478,3 +482,22 @@ def test_jax_multioutput(): ...@@ -478,3 +482,22 @@ def test_jax_multioutput():
fgraph = theano.gof.FunctionGraph([x, y], [w, v]) fgraph = theano.gof.FunctionGraph([x, y], [w, v])
_ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) _ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_nnet():
x = tt.vector("x")
x.tag.test_value = np.r_[1.0, 2.0].astype(tt.config.floatX)
out = tt.nnet.sigmoid(x)
fgraph = theano.gof.FunctionGraph([x], [out])
_ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = tt.nnet.ultra_fast_sigmoid(x)
fgraph = theano.gof.FunctionGraph([x], [out])
_ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = tt.nnet.softplus(x)
fgraph = theano.gof.FunctionGraph([x], [out])
_ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
...@@ -50,6 +50,8 @@ from theano.compile.ops import ( ...@@ -50,6 +50,8 @@ from theano.compile.ops import (
) )
from theano.tensor.opt import MakeVector from theano.tensor.opt import MakeVector
from theano.tensor.nnet.sigm import ScalarSoftplus
# XXX: Enabling this will break some shape-based functionality, and severely # XXX: Enabling this will break some shape-based functionality, and severely
# limit the types of graphs that can be converted. # limit the types of graphs that can be converted.
...@@ -165,7 +167,18 @@ def jax_funcify_ScalarOp(op): ...@@ -165,7 +167,18 @@ def jax_funcify_ScalarOp(op):
@jax_funcify.register(Clip) @jax_funcify.register(Clip)
def jax_funcify_Clip(op): def jax_funcify_Clip(op):
return partial(op.impl, None) def clip(x, min, max):
return jnp.where(x < min, min, jnp.where(x > max, max, x))
return clip
@jax_funcify.register(ScalarSoftplus)
def jax_funcify_ScalarSoftplus(op):
def scalarsoftplus(x):
return jnp.where(x < -30.0, 0.0, jnp.where(x > 30.0, x, jnp.log1p(jnp.exp(x))))
return scalarsoftplus
@jax_funcify.register(AllocEmpty) @jax_funcify.register(AllocEmpty)
......
...@@ -31,6 +31,8 @@ class ScalarSigmoid(scalar.UnaryScalarOp): ...@@ -31,6 +31,8 @@ class ScalarSigmoid(scalar.UnaryScalarOp):
""" """
nfunc_spec = ("scipy.special.expit", 1, 1)
@staticmethod @staticmethod
def st_impl(x): def st_impl(x):
if x < -30.0: if x < -30.0:
...@@ -196,6 +198,8 @@ class UltraFastScalarSigmoid(scalar.UnaryScalarOp): ...@@ -196,6 +198,8 @@ class UltraFastScalarSigmoid(scalar.UnaryScalarOp):
""" """
nfunc_spec = ("scipy.special.expit", 1, 1)
@staticmethod @staticmethod
def st_impl(x): def st_impl(x):
x = 0.5 * x x = 0.5 * x
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论