提交 995b6cbc authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Rename n_untraced_sit_sot_outs

上级 207b0c6e
...@@ -695,7 +695,7 @@ def scan( ...@@ -695,7 +695,7 @@ def scan(
sit_sot_inner_outputs = [] sit_sot_inner_outputs = []
sit_sot_rightOrder = [] sit_sot_rightOrder = []
n_untraced_sit_sot_outs = 0 n_untraced_sit_sot = 0
untraced_sit_sot_scan_inputs = [] untraced_sit_sot_scan_inputs = []
untraced_sit_sot_inner_inputs = [] untraced_sit_sot_inner_inputs = []
untraced_sit_sot_inner_outputs = [] untraced_sit_sot_inner_outputs = []
...@@ -763,7 +763,7 @@ def scan( ...@@ -763,7 +763,7 @@ def scan(
) )
untraced_sit_sot_scan_inputs.append(actual_arg) untraced_sit_sot_scan_inputs.append(actual_arg)
untraced_sit_sot_inner_inputs.append(arg) untraced_sit_sot_inner_inputs.append(arg)
n_untraced_sit_sot_outs += 1 n_untraced_sit_sot += 1
untraced_sit_sot_rightOrder.append(i) untraced_sit_sot_rightOrder.append(i)
elif init_out.get("taps", None): elif init_out.get("taps", None):
...@@ -839,7 +839,7 @@ def scan( ...@@ -839,7 +839,7 @@ def scan(
else: else:
_ordered_args[sit_sot_rightOrder[idx]] = [sit_sot_inner_inputs[idx]] _ordered_args[sit_sot_rightOrder[idx]] = [sit_sot_inner_inputs[idx]]
for idx in range(n_untraced_sit_sot_outs): for idx in range(n_untraced_sit_sot):
_ordered_args[untraced_sit_sot_rightOrder[idx]] = [ _ordered_args[untraced_sit_sot_rightOrder[idx]] = [
untraced_sit_sot_inner_inputs[idx] untraced_sit_sot_inner_inputs[idx]
] ]
...@@ -1026,7 +1026,7 @@ def scan( ...@@ -1026,7 +1026,7 @@ def scan(
untraced_sit_sot_inner_inputs.append(new_var) untraced_sit_sot_inner_inputs.append(new_var)
untraced_sit_sot_scan_inputs.append(input.variable) untraced_sit_sot_scan_inputs.append(input.variable)
untraced_sit_sot_inner_outputs.append(input.update) untraced_sit_sot_inner_outputs.append(input.update)
n_untraced_sit_sot_outs += 1 n_untraced_sit_sot += 1
else: else:
no_update_shared_inputs.append(input) no_update_shared_inputs.append(input)
...@@ -1121,7 +1121,7 @@ def scan( ...@@ -1121,7 +1121,7 @@ def scan(
mit_mot_out_slices=tuple(tuple(v) for v in mit_mot_out_slices), mit_mot_out_slices=tuple(tuple(v) for v in mit_mot_out_slices),
mit_sot_in_slices=tuple(tuple(v) for v in mit_sot_tap_array), mit_sot_in_slices=tuple(tuple(v) for v in mit_sot_tap_array),
sit_sot_in_slices=tuple((-1,) for x in range(n_sit_sot)), sit_sot_in_slices=tuple((-1,) for x in range(n_sit_sot)),
n_untraced_sit_sot_outs=n_untraced_sit_sot_outs, n_untraced_sit_sot=n_untraced_sit_sot,
n_nit_sot=n_nit_sot, n_nit_sot=n_nit_sot,
n_non_seqs=len(other_shared_inner_args) + len(other_inner_args), n_non_seqs=len(other_shared_inner_args) + len(other_inner_args),
as_while=as_while, as_while=as_while,
...@@ -1195,14 +1195,12 @@ def scan( ...@@ -1195,14 +1195,12 @@ def scan(
offset += n_nit_sot offset += n_nit_sot
# Legacy support for explicit untraced sit_sot and those built with update dictionary # Legacy support for explicit untraced sit_sot and those built with update dictionary
# Switch to n_untraced_sit_sot_outs after deprecation period # Switch to n_untraced_sit_sot after deprecation period
n_explicit_untraced_sit_sot_outs = len(untraced_sit_sot_rightOrder) n_explicit_untraced_sit_sot = len(untraced_sit_sot_rightOrder)
untraced_sit_sot_outs = scan_outs[ untraced_sit_sot_outs = scan_outs[offset : offset + n_explicit_untraced_sit_sot]
offset : offset + n_explicit_untraced_sit_sot_outs
]
# Legacy support: map shared outputs to their updates # Legacy support: map shared outputs to their updates
offset += n_explicit_untraced_sit_sot_outs offset += n_explicit_untraced_sit_sot
for idx, update_rule in enumerate(scan_outs[offset:]): for idx, update_rule in enumerate(scan_outs[offset:]):
update_map[untraced_sit_sot_scan_inputs[idx]] = update_rule update_map[untraced_sit_sot_scan_inputs[idx]] = update_rule
......
...@@ -217,18 +217,18 @@ class ScanInfo: ...@@ -217,18 +217,18 @@ class ScanInfo:
mit_sot_in_slices: tuple mit_sot_in_slices: tuple
sit_sot_in_slices: tuple sit_sot_in_slices: tuple
n_nit_sot: int n_nit_sot: int
n_untraced_sit_sot_outs: int n_untraced_sit_sot: int
n_non_seqs: int n_non_seqs: int
as_while: bool as_while: bool
@property @property
def n_shared_outs(self): def n_shared_outs(self):
warnings.warn( warnings.warn(
"The 'n_shared_outs' property is deprecated. Use 'n_untraced_sit_sot_outs' instead.", "The 'n_shared_outs' property is deprecated. Use 'n_untraced_sit_sot' instead.",
DeprecationWarning, DeprecationWarning,
stacklevel=2, stacklevel=2,
) )
return self.n_untraced_sit_sot_outs return self.n_untraced_sit_sot
@property @property
def n_mit_mot(self): def n_mit_mot(self):
...@@ -257,7 +257,7 @@ class ScanInfo: ...@@ -257,7 +257,7 @@ class ScanInfo:
+ sum(len(x) for x in self.mit_mot_in_slices) + sum(len(x) for x in self.mit_mot_in_slices)
+ sum(len(x) for x in self.mit_sot_in_slices) + sum(len(x) for x in self.mit_sot_in_slices)
+ self.n_sit_sot + self.n_sit_sot
+ self.n_untraced_sit_sot_outs + self.n_untraced_sit_sot
+ self.n_non_seqs + self.n_non_seqs
) )
...@@ -268,7 +268,7 @@ class ScanInfo: ...@@ -268,7 +268,7 @@ class ScanInfo:
+ self.n_mit_sot + self.n_mit_sot
+ self.n_sit_sot + self.n_sit_sot
+ self.n_nit_sot + self.n_nit_sot
+ self.n_untraced_sit_sot_outs + self.n_untraced_sit_sot
+ int(self.as_while) + int(self.as_while)
) )
...@@ -281,7 +281,7 @@ class ScanInfo: ...@@ -281,7 +281,7 @@ class ScanInfo:
+ self.n_mit_sot + self.n_mit_sot
+ self.n_sit_sot + self.n_sit_sot
+ self.n_nit_sot + self.n_nit_sot
+ self.n_untraced_sit_sot_outs + self.n_untraced_sit_sot
+ self.n_non_seqs + self.n_non_seqs
) )
...@@ -292,7 +292,7 @@ class ScanInfo: ...@@ -292,7 +292,7 @@ class ScanInfo:
+ self.n_mit_sot + self.n_mit_sot
+ self.n_sit_sot + self.n_sit_sot
+ self.n_nit_sot + self.n_nit_sot
+ self.n_untraced_sit_sot_outs + self.n_untraced_sit_sot
) )
@property @property
...@@ -419,7 +419,7 @@ class ScanMethodsMixin: ...@@ -419,7 +419,7 @@ class ScanMethodsMixin:
+ self.info.n_mit_mot + self.info.n_mit_mot
+ self.info.n_mit_sot + self.info.n_mit_sot
+ self.info.n_sit_sot + self.info.n_sit_sot
+ self.info.n_untraced_sit_sot_outs + self.info.n_untraced_sit_sot
) )
return list_inputs[offset : offset + self.info.n_nit_sot] return list_inputs[offset : offset + self.info.n_nit_sot]
...@@ -438,7 +438,7 @@ class ScanMethodsMixin: ...@@ -438,7 +438,7 @@ class ScanMethodsMixin:
for x in chain(self.info.mit_mot_in_slices, self.info.mit_sot_in_slices) for x in chain(self.info.mit_mot_in_slices, self.info.mit_sot_in_slices)
) )
offset = self.info.n_seqs + n_taps_upto_sit_sot + self.info.n_sit_sot offset = self.info.n_seqs + n_taps_upto_sit_sot + self.info.n_sit_sot
return list_inputs[offset : offset + self.info.n_untraced_sit_sot_outs] return list_inputs[offset : offset + self.info.n_untraced_sit_sot]
def inner_shared(self, list_inputs): def inner_shared(self, list_inputs):
warnings.warn( warnings.warn(
...@@ -456,7 +456,7 @@ class ScanMethodsMixin: ...@@ -456,7 +456,7 @@ class ScanMethodsMixin:
+ self.info.n_mit_sot + self.info.n_mit_sot
+ self.info.n_sit_sot + self.info.n_sit_sot
) )
return list_inputs[offset : offset + self.info.n_untraced_sit_sot_outs] return list_inputs[offset : offset + self.info.n_untraced_sit_sot]
def outer_shared(self, list_inputs): def outer_shared(self, list_inputs):
warnings.warn( warnings.warn(
...@@ -471,7 +471,7 @@ class ScanMethodsMixin: ...@@ -471,7 +471,7 @@ class ScanMethodsMixin:
offset = ( offset = (
self.info.n_mit_sot + n_taps + self.info.n_sit_sot + self.info.n_nit_sot self.info.n_mit_sot + n_taps + self.info.n_sit_sot + self.info.n_nit_sot
) )
return list_outputs[offset : offset + self.info.n_untraced_sit_sot_outs] return list_outputs[offset : offset + self.info.n_untraced_sit_sot]
def inner_shared_outs(self, list_outputs): def inner_shared_outs(self, list_outputs):
warnings.warn( warnings.warn(
...@@ -488,7 +488,7 @@ class ScanMethodsMixin: ...@@ -488,7 +488,7 @@ class ScanMethodsMixin:
+ self.info.n_sit_sot + self.info.n_sit_sot
+ self.info.n_nit_sot + self.info.n_nit_sot
) )
return list_outputs[offset : offset + self.info.n_untraced_sit_sot_outs] return list_outputs[offset : offset + self.info.n_untraced_sit_sot]
def outer_shared_outs(self, list_outputs): def outer_shared_outs(self, list_outputs):
warnings.warn( warnings.warn(
...@@ -507,7 +507,7 @@ class ScanMethodsMixin: ...@@ -507,7 +507,7 @@ class ScanMethodsMixin:
self.info.n_seqs self.info.n_seqs
+ n_taps_upto_sit_sot + n_taps_upto_sit_sot
+ self.info.n_sit_sot + self.info.n_sit_sot
+ self.info.n_untraced_sit_sot_outs + self.info.n_untraced_sit_sot
) )
return list_inputs[offset:] return list_inputs[offset:]
...@@ -519,7 +519,7 @@ class ScanMethodsMixin: ...@@ -519,7 +519,7 @@ class ScanMethodsMixin:
+ self.info.n_mit_sot + self.info.n_mit_sot
+ self.info.n_sit_sot + self.info.n_sit_sot
+ self.info.n_nit_sot + self.info.n_nit_sot
+ self.info.n_untraced_sit_sot_outs + self.info.n_untraced_sit_sot
) )
return list_inputs[offset:] return list_inputs[offset:]
...@@ -596,7 +596,7 @@ class ScanMethodsMixin: ...@@ -596,7 +596,7 @@ class ScanMethodsMixin:
# This is needed because, for outer inputs (and for outer inputs only) # This is needed because, for outer inputs (and for outer inputs only)
# nitsots come *after* untraced_sitsot variables. # nitsots come *after* untraced_sitsot variables.
outer_iidx += self.info.n_untraced_sit_sot_outs outer_iidx += self.info.n_untraced_sit_sot
# Handle nitsots variables # Handle nitsots variables
for i in range(self.info.n_nit_sot): for i in range(self.info.n_nit_sot):
...@@ -612,10 +612,10 @@ class ScanMethodsMixin: ...@@ -612,10 +612,10 @@ class ScanMethodsMixin:
# This is needed because, for outer inputs (and for outer inputs only) # This is needed because, for outer inputs (and for outer inputs only)
# nitsots come *after* untraced_sit_sot variables. # nitsots come *after* untraced_sit_sot variables.
outer_iidx -= self.info.n_untraced_sit_sot_outs + self.info.n_nit_sot outer_iidx -= self.info.n_untraced_sit_sot + self.info.n_nit_sot
# Handle untraced_sitsot states # Handle untraced_sitsot states
for i in range(self.info.n_untraced_sit_sot_outs): for i in range(self.info.n_untraced_sit_sot):
outer_input_indices.append(outer_iidx) outer_input_indices.append(outer_iidx)
inner_input_indices.append([inner_iidx]) inner_input_indices.append([inner_iidx])
inner_output_indices.append([inner_oidx]) inner_output_indices.append([inner_oidx])
...@@ -910,7 +910,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -910,7 +910,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self.seqs_arg_offset + info.n_mit_mot + info.n_mit_sot + info.n_sit_sot self.seqs_arg_offset + info.n_mit_mot + info.n_mit_sot + info.n_sit_sot
) )
self.nit_sot_arg_offset = ( self.nit_sot_arg_offset = (
self.untraced_sit_sot_arg_offset + info.n_untraced_sit_sot_outs self.untraced_sit_sot_arg_offset + info.n_untraced_sit_sot
) )
# Note: This doesn't include `info.n_nit_sot`s, so it's really a count # Note: This doesn't include `info.n_nit_sot`s, so it's really a count
# of the number of outputs generated by taps with inputs # of the number of outputs generated by taps with inputs
...@@ -1647,7 +1647,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1647,7 +1647,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
try: try:
t_fn, n_steps = scan_perform_ext.perform( t_fn, n_steps = scan_perform_ext.perform(
self.info.n_untraced_sit_sot_outs, self.info.n_untraced_sit_sot,
self.info.n_mit_mot_outs, self.info.n_mit_mot_outs,
self.info.n_seqs, self.info.n_seqs,
self.info.n_mit_mot, self.info.n_mit_mot,
...@@ -1846,7 +1846,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1846,7 +1846,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
info.sit_sot_in_slices, info.sit_sot_in_slices,
) )
) )
+ info.n_untraced_sit_sot_outs + info.n_untraced_sit_sot
) )
for idx in range(len(other_args)): for idx in range(len(other_args)):
inner_input_storage[idx + offset].storage[0] = other_args[idx] inner_input_storage[idx + offset].storage[0] = other_args[idx]
...@@ -1892,11 +1892,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1892,11 +1892,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
a_offset = self.untraced_sit_sot_arg_offset a_offset = self.untraced_sit_sot_arg_offset
o_offset = self.n_outs + info.n_nit_sot o_offset = self.n_outs + info.n_nit_sot
if i == 0: if i == 0:
for j in range(info.n_untraced_sit_sot_outs): for j in range(info.n_untraced_sit_sot):
inner_input_storage[offset].storage[0] = inputs[a_offset + j] inner_input_storage[offset].storage[0] = inputs[a_offset + j]
offset += 1 offset += 1
else: else:
for j in range(info.n_untraced_sit_sot_outs): for j in range(info.n_untraced_sit_sot):
inner_input_storage[offset].storage[0] = output_storage[ inner_input_storage[offset].storage[0] = output_storage[
o_offset + j o_offset + j
][0] ][0]
...@@ -1930,12 +1930,12 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1930,12 +1930,12 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# 4.3. Collect slices for untraced sitsot outputs # 4.3. Collect slices for untraced sitsot outputs
offset += self.n_outs + info.n_nit_sot - info.n_mit_mot offset += self.n_outs + info.n_nit_sot - info.n_mit_mot
for idx in range(info.n_untraced_sit_sot_outs): for idx in range(info.n_untraced_sit_sot):
inner_output_storage[idx + offset].storage[0] = None inner_output_storage[idx + offset].storage[0] = None
# 4.4. If there is a condition add it to the mix # 4.4. If there is a condition add it to the mix
if info.as_while: if info.as_while:
pdx = offset + info.n_untraced_sit_sot_outs pdx = offset + info.n_untraced_sit_sot
inner_output_storage[pdx].storage[0] = None inner_output_storage[pdx].storage[0] = None
# 4.5. Keep a reference to the variables (ndarrays, # 4.5. Keep a reference to the variables (ndarrays,
...@@ -2004,7 +2004,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2004,7 +2004,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
dt_fn = time.perf_counter() - t0_fn dt_fn = time.perf_counter() - t0_fn
if info.as_while: if info.as_while:
pdx = offset + info.n_untraced_sit_sot_outs pdx = offset + info.n_untraced_sit_sot
cond = inner_output_storage[pdx].storage[0] == 0 cond = inner_output_storage[pdx].storage[0] == 0
t_fn += dt_fn t_fn += dt_fn
...@@ -2151,7 +2151,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2151,7 +2151,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# 5.6 Copy over the values for outputs corresponding to untraced sitsot # 5.6 Copy over the values for outputs corresponding to untraced sitsot
# variables # variables
begin = end begin = end
end += info.n_untraced_sit_sot_outs end += info.n_untraced_sit_sot
for j in range(begin, end): for j in range(begin, end):
jout = j + offset_out jout = j + offset_out
output_storage[j][0] = inner_output_storage[jout].storage[0] output_storage[j][0] = inner_output_storage[jout].storage[0]
...@@ -2301,11 +2301,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2301,11 +2301,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# untraced sit_sot outputs # untraced sit_sot outputs
offset = 1 + info.n_seqs + n_outs offset = 1 + info.n_seqs + n_outs
for idx in range(info.n_untraced_sit_sot_outs): for idx in range(info.n_untraced_sit_sot):
outs_shape += [input_shapes[idx + offset]] outs_shape += [input_shapes[idx + offset]]
# non_sequences # non_sequences
offset += info.n_nit_sot + info.n_untraced_sit_sot_outs offset += info.n_nit_sot + info.n_untraced_sit_sot
inner_ins_shapes = seqs_shape + outs_shape + input_shapes[offset:] inner_ins_shapes = seqs_shape + outs_shape + input_shapes[offset:]
assert len(inner_ins_shapes) == len(self.inner_inputs) assert len(inner_ins_shapes) == len(self.inner_inputs)
...@@ -2347,7 +2347,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2347,7 +2347,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# in the inner function. # in the inner function.
r = node.outputs[n_outs + x] r = node.outputs[n_outs + x]
assert r.ndim == 1 + len(out_shape_x) assert r.ndim == 1 + len(out_shape_x)
shp = [node.inputs[offset + info.n_untraced_sit_sot_outs + x]] shp = [node.inputs[offset + info.n_untraced_sit_sot + x]]
for i, shp_i in zip(range(1, r.ndim), out_shape_x, strict=True): for i, shp_i in zip(range(1, r.ndim), out_shape_x, strict=True):
# Validate shp_i. v_shape_i is either None (if invalid), # Validate shp_i. v_shape_i is either None (if invalid),
# or a (variable, Boolean) tuple. The Boolean indicates # or a (variable, Boolean) tuple. The Boolean indicates
...@@ -2364,7 +2364,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2364,7 +2364,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
shp.append(v_shp_i[0]) shp.append(v_shp_i[0])
scan_outs.append(tuple(shp)) scan_outs.append(tuple(shp))
scan_outs += list(input_shapes[offset : offset + info.n_untraced_sit_sot_outs]) scan_outs += list(input_shapes[offset : offset + info.n_untraced_sit_sot])
# if we are dealing with a repeat-until, then we do not know the # if we are dealing with a repeat-until, then we do not know the
# leading dimension so we replace it for every entry with Shape_i # leading dimension so we replace it for every entry with Shape_i
if info.as_while: if info.as_while:
...@@ -2437,8 +2437,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2437,8 +2437,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
return connection_pattern return connection_pattern
def L_op(self, inputs, outs, dC_douts): def L_op(self, inputs, outs, dC_douts):
if not isinstance(outs, list | tuple):
outs = [outs]
# `grad_step` equals the number of steps the original scan node has # `grad_step` equals the number of steps the original scan node has
# done (if the original scan is a while loop than this number is the # done (if the original scan is a while loop than this number is the
# length of the output sequence) # length of the output sequence)
...@@ -2695,7 +2693,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2695,7 +2693,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
pass pass
elif isinstance(dC_dXtm1.type, NullType): elif isinstance(dC_dXtm1.type, NullType):
# The new gradient is undefined, this makes the accumulated # The new gradient is undefined, this makes the accumulated
# gradient undefined as weell # gradient undefined as well
dC_dinps_t[dx + info.n_seqs] = dC_dXtm1 dC_dinps_t[dx + info.n_seqs] = dC_dXtm1
else: else:
dC_dinps_t[dx + info.n_seqs] += dC_dXtm1 dC_dinps_t[dx + info.n_seqs] += dC_dXtm1
...@@ -3062,7 +3060,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3062,7 +3060,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
mit_sot_in_slices=(), mit_sot_in_slices=(),
sit_sot_in_slices=tuple((-1,) for k in range(n_sitsot_outs)), sit_sot_in_slices=tuple((-1,) for k in range(n_sitsot_outs)),
n_nit_sot=n_nit_sot, n_nit_sot=n_nit_sot,
n_untraced_sit_sot_outs=0, n_untraced_sit_sot=0,
n_non_seqs=len(self.outer_untraced_sit_sot(inputs)) n_non_seqs=len(self.outer_untraced_sit_sot(inputs))
+ len(self.outer_non_seqs(inputs)), + len(self.outer_non_seqs(inputs)),
as_while=False, as_while=False,
...@@ -3149,7 +3147,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3149,7 +3147,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
start = len(gradients) start = len(gradients)
node = outs[0].owner node = outs[0].owner
for idx in range(info.n_untraced_sit_sot_outs): for idx in range(info.n_untraced_sit_sot):
disconnected = True disconnected = True
connected_flags = self.connection_pattern(node)[idx + start] connected_flags = self.connection_pattern(node)[idx + start]
for dC_dout, connected in zip(dC_douts, connected_flags, strict=True): for dC_dout, connected in zip(dC_douts, connected_flags, strict=True):
...@@ -3211,7 +3209,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3211,7 +3209,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self_inputs = self.inner_inputs self_inputs = self.inner_inputs
rop_of_inputs = ( rop_of_inputs = (
self_inputs[: info.n_seqs + self.n_outs] self_inputs[: info.n_seqs + self.n_outs]
+ self_inputs[info.n_seqs + self.n_outs + info.n_untraced_sit_sot_outs :] + self_inputs[info.n_seqs + self.n_outs + info.n_untraced_sit_sot :]
) )
self_outputs = self.inner_outputs self_outputs = self.inner_outputs
...@@ -3221,8 +3219,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3221,8 +3219,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
rop_self_outputs = self_outputs[:-1] rop_self_outputs = self_outputs[:-1]
else: else:
rop_self_outputs = self_outputs rop_self_outputs = self_outputs
if info.n_untraced_sit_sot_outs > 0: if info.n_untraced_sit_sot > 0:
rop_self_outputs = rop_self_outputs[: -info.n_untraced_sit_sot_outs] rop_self_outputs = rop_self_outputs[: -info.n_untraced_sit_sot]
rop_outs = Rop( rop_outs = Rop(
rop_self_outputs, rop_self_outputs,
rop_of_inputs, rop_of_inputs,
...@@ -3308,9 +3306,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3308,9 +3306,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# Untraced outs ... # Untraced outs ...
b = e b = e
e = e + info.n_untraced_sit_sot_outs e = e + info.n_untraced_sit_sot
ib = ie ib = ie
ie = ie + info.n_untraced_sit_sot_outs ie = ie + info.n_untraced_sit_sot
scan_untraced = inputs[b:e] scan_untraced = inputs[b:e]
inner_untraced = self_inputs[ib:ie] inner_untraced = self_inputs[ib:ie]
...@@ -3346,7 +3344,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3346,7 +3344,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
e = e + info.n_nit_sot e = e + info.n_nit_sot
inner_out_nit_sot = self_outputs[b:e] + rop_outs[b:e] inner_out_nit_sot = self_outputs[b:e] + rop_outs[b:e]
b = e b = e
e = e + info.n_untraced_sit_sot_outs e = e + info.n_untraced_sit_sot
inner_out_untraced = self_outputs[b:e] inner_out_untraced = self_outputs[b:e]
inner_ins = ( inner_ins = (
...@@ -3385,7 +3383,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3385,7 +3383,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
mit_sot_in_slices=new_mit_sot_in_slices, mit_sot_in_slices=new_mit_sot_in_slices,
sit_sot_in_slices=new_sit_sot_in_slices, sit_sot_in_slices=new_sit_sot_in_slices,
n_nit_sot=info.n_nit_sot * 2, n_nit_sot=info.n_nit_sot * 2,
n_untraced_sit_sot_outs=info.n_untraced_sit_sot_outs, n_untraced_sit_sot=info.n_untraced_sit_sot,
n_non_seqs=len(inner_other), n_non_seqs=len(inner_other),
as_while=info.as_while, as_while=info.as_while,
) )
...@@ -3417,7 +3415,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3417,7 +3415,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
b = e + info.n_nit_sot b = e + info.n_nit_sot
e = e + info.n_nit_sot * 2 e = e + info.n_nit_sot * 2
final_outs += outputs[b:e] final_outs += outputs[b:e]
final_outs += [None] * info.n_untraced_sit_sot_outs final_outs += [None] * info.n_untraced_sit_sot
return final_outs return final_outs
......
...@@ -110,7 +110,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node): ...@@ -110,7 +110,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
sum(len(x) for x in chain(op_info.mit_mot_in_slices, op_info.mit_sot_in_slices)) sum(len(x) for x in chain(op_info.mit_mot_in_slices, op_info.mit_sot_in_slices))
) )
st += op_info.n_sit_sot st += op_info.n_sit_sot
st += op_info.n_untraced_sit_sot_outs st += op_info.n_untraced_sit_sot
op_ins = op.inner_inputs op_ins = op.inner_inputs
op_outs = op.inner_outputs op_outs = op.inner_outputs
...@@ -126,7 +126,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node): ...@@ -126,7 +126,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
+ op_info.n_mit_sot + op_info.n_mit_sot
+ op_info.n_sit_sot + op_info.n_sit_sot
+ op_info.n_nit_sot + op_info.n_nit_sot
+ op_info.n_untraced_sit_sot_outs + op_info.n_untraced_sit_sot
+ 1 + 1
) )
outer_non_seqs = node.inputs[st:] outer_non_seqs = node.inputs[st:]
...@@ -1628,7 +1628,7 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: ...@@ -1628,7 +1628,7 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
+ idx + idx
+ op_info.n_seqs + op_info.n_seqs
+ 1 + 1
+ op_info.n_untraced_sit_sot_outs + op_info.n_untraced_sit_sot
) )
if nw_inputs[pos] == node.inputs[0]: if nw_inputs[pos] == node.inputs[0]:
nw_inputs[pos] = 1 if required_orphan else val nw_inputs[pos] = 1 if required_orphan else val
...@@ -1662,7 +1662,7 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: ...@@ -1662,7 +1662,7 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
elif ( elif (
idx < op_info.n_mit_sot + op_info.n_sit_sot + op_info.n_nit_sot idx < op_info.n_mit_sot + op_info.n_sit_sot + op_info.n_nit_sot
): ):
in_idx = offset + idx + op_info.n_untraced_sit_sot_outs in_idx = offset + idx + op_info.n_untraced_sit_sot
if nw_inputs[in_idx] == node.inputs[0]: if nw_inputs[in_idx] == node.inputs[0]:
nw_inputs[in_idx] = nw_steps nw_inputs[in_idx] = nw_steps
...@@ -1980,9 +1980,7 @@ class ScanMerge(GraphRewriter): ...@@ -1980,9 +1980,7 @@ class ScanMerge(GraphRewriter):
mit_sot_in_slices=mit_sot_in_slices, mit_sot_in_slices=mit_sot_in_slices,
sit_sot_in_slices=sit_sot_in_slices, sit_sot_in_slices=sit_sot_in_slices,
n_nit_sot=sum(nd.op.info.n_nit_sot for nd in nodes), n_nit_sot=sum(nd.op.info.n_nit_sot for nd in nodes),
n_untraced_sit_sot_outs=sum( n_untraced_sit_sot=sum(nd.op.info.n_untraced_sit_sot for nd in nodes),
nd.op.info.n_untraced_sit_sot_outs for nd in nodes
),
n_non_seqs=n_non_seqs, n_non_seqs=n_non_seqs,
as_while=as_while, as_while=as_while,
) )
......
...@@ -371,7 +371,7 @@ def scan_can_remove_outs(op, out_idxs): ...@@ -371,7 +371,7 @@ def scan_can_remove_outs(op, out_idxs):
offset += n_ins offset += n_ins
out_ins += [[] for k in range(op.info.n_nit_sot)] out_ins += [[] for k in range(op.info.n_nit_sot)]
out_ins += [ out_ins += [
[op.inner_inputs[offset + k]] for k in range(op.info.n_untraced_sit_sot_outs) [op.inner_inputs[offset + k]] for k in range(op.info.n_untraced_sit_sot)
] ]
added = True added = True
...@@ -411,7 +411,7 @@ def compress_outs(op, not_required, inputs): ...@@ -411,7 +411,7 @@ def compress_outs(op, not_required, inputs):
mit_sot_in_slices=(), mit_sot_in_slices=(),
sit_sot_in_slices=(), sit_sot_in_slices=(),
n_nit_sot=0, n_nit_sot=0,
n_untraced_sit_sot_outs=0, n_untraced_sit_sot=0,
n_non_seqs=0, n_non_seqs=0,
as_while=op_info.as_while, as_while=op_info.as_while,
) )
...@@ -517,18 +517,18 @@ def compress_outs(op, not_required, inputs): ...@@ -517,18 +517,18 @@ def compress_outs(op, not_required, inputs):
info = dataclasses.replace(info, n_nit_sot=info.n_nit_sot + 1) info = dataclasses.replace(info, n_nit_sot=info.n_nit_sot + 1)
op_outputs += [op.inner_outputs[o_offset]] op_outputs += [op.inner_outputs[o_offset]]
o_offset += 1 o_offset += 1
nit_sot_ins += [inputs[ni_offset + idx + op_info.n_untraced_sit_sot_outs]] nit_sot_ins += [inputs[ni_offset + idx + op_info.n_untraced_sit_sot]]
else: else:
o_offset += 1 o_offset += 1
offset += op_info.n_nit_sot offset += op_info.n_nit_sot
shared_ins = [] shared_ins = []
for idx in range(op_info.n_untraced_sit_sot_outs): for idx in range(op_info.n_untraced_sit_sot):
if offset + idx not in not_required: if offset + idx not in not_required:
map_old_new[offset + idx] = curr_pos map_old_new[offset + idx] = curr_pos
curr_pos += 1 curr_pos += 1
info = dataclasses.replace( info = dataclasses.replace(
info, n_untraced_sit_sot_outs=info.n_untraced_sit_sot_outs + 1 info, n_untraced_sit_sot=info.n_untraced_sit_sot + 1
) )
op_outputs += [op.inner_outputs[o_offset]] op_outputs += [op.inner_outputs[o_offset]]
o_offset += 1 o_offset += 1
...@@ -543,9 +543,7 @@ def compress_outs(op, not_required, inputs): ...@@ -543,9 +543,7 @@ def compress_outs(op, not_required, inputs):
# other stuff # other stuff
op_inputs += op.inner_inputs[i_offset:] op_inputs += op.inner_inputs[i_offset:]
info = dataclasses.replace(info, n_non_seqs=len(op.inner_inputs[i_offset:])) info = dataclasses.replace(info, n_non_seqs=len(op.inner_inputs[i_offset:]))
node_inputs += inputs[ node_inputs += inputs[ni_offset + op_info.n_untraced_sit_sot + op_info.n_nit_sot :]
ni_offset + op_info.n_untraced_sit_sot_outs + op_info.n_nit_sot :
]
if op_info.as_while: if op_info.as_while:
op_outputs += [op.inner_outputs[o_offset]] op_outputs += [op.inner_outputs[o_offset]]
map_old_new[o_offset] = len(op_outputs) - 1 map_old_new[o_offset] = len(op_outputs) - 1
...@@ -664,11 +662,11 @@ class ScanArgs: ...@@ -664,11 +662,11 @@ class ScanArgs:
p += n_sit_sot p += n_sit_sot
q += n_sit_sot q += n_sit_sot
n_untraced_sit_sot_outs = info.n_untraced_sit_sot_outs n_untraced_sit_sot = info.n_untraced_sit_sot
self.outer_in_shared = list(outer_inputs[p : p + n_untraced_sit_sot_outs]) self.outer_in_shared = list(outer_inputs[p : p + n_untraced_sit_sot])
self.inner_in_shared = list(inner_inputs[q : q + n_untraced_sit_sot_outs]) self.inner_in_shared = list(inner_inputs[q : q + n_untraced_sit_sot])
p += n_untraced_sit_sot_outs p += n_untraced_sit_sot
q += n_untraced_sit_sot_outs q += n_untraced_sit_sot
n_nit_sot = info.n_nit_sot n_nit_sot = info.n_nit_sot
self.outer_in_nit_sot = list(outer_inputs[p : p + n_nit_sot]) self.outer_in_nit_sot = list(outer_inputs[p : p + n_nit_sot])
...@@ -708,10 +706,10 @@ class ScanArgs: ...@@ -708,10 +706,10 @@ class ScanArgs:
p += n_nit_sot p += n_nit_sot
q += n_nit_sot q += n_nit_sot
self.outer_out_shared = list(outer_outputs[p : p + n_untraced_sit_sot_outs]) self.outer_out_shared = list(outer_outputs[p : p + n_untraced_sit_sot])
self.inner_out_shared = list(inner_outputs[q : q + n_untraced_sit_sot_outs]) self.inner_out_shared = list(inner_outputs[q : q + n_untraced_sit_sot])
p += n_untraced_sit_sot_outs p += n_untraced_sit_sot
q += n_untraced_sit_sot_outs q += n_untraced_sit_sot
assert p == len(outer_outputs) assert p == len(outer_outputs)
assert q == len(inner_outputs) assert q == len(inner_outputs)
...@@ -822,7 +820,7 @@ class ScanArgs: ...@@ -822,7 +820,7 @@ class ScanArgs:
mit_sot_in_slices=tuple(tuple(v) for v in self.mit_sot_in_slices), mit_sot_in_slices=tuple(tuple(v) for v in self.mit_sot_in_slices),
sit_sot_in_slices=((-1,),) * len(self.inner_in_sit_sot), sit_sot_in_slices=((-1,),) * len(self.inner_in_sit_sot),
n_nit_sot=len(self.outer_in_nit_sot), n_nit_sot=len(self.outer_in_nit_sot),
n_untraced_sit_sot_outs=len(self.outer_in_shared), n_untraced_sit_sot=len(self.outer_in_shared),
n_non_seqs=len(self.inner_in_non_seqs), n_non_seqs=len(self.inner_in_non_seqs),
as_while=self.as_while, as_while=self.as_while,
) )
......
...@@ -651,7 +651,7 @@ def test_trace_truncation_regression_bug(): ...@@ -651,7 +651,7 @@ def test_trace_truncation_regression_bug():
mit_sot_in_slices=(), mit_sot_in_slices=(),
sit_sot_in_slices=((-1,),), sit_sot_in_slices=((-1,),),
n_nit_sot=0, n_nit_sot=0,
n_untraced_sit_sot_outs=0, n_untraced_sit_sot=0,
n_non_seqs=0, n_non_seqs=0,
as_while=False, as_while=False,
), ),
......
...@@ -86,7 +86,7 @@ from tests.scan.test_basic import ScanCompatibilityTests ...@@ -86,7 +86,7 @@ from tests.scan.test_basic import ScanCompatibilityTests
3, 3,
[], [],
[np.array([0.50100236, 2.16822932, 1.36326596])], [np.array([0.50100236, 2.16822932, 1.36326596])],
lambda op: op.info.n_untraced_sit_sot_outs > 0, lambda op: op.info.n_untraced_sit_sot > 0,
), ),
# mit-sot (that's also a type of sit-sot) # mit-sot (that's also a type of sit-sot)
( (
......
...@@ -4095,7 +4095,7 @@ class TestExamples: ...@@ -4095,7 +4095,7 @@ class TestExamples:
[{}], [{}],
[], [],
3, 3,
lambda op: op.info.n_untraced_sit_sot_outs > 0, lambda op: op.info.n_untraced_sit_sot > 0,
), ),
# mit-sot (that's also a type of sit-sot) # mit-sot (that's also a type of sit-sot)
( (
...@@ -4292,7 +4292,7 @@ def test_scan_mode_compatibility(scan_mode): ...@@ -4292,7 +4292,7 @@ def test_scan_mode_compatibility(scan_mode):
mit_sot_in_slices=(), mit_sot_in_slices=(),
sit_sot_in_slices=(), sit_sot_in_slices=(),
n_nit_sot=0, n_nit_sot=0,
n_untraced_sit_sot_outs=0, n_untraced_sit_sot=0,
n_non_seqs=0, n_non_seqs=0,
as_while=False, as_while=False,
) )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论