提交 02c02d72 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Create a JAX compilation mode

上级 54678486
......@@ -17,7 +17,8 @@ def set_theano_flags():
def compare_jax_and_py(fgraph, inputs, cmp_fn=np.allclose):
jax_mode = theano.compile.Mode(linker="jax")
# jax_mode = theano.compile.Mode(linker="jax")
jax_mode = "JAX"
theano_jax_fn = theano.function(fgraph.inputs, fgraph.outputs, mode=jax_mode)
jax_res = theano_jax_fn(*inputs)
......@@ -499,5 +500,3 @@ def test_nnet():
out = tt.nnet.softplus(x)
fgraph = theano.gof.FunctionGraph([x], [out])
_ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
......@@ -413,9 +413,15 @@ if theano.config.cxx:
else:
FAST_RUN = Mode("vm", "fast_run")
JAX = Mode(
JAXLinker(), gof.Query(include=["fast_run"], exclude=["cxx_only", "BlasOpt"])
)
predefined_modes = {
"FAST_COMPILE": FAST_COMPILE,
"FAST_RUN": FAST_RUN,
"JAX": JAX,
}
instantiated_default_mode = None
......
......@@ -596,14 +596,16 @@ AddConfigVar(
# Also, please be careful not to modify the first item in the enum when adding
# new modes, since it is the default mode.
def filter_mode(val):
if val in [
"Mode",
"DebugMode",
"FAST_RUN",
"NanGuardMode",
"FAST_COMPILE",
"DEBUG_MODE",
]:
if (
val
in [
"Mode",
"DebugMode",
"NanGuardMode",
"DEBUG_MODE",
]
or val in theano.compile.mode.predefined_modes
):
return val
# This can be executed before Theano is completly imported, so
# theano.Mode is not always available.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论