Unverified 提交 819122e6 authored 作者: Michael Osthege's avatar Michael Osthege 提交者: GitHub

Move JAX linker to new "link" module (#219)

* Move JAX linker to new "link" module closes #188 * Move JAX tests to link module and update import path
上级 2601e7ac
...@@ -52,7 +52,7 @@ def compare_jax_and_py( ...@@ -52,7 +52,7 @@ def compare_jax_and_py(
assert_fn = partial(np.testing.assert_allclose, rtol=1e-4) assert_fn = partial(np.testing.assert_allclose, rtol=1e-4)
opts = theano.gof.Query(include=[None], exclude=["cxx_only", "BlasOpt"]) opts = theano.gof.Query(include=[None], exclude=["cxx_only", "BlasOpt"])
jax_mode = theano.compile.mode.Mode(theano.sandbox.jax_linker.JAXLinker(), opts) jax_mode = theano.compile.mode.Mode(theano.link.jax.JAXLinker(), opts)
py_mode = theano.compile.Mode("py", opts) py_mode = theano.compile.Mode("py", opts)
theano_jax_fn = theano.function(fgraph.inputs, fgraph.outputs, mode=jax_mode) theano_jax_fn = theano.function(fgraph.inputs, fgraph.outputs, mode=jax_mode)
...@@ -833,7 +833,7 @@ def test_jax_BatchedDot(): ...@@ -833,7 +833,7 @@ def test_jax_BatchedDot():
# A dimension mismatch should raise a TypeError for compatibility # A dimension mismatch should raise a TypeError for compatibility
inputs = [get_test_value(a)[:-1], get_test_value(b)] inputs = [get_test_value(a)[:-1], get_test_value(b)]
opts = theano.gof.Query(include=[None], exclude=["cxx_only", "BlasOpt"]) opts = theano.gof.Query(include=[None], exclude=["cxx_only", "BlasOpt"])
jax_mode = theano.compile.mode.Mode(theano.sandbox.jax_linker.JAXLinker(), opts) jax_mode = theano.compile.mode.Mode(theano.link.jax.JAXLinker(), opts)
theano_jax_fn = theano.function(fgraph.inputs, fgraph.outputs, mode=jax_mode) theano_jax_fn = theano.function(fgraph.inputs, fgraph.outputs, mode=jax_mode)
with pytest.raises(TypeError): with pytest.raises(TypeError):
theano_jax_fn(*inputs) theano_jax_fn(*inputs)
......
...@@ -10,7 +10,7 @@ import theano ...@@ -10,7 +10,7 @@ import theano
import theano.gof.vm import theano.gof.vm
from theano import config, gof from theano import config, gof
from theano.compile.function.types import Supervisor from theano.compile.function.types import Supervisor
from theano.sandbox.jax_linker import JAXLinker from theano.link.jax import JAXLinker
_logger = logging.getLogger("theano.compile.mode") _logger = logging.getLogger("theano.compile.mode")
......
from theano.link.jax.jax_linker import JAXLinker
...@@ -50,7 +50,7 @@ class JAXLinker(PerformLinker): ...@@ -50,7 +50,7 @@ class JAXLinker(PerformLinker):
""" """
import jax import jax
from theano.sandbox.jaxify import jax_funcify from theano.link.jax.jax_dispatch import jax_funcify
output_nodes = [o.owner for o in self.fgraph.outputs] output_nodes = [o.owner for o in self.fgraph.outputs]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论