提交 c8eea207 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix a Scan Cython issue involving output storage lengths

上级 77bb1523
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -62,7 +62,7 @@ numpy.import_array()
def get_version():
return 0.325
return 0.326
@cython.cdivision(True)
......@@ -253,19 +253,24 @@ def perform(
continue
outer_outputs_idx_0 = outer_outputs_idx[0]
outer_inputs_offset_idx = outer_inputs[<unsigned int>(seqs_arg_offset + idx)]
if ( outer_outputs_idx_0 is not None and
outer_outputs_idx_0.shape[1:] == outer_inputs[<unsigned int>(1+ n_seqs + idx)].shape[1:]
outer_outputs_idx_0.shape[1:] == outer_inputs_offset_idx.shape[1:]
and outer_outputs_idx_0.shape[0] >= store_steps[idx] ):
# Put in the values of the initial state
outer_outputs_idx[0] = outer_outputs_idx_0[:store_steps[idx]]
outer_outputs_idx_0 = outer_outputs_idx_0[:store_steps[idx]]
outer_outputs_idx[0] = outer_outputs_idx_0
if idx > n_mit_mot:
l = - mintaps[idx]
outer_outputs_idx_0[:l] = outer_inputs[<unsigned int>(seqs_arg_offset + idx)][:l]
outer_outputs_idx_0[:l] = outer_inputs_offset_idx[:l]
else:
outer_outputs_idx_0[:] = outer_inputs[<unsigned int>(seqs_arg_offset + idx)]
outer_outputs_idx_0[:] = outer_inputs_offset_idx
else:
outer_outputs_idx[0] = outer_inputs[<unsigned int>(seqs_arg_offset + idx)].copy()
outer_outputs_idx[0] = outer_inputs_offset_idx.copy()
if n_steps == 0:
for idx in range(n_outs, n_outs + n_nit_sot):
......
......@@ -23,7 +23,7 @@ if not config.cxx:
_logger = logging.getLogger("aesara.scan.scan_perform")
version = 0.325 # must match constant returned in function get_version()
version = 0.326 # must match constant returned in function get_version()
need_reload = False
scan_perform: Optional[ModuleType] = None
......
......@@ -20,6 +20,7 @@ from tempfile import mkdtemp
import numpy as np
import pytest
import aesara.tensor as at
from aesara.compile.debugmode import DebugMode
from aesara.compile.function import function
from aesara.compile.function.pfunc import rebuild_collect_shared
......@@ -38,7 +39,6 @@ from aesara.raise_op import assert_op
from aesara.scan.basic import scan
from aesara.scan.op import Scan
from aesara.scan.utils import until
from aesara.tensor import basic as at
from aesara.tensor.math import all as at_all
from aesara.tensor.math import dot, mean, sigmoid
from aesara.tensor.math import sum as at_sum
......@@ -4028,3 +4028,51 @@ def test_ScanInfo_totals(fn, sequences, outputs_info, non_sequences, n_steps, op
assert scan_op.info.n_outer_outputs == len(res.owner.outputs)
assert scan_op.info.n_inner_inputs == len(res.owner.op.inner_inputs)
assert scan_op.info.n_inner_outputs == len(res.owner.op.inner_outputs)
@pytest.mark.parametrize("linker_mode", ["cvm", "py"])
def test_output_storage_reuse(linker_mode):
"""Make sure that outer-output storage is correctly initialized when it's non-``None``/empty."""
if linker_mode == "cvm":
# This implicitly confirms that the Cython version is being used
from aesara.scan import scan_perform_ext # noqa: F401
mode = Mode(linker=linker_mode, optimizer=None)
def fn(n):
"""
Since the following inner-`Scan` is nested, its outer-output storage
will be non-``None`` after the second outer-`Scan` iteration, and all
subsequent iterations will use the previous outer-output storage. Due
to the ``n_step`` changes, the shape of the outer-inputs array that's
allocated for the lagged/sit-sot ``z`` results should differ from the
shape of the previously allocated outer-output array. Since the
outer-output arrays are initialized using the outer-input arrays, the
shape difference needs to be handled correctly.
"""
s_in_y, _ = scan(
fn=lambda z: (z + 1, until(z > 2)),
outputs_info=[
{"taps": [-1], "initial": at.as_tensor(0.0, dtype=np.float64)}
],
mode=mode,
n_steps=n - 1,
allow_gc=False,
)
return s_in_y.sum()
s_y, updates = scan(
fn=fn,
outputs_info=[None],
sequences=[at.as_tensor([3, 2, 1], dtype=np.int64)],
mode=mode,
allow_gc=False,
)
f_cvm = function([], s_y, mode=mode)
res = f_cvm()
assert np.array_equal(res, np.array([3, 1, 0]))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论