提交 cdd8575b authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Move Scan's as_while to ScanInfo

上级 81a8741c
......@@ -120,7 +120,7 @@ def numba_funcify_Scan(op, node, **kwargs):
]
while_logic = ""
if op.as_while:
if op.info.as_while:
# The inner function will be returning a boolean as last argument
inner_out_indexed.append("while_flag")
while_logic += """
......
......@@ -1140,6 +1140,7 @@ def scan(
n_shared_outs=n_shared_outs,
n_nit_sot=n_nit_sot,
n_non_seqs=len(other_shared_inner_args) + len(other_inner_args),
as_while=as_while,
)
local_op = Scan(
......@@ -1149,7 +1150,6 @@ def scan(
mode=mode,
truncate_gradient=truncate_gradient,
name=name,
as_while=as_while,
profile=profile,
allow_gc=allow_gc,
strict=strict,
......
......@@ -217,6 +217,7 @@ class ScanInfo:
n_shared_outs: int
n_nit_sot: int
n_non_seqs: int
as_while: bool
TensorConstructorType = Callable[[List[bool], Union[str, np.generic]], TensorType]
......@@ -670,8 +671,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
as well as profiles for the computation of one step of each instance of
`Scan`. The `name` of the instance appears in those profiles and can
greatly help to disambiguate information.
as_while
Whether or not the `Scan` is a ``while``-loop.
profile
If ``True`` or a non-empty string, a profile object will be created and
attached to the inner graph of `Scan`. When `profile` is ``True``, the
......@@ -701,7 +700,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self.info = info
self.truncate_gradient = truncate_gradient
self.name = name
self.as_while = as_while
self.profile = profile
self.allow_gc = allow_gc
self.strict = strict
......@@ -753,7 +751,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
for o in outputs[end:]:
self.output_types.append(o.type)
if self.as_while:
if info.as_while:
self.output_types = self.output_types[:-1]
if not hasattr(self, "name") or self.name is None:
......@@ -1201,9 +1199,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
if self.info != other.info:
return False
if self.as_while != other.as_while:
return False
if self.profile != other.profile:
return False
......@@ -1234,7 +1229,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
def __str__(self):
device_str = "cpu"
if self.as_while:
if self.info.as_while:
name = "do_while"
else:
name = "for"
......@@ -1261,7 +1256,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
type(self),
self._hash_inner_graph,
self.info,
self.as_while,
self.profile,
self.truncate_gradient,
self.name,
......@@ -1510,7 +1504,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self.info.n_mit_sot,
self.info.n_sit_sot,
self.info.n_nit_sot,
self.as_while,
self.info.as_while,
cython_mintaps,
self.info.tap_array,
tap_array_len,
......@@ -1777,7 +1771,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
inner_output_storage[idx + offset].storage[0] = None
# 4.4. If there is a condition add it to the mix
if self.as_while:
if info.as_while:
pdx = offset + info.n_shared_outs
inner_output_storage[pdx].storage[0] = None
......@@ -1847,7 +1841,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
raise
dt_fn = time.time() - t0_fn
if self.as_while:
if info.as_while:
pdx = offset + info.n_shared_outs
cond = inner_output_storage[pdx].storage[0] == 0
......@@ -2173,7 +2167,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
for in_ns, out_ns in zip(inner_non_sequences, node.inputs[offset:]):
out_equivalent[in_ns] = out_ns
if self.as_while:
if info.as_while:
self_outs = self.outputs[:-1]
else:
self_outs = self.outputs
......@@ -2222,7 +2216,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
scan_outs += [x for x in input_shapes[offset : offset + info.n_shared_outs]]
# if we are dealing with a repeat-until, then we do not know the
# leading dimension so we replace it for every entry with Shape_i
if self.as_while:
if info.as_while:
scan_outs_init = scan_outs
scan_outs = []
for o, x in zip(node.outputs, scan_outs_init):
......@@ -2312,7 +2306,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
)
else:
grad_steps = inputs[0]
if self.as_while:
if info.as_while:
n_steps = outs[0].shape[0]
# Restrict the number of grad steps according to
......@@ -2537,9 +2531,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
dC_dinps_t[dx + info.n_seqs] = dC_dXtm1
else:
dC_dinps_t[dx + info.n_seqs] += dC_dXtm1
# Construct scan op
# Seqs
if self.as_while:
if info.as_while:
# equivalent to x[:n_steps][::-1]
outer_inp_seqs = [x[n_steps - 1 :: -1] for x in inputs[1 : 1 + info.n_seqs]]
else:
......@@ -2560,7 +2553,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
outer_inp_seqs += [x[:-1][::-1] for x in self.outer_sitsot_outs(outs)]
for x in self.outer_nitsot_outs(dC_douts):
if not isinstance(x.type, DisconnectedType):
if self.as_while:
if info.as_while:
# equivalent to x[:n_steps][::-1]
outer_inp_seqs.append(x[n_steps - 1 :: -1])
else:
......@@ -2572,7 +2565,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# fct add and we want to keep it for all Scan op. This is
# used in T_Scan.test_grad_multiple_outs_taps to test
# that.
if self.as_while:
if info.as_while:
n = n_steps.tag.test_value
else:
n = inputs[0].tag.test_value
......@@ -2585,7 +2578,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
assert x[::-1][:-1].tag.test_value.shape[0] == n
for x in self.outer_nitsot_outs(outs):
if hasattr(x[::-1].tag, "test_value"):
if self.as_while:
if info.as_while:
assert x[n_steps - 1 :: -1].tag.test_value.shape[0] == n
else:
assert x[::-1].tag.test_value.shape[0] == n
......@@ -2874,7 +2867,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
+ outer_inp_seqs
+ outer_inp_mitmot
+ outer_inp_sitsot
+ [n_steps if self.as_while else inputs[0] for _ in range(n_nit_sot)]
+ [n_steps if info.as_while else inputs[0] for _ in range(n_nit_sot)]
+ self.outer_shared(inputs)
+ self.outer_non_seqs(inputs)
)
......@@ -2900,6 +2893,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
n_nit_sot=n_nit_sot,
n_non_seqs=len(self.outer_shared(inputs))
+ len(self.outer_non_seqs(inputs)),
as_while=False,
)
local_op = Scan(
......@@ -2908,7 +2902,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
out_info,
mode=self.mode,
truncate_gradient=self.truncate_gradient,
as_while=False,
profile=self.profile,
name=f"grad_of_{self.name}" if self.name else None,
allow_gc=self.allow_gc,
......@@ -2930,7 +2923,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# If the forward scan is in as_while mode, we need to pad
# the gradients, so that they match the size of the input
# sequences.
if self.as_while:
if info.as_while:
n_zeros = inputs[0] - n_steps
shp = (n_zeros,)
if x.ndim > 1:
......@@ -2958,7 +2951,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# If the forward scan is in as_while mode, we need to pad
# the gradients, so that they match the size of the input
# sequences.
if self.as_while:
if info.as_while:
n_zeros = inputs[0] - grad_steps
shp = (n_zeros,)
if x.ndim > 1:
......@@ -3052,7 +3045,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# Step 1. Compute the R_op of the inner function
inner_eval_points = [safe_new(x, "_evalpoint") for x in rop_of_inputs]
if self.as_while:
if info.as_while:
rop_self_outputs = self_outputs[:-1]
else:
rop_self_outputs = self_outputs
......@@ -3209,7 +3202,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
+ inner_out_shared
)
if self.as_while:
if info.as_while:
inner_outs += [self_outputs[-1]]
scan_inputs = (
[inputs[0]]
......@@ -3233,6 +3226,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
tap_array=tuple(tuple(v) for v in new_tap_array),
mit_mot_out_slices=tuple(tuple(v) for v in info.mit_mot_out_slices) * 2,
n_non_seqs=len(inner_other),
as_while=info.as_while,
)
local_op = Scan(
......@@ -3240,7 +3234,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
inner_outs,
out_info,
mode=self.mode,
as_while=self.as_while,
profile=self.profile,
truncate_gradient=self.truncate_gradient,
name=f"rop_of_{self.name}" if self.name else None,
......@@ -3363,7 +3356,6 @@ def _op_debug_information_Scan(op, node):
inner_inputs,
inner_outputs,
node.op.info,
node.op.as_while,
clone=False,
)
......
......@@ -176,7 +176,6 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
op_outs,
nw_info,
mode=op.mode,
as_while=op.as_while,
profile=op.profile,
truncate_gradient=op.truncate_gradient,
# TODO: This seems questionable
......@@ -351,7 +350,6 @@ def push_out_non_seq_scan(fgraph, node):
op_outs,
new_info,
mode=op.mode,
as_while=op.as_while,
profile=op.profile,
truncate_gradient=op.truncate_gradient,
# TODO: This seems questionable
......@@ -589,7 +587,6 @@ def push_out_seq_scan(fgraph, node):
op_outs,
nw_info,
mode=op.mode,
as_while=op.as_while,
profile=op.profile,
truncate_gradient=op.truncate_gradient,
# TODO: This seems questionable
......@@ -606,7 +603,7 @@ def push_out_seq_scan(fgraph, node):
replacements["remove"] = [node]
return replacements
elif not to_keep_set and not op.as_while and not op.outer_mitmot(node.inputs):
elif not to_keep_set and not op.info.as_while and not op.outer_mitmot(node.inputs):
# Nothing in the inner graph should be kept
replace_with = {}
for out, idx in to_replace_map.items():
......@@ -728,7 +725,6 @@ def push_out_inner_vars(
new_scan_node.op.inputs,
new_scan_node.op.outputs,
new_scan_node.op.info,
new_scan_node.op.as_while,
)
new_outs = new_scan_args.outer_out_nit_sot[-len(add_as_nitsots) :]
......@@ -770,7 +766,6 @@ def add_nitsot_outputs(
new_scan_args.inner_outputs,
new_scan_args.info,
mode=old_scan_node.op.mode,
as_while=old_scan_node.op.as_while,
profile=old_scan_node.op.profile,
truncate_gradient=old_scan_node.op.truncate_gradient,
# TODO: This seems questionable
......@@ -818,16 +813,14 @@ def push_out_add_scan(fgraph, node):
# Don't perform the optimization on `as_while` `Scan`s. Because these
# `Scan`s don't run for a predetermined number of steps, handling them is
# more complicated and this optimization doesn't support it at the moment.
if not (isinstance(node.op, Scan) and not node.op.as_while):
if not (isinstance(node.op, Scan) and not node.op.info.as_while):
return False
op = node.op
# Use `ScanArgs` to parse the inputs and outputs of scan for ease of
# use
args = ScanArgs(
node.inputs, node.outputs, op.inputs, op.outputs, op.info, op.as_while
)
args = ScanArgs(node.inputs, node.outputs, op.inputs, op.outputs, op.info)
clients = {}
local_fgraph_topo = io_toposort(
......@@ -997,7 +990,6 @@ class ScanInplaceOptimizer(GlobalOptimizer):
op.info,
mode=op.mode,
typeConstructor=typeConstructor,
as_while=op.as_while,
profile=op.profile,
truncate_gradient=op.truncate_gradient,
# TODO: This seems questionable
......@@ -1525,7 +1517,6 @@ def save_mem_new_scan(fgraph, node):
outs,
info,
mode=op.mode,
as_while=op.as_while,
profile=op.profile,
truncate_gradient=op.truncate_gradient,
# TODO: This seems questionable
......@@ -1662,7 +1653,7 @@ class ScanMerge(GlobalOptimizer):
def merge(self, nodes):
if nodes[0].op.as_while:
if nodes[0].op.info.as_while:
as_while = True
condition = nodes[0].op.outputs[-1]
else:
......@@ -1813,6 +1804,7 @@ class ScanMerge(GlobalOptimizer):
n_shared_outs=sum(nd.op.n_shared_outs for nd in nodes),
n_nit_sot=sum(nd.op.n_nit_sot for nd in nodes),
n_non_seqs=n_non_seqs,
as_while=as_while,
)
old_op = nodes[0].op
......@@ -1825,7 +1817,6 @@ class ScanMerge(GlobalOptimizer):
truncate_gradient=old_op.truncate_gradient,
allow_gc=old_op.allow_gc,
name="&".join([nd.op.name for nd in nodes]),
as_while=as_while,
)
new_outs = new_op(*outer_ins)
......@@ -1846,7 +1837,7 @@ class ScanMerge(GlobalOptimizer):
"""
rep = set_nodes[0]
if (
rep.op.as_while != node.op.as_while
rep.op.info.as_while != node.op.info.as_while
or node.op.truncate_gradient != rep.op.truncate_gradient
or node.op.mode != rep.op.mode
):
......@@ -1872,7 +1863,7 @@ class ScanMerge(GlobalOptimizer):
if is_in_ancestors(node, nd) or is_in_ancestors(nd, node):
return False
if not node.op.as_while:
if not node.op.info.as_while:
return True
cond = node.op.outputs[-1]
rep_cond = rep.op.outputs[-1]
......@@ -1957,7 +1948,6 @@ def scan_merge_inouts(fgraph, node):
node.op.inputs,
node.op.outputs,
node.op.info,
node.op.as_while,
)
inp_equiv = {}
......@@ -2001,7 +1991,6 @@ def scan_merge_inouts(fgraph, node):
inner_outputs,
info,
mode=node.op.mode,
as_while=node.op.as_while,
profile=node.op.profile,
truncate_gradient=node.op.truncate_gradient,
# TODO: This seems questionable
......@@ -2019,7 +2008,6 @@ def scan_merge_inouts(fgraph, node):
new_op.inputs,
new_op.outputs,
new_op.info,
new_op.as_while,
)
remove = [node]
else:
......@@ -2266,7 +2254,6 @@ def push_out_dot1_scan(fgraph, node):
new_inner_outs,
new_info,
mode=op.mode,
as_while=op.as_while,
profile=op.profile,
truncate_gradient=op.truncate_gradient,
# TODO: This seems questionable
......
......@@ -408,6 +408,7 @@ def compress_outs(op, not_required, inputs):
n_shared_outs=0,
n_nit_sot=0,
n_non_seqs=0,
as_while=op.info.as_while,
)
op_inputs = op.inputs[: op.n_seqs]
......@@ -528,7 +529,7 @@ def compress_outs(op, not_required, inputs):
op_inputs += op.inputs[i_offset:]
info = dataclasses.replace(info, n_non_seqs=len(op.inputs[i_offset:]))
node_inputs += inputs[ni_offset + op.n_shared_outs + op.n_nit_sot :]
if op.as_while:
if op.info.as_while:
op_outputs += [op.outputs[o_offset]]
map_old_new[o_offset] = len(op_outputs) - 1
# map_old_new[len(op_outputs)-1] = o_offset
......@@ -582,11 +583,10 @@ class ScanArgs:
_inner_inputs: Sequence[Variable],
_inner_outputs: Sequence[Variable],
info: "ScanInfo",
as_while: bool,
clone: Optional[bool] = True,
):
self.n_steps = outer_inputs[0]
self.as_while = as_while
self.as_while = info.as_while
if clone:
rval = reconstruct_graph(_inner_inputs, _inner_outputs, "")
......@@ -710,7 +710,6 @@ class ScanArgs:
node.op.inputs,
node.op.outputs,
node.op.info,
node.op.as_while,
clone=clone,
)
......@@ -815,6 +814,7 @@ class ScanArgs:
n_mit_mot_outs=sum(len(s) for s in self.mit_mot_out_slices),
mit_mot_out_slices=tuple(self.mit_mot_out_slices),
n_non_seqs=len(self.inner_in_non_seqs),
as_while=self.as_while,
)
def get_alt_field(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论