提交 76585fe1 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Split ScanInfo.tap_array into mit-mot, mit-sot, sit-sot parts

上级 cdd8575b
...@@ -1123,15 +1123,14 @@ def scan( ...@@ -1123,15 +1123,14 @@ def scan(
# Step 7. Create the Scan Op # Step 7. Create the Scan Op
## ##
tap_array = tuple(tuple(v) for v in mit_sot_tap_array) + tuple(
(-1,) for x in range(n_sit_sot)
)
if allow_gc is None: if allow_gc is None:
allow_gc = config.scan__allow_gc allow_gc = config.scan__allow_gc
info = ScanInfo( info = ScanInfo(
tap_array=tap_array,
n_seqs=n_seqs, n_seqs=n_seqs,
mit_mot_in_slices=(),
mit_sot_in_slices=tuple(tuple(v) for v in mit_sot_tap_array),
sit_sot_in_slices=tuple((-1,) for x in range(n_sit_sot)),
n_mit_mot=n_mit_mot, n_mit_mot=n_mit_mot,
n_mit_mot_outs=n_mit_mot_outs, n_mit_mot_outs=n_mit_mot_outs,
mit_mot_out_slices=tuple(tuple(v) for v in mit_mot_out_slices), mit_mot_out_slices=tuple(tuple(v) for v in mit_mot_out_slices),
......
差异被折叠。
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import copy import copy
import dataclasses import dataclasses
from itertools import chain
from sys import maxsize from sys import maxsize
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
...@@ -82,7 +83,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node): ...@@ -82,7 +83,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
op = node.op op = node.op
# We only need to take care of sequences and other arguments # We only need to take care of sequences and other arguments
st = op.n_seqs st = op.n_seqs
st += int(sum(len(x) for x in op.tap_array[: (op.n_mit_mot + op.n_mit_sot)])) st += int(sum(len(x) for x in chain(op.mit_mot_in_slices, op.mit_sot_in_slices)))
st += op.n_sit_sot st += op.n_sit_sot
st += op.n_shared_outs st += op.n_shared_outs
...@@ -1147,7 +1148,7 @@ def save_mem_new_scan(fgraph, node): ...@@ -1147,7 +1148,7 @@ def save_mem_new_scan(fgraph, node):
c_outs = op.n_mit_mot + op.n_mit_sot + op.n_sit_sot + op.n_nit_sot c_outs = op.n_mit_mot + op.n_mit_sot + op.n_sit_sot + op.n_nit_sot
init_l = [0 for x in range(op.n_mit_mot)] init_l = [0 for x in range(op.n_mit_mot)]
init_l += [abs(min(v)) for v in op.tap_array[op.n_mit_mot :]] init_l += [abs(min(v)) for v in chain(op.mit_sot_in_slices, op.sit_sot_in_slices)]
init_l += [0 for x in range(op.n_nit_sot)] init_l += [0 for x in range(op.n_nit_sot)]
# 2. Check the clients of each output and see for how many steps # 2. Check the clients of each output and see for how many steps
# does scan need to run # does scan need to run
...@@ -1678,30 +1679,29 @@ class ScanMerge(GlobalOptimizer): ...@@ -1678,30 +1679,29 @@ class ScanMerge(GlobalOptimizer):
inner_ins[idx].append(rename(nd.op.inner_seqs(nd.op.inputs), idx)) inner_ins[idx].append(rename(nd.op.inner_seqs(nd.op.inputs), idx))
outer_ins += rename(nd.op.outer_seqs(nd.inputs), idx) outer_ins += rename(nd.op.outer_seqs(nd.inputs), idx)
tap_array = ()
mit_mot_out_slices = () mit_mot_out_slices = ()
mit_mot_in_slices = ()
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# MitMot
inner_ins[idx].append(rename(nd.op.inner_mitmot(nd.op.inputs), idx)) inner_ins[idx].append(rename(nd.op.inner_mitmot(nd.op.inputs), idx))
inner_outs[idx].append(nd.op.inner_mitmot_outs(nd.op.outputs)) inner_outs[idx].append(nd.op.inner_mitmot_outs(nd.op.outputs))
tap_array += nd.op.mitmot_taps() mit_mot_in_slices += nd.op.mitmot_taps()
mit_mot_out_slices += nd.op.mitmot_out_taps() mit_mot_out_slices += nd.op.mitmot_out_taps()
outer_ins += rename(nd.op.outer_mitmot(nd.inputs), idx) outer_ins += rename(nd.op.outer_mitmot(nd.inputs), idx)
outer_outs += nd.op.outer_mitmot_outs(nd.outputs) outer_outs += nd.op.outer_mitmot_outs(nd.outputs)
mit_sot_in_slices = ()
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# MitSot
inner_ins[idx].append(rename(nd.op.inner_mitsot(nd.op.inputs), idx)) inner_ins[idx].append(rename(nd.op.inner_mitsot(nd.op.inputs), idx))
inner_outs[idx].append(nd.op.inner_mitsot_outs(nd.op.outputs)) inner_outs[idx].append(nd.op.inner_mitsot_outs(nd.op.outputs))
tap_array += nd.op.mitsot_taps() mit_sot_in_slices += nd.op.mitsot_taps()
outer_ins += rename(nd.op.outer_mitsot(nd.inputs), idx) outer_ins += rename(nd.op.outer_mitsot(nd.inputs), idx)
outer_outs += nd.op.outer_mitsot_outs(nd.outputs) outer_outs += nd.op.outer_mitsot_outs(nd.outputs)
sit_sot_in_slices = ()
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# SitSot
inner_ins[idx].append(rename(nd.op.inner_sitsot(nd.op.inputs), idx)) inner_ins[idx].append(rename(nd.op.inner_sitsot(nd.op.inputs), idx))
tap_array += tuple((-1,) for x in range(nd.op.n_sit_sot)) sit_sot_in_slices += tuple((-1,) for x in range(nd.op.n_sit_sot))
inner_outs[idx].append(nd.op.inner_sitsot_outs(nd.op.outputs)) inner_outs[idx].append(nd.op.inner_sitsot_outs(nd.op.outputs))
outer_ins += rename(nd.op.outer_sitsot(nd.inputs), idx) outer_ins += rename(nd.op.outer_sitsot(nd.inputs), idx)
outer_outs += nd.op.outer_sitsot_outs(nd.outputs) outer_outs += nd.op.outer_sitsot_outs(nd.outputs)
...@@ -1794,7 +1794,9 @@ class ScanMerge(GlobalOptimizer): ...@@ -1794,7 +1794,9 @@ class ScanMerge(GlobalOptimizer):
new_inner_outs += inner_outs[idx][gr_idx] new_inner_outs += inner_outs[idx][gr_idx]
info = ScanInfo( info = ScanInfo(
tap_array=tap_array, mit_mot_in_slices=mit_mot_in_slices,
mit_sot_in_slices=mit_sot_in_slices,
sit_sot_in_slices=sit_sot_in_slices,
n_seqs=sum(nd.op.n_seqs for nd in nodes), n_seqs=sum(nd.op.n_seqs for nd in nodes),
n_mit_mot=sum(nd.op.n_mit_mot for nd in nodes), n_mit_mot=sum(nd.op.n_mit_mot for nd in nodes),
n_mit_mot_outs=sum(nd.op.n_mit_mot_outs for nd in nodes), n_mit_mot_outs=sum(nd.op.n_mit_mot_outs for nd in nodes),
...@@ -2212,14 +2214,10 @@ def push_out_dot1_scan(fgraph, node): ...@@ -2212,14 +2214,10 @@ def push_out_dot1_scan(fgraph, node):
inner_non_seqs = op.inner_non_seqs(op.inputs) inner_non_seqs = op.inner_non_seqs(op.inputs)
outer_non_seqs = op.outer_non_seqs(node.inputs) outer_non_seqs = op.outer_non_seqs(node.inputs)
st = len(op.mitmot_taps()) + len(op.mitsot_taps())
new_info = dataclasses.replace( new_info = dataclasses.replace(
op.info, op.info,
tap_array=( sit_sot_in_slices=op.info.sit_sot_in_slices[:idx]
op.info.tap_array[: st + idx] + op.info.sit_sot_in_slices[idx + 1 :],
+ op.info.tap_array[st + idx + 1 :]
),
n_sit_sot=op.info.n_sit_sot - 1, n_sit_sot=op.info.n_sit_sot - 1,
n_nit_sot=op.info.n_nit_sot + 1, n_nit_sot=op.info.n_nit_sot + 1,
) )
......
...@@ -4,6 +4,7 @@ import copy ...@@ -4,6 +4,7 @@ import copy
import dataclasses import dataclasses
import logging import logging
from collections import OrderedDict, namedtuple from collections import OrderedDict, namedtuple
from itertools import chain
from typing import TYPE_CHECKING, Callable, List, Optional, Sequence, Set, Tuple, Union from typing import TYPE_CHECKING, Callable, List, Optional, Sequence, Set, Tuple, Union
from typing import cast as type_cast from typing import cast as type_cast
...@@ -348,11 +349,12 @@ class Validator: ...@@ -348,11 +349,12 @@ class Validator:
def scan_can_remove_outs(op, out_idxs): def scan_can_remove_outs(op, out_idxs):
""" """Look at all outputs defined by indices ``out_idxs`` and determines which can be removed.
Looks at all outputs defined by indices ``out_idxs`` and see whom can be
removed from the scan op without affecting the rest. Return two lists, Returns
the first one with the indices of outs that can be removed, the second -------
with the outputs that can not be removed. two lists, the first one with the indices of outs that can be removed, the
second with the outputs that can not be removed.
""" """
non_removable = [o for i, o in enumerate(op.outputs) if i not in out_idxs] non_removable = [o for i, o in enumerate(op.outputs) if i not in out_idxs]
...@@ -360,9 +362,10 @@ def scan_can_remove_outs(op, out_idxs): ...@@ -360,9 +362,10 @@ def scan_can_remove_outs(op, out_idxs):
out_ins = [] out_ins = []
offset = op.n_seqs offset = op.n_seqs
lim = op.n_mit_mot + op.n_mit_sot + op.n_sit_sot for idx, tap in enumerate(
for idx in range(lim): chain(op.mit_mot_in_slices, op.mit_sot_in_slices, op.sit_sot_in_slices)
n_ins = len(op.info.tap_array[idx]) ):
n_ins = len(tap)
out_ins += [op.inputs[offset : offset + n_ins]] out_ins += [op.inputs[offset : offset + n_ins]]
offset += n_ins offset += n_ins
out_ins += [[] for k in range(op.n_nit_sot)] out_ins += [[] for k in range(op.n_nit_sot)]
...@@ -398,7 +401,9 @@ def compress_outs(op, not_required, inputs): ...@@ -398,7 +401,9 @@ def compress_outs(op, not_required, inputs):
from aesara.scan.op import ScanInfo from aesara.scan.op import ScanInfo
info = ScanInfo( info = ScanInfo(
tap_array=(), mit_mot_in_slices=(),
mit_sot_in_slices=(),
sit_sot_in_slices=(),
n_seqs=op.info.n_seqs, n_seqs=op.info.n_seqs,
n_mit_mot=0, n_mit_mot=0,
n_mit_mot_outs=0, n_mit_mot_outs=0,
...@@ -428,23 +433,24 @@ def compress_outs(op, not_required, inputs): ...@@ -428,23 +433,24 @@ def compress_outs(op, not_required, inputs):
info = dataclasses.replace( info = dataclasses.replace(
info, info,
n_mit_mot=info.n_mit_mot + 1, n_mit_mot=info.n_mit_mot + 1,
tap_array=info.tap_array + (tuple(op.tap_array[offset + idx]),), mit_mot_in_slices=info.mit_mot_in_slices + (op.mit_mot_in_slices[idx],),
mit_mot_out_slices=info.mit_mot_out_slices mit_mot_out_slices=info.mit_mot_out_slices
+ (tuple(op.mit_mot_out_slices[offset + idx]),), + (op.mit_mot_out_slices[idx],),
) )
# input taps # input taps
for jdx in op.tap_array[offset + idx]: for jdx in op.mit_mot_in_slices[idx]:
op_inputs += [op.inputs[i_offset]] op_inputs += [op.inputs[i_offset]]
i_offset += 1 i_offset += 1
# output taps # output taps
for jdx in op.mit_mot_out_slices[offset + idx]: for jdx in op.mit_mot_out_slices[idx]:
op_outputs += [op.outputs[o_offset]] op_outputs += [op.outputs[o_offset]]
o_offset += 1 o_offset += 1
# node inputs # node inputs
node_inputs += [inputs[ni_offset + idx]] node_inputs += [inputs[ni_offset + idx]]
else: else:
o_offset += len(op.mit_mot_out_slices[offset + idx]) o_offset += len(op.mit_mot_out_slices[idx])
i_offset += len(op.tap_array[offset + idx]) i_offset += len(op.mit_mot_in_slices[idx])
info = dataclasses.replace(info, n_mit_mot_outs=len(op_outputs)) info = dataclasses.replace(info, n_mit_mot_outs=len(op_outputs))
offset += op.n_mit_mot offset += op.n_mit_mot
ni_offset += op.n_mit_mot ni_offset += op.n_mit_mot
...@@ -456,10 +462,10 @@ def compress_outs(op, not_required, inputs): ...@@ -456,10 +462,10 @@ def compress_outs(op, not_required, inputs):
info = dataclasses.replace( info = dataclasses.replace(
info, info,
n_mit_sot=info.n_mit_sot + 1, n_mit_sot=info.n_mit_sot + 1,
tap_array=info.tap_array + (tuple(op.tap_array[offset + idx]),), mit_sot_in_slices=info.mit_sot_in_slices + (op.mit_sot_in_slices[idx],),
) )
# input taps # input taps
for jdx in op.tap_array[offset + idx]: for jdx in op.mit_sot_in_slices[idx]:
op_inputs += [op.inputs[i_offset]] op_inputs += [op.inputs[i_offset]]
i_offset += 1 i_offset += 1
# output taps # output taps
...@@ -469,7 +475,7 @@ def compress_outs(op, not_required, inputs): ...@@ -469,7 +475,7 @@ def compress_outs(op, not_required, inputs):
node_inputs += [inputs[ni_offset + idx]] node_inputs += [inputs[ni_offset + idx]]
else: else:
o_offset += 1 o_offset += 1
i_offset += len(op.tap_array[offset + idx]) i_offset += len(op.mit_sot_in_slices[idx])
offset += op.n_mit_sot offset += op.n_mit_sot
ni_offset += op.n_mit_sot ni_offset += op.n_mit_sot
...@@ -480,7 +486,7 @@ def compress_outs(op, not_required, inputs): ...@@ -480,7 +486,7 @@ def compress_outs(op, not_required, inputs):
info = dataclasses.replace( info = dataclasses.replace(
info, info,
n_sit_sot=info.n_sit_sot + 1, n_sit_sot=info.n_sit_sot + 1,
tap_array=info.tap_array + (tuple(op.tap_array[offset + idx]),), sit_sot_in_slices=info.sit_sot_in_slices + (op.sit_sot_in_slices[idx],),
) )
# input taps # input taps
op_inputs += [op.inputs[i_offset]] op_inputs += [op.inputs[i_offset]]
...@@ -613,8 +619,9 @@ class ScanArgs: ...@@ -613,8 +619,9 @@ class ScanArgs:
n_mit_mot = info.n_mit_mot n_mit_mot = info.n_mit_mot
n_mit_sot = info.n_mit_sot n_mit_sot = info.n_mit_sot
self.mit_mot_in_slices = list(info.tap_array[:n_mit_mot]) self.mit_mot_in_slices = info.mit_mot_in_slices
self.mit_sot_in_slices = list(info.tap_array[n_mit_mot : n_mit_mot + n_mit_sot]) self.mit_sot_in_slices = info.mit_sot_in_slices
self.sit_sot_in_slices = info.sit_sot_in_slices
n_mit_mot_ins = sum(len(s) for s in self.mit_mot_in_slices) n_mit_mot_ins = sum(len(s) for s in self.mit_mot_in_slices)
n_mit_sot_ins = sum(len(s) for s in self.mit_sot_in_slices) n_mit_sot_ins = sum(len(s) for s in self.mit_sot_in_slices)
...@@ -800,14 +807,12 @@ class ScanArgs: ...@@ -800,14 +807,12 @@ class ScanArgs:
from aesara.scan.op import ScanInfo from aesara.scan.op import ScanInfo
return ScanInfo( return ScanInfo(
mit_mot_in_slices=tuple(tuple(v) for v in self.mit_mot_in_slices),
mit_sot_in_slices=tuple(tuple(v) for v in self.mit_sot_in_slices),
sit_sot_in_slices=((-1,),) * len(self.inner_in_sit_sot),
n_seqs=len(self.outer_in_seqs), n_seqs=len(self.outer_in_seqs),
n_mit_mot=len(self.outer_in_mit_mot), n_mit_mot=len(self.outer_in_mit_mot),
n_mit_sot=len(self.outer_in_mit_sot), n_mit_sot=len(self.outer_in_mit_sot),
tap_array=(
tuple(tuple(v) for v in self.mit_mot_in_slices)
+ tuple(tuple(v) for v in self.mit_sot_in_slices)
+ ((-1,),) * len(self.inner_in_sit_sot)
),
n_sit_sot=len(self.outer_in_sit_sot), n_sit_sot=len(self.outer_in_sit_sot),
n_nit_sot=len(self.outer_in_nit_sot), n_nit_sot=len(self.outer_in_nit_sot),
n_shared_outs=len(self.outer_in_shared), n_shared_outs=len(self.outer_in_shared),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论