提交 0666fd56 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix bug in storage_input alignment of the JAX backend

When replacing the Shared RNG variables, the input order of the FunctionGraph was not explicitly aligned with the input storage of the function being compiled.
上级 3170c7d8
......@@ -23,7 +23,7 @@ class JAXLinker(JITLinker):
# Replace any shared RNG inputs so that their values can be updated in place
# without affecting the original RNG container. This is necessary because
# JAX does not accept RandomState/Generators as inputs, and they will have to
# be typyfied
# be tipyfied
if shared_rng_inputs:
warnings.warn(
f"The RandomType SharedVariables {shared_rng_inputs} will not be used "
......@@ -52,9 +52,16 @@ class JAXLinker(JITLinker):
else: # no break
raise ValueError()
input_storage[input_storage_idx] = new_inp_storage
# We need to change the order of the inputs of the FunctionGraph
# so that the new input is in the same position as to old one,
# to align with the storage_map. We hope this is safe!
old_inp_fgrap_index = fgraph.inputs.index(old_inp)
fgraph.remove_input(
fgraph.inputs.index(old_inp), reason="JAXLinker.fgraph_convert"
old_inp_fgrap_index,
reason="JAXLinker.fgraph_convert",
)
fgraph.inputs.remove(new_inp)
fgraph.inputs.insert(old_inp_fgrap_index, new_inp)
return jax_funcify(
fgraph, input_storage=input_storage, storage_map=storage_map, **kwargs
......
......@@ -10,6 +10,7 @@ from pytensor.compile.sharedvalue import SharedVariable, shared
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor.random.basic import RandomVariable
from pytensor.tensor.random.type import RandomType
from pytensor.tensor.random.utils import RandomStream
from tests.link.jax.test_basic import compare_jax_and_py, jax_mode, set_test_value
......@@ -58,7 +59,42 @@ def test_random_updates(rng_ctor):
)
def test_random_updates_input_storage_order():
@pytest.mark.parametrize("noise_first", (False, True))
def test_replaced_shared_rng_storage_order(noise_first):
# Test that replacing the RNG variable in the linker does not cause
# a disalignment between the compiled graph and the storage_map.
mu = pytensor.shared(np.array(1.0), name="mu")
rng = pytensor.shared(np.random.default_rng(123))
next_rng, noise = pt.random.normal(rng=rng).owner.outputs
if noise_first:
out = noise * mu
else:
out = mu * noise
updates = {
mu: pt.grad(out, mu),
rng: next_rng,
}
f = compile_random_function([], [out], updates=updates, mode="JAX")
# The bug was found when noise used to be the first input of the fgraph
# If this changes, the test may need to be tweaked to keep the save coverage
assert isinstance(
f.input_storage[1 - noise_first].type, RandomType
), "Test may need to be tweaked"
# Confirm that input_storage type and fgraph input order are aligned
for storage, fgrapn_input in zip(f.input_storage, f.maker.fgraph.inputs):
assert storage.type == fgrapn_input.type
assert mu.get_value() == 1
f()
assert mu.get_value() != 1
def test_replaced_shared_rng_storage_ordering_equality():
"""Test case described in issue #314.
This happened when we tried to update the input storage after we clone the shared RNG.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论