提交 7c959797 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add tests for ScanArgs and change clone behavior of ScanArgs.from_node

上级 0d0ee9d8
...@@ -895,10 +895,21 @@ class ScanArgs: ...@@ -895,10 +895,21 @@ class ScanArgs:
nested_list_fields = ("inner_in_mit_mot", "inner_in_mit_sot", "inner_out_mit_mot") nested_list_fields = ("inner_in_mit_mot", "inner_in_mit_sot", "inner_out_mit_mot")
def __init__( def __init__(
self, outer_inputs, outer_outputs, _inner_inputs, _inner_outputs, info self,
outer_inputs,
outer_outputs,
_inner_inputs,
_inner_outputs,
info,
clone=True,
): ):
self.n_steps = outer_inputs[0] self.n_steps = outer_inputs[0]
rval = reconstruct_graph(_inner_inputs, _inner_outputs, "")
if clone:
rval = reconstruct_graph(_inner_inputs, _inner_outputs, "")
else:
rval = (_inner_inputs, _inner_outputs)
if info["as_while"]: if info["as_while"]:
self.cond = [rval[1][-1]] self.cond = [rval[1][-1]]
inner_outputs = rval[1][:-1] inner_outputs = rval[1][:-1]
...@@ -1019,13 +1030,18 @@ class ScanArgs: ...@@ -1019,13 +1030,18 @@ class ScanArgs:
self.other_info[k] = info[k] self.other_info[k] = info[k]
@staticmethod @staticmethod
def from_node(node): def from_node(node, clone=False):
from aesara.scan.op import Scan from aesara.scan.op import Scan
if not isinstance(node.op, Scan): if not isinstance(node.op, Scan):
raise TypeError("{} is not a Scan node".format(node)) raise TypeError("{} is not a Scan node".format(node))
return ScanArgs( return ScanArgs(
node.inputs, node.outputs, node.op.inputs, node.op.outputs, node.op.info node.inputs,
node.outputs,
node.op.inputs,
node.op.outputs,
node.op.info,
clone=clone,
) )
@classmethod @classmethod
...@@ -1242,8 +1258,6 @@ class ScanArgs: ...@@ -1242,8 +1258,6 @@ class ScanArgs:
return field_info return field_info
def get_dependent_nodes(self, i, seen=None): def get_dependent_nodes(self, i, seen=None):
from aesara.graph import inputs as at_inputs
if seen is None: if seen is None:
seen = {i} seen = {i}
else: else:
...@@ -1302,7 +1316,7 @@ class ScanArgs: ...@@ -1302,7 +1316,7 @@ class ScanArgs:
# If starting from an inner-input, then we need to find any # If starting from an inner-input, then we need to find any
# inner-outputs that depend on it. # inner-outputs that depend on it.
for out_n in self.inner_outputs: for out_n in self.inner_outputs:
if i in at_inputs([out_n]): if i in graph_inputs([out_n]):
if out_n not in seen: if out_n not in seen:
dependent_nodes.add(out_n) dependent_nodes.add(out_n)
......
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论