提交 93bfa1bd authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix bug in JAX cloning of RNG shared variables

上级 2b7f95cf
......@@ -44,7 +44,14 @@ class JAXLinker(JITLinker):
new_inp_storage = [new_inp.get_value(borrow=True)]
storage_map[new_inp] = new_inp_storage
old_inp_storage = storage_map.pop(old_inp)
input_storage[input_storage.index(old_inp_storage)] = new_inp_storage
# Find index of old_inp_storage in input_storage
for input_storage_idx, input_storage_item in enumerate(input_storage):
# We have to establish equality based on identity because input_storage may contain numpy arrays
if input_storage_item is old_inp_storage:
break
else: # no break
raise ValueError()
input_storage[input_storage_idx] = new_inp_storage
fgraph.remove_input(
fgraph.inputs.index(old_inp), reason="JAXLinker.fgraph_convert"
)
......
......@@ -63,6 +63,43 @@ def test_random_updates(rng_ctor):
)
def test_random_updates_input_storage_order():
"""Test case described in issue #314.
This happened when we tried to update the input storage after we clone the shared RNG.
We used to call `input_storage.index(old_input_storage)` which would fail when the input_storage contained
numpy arrays before the RNG value, which would fail the equality check.
"""
pt_rng = RandomStream(1)
batchshape = (3, 1, 4, 4)
inp_shared = pytensor.shared(
np.zeros(batchshape, dtype="float64"), name="inp_shared"
)
inp = at.tensor4(dtype="float64", name="inp")
inp_update = inp + pt_rng.normal(size=inp.shape, loc=5, scale=1e-5)
# This function replaces inp by input_shared in the update expression
# This is what caused the RNG to appear later than inp_shared in the input_storage
with pytest.warns(
UserWarning,
match=r"The RandomType SharedVariables \[.+\] will not be used",
):
fn = pytensor.function(
inputs=[],
outputs=[],
updates={inp_shared: inp_update},
givens={inp: inp_shared},
mode="JAX",
)
fn()
np.testing.assert_allclose(inp_shared.get_value(), 5, rtol=1e-3)
fn()
np.testing.assert_allclose(inp_shared.get_value(), 10, rtol=1e-3)
@pytest.mark.parametrize(
"rv_op, dist_params, base_size, cdf_name, params_conv",
[
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论