提交 d05ab606 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add a test for jaxification of a shared variable

上级 6c8689cc
......@@ -623,3 +623,28 @@ def test_identity():
out = theano.scalar.basic.identity(a)
fgraph = theano.gof.FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_shared():
a = theano.shared(np.array([1, 2, 3], dtype=theano.config.floatX))
theano_jax_fn = theano.function([], a, mode="JAX")
jax_res = theano_jax_fn()
assert isinstance(jax_res, jax.interpreters.xla.DeviceArray)
np.testing.assert_allclose(jax_res, a.get_value())
theano_jax_fn = theano.function([], a * 2, mode="JAX")
jax_res = theano_jax_fn()
assert isinstance(jax_res, jax.interpreters.xla.DeviceArray)
np.testing.assert_allclose(jax_res, a.get_value() * 2)
# Changed the shared value and make sure that the JAX-compiled
# function also changes.
new_a_value = np.array([3, 4, 5], dtype=theano.config.floatX)
a.set_value(new_a_value)
jax_res = theano_jax_fn()
assert isinstance(jax_res, jax.interpreters.xla.DeviceArray)
np.testing.assert_allclose(jax_res, new_a_value * 2)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论