提交 9a2280b8 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Move Scan helper methods to ScanMethodsMixin

上级 e85c7fd0
...@@ -237,7 +237,9 @@ N.B.: ...@@ -237,7 +237,9 @@ N.B.:
outer_inputs = s.owner.inputs outer_inputs = s.owner.inputs
inner_to_outer_inputs = { inner_to_outer_inputs = {
inner_inputs[i]: outer_inputs[o] inner_inputs[i]: outer_inputs[o]
for i, o in s.owner.op.var_mappings["outer_inp_from_inner_inp"].items() for i, o in s.owner.op.get_oinp_iinp_iout_oout_mappings()[
"outer_inp_from_inner_inp"
].items()
} }
print("", file=_file) print("", file=_file)
......
差异被折叠。
...@@ -1069,9 +1069,9 @@ class ScanArgs: ...@@ -1069,9 +1069,9 @@ class ScanArgs:
@property @property
def var_mappings(self): def var_mappings(self):
from aesara.scan.op import Scan from aesara.scan.op import ScanMethodsMixin
return Scan.get_oinp_iinp_iout_oout_mappings(self) return ScanMethodsMixin.get_oinp_iinp_iout_oout_mappings(self)
@property @property
def field_names(self): def field_names(self):
......
...@@ -299,7 +299,7 @@ If the goal is to navigate between variables that are associated with the same ...@@ -299,7 +299,7 @@ If the goal is to navigate between variables that are associated with the same
states (ex : going from an outer sequence input to the corresponding inner states (ex : going from an outer sequence input to the corresponding inner
sequence input, going from an inner output associated with a recurrent state sequence input, going from an inner output associated with a recurrent state
to the inner input(s) associated with that same recurrent state, etc.), then to the inner input(s) associated with that same recurrent state, etc.), then
the ``var_mappings`` attribute of the scan op can be used. the `get_oinp_iinp_iout_oout_mappings_mappings` method of the `Scan` `Op` can be used.
This attribute is a dictionary with 12 {key/value} pairs. The keys are listed This attribute is a dictionary with 12 {key/value} pairs. The keys are listed
below : below :
......
...@@ -700,11 +700,12 @@ class TestScan: ...@@ -700,11 +700,12 @@ class TestScan:
# outer_inp_from_inner_inp produce the correct results # outer_inp_from_inner_inp produce the correct results
scan_node = a.owner.inputs[0].owner scan_node = a.owner.inputs[0].owner
result = scan_node.op.var_mappings["outer_inp_from_outer_out"] var_mappings = scan_node.op.get_oinp_iinp_iout_oout_mappings()
result = var_mappings["outer_inp_from_outer_out"]
expected_result = {0: 1, 1: 2} expected_result = {0: 1, 1: 2}
assert result == expected_result assert result == expected_result
result = scan_node.op.var_mappings["outer_inp_from_inner_inp"] result = var_mappings["outer_inp_from_inner_inp"]
expected_result = {0: 1, 1: 1, 2: 2, 3: 2} expected_result = {0: 1, 1: 1, 2: 2, 3: 2}
assert result == expected_result assert result == expected_result
...@@ -733,11 +734,12 @@ class TestScan: ...@@ -733,11 +734,12 @@ class TestScan:
# outer_inp_from_inner_inp produce the correct results # outer_inp_from_inner_inp produce the correct results
scan_node = out.owner.inputs[0].owner scan_node = out.owner.inputs[0].owner
result = scan_node.op.var_mappings["outer_inp_from_outer_out"] var_mappings = scan_node.op.get_oinp_iinp_iout_oout_mappings()
result = var_mappings["outer_inp_from_outer_out"]
expected_result = {0: 2} expected_result = {0: 2}
assert result == expected_result assert result == expected_result
result = scan_node.op.var_mappings["outer_inp_from_inner_inp"] result = var_mappings["outer_inp_from_inner_inp"]
expected_result = {0: 1, 1: 2, 2: 2} expected_result = {0: 1, 1: 2, 2: 2}
assert result == expected_result assert result == expected_result
...@@ -1685,11 +1687,12 @@ class TestScan: ...@@ -1685,11 +1687,12 @@ class TestScan:
# outer_inp_from_inner_inp produce the correct results # outer_inp_from_inner_inp produce the correct results
scan_node = list(updates.values())[0].owner scan_node = list(updates.values())[0].owner
result = scan_node.op.var_mappings["outer_inp_from_outer_out"] var_mappings = scan_node.op.get_oinp_iinp_iout_oout_mappings()
result = var_mappings["outer_inp_from_outer_out"]
expected_result = {0: 3, 1: 5, 2: 4} expected_result = {0: 3, 1: 5, 2: 4}
assert result == expected_result assert result == expected_result
result = scan_node.op.var_mappings["outer_inp_from_inner_inp"] result = var_mappings["outer_inp_from_inner_inp"]
expected_result = {0: 1, 1: 2, 2: 3, 3: 4, 4: 6} expected_result = {0: 1, 1: 2, 2: 3, 3: 4, 4: 6}
assert result == expected_result assert result == expected_result
...@@ -3491,7 +3494,7 @@ class TestScan: ...@@ -3491,7 +3494,7 @@ class TestScan:
# Compare the mappings with the expected values # Compare the mappings with the expected values
scan_node = scan_outputs[0].owner.inputs[0].owner scan_node = scan_outputs[0].owner.inputs[0].owner
mappings = scan_node.op.var_mappings mappings = scan_node.op.get_oinp_iinp_iout_oout_mappings()
assert mappings["inner_inp_from_outer_inp"] == { assert mappings["inner_inp_from_outer_inp"] == {
0: [], 0: [],
......
...@@ -253,7 +253,6 @@ def test_ScanArgs(): ...@@ -253,7 +253,6 @@ def test_ScanArgs():
# here we make sure it doesn't (and that all the inputs are the same) # here we make sure it doesn't (and that all the inputs are the same)
assert scan_args.inputs == scan_op.inputs assert scan_args.inputs == scan_op.inputs
assert scan_args.info == scan_op.info assert scan_args.info == scan_op.info
assert scan_args.var_mappings == scan_op.var_mappings
# Check that `ScanArgs.find_among_fields` works # Check that `ScanArgs.find_among_fields` works
test_v = scan_op.inner_seqs(scan_op.inputs)[1] test_v = scan_op.inner_seqs(scan_op.inputs)[1]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论