提交 8a7f59e1 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Add an option to remove all optimizations and ignore result dtypes in JAX testing

上级 44aa0fb5
...@@ -19,17 +19,31 @@ def set_theano_flags(): ...@@ -19,17 +19,31 @@ def set_theano_flags():
def compare_jax_and_py( def compare_jax_and_py(
fgraph, inputs, assert_fn=partial(np.testing.assert_allclose, rtol=1e-4) fgraph,
inputs,
assert_fn=partial(np.testing.assert_allclose, rtol=1e-4),
simplify=False,
must_be_device_array=True,
): ):
theano_jax_fn = theano.function(fgraph.inputs, fgraph.outputs, mode="JAX") if not simplify:
opts = theano.gof.Query(include=[None], exclude=["cxx_only", "BlasOpt"])
jax_mode = theano.compile.mode.Mode(theano.sandbox.jax_linker.JAXLinker(), opts)
py_mode = theano.compile.Mode("py", opts)
else:
py_mode = theano.compile.Mode(linker="py")
jax_mode = "JAX"
theano_jax_fn = theano.function(fgraph.inputs, fgraph.outputs, mode=jax_mode)
jax_res = theano_jax_fn(*inputs) jax_res = theano_jax_fn(*inputs)
if isinstance(jax_res, list): if must_be_device_array:
assert all(isinstance(res, jax.interpreters.xla.DeviceArray) for res in jax_res) if isinstance(jax_res, list):
else: assert all(
assert isinstance(jax_res, jax.interpreters.xla.DeviceArray) isinstance(res, jax.interpreters.xla.DeviceArray) for res in jax_res
)
else:
assert isinstance(jax_res, jax.interpreters.xla.DeviceArray)
py_mode = theano.compile.Mode(linker="py")
theano_py_fn = theano.function(fgraph.inputs, fgraph.outputs, mode=py_mode) theano_py_fn = theano.function(fgraph.inputs, fgraph.outputs, mode=py_mode)
py_res = theano_py_fn(*inputs) py_res = theano_py_fn(*inputs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论