提交 e85c7fd0 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Replace Scan info dict with ScanInfo dataclass

上级 d83ff33c
......@@ -13,7 +13,7 @@ from aesara.graph.fg import MissingInputError
from aesara.graph.op import get_test_value
from aesara.graph.utils import TestValueError
from aesara.scan import utils
from aesara.scan.op import Scan
from aesara.scan.op import Scan, ScanInfo
from aesara.scan.utils import safe_new, traverse
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import minimum
......@@ -1022,31 +1022,32 @@ def scan(
# Step 7. Create the Scan Op
##
tap_array = mit_sot_tap_array + [[-1] for x in range(n_sit_sot)]
tap_array = tuple(tuple(v) for v in mit_sot_tap_array) + tuple(
(-1,) for x in range(n_sit_sot)
)
if allow_gc is None:
allow_gc = config.scan__allow_gc
info = OrderedDict()
info["tap_array"] = tap_array
info["n_seqs"] = n_seqs
info["n_mit_mot"] = n_mit_mot
info["n_mit_mot_outs"] = n_mit_mot_outs
info["mit_mot_out_slices"] = mit_mot_out_slices
info["n_mit_sot"] = n_mit_sot
info["n_sit_sot"] = n_sit_sot
info["n_shared_outs"] = n_shared_outs
info["n_nit_sot"] = n_nit_sot
info["truncate_gradient"] = truncate_gradient
info["name"] = name
info["mode"] = mode
info["destroy_map"] = OrderedDict()
info["gpua"] = False
info["as_while"] = as_while
info["profile"] = profile
info["allow_gc"] = allow_gc
info["strict"] = strict
local_op = Scan(inner_inputs, new_outs, info)
info = ScanInfo(
tap_array=tap_array,
n_seqs=n_seqs,
n_mit_mot=n_mit_mot,
n_mit_mot_outs=n_mit_mot_outs,
mit_mot_out_slices=tuple(tuple(v) for v in mit_mot_out_slices),
n_mit_sot=n_mit_sot,
n_sit_sot=n_sit_sot,
n_shared_outs=n_shared_outs,
n_nit_sot=n_nit_sot,
truncate_gradient=truncate_gradient,
name=name,
gpua=False,
as_while=as_while,
profile=profile,
allow_gc=allow_gc,
strict=strict,
)
local_op = Scan(inner_inputs, new_outs, info, mode)
##
# Step 8. Compute the outputs using the scan op
......
......@@ -44,11 +44,12 @@ relies on the following elements to work properly :
"""
import copy
import dataclasses
import itertools
import logging
import time
from collections import OrderedDict
from typing import Callable, List, Optional, Union
import numpy as np
......@@ -57,7 +58,7 @@ from aesara import tensor as aet
from aesara.compile.builders import infer_shape
from aesara.compile.function import function
from aesara.compile.io import In, Out
from aesara.compile.mode import AddFeatureOptimizer, get_mode
from aesara.compile.mode import AddFeatureOptimizer, Mode, get_default_mode, get_mode
from aesara.compile.profiling import ScanProfileStats, register_profiler_printer
from aesara.configdefaults import config
from aesara.gradient import DisconnectedType, NullType, Rop, grad, grad_undefined
......@@ -76,7 +77,7 @@ from aesara.graph.op import Op, ops_with_inner_function
from aesara.link.c.basic import CLinker
from aesara.link.c.exceptions import MissingGXX
from aesara.link.utils import raise_with_op
from aesara.scan.utils import Validator, forced_replace, hash_listsDictsTuples, safe_new
from aesara.scan.utils import Validator, forced_replace, safe_new
from aesara.tensor.basic import as_tensor_variable
from aesara.tensor.math import minimum
from aesara.tensor.shape import Shape_i
......@@ -88,57 +89,90 @@ from aesara.tensor.var import TensorVariable
_logger = logging.getLogger("aesara.scan.op")
class Scan(Op):
"""
Parameters
----------
inputs
Inputs of the inner function of scan.
outputs
Outputs of the inner function of scan.
info
Dictionary containing different properties of the scan op (like number
of different types of arguments, name, mode, if it should run on GPU or
not, etc.).
typeConstructor
Function that constructs an equivalent to Aesara TensorType.
Notes
-----
``typeConstructor`` had been added to refactor how
Aesara deals with the GPU. If it runs on the GPU, scan needs
to construct certain outputs (those who reside in the GPU
memory) as the GPU-specific type. However we can not import
gpu code in this file (as it is in sandbox, and not available
on each machine) so the workaround is that the GPU
optimization passes to the constructor of this class a
function that is able to construct a GPU type. This way the
class Scan does not need to be aware of the details for the
GPU, it just constructs any tensor using this function (which
by default constructs normal tensors).
"""
@dataclasses.dataclass(frozen=True)
class ScanInfo:
tap_array: tuple
n_seqs: int
n_mit_mot: int
n_mit_mot_outs: int
mit_mot_out_slices: tuple
n_mit_sot: int
n_sit_sot: int
n_shared_outs: int
n_nit_sot: int
truncate_gradient: bool = False
name: Optional[str] = None
gpua: bool = False
as_while: bool = False
profile: Optional[Union[str, bool]] = None
allow_gc: bool = True
strict: bool = True
TensorConstructorType = Callable[[List[bool], Union[str, np.generic]], TensorType]
class Scan(Op):
def __init__(
self,
inputs,
outputs,
info,
typeConstructor=None,
inputs: List[Variable],
outputs: List[Variable],
info: ScanInfo,
mode: Optional[Mode] = None,
typeConstructor: Optional[TensorConstructorType] = None,
):
r"""
Parameters
----------
inputs
Inputs of the inner function of `Scan`.
outputs
Outputs of the inner function of `Scan`.
info
Dictionary containing different properties of the `Scan` `Op` (like
number of different types of arguments, name, mode, if it should run on
GPU or not, etc.).
mode
The compilation mode for the inner graph.
typeConstructor
Function that constructs an equivalent to Aesara `TensorType`.
Notes
-----
`typeConstructor` had been added to refactor how
Aesara deals with the GPU. If it runs on the GPU, scan needs
to construct certain outputs (those who reside in the GPU
memory) as the GPU-specific type. However we can not import
gpu code in this file (as it is in sandbox, and not available
on each machine) so the workaround is that the GPU
optimization passes to the constructor of this class a
function that is able to construct a GPU type. This way the
class `Scan` does not need to be aware of the details for the
GPU, it just constructs any tensor using this function (which
by default constructs normal tensors).
"""
# adding properties into self
self.inputs = inputs
self.outputs = outputs
self.__dict__.update(info)
# I keep a version of info in self, to use in __eq__ and __hash__,
# since info contains all tunable parameters of the op, so for two
# scan to be equal this tunable parameters should be the same
self.info = info
self.__dict__.update(dataclasses.asdict(info))
# Clone mode_instance, altering "allow_gc" for the linker,
# and adding a message if we profile
if self.name:
message = self.name + " sub profile"
else:
message = "Scan sub profile"
self.mode = get_default_mode() if mode is None else mode
self.mode_instance = get_mode(self.mode).clone(
link_kwargs=dict(allow_gc=self.allow_gc), message=message
)
# build a list of output types for any Apply node using this op.
self.output_types = []
idx = 0
jdx = 0
def tensorConstructor(broadcastable, dtype):
return TensorType(broadcastable=broadcastable, dtype=dtype)
......@@ -146,6 +180,8 @@ class Scan(Op):
if typeConstructor is None:
typeConstructor = tensorConstructor
idx = 0
jdx = 0
while idx < self.n_mit_mot_outs:
# Not that for mit_mot there are several output slices per
# output sequence
......@@ -176,24 +212,13 @@ class Scan(Op):
if self.as_while:
self.output_types = self.output_types[:-1]
mode_instance = get_mode(self.mode)
# Clone mode_instance, altering "allow_gc" for the linker,
# and adding a message if we profile
if self.name:
message = self.name + " sub profile"
else:
message = "Scan sub profile"
self.mode_instance = mode_instance.clone(
link_kwargs=dict(allow_gc=self.allow_gc), message=message
)
if not hasattr(self, "name") or self.name is None:
self.name = "scan_fn"
# to have a fair __eq__ comparison later on, we update the info with
# the actual mode used to compile the function and the name of the
# function that we set in case none was given
self.info["name"] = self.name
self.info = dataclasses.replace(self.info, name=self.name)
# Pre-computing some values to speed up perform
self.mintaps = [np.min(x) for x in self.tap_array]
......@@ -205,8 +230,8 @@ class Scan(Op):
self.nit_sot_arg_offset = self.shared_arg_offset + self.n_shared_outs
self.n_outs = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
self.n_tap_outs = self.n_mit_mot + self.n_mit_sot
if self.info["gpua"]:
self._hash_inner_graph = self.info["gpu_hash"]
if self.info.gpua:
self._hash_inner_graph = self.info.gpu_hash
else:
# Do the missing inputs check here to have the error early.
for var in graph_inputs(self.outputs, self.inputs):
......@@ -256,7 +281,7 @@ class Scan(Op):
# output with type GpuArrayType
from aesara.gpuarray import GpuArrayType
if not self.info.get("gpua", False):
if not self.info.gpua:
for inp in self.inputs:
if isinstance(inp.type, GpuArrayType):
raise TypeError(
......@@ -279,9 +304,7 @@ class Scan(Op):
def __setstate__(self, d):
self.__dict__.update(d)
if "allow_gc" not in self.__dict__:
self.allow_gc = True
self.info["allow_gc"] = True
if not hasattr(self, "var_mappings"):
# Generate the mappings between inner and outer inputs and outputs
# if they haven't already been generated.
......@@ -731,41 +754,20 @@ class Scan(Op):
return apply_node
def __eq__(self, other):
# Check if we are dealing with same type of objects
if not type(self) == type(other):
if type(self) != type(other):
return False
if "destroy_map" not in self.info:
self.info["destroy_map"] = OrderedDict()
if "destroy_map" not in other.info:
other.info["destroy_map"] = OrderedDict()
keys_to_check = [
"truncate_gradient",
"profile",
"n_seqs",
"tap_array",
"as_while",
"n_mit_sot",
"destroy_map",
"n_nit_sot",
"n_shared_outs",
"n_sit_sot",
"gpua",
"n_mit_mot_outs",
"n_mit_mot",
"mit_mot_out_slices",
]
# This are some safety checks ( namely that the inner graph has the
# same number of inputs and same number of outputs )
if not len(self.inputs) == len(other.inputs):
if self.info != other.info:
return False
elif not len(self.outputs) == len(other.outputs):
# Compare inner graphs
# TODO: Use `self.inner_fgraph == other.inner_fgraph`
if len(self.inputs) != len(other.inputs):
return False
for key in keys_to_check:
if self.info[key] != other.info[key]:
return False
# If everything went OK up to here, there is still one thing to
# check. Namely, do the internal graph represent same
# computations
if len(self.outputs) != len(other.outputs):
return False
for self_in, other_in in zip(self.inputs, other.inputs):
if self_in.type != other_in.type:
return False
......@@ -801,15 +803,7 @@ class Scan(Op):
return aux_txt
def __hash__(self):
return hash(
(
type(self),
# and a hash representing the inner graph using the
# CLinker.cmodule_key_
self._hash_inner_graph,
hash_listsDictsTuples(self.info),
)
)
return hash((type(self), self._hash_inner_graph, self.info))
def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None):
"""
......@@ -864,7 +858,6 @@ class Scan(Op):
wrapped_inputs = [In(x, borrow=False) for x in self.inputs[: self.n_seqs]]
new_outputs = [x for x in self.outputs]
preallocated_mitmot_outs = []
new_mit_mot_out_slices = copy.deepcopy(self.mit_mot_out_slices)
input_idx = self.n_seqs
for mitmot_idx in range(self.n_mit_mot):
......@@ -894,7 +887,6 @@ class Scan(Op):
)
wrapped_inputs.append(wrapped_inp)
preallocated_mitmot_outs.append(output_idx)
new_mit_mot_out_slices[mitmot_idx].remove(inp_tap)
else:
# Wrap the corresponding input as usual. Leave the
# output as-is.
......@@ -1963,7 +1955,7 @@ class Scan(Op):
outer_oidx = 0
# Handle sequences inputs
for i in range(self.info["n_seqs"]):
for i in range(self.info.n_seqs):
outer_input_indices.append(outer_iidx)
inner_input_indices.append([inner_iidx])
inner_output_indices.append([])
......@@ -1975,8 +1967,8 @@ class Scan(Op):
outer_oidx += 0
# Handle mitmots, mitsots and sitsots variables
for i in range(len(self.info["tap_array"])):
nb_input_taps = len(self.info["tap_array"][i])
for i in range(len(self.info.tap_array)):
nb_input_taps = len(self.info.tap_array[i])
if i < self.n_mit_mot:
nb_output_taps = len(self.mit_mot_out_slices[i])
......@@ -1999,7 +1991,7 @@ class Scan(Op):
# This is needed because, for outer inputs (and for outer inputs only)
# nitsots come *after* shared variables.
outer_iidx += self.info["n_shared_outs"]
outer_iidx += self.info.n_shared_outs
# Handle nitsots variables
for i in range(self.n_nit_sot):
......@@ -2015,10 +2007,10 @@ class Scan(Op):
# This is needed because, for outer inputs (and for outer inputs only)
# nitsots come *after* shared variables.
outer_iidx -= self.info["n_shared_outs"] + self.n_nit_sot
outer_iidx -= self.info.n_shared_outs + self.n_nit_sot
# Handle shared states
for i in range(self.info["n_shared_outs"]):
for i in range(self.info.n_shared_outs):
outer_input_indices.append(outer_iidx)
inner_input_indices.append([inner_iidx])
inner_output_indices.append([inner_oidx])
......@@ -2158,10 +2150,10 @@ class Scan(Op):
oidx += 1
iidx -= len(taps)
if iidx < self.info["n_sit_sot"]:
if iidx < self.info.n_sit_sot:
return oidx + iidx
else:
return oidx + iidx + self.info["n_nit_sot"]
return oidx + iidx + self.info.n_nit_sot
def get_out_idx(iidx):
oidx = 0
......@@ -2241,9 +2233,9 @@ class Scan(Op):
# "if Xt not in self.inner_nitsot_outs(self_outputs)" because
# the exact same variable can be used as multiple outputs.
idx_nitsot_start = (
self.info["n_mit_mot"] + self.info["n_mit_sot"] + self.info["n_sit_sot"]
self.info.n_mit_mot + self.info.n_mit_sot + self.info.n_sit_sot
)
idx_nitsot_end = idx_nitsot_start + self.info["n_nit_sot"]
idx_nitsot_end = idx_nitsot_start + self.info.n_nit_sot
if idx < idx_nitsot_start or idx >= idx_nitsot_end:
# What we do here is loop through dC_douts and collect all
# those that are connected to the specific one and do an
......@@ -2668,27 +2660,23 @@ class Scan(Op):
n_sitsot_outs = len(outer_inp_sitsot)
new_tap_array = mitmot_inp_taps + [[-1] for k in range(n_sitsot_outs)]
info = OrderedDict()
info["n_seqs"] = len(outer_inp_seqs)
info["n_mit_sot"] = 0
info["tap_array"] = new_tap_array
info["gpua"] = False
info["n_mit_mot"] = len(outer_inp_mitmot)
info["n_mit_mot_outs"] = n_mitmot_outs
info["mit_mot_out_slices"] = mitmot_out_taps
info["truncate_gradient"] = self.truncate_gradient
info["n_sit_sot"] = n_sitsot_outs
info["n_shared_outs"] = 0
info["n_nit_sot"] = n_nit_sot
info["as_while"] = False
info["profile"] = self.profile
info["destroy_map"] = OrderedDict()
if self.name:
info["name"] = "grad_of_" + self.name
else:
info["name"] = None
info["mode"] = self.mode
info["allow_gc"] = self.allow_gc
info = ScanInfo(
n_seqs=len(outer_inp_seqs),
n_mit_sot=0,
tap_array=tuple(tuple(v) for v in new_tap_array),
gpua=False,
n_mit_mot=len(outer_inp_mitmot),
n_mit_mot_outs=n_mitmot_outs,
mit_mot_out_slices=tuple(tuple(v) for v in mitmot_out_taps),
truncate_gradient=self.truncate_gradient,
n_sit_sot=n_sitsot_outs,
n_shared_outs=0,
n_nit_sot=n_nit_sot,
as_while=False,
profile=self.profile,
name=f"grad_of_{self.name}" if self.name else None,
allow_gc=self.allow_gc,
)
outer_inputs = (
[grad_steps]
......@@ -2709,7 +2697,7 @@ class Scan(Op):
)
inner_gfn_outs = inner_out_mitmot + inner_out_sitsot + inner_out_nitsot
local_op = Scan(inner_gfn_ins, inner_gfn_outs, info)
local_op = Scan(inner_gfn_ins, inner_gfn_outs, info, self.mode)
outputs = local_op(*outer_inputs)
if type(outputs) not in (list, tuple):
outputs = [outputs]
......@@ -2852,8 +2840,8 @@ class Scan(Op):
rop_self_outputs = self_outputs[:-1]
else:
rop_self_outputs = self_outputs
if self.info["n_shared_outs"] > 0:
rop_self_outputs = rop_self_outputs[: -self.info["n_shared_outs"]]
if self.info.n_shared_outs > 0:
rop_self_outputs = rop_self_outputs[: -self.info.n_shared_outs]
rop_outs = Rop(rop_self_outputs, rop_of_inputs, inner_eval_points)
if type(rop_outs) not in (list, tuple):
rop_outs = [rop_outs]
......@@ -2867,25 +2855,7 @@ class Scan(Op):
# The only exception is the eval point for the number of sequences, and
# evan point for the number of nit_sot which I think should just be
# ignored (?)
info = OrderedDict()
info["n_seqs"] = self.n_seqs * 2
info["n_mit_sot"] = self.n_mit_sot * 2
info["n_sit_sot"] = self.n_sit_sot * 2
info["n_mit_mot"] = self.n_mit_mot * 2
info["n_nit_sot"] = self.n_nit_sot * 2
info["n_shared_outs"] = self.n_shared_outs
info["gpua"] = False
info["as_while"] = self.as_while
info["profile"] = self.profile
info["truncate_gradient"] = self.truncate_gradient
if self.name:
info["name"] = "rop_of_" + self.name
else:
info["name"] = None
info["mode"] = self.mode
info["allow_gc"] = self.allow_gc
info["mit_mot_out_slices"] = self.mit_mot_out_slices * 2
info["destroy_map"] = OrderedDict()
new_tap_array = []
b = 0
e = self.n_mit_mot
......@@ -2896,7 +2866,6 @@ class Scan(Op):
b = e
e += self.n_sit_sot
new_tap_array += self.tap_array[b:e] * 2
info["tap_array"] = new_tap_array
# Sequences ...
b = 1
......@@ -2993,7 +2962,7 @@ class Scan(Op):
# Outputs
n_mit_mot_outs = int(np.sum([len(x) for x in self.mit_mot_out_slices]))
info["n_mit_mot_outs"] = n_mit_mot_outs * 2
b = 0
e = n_mit_mot_outs
inner_out_mit_mot = self_outputs[b:e] + rop_outs[b:e]
......@@ -3039,7 +3008,25 @@ class Scan(Op):
+ scan_other
)
local_op = Scan(inner_ins, inner_outs, info)
info = ScanInfo(
n_seqs=self.n_seqs * 2,
n_mit_sot=self.n_mit_sot * 2,
n_sit_sot=self.n_sit_sot * 2,
n_mit_mot=self.n_mit_mot * 2,
n_nit_sot=self.n_nit_sot * 2,
n_shared_outs=self.n_shared_outs,
n_mit_mot_outs=n_mit_mot_outs * 2,
gpua=False,
as_while=self.as_while,
profile=self.profile,
truncate_gradient=self.truncate_gradient,
name=f"rop_of_{self.name}" if self.name else None,
allow_gc=self.allow_gc,
tap_array=tuple(tuple(v) for v in new_tap_array),
mit_mot_out_slices=tuple(tuple(v) for v in self.mit_mot_out_slices) * 2,
)
local_op = Scan(inner_ins, inner_outs, info, self.mode)
outputs = local_op(*scan_inputs)
if type(outputs) not in (list, tuple):
outputs = [outputs]
......
......@@ -51,8 +51,8 @@ scan_eqopt2 -> They are all global optimizer. (in2out convert local to global).
"""
import copy
import dataclasses
import logging
from collections import OrderedDict
from sys import maxsize
import numpy as np
......@@ -78,7 +78,7 @@ from aesara.graph.fg import InconsistencyError
from aesara.graph.op import compute_test_value
from aesara.graph.opt import GlobalOptimizer, in2out, local_optimizer
from aesara.graph.optdb import EquilibriumDB, SequenceDB
from aesara.scan.op import Scan
from aesara.scan.op import Scan, ScanInfo
from aesara.scan.utils import (
ScanArgs,
compress_outs,
......@@ -156,7 +156,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
out_stuff_outer = node.inputs[1 + op.n_seqs : st]
# To replace constants in the outer graph by clones in the inner graph
givens = OrderedDict()
givens = {}
# All the inputs of the inner graph of the new scan
nw_inner = []
# Same for the outer graph, initialized w/ number of steps
......@@ -217,12 +217,10 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
if len(nw_inner) != len(op_ins):
op_outs = clone_replace(op_outs, replace=givens)
nw_info = copy.deepcopy(op.info)
nw_info["n_seqs"] = nw_n_seqs
# DEBUG CHECK
nwScan = Scan(nw_inner, op_outs, nw_info)
nw_info = dataclasses.replace(op.info, n_seqs=nw_n_seqs)
nwScan = Scan(nw_inner, op_outs, nw_info, op.mode)
nw_outs = nwScan(*nw_outer, **dict(return_list=True))
return OrderedDict([("remove", [node])] + list(zip(node.outputs, nw_outs)))
return dict([("remove", [node])] + list(zip(node.outputs, nw_outs)))
else:
return False
......@@ -263,7 +261,7 @@ class PushOutNonSeqScan(GlobalOptimizer):
to_remove_set = set()
to_replace_set = set()
to_replace_map = OrderedDict()
to_replace_map = {}
def add_to_replace(y):
to_replace_set.add(y)
......@@ -377,7 +375,7 @@ class PushOutNonSeqScan(GlobalOptimizer):
if len(clean_to_replace) > 0:
# We can finally put an end to all this madness
givens = OrderedDict()
givens = {}
nw_outer = []
nw_inner = []
for to_repl, repl_in, repl_out in zip(
......@@ -394,7 +392,7 @@ class PushOutNonSeqScan(GlobalOptimizer):
op_ins = clean_inputs + nw_inner
# Reconstruct node
nwScan = Scan(op_ins, op_outs, op.info)
nwScan = Scan(op_ins, op_outs, op.info, op.mode)
# Do not call make_node for test_value
nw_node = nwScan(*(node.inputs + nw_outer), **dict(return_list=True))[
......@@ -409,7 +407,7 @@ class PushOutNonSeqScan(GlobalOptimizer):
return True
elif not to_keep_set:
# Nothing in the inner graph should be kept
replace_with = OrderedDict()
replace_with = {}
for out, idx in to_replace_map.items():
if out in local_fgraph_outs_set:
x = node.outputs[local_fgraph_outs_map[out]]
......@@ -481,7 +479,7 @@ class PushOutSeqScan(GlobalOptimizer):
to_remove_set = set()
to_replace_set = set()
to_replace_map = OrderedDict()
to_replace_map = {}
def add_to_replace(y):
to_replace_set.add(y)
......@@ -638,7 +636,7 @@ class PushOutSeqScan(GlobalOptimizer):
if len(clean_to_replace) > 0:
# We can finally put an end to all this madness
givens = OrderedDict()
givens = {}
nw_outer = []
nw_inner = []
for to_repl, repl_in, repl_out in zip(
......@@ -656,9 +654,10 @@ class PushOutSeqScan(GlobalOptimizer):
op_ins = nw_inner + clean_inputs
# Reconstruct node
nw_info = op.info.copy()
nw_info["n_seqs"] += len(nw_inner)
nwScan = Scan(op_ins, op_outs, nw_info)
nw_info = dataclasses.replace(
op.info, n_seqs=op.info.n_seqs + len(nw_inner)
)
nwScan = Scan(op_ins, op_outs, nw_info, op.mode)
# Do not call make_node for test_value
nw_node = nwScan(
*(node.inputs[:1] + nw_outer + node.inputs[1:]),
......@@ -673,7 +672,7 @@ class PushOutSeqScan(GlobalOptimizer):
return True
elif not to_keep_set and not op.as_while and not op.outer_mitmot(node):
# Nothing in the inner graph should be kept
replace_with = OrderedDict()
replace_with = {}
for out, idx in to_replace_map.items():
if out in local_fgraph_outs_set:
x = node.outputs[local_fgraph_outs_map[out]]
......@@ -937,7 +936,10 @@ class PushOutScanOutput(GlobalOptimizer):
# Create the `Scan` `Op` from the `ScanArgs`
new_scan_op = Scan(
new_scan_args.inner_inputs, new_scan_args.inner_outputs, new_scan_args.info
new_scan_args.inner_inputs,
new_scan_args.inner_outputs,
new_scan_args.info,
old_scan_node.op.mode,
)
# Create the Apply node for the scan op
......@@ -1000,13 +1002,6 @@ class ScanInplaceOptimizer(GlobalOptimizer):
op = node.op
info = copy.deepcopy(op.info)
if "destroy_map" not in info:
info["destroy_map"] = OrderedDict()
for out_idx in output_indices:
info["destroy_map"][out_idx] = [out_idx + 1 + op.info["n_seqs"]]
# inputs corresponding to sequences and n_steps
ls_begin = node.inputs[: 1 + op.n_seqs]
ls = op.outer_mitmot(node.inputs)
......@@ -1048,7 +1043,15 @@ class ScanInplaceOptimizer(GlobalOptimizer):
else:
typeConstructor = self.typeInfer(node)
new_op = Scan(op.inputs, op.outputs, info, typeConstructor=typeConstructor)
new_op = Scan(
op.inputs, op.outputs, op.info, op.mode, typeConstructor=typeConstructor
)
destroy_map = op.destroy_map.copy()
for out_idx in output_indices:
destroy_map[out_idx] = [out_idx + 1 + op.info.n_seqs]
new_op.destroy_map = destroy_map
# Do not call make_node for test_value
new_outs = new_op(*inputs, **dict(return_list=True))
......@@ -1070,7 +1073,7 @@ class ScanInplaceOptimizer(GlobalOptimizer):
scan_nodes = [
x
for x in nodes
if (isinstance(x.op, Scan) and x.op.info["gpua"] == self.gpua_flag)
if (isinstance(x.op, Scan) and x.op.info.gpua == self.gpua_flag)
]
for scan_idx in range(len(scan_nodes)):
......@@ -1080,7 +1083,7 @@ class ScanInplaceOptimizer(GlobalOptimizer):
# them.
original_node = scan_nodes[scan_idx]
op = original_node.op
n_outs = op.info["n_mit_mot"] + op.info["n_mit_sot"] + op.info["n_sit_sot"]
n_outs = op.info.n_mit_mot + op.info.n_mit_sot + op.info.n_sit_sot
# Generate a list of outputs on which the node could potentially
# operate inplace.
......@@ -1171,7 +1174,7 @@ class ScanSaveMem(GlobalOptimizer):
# Each access to shape_of is in a try..except block in order to
# use a default version when the variable is not in the shape_of
# dictionary.
shape_of = OrderedDict()
shape_of = {}
# 1. Initialization of variables
# Note 1) We do not actually care about outputs representing shared
# variables (those have no intermediate values) so it is safer to
......@@ -1548,7 +1551,7 @@ class ScanSaveMem(GlobalOptimizer):
(inps, outs, info, node_ins, compress_map) = compress_outs(
op, not_required, nw_inputs
)
inv_compress_map = OrderedDict()
inv_compress_map = {}
for k, v in compress_map.items():
inv_compress_map[v] = k
......@@ -1559,7 +1562,9 @@ class ScanSaveMem(GlobalOptimizer):
return
# Do not call make_node for test_value
new_outs = Scan(inps, outs, info)(*node_ins, **dict(return_list=True))
new_outs = Scan(inps, outs, info, op.mode)(
*node_ins, **dict(return_list=True)
)
old_new = []
# 3.7 Get replace pairs for those outputs that do not change
......@@ -1688,24 +1693,6 @@ class ScanMerge(GlobalOptimizer):
else:
as_while = False
info = OrderedDict()
info["tap_array"] = []
info["n_seqs"] = sum([nd.op.n_seqs for nd in nodes])
info["n_mit_mot"] = sum([nd.op.n_mit_mot for nd in nodes])
info["n_mit_mot_outs"] = sum([nd.op.n_mit_mot_outs for nd in nodes])
info["mit_mot_out_slices"] = []
info["n_mit_sot"] = sum([nd.op.n_mit_sot for nd in nodes])
info["n_sit_sot"] = sum([nd.op.n_sit_sot for nd in nodes])
info["n_shared_outs"] = sum([nd.op.n_shared_outs for nd in nodes])
info["n_nit_sot"] = sum([nd.op.n_nit_sot for nd in nodes])
info["truncate_gradient"] = nodes[0].op.truncate_gradient
info["name"] = "&".join([nd.op.name for nd in nodes])
info["mode"] = nodes[0].op.mode
info["gpua"] = False
info["as_while"] = as_while
info["profile"] = nodes[0].op.profile
info["allow_gc"] = nodes[0].op.allow_gc
# We keep the inner_ins and inner_outs of each original node separated.
# To be able to recombine them in the right order after the clone,
# we also need to split them by types (seq, mitmot, ...).
......@@ -1725,12 +1712,15 @@ class ScanMerge(GlobalOptimizer):
inner_ins[idx].append(rename(nd.op.inner_seqs(nd.op.inputs), idx))
outer_ins += rename(nd.op.outer_seqs(nd.inputs), idx)
tap_array = ()
mit_mot_out_slices = ()
for idx, nd in enumerate(nodes):
# MitMot
inner_ins[idx].append(rename(nd.op.inner_mitmot(nd.op.inputs), idx))
inner_outs[idx].append(nd.op.inner_mitmot_outs(nd.op.outputs))
info["tap_array"] += nd.op.mitmot_taps()
info["mit_mot_out_slices"] += nd.op.mitmot_out_taps()
tap_array += nd.op.mitmot_taps()
mit_mot_out_slices += nd.op.mitmot_out_taps()
outer_ins += rename(nd.op.outer_mitmot(nd.inputs), idx)
outer_outs += nd.op.outer_mitmot_outs(nd.outputs)
......@@ -1738,14 +1728,14 @@ class ScanMerge(GlobalOptimizer):
# MitSot
inner_ins[idx].append(rename(nd.op.inner_mitsot(nd.op.inputs), idx))
inner_outs[idx].append(nd.op.inner_mitsot_outs(nd.op.outputs))
info["tap_array"] += nd.op.mitsot_taps()
tap_array += nd.op.mitsot_taps()
outer_ins += rename(nd.op.outer_mitsot(nd.inputs), idx)
outer_outs += nd.op.outer_mitsot_outs(nd.outputs)
for idx, nd in enumerate(nodes):
# SitSot
inner_ins[idx].append(rename(nd.op.inner_sitsot(nd.op.inputs), idx))
info["tap_array"] += [[-1] for x in range(nd.op.n_sit_sot)]
tap_array += tuple((-1,) for x in range(nd.op.n_sit_sot))
inner_outs[idx].append(nd.op.inner_sitsot_outs(nd.op.outputs))
outer_ins += rename(nd.op.outer_sitsot(nd.inputs), idx)
outer_outs += nd.op.outer_sitsot_outs(nd.outputs)
......@@ -1834,7 +1824,25 @@ class ScanMerge(GlobalOptimizer):
else:
new_inner_outs += inner_outs[idx][gr_idx]
new_op = Scan(new_inner_ins, new_inner_outs, info)
info = ScanInfo(
tap_array=tap_array,
n_seqs=sum([nd.op.n_seqs for nd in nodes]),
n_mit_mot=sum([nd.op.n_mit_mot for nd in nodes]),
n_mit_mot_outs=sum([nd.op.n_mit_mot_outs for nd in nodes]),
mit_mot_out_slices=mit_mot_out_slices,
n_mit_sot=sum([nd.op.n_mit_sot for nd in nodes]),
n_sit_sot=sum([nd.op.n_sit_sot for nd in nodes]),
n_shared_outs=sum([nd.op.n_shared_outs for nd in nodes]),
n_nit_sot=sum([nd.op.n_nit_sot for nd in nodes]),
truncate_gradient=nodes[0].op.truncate_gradient,
name="&".join([nd.op.name for nd in nodes]),
gpua=False,
as_while=as_while,
profile=nodes[0].op.profile,
allow_gc=nodes[0].op.allow_gc,
)
new_op = Scan(new_inner_ins, new_inner_outs, info, nodes[0].op.mode)
new_outs = new_op(*outer_ins)
if not isinstance(new_outs, (list, tuple)):
......@@ -1932,7 +1940,7 @@ def make_equiv(lo, li):
the equivalence of their corresponding outer inputs.
"""
seeno = OrderedDict()
seeno = {}
left = []
right = []
for o, i in zip(lo, li):
......@@ -1956,7 +1964,7 @@ def scan_merge_inouts(fgraph, node):
node.inputs, node.outputs, node.op.inputs, node.op.outputs, node.op.info
)
inp_equiv = OrderedDict()
inp_equiv = {}
if has_duplicates(a.outer_in_seqs):
new_outer_seqs = []
......@@ -1992,7 +2000,7 @@ def scan_merge_inouts(fgraph, node):
a_inner_outs = a.inner_outputs
inner_outputs = clone_replace(a_inner_outs, replace=inp_equiv)
op = Scan(inner_inputs, inner_outputs, info)
op = Scan(inner_inputs, inner_outputs, info, node.op.mode)
outputs = op(*outer_inputs)
if not isinstance(outputs, (list, tuple)):
......@@ -2019,7 +2027,7 @@ def scan_merge_inouts(fgraph, node):
left += _left
right += _right
if has_duplicates(na.outer_in_mit_mot):
seen = OrderedDict()
seen = {}
for omm, imm, _sl in zip(
na.outer_in_mit_mot, na.inner_in_mit_mot, na.mit_mot_in_slices
):
......@@ -2032,7 +2040,7 @@ def scan_merge_inouts(fgraph, node):
seen[(omm, sl)] = imm
if has_duplicates(na.outer_in_mit_sot):
seen = OrderedDict()
seen = {}
for oms, ims, _sl in zip(
na.outer_in_mit_sot, na.inner_in_mit_sot, na.mit_sot_in_slices
):
......@@ -2117,9 +2125,7 @@ def scan_merge_inouts(fgraph, node):
new_outer_out_mit_mot.append(outer_omm)
na.outer_out_mit_mot = new_outer_out_mit_mot
if remove:
return OrderedDict(
[("remove", remove)] + list(zip(node.outputs, na.outer_outputs))
)
return dict([("remove", remove)] + list(zip(node.outputs, na.outer_outputs)))
return na.outer_outputs
......@@ -2214,15 +2220,17 @@ class PushOutDot1(GlobalOptimizer):
inner_non_seqs = op.inner_non_seqs(op.inputs)
outer_non_seqs = op.outer_non_seqs(node)
new_info = op.info.copy()
st = len(op.mitmot_taps()) + len(op.mitsot_taps())
new_info["tap_array"] = (
new_info["tap_array"][: st + idx]
+ new_info["tap_array"][st + idx + 1 :]
new_info = dataclasses.replace(
op.info,
tap_array=(
op.info.tap_array[: st + idx]
+ op.info.tap_array[st + idx + 1 :]
),
n_sit_sot=op.info.n_sit_sot - 1,
n_nit_sot=op.info.n_nit_sot + 1,
)
new_info["n_sit_sot"] -= 1
new_info["n_nit_sot"] += 1
inner_sitsot = inner_sitsot[:idx] + inner_sitsot[idx + 1 :]
outer_sitsot = outer_sitsot[:idx] + outer_sitsot[idx + 1 :]
inner_sitsot_outs = (
......@@ -2249,7 +2257,7 @@ class PushOutDot1(GlobalOptimizer):
new_inner_inps, new_inner_outs = reconstruct_graph(
_new_inner_inps, _new_inner_outs
)
new_op = Scan(new_inner_inps, new_inner_outs, new_info)
new_op = Scan(new_inner_inps, new_inner_outs, new_info, op.mode)
_scan_inputs = (
[node.inputs[0]]
+ outer_seqs
......
"""This module provides utility functions for the `Scan` `Op`."""
import copy
import dataclasses
import logging
import warnings
from collections import OrderedDict, namedtuple
......@@ -157,21 +158,6 @@ def traverse(out, x, x_copy, d, visited=None):
return d
# Hashing a dictionary/list/tuple by xoring the hash of each element
def hash_listsDictsTuples(x):
hash_value = 0
if isinstance(x, dict):
for k, v in x.items():
hash_value ^= hash_listsDictsTuples(k)
hash_value ^= hash_listsDictsTuples(v)
elif isinstance(x, (list, tuple)):
for v in x:
hash_value ^= hash_listsDictsTuples(v)
else:
hash_value ^= hash(x)
return hash_value
def map_variables(replacer, graphs, additional_inputs=None):
"""Construct new graphs based on 'graphs' with some variables replaced
according to 'replacer'.
......@@ -264,6 +250,7 @@ def map_variables(replacer, graphs, additional_inputs=None):
new_inner_inputs,
new_inner_outputs,
node.op.info,
node.op.mode,
# FIXME: infer this someday?
typeConstructor=None,
)
......@@ -669,7 +656,7 @@ def scan_can_remove_outs(op, out_idxs):
offset = op.n_seqs
lim = op.n_mit_mot + op.n_mit_sot + op.n_sit_sot
for idx in range(lim):
n_ins = len(op.info["tap_array"][idx])
n_ins = len(op.info.tap_array[idx])
out_ins += [op.inputs[offset : offset + n_ins]]
offset += n_ins
out_ins += [[] for k in range(op.n_nit_sot)]
......@@ -702,23 +689,25 @@ def compress_outs(op, not_required, inputs):
node inputs, and changing the dictionary.
"""
info = OrderedDict()
info["tap_array"] = []
info["n_seqs"] = op.info["n_seqs"]
info["n_mit_mot"] = 0
info["n_mit_mot_outs"] = 0
info["mit_mot_out_slices"] = []
info["n_mit_sot"] = 0
info["n_sit_sot"] = 0
info["n_shared_outs"] = 0
info["n_nit_sot"] = 0
info["truncate_gradient"] = op.info["truncate_gradient"]
info["name"] = op.info["name"]
info["gpua"] = op.info["gpua"]
info["mode"] = op.info["mode"]
info["as_while"] = op.info["as_while"]
info["profile"] = op.info["profile"]
info["allow_gc"] = op.info["allow_gc"]
from aesara.scan.op import ScanInfo
info = ScanInfo(
tap_array=(),
n_seqs=op.info.n_seqs,
n_mit_mot=0,
n_mit_mot_outs=0,
mit_mot_out_slices=(),
n_mit_sot=0,
n_sit_sot=0,
n_shared_outs=0,
n_nit_sot=0,
truncate_gradient=op.info.truncate_gradient,
name=op.info.name,
gpua=op.info.gpua,
as_while=op.info.as_while,
profile=op.info.profile,
allow_gc=op.info.allow_gc,
)
op_inputs = op.inputs[: op.n_seqs]
op_outputs = []
......@@ -730,13 +719,17 @@ def compress_outs(op, not_required, inputs):
i_offset = op.n_seqs
o_offset = 0
curr_pos = 0
for idx in range(op.info["n_mit_mot"]):
for idx in range(op.info.n_mit_mot):
if offset + idx not in not_required:
map_old_new[offset + idx] = curr_pos
curr_pos += 1
info["n_mit_mot"] += 1
info["tap_array"] += [op.tap_array[offset + idx]]
info["mit_mot_out_slices"] += [op.mit_mot_out_slices[offset + idx]]
info = dataclasses.replace(
info,
n_mit_mot=info.n_mit_mot + 1,
tap_array=info.tap_array + (tuple(op.tap_array[offset + idx]),),
mit_mot_out_slices=info.mit_mot_out_slices
+ (tuple(op.mit_mot_out_slices[offset + idx]),),
)
# input taps
for jdx in op.tap_array[offset + idx]:
op_inputs += [op.inputs[i_offset]]
......@@ -750,16 +743,19 @@ def compress_outs(op, not_required, inputs):
else:
o_offset += len(op.mit_mot_out_slices[offset + idx])
i_offset += len(op.tap_array[offset + idx])
info["n_mit_mot_outs"] = len(op_outputs)
info = dataclasses.replace(info, n_mit_mot_outs=len(op_outputs))
offset += op.n_mit_mot
ni_offset += op.n_mit_mot
for idx in range(op.info["n_mit_sot"]):
for idx in range(op.info.n_mit_sot):
if offset + idx not in not_required:
map_old_new[offset + idx] = curr_pos
curr_pos += 1
info["n_mit_sot"] += 1
info["tap_array"] += [op.tap_array[offset + idx]]
info = dataclasses.replace(
info,
n_mit_sot=info.n_mit_sot + 1,
tap_array=info.tap_array + (tuple(op.tap_array[offset + idx]),),
)
# input taps
for jdx in op.tap_array[offset + idx]:
op_inputs += [op.inputs[i_offset]]
......@@ -775,12 +771,15 @@ def compress_outs(op, not_required, inputs):
offset += op.n_mit_sot
ni_offset += op.n_mit_sot
for idx in range(op.info["n_sit_sot"]):
for idx in range(op.info.n_sit_sot):
if offset + idx not in not_required:
map_old_new[offset + idx] = curr_pos
curr_pos += 1
info["n_sit_sot"] += 1
info["tap_array"] += [op.tap_array[offset + idx]]
info = dataclasses.replace(
info,
n_sit_sot=info.n_sit_sot + 1,
tap_array=info.tap_array + (tuple(op.tap_array[offset + idx]),),
)
# input taps
op_inputs += [op.inputs[i_offset]]
i_offset += 1
......@@ -796,11 +795,11 @@ def compress_outs(op, not_required, inputs):
offset += op.n_sit_sot
ni_offset += op.n_sit_sot
nit_sot_ins = []
for idx in range(op.info["n_nit_sot"]):
for idx in range(op.info.n_nit_sot):
if offset + idx not in not_required:
map_old_new[offset + idx] = curr_pos
curr_pos += 1
info["n_nit_sot"] += 1
info = dataclasses.replace(info, n_nit_sot=info.n_nit_sot + 1)
op_outputs += [op.outputs[o_offset]]
o_offset += 1
nit_sot_ins += [inputs[ni_offset + idx + op.n_shared_outs]]
......@@ -809,11 +808,11 @@ def compress_outs(op, not_required, inputs):
offset += op.n_nit_sot
shared_ins = []
for idx in range(op.info["n_shared_outs"]):
for idx in range(op.info.n_shared_outs):
if offset + idx not in not_required:
map_old_new[offset + idx] = curr_pos
curr_pos += 1
info["n_shared_outs"] += 1
info = dataclasses.replace(info, n_shared_outs=info.n_shared_outs + 1)
op_outputs += [op.outputs[o_offset]]
o_offset += 1
op_inputs += [op.inputs[i_offset]]
......@@ -896,7 +895,7 @@ class ScanArgs:
else:
rval = (_inner_inputs, _inner_outputs)
if info["as_while"]:
if info.as_while:
self.cond = [rval[1][-1]]
inner_outputs = rval[1][:-1]
else:
......@@ -907,17 +906,17 @@ class ScanArgs:
p = 1
q = 0
n_seqs = info["n_seqs"]
n_seqs = info.n_seqs
self.outer_in_seqs = outer_inputs[p : p + n_seqs]
self.inner_in_seqs = inner_inputs[q : q + n_seqs]
p += n_seqs
q += n_seqs
n_mit_mot = info["n_mit_mot"]
n_mit_sot = info["n_mit_sot"]
n_mit_mot = info.n_mit_mot
n_mit_sot = info.n_mit_sot
self.mit_mot_in_slices = info["tap_array"][:n_mit_mot]
self.mit_sot_in_slices = info["tap_array"][n_mit_mot : n_mit_mot + n_mit_sot]
self.mit_mot_in_slices = info.tap_array[:n_mit_mot]
self.mit_sot_in_slices = info.tap_array[n_mit_mot : n_mit_mot + n_mit_sot]
n_mit_mot_ins = sum(len(s) for s in self.mit_mot_in_slices)
n_mit_sot_ins = sum(len(s) for s in self.mit_sot_in_slices)
......@@ -943,19 +942,19 @@ class ScanArgs:
self.outer_in_mit_sot = outer_inputs[p : p + n_mit_sot]
p += n_mit_sot
n_sit_sot = info["n_sit_sot"]
n_sit_sot = info.n_sit_sot
self.outer_in_sit_sot = outer_inputs[p : p + n_sit_sot]
self.inner_in_sit_sot = inner_inputs[q : q + n_sit_sot]
p += n_sit_sot
q += n_sit_sot
n_shared_outs = info["n_shared_outs"]
n_shared_outs = info.n_shared_outs
self.outer_in_shared = outer_inputs[p : p + n_shared_outs]
self.inner_in_shared = inner_inputs[q : q + n_shared_outs]
p += n_shared_outs
q += n_shared_outs
n_nit_sot = info["n_nit_sot"]
n_nit_sot = info.n_nit_sot
self.outer_in_nit_sot = outer_inputs[p : p + n_nit_sot]
p += n_nit_sot
......@@ -966,14 +965,14 @@ class ScanArgs:
p = 0
q = 0
self.mit_mot_out_slices = info["mit_mot_out_slices"]
n_mit_mot_outs = info["n_mit_mot_outs"]
self.mit_mot_out_slices = info.mit_mot_out_slices
n_mit_mot_outs = info.n_mit_mot_outs
self.outer_out_mit_mot = outer_outputs[p : p + n_mit_mot]
iomm = inner_outputs[q : q + n_mit_mot_outs]
self.inner_out_mit_mot = []
self.inner_out_mit_mot = ()
qq = 0
for sl in self.mit_mot_out_slices:
self.inner_out_mit_mot.append(iomm[qq : qq + len(sl)])
self.inner_out_mit_mot += (iomm[qq : qq + len(sl)],)
qq += len(sl)
p += n_mit_mot
q += n_mit_mot_outs
......@@ -1001,19 +1000,17 @@ class ScanArgs:
assert p == len(outer_outputs)
assert q == len(inner_outputs)
self.other_info = OrderedDict()
for k in (
"truncate_gradient",
"name",
"mode",
"destroy_map",
"gpua",
"as_while",
"profile",
"allow_gc",
):
if k in info:
self.other_info[k] = info[k]
self.other_info = {
k: getattr(info, k)
for k in (
"truncate_gradient",
"name",
"gpua",
"as_while",
"profile",
"allow_gc",
)
}
@staticmethod
def from_node(node, clone=False):
......@@ -1032,26 +1029,24 @@ class ScanArgs:
@classmethod
def create_empty(cls):
info = OrderedDict(
[
("n_seqs", 0),
("n_mit_mot", 0),
("n_mit_sot", 0),
("tap_array", []),
("n_sit_sot", 0),
("n_nit_sot", 0),
("n_shared_outs", 0),
("n_mit_mot_outs", 0),
("mit_mot_out_slices", []),
("truncate_gradient", -1),
("name", None),
("mode", None),
("destroy_map", OrderedDict()),
("gpua", False),
("as_while", False),
("profile", False),
("allow_gc", False),
]
from aesara.scan.op import ScanInfo
info = ScanInfo(
n_seqs=0,
n_mit_mot=0,
n_mit_sot=0,
tap_array=(),
n_sit_sot=0,
n_nit_sot=0,
n_shared_outs=0,
n_mit_mot_outs=0,
mit_mot_out_slices=(),
truncate_gradient=-1,
name=None,
gpua=False,
as_while=False,
profile=False,
allow_gc=False,
)
res = cls([1], [], [], [], info)
res.n_steps = None
......@@ -1060,7 +1055,7 @@ class ScanArgs:
@property
def n_nit_sot(self):
# This is just a hack that allows us to use `Scan.get_oinp_iinp_iout_oout_mappings`
return self.info["n_nit_sot"]
return self.info.n_nit_sot
@property
def inputs(self):
......@@ -1070,7 +1065,7 @@ class ScanArgs:
@property
def n_mit_mot(self):
# This is just a hack that allows us to use `Scan.get_oinp_iinp_iout_oout_mappings`
return self.info["n_mit_mot"]
return self.info.n_mit_mot
@property
def var_mappings(self):
......@@ -1141,20 +1136,22 @@ class ScanArgs:
@property
def info(self):
return OrderedDict(
from aesara.scan.op import ScanInfo
return ScanInfo(
n_seqs=len(self.outer_in_seqs),
n_mit_mot=len(self.outer_in_mit_mot),
n_mit_sot=len(self.outer_in_mit_sot),
tap_array=(
self.mit_mot_in_slices
+ self.mit_sot_in_slices
+ [[-1]] * len(self.inner_in_sit_sot)
tuple(tuple(v) for v in self.mit_mot_in_slices)
+ tuple(tuple(v) for v in self.mit_sot_in_slices)
+ ((-1,),) * len(self.inner_in_sit_sot)
),
n_sit_sot=len(self.outer_in_sit_sot),
n_nit_sot=len(self.outer_in_nit_sot),
n_shared_outs=len(self.outer_in_shared),
n_mit_mot_outs=sum(len(s) for s in self.mit_mot_out_slices),
mit_mot_out_slices=self.mit_mot_out_slices,
mit_mot_out_slices=tuple(self.mit_mot_out_slices),
**self.other_info,
)
......
-e ./
dataclasses>=0.7; python_version < '3.7'
filelock
flake8==3.8.4
pep8
......
#!/usr/bin/env python
import sys
from setuptools import find_packages, setup
import versioneer
......@@ -43,6 +45,11 @@ Programming Language :: Python :: 3.9
"""
CLASSIFIERS = [_f for _f in CLASSIFIERS.split("\n") if _f]
install_requires = ["numpy>=1.17.0", "scipy>=0.14", "filelock"]
if sys.version_info[0:2] < (3, 7):
install_requires += ["dataclasses"]
if __name__ == "__main__":
setup(
name=NAME,
......@@ -57,7 +64,7 @@ if __name__ == "__main__":
license=LICENSE,
platforms=PLATFORMS,
packages=find_packages(exclude=["tests", "tests.*"]),
install_requires=["numpy>=1.17.0", "scipy>=0.14", "filelock"],
install_requires=install_requires,
package_data={
"": [
"*.txt",
......
......@@ -252,10 +252,7 @@ def test_ScanArgs():
# The `scan_args` base class always clones the inner-graph;
# here we make sure it doesn't (and that all the inputs are the same)
assert scan_args.inputs == scan_op.inputs
scan_op_info = dict(scan_op.info)
# The `ScanInfo` dictionary has the wrong order and an extra entry
del scan_op_info["strict"]
assert dict(scan_args.info) == scan_op_info
assert scan_args.info == scan_op.info
assert scan_args.var_mappings == scan_op.var_mappings
# Check that `ScanArgs.find_among_fields` works
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论