提交 7c4871ce authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove non-sequence settings from ScanInfo

上级 6f685799
......@@ -1033,6 +1033,13 @@ def scan(
n_sit_sot=n_sit_sot,
n_shared_outs=n_shared_outs,
n_nit_sot=n_nit_sot,
)
local_op = Scan(
inner_inputs,
new_outs,
info,
mode=mode,
truncate_gradient=truncate_gradient,
name=name,
gpua=False,
......@@ -1042,8 +1049,6 @@ def scan(
strict=strict,
)
local_op = Scan(inner_inputs, new_outs, info, mode)
##
# Step 8. Compute the outputs using the scan op
##
......
差异被折叠。
......@@ -217,7 +217,19 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
if len(nw_inner) != len(op_ins):
op_outs = clone_replace(op_outs, replace=givens)
nw_info = dataclasses.replace(op.info, n_seqs=nw_n_seqs)
nwScan = Scan(nw_inner, op_outs, nw_info, op.mode)
nwScan = Scan(
nw_inner,
op_outs,
nw_info,
mode=op.mode,
gpua=op.gpua,
as_while=op.as_while,
profile=op.profile,
truncate_gradient=op.truncate_gradient,
# TODO: This seems questionable
name=op.name,
allow_gc=op.allow_gc,
)
nw_outs = nwScan(*nw_outer, return_list=True)
return dict([("remove", [node])] + list(zip(node.outputs, nw_outs)))
else:
......@@ -396,7 +408,19 @@ class PushOutNonSeqScan(GlobalOptimizer):
op_ins = clean_inputs + nw_inner
# Reconstruct node
nwScan = Scan(op_ins, op_outs, op.info, op.mode)
nwScan = Scan(
op_ins,
op_outs,
op.info,
mode=op.mode,
gpua=op.gpua,
as_while=op.as_while,
profile=op.profile,
truncate_gradient=op.truncate_gradient,
# TODO: This seems questionable
name=op.name,
allow_gc=op.allow_gc,
)
# Do not call make_node for test_value
nw_node = nwScan(*(node.inputs + nw_outer), return_list=True)[0].owner
......@@ -666,7 +690,19 @@ class PushOutSeqScan(GlobalOptimizer):
nw_info = dataclasses.replace(
op.info, n_seqs=op.info.n_seqs + len(nw_inner)
)
nwScan = Scan(op_ins, op_outs, nw_info, op.mode)
nwScan = Scan(
op_ins,
op_outs,
nw_info,
mode=op.mode,
gpua=op.gpua,
as_while=op.as_while,
profile=op.profile,
truncate_gradient=op.truncate_gradient,
# TODO: This seems questionable
name=op.name,
allow_gc=op.allow_gc,
)
# Do not call make_node for test_value
nw_node = nwScan(
*(node.inputs[:1] + nw_outer + node.inputs[1:]),
......@@ -751,7 +787,9 @@ class PushOutScanOutput(GlobalOptimizer):
# Use `ScanArgs` to parse the inputs and outputs of scan for ease of
# use
args = ScanArgs(node.inputs, node.outputs, op.inputs, op.outputs, op.info)
args = ScanArgs(
node.inputs, node.outputs, op.inputs, op.outputs, op.info, op.as_while
)
new_scan_node = None
clients = {}
......@@ -921,6 +959,7 @@ class PushOutScanOutput(GlobalOptimizer):
new_scan_node.op.inputs,
new_scan_node.op.outputs,
new_scan_node.op.info,
new_scan_node.op.as_while,
)
new_outs = new_scan_args.outer_out_nit_sot[-len(add_as_nitsots) :]
......@@ -952,7 +991,14 @@ class PushOutScanOutput(GlobalOptimizer):
new_scan_args.inner_inputs,
new_scan_args.inner_outputs,
new_scan_args.info,
old_scan_node.op.mode,
mode=old_scan_node.op.mode,
gpua=old_scan_node.op.gpua,
as_while=old_scan_node.op.as_while,
profile=old_scan_node.op.profile,
truncate_gradient=old_scan_node.op.truncate_gradient,
# TODO: This seems questionable
name=old_scan_node.op.name,
allow_gc=old_scan_node.op.allow_gc,
)
# Create the Apply node for the scan op
......@@ -1059,7 +1105,18 @@ class ScanInplaceOptimizer(GlobalOptimizer):
typeConstructor = self.typeInfer(node)
new_op = Scan(
op.inputs, op.outputs, op.info, op.mode, typeConstructor=typeConstructor
op.inputs,
op.outputs,
op.info,
mode=op.mode,
typeConstructor=typeConstructor,
gpua=op.gpua,
as_while=op.as_while,
profile=op.profile,
truncate_gradient=op.truncate_gradient,
# TODO: This seems questionable
name=op.name,
allow_gc=op.allow_gc,
)
destroy_map = op.destroy_map.copy()
......@@ -1086,9 +1143,7 @@ class ScanInplaceOptimizer(GlobalOptimizer):
alloc_ops = (Alloc, AllocEmpty)
nodes = fgraph.toposort()[::-1]
scan_nodes = [
x
for x in nodes
if (isinstance(x.op, Scan) and x.op.info.gpua == self.gpua_flag)
x for x in nodes if (isinstance(x.op, Scan) and x.op.gpua == self.gpua_flag)
]
for scan_idx in range(len(scan_nodes)):
......@@ -1593,7 +1648,20 @@ class ScanSaveMem(GlobalOptimizer):
return
# Do not call make_node for test_value
new_outs = Scan(inps, outs, info, op.mode)(*node_ins, return_list=True)
new_op = Scan(
inps,
outs,
info,
mode=op.mode,
gpua=op.gpua,
as_while=op.as_while,
profile=op.profile,
truncate_gradient=op.truncate_gradient,
# TODO: This seems questionable
name=op.name,
allow_gc=op.allow_gc,
)
new_outs = new_op(*node_ins, return_list=True)
old_new = []
# 3.7 Get replace pairs for those outputs that do not change
......@@ -1871,15 +1939,21 @@ class ScanMerge(GlobalOptimizer):
n_sit_sot=sum([nd.op.n_sit_sot for nd in nodes]),
n_shared_outs=sum([nd.op.n_shared_outs for nd in nodes]),
n_nit_sot=sum([nd.op.n_nit_sot for nd in nodes]),
truncate_gradient=nodes[0].op.truncate_gradient,
)
old_op = nodes[0].op
new_op = Scan(
new_inner_ins,
new_inner_outs,
info,
mode=old_op.mode,
profile=old_op.profile,
truncate_gradient=old_op.truncate_gradient,
allow_gc=old_op.allow_gc,
name="&".join([nd.op.name for nd in nodes]),
gpua=False,
as_while=as_while,
profile=nodes[0].op.profile,
allow_gc=nodes[0].op.allow_gc,
)
new_op = Scan(new_inner_ins, new_inner_outs, info, nodes[0].op.mode)
new_outs = new_op(*outer_ins)
if not isinstance(new_outs, (list, tuple)):
......@@ -2005,7 +2079,12 @@ def scan_merge_inouts(fgraph, node):
# Equivalent inputs will be stored in inp_equiv, then a new
# scan node created without duplicates.
a = ScanArgs(
node.inputs, node.outputs, node.op.inputs, node.op.outputs, node.op.info
node.inputs,
node.outputs,
node.op.inputs,
node.op.outputs,
node.op.info,
node.op.as_while,
)
inp_equiv = {}
......@@ -2044,13 +2123,32 @@ def scan_merge_inouts(fgraph, node):
a_inner_outs = a.inner_outputs
inner_outputs = clone_replace(a_inner_outs, replace=inp_equiv)
op = Scan(inner_inputs, inner_outputs, info, node.op.mode)
outputs = op(*outer_inputs)
new_op = Scan(
inner_inputs,
inner_outputs,
info,
mode=node.op.mode,
gpua=node.op.gpua,
as_while=node.op.as_while,
profile=node.op.profile,
truncate_gradient=node.op.truncate_gradient,
# TODO: This seems questionable
name=node.op.name,
allow_gc=node.op.allow_gc,
)
outputs = new_op(*outer_inputs)
if not isinstance(outputs, (list, tuple)):
outputs = [outputs]
na = ScanArgs(outer_inputs, outputs, op.inputs, op.outputs, op.info)
na = ScanArgs(
outer_inputs,
outputs,
new_op.inputs,
new_op.outputs,
new_op.info,
new_op.as_while,
)
remove = [node]
else:
na = a
......@@ -2302,7 +2400,19 @@ class PushOutDot1(GlobalOptimizer):
new_inner_inps, new_inner_outs = reconstruct_graph(
_new_inner_inps, _new_inner_outs
)
new_op = Scan(new_inner_inps, new_inner_outs, new_info, op.mode)
new_op = Scan(
new_inner_inps,
new_inner_outs,
new_info,
mode=op.mode,
gpua=op.gpua,
as_while=op.as_while,
profile=op.profile,
truncate_gradient=op.truncate_gradient,
# TODO: This seems questionable
name=op.name,
allow_gc=op.allow_gc,
)
_scan_inputs = (
[node.inputs[0]]
+ outer_seqs
......
......@@ -701,12 +701,6 @@ def compress_outs(op, not_required, inputs):
n_sit_sot=0,
n_shared_outs=0,
n_nit_sot=0,
truncate_gradient=op.info.truncate_gradient,
name=op.info.name,
gpua=op.info.gpua,
as_while=op.info.as_while,
profile=op.info.profile,
allow_gc=op.info.allow_gc,
)
op_inputs = op.inputs[: op.n_seqs]
......@@ -886,16 +880,18 @@ class ScanArgs:
_inner_inputs,
_inner_outputs,
info,
as_while,
clone=True,
):
self.n_steps = outer_inputs[0]
self.as_while = as_while
if clone:
rval = reconstruct_graph(_inner_inputs, _inner_outputs, "")
else:
rval = (_inner_inputs, _inner_outputs)
if info.as_while:
if self.as_while:
self.cond = [rval[1][-1]]
inner_outputs = rval[1][:-1]
else:
......@@ -1000,18 +996,6 @@ class ScanArgs:
assert p == len(outer_outputs)
assert q == len(inner_outputs)
self.other_info = {
k: getattr(info, k)
for k in (
"truncate_gradient",
"name",
"gpua",
"as_while",
"profile",
"allow_gc",
)
}
@staticmethod
def from_node(node, clone=False):
from aesara.scan.op import Scan
......@@ -1024,6 +1008,7 @@ class ScanArgs:
node.op.inputs,
node.op.outputs,
node.op.info,
node.op.as_while,
clone=clone,
)
......@@ -1041,14 +1026,8 @@ class ScanArgs:
n_shared_outs=0,
n_mit_mot_outs=0,
mit_mot_out_slices=(),
truncate_gradient=-1,
name=None,
gpua=False,
as_while=False,
profile=False,
allow_gc=False,
)
res = cls([1], [], [], [], info)
res = cls([1], [], [], [], info, False)
res.n_steps = None
return res
......@@ -1152,7 +1131,6 @@ class ScanArgs:
n_shared_outs=len(self.outer_in_shared),
n_mit_mot_outs=sum(len(s) for s in self.mit_mot_out_slices),
mit_mot_out_slices=tuple(self.mit_mot_out_slices),
**self.other_info,
)
def get_alt_field(self, var_info, alt_prefix):
......@@ -1341,7 +1319,6 @@ class ScanArgs:
"mit_mot_out_slices",
"mit_mot_in_slices",
"mit_sot_in_slices",
"other_info",
)
):
setattr(res, attr, copy.copy(getattr(self, attr)))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论