提交 7c959797 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add tests for ScanArgs and change clone behavior of ScanArgs.from_node

上级 0d0ee9d8
......@@ -895,10 +895,21 @@ class ScanArgs:
nested_list_fields = ("inner_in_mit_mot", "inner_in_mit_sot", "inner_out_mit_mot")
def __init__(
self, outer_inputs, outer_outputs, _inner_inputs, _inner_outputs, info
self,
outer_inputs,
outer_outputs,
_inner_inputs,
_inner_outputs,
info,
clone=True,
):
self.n_steps = outer_inputs[0]
rval = reconstruct_graph(_inner_inputs, _inner_outputs, "")
if clone:
rval = reconstruct_graph(_inner_inputs, _inner_outputs, "")
else:
rval = (_inner_inputs, _inner_outputs)
if info["as_while"]:
self.cond = [rval[1][-1]]
inner_outputs = rval[1][:-1]
......@@ -1019,13 +1030,18 @@ class ScanArgs:
self.other_info[k] = info[k]
@staticmethod
def from_node(node):
def from_node(node, clone=False):
from aesara.scan.op import Scan
if not isinstance(node.op, Scan):
raise TypeError("{} is not a Scan node".format(node))
return ScanArgs(
node.inputs, node.outputs, node.op.inputs, node.op.outputs, node.op.info
node.inputs,
node.outputs,
node.op.inputs,
node.op.outputs,
node.op.info,
clone=clone,
)
@classmethod
......@@ -1242,8 +1258,6 @@ class ScanArgs:
return field_info
def get_dependent_nodes(self, i, seen=None):
from aesara.graph import inputs as at_inputs
if seen is None:
seen = {i}
else:
......@@ -1302,7 +1316,7 @@ class ScanArgs:
# If starting from an inner-input, then we need to find any
# inner-outputs that depend on it.
for out_n in self.inner_outputs:
if i in at_inputs([out_n]):
if i in graph_inputs([out_n]):
if out_n not in seen:
dependent_nodes.add(out_n)
......
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论