提交 e85c7fd0 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Replace Scan info dict with ScanInfo dataclass

上级 d83ff33c
...@@ -13,7 +13,7 @@ from aesara.graph.fg import MissingInputError ...@@ -13,7 +13,7 @@ from aesara.graph.fg import MissingInputError
from aesara.graph.op import get_test_value from aesara.graph.op import get_test_value
from aesara.graph.utils import TestValueError from aesara.graph.utils import TestValueError
from aesara.scan import utils from aesara.scan import utils
from aesara.scan.op import Scan from aesara.scan.op import Scan, ScanInfo
from aesara.scan.utils import safe_new, traverse from aesara.scan.utils import safe_new, traverse
from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import minimum from aesara.tensor.math import minimum
...@@ -1022,31 +1022,32 @@ def scan( ...@@ -1022,31 +1022,32 @@ def scan(
# Step 7. Create the Scan Op # Step 7. Create the Scan Op
## ##
tap_array = mit_sot_tap_array + [[-1] for x in range(n_sit_sot)] tap_array = tuple(tuple(v) for v in mit_sot_tap_array) + tuple(
(-1,) for x in range(n_sit_sot)
)
if allow_gc is None: if allow_gc is None:
allow_gc = config.scan__allow_gc allow_gc = config.scan__allow_gc
info = OrderedDict()
info = ScanInfo(
info["tap_array"] = tap_array tap_array=tap_array,
info["n_seqs"] = n_seqs n_seqs=n_seqs,
info["n_mit_mot"] = n_mit_mot n_mit_mot=n_mit_mot,
info["n_mit_mot_outs"] = n_mit_mot_outs n_mit_mot_outs=n_mit_mot_outs,
info["mit_mot_out_slices"] = mit_mot_out_slices mit_mot_out_slices=tuple(tuple(v) for v in mit_mot_out_slices),
info["n_mit_sot"] = n_mit_sot n_mit_sot=n_mit_sot,
info["n_sit_sot"] = n_sit_sot n_sit_sot=n_sit_sot,
info["n_shared_outs"] = n_shared_outs n_shared_outs=n_shared_outs,
info["n_nit_sot"] = n_nit_sot n_nit_sot=n_nit_sot,
info["truncate_gradient"] = truncate_gradient truncate_gradient=truncate_gradient,
info["name"] = name name=name,
info["mode"] = mode gpua=False,
info["destroy_map"] = OrderedDict() as_while=as_while,
info["gpua"] = False profile=profile,
info["as_while"] = as_while allow_gc=allow_gc,
info["profile"] = profile strict=strict,
info["allow_gc"] = allow_gc )
info["strict"] = strict
local_op = Scan(inner_inputs, new_outs, info, mode)
local_op = Scan(inner_inputs, new_outs, info)
## ##
# Step 8. Compute the outputs using the scan op # Step 8. Compute the outputs using the scan op
......
差异被折叠。
差异被折叠。
差异被折叠。
-e ./ -e ./
dataclasses>=0.7; python_version < '3.7'
filelock filelock
flake8==3.8.4 flake8==3.8.4
pep8 pep8
......
#!/usr/bin/env python #!/usr/bin/env python
import sys
from setuptools import find_packages, setup from setuptools import find_packages, setup
import versioneer import versioneer
...@@ -43,6 +45,11 @@ Programming Language :: Python :: 3.9 ...@@ -43,6 +45,11 @@ Programming Language :: Python :: 3.9
""" """
CLASSIFIERS = [_f for _f in CLASSIFIERS.split("\n") if _f] CLASSIFIERS = [_f for _f in CLASSIFIERS.split("\n") if _f]
install_requires = ["numpy>=1.17.0", "scipy>=0.14", "filelock"]
if sys.version_info[0:2] < (3, 7):
install_requires += ["dataclasses"]
if __name__ == "__main__": if __name__ == "__main__":
setup( setup(
name=NAME, name=NAME,
...@@ -57,7 +64,7 @@ if __name__ == "__main__": ...@@ -57,7 +64,7 @@ if __name__ == "__main__":
license=LICENSE, license=LICENSE,
platforms=PLATFORMS, platforms=PLATFORMS,
packages=find_packages(exclude=["tests", "tests.*"]), packages=find_packages(exclude=["tests", "tests.*"]),
install_requires=["numpy>=1.17.0", "scipy>=0.14", "filelock"], install_requires=install_requires,
package_data={ package_data={
"": [ "": [
"*.txt", "*.txt",
......
...@@ -252,10 +252,7 @@ def test_ScanArgs(): ...@@ -252,10 +252,7 @@ def test_ScanArgs():
# The `scan_args` base class always clones the inner-graph; # The `scan_args` base class always clones the inner-graph;
# here we make sure it doesn't (and that all the inputs are the same) # here we make sure it doesn't (and that all the inputs are the same)
assert scan_args.inputs == scan_op.inputs assert scan_args.inputs == scan_op.inputs
scan_op_info = dict(scan_op.info) assert scan_args.info == scan_op.info
# The `ScanInfo` dictionary has the wrong order and an extra entry
del scan_op_info["strict"]
assert dict(scan_args.info) == scan_op_info
assert scan_args.var_mappings == scan_op.var_mappings assert scan_args.var_mappings == scan_op.var_mappings
# Check that `ScanArgs.find_among_fields` works # Check that `ScanArgs.find_among_fields` works
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论