提交 9a2280b8 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Move Scan helper methods to ScanMethodsMixin

上级 e85c7fd0
......@@ -237,7 +237,9 @@ N.B.:
outer_inputs = s.owner.inputs
inner_to_outer_inputs = {
inner_inputs[i]: outer_inputs[o]
for i, o in s.owner.op.var_mappings["outer_inp_from_inner_inp"].items()
for i, o in s.owner.op.get_oinp_iinp_iout_oout_mappings()[
"outer_inp_from_inner_inp"
].items()
}
print("", file=_file)
......
差异被折叠。
......@@ -1069,9 +1069,9 @@ class ScanArgs:
@property
def var_mappings(self):
from aesara.scan.op import Scan
from aesara.scan.op import ScanMethodsMixin
return Scan.get_oinp_iinp_iout_oout_mappings(self)
return ScanMethodsMixin.get_oinp_iinp_iout_oout_mappings(self)
@property
def field_names(self):
......
......@@ -299,7 +299,7 @@ If the goal is to navigate between variables that are associated with the same
states (ex : going from an outer sequence input to the corresponding inner
sequence input, going from an inner output associated with a recurrent state
to the inner input(s) associated with that same recurrent state, etc.), then
the ``var_mappings`` attribute of the scan op can be used.
the `get_oinp_iinp_iout_oout_mappings_mappings` method of the `Scan` `Op` can be used.
This attribute is a dictionary with 12 {key/value} pairs. The keys are listed
below :
......
......@@ -700,11 +700,12 @@ class TestScan:
# outer_inp_from_inner_inp produce the correct results
scan_node = a.owner.inputs[0].owner
result = scan_node.op.var_mappings["outer_inp_from_outer_out"]
var_mappings = scan_node.op.get_oinp_iinp_iout_oout_mappings()
result = var_mappings["outer_inp_from_outer_out"]
expected_result = {0: 1, 1: 2}
assert result == expected_result
result = scan_node.op.var_mappings["outer_inp_from_inner_inp"]
result = var_mappings["outer_inp_from_inner_inp"]
expected_result = {0: 1, 1: 1, 2: 2, 3: 2}
assert result == expected_result
......@@ -733,11 +734,12 @@ class TestScan:
# outer_inp_from_inner_inp produce the correct results
scan_node = out.owner.inputs[0].owner
result = scan_node.op.var_mappings["outer_inp_from_outer_out"]
var_mappings = scan_node.op.get_oinp_iinp_iout_oout_mappings()
result = var_mappings["outer_inp_from_outer_out"]
expected_result = {0: 2}
assert result == expected_result
result = scan_node.op.var_mappings["outer_inp_from_inner_inp"]
result = var_mappings["outer_inp_from_inner_inp"]
expected_result = {0: 1, 1: 2, 2: 2}
assert result == expected_result
......@@ -1685,11 +1687,12 @@ class TestScan:
# outer_inp_from_inner_inp produce the correct results
scan_node = list(updates.values())[0].owner
result = scan_node.op.var_mappings["outer_inp_from_outer_out"]
var_mappings = scan_node.op.get_oinp_iinp_iout_oout_mappings()
result = var_mappings["outer_inp_from_outer_out"]
expected_result = {0: 3, 1: 5, 2: 4}
assert result == expected_result
result = scan_node.op.var_mappings["outer_inp_from_inner_inp"]
result = var_mappings["outer_inp_from_inner_inp"]
expected_result = {0: 1, 1: 2, 2: 3, 3: 4, 4: 6}
assert result == expected_result
......@@ -3491,7 +3494,7 @@ class TestScan:
# Compare the mappings with the expected values
scan_node = scan_outputs[0].owner.inputs[0].owner
mappings = scan_node.op.var_mappings
mappings = scan_node.op.get_oinp_iinp_iout_oout_mappings()
assert mappings["inner_inp_from_outer_inp"] == {
0: [],
......
......@@ -253,7 +253,6 @@ def test_ScanArgs():
# here we make sure it doesn't (and that all the inputs are the same)
assert scan_args.inputs == scan_op.inputs
assert scan_args.info == scan_op.info
assert scan_args.var_mappings == scan_op.var_mappings
# Check that `ScanArgs.find_among_fields` works
test_v = scan_op.inner_seqs(scan_op.inputs)[1]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论