提交 a7840844 authored 作者: kc611's avatar kc611 提交者: Brandon T. Willard

Added ScanArgs

上级 45a2f784
......@@ -17,7 +17,7 @@ from aesara.link.utils import fgraph_to_python
from aesara.scalar import Softplus
from aesara.scalar.basic import Cast, Clip, Composite, Identity, ScalarOp, Second
from aesara.scan.op import Scan
from aesara.scan.utils import scan_args as ScanArgs
from aesara.scan.utils import ScanArgs
from aesara.tensor.basic import (
Alloc,
AllocDiag,
......
......@@ -80,11 +80,11 @@ from aesara.graph.opt import GlobalOptimizer, in2out, local_optimizer
from aesara.graph.optdb import EquilibriumDB, SequenceDB
from aesara.scan.op import Scan
from aesara.scan.utils import (
ScanArgs,
compress_outs,
expand_empty,
reconstruct_graph,
safe_new,
scan_args,
scan_can_remove_outs,
)
from aesara.tensor import basic_opt, math_opt
......@@ -751,9 +751,9 @@ class PushOutScanOutput(GlobalOptimizer):
op = node.op
# Use scan_args to parse the inputs and outputs of scan for ease of
# Use `ScanArgs` to parse the inputs and outputs of scan for ease of
# use
args = scan_args(node.inputs, node.outputs, op.inputs, op.outputs, op.info)
args = ScanArgs(node.inputs, node.outputs, op.inputs, op.outputs, op.info)
new_scan_node = None
clients = {}
......@@ -917,7 +917,7 @@ class PushOutScanOutput(GlobalOptimizer):
fgraph, old_scan_node, old_scan_args, add_as_nitsots
)
new_scan_args = scan_args(
new_scan_args = ScanArgs(
new_scan_node.inputs,
new_scan_node.outputs,
new_scan_node.op.inputs,
......@@ -944,13 +944,12 @@ class PushOutScanOutput(GlobalOptimizer):
old_scan_node.inputs[0] for i in range(nb_new_outs)
]
# Create the scan_args corresponding to the new scan op to
# create
# Create the `ScanArgs` corresponding to the new `Scan` `Op` to create
new_scan_args = copy.copy(old_scan_args)
new_scan_args.inner_out_nit_sot.extend(new_outputs_inner)
new_scan_args.outer_in_nit_sot.extend(new_nitsots_initial_value)
# Create the scan op from the scan_args
# Create the `Scan` `Op` from the `ScanArgs`
new_scan_op = Scan(
new_scan_args.inner_inputs, new_scan_args.inner_outputs, new_scan_args.info
)
......@@ -1970,7 +1969,7 @@ def scan_merge_inouts(fgraph, node):
# Do a first pass to merge identical external inputs.
# Equivalent inputs will be stored in inp_equiv, then a new
# scan node created without duplicates.
a = scan_args(
a = ScanArgs(
node.inputs, node.outputs, node.op.inputs, node.op.outputs, node.op.info
)
......@@ -2016,7 +2015,7 @@ def scan_merge_inouts(fgraph, node):
if not isinstance(outputs, (list, tuple)):
outputs = [outputs]
na = scan_args(outer_inputs, outputs, op.inputs, op.outputs, op.info)
na = ScanArgs(outer_inputs, outputs, op.inputs, op.outputs, op.info)
remove = [node]
else:
na = a
......
差异被折叠。
......@@ -276,12 +276,12 @@ by usage.
Accessing/manipulating Scan's inputs and outputs by type
--------------------------------------------------------
Declared in ``utils.py``, the class ``scan_args`` handles the
Declared in ``utils.py``, the class ``ScanArgs`` handles the
parsing of the inputs and outputs (both inner and outer) to a format
that is easier to analyse and manipulate. Without this class,
analysing Scan's inputs and outputs often required convoluted logic
which make for code that is hard to read and to maintain. Because of
this, you should favor using ``scan_args`` when it is practical and
this, you should favor using ``ScanArgs`` when it is practical and
appropriate to do so.
The scan op also defines a few helper functions for this purpose, such as
......
......@@ -5,7 +5,7 @@ import pytest
import aesara
from aesara import tensor as aet
from aesara.scan.utils import map_variables
from aesara.scan.utils import ScanArgs, map_variables
from aesara.tensor.type import scalar, vector
......@@ -145,3 +145,15 @@ class TestMapVariables:
shared.update = shared + 1
with pytest.raises(NotImplementedError):
map_variables(self.replacer, [t])
def test_ScanArgs():
scan_args = ScanArgs.create_empty()
assert scan_args.n_steps is None
for name in scan_args.field_names:
if name == "n_steps":
continue
assert len(getattr(scan_args, name)) == 0
with pytest.raises(TypeError):
ScanArgs.from_node(aet.ones(2).owner)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论