Unverified 提交 0150ddf6 authored 作者: Ravin Kumar's avatar Ravin Kumar 提交者: GitHub

Add docstring to compare_jax_and_py (#155)

上级 1a9b04bf
...@@ -22,19 +22,38 @@ def compare_jax_and_py( ...@@ -22,19 +22,38 @@ def compare_jax_and_py(
fgraph, fgraph,
inputs, inputs,
assert_fn=None, assert_fn=None,
simplify=False,
must_be_device_array=True, must_be_device_array=True,
): ):
"""Function to compare python graph output and jax compiled output for testing equality
In the tests below computational graphs are defined in Theano. These graphs are then passed to
this function which then compiles the graphs in both jax and python, runs the calculation
in both and checks if the results are the same
Parameters
----------
fgraph: theano.gof.FunctionGraph
Theano function Graph object
inputs: iter
Inputs for function graph
assert_fn: func, opt
Assert function used to check for equality between python and jax. If not
provided uses np.testing.assert_allclose
must_be_device_array: Bool
Checks for instance of jax.interpreters.xla.DeviceArray. For testing purposes
if this device array is found it indicates if the result was computed by jax
Returns
-------
jax_res
"""
if assert_fn is None: if assert_fn is None:
assert_fn = partial(np.testing.assert_allclose, rtol=1e-4) assert_fn = partial(np.testing.assert_allclose, rtol=1e-4)
if not simplify: opts = theano.gof.Query(include=[None], exclude=["cxx_only", "BlasOpt"])
opts = theano.gof.Query(include=[None], exclude=["cxx_only", "BlasOpt"]) jax_mode = theano.compile.mode.Mode(theano.sandbox.jax_linker.JAXLinker(), opts)
jax_mode = theano.compile.mode.Mode(theano.sandbox.jax_linker.JAXLinker(), opts) py_mode = theano.compile.Mode("py", opts)
py_mode = theano.compile.Mode("py", opts)
else:
py_mode = theano.compile.Mode(linker="py")
jax_mode = "JAX"
theano_jax_fn = theano.function(fgraph.inputs, fgraph.outputs, mode=jax_mode) theano_jax_fn = theano.function(fgraph.inputs, fgraph.outputs, mode=jax_mode)
jax_res = theano_jax_fn(*inputs) jax_res = theano_jax_fn(*inputs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论