提交 06244f30 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba Scan: zero out unwritten buffers

上级 f3a7d94f
...@@ -254,6 +254,17 @@ def numba_funcify_Scan(op: Scan, node, **kwargs): ...@@ -254,6 +254,17 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
""" """
).strip() ).strip()
) )
else:
# And regular loops should zero out unused entries of the output buffer
# These show up with truncated gradients of while loops
output_storage_post_proc_stmts.append(
dedent(
f"""
elif {storage_size} > (i + {max_offset}):
{outer_in_name}[i + {max_offset}:] = 0
"""
).strip()
)
# Special in-loop statements that create (nit-sot) storage arrays after a # Special in-loop statements that create (nit-sot) storage arrays after a
# single iteration is performed. This is necessary because we don't know # single iteration is performed. This is necessary because we don't know
......
...@@ -657,3 +657,7 @@ class TestScanMITSOTBuffer: ...@@ -657,3 +657,7 @@ class TestScanMITSOTBuffer:
def test_higher_order_derivatives(): def test_higher_order_derivatives():
ScanCompatibilityTests.check_higher_order_derivative(mode="NUMBA") ScanCompatibilityTests.check_higher_order_derivative(mode="NUMBA")
def test_grad_until_and_truncate_sequence_taps():
ScanCompatibilityTests.check_grad_until_and_truncate_sequence_taps(mode="NUMBA")
...@@ -2621,22 +2621,7 @@ class TestGradUntil: ...@@ -2621,22 +2621,7 @@ class TestGradUntil:
utt.assert_allclose(pytensor_gradient, self.numpy_gradient) utt.assert_allclose(pytensor_gradient, self.numpy_gradient)
def test_grad_until_and_truncate_sequence_taps(self): def test_grad_until_and_truncate_sequence_taps(self):
n = 3 ScanCompatibilityTests.check_grad_until_and_truncate_sequence_taps(mode=None)
r = scan(
lambda x, y, u: (x * y, until(y > u)),
sequences=dict(input=self.x, taps=[-2, 0]),
non_sequences=[self.threshold],
truncate_gradient=n,
return_updates=False,
)
g = grad(r.sum(), self.x)
f = function([self.x, self.threshold], [r, g])
_pytensor_output, pytensor_gradient = f(self.seq, 6)
# Gradient computed by hand:
numpy_grad = np.array([0, 0, 0, 5, 6, 10, 4, 5, 0, 0, 0, 0, 0, 0, 0])
numpy_grad = numpy_grad.astype(config.floatX)
utt.assert_allclose(pytensor_gradient, numpy_grad)
def test_mintap_onestep(): def test_mintap_onestep():
...@@ -4431,3 +4416,26 @@ class ScanCompatibilityTests: ...@@ -4431,3 +4416,26 @@ class ScanCompatibilityTests:
np.testing.assert_allclose(gg_res, (16 * 15) * x_test**14) np.testing.assert_allclose(gg_res, (16 * 15) * x_test**14)
# FIXME: All implementations of Scan seem to get this one wrong! # FIXME: All implementations of Scan seem to get this one wrong!
# np.testing.assert_allclose(ggg_res, (16 * 15 * 14) * x_test**13) # np.testing.assert_allclose(ggg_res, (16 * 15 * 14) * x_test**13)
@staticmethod
def check_grad_until_and_truncate_sequence_taps(mode):
"""Test case where we need special behavior of zeroing out sequences in Scan"""
x = pt.vector("x")
threshold = pt.scalar(name="threshold", dtype="int64")
r = scan(
lambda x, y, u: (x * y, until(y > u)),
sequences=dict(input=x, taps=[-2, 0]),
non_sequences=[threshold],
truncate_gradient=3,
return_updates=False,
)
g = grad(r.sum(), x)
f = function([x, threshold], [r, g], mode=mode)
_, grad_res = f(np.arange(15, dtype=x.dtype), 6)
# Gradient computed by hand:
grad_expected = np.array([0, 0, 0, 5, 6, 10, 4, 5, 0, 0, 0, 0, 0, 0, 0])
grad_expected = grad_expected.astype(config.floatX)
np.testing.assert_allclose(grad_res, grad_expected)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论