提交 c8eea207 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix a Scan Cython issue involving output storage lengths

上级 77bb1523
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -62,7 +62,7 @@ numpy.import_array() ...@@ -62,7 +62,7 @@ numpy.import_array()
def get_version(): def get_version():
return 0.325 return 0.326
@cython.cdivision(True) @cython.cdivision(True)
...@@ -253,19 +253,24 @@ def perform( ...@@ -253,19 +253,24 @@ def perform(
continue continue
outer_outputs_idx_0 = outer_outputs_idx[0] outer_outputs_idx_0 = outer_outputs_idx[0]
outer_inputs_offset_idx = outer_inputs[<unsigned int>(seqs_arg_offset + idx)]
if ( outer_outputs_idx_0 is not None and if ( outer_outputs_idx_0 is not None and
outer_outputs_idx_0.shape[1:] == outer_inputs[<unsigned int>(1+ n_seqs + idx)].shape[1:] outer_outputs_idx_0.shape[1:] == outer_inputs_offset_idx.shape[1:]
and outer_outputs_idx_0.shape[0] >= store_steps[idx] ): and outer_outputs_idx_0.shape[0] >= store_steps[idx] ):
# Put in the values of the initial state # Put in the values of the initial state
outer_outputs_idx[0] = outer_outputs_idx_0[:store_steps[idx]]
outer_outputs_idx_0 = outer_outputs_idx_0[:store_steps[idx]]
outer_outputs_idx[0] = outer_outputs_idx_0
if idx > n_mit_mot: if idx > n_mit_mot:
l = - mintaps[idx] l = - mintaps[idx]
outer_outputs_idx_0[:l] = outer_inputs[<unsigned int>(seqs_arg_offset + idx)][:l] outer_outputs_idx_0[:l] = outer_inputs_offset_idx[:l]
else: else:
outer_outputs_idx_0[:] = outer_inputs[<unsigned int>(seqs_arg_offset + idx)] outer_outputs_idx_0[:] = outer_inputs_offset_idx
else: else:
outer_outputs_idx[0] = outer_inputs[<unsigned int>(seqs_arg_offset + idx)].copy() outer_outputs_idx[0] = outer_inputs_offset_idx.copy()
if n_steps == 0: if n_steps == 0:
for idx in range(n_outs, n_outs + n_nit_sot): for idx in range(n_outs, n_outs + n_nit_sot):
......
...@@ -23,7 +23,7 @@ if not config.cxx: ...@@ -23,7 +23,7 @@ if not config.cxx:
_logger = logging.getLogger("aesara.scan.scan_perform") _logger = logging.getLogger("aesara.scan.scan_perform")
version = 0.325 # must match constant returned in function get_version() version = 0.326 # must match constant returned in function get_version()
need_reload = False need_reload = False
scan_perform: Optional[ModuleType] = None scan_perform: Optional[ModuleType] = None
......
...@@ -20,6 +20,7 @@ from tempfile import mkdtemp ...@@ -20,6 +20,7 @@ from tempfile import mkdtemp
import numpy as np import numpy as np
import pytest import pytest
import aesara.tensor as at
from aesara.compile.debugmode import DebugMode from aesara.compile.debugmode import DebugMode
from aesara.compile.function import function from aesara.compile.function import function
from aesara.compile.function.pfunc import rebuild_collect_shared from aesara.compile.function.pfunc import rebuild_collect_shared
...@@ -38,7 +39,6 @@ from aesara.raise_op import assert_op ...@@ -38,7 +39,6 @@ from aesara.raise_op import assert_op
from aesara.scan.basic import scan from aesara.scan.basic import scan
from aesara.scan.op import Scan from aesara.scan.op import Scan
from aesara.scan.utils import until from aesara.scan.utils import until
from aesara.tensor import basic as at
from aesara.tensor.math import all as at_all from aesara.tensor.math import all as at_all
from aesara.tensor.math import dot, mean, sigmoid from aesara.tensor.math import dot, mean, sigmoid
from aesara.tensor.math import sum as at_sum from aesara.tensor.math import sum as at_sum
...@@ -4028,3 +4028,51 @@ def test_ScanInfo_totals(fn, sequences, outputs_info, non_sequences, n_steps, op ...@@ -4028,3 +4028,51 @@ def test_ScanInfo_totals(fn, sequences, outputs_info, non_sequences, n_steps, op
assert scan_op.info.n_outer_outputs == len(res.owner.outputs) assert scan_op.info.n_outer_outputs == len(res.owner.outputs)
assert scan_op.info.n_inner_inputs == len(res.owner.op.inner_inputs) assert scan_op.info.n_inner_inputs == len(res.owner.op.inner_inputs)
assert scan_op.info.n_inner_outputs == len(res.owner.op.inner_outputs) assert scan_op.info.n_inner_outputs == len(res.owner.op.inner_outputs)
@pytest.mark.parametrize("linker_mode", ["cvm", "py"])
def test_output_storage_reuse(linker_mode):
"""Make sure that outer-output storage is correctly initialized when it's non-``None``/empty."""
if linker_mode == "cvm":
# This implicitly confirms that the Cython version is being used
from aesara.scan import scan_perform_ext # noqa: F401
mode = Mode(linker=linker_mode, optimizer=None)
def fn(n):
"""
Since the following inner-`Scan` is nested, its outer-output storage
will be non-``None`` after the second outer-`Scan` iteration, and all
subsequent iterations will use the previous outer-output storage. Due
to the ``n_step`` changes, the shape of the outer-inputs array that's
allocated for the lagged/sit-sot ``z`` results should differ from the
shape of the previously allocated outer-output array. Since the
outer-output arrays are initialized using the outer-input arrays, the
shape difference needs to be handled correctly.
"""
s_in_y, _ = scan(
fn=lambda z: (z + 1, until(z > 2)),
outputs_info=[
{"taps": [-1], "initial": at.as_tensor(0.0, dtype=np.float64)}
],
mode=mode,
n_steps=n - 1,
allow_gc=False,
)
return s_in_y.sum()
s_y, updates = scan(
fn=fn,
outputs_info=[None],
sequences=[at.as_tensor([3, 2, 1], dtype=np.int64)],
mode=mode,
allow_gc=False,
)
f_cvm = function([], s_y, mode=mode)
res = f_cvm()
assert np.array_equal(res, np.array([3, 1, 0]))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论