提交 7c4871ce authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove non-sequence settings from ScanInfo

上级 6f685799
...@@ -1033,6 +1033,13 @@ def scan( ...@@ -1033,6 +1033,13 @@ def scan(
n_sit_sot=n_sit_sot, n_sit_sot=n_sit_sot,
n_shared_outs=n_shared_outs, n_shared_outs=n_shared_outs,
n_nit_sot=n_nit_sot, n_nit_sot=n_nit_sot,
)
local_op = Scan(
inner_inputs,
new_outs,
info,
mode=mode,
truncate_gradient=truncate_gradient, truncate_gradient=truncate_gradient,
name=name, name=name,
gpua=False, gpua=False,
...@@ -1042,8 +1049,6 @@ def scan( ...@@ -1042,8 +1049,6 @@ def scan(
strict=strict, strict=strict,
) )
local_op = Scan(inner_inputs, new_outs, info, mode)
## ##
# Step 8. Compute the outputs using the scan op # Step 8. Compute the outputs using the scan op
## ##
......
差异被折叠。
...@@ -217,7 +217,19 @@ def remove_constants_and_unused_inputs_scan(fgraph, node): ...@@ -217,7 +217,19 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
if len(nw_inner) != len(op_ins): if len(nw_inner) != len(op_ins):
op_outs = clone_replace(op_outs, replace=givens) op_outs = clone_replace(op_outs, replace=givens)
nw_info = dataclasses.replace(op.info, n_seqs=nw_n_seqs) nw_info = dataclasses.replace(op.info, n_seqs=nw_n_seqs)
nwScan = Scan(nw_inner, op_outs, nw_info, op.mode) nwScan = Scan(
nw_inner,
op_outs,
nw_info,
mode=op.mode,
gpua=op.gpua,
as_while=op.as_while,
profile=op.profile,
truncate_gradient=op.truncate_gradient,
# TODO: This seems questionable
name=op.name,
allow_gc=op.allow_gc,
)
nw_outs = nwScan(*nw_outer, return_list=True) nw_outs = nwScan(*nw_outer, return_list=True)
return dict([("remove", [node])] + list(zip(node.outputs, nw_outs))) return dict([("remove", [node])] + list(zip(node.outputs, nw_outs)))
else: else:
...@@ -396,7 +408,19 @@ class PushOutNonSeqScan(GlobalOptimizer): ...@@ -396,7 +408,19 @@ class PushOutNonSeqScan(GlobalOptimizer):
op_ins = clean_inputs + nw_inner op_ins = clean_inputs + nw_inner
# Reconstruct node # Reconstruct node
nwScan = Scan(op_ins, op_outs, op.info, op.mode) nwScan = Scan(
op_ins,
op_outs,
op.info,
mode=op.mode,
gpua=op.gpua,
as_while=op.as_while,
profile=op.profile,
truncate_gradient=op.truncate_gradient,
# TODO: This seems questionable
name=op.name,
allow_gc=op.allow_gc,
)
# Do not call make_node for test_value # Do not call make_node for test_value
nw_node = nwScan(*(node.inputs + nw_outer), return_list=True)[0].owner nw_node = nwScan(*(node.inputs + nw_outer), return_list=True)[0].owner
...@@ -666,7 +690,19 @@ class PushOutSeqScan(GlobalOptimizer): ...@@ -666,7 +690,19 @@ class PushOutSeqScan(GlobalOptimizer):
nw_info = dataclasses.replace( nw_info = dataclasses.replace(
op.info, n_seqs=op.info.n_seqs + len(nw_inner) op.info, n_seqs=op.info.n_seqs + len(nw_inner)
) )
nwScan = Scan(op_ins, op_outs, nw_info, op.mode) nwScan = Scan(
op_ins,
op_outs,
nw_info,
mode=op.mode,
gpua=op.gpua,
as_while=op.as_while,
profile=op.profile,
truncate_gradient=op.truncate_gradient,
# TODO: This seems questionable
name=op.name,
allow_gc=op.allow_gc,
)
# Do not call make_node for test_value # Do not call make_node for test_value
nw_node = nwScan( nw_node = nwScan(
*(node.inputs[:1] + nw_outer + node.inputs[1:]), *(node.inputs[:1] + nw_outer + node.inputs[1:]),
...@@ -751,7 +787,9 @@ class PushOutScanOutput(GlobalOptimizer): ...@@ -751,7 +787,9 @@ class PushOutScanOutput(GlobalOptimizer):
# Use `ScanArgs` to parse the inputs and outputs of scan for ease of # Use `ScanArgs` to parse the inputs and outputs of scan for ease of
# use # use
args = ScanArgs(node.inputs, node.outputs, op.inputs, op.outputs, op.info) args = ScanArgs(
node.inputs, node.outputs, op.inputs, op.outputs, op.info, op.as_while
)
new_scan_node = None new_scan_node = None
clients = {} clients = {}
...@@ -921,6 +959,7 @@ class PushOutScanOutput(GlobalOptimizer): ...@@ -921,6 +959,7 @@ class PushOutScanOutput(GlobalOptimizer):
new_scan_node.op.inputs, new_scan_node.op.inputs,
new_scan_node.op.outputs, new_scan_node.op.outputs,
new_scan_node.op.info, new_scan_node.op.info,
new_scan_node.op.as_while,
) )
new_outs = new_scan_args.outer_out_nit_sot[-len(add_as_nitsots) :] new_outs = new_scan_args.outer_out_nit_sot[-len(add_as_nitsots) :]
...@@ -952,7 +991,14 @@ class PushOutScanOutput(GlobalOptimizer): ...@@ -952,7 +991,14 @@ class PushOutScanOutput(GlobalOptimizer):
new_scan_args.inner_inputs, new_scan_args.inner_inputs,
new_scan_args.inner_outputs, new_scan_args.inner_outputs,
new_scan_args.info, new_scan_args.info,
old_scan_node.op.mode, mode=old_scan_node.op.mode,
gpua=old_scan_node.op.gpua,
as_while=old_scan_node.op.as_while,
profile=old_scan_node.op.profile,
truncate_gradient=old_scan_node.op.truncate_gradient,
# TODO: This seems questionable
name=old_scan_node.op.name,
allow_gc=old_scan_node.op.allow_gc,
) )
# Create the Apply node for the scan op # Create the Apply node for the scan op
...@@ -1059,7 +1105,18 @@ class ScanInplaceOptimizer(GlobalOptimizer): ...@@ -1059,7 +1105,18 @@ class ScanInplaceOptimizer(GlobalOptimizer):
typeConstructor = self.typeInfer(node) typeConstructor = self.typeInfer(node)
new_op = Scan( new_op = Scan(
op.inputs, op.outputs, op.info, op.mode, typeConstructor=typeConstructor op.inputs,
op.outputs,
op.info,
mode=op.mode,
typeConstructor=typeConstructor,
gpua=op.gpua,
as_while=op.as_while,
profile=op.profile,
truncate_gradient=op.truncate_gradient,
# TODO: This seems questionable
name=op.name,
allow_gc=op.allow_gc,
) )
destroy_map = op.destroy_map.copy() destroy_map = op.destroy_map.copy()
...@@ -1086,9 +1143,7 @@ class ScanInplaceOptimizer(GlobalOptimizer): ...@@ -1086,9 +1143,7 @@ class ScanInplaceOptimizer(GlobalOptimizer):
alloc_ops = (Alloc, AllocEmpty) alloc_ops = (Alloc, AllocEmpty)
nodes = fgraph.toposort()[::-1] nodes = fgraph.toposort()[::-1]
scan_nodes = [ scan_nodes = [
x x for x in nodes if (isinstance(x.op, Scan) and x.op.gpua == self.gpua_flag)
for x in nodes
if (isinstance(x.op, Scan) and x.op.info.gpua == self.gpua_flag)
] ]
for scan_idx in range(len(scan_nodes)): for scan_idx in range(len(scan_nodes)):
...@@ -1593,7 +1648,20 @@ class ScanSaveMem(GlobalOptimizer): ...@@ -1593,7 +1648,20 @@ class ScanSaveMem(GlobalOptimizer):
return return
# Do not call make_node for test_value # Do not call make_node for test_value
new_outs = Scan(inps, outs, info, op.mode)(*node_ins, return_list=True) new_op = Scan(
inps,
outs,
info,
mode=op.mode,
gpua=op.gpua,
as_while=op.as_while,
profile=op.profile,
truncate_gradient=op.truncate_gradient,
# TODO: This seems questionable
name=op.name,
allow_gc=op.allow_gc,
)
new_outs = new_op(*node_ins, return_list=True)
old_new = [] old_new = []
# 3.7 Get replace pairs for those outputs that do not change # 3.7 Get replace pairs for those outputs that do not change
...@@ -1871,15 +1939,21 @@ class ScanMerge(GlobalOptimizer): ...@@ -1871,15 +1939,21 @@ class ScanMerge(GlobalOptimizer):
n_sit_sot=sum([nd.op.n_sit_sot for nd in nodes]), n_sit_sot=sum([nd.op.n_sit_sot for nd in nodes]),
n_shared_outs=sum([nd.op.n_shared_outs for nd in nodes]), n_shared_outs=sum([nd.op.n_shared_outs for nd in nodes]),
n_nit_sot=sum([nd.op.n_nit_sot for nd in nodes]), n_nit_sot=sum([nd.op.n_nit_sot for nd in nodes]),
truncate_gradient=nodes[0].op.truncate_gradient, )
old_op = nodes[0].op
new_op = Scan(
new_inner_ins,
new_inner_outs,
info,
mode=old_op.mode,
profile=old_op.profile,
truncate_gradient=old_op.truncate_gradient,
allow_gc=old_op.allow_gc,
name="&".join([nd.op.name for nd in nodes]), name="&".join([nd.op.name for nd in nodes]),
gpua=False, gpua=False,
as_while=as_while, as_while=as_while,
profile=nodes[0].op.profile,
allow_gc=nodes[0].op.allow_gc,
) )
new_op = Scan(new_inner_ins, new_inner_outs, info, nodes[0].op.mode)
new_outs = new_op(*outer_ins) new_outs = new_op(*outer_ins)
if not isinstance(new_outs, (list, tuple)): if not isinstance(new_outs, (list, tuple)):
...@@ -2005,7 +2079,12 @@ def scan_merge_inouts(fgraph, node): ...@@ -2005,7 +2079,12 @@ def scan_merge_inouts(fgraph, node):
# Equivalent inputs will be stored in inp_equiv, then a new # Equivalent inputs will be stored in inp_equiv, then a new
# scan node created without duplicates. # scan node created without duplicates.
a = ScanArgs( a = ScanArgs(
node.inputs, node.outputs, node.op.inputs, node.op.outputs, node.op.info node.inputs,
node.outputs,
node.op.inputs,
node.op.outputs,
node.op.info,
node.op.as_while,
) )
inp_equiv = {} inp_equiv = {}
...@@ -2044,13 +2123,32 @@ def scan_merge_inouts(fgraph, node): ...@@ -2044,13 +2123,32 @@ def scan_merge_inouts(fgraph, node):
a_inner_outs = a.inner_outputs a_inner_outs = a.inner_outputs
inner_outputs = clone_replace(a_inner_outs, replace=inp_equiv) inner_outputs = clone_replace(a_inner_outs, replace=inp_equiv)
op = Scan(inner_inputs, inner_outputs, info, node.op.mode) new_op = Scan(
outputs = op(*outer_inputs) inner_inputs,
inner_outputs,
info,
mode=node.op.mode,
gpua=node.op.gpua,
as_while=node.op.as_while,
profile=node.op.profile,
truncate_gradient=node.op.truncate_gradient,
# TODO: This seems questionable
name=node.op.name,
allow_gc=node.op.allow_gc,
)
outputs = new_op(*outer_inputs)
if not isinstance(outputs, (list, tuple)): if not isinstance(outputs, (list, tuple)):
outputs = [outputs] outputs = [outputs]
na = ScanArgs(outer_inputs, outputs, op.inputs, op.outputs, op.info) na = ScanArgs(
outer_inputs,
outputs,
new_op.inputs,
new_op.outputs,
new_op.info,
new_op.as_while,
)
remove = [node] remove = [node]
else: else:
na = a na = a
...@@ -2302,7 +2400,19 @@ class PushOutDot1(GlobalOptimizer): ...@@ -2302,7 +2400,19 @@ class PushOutDot1(GlobalOptimizer):
new_inner_inps, new_inner_outs = reconstruct_graph( new_inner_inps, new_inner_outs = reconstruct_graph(
_new_inner_inps, _new_inner_outs _new_inner_inps, _new_inner_outs
) )
new_op = Scan(new_inner_inps, new_inner_outs, new_info, op.mode) new_op = Scan(
new_inner_inps,
new_inner_outs,
new_info,
mode=op.mode,
gpua=op.gpua,
as_while=op.as_while,
profile=op.profile,
truncate_gradient=op.truncate_gradient,
# TODO: This seems questionable
name=op.name,
allow_gc=op.allow_gc,
)
_scan_inputs = ( _scan_inputs = (
[node.inputs[0]] [node.inputs[0]]
+ outer_seqs + outer_seqs
......
...@@ -701,12 +701,6 @@ def compress_outs(op, not_required, inputs): ...@@ -701,12 +701,6 @@ def compress_outs(op, not_required, inputs):
n_sit_sot=0, n_sit_sot=0,
n_shared_outs=0, n_shared_outs=0,
n_nit_sot=0, n_nit_sot=0,
truncate_gradient=op.info.truncate_gradient,
name=op.info.name,
gpua=op.info.gpua,
as_while=op.info.as_while,
profile=op.info.profile,
allow_gc=op.info.allow_gc,
) )
op_inputs = op.inputs[: op.n_seqs] op_inputs = op.inputs[: op.n_seqs]
...@@ -886,16 +880,18 @@ class ScanArgs: ...@@ -886,16 +880,18 @@ class ScanArgs:
_inner_inputs, _inner_inputs,
_inner_outputs, _inner_outputs,
info, info,
as_while,
clone=True, clone=True,
): ):
self.n_steps = outer_inputs[0] self.n_steps = outer_inputs[0]
self.as_while = as_while
if clone: if clone:
rval = reconstruct_graph(_inner_inputs, _inner_outputs, "") rval = reconstruct_graph(_inner_inputs, _inner_outputs, "")
else: else:
rval = (_inner_inputs, _inner_outputs) rval = (_inner_inputs, _inner_outputs)
if info.as_while: if self.as_while:
self.cond = [rval[1][-1]] self.cond = [rval[1][-1]]
inner_outputs = rval[1][:-1] inner_outputs = rval[1][:-1]
else: else:
...@@ -1000,18 +996,6 @@ class ScanArgs: ...@@ -1000,18 +996,6 @@ class ScanArgs:
assert p == len(outer_outputs) assert p == len(outer_outputs)
assert q == len(inner_outputs) assert q == len(inner_outputs)
self.other_info = {
k: getattr(info, k)
for k in (
"truncate_gradient",
"name",
"gpua",
"as_while",
"profile",
"allow_gc",
)
}
@staticmethod @staticmethod
def from_node(node, clone=False): def from_node(node, clone=False):
from aesara.scan.op import Scan from aesara.scan.op import Scan
...@@ -1024,6 +1008,7 @@ class ScanArgs: ...@@ -1024,6 +1008,7 @@ class ScanArgs:
node.op.inputs, node.op.inputs,
node.op.outputs, node.op.outputs,
node.op.info, node.op.info,
node.op.as_while,
clone=clone, clone=clone,
) )
...@@ -1041,14 +1026,8 @@ class ScanArgs: ...@@ -1041,14 +1026,8 @@ class ScanArgs:
n_shared_outs=0, n_shared_outs=0,
n_mit_mot_outs=0, n_mit_mot_outs=0,
mit_mot_out_slices=(), mit_mot_out_slices=(),
truncate_gradient=-1,
name=None,
gpua=False,
as_while=False,
profile=False,
allow_gc=False,
) )
res = cls([1], [], [], [], info) res = cls([1], [], [], [], info, False)
res.n_steps = None res.n_steps = None
return res return res
...@@ -1152,7 +1131,6 @@ class ScanArgs: ...@@ -1152,7 +1131,6 @@ class ScanArgs:
n_shared_outs=len(self.outer_in_shared), n_shared_outs=len(self.outer_in_shared),
n_mit_mot_outs=sum(len(s) for s in self.mit_mot_out_slices), n_mit_mot_outs=sum(len(s) for s in self.mit_mot_out_slices),
mit_mot_out_slices=tuple(self.mit_mot_out_slices), mit_mot_out_slices=tuple(self.mit_mot_out_slices),
**self.other_info,
) )
def get_alt_field(self, var_info, alt_prefix): def get_alt_field(self, var_info, alt_prefix):
...@@ -1341,7 +1319,6 @@ class ScanArgs: ...@@ -1341,7 +1319,6 @@ class ScanArgs:
"mit_mot_out_slices", "mit_mot_out_slices",
"mit_mot_in_slices", "mit_mot_in_slices",
"mit_sot_in_slices", "mit_sot_in_slices",
"other_info",
) )
): ):
setattr(res, attr, copy.copy(getattr(self, attr))) setattr(res, attr, copy.copy(getattr(self, attr)))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论