提交 e85c7fd0 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Replace Scan info dict with ScanInfo dataclass

上级 d83ff33c
......@@ -13,7 +13,7 @@ from aesara.graph.fg import MissingInputError
from aesara.graph.op import get_test_value
from aesara.graph.utils import TestValueError
from aesara.scan import utils
from aesara.scan.op import Scan
from aesara.scan.op import Scan, ScanInfo
from aesara.scan.utils import safe_new, traverse
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import minimum
......@@ -1022,31 +1022,32 @@ def scan(
# Step 7. Create the Scan Op
##
tap_array = mit_sot_tap_array + [[-1] for x in range(n_sit_sot)]
tap_array = tuple(tuple(v) for v in mit_sot_tap_array) + tuple(
(-1,) for x in range(n_sit_sot)
)
if allow_gc is None:
allow_gc = config.scan__allow_gc
info = OrderedDict()
info["tap_array"] = tap_array
info["n_seqs"] = n_seqs
info["n_mit_mot"] = n_mit_mot
info["n_mit_mot_outs"] = n_mit_mot_outs
info["mit_mot_out_slices"] = mit_mot_out_slices
info["n_mit_sot"] = n_mit_sot
info["n_sit_sot"] = n_sit_sot
info["n_shared_outs"] = n_shared_outs
info["n_nit_sot"] = n_nit_sot
info["truncate_gradient"] = truncate_gradient
info["name"] = name
info["mode"] = mode
info["destroy_map"] = OrderedDict()
info["gpua"] = False
info["as_while"] = as_while
info["profile"] = profile
info["allow_gc"] = allow_gc
info["strict"] = strict
local_op = Scan(inner_inputs, new_outs, info)
info = ScanInfo(
tap_array=tap_array,
n_seqs=n_seqs,
n_mit_mot=n_mit_mot,
n_mit_mot_outs=n_mit_mot_outs,
mit_mot_out_slices=tuple(tuple(v) for v in mit_mot_out_slices),
n_mit_sot=n_mit_sot,
n_sit_sot=n_sit_sot,
n_shared_outs=n_shared_outs,
n_nit_sot=n_nit_sot,
truncate_gradient=truncate_gradient,
name=name,
gpua=False,
as_while=as_while,
profile=profile,
allow_gc=allow_gc,
strict=strict,
)
local_op = Scan(inner_inputs, new_outs, info, mode)
##
# Step 8. Compute the outputs using the scan op
......
差异被折叠。
差异被折叠。
差异被折叠。
-e ./
dataclasses>=0.7; python_version < '3.7'
filelock
flake8==3.8.4
pep8
......
#!/usr/bin/env python
import sys
from setuptools import find_packages, setup
import versioneer
......@@ -43,6 +45,11 @@ Programming Language :: Python :: 3.9
"""
CLASSIFIERS = [_f for _f in CLASSIFIERS.split("\n") if _f]
install_requires = ["numpy>=1.17.0", "scipy>=0.14", "filelock"]
if sys.version_info[0:2] < (3, 7):
install_requires += ["dataclasses"]
if __name__ == "__main__":
setup(
name=NAME,
......@@ -57,7 +64,7 @@ if __name__ == "__main__":
license=LICENSE,
platforms=PLATFORMS,
packages=find_packages(exclude=["tests", "tests.*"]),
install_requires=["numpy>=1.17.0", "scipy>=0.14", "filelock"],
install_requires=install_requires,
package_data={
"": [
"*.txt",
......
......@@ -252,10 +252,7 @@ def test_ScanArgs():
# The `scan_args` base class always clones the inner-graph;
# here we make sure it doesn't (and that all the inputs are the same)
assert scan_args.inputs == scan_op.inputs
scan_op_info = dict(scan_op.info)
# The `ScanInfo` dictionary has the wrong order and an extra entry
del scan_op_info["strict"]
assert dict(scan_args.info) == scan_op_info
assert scan_args.info == scan_op.info
assert scan_args.var_mappings == scan_op.var_mappings
# Check that `ScanArgs.find_among_fields` works
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论