提交 cdd8575b authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Move Scan's as_while to ScanInfo

上级 81a8741c
...@@ -120,7 +120,7 @@ def numba_funcify_Scan(op, node, **kwargs): ...@@ -120,7 +120,7 @@ def numba_funcify_Scan(op, node, **kwargs):
] ]
while_logic = "" while_logic = ""
if op.as_while: if op.info.as_while:
# The inner function will be returning a boolean as last argument # The inner function will be returning a boolean as last argument
inner_out_indexed.append("while_flag") inner_out_indexed.append("while_flag")
while_logic += """ while_logic += """
......
...@@ -1140,6 +1140,7 @@ def scan( ...@@ -1140,6 +1140,7 @@ def scan(
n_shared_outs=n_shared_outs, n_shared_outs=n_shared_outs,
n_nit_sot=n_nit_sot, n_nit_sot=n_nit_sot,
n_non_seqs=len(other_shared_inner_args) + len(other_inner_args), n_non_seqs=len(other_shared_inner_args) + len(other_inner_args),
as_while=as_while,
) )
local_op = Scan( local_op = Scan(
...@@ -1149,7 +1150,6 @@ def scan( ...@@ -1149,7 +1150,6 @@ def scan(
mode=mode, mode=mode,
truncate_gradient=truncate_gradient, truncate_gradient=truncate_gradient,
name=name, name=name,
as_while=as_while,
profile=profile, profile=profile,
allow_gc=allow_gc, allow_gc=allow_gc,
strict=strict, strict=strict,
......
...@@ -217,6 +217,7 @@ class ScanInfo: ...@@ -217,6 +217,7 @@ class ScanInfo:
n_shared_outs: int n_shared_outs: int
n_nit_sot: int n_nit_sot: int
n_non_seqs: int n_non_seqs: int
as_while: bool
TensorConstructorType = Callable[[List[bool], Union[str, np.generic]], TensorType] TensorConstructorType = Callable[[List[bool], Union[str, np.generic]], TensorType]
...@@ -670,8 +671,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -670,8 +671,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
as well as profiles for the computation of one step of each instance of as well as profiles for the computation of one step of each instance of
`Scan`. The `name` of the instance appears in those profiles and can `Scan`. The `name` of the instance appears in those profiles and can
greatly help to disambiguate information. greatly help to disambiguate information.
as_while
Whether or not the `Scan` is a ``while``-loop.
profile profile
If ``True`` or a non-empty string, a profile object will be created and If ``True`` or a non-empty string, a profile object will be created and
attached to the inner graph of `Scan`. When `profile` is ``True``, the attached to the inner graph of `Scan`. When `profile` is ``True``, the
...@@ -701,7 +700,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -701,7 +700,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self.info = info self.info = info
self.truncate_gradient = truncate_gradient self.truncate_gradient = truncate_gradient
self.name = name self.name = name
self.as_while = as_while
self.profile = profile self.profile = profile
self.allow_gc = allow_gc self.allow_gc = allow_gc
self.strict = strict self.strict = strict
...@@ -753,7 +751,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -753,7 +751,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
for o in outputs[end:]: for o in outputs[end:]:
self.output_types.append(o.type) self.output_types.append(o.type)
if self.as_while: if info.as_while:
self.output_types = self.output_types[:-1] self.output_types = self.output_types[:-1]
if not hasattr(self, "name") or self.name is None: if not hasattr(self, "name") or self.name is None:
...@@ -1201,9 +1199,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1201,9 +1199,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
if self.info != other.info: if self.info != other.info:
return False return False
if self.as_while != other.as_while:
return False
if self.profile != other.profile: if self.profile != other.profile:
return False return False
...@@ -1234,7 +1229,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1234,7 +1229,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
def __str__(self): def __str__(self):
device_str = "cpu" device_str = "cpu"
if self.as_while: if self.info.as_while:
name = "do_while" name = "do_while"
else: else:
name = "for" name = "for"
...@@ -1261,7 +1256,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1261,7 +1256,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
type(self), type(self),
self._hash_inner_graph, self._hash_inner_graph,
self.info, self.info,
self.as_while,
self.profile, self.profile,
self.truncate_gradient, self.truncate_gradient,
self.name, self.name,
...@@ -1510,7 +1504,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1510,7 +1504,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self.info.n_mit_sot, self.info.n_mit_sot,
self.info.n_sit_sot, self.info.n_sit_sot,
self.info.n_nit_sot, self.info.n_nit_sot,
self.as_while, self.info.as_while,
cython_mintaps, cython_mintaps,
self.info.tap_array, self.info.tap_array,
tap_array_len, tap_array_len,
...@@ -1777,7 +1771,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1777,7 +1771,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
inner_output_storage[idx + offset].storage[0] = None inner_output_storage[idx + offset].storage[0] = None
# 4.4. If there is a condition add it to the mix # 4.4. If there is a condition add it to the mix
if self.as_while: if info.as_while:
pdx = offset + info.n_shared_outs pdx = offset + info.n_shared_outs
inner_output_storage[pdx].storage[0] = None inner_output_storage[pdx].storage[0] = None
...@@ -1847,7 +1841,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1847,7 +1841,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
raise raise
dt_fn = time.time() - t0_fn dt_fn = time.time() - t0_fn
if self.as_while: if info.as_while:
pdx = offset + info.n_shared_outs pdx = offset + info.n_shared_outs
cond = inner_output_storage[pdx].storage[0] == 0 cond = inner_output_storage[pdx].storage[0] == 0
...@@ -2173,7 +2167,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2173,7 +2167,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
for in_ns, out_ns in zip(inner_non_sequences, node.inputs[offset:]): for in_ns, out_ns in zip(inner_non_sequences, node.inputs[offset:]):
out_equivalent[in_ns] = out_ns out_equivalent[in_ns] = out_ns
if self.as_while: if info.as_while:
self_outs = self.outputs[:-1] self_outs = self.outputs[:-1]
else: else:
self_outs = self.outputs self_outs = self.outputs
...@@ -2222,7 +2216,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2222,7 +2216,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
scan_outs += [x for x in input_shapes[offset : offset + info.n_shared_outs]] scan_outs += [x for x in input_shapes[offset : offset + info.n_shared_outs]]
# if we are dealing with a repeat-until, then we do not know the # if we are dealing with a repeat-until, then we do not know the
# leading dimension so we replace it for every entry with Shape_i # leading dimension so we replace it for every entry with Shape_i
if self.as_while: if info.as_while:
scan_outs_init = scan_outs scan_outs_init = scan_outs
scan_outs = [] scan_outs = []
for o, x in zip(node.outputs, scan_outs_init): for o, x in zip(node.outputs, scan_outs_init):
...@@ -2312,7 +2306,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2312,7 +2306,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
) )
else: else:
grad_steps = inputs[0] grad_steps = inputs[0]
if self.as_while: if info.as_while:
n_steps = outs[0].shape[0] n_steps = outs[0].shape[0]
# Restrict the number of grad steps according to # Restrict the number of grad steps according to
...@@ -2537,9 +2531,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2537,9 +2531,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
dC_dinps_t[dx + info.n_seqs] = dC_dXtm1 dC_dinps_t[dx + info.n_seqs] = dC_dXtm1
else: else:
dC_dinps_t[dx + info.n_seqs] += dC_dXtm1 dC_dinps_t[dx + info.n_seqs] += dC_dXtm1
# Construct scan op
# Seqs if info.as_while:
if self.as_while:
# equivalent to x[:n_steps][::-1] # equivalent to x[:n_steps][::-1]
outer_inp_seqs = [x[n_steps - 1 :: -1] for x in inputs[1 : 1 + info.n_seqs]] outer_inp_seqs = [x[n_steps - 1 :: -1] for x in inputs[1 : 1 + info.n_seqs]]
else: else:
...@@ -2560,7 +2553,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2560,7 +2553,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
outer_inp_seqs += [x[:-1][::-1] for x in self.outer_sitsot_outs(outs)] outer_inp_seqs += [x[:-1][::-1] for x in self.outer_sitsot_outs(outs)]
for x in self.outer_nitsot_outs(dC_douts): for x in self.outer_nitsot_outs(dC_douts):
if not isinstance(x.type, DisconnectedType): if not isinstance(x.type, DisconnectedType):
if self.as_while: if info.as_while:
# equivalent to x[:n_steps][::-1] # equivalent to x[:n_steps][::-1]
outer_inp_seqs.append(x[n_steps - 1 :: -1]) outer_inp_seqs.append(x[n_steps - 1 :: -1])
else: else:
...@@ -2572,7 +2565,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2572,7 +2565,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# fct add and we want to keep it for all Scan op. This is # fct add and we want to keep it for all Scan op. This is
# used in T_Scan.test_grad_multiple_outs_taps to test # used in T_Scan.test_grad_multiple_outs_taps to test
# that. # that.
if self.as_while: if info.as_while:
n = n_steps.tag.test_value n = n_steps.tag.test_value
else: else:
n = inputs[0].tag.test_value n = inputs[0].tag.test_value
...@@ -2585,7 +2578,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2585,7 +2578,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
assert x[::-1][:-1].tag.test_value.shape[0] == n assert x[::-1][:-1].tag.test_value.shape[0] == n
for x in self.outer_nitsot_outs(outs): for x in self.outer_nitsot_outs(outs):
if hasattr(x[::-1].tag, "test_value"): if hasattr(x[::-1].tag, "test_value"):
if self.as_while: if info.as_while:
assert x[n_steps - 1 :: -1].tag.test_value.shape[0] == n assert x[n_steps - 1 :: -1].tag.test_value.shape[0] == n
else: else:
assert x[::-1].tag.test_value.shape[0] == n assert x[::-1].tag.test_value.shape[0] == n
...@@ -2874,7 +2867,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2874,7 +2867,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
+ outer_inp_seqs + outer_inp_seqs
+ outer_inp_mitmot + outer_inp_mitmot
+ outer_inp_sitsot + outer_inp_sitsot
+ [n_steps if self.as_while else inputs[0] for _ in range(n_nit_sot)] + [n_steps if info.as_while else inputs[0] for _ in range(n_nit_sot)]
+ self.outer_shared(inputs) + self.outer_shared(inputs)
+ self.outer_non_seqs(inputs) + self.outer_non_seqs(inputs)
) )
...@@ -2900,6 +2893,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2900,6 +2893,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
n_nit_sot=n_nit_sot, n_nit_sot=n_nit_sot,
n_non_seqs=len(self.outer_shared(inputs)) n_non_seqs=len(self.outer_shared(inputs))
+ len(self.outer_non_seqs(inputs)), + len(self.outer_non_seqs(inputs)),
as_while=False,
) )
local_op = Scan( local_op = Scan(
...@@ -2908,7 +2902,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2908,7 +2902,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
out_info, out_info,
mode=self.mode, mode=self.mode,
truncate_gradient=self.truncate_gradient, truncate_gradient=self.truncate_gradient,
as_while=False,
profile=self.profile, profile=self.profile,
name=f"grad_of_{self.name}" if self.name else None, name=f"grad_of_{self.name}" if self.name else None,
allow_gc=self.allow_gc, allow_gc=self.allow_gc,
...@@ -2930,7 +2923,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2930,7 +2923,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# If the forward scan is in as_while mode, we need to pad # If the forward scan is in as_while mode, we need to pad
# the gradients, so that they match the size of the input # the gradients, so that they match the size of the input
# sequences. # sequences.
if self.as_while: if info.as_while:
n_zeros = inputs[0] - n_steps n_zeros = inputs[0] - n_steps
shp = (n_zeros,) shp = (n_zeros,)
if x.ndim > 1: if x.ndim > 1:
...@@ -2958,7 +2951,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2958,7 +2951,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# If the forward scan is in as_while mode, we need to pad # If the forward scan is in as_while mode, we need to pad
# the gradients, so that they match the size of the input # the gradients, so that they match the size of the input
# sequences. # sequences.
if self.as_while: if info.as_while:
n_zeros = inputs[0] - grad_steps n_zeros = inputs[0] - grad_steps
shp = (n_zeros,) shp = (n_zeros,)
if x.ndim > 1: if x.ndim > 1:
...@@ -3052,7 +3045,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3052,7 +3045,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# Step 1. Compute the R_op of the inner function # Step 1. Compute the R_op of the inner function
inner_eval_points = [safe_new(x, "_evalpoint") for x in rop_of_inputs] inner_eval_points = [safe_new(x, "_evalpoint") for x in rop_of_inputs]
if self.as_while: if info.as_while:
rop_self_outputs = self_outputs[:-1] rop_self_outputs = self_outputs[:-1]
else: else:
rop_self_outputs = self_outputs rop_self_outputs = self_outputs
...@@ -3209,7 +3202,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3209,7 +3202,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
+ inner_out_shared + inner_out_shared
) )
if self.as_while: if info.as_while:
inner_outs += [self_outputs[-1]] inner_outs += [self_outputs[-1]]
scan_inputs = ( scan_inputs = (
[inputs[0]] [inputs[0]]
...@@ -3233,6 +3226,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3233,6 +3226,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
tap_array=tuple(tuple(v) for v in new_tap_array), tap_array=tuple(tuple(v) for v in new_tap_array),
mit_mot_out_slices=tuple(tuple(v) for v in info.mit_mot_out_slices) * 2, mit_mot_out_slices=tuple(tuple(v) for v in info.mit_mot_out_slices) * 2,
n_non_seqs=len(inner_other), n_non_seqs=len(inner_other),
as_while=info.as_while,
) )
local_op = Scan( local_op = Scan(
...@@ -3240,7 +3234,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3240,7 +3234,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
inner_outs, inner_outs,
out_info, out_info,
mode=self.mode, mode=self.mode,
as_while=self.as_while,
profile=self.profile, profile=self.profile,
truncate_gradient=self.truncate_gradient, truncate_gradient=self.truncate_gradient,
name=f"rop_of_{self.name}" if self.name else None, name=f"rop_of_{self.name}" if self.name else None,
...@@ -3363,7 +3356,6 @@ def _op_debug_information_Scan(op, node): ...@@ -3363,7 +3356,6 @@ def _op_debug_information_Scan(op, node):
inner_inputs, inner_inputs,
inner_outputs, inner_outputs,
node.op.info, node.op.info,
node.op.as_while,
clone=False, clone=False,
) )
......
...@@ -176,7 +176,6 @@ def remove_constants_and_unused_inputs_scan(fgraph, node): ...@@ -176,7 +176,6 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
op_outs, op_outs,
nw_info, nw_info,
mode=op.mode, mode=op.mode,
as_while=op.as_while,
profile=op.profile, profile=op.profile,
truncate_gradient=op.truncate_gradient, truncate_gradient=op.truncate_gradient,
# TODO: This seems questionable # TODO: This seems questionable
...@@ -351,7 +350,6 @@ def push_out_non_seq_scan(fgraph, node): ...@@ -351,7 +350,6 @@ def push_out_non_seq_scan(fgraph, node):
op_outs, op_outs,
new_info, new_info,
mode=op.mode, mode=op.mode,
as_while=op.as_while,
profile=op.profile, profile=op.profile,
truncate_gradient=op.truncate_gradient, truncate_gradient=op.truncate_gradient,
# TODO: This seems questionable # TODO: This seems questionable
...@@ -589,7 +587,6 @@ def push_out_seq_scan(fgraph, node): ...@@ -589,7 +587,6 @@ def push_out_seq_scan(fgraph, node):
op_outs, op_outs,
nw_info, nw_info,
mode=op.mode, mode=op.mode,
as_while=op.as_while,
profile=op.profile, profile=op.profile,
truncate_gradient=op.truncate_gradient, truncate_gradient=op.truncate_gradient,
# TODO: This seems questionable # TODO: This seems questionable
...@@ -606,7 +603,7 @@ def push_out_seq_scan(fgraph, node): ...@@ -606,7 +603,7 @@ def push_out_seq_scan(fgraph, node):
replacements["remove"] = [node] replacements["remove"] = [node]
return replacements return replacements
elif not to_keep_set and not op.as_while and not op.outer_mitmot(node.inputs): elif not to_keep_set and not op.info.as_while and not op.outer_mitmot(node.inputs):
# Nothing in the inner graph should be kept # Nothing in the inner graph should be kept
replace_with = {} replace_with = {}
for out, idx in to_replace_map.items(): for out, idx in to_replace_map.items():
...@@ -728,7 +725,6 @@ def push_out_inner_vars( ...@@ -728,7 +725,6 @@ def push_out_inner_vars(
new_scan_node.op.inputs, new_scan_node.op.inputs,
new_scan_node.op.outputs, new_scan_node.op.outputs,
new_scan_node.op.info, new_scan_node.op.info,
new_scan_node.op.as_while,
) )
new_outs = new_scan_args.outer_out_nit_sot[-len(add_as_nitsots) :] new_outs = new_scan_args.outer_out_nit_sot[-len(add_as_nitsots) :]
...@@ -770,7 +766,6 @@ def add_nitsot_outputs( ...@@ -770,7 +766,6 @@ def add_nitsot_outputs(
new_scan_args.inner_outputs, new_scan_args.inner_outputs,
new_scan_args.info, new_scan_args.info,
mode=old_scan_node.op.mode, mode=old_scan_node.op.mode,
as_while=old_scan_node.op.as_while,
profile=old_scan_node.op.profile, profile=old_scan_node.op.profile,
truncate_gradient=old_scan_node.op.truncate_gradient, truncate_gradient=old_scan_node.op.truncate_gradient,
# TODO: This seems questionable # TODO: This seems questionable
...@@ -818,16 +813,14 @@ def push_out_add_scan(fgraph, node): ...@@ -818,16 +813,14 @@ def push_out_add_scan(fgraph, node):
# Don't perform the optimization on `as_while` `Scan`s. Because these # Don't perform the optimization on `as_while` `Scan`s. Because these
# `Scan`s don't run for a predetermined number of steps, handling them is # `Scan`s don't run for a predetermined number of steps, handling them is
# more complicated and this optimization doesn't support it at the moment. # more complicated and this optimization doesn't support it at the moment.
if not (isinstance(node.op, Scan) and not node.op.as_while): if not (isinstance(node.op, Scan) and not node.op.info.as_while):
return False return False
op = node.op op = node.op
# Use `ScanArgs` to parse the inputs and outputs of scan for ease of # Use `ScanArgs` to parse the inputs and outputs of scan for ease of
# use # use
args = ScanArgs( args = ScanArgs(node.inputs, node.outputs, op.inputs, op.outputs, op.info)
node.inputs, node.outputs, op.inputs, op.outputs, op.info, op.as_while
)
clients = {} clients = {}
local_fgraph_topo = io_toposort( local_fgraph_topo = io_toposort(
...@@ -997,7 +990,6 @@ class ScanInplaceOptimizer(GlobalOptimizer): ...@@ -997,7 +990,6 @@ class ScanInplaceOptimizer(GlobalOptimizer):
op.info, op.info,
mode=op.mode, mode=op.mode,
typeConstructor=typeConstructor, typeConstructor=typeConstructor,
as_while=op.as_while,
profile=op.profile, profile=op.profile,
truncate_gradient=op.truncate_gradient, truncate_gradient=op.truncate_gradient,
# TODO: This seems questionable # TODO: This seems questionable
...@@ -1525,7 +1517,6 @@ def save_mem_new_scan(fgraph, node): ...@@ -1525,7 +1517,6 @@ def save_mem_new_scan(fgraph, node):
outs, outs,
info, info,
mode=op.mode, mode=op.mode,
as_while=op.as_while,
profile=op.profile, profile=op.profile,
truncate_gradient=op.truncate_gradient, truncate_gradient=op.truncate_gradient,
# TODO: This seems questionable # TODO: This seems questionable
...@@ -1662,7 +1653,7 @@ class ScanMerge(GlobalOptimizer): ...@@ -1662,7 +1653,7 @@ class ScanMerge(GlobalOptimizer):
def merge(self, nodes): def merge(self, nodes):
if nodes[0].op.as_while: if nodes[0].op.info.as_while:
as_while = True as_while = True
condition = nodes[0].op.outputs[-1] condition = nodes[0].op.outputs[-1]
else: else:
...@@ -1813,6 +1804,7 @@ class ScanMerge(GlobalOptimizer): ...@@ -1813,6 +1804,7 @@ class ScanMerge(GlobalOptimizer):
n_shared_outs=sum(nd.op.n_shared_outs for nd in nodes), n_shared_outs=sum(nd.op.n_shared_outs for nd in nodes),
n_nit_sot=sum(nd.op.n_nit_sot for nd in nodes), n_nit_sot=sum(nd.op.n_nit_sot for nd in nodes),
n_non_seqs=n_non_seqs, n_non_seqs=n_non_seqs,
as_while=as_while,
) )
old_op = nodes[0].op old_op = nodes[0].op
...@@ -1825,7 +1817,6 @@ class ScanMerge(GlobalOptimizer): ...@@ -1825,7 +1817,6 @@ class ScanMerge(GlobalOptimizer):
truncate_gradient=old_op.truncate_gradient, truncate_gradient=old_op.truncate_gradient,
allow_gc=old_op.allow_gc, allow_gc=old_op.allow_gc,
name="&".join([nd.op.name for nd in nodes]), name="&".join([nd.op.name for nd in nodes]),
as_while=as_while,
) )
new_outs = new_op(*outer_ins) new_outs = new_op(*outer_ins)
...@@ -1846,7 +1837,7 @@ class ScanMerge(GlobalOptimizer): ...@@ -1846,7 +1837,7 @@ class ScanMerge(GlobalOptimizer):
""" """
rep = set_nodes[0] rep = set_nodes[0]
if ( if (
rep.op.as_while != node.op.as_while rep.op.info.as_while != node.op.info.as_while
or node.op.truncate_gradient != rep.op.truncate_gradient or node.op.truncate_gradient != rep.op.truncate_gradient
or node.op.mode != rep.op.mode or node.op.mode != rep.op.mode
): ):
...@@ -1872,7 +1863,7 @@ class ScanMerge(GlobalOptimizer): ...@@ -1872,7 +1863,7 @@ class ScanMerge(GlobalOptimizer):
if is_in_ancestors(node, nd) or is_in_ancestors(nd, node): if is_in_ancestors(node, nd) or is_in_ancestors(nd, node):
return False return False
if not node.op.as_while: if not node.op.info.as_while:
return True return True
cond = node.op.outputs[-1] cond = node.op.outputs[-1]
rep_cond = rep.op.outputs[-1] rep_cond = rep.op.outputs[-1]
...@@ -1957,7 +1948,6 @@ def scan_merge_inouts(fgraph, node): ...@@ -1957,7 +1948,6 @@ def scan_merge_inouts(fgraph, node):
node.op.inputs, node.op.inputs,
node.op.outputs, node.op.outputs,
node.op.info, node.op.info,
node.op.as_while,
) )
inp_equiv = {} inp_equiv = {}
...@@ -2001,7 +1991,6 @@ def scan_merge_inouts(fgraph, node): ...@@ -2001,7 +1991,6 @@ def scan_merge_inouts(fgraph, node):
inner_outputs, inner_outputs,
info, info,
mode=node.op.mode, mode=node.op.mode,
as_while=node.op.as_while,
profile=node.op.profile, profile=node.op.profile,
truncate_gradient=node.op.truncate_gradient, truncate_gradient=node.op.truncate_gradient,
# TODO: This seems questionable # TODO: This seems questionable
...@@ -2019,7 +2008,6 @@ def scan_merge_inouts(fgraph, node): ...@@ -2019,7 +2008,6 @@ def scan_merge_inouts(fgraph, node):
new_op.inputs, new_op.inputs,
new_op.outputs, new_op.outputs,
new_op.info, new_op.info,
new_op.as_while,
) )
remove = [node] remove = [node]
else: else:
...@@ -2266,7 +2254,6 @@ def push_out_dot1_scan(fgraph, node): ...@@ -2266,7 +2254,6 @@ def push_out_dot1_scan(fgraph, node):
new_inner_outs, new_inner_outs,
new_info, new_info,
mode=op.mode, mode=op.mode,
as_while=op.as_while,
profile=op.profile, profile=op.profile,
truncate_gradient=op.truncate_gradient, truncate_gradient=op.truncate_gradient,
# TODO: This seems questionable # TODO: This seems questionable
......
...@@ -408,6 +408,7 @@ def compress_outs(op, not_required, inputs): ...@@ -408,6 +408,7 @@ def compress_outs(op, not_required, inputs):
n_shared_outs=0, n_shared_outs=0,
n_nit_sot=0, n_nit_sot=0,
n_non_seqs=0, n_non_seqs=0,
as_while=op.info.as_while,
) )
op_inputs = op.inputs[: op.n_seqs] op_inputs = op.inputs[: op.n_seqs]
...@@ -528,7 +529,7 @@ def compress_outs(op, not_required, inputs): ...@@ -528,7 +529,7 @@ def compress_outs(op, not_required, inputs):
op_inputs += op.inputs[i_offset:] op_inputs += op.inputs[i_offset:]
info = dataclasses.replace(info, n_non_seqs=len(op.inputs[i_offset:])) info = dataclasses.replace(info, n_non_seqs=len(op.inputs[i_offset:]))
node_inputs += inputs[ni_offset + op.n_shared_outs + op.n_nit_sot :] node_inputs += inputs[ni_offset + op.n_shared_outs + op.n_nit_sot :]
if op.as_while: if op.info.as_while:
op_outputs += [op.outputs[o_offset]] op_outputs += [op.outputs[o_offset]]
map_old_new[o_offset] = len(op_outputs) - 1 map_old_new[o_offset] = len(op_outputs) - 1
# map_old_new[len(op_outputs)-1] = o_offset # map_old_new[len(op_outputs)-1] = o_offset
...@@ -582,11 +583,10 @@ class ScanArgs: ...@@ -582,11 +583,10 @@ class ScanArgs:
_inner_inputs: Sequence[Variable], _inner_inputs: Sequence[Variable],
_inner_outputs: Sequence[Variable], _inner_outputs: Sequence[Variable],
info: "ScanInfo", info: "ScanInfo",
as_while: bool,
clone: Optional[bool] = True, clone: Optional[bool] = True,
): ):
self.n_steps = outer_inputs[0] self.n_steps = outer_inputs[0]
self.as_while = as_while self.as_while = info.as_while
if clone: if clone:
rval = reconstruct_graph(_inner_inputs, _inner_outputs, "") rval = reconstruct_graph(_inner_inputs, _inner_outputs, "")
...@@ -710,7 +710,6 @@ class ScanArgs: ...@@ -710,7 +710,6 @@ class ScanArgs:
node.op.inputs, node.op.inputs,
node.op.outputs, node.op.outputs,
node.op.info, node.op.info,
node.op.as_while,
clone=clone, clone=clone,
) )
...@@ -815,6 +814,7 @@ class ScanArgs: ...@@ -815,6 +814,7 @@ class ScanArgs:
n_mit_mot_outs=sum(len(s) for s in self.mit_mot_out_slices), n_mit_mot_outs=sum(len(s) for s in self.mit_mot_out_slices),
mit_mot_out_slices=tuple(self.mit_mot_out_slices), mit_mot_out_slices=tuple(self.mit_mot_out_slices),
n_non_seqs=len(self.inner_in_non_seqs), n_non_seqs=len(self.inner_in_non_seqs),
as_while=self.as_while,
) )
def get_alt_field( def get_alt_field(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论