提交 dee152ce authored 作者: Rémi Louf's avatar Rémi Louf 提交者: Brandon T. Willard

Make scan support n_steps = 0 and empty input sequences

上级 f63d372d
......@@ -888,6 +888,7 @@ def scan(
n_shared_outs += 1
n_sit_sot = len(sit_sot_inner_inputs)
# Step 5.4 Outputs with no taps used in the input
n_nit_sot = 0
nit_sot_inner_outputs = []
......
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -1497,8 +1497,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
raise IndexError(
f"Scan was asked to run for negative number of step {n_steps}"
)
elif n_steps == 0:
raise NotImplementedError("n_steps == 0")
else:
for idx, seq in enumerate(inputs[1 : self.seqs_arg_offset]):
if seq.shape[0] < n_steps:
......@@ -1524,10 +1522,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
]
]
pos = [
(-self.mintaps[idx]) % store_steps[idx]
for idx in range(self.n_outs + self.n_nit_sot)
]
# 2.1 Create storage space for outputs
for idx in range(self.n_outs):
if idx in self.destroy_map:
......@@ -1550,6 +1544,20 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
else:
output_storage[idx][0] = inputs[self.seqs_arg_offset + idx].copy()
if n_steps == 0:
for idx in range(self.n_outs, self.n_outs + self.n_nit_sot):
out_var = node.outputs[idx]
if isinstance(out_var, TensorVariable):
output_storage[idx][0] = out_var.type.value_zeros(0)
else:
output_storage[idx][0] = None
return
pos = [
(-self.mintaps[idx]) % store_steps[idx]
for idx in range(self.n_outs + self.n_nit_sot)
]
offset = self.nit_sot_arg_offset + self.n_nit_sot
other_args = inputs[offset:]
inner_input_storage = self.fn.input_storage
......
......@@ -58,7 +58,7 @@ from aesara.link.utils import raise_with_op
def get_version():
return 0.301
return 0.302
@cython.boundscheck(False)
def perform(
......@@ -205,8 +205,6 @@ def perform(
raise IndexError(
"Scan was asked to run for negative number of step %d" %
n_steps)
elif n_steps == 0:
raise NotImplementedError("n_steps == 0")
else:
for idx in range(n_seqs):
if args[<unsigned int>(1+idx)].shape[0] < n_steps:
......@@ -230,10 +228,6 @@ def perform(
args[<unsigned int>(idx + n_mit_mot + n_mit_sot + n_sit_sot
+ n_shared_outs + n_seqs+1)]
for idx in range(n_outs + n_nit_sot):
pos[idx] = (-mintaps[idx])%store_steps[idx]
# 2.1 Create storage space for outputs
for idx in range(n_outs):
if destroy_map[idx] != 0:
......@@ -254,6 +248,21 @@ def perform(
else:
outs[idx][0] = args[<unsigned int>(seqs_arg_offset + idx)].copy()
if n_steps == 0:
for idx in range(n_outs, n_outs + n_nit_sot):
if outs_is_tensor[idx]:
# TODO FIXME: Why have an `outs_is_tensor` when you can access
# the node directly?
# (The answer is that you shouldn't have a `node` object to
# access, because it's not going to produce a very efficient
# Cython function!)
outs[idx][0] = node.outputs[idx].type.value_zeros(0)
else:
outs[idx][0] = None
return
for idx in range(n_outs + n_nit_sot):
pos[idx] = -mintaps[idx] % store_steps[idx]
offset = nit_sot_arg_offset + n_nit_sot
other_args = args[offset:]
......@@ -274,7 +283,6 @@ def perform(
for idx in range(len(other_args)):
input_storage[<unsigned int>(idx+offset)].storage[0] = other_args[idx]
i = 0
cond = 1
############## THE MAIN LOOP #########################
......
......@@ -21,7 +21,7 @@ if not config.cxx:
_logger = logging.getLogger("aesara.scan.scan_perform")
version = 0.301 # must match constant returned in function get_version()
version = 0.302 # must match constant returned in function get_version()
need_reload = False
......
......@@ -408,6 +408,98 @@ class TestScan:
rng = np.random.default_rng(utt.fetch_seed())
my_f(rng.uniform(size=(3,)), 4, np.int64([2, 2, 3]))
@pytest.mark.parametrize("mode", [Mode(linker="py"), Mode(linker="cvm")])
@pytest.mark.parametrize(
"x_init",
[
scalar("x"),
iscalar("x"),
],
)
def test_no_step(self, mode, x_init):
"""We expect an empty output array when ``n_steps == 0``."""
def f_pow(x_tm1):
return 2 * x_tm1
n_steps = iscalar("n_steps")
values, _ = scan(f_pow, outputs_info=(x_init,), n_steps=n_steps)
update_fn = function((x_init, n_steps), values, mode=mode)
res = update_fn(1.0, 0)
exp_res = np.array([], dtype=values.dtype)
assert np.array_equal(res, exp_res)
assert res.dtype == exp_res.dtype
@pytest.mark.parametrize(
"mode", [Mode(linker="py", optimizer=None), Mode(linker="cvm", optimizer=None)]
)
@pytest.mark.parametrize(
"x",
[
vector("x"),
ivector("x"),
],
)
@pytest.mark.parametrize(
"x_init",
[
scalar("x"),
iscalar("x"),
],
)
def test_no_steps_sit_sot(self, mode, x, x_init):
"""We expect an empty output array when scanning over an empty sequence."""
def inner_fn(x_seq, x_i):
return 2 * x_i
with config.change_flags(mode=mode):
values, _ = scan(inner_fn, outputs_info=(x_init,), sequences=x)
values_fn = function((x_init, x), values)
assert isinstance(values.owner.inputs[0].owner.op, Scan)
x_val = np.array([], dtype=x.dtype)
x_init_val = 1.0
res = values_fn(x_init_val, x_val)
exp_res = np.array([], dtype=values.dtype)
assert np.array_equal(res, exp_res)
assert res.dtype == exp_res.dtype
@pytest.mark.parametrize(
"mode", [Mode(linker="py", optimizer=None), Mode(linker="cvm", optimizer=None)]
)
@pytest.mark.parametrize(
"x",
[
vector("x"),
ivector("x"),
],
)
def test_no_steps_nit_sot(self, mode, x):
"""We expect an empty output array when scanning over an empty sequence."""
def inner_fn(x_i):
return 2 * x_i
with config.change_flags(mode=mode):
values, _ = scan(inner_fn, sequences=x)
values_fn = function((x,), values)
assert isinstance(values.owner.op, Scan)
x_val = np.array([], dtype=x.dtype)
res = values_fn(x_val)
exp_res = np.array([], dtype=values.dtype)
assert np.array_equal(res, exp_res)
assert res.dtype == exp_res.dtype
@pytest.mark.slow
def test_only_nonseq_inputs(self):
# Compile the Aesara function
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论