提交 76585fe1 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Split ScanInfo.tap_array into mit-mot, mit-sot, sit-sot parts

上级 cdd8575b
......@@ -1123,15 +1123,14 @@ def scan(
# Step 7. Create the Scan Op
##
tap_array = tuple(tuple(v) for v in mit_sot_tap_array) + tuple(
(-1,) for x in range(n_sit_sot)
)
if allow_gc is None:
allow_gc = config.scan__allow_gc
info = ScanInfo(
tap_array=tap_array,
n_seqs=n_seqs,
mit_mot_in_slices=(),
mit_sot_in_slices=tuple(tuple(v) for v in mit_sot_tap_array),
sit_sot_in_slices=tuple((-1,) for x in range(n_sit_sot)),
n_mit_mot=n_mit_mot,
n_mit_mot_outs=n_mit_mot_outs,
mit_mot_out_slices=tuple(tuple(v) for v in mit_mot_out_slices),
......
差异被折叠。
......@@ -2,6 +2,7 @@
import copy
import dataclasses
from itertools import chain
from sys import maxsize
from typing import Dict, List, Optional, Tuple
......@@ -82,7 +83,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
op = node.op
# We only need to take care of sequences and other arguments
st = op.n_seqs
st += int(sum(len(x) for x in op.tap_array[: (op.n_mit_mot + op.n_mit_sot)]))
st += int(sum(len(x) for x in chain(op.mit_mot_in_slices, op.mit_sot_in_slices)))
st += op.n_sit_sot
st += op.n_shared_outs
......@@ -1147,7 +1148,7 @@ def save_mem_new_scan(fgraph, node):
c_outs = op.n_mit_mot + op.n_mit_sot + op.n_sit_sot + op.n_nit_sot
init_l = [0 for x in range(op.n_mit_mot)]
init_l += [abs(min(v)) for v in op.tap_array[op.n_mit_mot :]]
init_l += [abs(min(v)) for v in chain(op.mit_sot_in_slices, op.sit_sot_in_slices)]
init_l += [0 for x in range(op.n_nit_sot)]
# 2. Check the clients of each output and see for how many steps
# does scan need to run
......@@ -1678,30 +1679,29 @@ class ScanMerge(GlobalOptimizer):
inner_ins[idx].append(rename(nd.op.inner_seqs(nd.op.inputs), idx))
outer_ins += rename(nd.op.outer_seqs(nd.inputs), idx)
tap_array = ()
mit_mot_out_slices = ()
mit_mot_in_slices = ()
for idx, nd in enumerate(nodes):
# MitMot
inner_ins[idx].append(rename(nd.op.inner_mitmot(nd.op.inputs), idx))
inner_outs[idx].append(nd.op.inner_mitmot_outs(nd.op.outputs))
tap_array += nd.op.mitmot_taps()
mit_mot_in_slices += nd.op.mitmot_taps()
mit_mot_out_slices += nd.op.mitmot_out_taps()
outer_ins += rename(nd.op.outer_mitmot(nd.inputs), idx)
outer_outs += nd.op.outer_mitmot_outs(nd.outputs)
mit_sot_in_slices = ()
for idx, nd in enumerate(nodes):
# MitSot
inner_ins[idx].append(rename(nd.op.inner_mitsot(nd.op.inputs), idx))
inner_outs[idx].append(nd.op.inner_mitsot_outs(nd.op.outputs))
tap_array += nd.op.mitsot_taps()
mit_sot_in_slices += nd.op.mitsot_taps()
outer_ins += rename(nd.op.outer_mitsot(nd.inputs), idx)
outer_outs += nd.op.outer_mitsot_outs(nd.outputs)
sit_sot_in_slices = ()
for idx, nd in enumerate(nodes):
# SitSot
inner_ins[idx].append(rename(nd.op.inner_sitsot(nd.op.inputs), idx))
tap_array += tuple((-1,) for x in range(nd.op.n_sit_sot))
sit_sot_in_slices += tuple((-1,) for x in range(nd.op.n_sit_sot))
inner_outs[idx].append(nd.op.inner_sitsot_outs(nd.op.outputs))
outer_ins += rename(nd.op.outer_sitsot(nd.inputs), idx)
outer_outs += nd.op.outer_sitsot_outs(nd.outputs)
......@@ -1794,7 +1794,9 @@ class ScanMerge(GlobalOptimizer):
new_inner_outs += inner_outs[idx][gr_idx]
info = ScanInfo(
tap_array=tap_array,
mit_mot_in_slices=mit_mot_in_slices,
mit_sot_in_slices=mit_sot_in_slices,
sit_sot_in_slices=sit_sot_in_slices,
n_seqs=sum(nd.op.n_seqs for nd in nodes),
n_mit_mot=sum(nd.op.n_mit_mot for nd in nodes),
n_mit_mot_outs=sum(nd.op.n_mit_mot_outs for nd in nodes),
......@@ -2212,14 +2214,10 @@ def push_out_dot1_scan(fgraph, node):
inner_non_seqs = op.inner_non_seqs(op.inputs)
outer_non_seqs = op.outer_non_seqs(node.inputs)
st = len(op.mitmot_taps()) + len(op.mitsot_taps())
new_info = dataclasses.replace(
op.info,
tap_array=(
op.info.tap_array[: st + idx]
+ op.info.tap_array[st + idx + 1 :]
),
sit_sot_in_slices=op.info.sit_sot_in_slices[:idx]
+ op.info.sit_sot_in_slices[idx + 1 :],
n_sit_sot=op.info.n_sit_sot - 1,
n_nit_sot=op.info.n_nit_sot + 1,
)
......
......@@ -4,6 +4,7 @@ import copy
import dataclasses
import logging
from collections import OrderedDict, namedtuple
from itertools import chain
from typing import TYPE_CHECKING, Callable, List, Optional, Sequence, Set, Tuple, Union
from typing import cast as type_cast
......@@ -348,11 +349,12 @@ class Validator:
def scan_can_remove_outs(op, out_idxs):
"""
Looks at all outputs defined by indices ``out_idxs`` and see whom can be
removed from the scan op without affecting the rest. Return two lists,
the first one with the indices of outs that can be removed, the second
with the outputs that can not be removed.
"""Look at all outputs defined by indices ``out_idxs`` and determines which can be removed.
Returns
-------
two lists, the first one with the indices of outs that can be removed, the
second with the outputs that can not be removed.
"""
non_removable = [o for i, o in enumerate(op.outputs) if i not in out_idxs]
......@@ -360,9 +362,10 @@ def scan_can_remove_outs(op, out_idxs):
out_ins = []
offset = op.n_seqs
lim = op.n_mit_mot + op.n_mit_sot + op.n_sit_sot
for idx in range(lim):
n_ins = len(op.info.tap_array[idx])
for idx, tap in enumerate(
chain(op.mit_mot_in_slices, op.mit_sot_in_slices, op.sit_sot_in_slices)
):
n_ins = len(tap)
out_ins += [op.inputs[offset : offset + n_ins]]
offset += n_ins
out_ins += [[] for k in range(op.n_nit_sot)]
......@@ -398,7 +401,9 @@ def compress_outs(op, not_required, inputs):
from aesara.scan.op import ScanInfo
info = ScanInfo(
tap_array=(),
mit_mot_in_slices=(),
mit_sot_in_slices=(),
sit_sot_in_slices=(),
n_seqs=op.info.n_seqs,
n_mit_mot=0,
n_mit_mot_outs=0,
......@@ -428,23 +433,24 @@ def compress_outs(op, not_required, inputs):
info = dataclasses.replace(
info,
n_mit_mot=info.n_mit_mot + 1,
tap_array=info.tap_array + (tuple(op.tap_array[offset + idx]),),
mit_mot_in_slices=info.mit_mot_in_slices + (op.mit_mot_in_slices[idx],),
mit_mot_out_slices=info.mit_mot_out_slices
+ (tuple(op.mit_mot_out_slices[offset + idx]),),
+ (op.mit_mot_out_slices[idx],),
)
# input taps
for jdx in op.tap_array[offset + idx]:
for jdx in op.mit_mot_in_slices[idx]:
op_inputs += [op.inputs[i_offset]]
i_offset += 1
# output taps
for jdx in op.mit_mot_out_slices[offset + idx]:
for jdx in op.mit_mot_out_slices[idx]:
op_outputs += [op.outputs[o_offset]]
o_offset += 1
# node inputs
node_inputs += [inputs[ni_offset + idx]]
else:
o_offset += len(op.mit_mot_out_slices[offset + idx])
i_offset += len(op.tap_array[offset + idx])
o_offset += len(op.mit_mot_out_slices[idx])
i_offset += len(op.mit_mot_in_slices[idx])
info = dataclasses.replace(info, n_mit_mot_outs=len(op_outputs))
offset += op.n_mit_mot
ni_offset += op.n_mit_mot
......@@ -456,10 +462,10 @@ def compress_outs(op, not_required, inputs):
info = dataclasses.replace(
info,
n_mit_sot=info.n_mit_sot + 1,
tap_array=info.tap_array + (tuple(op.tap_array[offset + idx]),),
mit_sot_in_slices=info.mit_sot_in_slices + (op.mit_sot_in_slices[idx],),
)
# input taps
for jdx in op.tap_array[offset + idx]:
for jdx in op.mit_sot_in_slices[idx]:
op_inputs += [op.inputs[i_offset]]
i_offset += 1
# output taps
......@@ -469,7 +475,7 @@ def compress_outs(op, not_required, inputs):
node_inputs += [inputs[ni_offset + idx]]
else:
o_offset += 1
i_offset += len(op.tap_array[offset + idx])
i_offset += len(op.mit_sot_in_slices[idx])
offset += op.n_mit_sot
ni_offset += op.n_mit_sot
......@@ -480,7 +486,7 @@ def compress_outs(op, not_required, inputs):
info = dataclasses.replace(
info,
n_sit_sot=info.n_sit_sot + 1,
tap_array=info.tap_array + (tuple(op.tap_array[offset + idx]),),
sit_sot_in_slices=info.sit_sot_in_slices + (op.sit_sot_in_slices[idx],),
)
# input taps
op_inputs += [op.inputs[i_offset]]
......@@ -613,8 +619,9 @@ class ScanArgs:
n_mit_mot = info.n_mit_mot
n_mit_sot = info.n_mit_sot
self.mit_mot_in_slices = list(info.tap_array[:n_mit_mot])
self.mit_sot_in_slices = list(info.tap_array[n_mit_mot : n_mit_mot + n_mit_sot])
self.mit_mot_in_slices = info.mit_mot_in_slices
self.mit_sot_in_slices = info.mit_sot_in_slices
self.sit_sot_in_slices = info.sit_sot_in_slices
n_mit_mot_ins = sum(len(s) for s in self.mit_mot_in_slices)
n_mit_sot_ins = sum(len(s) for s in self.mit_sot_in_slices)
......@@ -800,14 +807,12 @@ class ScanArgs:
from aesara.scan.op import ScanInfo
return ScanInfo(
mit_mot_in_slices=tuple(tuple(v) for v in self.mit_mot_in_slices),
mit_sot_in_slices=tuple(tuple(v) for v in self.mit_sot_in_slices),
sit_sot_in_slices=((-1,),) * len(self.inner_in_sit_sot),
n_seqs=len(self.outer_in_seqs),
n_mit_mot=len(self.outer_in_mit_mot),
n_mit_sot=len(self.outer_in_mit_sot),
tap_array=(
tuple(tuple(v) for v in self.mit_mot_in_slices)
+ tuple(tuple(v) for v in self.mit_sot_in_slices)
+ ((-1,),) * len(self.inner_in_sit_sot)
),
n_sit_sot=len(self.outer_in_sit_sot),
n_nit_sot=len(self.outer_in_nit_sot),
n_shared_outs=len(self.outer_in_shared),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论