提交 ccbf2e98 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix typing issues in aesara.scan

上级 05da60c3
...@@ -4,7 +4,7 @@ import copy ...@@ -4,7 +4,7 @@ import copy
import dataclasses import dataclasses
import logging import logging
from sys import maxsize from sys import maxsize
from typing import Dict, List, Tuple from typing import Dict, List, Optional, Tuple
import numpy as np import numpy as np
...@@ -15,8 +15,8 @@ from aesara.compile import optdb ...@@ -15,8 +15,8 @@ from aesara.compile import optdb
from aesara.compile.function.types import deep_copy_op from aesara.compile.function.types import deep_copy_op
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import ( from aesara.graph.basic import (
Apply,
Constant, Constant,
Node,
Variable, Variable,
clone_replace, clone_replace,
equal_computations, equal_computations,
...@@ -30,6 +30,7 @@ from aesara.graph.fg import FunctionGraph ...@@ -30,6 +30,7 @@ from aesara.graph.fg import FunctionGraph
from aesara.graph.op import compute_test_value from aesara.graph.op import compute_test_value
from aesara.graph.opt import GlobalOptimizer, in2out, local_optimizer from aesara.graph.opt import GlobalOptimizer, in2out, local_optimizer
from aesara.graph.optdb import EquilibriumDB, SequenceDB from aesara.graph.optdb import EquilibriumDB, SequenceDB
from aesara.graph.type import HasShape
from aesara.graph.utils import InconsistencyError from aesara.graph.utils import InconsistencyError
from aesara.scan.op import Scan, ScanInfo from aesara.scan.op import Scan, ScanInfo
from aesara.scan.utils import ( from aesara.scan.utils import (
...@@ -652,7 +653,7 @@ def inner_sitsot_only_last_step_used( ...@@ -652,7 +653,7 @@ def inner_sitsot_only_last_step_used(
if len(fgraph.clients[outer_var]) == 1: if len(fgraph.clients[outer_var]) == 1:
client = fgraph.clients[outer_var][0][0] client = fgraph.clients[outer_var][0][0]
if client != "output" and isinstance(client.op, Subtensor): if isinstance(client, Apply) and isinstance(client.op, Subtensor):
lst = get_idx_list(client.inputs, client.op.idx_list) lst = get_idx_list(client.inputs, client.op.idx_list)
if len(lst) == 1 and at.extract_constant(lst[0]) == -1: if len(lst) == 1 and at.extract_constant(lst[0]) == -1:
return True return True
...@@ -662,10 +663,12 @@ def inner_sitsot_only_last_step_used( ...@@ -662,10 +663,12 @@ def inner_sitsot_only_last_step_used(
def get_outer_ndim(var: Variable, scan_args: ScanArgs) -> int: def get_outer_ndim(var: Variable, scan_args: ScanArgs) -> int:
"""Determine the number of dimension a variable would have if it was pushed out of a `Scan`.""" """Determine the number of dimension a variable would have if it was pushed out of a `Scan`."""
assert isinstance(var.type, HasShape)
if var in scan_args.inner_in_non_seqs or isinstance(var, Constant): if var in scan_args.inner_in_non_seqs or isinstance(var, Constant):
outer_ndim = var.ndim outer_ndim = var.type.ndim
else: else:
outer_ndim = var.ndim + 1 outer_ndim = var.type.ndim + 1
return outer_ndim return outer_ndim
...@@ -673,11 +676,11 @@ def get_outer_ndim(var: Variable, scan_args: ScanArgs) -> int: ...@@ -673,11 +676,11 @@ def get_outer_ndim(var: Variable, scan_args: ScanArgs) -> int:
def push_out_inner_vars( def push_out_inner_vars(
fgraph: FunctionGraph, fgraph: FunctionGraph,
inner_vars: List[Variable], inner_vars: List[Variable],
old_scan_node: Node, old_scan_node: Apply,
old_scan_args: ScanArgs, old_scan_args: ScanArgs,
) -> Tuple[List[Variable], ScanArgs, Dict[Variable, Variable]]: ) -> Tuple[List[Variable], ScanArgs, Dict[Variable, Variable]]:
outer_vars = [None] * len(inner_vars) tmp_outer_vars: List[Optional[Variable]] = []
new_scan_node = old_scan_node new_scan_node = old_scan_node
new_scan_args = old_scan_args new_scan_args = old_scan_args
replacements: Dict[Variable, Variable] = {} replacements: Dict[Variable, Variable] = {}
...@@ -688,26 +691,32 @@ def push_out_inner_vars( ...@@ -688,26 +691,32 @@ def push_out_inner_vars(
var = inner_vars[idx] var = inner_vars[idx]
new_outer_var: Optional[Variable] = None
if var in old_scan_args.inner_in_seqs: if var in old_scan_args.inner_in_seqs:
idx_seq = old_scan_args.inner_in_seqs.index(var) idx_seq = old_scan_args.inner_in_seqs.index(var)
outer_vars[idx] = old_scan_args.outer_in_seqs[idx_seq] new_outer_var = old_scan_args.outer_in_seqs[idx_seq]
elif var in old_scan_args.inner_in_non_seqs: elif var in old_scan_args.inner_in_non_seqs:
idx_non_seq = old_scan_args.inner_in_non_seqs.index(var) idx_non_seq = old_scan_args.inner_in_non_seqs.index(var)
outer_vars[idx] = old_scan_args.outer_in_non_seqs[idx_non_seq] new_outer_var = old_scan_args.outer_in_non_seqs[idx_non_seq]
elif isinstance(var, Constant): elif isinstance(var, Constant):
outer_vars[idx] = var.clone() new_outer_var = var.clone()
elif var in old_scan_args.inner_out_nit_sot: elif var in old_scan_args.inner_out_nit_sot:
idx_nitsot = old_scan_args.inner_out_nit_sot.index(var) idx_nitsot = old_scan_args.inner_out_nit_sot.index(var)
outer_vars[idx] = old_scan_args.outer_out_nit_sot[idx_nitsot] new_outer_var = old_scan_args.outer_out_nit_sot[idx_nitsot]
tmp_outer_vars.append(new_outer_var)
# For the inner_vars that don't already exist in the outer graph, add # For the inner_vars that don't already exist in the outer graph, add
# them as new nitsot outputs to the scan node. # them as new nitsot outputs to the scan node.
idx_add_as_nitsots = [i for i in range(len(outer_vars)) if outer_vars[i] is None] idx_add_as_nitsots = [i for i, v in enumerate(tmp_outer_vars) if v is None]
add_as_nitsots = [inner_vars[idx] for idx in idx_add_as_nitsots] add_as_nitsots = [inner_vars[idx] for idx in idx_add_as_nitsots]
new_outs: List[Variable] = []
if len(add_as_nitsots) > 0: if len(add_as_nitsots) > 0:
new_scan_node, replacements = add_nitsot_outputs( new_scan_node, replacements = add_nitsot_outputs(
...@@ -724,18 +733,25 @@ def push_out_inner_vars( ...@@ -724,18 +733,25 @@ def push_out_inner_vars(
) )
new_outs = new_scan_args.outer_out_nit_sot[-len(add_as_nitsots) :] new_outs = new_scan_args.outer_out_nit_sot[-len(add_as_nitsots) :]
for i in range(len(new_outs)):
outer_vars[idx_add_as_nitsots[i]] = new_outs[i] outer_vars: List[Variable] = []
for i, v in enumerate(tmp_outer_vars):
if i in idx_add_as_nitsots:
outer_vars.append(new_outs.pop(0))
else:
assert v is not None
outer_vars.append(v)
return outer_vars, new_scan_args, replacements return outer_vars, new_scan_args, replacements
def add_nitsot_outputs( def add_nitsot_outputs(
fgraph: FunctionGraph, fgraph: FunctionGraph,
old_scan_node: Node, old_scan_node: Apply,
old_scan_args: ScanArgs, old_scan_args: ScanArgs,
new_outputs_inner, new_outputs_inner,
) -> Tuple[Node, Dict[Variable, Variable]]: ) -> Tuple[Apply, Dict[Variable, Variable]]:
nb_new_outs = len(new_outputs_inner) nb_new_outs = len(new_outputs_inner)
...@@ -764,7 +780,10 @@ def add_nitsot_outputs( ...@@ -764,7 +780,10 @@ def add_nitsot_outputs(
) )
# Create the Apply node for the scan op # Create the Apply node for the scan op
new_scan_node = new_scan_op(*new_scan_args.outer_inputs, return_list=True)[0].owner new_scan_outs = new_scan_op(*new_scan_args.outer_inputs, return_list=True)
assert isinstance(new_scan_outs, list)
new_scan_node = new_scan_outs[0].owner
assert new_scan_node is not None
# Modify the outer graph to make sure the outputs of the new scan are # Modify the outer graph to make sure the outputs of the new scan are
# used instead of the outputs of the old scan # used instead of the outputs of the old scan
...@@ -781,7 +800,7 @@ def add_nitsot_outputs( ...@@ -781,7 +800,7 @@ def add_nitsot_outputs(
# replacements = dict(zip(old_scan_node.outputs, new_node_old_outputs)) # replacements = dict(zip(old_scan_node.outputs, new_node_old_outputs))
# replacements["remove"] = [old_scan_node] # replacements["remove"] = [old_scan_node]
# return new_scan_node, replacements # return new_scan_node, replacements
fgraph.replace_all_validate_remove( fgraph.replace_all_validate_remove( # type: ignore
list(zip(old_scan_node.outputs, new_node_old_outputs)), list(zip(old_scan_node.outputs, new_node_old_outputs)),
remove=[old_scan_node], remove=[old_scan_node],
reason="scan_pushout_add", reason="scan_pushout_add",
......
...@@ -9,6 +9,8 @@ import logging ...@@ -9,6 +9,8 @@ import logging
import os import os
import sys import sys
from importlib import reload from importlib import reload
from types import ModuleType
from typing import Optional
import aesara import aesara
from aesara.compile.compilelock import lock_ctx from aesara.compile.compilelock import lock_ctx
...@@ -24,6 +26,7 @@ _logger = logging.getLogger("aesara.scan.scan_perform") ...@@ -24,6 +26,7 @@ _logger = logging.getLogger("aesara.scan.scan_perform")
version = 0.312 # must match constant returned in function get_version() version = 0.312 # must match constant returned in function get_version()
need_reload = False need_reload = False
scan_perform: Optional[ModuleType] = None
def try_import(): def try_import():
...@@ -107,7 +110,10 @@ except ImportError: ...@@ -107,7 +110,10 @@ except ImportError:
from scan_perform import scan_perform as scan_c from scan_perform import scan_perform as scan_c
assert scan_perform._version == scan_c.get_version() assert (
scan_perform is not None
and scan_perform._version == scan_c.get_version()
)
_logger.info(f"New version {scan_perform._version}") _logger.info(f"New version {scan_perform._version}")
......
...@@ -5,7 +5,8 @@ import dataclasses ...@@ -5,7 +5,8 @@ import dataclasses
import logging import logging
import warnings import warnings
from collections import OrderedDict, namedtuple from collections import OrderedDict, namedtuple
from typing import TYPE_CHECKING, Callable, List, Optional, Set, Tuple, Union from typing import TYPE_CHECKING, Callable, List, Optional, Sequence, Set, Tuple, Union
from typing import cast as type_cast
import numpy as np import numpy as np
...@@ -21,8 +22,9 @@ from aesara.graph.basic import ( ...@@ -21,8 +22,9 @@ from aesara.graph.basic import (
graph_inputs, graph_inputs,
) )
from aesara.graph.op import get_test_value from aesara.graph.op import get_test_value
from aesara.graph.type import HasDataType
from aesara.graph.utils import TestValueError from aesara.graph.utils import TestValueError
from aesara.tensor.basic import AllocEmpty, get_scalar_constant_value from aesara.tensor.basic import AllocEmpty, cast, get_scalar_constant_value
from aesara.tensor.subtensor import set_subtensor from aesara.tensor.subtensor import set_subtensor
from aesara.tensor.var import TensorConstant from aesara.tensor.var import TensorConstant
...@@ -55,8 +57,11 @@ def safe_new( ...@@ -55,8 +57,11 @@ def safe_new(
nw_name = None nw_name = None
if isinstance(x, Constant): if isinstance(x, Constant):
if dtype and x.dtype != dtype: # TODO: Do something better about this
casted_x = x.astype(dtype) assert isinstance(x.type, HasDataType)
if dtype and x.type.dtype != dtype:
casted_x = cast(x, dtype)
nwx = type(x)(casted_x.type, x.data, x.name) nwx = type(x)(casted_x.type, x.data, x.name)
nwx.tag = copy.copy(x.tag) nwx.tag = copy.copy(x.tag)
return nwx return nwx
...@@ -89,8 +94,12 @@ def safe_new( ...@@ -89,8 +94,12 @@ def safe_new(
pass pass
# Cast `x` if needed. If `x` has a test value, this will also cast it. # Cast `x` if needed. If `x` has a test value, this will also cast it.
if dtype and x.dtype != dtype: if dtype:
x = x.astype(dtype) # TODO: Do something better about this
assert isinstance(x.type, HasDataType)
if x.type.dtype != dtype:
x = cast(x, dtype)
nw_x = x.type() nw_x = x.type()
nw_x.name = nw_name nw_x.name = nw_name
...@@ -717,13 +726,13 @@ class ScanArgs: ...@@ -717,13 +726,13 @@ class ScanArgs:
def __init__( def __init__(
self, self,
outer_inputs, outer_inputs: Sequence[Variable],
outer_outputs, outer_outputs: Sequence[Variable],
_inner_inputs, _inner_inputs: Sequence[Variable],
_inner_outputs, _inner_outputs: Sequence[Variable],
info, info: "ScanInfo",
as_while, as_while: bool,
clone=True, clone: Optional[bool] = True,
): ):
self.n_steps = outer_inputs[0] self.n_steps = outer_inputs[0]
self.as_while = as_while self.as_while = as_while
...@@ -745,16 +754,16 @@ class ScanArgs: ...@@ -745,16 +754,16 @@ class ScanArgs:
q = 0 q = 0
n_seqs = info.n_seqs n_seqs = info.n_seqs
self.outer_in_seqs = outer_inputs[p : p + n_seqs] self.outer_in_seqs = list(outer_inputs[p : p + n_seqs])
self.inner_in_seqs = inner_inputs[q : q + n_seqs] self.inner_in_seqs = list(inner_inputs[q : q + n_seqs])
p += n_seqs p += n_seqs
q += n_seqs q += n_seqs
n_mit_mot = info.n_mit_mot n_mit_mot = info.n_mit_mot
n_mit_sot = info.n_mit_sot n_mit_sot = info.n_mit_sot
self.mit_mot_in_slices = info.tap_array[:n_mit_mot] self.mit_mot_in_slices = list(info.tap_array[:n_mit_mot])
self.mit_sot_in_slices = info.tap_array[n_mit_mot : n_mit_mot + n_mit_sot] self.mit_sot_in_slices = list(info.tap_array[n_mit_mot : n_mit_mot + n_mit_sot])
n_mit_mot_ins = sum(len(s) for s in self.mit_mot_in_slices) n_mit_mot_ins = sum(len(s) for s in self.mit_mot_in_slices)
n_mit_sot_ins = sum(len(s) for s in self.mit_sot_in_slices) n_mit_sot_ins = sum(len(s) for s in self.mit_sot_in_slices)
...@@ -775,29 +784,29 @@ class ScanArgs: ...@@ -775,29 +784,29 @@ class ScanArgs:
qq += len(sl) qq += len(sl)
q += n_mit_sot_ins q += n_mit_sot_ins
self.outer_in_mit_mot = outer_inputs[p : p + n_mit_mot] self.outer_in_mit_mot = list(outer_inputs[p : p + n_mit_mot])
p += n_mit_mot p += n_mit_mot
self.outer_in_mit_sot = outer_inputs[p : p + n_mit_sot] self.outer_in_mit_sot = list(outer_inputs[p : p + n_mit_sot])
p += n_mit_sot p += n_mit_sot
n_sit_sot = info.n_sit_sot n_sit_sot = info.n_sit_sot
self.outer_in_sit_sot = outer_inputs[p : p + n_sit_sot] self.outer_in_sit_sot = list(outer_inputs[p : p + n_sit_sot])
self.inner_in_sit_sot = inner_inputs[q : q + n_sit_sot] self.inner_in_sit_sot = list(inner_inputs[q : q + n_sit_sot])
p += n_sit_sot p += n_sit_sot
q += n_sit_sot q += n_sit_sot
n_shared_outs = info.n_shared_outs n_shared_outs = info.n_shared_outs
self.outer_in_shared = outer_inputs[p : p + n_shared_outs] self.outer_in_shared = list(outer_inputs[p : p + n_shared_outs])
self.inner_in_shared = inner_inputs[q : q + n_shared_outs] self.inner_in_shared = list(inner_inputs[q : q + n_shared_outs])
p += n_shared_outs p += n_shared_outs
q += n_shared_outs q += n_shared_outs
n_nit_sot = info.n_nit_sot n_nit_sot = info.n_nit_sot
self.outer_in_nit_sot = outer_inputs[p : p + n_nit_sot] self.outer_in_nit_sot = list(outer_inputs[p : p + n_nit_sot])
p += n_nit_sot p += n_nit_sot
self.outer_in_non_seqs = outer_inputs[p:] self.outer_in_non_seqs = list(outer_inputs[p:])
self.inner_in_non_seqs = inner_inputs[q:] self.inner_in_non_seqs = list(inner_inputs[q:])
# now for the outputs # now for the outputs
p = 0 p = 0
...@@ -805,9 +814,9 @@ class ScanArgs: ...@@ -805,9 +814,9 @@ class ScanArgs:
self.mit_mot_out_slices = info.mit_mot_out_slices self.mit_mot_out_slices = info.mit_mot_out_slices
n_mit_mot_outs = info.n_mit_mot_outs n_mit_mot_outs = info.n_mit_mot_outs
self.outer_out_mit_mot = outer_outputs[p : p + n_mit_mot] self.outer_out_mit_mot = list(outer_outputs[p : p + n_mit_mot])
iomm = inner_outputs[q : q + n_mit_mot_outs] iomm = list(inner_outputs[q : q + n_mit_mot_outs])
self.inner_out_mit_mot = () self.inner_out_mit_mot: Tuple[List[Variable], ...] = ()
qq = 0 qq = 0
for sl in self.mit_mot_out_slices: for sl in self.mit_mot_out_slices:
self.inner_out_mit_mot += (iomm[qq : qq + len(sl)],) self.inner_out_mit_mot += (iomm[qq : qq + len(sl)],)
...@@ -815,23 +824,23 @@ class ScanArgs: ...@@ -815,23 +824,23 @@ class ScanArgs:
p += n_mit_mot p += n_mit_mot
q += n_mit_mot_outs q += n_mit_mot_outs
self.outer_out_mit_sot = outer_outputs[p : p + n_mit_sot] self.outer_out_mit_sot = list(outer_outputs[p : p + n_mit_sot])
self.inner_out_mit_sot = inner_outputs[q : q + n_mit_sot] self.inner_out_mit_sot = list(inner_outputs[q : q + n_mit_sot])
p += n_mit_sot p += n_mit_sot
q += n_mit_sot q += n_mit_sot
self.outer_out_sit_sot = outer_outputs[p : p + n_sit_sot] self.outer_out_sit_sot = list(outer_outputs[p : p + n_sit_sot])
self.inner_out_sit_sot = inner_outputs[q : q + n_sit_sot] self.inner_out_sit_sot = list(inner_outputs[q : q + n_sit_sot])
p += n_sit_sot p += n_sit_sot
q += n_sit_sot q += n_sit_sot
self.outer_out_nit_sot = outer_outputs[p : p + n_nit_sot] self.outer_out_nit_sot = list(outer_outputs[p : p + n_nit_sot])
self.inner_out_nit_sot = inner_outputs[q : q + n_nit_sot] self.inner_out_nit_sot = list(inner_outputs[q : q + n_nit_sot])
p += n_nit_sot p += n_nit_sot
q += n_nit_sot q += n_nit_sot
self.outer_out_shared = outer_outputs[p : p + n_shared_outs] self.outer_out_shared = list(outer_outputs[p : p + n_shared_outs])
self.inner_out_shared = inner_outputs[q : q + n_shared_outs] self.inner_out_shared = list(inner_outputs[q : q + n_shared_outs])
p += n_shared_outs p += n_shared_outs
q += n_shared_outs q += n_shared_outs
...@@ -978,12 +987,18 @@ class ScanArgs: ...@@ -978,12 +987,18 @@ class ScanArgs:
------- -------
The alternate variable. The alternate variable.
""" """
_var_info: FieldInfo
if not isinstance(var_info, FieldInfo): if not isinstance(var_info, FieldInfo):
var_info = self.find_among_fields(var_info) find_var_info = self.find_among_fields(var_info)
if find_var_info is None:
raise ValueError(f"Couldn't find {var_info} among fields")
_var_info = find_var_info
else:
_var_info = var_info
alt_type = var_info.name[(var_info.name.index("_", 6) + 1) :] alt_type = _var_info.name[(_var_info.name.index("_", 6) + 1) :]
alt_var = getattr(self, f"{alt_prefix}_{alt_type}")[var_info.index] alt_var = getattr(self, f"{alt_prefix}_{alt_type}")[_var_info.index]
return alt_var return type_cast(Variable, alt_var)
def find_among_fields( def find_among_fields(
self, i: Variable, field_filter: Callable[[str], bool] = default_filter self, i: Variable, field_filter: Callable[[str], bool] = default_filter
...@@ -1054,7 +1069,7 @@ class ScanArgs: ...@@ -1054,7 +1069,7 @@ class ScanArgs:
return field_info return field_info
def get_dependent_nodes( def get_dependent_nodes(
self, i: Variable, seen: Optional[Set[int]] = None self, i: Variable, seen: Optional[Set[Variable]] = None
) -> Set[Variable]: ) -> Set[Variable]:
if seen is None: if seen is None:
seen = {i} seen = {i}
......
...@@ -227,22 +227,6 @@ check_untyped_defs = False ...@@ -227,22 +227,6 @@ check_untyped_defs = False
ignore_errors = True ignore_errors = True
check_untyped_defs = False check_untyped_defs = False
[mypy-aesara.scan.op]
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.scan.opt]
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.scan.scan_perform_ext]
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.scan.utils]
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.tensor.nnet.conv3d2d] [mypy-aesara.tensor.nnet.conv3d2d]
ignore_errors = True ignore_errors = True
check_untyped_defs = False check_untyped_defs = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论