提交 995b6cbc authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Rename n_untraced_sit_sot_outs

上级 207b0c6e
......@@ -695,7 +695,7 @@ def scan(
sit_sot_inner_outputs = []
sit_sot_rightOrder = []
n_untraced_sit_sot_outs = 0
n_untraced_sit_sot = 0
untraced_sit_sot_scan_inputs = []
untraced_sit_sot_inner_inputs = []
untraced_sit_sot_inner_outputs = []
......@@ -763,7 +763,7 @@ def scan(
)
untraced_sit_sot_scan_inputs.append(actual_arg)
untraced_sit_sot_inner_inputs.append(arg)
n_untraced_sit_sot_outs += 1
n_untraced_sit_sot += 1
untraced_sit_sot_rightOrder.append(i)
elif init_out.get("taps", None):
......@@ -839,7 +839,7 @@ def scan(
else:
_ordered_args[sit_sot_rightOrder[idx]] = [sit_sot_inner_inputs[idx]]
for idx in range(n_untraced_sit_sot_outs):
for idx in range(n_untraced_sit_sot):
_ordered_args[untraced_sit_sot_rightOrder[idx]] = [
untraced_sit_sot_inner_inputs[idx]
]
......@@ -1026,7 +1026,7 @@ def scan(
untraced_sit_sot_inner_inputs.append(new_var)
untraced_sit_sot_scan_inputs.append(input.variable)
untraced_sit_sot_inner_outputs.append(input.update)
n_untraced_sit_sot_outs += 1
n_untraced_sit_sot += 1
else:
no_update_shared_inputs.append(input)
......@@ -1121,7 +1121,7 @@ def scan(
mit_mot_out_slices=tuple(tuple(v) for v in mit_mot_out_slices),
mit_sot_in_slices=tuple(tuple(v) for v in mit_sot_tap_array),
sit_sot_in_slices=tuple((-1,) for x in range(n_sit_sot)),
n_untraced_sit_sot_outs=n_untraced_sit_sot_outs,
n_untraced_sit_sot=n_untraced_sit_sot,
n_nit_sot=n_nit_sot,
n_non_seqs=len(other_shared_inner_args) + len(other_inner_args),
as_while=as_while,
......@@ -1195,14 +1195,12 @@ def scan(
offset += n_nit_sot
# Legacy support for explicit untraced sit_sot and those built with update dictionary
# Switch to n_untraced_sit_sot_outs after deprecation period
n_explicit_untraced_sit_sot_outs = len(untraced_sit_sot_rightOrder)
untraced_sit_sot_outs = scan_outs[
offset : offset + n_explicit_untraced_sit_sot_outs
]
# Switch to n_untraced_sit_sot after deprecation period
n_explicit_untraced_sit_sot = len(untraced_sit_sot_rightOrder)
untraced_sit_sot_outs = scan_outs[offset : offset + n_explicit_untraced_sit_sot]
# Legacy support: map shared outputs to their updates
offset += n_explicit_untraced_sit_sot_outs
offset += n_explicit_untraced_sit_sot
for idx, update_rule in enumerate(scan_outs[offset:]):
update_map[untraced_sit_sot_scan_inputs[idx]] = update_rule
......
差异被折叠。
......@@ -110,7 +110,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
sum(len(x) for x in chain(op_info.mit_mot_in_slices, op_info.mit_sot_in_slices))
)
st += op_info.n_sit_sot
st += op_info.n_untraced_sit_sot_outs
st += op_info.n_untraced_sit_sot
op_ins = op.inner_inputs
op_outs = op.inner_outputs
......@@ -126,7 +126,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
+ op_info.n_mit_sot
+ op_info.n_sit_sot
+ op_info.n_nit_sot
+ op_info.n_untraced_sit_sot_outs
+ op_info.n_untraced_sit_sot
+ 1
)
outer_non_seqs = node.inputs[st:]
......@@ -1628,7 +1628,7 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
+ idx
+ op_info.n_seqs
+ 1
+ op_info.n_untraced_sit_sot_outs
+ op_info.n_untraced_sit_sot
)
if nw_inputs[pos] == node.inputs[0]:
nw_inputs[pos] = 1 if required_orphan else val
......@@ -1662,7 +1662,7 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
elif (
idx < op_info.n_mit_sot + op_info.n_sit_sot + op_info.n_nit_sot
):
in_idx = offset + idx + op_info.n_untraced_sit_sot_outs
in_idx = offset + idx + op_info.n_untraced_sit_sot
if nw_inputs[in_idx] == node.inputs[0]:
nw_inputs[in_idx] = nw_steps
......@@ -1980,9 +1980,7 @@ class ScanMerge(GraphRewriter):
mit_sot_in_slices=mit_sot_in_slices,
sit_sot_in_slices=sit_sot_in_slices,
n_nit_sot=sum(nd.op.info.n_nit_sot for nd in nodes),
n_untraced_sit_sot_outs=sum(
nd.op.info.n_untraced_sit_sot_outs for nd in nodes
),
n_untraced_sit_sot=sum(nd.op.info.n_untraced_sit_sot for nd in nodes),
n_non_seqs=n_non_seqs,
as_while=as_while,
)
......
......@@ -371,7 +371,7 @@ def scan_can_remove_outs(op, out_idxs):
offset += n_ins
out_ins += [[] for k in range(op.info.n_nit_sot)]
out_ins += [
[op.inner_inputs[offset + k]] for k in range(op.info.n_untraced_sit_sot_outs)
[op.inner_inputs[offset + k]] for k in range(op.info.n_untraced_sit_sot)
]
added = True
......@@ -411,7 +411,7 @@ def compress_outs(op, not_required, inputs):
mit_sot_in_slices=(),
sit_sot_in_slices=(),
n_nit_sot=0,
n_untraced_sit_sot_outs=0,
n_untraced_sit_sot=0,
n_non_seqs=0,
as_while=op_info.as_while,
)
......@@ -517,18 +517,18 @@ def compress_outs(op, not_required, inputs):
info = dataclasses.replace(info, n_nit_sot=info.n_nit_sot + 1)
op_outputs += [op.inner_outputs[o_offset]]
o_offset += 1
nit_sot_ins += [inputs[ni_offset + idx + op_info.n_untraced_sit_sot_outs]]
nit_sot_ins += [inputs[ni_offset + idx + op_info.n_untraced_sit_sot]]
else:
o_offset += 1
offset += op_info.n_nit_sot
shared_ins = []
for idx in range(op_info.n_untraced_sit_sot_outs):
for idx in range(op_info.n_untraced_sit_sot):
if offset + idx not in not_required:
map_old_new[offset + idx] = curr_pos
curr_pos += 1
info = dataclasses.replace(
info, n_untraced_sit_sot_outs=info.n_untraced_sit_sot_outs + 1
info, n_untraced_sit_sot=info.n_untraced_sit_sot + 1
)
op_outputs += [op.inner_outputs[o_offset]]
o_offset += 1
......@@ -543,9 +543,7 @@ def compress_outs(op, not_required, inputs):
# other stuff
op_inputs += op.inner_inputs[i_offset:]
info = dataclasses.replace(info, n_non_seqs=len(op.inner_inputs[i_offset:]))
node_inputs += inputs[
ni_offset + op_info.n_untraced_sit_sot_outs + op_info.n_nit_sot :
]
node_inputs += inputs[ni_offset + op_info.n_untraced_sit_sot + op_info.n_nit_sot :]
if op_info.as_while:
op_outputs += [op.inner_outputs[o_offset]]
map_old_new[o_offset] = len(op_outputs) - 1
......@@ -664,11 +662,11 @@ class ScanArgs:
p += n_sit_sot
q += n_sit_sot
n_untraced_sit_sot_outs = info.n_untraced_sit_sot_outs
self.outer_in_shared = list(outer_inputs[p : p + n_untraced_sit_sot_outs])
self.inner_in_shared = list(inner_inputs[q : q + n_untraced_sit_sot_outs])
p += n_untraced_sit_sot_outs
q += n_untraced_sit_sot_outs
n_untraced_sit_sot = info.n_untraced_sit_sot
self.outer_in_shared = list(outer_inputs[p : p + n_untraced_sit_sot])
self.inner_in_shared = list(inner_inputs[q : q + n_untraced_sit_sot])
p += n_untraced_sit_sot
q += n_untraced_sit_sot
n_nit_sot = info.n_nit_sot
self.outer_in_nit_sot = list(outer_inputs[p : p + n_nit_sot])
......@@ -708,10 +706,10 @@ class ScanArgs:
p += n_nit_sot
q += n_nit_sot
self.outer_out_shared = list(outer_outputs[p : p + n_untraced_sit_sot_outs])
self.inner_out_shared = list(inner_outputs[q : q + n_untraced_sit_sot_outs])
p += n_untraced_sit_sot_outs
q += n_untraced_sit_sot_outs
self.outer_out_shared = list(outer_outputs[p : p + n_untraced_sit_sot])
self.inner_out_shared = list(inner_outputs[q : q + n_untraced_sit_sot])
p += n_untraced_sit_sot
q += n_untraced_sit_sot
assert p == len(outer_outputs)
assert q == len(inner_outputs)
......@@ -822,7 +820,7 @@ class ScanArgs:
mit_sot_in_slices=tuple(tuple(v) for v in self.mit_sot_in_slices),
sit_sot_in_slices=((-1,),) * len(self.inner_in_sit_sot),
n_nit_sot=len(self.outer_in_nit_sot),
n_untraced_sit_sot_outs=len(self.outer_in_shared),
n_untraced_sit_sot=len(self.outer_in_shared),
n_non_seqs=len(self.inner_in_non_seqs),
as_while=self.as_while,
)
......
......@@ -651,7 +651,7 @@ def test_trace_truncation_regression_bug():
mit_sot_in_slices=(),
sit_sot_in_slices=((-1,),),
n_nit_sot=0,
n_untraced_sit_sot_outs=0,
n_untraced_sit_sot=0,
n_non_seqs=0,
as_while=False,
),
......
......@@ -86,7 +86,7 @@ from tests.scan.test_basic import ScanCompatibilityTests
3,
[],
[np.array([0.50100236, 2.16822932, 1.36326596])],
lambda op: op.info.n_untraced_sit_sot_outs > 0,
lambda op: op.info.n_untraced_sit_sot > 0,
),
# mit-sot (that's also a type of sit-sot)
(
......
......@@ -4095,7 +4095,7 @@ class TestExamples:
[{}],
[],
3,
lambda op: op.info.n_untraced_sit_sot_outs > 0,
lambda op: op.info.n_untraced_sit_sot > 0,
),
# mit-sot (that's also a type of sit-sot)
(
......@@ -4292,7 +4292,7 @@ def test_scan_mode_compatibility(scan_mode):
mit_sot_in_slices=(),
sit_sot_in_slices=(),
n_nit_sot=0,
n_untraced_sit_sot_outs=0,
n_untraced_sit_sot=0,
n_non_seqs=0,
as_while=False,
)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论