提交 84c78027 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Handle Scan gradients of non shaped disconnected inputs

上级 4aea87c2
......@@ -72,6 +72,7 @@ from pytensor.graph.basic import (
from pytensor.graph.features import NoOutputFromInplace
from pytensor.graph.op import HasInnerGraph, Op
from pytensor.graph.replace import clone_replace
from pytensor.graph.type import HasShape
from pytensor.graph.utils import InconsistencyError, MissingInputError
from pytensor.link.c.basic import CLinker
from pytensor.printing import op_debug_information
......@@ -2591,7 +2592,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# mask inputs that get no gradients
for dx in range(len(dC_dinps_t)):
if dC_dinps_t[dx] is None:
dC_dinps_t[dx] = dC_dinps_t[dx] = pt.zeros_like(diff_inputs[dx])
dC_dinps_t[dx] = dC_dinps_t[dx] = (
pt.zeros_like(diff_inputs[dx])
if isinstance(diff_inputs[dx].type, HasShape)
else pt.zeros(())
)
else:
disconnected_dC_dinps_t[dx] = False
for Xt, Xt_placeholder in zip(
......@@ -2965,7 +2970,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
else:
outer_inp_sitsot.append(
pt.zeros(
[grad_steps + 1] + [x.shape[i] for i in range(x.ndim)],
[grad_steps + 1]
+ (list(x.shape) if isinstance(x.type, HasShape) else []),
dtype=y.dtype,
)
)
......
......@@ -2179,6 +2179,72 @@ class TestScan:
assert gg.eval({seq: [1, 1], x0: [1, 1], z: 2}) == 12
assert gg.eval({seq: [1, 1], x0: [1, 1], z: 1}) == 3 / 2
@pytest.mark.parametrize("case", ("inside-explicit", "inside-implicit", "outside"))
def test_non_shaped_input_disconnected_gradient(self, case):
"""Test that Scan gradient works when non shaped variables are disconnected from the gradient.
Regression test for https://github.com/pymc-devs/pytensor/issues/6
"""
# In all cases rng is disconnected from the output gradient
# Note that when it is an input to the scan (explicit or not) it is still not updated by the scan,
# so it is equivalent to the `outside` case. A rewrite could have legally hoisted the rng out of the scan.
rng = shared(np.random.default_rng())
data = pt.zeros(16)
nonlocal_random_index = pt.random.integers(16, rng=rng)
nonlocal_random_datum = data[nonlocal_random_index]
if case == "outside":
def step(s, random_datum):
return (random_datum + s) ** 2
strict = True
non_sequences = [nonlocal_random_datum]
elif case == "inside-implicit":
def step(s):
return (nonlocal_random_datum + s) ** 2
strict = False
non_sequences = [] # Scan will introduce the non_sequences for us
elif case == "inside-explicit":
def step(s, data, rng):
random_index = pt.random.integers(
16, rng=rng
) # Not updated by the scan
random_datum = data[random_index]
return (random_datum + s) ** 2
strict = (True,)
non_sequences = [data, rng]
else:
raise ValueError(f"Invalid case: {case}")
seq = vector("seq")
xs, _ = scan(
step,
sequences=[seq],
non_sequences=non_sequences,
strict=strict,
)
x0 = xs[0]
np.testing.assert_allclose(
x0.eval({seq: [np.pi, np.nan, np.nan]}),
np.pi**2,
)
np.testing.assert_allclose(
grad(x0, seq)[0].eval({seq: [np.pi, np.nan, np.nan]}),
2 * np.pi,
)
@pytest.mark.skipif(
not config.cxx, reason="G++ not available, so we need to skip this test."
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论