提交 02c02d72 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Create a JAX compilation mode

上级 54678486
...@@ -17,7 +17,8 @@ def set_theano_flags(): ...@@ -17,7 +17,8 @@ def set_theano_flags():
def compare_jax_and_py(fgraph, inputs, cmp_fn=np.allclose): def compare_jax_and_py(fgraph, inputs, cmp_fn=np.allclose):
jax_mode = theano.compile.Mode(linker="jax") # jax_mode = theano.compile.Mode(linker="jax")
jax_mode = "JAX"
theano_jax_fn = theano.function(fgraph.inputs, fgraph.outputs, mode=jax_mode) theano_jax_fn = theano.function(fgraph.inputs, fgraph.outputs, mode=jax_mode)
jax_res = theano_jax_fn(*inputs) jax_res = theano_jax_fn(*inputs)
...@@ -499,5 +500,3 @@ def test_nnet(): ...@@ -499,5 +500,3 @@ def test_nnet():
out = tt.nnet.softplus(x) out = tt.nnet.softplus(x)
fgraph = theano.gof.FunctionGraph([x], [out]) fgraph = theano.gof.FunctionGraph([x], [out])
_ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) _ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
...@@ -413,9 +413,15 @@ if theano.config.cxx: ...@@ -413,9 +413,15 @@ if theano.config.cxx:
else: else:
FAST_RUN = Mode("vm", "fast_run") FAST_RUN = Mode("vm", "fast_run")
JAX = Mode(
JAXLinker(), gof.Query(include=["fast_run"], exclude=["cxx_only", "BlasOpt"])
)
predefined_modes = { predefined_modes = {
"FAST_COMPILE": FAST_COMPILE, "FAST_COMPILE": FAST_COMPILE,
"FAST_RUN": FAST_RUN, "FAST_RUN": FAST_RUN,
"JAX": JAX,
} }
instantiated_default_mode = None instantiated_default_mode = None
......
...@@ -596,14 +596,16 @@ AddConfigVar( ...@@ -596,14 +596,16 @@ AddConfigVar(
# Also, please be careful not to modify the first item in the enum when adding # Also, please be careful not to modify the first item in the enum when adding
# new modes, since it is the default mode. # new modes, since it is the default mode.
def filter_mode(val): def filter_mode(val):
if val in [ if (
"Mode", val
"DebugMode", in [
"FAST_RUN", "Mode",
"NanGuardMode", "DebugMode",
"FAST_COMPILE", "NanGuardMode",
"DEBUG_MODE", "DEBUG_MODE",
]: ]
or val in theano.compile.mode.predefined_modes
):
return val return val
# This can be executed before Theano is completly imported, so # This can be executed before Theano is completly imported, so
# theano.Mode is not always available. # theano.Mode is not always available.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论