提交 dca7f5d9 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba Scan: prevent alias of outputs

Also simplified test. Shared variables aren't needed for the test and clobber it
上级 06244f30
......@@ -5,13 +5,12 @@ import numpy as np
from numba import types
from numba.extending import overload
from pytensor import In
from pytensor.compile.function.types import add_supervisor_to_fgraph
from pytensor.compile.function.types import add_supervisor_to_fgraph, insert_deepcopy
from pytensor.compile.io import In, Out
from pytensor.compile.mode import NUMBA, get_mode
from pytensor.link.numba.cache import compile_numba_function_src
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import (
create_arg_string,
create_tuple_string,
numba_funcify_and_cache_key,
register_funcify_and_cache_key,
......@@ -89,14 +88,15 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
if outer_mitsot.type.shape[0] == abs(min(taps))
]
destroyable = {*destroyable_sitsot, *destroyable_mitsot}
input_specs = [In(x, borrow=True, mutable=x in destroyable) for x in fgraph.inputs]
add_supervisor_to_fgraph(
fgraph=fgraph,
input_specs=[
In(x, borrow=True, mutable=x in destroyable) for x in fgraph.inputs
],
input_specs=input_specs,
accept_inplace=True,
)
rewriter(fgraph)
output_specs = [Out(x, borrow=False) for x in fgraph.outputs]
insert_deepcopy(fgraph, wrapped_inputs=input_specs, wrapped_outputs=output_specs)
scan_inner_func, inner_func_cache_key = numba_funcify_and_cache_key(op.fgraph)
......
......@@ -661,3 +661,7 @@ def test_higher_order_derivatives():
def test_grad_until_and_truncate_sequence_taps():
ScanCompatibilityTests.check_grad_until_and_truncate_sequence_taps(mode="NUMBA")
def test_aliased_inner_outputs():
ScanCompatibilityTests.check_aliased_inner_outputs(static_shape=True, mode="NUMBA")
......@@ -3181,40 +3181,9 @@ class TestExamples:
f = function([seq], results[1])
assert np.all(exp_out == f(inp))
def test_shared_borrow(self):
"""
This tests two things. The first is a bug occurring when scan wrongly
used the borrow flag. The second thing it that Scan's infer_shape()
method will be able to remove the Scan node from the graph in this
case.
"""
inp = np.arange(10).reshape(-1, 1).astype(config.floatX)
exp_out = np.zeros((10, 1)).astype(config.floatX)
exp_out[4:] = inp[:-4]
def onestep(x, x_tm4):
return x, x_tm4
seq = matrix()
initial_value = shared(np.zeros((4, 1), dtype=config.floatX))
outputs_info = [{"initial": initial_value, "taps": [-4]}, None]
results = scan(
fn=onestep, sequences=seq, outputs_info=outputs_info, return_updates=False
)
sharedvar = shared(np.zeros((1, 1), dtype=config.floatX))
updates = {sharedvar: results[0][-1:]}
f = function([seq], results[1], updates=updates)
# This fails if scan uses wrongly the borrow flag
assert np.all(exp_out == f(inp))
# This fails if Scan's infer_shape() is unable to remove the Scan
# node from the graph.
f_infershape = function([seq], results[1].shape, mode="FAST_RUN")
scan_nodes_infershape = scan_nodes_from_fct(f_infershape)
assert len(scan_nodes_infershape) == 0
@pytest.mark.parametrize("static_shape", (True, False))
def test_aliased_inner_outputs(self, static_shape):
ScanCompatibilityTests.check_aliased_inner_outputs(static_shape, mode=None)
def test_memory_reuse_with_outputs_as_inputs(self):
"""
......@@ -4417,7 +4386,6 @@ class ScanCompatibilityTests:
# FIXME: All implementations of Scan seem to get this one wrong!
# np.testing.assert_allclose(ggg_res, (16 * 15 * 14) * x_test**13)
@staticmethod
def check_grad_until_and_truncate_sequence_taps(mode):
"""Test case where we need special behavior of zeroing out sequences in Scan"""
......@@ -4439,3 +4407,66 @@ class ScanCompatibilityTests:
grad_expected = np.array([0, 0, 0, 5, 6, 10, 4, 5, 0, 0, 0, 0, 0, 0, 0])
grad_expected = grad_expected.astype(config.floatX)
np.testing.assert_allclose(grad_res, grad_expected)
@staticmethod
def check_aliased_inner_outputs(static_shape, mode):
"""
This tests two things. The first is a bug occurring when scan wrongly
used the borrow flag. The second thing it that Scan's infer_shape()
method will be able to remove the Scan node from the graph in this
case.
Here is pure python equivalent of the problem we want to avoid:
```python
def scan(seq, initval):
# Due to memory optimization we override values of mitsot as we iterate
# That's why mitsot has shape (4, 1) and not (14, 1)
mitsot = np.zeros((4, 1))
mitsot[:4] = initval
nitsot = np.zeros((10, 1))
for i, s in enumerate(seq):
# Incorrect results
mitsot[(i+4) % 4], nitsot[i] = s, mitsot[i % 4]
# Correct results
# mitsot[(i + 4) % 4], nitsot[i] = s, mitsot[i % 4].copy()
return mitsot[(i + 4) % 4: (i+4 + 1) % 4], nitsot
scan(np.arange(10), np.zeros((4, 1)))
```
"""
def onestep(seq, seq_tm4):
# Recurring output is just each value of seq
# And we further map the tap -4 as a new output
return seq, seq_tm4
# Outer tensors must be atleast matrix, so that they we have vectors in the inner loop
# Otherwise we would be working with scalars and memory alias wouldn't be a concern
seq = matrix(shape=(10, 1) if static_shape else (None, None), name="seq")
init = matrix(shape=(4, 1) if static_shape else (None, None), name="init")
outputs_info = [{"initial": init, "taps": [-4]}, None]
[out_seq, out_seq_tm4] = scan(
fn=onestep,
sequences=seq,
outputs_info=outputs_info,
return_updates=False,
)
f = function([seq, init], [out_seq[-1].ravel(), out_seq_tm4.ravel()], mode=mode)
seq_test_val = np.arange(10, dtype=config.floatX)[:, None]
init_test_val = np.zeros((4, 1), dtype=config.floatX)
res0, res1 = f(seq_test_val, init_test_val)
expected_res0 = np.array([9], dtype=config.floatX)
expected_res1 = np.zeros(10, dtype=config.floatX)
expected_res1[4:] = np.arange(6)
np.testing.assert_array_equal(res0, expected_res0)
np.testing.assert_array_equal(res1, expected_res1)
# This fails if Scan's infer_shape() is unable to remove the Scan
# node from the graph.
f_infershape = function([seq, init], out_seq_tm4[1].shape)
scan_nodes_infershape = scan_nodes_from_fct(f_infershape)
assert len(scan_nodes_infershape) == 0
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论