提交 3aff3d2f authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Brandon T. Willard

Add tests for shared updates in JAX and NUMBA backends

* Also fixes bug in JITCompiler when first output of inner fgraph is an input variable, as can happen in some specific functions with updates
上级 864ee339
...@@ -641,7 +641,8 @@ class JITLinker(PerformLinker): ...@@ -641,7 +641,8 @@ class JITLinker(PerformLinker):
The JITed function that performs the computations. The JITed function that performs the computations.
""" """
output_nodes = [o.owner for o in self.fgraph.outputs] # This is a bit hackish, but we only return one of the output nodes
output_nodes = [o.owner for o in self.fgraph.outputs if o.owner is not None][:1]
converted_fgraph = self.fgraph_convert( converted_fgraph = self.fgraph_convert(
self.fgraph, self.fgraph,
...@@ -678,8 +679,7 @@ class JITLinker(PerformLinker): ...@@ -678,8 +679,7 @@ class JITLinker(PerformLinker):
thunks.append(thunk) thunks.append(thunk)
# This is a bit hackish, but we only return one of the output nodes return thunks, output_nodes, fgraph_jit
return thunks, output_nodes[:1], fgraph_jit
def make_all(self, input_storage=None, output_storage=None, storage_map=None): def make_all(self, input_storage=None, output_storage=None, storage_map=None):
fgraph = self.fgraph fgraph = self.fgraph
......
...@@ -243,7 +243,7 @@ def gc_helper(node_list: List[Apply]): ...@@ -243,7 +243,7 @@ def gc_helper(node_list: List[Apply]):
------- -------
2-tuple 2-tuple
FIRST, the set of Variable instances which are computed by node_list, FIRST, the set of Variable instances which are computed by node_list,
and SECOND a dictionary that maps each Variable instance to a the last and SECOND a dictionary that maps each Variable instance to the last
node to use Variable as an input. node to use Variable as an input.
Extended Summary Extended Summary
......
...@@ -182,6 +182,22 @@ def test_shared(): ...@@ -182,6 +182,22 @@ def test_shared():
np.testing.assert_allclose(jax_res, new_a_value * 2) np.testing.assert_allclose(jax_res, new_a_value * 2)
def test_shared_updates():
a = shared(0)
aesara_jax_fn = function([], a, updates={a: a + 1}, mode="JAX")
res1, res2 = aesara_jax_fn(), aesara_jax_fn()
assert res1 == 0
assert res2 == 1
assert a.get_value() == 2
a.set_value(5)
res1, res2 = aesara_jax_fn(), aesara_jax_fn()
assert res1 == 5
assert res2 == 6
assert a.get_value() == 7
def test_jax_ifelse(): def test_jax_ifelse():
true_vals = np.r_[1, 2, 3] true_vals = np.r_[1, 2, 3]
......
...@@ -837,6 +837,22 @@ def test_shared(): ...@@ -837,6 +837,22 @@ def test_shared():
np.testing.assert_allclose(numba_res, new_a_value * 2) np.testing.assert_allclose(numba_res, new_a_value * 2)
def test_shared_updates():
a = shared(0)
aesara_numba_fn = function([], a, updates={a: a + 1}, mode="NUMBA")
res1, res2 = aesara_numba_fn(), aesara_numba_fn()
assert res1 == 0
assert res2 == 1
assert a.get_value() == 2
a.set_value(5)
res1, res2 = aesara_numba_fn(), aesara_numba_fn()
assert res1 == 5
assert res2 == 6
assert a.get_value() == 7
# We were seeing some weird results in CI where the following two almost # We were seeing some weird results in CI where the following two almost
# sign-swapped results were being return from Numba and Python, respectively. # sign-swapped results were being return from Numba and Python, respectively.
# The issue might be related to https://github.com/numba/numba/issues/4519. # The issue might be related to https://github.com/numba/numba/issues/4519.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论