提交 81a8741c authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add n_non_seqs to ScanInfo

上级 d58d482a
...@@ -1139,6 +1139,7 @@ def scan( ...@@ -1139,6 +1139,7 @@ def scan(
n_sit_sot=n_sit_sot, n_sit_sot=n_sit_sot,
n_shared_outs=n_shared_outs, n_shared_outs=n_shared_outs,
n_nit_sot=n_nit_sot, n_nit_sot=n_nit_sot,
n_non_seqs=len(other_shared_inner_args) + len(other_inner_args),
) )
local_op = Scan( local_op = Scan(
......
...@@ -216,6 +216,7 @@ class ScanInfo: ...@@ -216,6 +216,7 @@ class ScanInfo:
n_sit_sot: int n_sit_sot: int
n_shared_outs: int n_shared_outs: int
n_nit_sot: int n_nit_sot: int
n_non_seqs: int
TensorConstructorType = Callable[[List[bool], Union[str, np.generic]], TensorType] TensorConstructorType = Callable[[List[bool], Union[str, np.generic]], TensorType]
...@@ -785,6 +786,28 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -785,6 +786,28 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self.mitmots_preallocated, self.mitmots_preallocated,
) = self._mitmot_preallocations() ) = self._mitmot_preallocations()
# The total number of inputs across all multi-input taps
# `tap_array = mit_sot_tap_inputs + (-1,) * n_sit_sot`
# n_mit_mot_sot_inputs = sum(len(x) for x in info.tap_array[: (info.n_mit_mot + info.n_mit_sot)])
n_mit_mot_sot_inputs = info.n_mit_mot + info.n_mit_sot
# [n_steps] + sequences + mit-mots + mit-sots + sit-sots + shared-variables + nit-sots + non-sequences
self.n_outer_inputs = (
1
+ info.n_seqs
+ n_mit_mot_sot_inputs
+ info.n_sit_sot
+ info.n_nit_sot
+ info.n_shared_outs
+ info.n_non_seqs
)
self.n_outer_outputs = (
info.n_mit_mot
+ info.n_mit_sot
+ info.n_sit_sot
+ info.n_nit_sot
+ info.n_shared_outs
)
def _mitmot_preallocations(self): def _mitmot_preallocations(self):
if config.scan__allow_output_prealloc: if config.scan__allow_output_prealloc:
preallocated_mitmot_outs = [] preallocated_mitmot_outs = []
...@@ -1157,7 +1180,18 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1157,7 +1180,18 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
for t in self.outer_nitsot_outs(self.outputs) for t in self.outer_nitsot_outs(self.outputs)
] ]
apply_node = Apply(self, new_inputs, [t() for t in self.output_types]) outputs = [t() for t in self.output_types]
assert self.n_outer_inputs == len(new_inputs), (
self.n_outer_inputs,
len(new_inputs),
)
assert self.n_outer_outputs == len(outputs), (
self.n_outer_outputs,
len(outputs),
)
apply_node = Apply(self, new_inputs, outputs)
return apply_node return apply_node
def __eq__(self, other): def __eq__(self, other):
...@@ -2835,18 +2869,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2835,18 +2869,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
n_sitsot_outs = len(outer_inp_sitsot) n_sitsot_outs = len(outer_inp_sitsot)
new_tap_array = mitmot_inp_taps + [[-1] for k in range(n_sitsot_outs)] new_tap_array = mitmot_inp_taps + [[-1] for k in range(n_sitsot_outs)]
out_info = ScanInfo(
n_seqs=len(outer_inp_seqs),
n_mit_sot=0,
tap_array=tuple(tuple(v) for v in new_tap_array),
n_mit_mot=len(outer_inp_mitmot),
n_mit_mot_outs=n_mitmot_outs,
mit_mot_out_slices=tuple(tuple(v) for v in mitmot_out_taps),
n_sit_sot=n_sitsot_outs,
n_shared_outs=0,
n_nit_sot=n_nit_sot,
)
outer_inputs = ( outer_inputs = (
[grad_steps] [grad_steps]
+ outer_inp_seqs + outer_inp_seqs
...@@ -2866,6 +2888,20 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2866,6 +2888,20 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
) )
inner_gfn_outs = inner_out_mitmot + inner_out_sitsot + inner_out_nitsot inner_gfn_outs = inner_out_mitmot + inner_out_sitsot + inner_out_nitsot
out_info = ScanInfo(
n_seqs=len(outer_inp_seqs),
n_mit_sot=0,
tap_array=tuple(tuple(v) for v in new_tap_array),
n_mit_mot=len(outer_inp_mitmot),
n_mit_mot_outs=n_mitmot_outs,
mit_mot_out_slices=tuple(tuple(v) for v in mitmot_out_taps),
n_sit_sot=n_sitsot_outs,
n_shared_outs=0,
n_nit_sot=n_nit_sot,
n_non_seqs=len(self.outer_shared(inputs))
+ len(self.outer_non_seqs(inputs)),
)
local_op = Scan( local_op = Scan(
inner_gfn_ins, inner_gfn_ins,
inner_gfn_outs, inner_gfn_outs,
...@@ -3196,6 +3232,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3196,6 +3232,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
n_mit_mot_outs=n_mit_mot_outs * 2, n_mit_mot_outs=n_mit_mot_outs * 2,
tap_array=tuple(tuple(v) for v in new_tap_array), tap_array=tuple(tuple(v) for v in new_tap_array),
mit_mot_out_slices=tuple(tuple(v) for v in info.mit_mot_out_slices) * 2, mit_mot_out_slices=tuple(tuple(v) for v in info.mit_mot_out_slices) * 2,
n_non_seqs=len(inner_other),
) )
local_op = Scan( local_op = Scan(
......
...@@ -168,7 +168,9 @@ def remove_constants_and_unused_inputs_scan(fgraph, node): ...@@ -168,7 +168,9 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
if len(nw_inner) != len(op_ins): if len(nw_inner) != len(op_ins):
op_outs = clone_replace(op_outs, replace=givens) op_outs = clone_replace(op_outs, replace=givens)
nw_info = dataclasses.replace(op.info, n_seqs=nw_n_seqs) nw_info = dataclasses.replace(
op.info, n_seqs=nw_n_seqs, n_non_seqs=len(nw_inner_nonseq)
)
nwScan = Scan( nwScan = Scan(
nw_inner, nw_inner,
op_outs, op_outs,
...@@ -339,11 +341,15 @@ def push_out_non_seq_scan(fgraph, node): ...@@ -339,11 +341,15 @@ def push_out_non_seq_scan(fgraph, node):
op_outs = clone_replace(node_outputs, replace=givens) op_outs = clone_replace(node_outputs, replace=givens)
op_ins = node_inputs + nw_inner op_ins = node_inputs + nw_inner
new_info = dataclasses.replace(
op.info, n_non_seqs=op.info.n_non_seqs + len(nw_outer)
)
# Reconstruct node # Reconstruct node
nwScan = Scan( nwScan = Scan(
op_ins, op_ins,
op_outs, op_outs,
op.info, new_info,
mode=op.mode, mode=op.mode,
as_while=op.as_while, as_while=op.as_while,
profile=op.profile, profile=op.profile,
...@@ -1725,9 +1731,12 @@ class ScanMerge(GlobalOptimizer): ...@@ -1725,9 +1731,12 @@ class ScanMerge(GlobalOptimizer):
outer_outs += nd.op.outer_shared_outs(nd.outputs) outer_outs += nd.op.outer_shared_outs(nd.outputs)
inner_outs[idx].append(nd.op.inner_shared_outs(nd.op.outputs)) inner_outs[idx].append(nd.op.inner_shared_outs(nd.op.outputs))
n_non_seqs = 0
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# Non Seqs # Non Seqs
inner_ins[idx].append(rename(nd.op.inner_non_seqs(nd.op.inputs), idx)) node_inner_non_seqs = nd.op.inner_non_seqs(nd.op.inputs)
n_non_seqs += len(node_inner_non_seqs)
inner_ins[idx].append(rename(node_inner_non_seqs, idx))
outer_ins += rename(nd.op.outer_non_seqs(nd.inputs), idx) outer_ins += rename(nd.op.outer_non_seqs(nd.inputs), idx)
# Add back the number of steps # Add back the number of steps
...@@ -1803,6 +1812,7 @@ class ScanMerge(GlobalOptimizer): ...@@ -1803,6 +1812,7 @@ class ScanMerge(GlobalOptimizer):
n_sit_sot=sum(nd.op.n_sit_sot for nd in nodes), n_sit_sot=sum(nd.op.n_sit_sot for nd in nodes),
n_shared_outs=sum(nd.op.n_shared_outs for nd in nodes), n_shared_outs=sum(nd.op.n_shared_outs for nd in nodes),
n_nit_sot=sum(nd.op.n_nit_sot for nd in nodes), n_nit_sot=sum(nd.op.n_nit_sot for nd in nodes),
n_non_seqs=n_non_seqs,
) )
old_op = nodes[0].op old_op = nodes[0].op
......
...@@ -407,6 +407,7 @@ def compress_outs(op, not_required, inputs): ...@@ -407,6 +407,7 @@ def compress_outs(op, not_required, inputs):
n_sit_sot=0, n_sit_sot=0,
n_shared_outs=0, n_shared_outs=0,
n_nit_sot=0, n_nit_sot=0,
n_non_seqs=0,
) )
op_inputs = op.inputs[: op.n_seqs] op_inputs = op.inputs[: op.n_seqs]
...@@ -525,6 +526,7 @@ def compress_outs(op, not_required, inputs): ...@@ -525,6 +526,7 @@ def compress_outs(op, not_required, inputs):
node_inputs += nit_sot_ins node_inputs += nit_sot_ins
# other stuff # other stuff
op_inputs += op.inputs[i_offset:] op_inputs += op.inputs[i_offset:]
info = dataclasses.replace(info, n_non_seqs=len(op.inputs[i_offset:]))
node_inputs += inputs[ni_offset + op.n_shared_outs + op.n_nit_sot :] node_inputs += inputs[ni_offset + op.n_shared_outs + op.n_nit_sot :]
if op.as_while: if op.as_while:
op_outputs += [op.outputs[o_offset]] op_outputs += [op.outputs[o_offset]]
...@@ -812,6 +814,7 @@ class ScanArgs: ...@@ -812,6 +814,7 @@ class ScanArgs:
n_shared_outs=len(self.outer_in_shared), n_shared_outs=len(self.outer_in_shared),
n_mit_mot_outs=sum(len(s) for s in self.mit_mot_out_slices), n_mit_mot_outs=sum(len(s) for s in self.mit_mot_out_slices),
mit_mot_out_slices=tuple(self.mit_mot_out_slices), mit_mot_out_slices=tuple(self.mit_mot_out_slices),
n_non_seqs=len(self.inner_in_non_seqs),
) )
def get_alt_field( def get_alt_field(
......
...@@ -3976,3 +3976,116 @@ class TestExamples: ...@@ -3976,3 +3976,116 @@ class TestExamples:
# with config.change_flags(mode="DebugMode"): # with config.change_flags(mode="DebugMode"):
# Also, the purpose of this test is not clear. # Also, the purpose of this test is not clear.
self._grad_mout_helper(1, None) self._grad_mout_helper(1, None)
c = scalar("c", dtype="floatX")
@pytest.mark.parametrize(
"fn, sequences, outputs_info, non_sequences, n_steps, op_check",
[
# sequences
(
lambda a_t: 2 * a_t,
[at.arange(10)],
[{}],
[],
None,
lambda op: op.info.n_seqs > 0,
),
# nit-sot
(
lambda: at.as_tensor(2.0),
[],
[{}],
[],
3,
lambda op: op.info.n_nit_sot > 0,
),
# nit-sot, non_seq
(
lambda c: at.as_tensor(2.0) * c,
[],
[{}],
[c],
3,
lambda op: op.info.n_nit_sot > 0 and op.info.n_non_seqs > 0,
),
# sit-sot
(
lambda a_tm1: 2 * a_tm1,
[],
[{"initial": at.as_tensor(0.0, dtype="floatX"), "taps": [-1]}],
[],
3,
lambda op: op.info.n_sit_sot > 0,
),
# sit-sot, while
(
lambda a_tm1: (a_tm1 + 1, until(a_tm1 > 2)),
[],
[{"initial": at.as_tensor(1, dtype=np.int64), "taps": [-1]}],
[],
3,
lambda op: op.info.n_sit_sot > 0,
),
# nit-sot, shared input/output
(
lambda: RandomStream().normal(0, 1, name="a"),
[],
[{}],
[],
3,
lambda op: op.info.n_shared_outs > 0,
),
# mit-sot (that's also a type of sit-sot)
(
lambda a_tm1: 2 * a_tm1,
[],
[{"initial": at.as_tensor([0.0, 1.0], dtype="floatX"), "taps": [-2]}],
[],
6,
lambda op: op.info.n_mit_sot > 0,
),
# mit-sot
(
lambda a_tm1, b_tm1: (2 * a_tm1, 2 * b_tm1),
[],
[
{"initial": at.as_tensor(0.0, dtype="floatX"), "taps": [-1]},
{"initial": at.as_tensor(0.0, dtype="floatX"), "taps": [-1]},
],
[],
10,
lambda op: op.info.n_mit_sot > 0,
),
# TODO: mit-mot (can't be created through the `scan` interface)
],
)
def test_n_non_seqs(fn, sequences, outputs_info, non_sequences, n_steps, op_check):
res, _ = scan(
fn,
sequences=sequences,
outputs_info=outputs_info,
non_sequences=non_sequences,
n_steps=n_steps,
strict=True,
)
if isinstance(res, list):
res = res[0]
if not isinstance(res.owner.op, Scan):
res = res.owner.inputs[0]
scan_op = res.owner.op
assert isinstance(scan_op, Scan)
# from aesara.scan.utils import ScanArgs
# print(ScanArgs.from_node(res.owner))
# print(res.owner.op.info)
_ = op_check(scan_op)
assert scan_op.n_outer_inputs == len(res.owner.inputs)
assert scan_op.n_outer_outputs == len(res.owner.outputs)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论