提交 ccbf2e98 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix typing issues in aesara.scan

上级 05da60c3
......@@ -4,7 +4,7 @@ import copy
import dataclasses
import logging
from sys import maxsize
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple
import numpy as np
......@@ -15,8 +15,8 @@ from aesara.compile import optdb
from aesara.compile.function.types import deep_copy_op
from aesara.configdefaults import config
from aesara.graph.basic import (
Apply,
Constant,
Node,
Variable,
clone_replace,
equal_computations,
......@@ -30,6 +30,7 @@ from aesara.graph.fg import FunctionGraph
from aesara.graph.op import compute_test_value
from aesara.graph.opt import GlobalOptimizer, in2out, local_optimizer
from aesara.graph.optdb import EquilibriumDB, SequenceDB
from aesara.graph.type import HasShape
from aesara.graph.utils import InconsistencyError
from aesara.scan.op import Scan, ScanInfo
from aesara.scan.utils import (
......@@ -652,7 +653,7 @@ def inner_sitsot_only_last_step_used(
if len(fgraph.clients[outer_var]) == 1:
client = fgraph.clients[outer_var][0][0]
if client != "output" and isinstance(client.op, Subtensor):
if isinstance(client, Apply) and isinstance(client.op, Subtensor):
lst = get_idx_list(client.inputs, client.op.idx_list)
if len(lst) == 1 and at.extract_constant(lst[0]) == -1:
return True
......@@ -662,10 +663,12 @@ def inner_sitsot_only_last_step_used(
def get_outer_ndim(var: Variable, scan_args: ScanArgs) -> int:
"""Determine the number of dimension a variable would have if it was pushed out of a `Scan`."""
assert isinstance(var.type, HasShape)
if var in scan_args.inner_in_non_seqs or isinstance(var, Constant):
outer_ndim = var.ndim
outer_ndim = var.type.ndim
else:
outer_ndim = var.ndim + 1
outer_ndim = var.type.ndim + 1
return outer_ndim
......@@ -673,11 +676,11 @@ def get_outer_ndim(var: Variable, scan_args: ScanArgs) -> int:
def push_out_inner_vars(
fgraph: FunctionGraph,
inner_vars: List[Variable],
old_scan_node: Node,
old_scan_node: Apply,
old_scan_args: ScanArgs,
) -> Tuple[List[Variable], ScanArgs, Dict[Variable, Variable]]:
outer_vars = [None] * len(inner_vars)
tmp_outer_vars: List[Optional[Variable]] = []
new_scan_node = old_scan_node
new_scan_args = old_scan_args
replacements: Dict[Variable, Variable] = {}
......@@ -688,26 +691,32 @@ def push_out_inner_vars(
var = inner_vars[idx]
new_outer_var: Optional[Variable] = None
if var in old_scan_args.inner_in_seqs:
idx_seq = old_scan_args.inner_in_seqs.index(var)
outer_vars[idx] = old_scan_args.outer_in_seqs[idx_seq]
new_outer_var = old_scan_args.outer_in_seqs[idx_seq]
elif var in old_scan_args.inner_in_non_seqs:
idx_non_seq = old_scan_args.inner_in_non_seqs.index(var)
outer_vars[idx] = old_scan_args.outer_in_non_seqs[idx_non_seq]
new_outer_var = old_scan_args.outer_in_non_seqs[idx_non_seq]
elif isinstance(var, Constant):
outer_vars[idx] = var.clone()
new_outer_var = var.clone()
elif var in old_scan_args.inner_out_nit_sot:
idx_nitsot = old_scan_args.inner_out_nit_sot.index(var)
outer_vars[idx] = old_scan_args.outer_out_nit_sot[idx_nitsot]
new_outer_var = old_scan_args.outer_out_nit_sot[idx_nitsot]
tmp_outer_vars.append(new_outer_var)
# For the inner_vars that don't already exist in the outer graph, add
# them as new nitsot outputs to the scan node.
idx_add_as_nitsots = [i for i in range(len(outer_vars)) if outer_vars[i] is None]
idx_add_as_nitsots = [i for i, v in enumerate(tmp_outer_vars) if v is None]
add_as_nitsots = [inner_vars[idx] for idx in idx_add_as_nitsots]
new_outs: List[Variable] = []
if len(add_as_nitsots) > 0:
new_scan_node, replacements = add_nitsot_outputs(
......@@ -724,18 +733,25 @@ def push_out_inner_vars(
)
new_outs = new_scan_args.outer_out_nit_sot[-len(add_as_nitsots) :]
for i in range(len(new_outs)):
outer_vars[idx_add_as_nitsots[i]] = new_outs[i]
outer_vars: List[Variable] = []
for i, v in enumerate(tmp_outer_vars):
if i in idx_add_as_nitsots:
outer_vars.append(new_outs.pop(0))
else:
assert v is not None
outer_vars.append(v)
return outer_vars, new_scan_args, replacements
def add_nitsot_outputs(
fgraph: FunctionGraph,
old_scan_node: Node,
old_scan_node: Apply,
old_scan_args: ScanArgs,
new_outputs_inner,
) -> Tuple[Node, Dict[Variable, Variable]]:
) -> Tuple[Apply, Dict[Variable, Variable]]:
nb_new_outs = len(new_outputs_inner)
......@@ -764,7 +780,10 @@ def add_nitsot_outputs(
)
# Create the Apply node for the scan op
new_scan_node = new_scan_op(*new_scan_args.outer_inputs, return_list=True)[0].owner
new_scan_outs = new_scan_op(*new_scan_args.outer_inputs, return_list=True)
assert isinstance(new_scan_outs, list)
new_scan_node = new_scan_outs[0].owner
assert new_scan_node is not None
# Modify the outer graph to make sure the outputs of the new scan are
# used instead of the outputs of the old scan
......@@ -781,7 +800,7 @@ def add_nitsot_outputs(
# replacements = dict(zip(old_scan_node.outputs, new_node_old_outputs))
# replacements["remove"] = [old_scan_node]
# return new_scan_node, replacements
fgraph.replace_all_validate_remove(
fgraph.replace_all_validate_remove( # type: ignore
list(zip(old_scan_node.outputs, new_node_old_outputs)),
remove=[old_scan_node],
reason="scan_pushout_add",
......
......@@ -9,6 +9,8 @@ import logging
import os
import sys
from importlib import reload
from types import ModuleType
from typing import Optional
import aesara
from aesara.compile.compilelock import lock_ctx
......@@ -24,6 +26,7 @@ _logger = logging.getLogger("aesara.scan.scan_perform")
version = 0.312 # must match constant returned in function get_version()
need_reload = False
scan_perform: Optional[ModuleType] = None
def try_import():
......@@ -107,7 +110,10 @@ except ImportError:
from scan_perform import scan_perform as scan_c
assert scan_perform._version == scan_c.get_version()
assert (
scan_perform is not None
and scan_perform._version == scan_c.get_version()
)
_logger.info(f"New version {scan_perform._version}")
......
......@@ -5,7 +5,8 @@ import dataclasses
import logging
import warnings
from collections import OrderedDict, namedtuple
from typing import TYPE_CHECKING, Callable, List, Optional, Set, Tuple, Union
from typing import TYPE_CHECKING, Callable, List, Optional, Sequence, Set, Tuple, Union
from typing import cast as type_cast
import numpy as np
......@@ -21,8 +22,9 @@ from aesara.graph.basic import (
graph_inputs,
)
from aesara.graph.op import get_test_value
from aesara.graph.type import HasDataType
from aesara.graph.utils import TestValueError
from aesara.tensor.basic import AllocEmpty, get_scalar_constant_value
from aesara.tensor.basic import AllocEmpty, cast, get_scalar_constant_value
from aesara.tensor.subtensor import set_subtensor
from aesara.tensor.var import TensorConstant
......@@ -55,8 +57,11 @@ def safe_new(
nw_name = None
if isinstance(x, Constant):
if dtype and x.dtype != dtype:
casted_x = x.astype(dtype)
# TODO: Do something better about this
assert isinstance(x.type, HasDataType)
if dtype and x.type.dtype != dtype:
casted_x = cast(x, dtype)
nwx = type(x)(casted_x.type, x.data, x.name)
nwx.tag = copy.copy(x.tag)
return nwx
......@@ -89,8 +94,12 @@ def safe_new(
pass
# Cast `x` if needed. If `x` has a test value, this will also cast it.
if dtype and x.dtype != dtype:
x = x.astype(dtype)
if dtype:
# TODO: Do something better about this
assert isinstance(x.type, HasDataType)
if x.type.dtype != dtype:
x = cast(x, dtype)
nw_x = x.type()
nw_x.name = nw_name
......@@ -717,13 +726,13 @@ class ScanArgs:
def __init__(
self,
outer_inputs,
outer_outputs,
_inner_inputs,
_inner_outputs,
info,
as_while,
clone=True,
outer_inputs: Sequence[Variable],
outer_outputs: Sequence[Variable],
_inner_inputs: Sequence[Variable],
_inner_outputs: Sequence[Variable],
info: "ScanInfo",
as_while: bool,
clone: Optional[bool] = True,
):
self.n_steps = outer_inputs[0]
self.as_while = as_while
......@@ -745,16 +754,16 @@ class ScanArgs:
q = 0
n_seqs = info.n_seqs
self.outer_in_seqs = outer_inputs[p : p + n_seqs]
self.inner_in_seqs = inner_inputs[q : q + n_seqs]
self.outer_in_seqs = list(outer_inputs[p : p + n_seqs])
self.inner_in_seqs = list(inner_inputs[q : q + n_seqs])
p += n_seqs
q += n_seqs
n_mit_mot = info.n_mit_mot
n_mit_sot = info.n_mit_sot
self.mit_mot_in_slices = info.tap_array[:n_mit_mot]
self.mit_sot_in_slices = info.tap_array[n_mit_mot : n_mit_mot + n_mit_sot]
self.mit_mot_in_slices = list(info.tap_array[:n_mit_mot])
self.mit_sot_in_slices = list(info.tap_array[n_mit_mot : n_mit_mot + n_mit_sot])
n_mit_mot_ins = sum(len(s) for s in self.mit_mot_in_slices)
n_mit_sot_ins = sum(len(s) for s in self.mit_sot_in_slices)
......@@ -775,29 +784,29 @@ class ScanArgs:
qq += len(sl)
q += n_mit_sot_ins
self.outer_in_mit_mot = outer_inputs[p : p + n_mit_mot]
self.outer_in_mit_mot = list(outer_inputs[p : p + n_mit_mot])
p += n_mit_mot
self.outer_in_mit_sot = outer_inputs[p : p + n_mit_sot]
self.outer_in_mit_sot = list(outer_inputs[p : p + n_mit_sot])
p += n_mit_sot
n_sit_sot = info.n_sit_sot
self.outer_in_sit_sot = outer_inputs[p : p + n_sit_sot]
self.inner_in_sit_sot = inner_inputs[q : q + n_sit_sot]
self.outer_in_sit_sot = list(outer_inputs[p : p + n_sit_sot])
self.inner_in_sit_sot = list(inner_inputs[q : q + n_sit_sot])
p += n_sit_sot
q += n_sit_sot
n_shared_outs = info.n_shared_outs
self.outer_in_shared = outer_inputs[p : p + n_shared_outs]
self.inner_in_shared = inner_inputs[q : q + n_shared_outs]
self.outer_in_shared = list(outer_inputs[p : p + n_shared_outs])
self.inner_in_shared = list(inner_inputs[q : q + n_shared_outs])
p += n_shared_outs
q += n_shared_outs
n_nit_sot = info.n_nit_sot
self.outer_in_nit_sot = outer_inputs[p : p + n_nit_sot]
self.outer_in_nit_sot = list(outer_inputs[p : p + n_nit_sot])
p += n_nit_sot
self.outer_in_non_seqs = outer_inputs[p:]
self.inner_in_non_seqs = inner_inputs[q:]
self.outer_in_non_seqs = list(outer_inputs[p:])
self.inner_in_non_seqs = list(inner_inputs[q:])
# now for the outputs
p = 0
......@@ -805,9 +814,9 @@ class ScanArgs:
self.mit_mot_out_slices = info.mit_mot_out_slices
n_mit_mot_outs = info.n_mit_mot_outs
self.outer_out_mit_mot = outer_outputs[p : p + n_mit_mot]
iomm = inner_outputs[q : q + n_mit_mot_outs]
self.inner_out_mit_mot = ()
self.outer_out_mit_mot = list(outer_outputs[p : p + n_mit_mot])
iomm = list(inner_outputs[q : q + n_mit_mot_outs])
self.inner_out_mit_mot: Tuple[List[Variable], ...] = ()
qq = 0
for sl in self.mit_mot_out_slices:
self.inner_out_mit_mot += (iomm[qq : qq + len(sl)],)
......@@ -815,23 +824,23 @@ class ScanArgs:
p += n_mit_mot
q += n_mit_mot_outs
self.outer_out_mit_sot = outer_outputs[p : p + n_mit_sot]
self.inner_out_mit_sot = inner_outputs[q : q + n_mit_sot]
self.outer_out_mit_sot = list(outer_outputs[p : p + n_mit_sot])
self.inner_out_mit_sot = list(inner_outputs[q : q + n_mit_sot])
p += n_mit_sot
q += n_mit_sot
self.outer_out_sit_sot = outer_outputs[p : p + n_sit_sot]
self.inner_out_sit_sot = inner_outputs[q : q + n_sit_sot]
self.outer_out_sit_sot = list(outer_outputs[p : p + n_sit_sot])
self.inner_out_sit_sot = list(inner_outputs[q : q + n_sit_sot])
p += n_sit_sot
q += n_sit_sot
self.outer_out_nit_sot = outer_outputs[p : p + n_nit_sot]
self.inner_out_nit_sot = inner_outputs[q : q + n_nit_sot]
self.outer_out_nit_sot = list(outer_outputs[p : p + n_nit_sot])
self.inner_out_nit_sot = list(inner_outputs[q : q + n_nit_sot])
p += n_nit_sot
q += n_nit_sot
self.outer_out_shared = outer_outputs[p : p + n_shared_outs]
self.inner_out_shared = inner_outputs[q : q + n_shared_outs]
self.outer_out_shared = list(outer_outputs[p : p + n_shared_outs])
self.inner_out_shared = list(inner_outputs[q : q + n_shared_outs])
p += n_shared_outs
q += n_shared_outs
......@@ -978,12 +987,18 @@ class ScanArgs:
-------
The alternate variable.
"""
_var_info: FieldInfo
if not isinstance(var_info, FieldInfo):
var_info = self.find_among_fields(var_info)
find_var_info = self.find_among_fields(var_info)
if find_var_info is None:
raise ValueError(f"Couldn't find {var_info} among fields")
_var_info = find_var_info
else:
_var_info = var_info
alt_type = var_info.name[(var_info.name.index("_", 6) + 1) :]
alt_var = getattr(self, f"{alt_prefix}_{alt_type}")[var_info.index]
return alt_var
alt_type = _var_info.name[(_var_info.name.index("_", 6) + 1) :]
alt_var = getattr(self, f"{alt_prefix}_{alt_type}")[_var_info.index]
return type_cast(Variable, alt_var)
def find_among_fields(
self, i: Variable, field_filter: Callable[[str], bool] = default_filter
......@@ -1054,7 +1069,7 @@ class ScanArgs:
return field_info
def get_dependent_nodes(
self, i: Variable, seen: Optional[Set[int]] = None
self, i: Variable, seen: Optional[Set[Variable]] = None
) -> Set[Variable]:
if seen is None:
seen = {i}
......
......@@ -227,22 +227,6 @@ check_untyped_defs = False
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.scan.op]
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.scan.opt]
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.scan.scan_perform_ext]
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.scan.utils]
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.tensor.nnet.conv3d2d]
ignore_errors = True
check_untyped_defs = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论