提交 d11f3039 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Completely remove ScanInfo fields from Scan Op

上级 7cfe58c4
......@@ -418,7 +418,7 @@ def jax_funcify_Scan(op, **kwargs):
def scan(*outer_inputs):
scan_args = ScanArgs(
list(outer_inputs), [None] * op.n_outs, op.inputs, op.outputs, op.info
list(outer_inputs), [None] * op.info.n_outs, op.inputs, op.outputs, op.info
)
# `outer_inputs` is a list with the following composite form:
......
......@@ -314,12 +314,6 @@ class ScanMethodsMixin:
def outer_mitmot_outs(self, list_outputs):
return list_outputs[: self.info.n_mit_mot]
def mitmot_taps(self):
return self.info.mit_mot_in_slices
def mitmot_out_taps(self):
return self.info.mit_mot_out_slices[: self.info.n_mit_mot]
def inner_mitsot(self, list_inputs):
n_mitmot_taps = sum(len(x) for x in self.info.mit_mot_in_slices)
ntaps_upto_sit_sot = n_mitmot_taps + sum(
......@@ -342,9 +336,6 @@ class ScanMethodsMixin:
self.info.n_mit_mot : self.info.n_mit_mot + self.info.n_mit_sot
]
def mitsot_taps(self):
return self.info.mit_sot_in_slices
def inner_sitsot(self, list_inputs):
n_taps_upto_sit_sot = sum(
len(x)
......@@ -785,12 +776,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self.profile = profile
self.allow_gc = allow_gc
self.strict = strict
self.__dict__.update(dataclasses.asdict(info))
self.n_mit_mot = self.info.n_mit_mot
self.n_mit_mot_outs = self.info.n_mit_mot_outs
self.n_mit_sot = self.info.n_mit_sot
self.n_sit_sot = self.info.n_sit_sot
# Clone mode_instance, altering "allow_gc" for the linker,
# and adding a message if we profile
......@@ -971,8 +956,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
n_outer_ins = len(inputs) - len(self.outer_nitsot(inputs)) - 1
n_inner_ins = (
len(self.inner_seqs(self.inner_inputs))
+ len(self.mitmot_taps())
+ len(self.mitsot_taps())
+ len(self.info.mit_mot_in_slices)
+ len(self.info.mit_sot_in_slices)
+ len(self.inner_sitsot(self.inner_inputs))
+ len(self.inner_shared(self.inner_inputs))
+ len(self.inner_non_seqs(self.inner_inputs))
......@@ -1006,7 +991,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
inner_mitmot = self.inner_mitmot(self.inner_inputs)
inner_mitmot_outs = self.inner_mitmot_outs(self.inner_outputs)
for idx, (itaps, otaps, _outer_mitmot) in enumerate(
zip(self.mitmot_taps(), self.mitmot_out_taps(), self.outer_mitmot(inputs))
zip(
self.info.mit_mot_in_slices,
self.info.mit_mot_out_slices[: self.info.n_mit_mot],
self.outer_mitmot(inputs),
)
):
outer_mitmot = copy_var_format(_outer_mitmot, as_var=inner_mitmot[ipos])
new_inputs.append(outer_mitmot)
......@@ -1057,7 +1046,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
inner_mitsots = self.inner_mitsot(self.inner_inputs)
for idx, (itaps, _outer_mitsot, inner_mitsot_out) in enumerate(
zip(
self.mitsot_taps(),
self.info.mit_sot_in_slices,
self.outer_mitsot(inputs),
self.inner_mitsot_outs(self.inner_outputs),
)
......@@ -1383,9 +1372,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
output_idx = sum(
len(m) for m in info.mit_mot_out_slices[:mitmot_idx]
)
output_idx += self.info.mit_mot_out_slices[mitmot_idx].index(
inp_tap
)
output_idx += info.mit_mot_out_slices[mitmot_idx].index(inp_tap)
preallocated_mitmot_outs.append(output_idx)
......@@ -1979,7 +1966,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
if self.mitmots_preallocated[mitmot_out_idx]:
mitmot_inp_idx = mitmot_inp_grp_offset + taps.index(out_slice)
inner_inp_idx = self.n_seqs + mitmot_inp_idx
inner_inp_idx = info.n_seqs + mitmot_inp_idx
# Verify whether the input points to the same data as
# it did before the execution of the inner function.
......@@ -2455,13 +2442,13 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
return 1 + iidx
oidx = 1 + info.n_seqs
iidx = iidx - info.n_seqs
for taps in self.mitmot_taps():
for taps in info.mit_mot_in_slices:
if len(taps) > iidx:
return oidx
else:
oidx += 1
iidx -= len(taps)
for taps in self.mitsot_taps():
for taps in info.mit_sot_in_slices:
if len(taps) > iidx:
return oidx
else:
......@@ -2475,7 +2462,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
def get_out_idx(iidx):
oidx = 0
for taps in self.mitmot_out_taps():
for taps in info.mit_mot_out_slices[: info.n_mit_mot]:
if len(taps) > iidx:
return oidx
else:
......@@ -2666,7 +2653,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
):
mintap = min(taps)
if idx < info.n_mit_mot:
outmaxtap = np.max(self.mitmot_out_taps()[idx])
outmaxtap = np.max(info.mit_mot_out_slices[: info.n_mit_mot][idx])
else:
outmaxtap = 0
seq = outs[idx]
......@@ -2695,7 +2682,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
n = n_steps.tag.test_value
else:
n = inputs[0].tag.test_value
for taps, x in zip(self.mitsot_taps(), self.outer_mitsot_outs(outs)):
for taps, x in zip(info.mit_sot_in_slices, self.outer_mitsot_outs(outs)):
mintap = np.min(taps)
if hasattr(x[::-1][:mintap], "test_value"):
assert x[::-1][:mintap].tag.test_value.shape[0] == n
......@@ -2710,7 +2697,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
assert x[::-1].tag.test_value.shape[0] == n
outer_inp_seqs += [
x[::-1][: np.min(taps)]
for taps, x in zip(self.mitsot_taps(), self.outer_mitsot_outs(outs))
for taps, x in zip(info.mit_sot_in_slices, self.outer_mitsot_outs(outs))
]
outer_inp_seqs += [x[::-1][:-1] for x in self.outer_sitsot_outs(outs)]
outer_inp_seqs += [x[::-1] for x in self.outer_nitsot_outs(outs)]
......
......@@ -81,31 +81,34 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
if not isinstance(node.op, Scan):
return False
op = node.op
op_info = op.info
# We only need to take care of sequences and other arguments
st = op.n_seqs
st += int(sum(len(x) for x in chain(op.mit_mot_in_slices, op.mit_sot_in_slices)))
st += op.n_sit_sot
st += op.n_shared_outs
st = op_info.n_seqs
st += int(
sum(len(x) for x in chain(op_info.mit_mot_in_slices, op_info.mit_sot_in_slices))
)
st += op_info.n_sit_sot
st += op_info.n_shared_outs
op_ins = op.inner_inputs
op_outs = op.inner_outputs
# Corresponds to the initial states, which should stay untouched.
# We put those variables aside, and put them back at the end.
out_stuff_inner = op_ins[op.n_seqs : st]
out_stuff_inner = op_ins[op_info.n_seqs : st]
non_seqs = op_ins[st:]
st = (
op.n_seqs
+ op.n_mit_mot
+ op.n_mit_sot
+ op.n_sit_sot
+ op.n_nit_sot
+ op.n_shared_outs
op_info.n_seqs
+ op_info.n_mit_mot
+ op_info.n_mit_sot
+ op_info.n_sit_sot
+ op_info.n_nit_sot
+ op_info.n_shared_outs
+ 1
)
outer_non_seqs = node.inputs[st:]
out_stuff_outer = node.inputs[1 + op.n_seqs : st]
out_stuff_outer = node.inputs[1 + op_info.n_seqs : st]
# To replace constants in the outer graph by clones in the inner graph
givens = {}
......@@ -115,7 +118,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
nw_outer = [node.inputs[0]]
all_ins = list(graph_inputs(op_outs))
for idx in range(op.n_seqs):
for idx in range(op_info.n_seqs):
node_inp = node.inputs[idx + 1]
if (
isinstance(node_inp, TensorConstant)
......@@ -170,7 +173,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
if len(nw_inner) != len(op_ins):
op_outs = clone_replace(op_outs, replace=givens)
nw_info = dataclasses.replace(
op.info, n_seqs=nw_n_seqs, n_non_seqs=len(nw_inner_nonseq)
op_info, n_seqs=nw_n_seqs, n_non_seqs=len(nw_inner_nonseq)
)
nwScan = Scan(
nw_inner,
......@@ -615,7 +618,7 @@ def push_out_seq_scan(fgraph, node):
if out in op.inner_mitsot_outs(ls):
odx = op.inner_mitsot_outs(ls).index(out)
inp = op.outer_mitsot(node.inputs)[odx]
st = abs(np.min(op.mitsot_taps()))
st = abs(np.min(op.info.mit_sot_in_slices))
y = set_subtensor(inp[st:], _y)
elif out in op.inner_sitsot_outs(ls):
odx = op.inner_sitsot_outs(ls).index(out)
......@@ -953,7 +956,7 @@ class ScanInplaceOptimizer(GlobalOptimizer):
op = node.op
# inputs corresponding to sequences and n_steps
ls_begin = node.inputs[: 1 + op.n_seqs]
ls_begin = node.inputs[: 1 + op.info.n_seqs]
ls = op.outer_mitmot(node.inputs)
ls += op.outer_mitsot(node.inputs)
ls += op.outer_sitsot(node.inputs)
......@@ -1044,7 +1047,7 @@ class ScanInplaceOptimizer(GlobalOptimizer):
# operate inplace.
out_indices = []
for out_idx in range(n_outs):
inp_idx = 1 + op.n_seqs + out_idx
inp_idx = 1 + op.info.n_seqs + out_idx
inp = original_node.inputs[inp_idx]
# If the input is from an eligible allocation node, attempt to
......@@ -1153,11 +1156,16 @@ def save_mem_new_scan(fgraph, node):
# defining ``init_l`` for mit_mot sequences is a bit trickier but
# it is safe to set it to 0
op = node.op
c_outs = op.n_mit_mot + op.n_mit_sot + op.n_sit_sot + op.n_nit_sot
op_info = op.info
c_outs = (
op_info.n_mit_mot + op_info.n_mit_sot + op_info.n_sit_sot + op_info.n_nit_sot
)
init_l = [0 for x in range(op.n_mit_mot)]
init_l += [abs(min(v)) for v in chain(op.mit_sot_in_slices, op.sit_sot_in_slices)]
init_l += [0 for x in range(op.n_nit_sot)]
init_l = [0 for x in range(op_info.n_mit_mot)]
init_l += [
abs(min(v)) for v in chain(op_info.mit_sot_in_slices, op_info.sit_sot_in_slices)
]
init_l += [0 for x in range(op_info.n_nit_sot)]
# 2. Check the clients of each output and see for how many steps
# does scan need to run
......@@ -1198,8 +1206,8 @@ def save_mem_new_scan(fgraph, node):
# Note that for mit_mot outputs and shared outputs we can not change
# the number of intermediate steps stored without affecting the
# result of the op
store_steps = [0 for o in range(op.n_mit_mot)]
store_steps += [-1 for o in node.outputs[op.n_mit_mot : c_outs]]
store_steps = [0 for o in range(op_info.n_mit_mot)]
store_steps += [-1 for o in node.outputs[op_info.n_mit_mot : c_outs]]
# Flag that says if an input has changed and we need to do something
# or not
flag_store = False
......@@ -1237,7 +1245,7 @@ def save_mem_new_scan(fgraph, node):
break
# 2.3.2 extract the begin/end of the first dimension
if i >= op.n_mit_mot:
if i >= op_info.n_mit_mot:
try:
length = shape_of[out][0]
except KeyError:
......@@ -1339,7 +1347,7 @@ def save_mem_new_scan(fgraph, node):
store_steps[i] = 0
break
if i > op.n_mit_mot:
if i > op_info.n_mit_mot:
length = node.inputs[0] + init_l[i]
else:
try:
......@@ -1367,9 +1375,9 @@ def save_mem_new_scan(fgraph, node):
# the pre-allocation mechanism is activated.
prealloc_outs = config.scan__allow_output_prealloc
first_mitsot_idx = node.op.n_mit_mot
first_mitsot_idx = op_info.n_mit_mot
last_sitsot_idx = (
node.op.n_mit_mot + node.op.n_mit_sot + node.op.n_sit_sot - 1
op_info.n_mit_mot + op_info.n_mit_sot + op_info.n_sit_sot - 1
)
preallocable_output = first_mitsot_idx <= i <= last_sitsot_idx
......@@ -1401,18 +1409,18 @@ def save_mem_new_scan(fgraph, node):
# to store everything in memory ( or ar orphane and required
# by the inner function .. )
replaced_outs = []
offset = 1 + op.n_seqs + op.n_mit_mot
for idx, _val in enumerate(store_steps[op.n_mit_mot :]):
i = idx + op.n_mit_mot
offset = 1 + op_info.n_seqs + op_info.n_mit_mot
for idx, _val in enumerate(store_steps[op_info.n_mit_mot :]):
i = idx + op_info.n_mit_mot
if not (isinstance(_val, int) and _val <= 0 and i not in required):
if idx + op.n_mit_mot in required:
if idx + op_info.n_mit_mot in required:
val = 1
else:
val = _val
# If the memory for this output has been pre-allocated
# before going into the scan op (by an alloc node)
if idx < op.n_mit_sot + op.n_sit_sot:
if idx < op_info.n_mit_sot + op_info.n_sit_sot:
# In case the input is still an alloc node, we
# actually have two options:
# a) the input is a set_subtensor, in that case we
......@@ -1442,8 +1450,8 @@ def save_mem_new_scan(fgraph, node):
nw_input = nw_inputs[offset + idx][:tmp]
nw_inputs[offset + idx] = nw_input
replaced_outs.append(op.n_mit_mot + idx)
odx = op.n_mit_mot + idx
replaced_outs.append(op_info.n_mit_mot + idx)
odx = op_info.n_mit_mot + idx
old_outputs += [
(
odx,
......@@ -1454,12 +1462,18 @@ def save_mem_new_scan(fgraph, node):
)
]
# If there is no memory pre-allocated for this output
elif idx < op.n_mit_sot + op.n_sit_sot + op.n_nit_sot:
pos = op.n_mit_mot + idx + op.n_seqs + 1 + op.n_shared_outs
elif idx < op_info.n_mit_sot + op_info.n_sit_sot + op_info.n_nit_sot:
pos = (
op_info.n_mit_mot
+ idx
+ op_info.n_seqs
+ 1
+ op_info.n_shared_outs
)
if nw_inputs[pos] == node.inputs[0]:
nw_inputs[pos] = val
odx = op.n_mit_mot + idx
odx = op_info.n_mit_mot + idx
replaced_outs.append(odx)
old_outputs += [
(
......@@ -1473,14 +1487,14 @@ def save_mem_new_scan(fgraph, node):
# 3.4. Recompute inputs for everything else based on the new
# number of steps
if global_nsteps is not None:
for idx, val in enumerate(store_steps[op.n_mit_mot :]):
for idx, val in enumerate(store_steps[op_info.n_mit_mot :]):
if val == 0:
# val == 0 means that we want to keep all intermediate
# results for that state, including the initial values.
if idx < op.n_mit_sot + op.n_sit_sot:
if idx < op_info.n_mit_sot + op_info.n_sit_sot:
in_idx = offset + idx
# Number of steps in the initial state
initl = init_l[op.n_mit_mot + idx]
initl = init_l[op_info.n_mit_mot + idx]
# If the initial buffer has the form
# inc_subtensor(zeros(...)[...], _nw_input)
......@@ -1501,8 +1515,10 @@ def save_mem_new_scan(fgraph, node):
else:
nw_input = nw_inputs[in_idx][: (initl + nw_steps)]
elif idx < op.n_mit_sot + op.n_sit_sot + op.n_nit_sot:
in_idx = offset + idx + op.n_shared_outs
elif (
idx < op_info.n_mit_sot + op_info.n_sit_sot + op_info.n_nit_sot
):
in_idx = offset + idx + op_info.n_shared_outs
if nw_inputs[in_idx] == node.inputs[0]:
nw_inputs[in_idx] = nw_steps
......@@ -1693,8 +1709,8 @@ class ScanMerge(GlobalOptimizer):
for idx, nd in enumerate(nodes):
inner_ins[idx].append(rename(nd.op.inner_mitmot(nd.op.inner_inputs), idx))
inner_outs[idx].append(nd.op.inner_mitmot_outs(nd.op.inner_outputs))
mit_mot_in_slices += nd.op.mitmot_taps()
mit_mot_out_slices += nd.op.mitmot_out_taps()
mit_mot_in_slices += nd.op.info.mit_mot_in_slices
mit_mot_out_slices += nd.op.info.mit_mot_out_slices[: nd.op.info.n_mit_mot]
outer_ins += rename(nd.op.outer_mitmot(nd.inputs), idx)
outer_outs += nd.op.outer_mitmot_outs(nd.outputs)
......@@ -1702,14 +1718,14 @@ class ScanMerge(GlobalOptimizer):
for idx, nd in enumerate(nodes):
inner_ins[idx].append(rename(nd.op.inner_mitsot(nd.op.inner_inputs), idx))
inner_outs[idx].append(nd.op.inner_mitsot_outs(nd.op.inner_outputs))
mit_sot_in_slices += nd.op.mitsot_taps()
mit_sot_in_slices += nd.op.info.mit_sot_in_slices
outer_ins += rename(nd.op.outer_mitsot(nd.inputs), idx)
outer_outs += nd.op.outer_mitsot_outs(nd.outputs)
sit_sot_in_slices = ()
for idx, nd in enumerate(nodes):
inner_ins[idx].append(rename(nd.op.inner_sitsot(nd.op.inner_inputs), idx))
sit_sot_in_slices += tuple((-1,) for x in range(nd.op.n_sit_sot))
sit_sot_in_slices += tuple((-1,) for x in range(nd.op.info.n_sit_sot))
inner_outs[idx].append(nd.op.inner_sitsot_outs(nd.op.inner_outputs))
outer_ins += rename(nd.op.outer_sitsot(nd.inputs), idx)
outer_outs += nd.op.outer_sitsot_outs(nd.outputs)
......@@ -1802,13 +1818,13 @@ class ScanMerge(GlobalOptimizer):
new_inner_outs += inner_outs[idx][gr_idx]
info = ScanInfo(
n_seqs=sum(nd.op.n_seqs for nd in nodes),
n_seqs=sum(nd.op.info.n_seqs for nd in nodes),
mit_mot_in_slices=mit_mot_in_slices,
mit_mot_out_slices=mit_mot_out_slices,
mit_sot_in_slices=mit_sot_in_slices,
sit_sot_in_slices=sit_sot_in_slices,
n_nit_sot=sum(nd.op.n_nit_sot for nd in nodes),
n_shared_outs=sum(nd.op.n_shared_outs for nd in nodes),
n_nit_sot=sum(nd.op.info.n_nit_sot for nd in nodes),
n_shared_outs=sum(nd.op.info.n_shared_outs for nd in nodes),
n_non_seqs=n_non_seqs,
as_while=as_while,
)
......
......@@ -361,15 +361,19 @@ def scan_can_remove_outs(op, out_idxs):
required_inputs = list(graph_inputs(non_removable))
out_ins = []
offset = op.n_seqs
offset = op.info.n_seqs
for idx, tap in enumerate(
chain(op.mit_mot_in_slices, op.mit_sot_in_slices, op.sit_sot_in_slices)
chain(
op.info.mit_mot_in_slices,
op.info.mit_sot_in_slices,
op.info.sit_sot_in_slices,
)
):
n_ins = len(tap)
out_ins += [op.inner_inputs[offset : offset + n_ins]]
offset += n_ins
out_ins += [[] for k in range(op.n_nit_sot)]
out_ins += [[op.inner_inputs[offset + k]] for k in range(op.n_shared_outs)]
out_ins += [[] for k in range(op.info.n_nit_sot)]
out_ins += [[op.inner_inputs[offset + k]] for k in range(op.info.n_shared_outs)]
added = True
out_idxs_mask = [1 for idx in out_idxs]
......@@ -400,8 +404,9 @@ def compress_outs(op, not_required, inputs):
"""
from aesara.scan.op import ScanInfo
op_info = op.info
info = ScanInfo(
n_seqs=op.info.n_seqs,
n_seqs=op_info.n_seqs,
mit_mot_in_slices=(),
mit_mot_out_slices=(),
mit_sot_in_slices=(),
......@@ -409,56 +414,58 @@ def compress_outs(op, not_required, inputs):
n_nit_sot=0,
n_shared_outs=0,
n_non_seqs=0,
as_while=op.info.as_while,
as_while=op_info.as_while,
)
op_inputs = op.inner_inputs[: op.n_seqs]
op_inputs = op.inner_inputs[: op_info.n_seqs]
op_outputs = []
node_inputs = inputs[: op.n_seqs + 1]
node_inputs = inputs[: op_info.n_seqs + 1]
map_old_new = OrderedDict()
offset = 0
ni_offset = op.n_seqs + 1
i_offset = op.n_seqs
ni_offset = op_info.n_seqs + 1
i_offset = op_info.n_seqs
o_offset = 0
curr_pos = 0
for idx in range(op.info.n_mit_mot):
for idx in range(op_info.n_mit_mot):
if offset + idx not in not_required:
map_old_new[offset + idx] = curr_pos
curr_pos += 1
info = dataclasses.replace(
info,
mit_mot_in_slices=info.mit_mot_in_slices + (op.mit_mot_in_slices[idx],),
mit_mot_in_slices=info.mit_mot_in_slices
+ (op_info.mit_mot_in_slices[idx],),
mit_mot_out_slices=info.mit_mot_out_slices
+ (op.mit_mot_out_slices[idx],),
+ (op_info.mit_mot_out_slices[idx],),
)
# input taps
for jdx in op.mit_mot_in_slices[idx]:
for jdx in op_info.mit_mot_in_slices[idx]:
op_inputs += [op.inner_inputs[i_offset]]
i_offset += 1
# output taps
for jdx in op.mit_mot_out_slices[idx]:
for jdx in op_info.mit_mot_out_slices[idx]:
op_outputs += [op.inner_outputs[o_offset]]
o_offset += 1
# node inputs
node_inputs += [inputs[ni_offset + idx]]
else:
o_offset += len(op.mit_mot_out_slices[idx])
i_offset += len(op.mit_mot_in_slices[idx])
o_offset += len(op_info.mit_mot_out_slices[idx])
i_offset += len(op_info.mit_mot_in_slices[idx])
offset += op.n_mit_mot
ni_offset += op.n_mit_mot
offset += op_info.n_mit_mot
ni_offset += op_info.n_mit_mot
for idx in range(op.info.n_mit_sot):
for idx in range(op_info.n_mit_sot):
if offset + idx not in not_required:
map_old_new[offset + idx] = curr_pos
curr_pos += 1
info = dataclasses.replace(
info,
mit_sot_in_slices=info.mit_sot_in_slices + (op.mit_sot_in_slices[idx],),
mit_sot_in_slices=info.mit_sot_in_slices
+ (op_info.mit_sot_in_slices[idx],),
)
# input taps
for jdx in op.mit_sot_in_slices[idx]:
for jdx in op_info.mit_sot_in_slices[idx]:
op_inputs += [op.inner_inputs[i_offset]]
i_offset += 1
# output taps
......@@ -468,17 +475,18 @@ def compress_outs(op, not_required, inputs):
node_inputs += [inputs[ni_offset + idx]]
else:
o_offset += 1
i_offset += len(op.mit_sot_in_slices[idx])
i_offset += len(op_info.mit_sot_in_slices[idx])
offset += op.n_mit_sot
ni_offset += op.n_mit_sot
for idx in range(op.info.n_sit_sot):
offset += op_info.n_mit_sot
ni_offset += op_info.n_mit_sot
for idx in range(op_info.n_sit_sot):
if offset + idx not in not_required:
map_old_new[offset + idx] = curr_pos
curr_pos += 1
info = dataclasses.replace(
info,
sit_sot_in_slices=info.sit_sot_in_slices + (op.sit_sot_in_slices[idx],),
sit_sot_in_slices=info.sit_sot_in_slices
+ (op_info.sit_sot_in_slices[idx],),
)
# input taps
op_inputs += [op.inner_inputs[i_offset]]
......@@ -492,23 +500,23 @@ def compress_outs(op, not_required, inputs):
o_offset += 1
i_offset += 1
offset += op.n_sit_sot
ni_offset += op.n_sit_sot
offset += op_info.n_sit_sot
ni_offset += op_info.n_sit_sot
nit_sot_ins = []
for idx in range(op.info.n_nit_sot):
for idx in range(op_info.n_nit_sot):
if offset + idx not in not_required:
map_old_new[offset + idx] = curr_pos
curr_pos += 1
info = dataclasses.replace(info, n_nit_sot=info.n_nit_sot + 1)
op_outputs += [op.inner_outputs[o_offset]]
o_offset += 1
nit_sot_ins += [inputs[ni_offset + idx + op.n_shared_outs]]
nit_sot_ins += [inputs[ni_offset + idx + op_info.n_shared_outs]]
else:
o_offset += 1
offset += op.n_nit_sot
offset += op_info.n_nit_sot
shared_ins = []
for idx in range(op.info.n_shared_outs):
for idx in range(op_info.n_shared_outs):
if offset + idx not in not_required:
map_old_new[offset + idx] = curr_pos
curr_pos += 1
......@@ -526,8 +534,8 @@ def compress_outs(op, not_required, inputs):
# other stuff
op_inputs += op.inner_inputs[i_offset:]
info = dataclasses.replace(info, n_non_seqs=len(op.inner_inputs[i_offset:]))
node_inputs += inputs[ni_offset + op.n_shared_outs + op.n_nit_sot :]
if op.info.as_while:
node_inputs += inputs[ni_offset + op_info.n_shared_outs + op_info.n_nit_sot :]
if op_info.as_while:
op_outputs += [op.inner_outputs[o_offset]]
map_old_new[o_offset] = len(op_outputs) - 1
# map_old_new[len(op_outputs)-1] = o_offset
......
......@@ -97,8 +97,8 @@ def test_ScanArgs():
# Check the properties that allow us to use
# `Scan.get_oinp_iinp_iout_oout_mappings` as-is to implement
# `ScanArgs.var_mappings`
assert scan_args.n_nit_sot == scan_op.n_nit_sot
assert scan_args.n_mit_mot == scan_op.n_mit_mot
assert scan_args.n_nit_sot == scan_op.info.n_nit_sot
assert scan_args.n_mit_mot == scan_op.info.n_mit_mot
# The `scan_args` base class always clones the inner-graph;
# here we make sure it doesn't (and that all the inputs are the same)
assert scan_args.inputs == scan_op.inner_inputs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论