提交 9a2280b8 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Move Scan helper methods to ScanMethodsMixin

上级 e85c7fd0
......@@ -237,7 +237,9 @@ N.B.:
outer_inputs = s.owner.inputs
inner_to_outer_inputs = {
inner_inputs[i]: outer_inputs[o]
for i, o in s.owner.op.var_mappings["outer_inp_from_inner_inp"].items()
for i, o in s.owner.op.get_oinp_iinp_iout_oout_mappings()[
"outer_inp_from_inner_inp"
].items()
}
print("", file=_file)
......
......@@ -112,7 +112,383 @@ class ScanInfo:
TensorConstructorType = Callable[[List[bool], Union[str, np.generic]], TensorType]
class Scan(Op):
class ScanMethodsMixin:
def inner_seqs(self, list_inputs):
# Given the list of inner inputs this function grabs those
# corresponding to sequences
return list_inputs[: self.n_seqs]
def outer_seqs(self, list_inputs):
if isinstance(list_inputs, Apply):
list_inputs = list_inputs.inputs
# Given the list of outer inputs this function grabs those
# corresponding to sequences
return list_inputs[1 : 1 + self.n_seqs]
def inner_mitmot(self, list_inputs):
n_taps = sum(len(x) for x in self.tap_array[: self.n_mit_mot])
return list_inputs[self.n_seqs : self.n_seqs + n_taps]
def outer_mitmot(self, list_inputs):
if isinstance(list_inputs, Apply):
list_inputs = list_inputs.inputs
return list_inputs[1 + self.n_seqs : 1 + self.n_seqs + self.n_mit_mot]
def inner_mitmot_outs(self, list_outputs):
n_taps = sum(len(x) for x in self.mit_mot_out_slices)
return list_outputs[:n_taps]
def outer_mitmot_outs(self, list_outputs):
if isinstance(list_outputs, Apply):
list_outputs = list_outputs.outputs
return list_outputs[: self.n_mit_mot]
def mitmot_taps(self):
return self.tap_array[: self.n_mit_mot]
def mitmot_out_taps(self):
return self.mit_mot_out_slices[: self.n_mit_mot]
def inner_mitsot(self, list_inputs):
n_mitmot_taps = sum(len(x) for x in self.tap_array[: self.n_mit_mot])
ntaps_upto_sit_sot = sum(
len(x) for x in self.tap_array[: (self.n_mit_mot + self.n_mit_sot)]
)
return list_inputs[
self.n_seqs + n_mitmot_taps : self.n_seqs + ntaps_upto_sit_sot
]
def outer_mitsot(self, list_inputs):
if isinstance(list_inputs, Apply):
list_inputs = list_inputs.inputs
offset = 1 + self.n_seqs + self.n_mit_mot
return list_inputs[offset : offset + self.n_mit_sot]
def inner_mitsot_outs(self, list_outputs):
n_taps = sum(len(x) for x in self.mit_mot_out_slices)
return list_outputs[n_taps : n_taps + self.n_mit_sot]
def outer_mitsot_outs(self, list_outputs):
if isinstance(list_outputs, Apply):
list_outputs = list_outputs.outputs
return list_outputs[self.n_mit_mot : self.n_mit_mot + self.n_mit_sot]
def mitsot_taps(self):
return self.tap_array[self.n_mit_mot : self.n_mit_mot + self.n_mit_sot]
def inner_sitsot(self, list_inputs):
n_taps_upto_sit_sot = sum(
len(x) for x in self.tap_array[: (self.n_mit_mot + self.n_mit_sot)]
)
offset = self.n_seqs + n_taps_upto_sit_sot
return list_inputs[offset : offset + self.n_sit_sot]
def outer_sitsot(self, list_inputs):
if isinstance(list_inputs, Apply):
list_inputs = list_inputs.inputs
offset = 1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot
return list_inputs[offset : offset + self.n_sit_sot]
def inner_sitsot_outs(self, list_outputs):
n_taps = sum(len(x) for x in self.mit_mot_out_slices)
offset = self.n_mit_sot + n_taps
return list_outputs[offset : offset + self.n_sit_sot]
def outer_sitsot_outs(self, list_outputs):
if isinstance(list_outputs, Apply):
list_outputs = list_outputs.outputs
offset = self.n_mit_mot + self.n_mit_sot
return list_outputs[offset : offset + self.n_sit_sot]
def outer_nitsot(self, list_inputs):
if isinstance(list_inputs, Apply):
list_inputs = list_inputs.inputs
offset = (
1
+ self.n_seqs
+ self.n_mit_mot
+ self.n_mit_sot
+ self.n_sit_sot
+ self.n_shared_outs
)
return list_inputs[offset : offset + self.n_nit_sot]
def inner_nitsot_outs(self, list_outputs):
n_taps = sum(len(x) for x in self.mit_mot_out_slices)
offset = self.n_mit_sot + n_taps + self.n_sit_sot
return list_outputs[offset : offset + self.n_nit_sot]
def outer_nitsot_outs(self, list_outputs):
if isinstance(list_outputs, Apply):
list_outputs = list_outputs.outputs
offset = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
return list_outputs[offset : offset + self.n_nit_sot]
def inner_shared(self, list_inputs):
n_taps_upto_sit_sot = sum(
len(x) for x in self.tap_array[: (self.n_mit_mot + self.n_mit_sot)]
)
offset = self.n_seqs + n_taps_upto_sit_sot + self.n_sit_sot
return list_inputs[offset : offset + self.n_shared_outs]
def outer_shared(self, list_inputs):
if isinstance(list_inputs, Apply):
list_inputs = list_inputs.inputs
offset = 1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
return list_inputs[offset : offset + self.n_shared_outs]
def inner_shared_outs(self, list_outputs):
n_taps = sum(len(x) for x in self.mit_mot_out_slices)
offset = self.n_mit_sot + n_taps + self.n_sit_sot + self.n_nit_sot
return list_outputs[offset : offset + self.n_shared_outs]
def outer_shared_outs(self, list_outputs):
if isinstance(list_outputs, Apply):
list_outputs = list_outputs.outputs
offset = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot + self.n_nit_sot
return list_outputs[offset : offset + self.n_shared_outs]
def inner_non_seqs(self, list_inputs):
n_taps_upto_sit_sot = sum(
len(x) for x in self.tap_array[: (self.n_mit_mot + self.n_mit_sot)]
)
offset = self.n_seqs + n_taps_upto_sit_sot + self.n_sit_sot + self.n_shared_outs
return list_inputs[offset:]
def outer_non_seqs(self, list_inputs):
if isinstance(list_inputs, Apply):
list_inputs = list_inputs.inputs
offset = (
1
+ self.n_seqs
+ self.n_mit_mot
+ self.n_mit_sot
+ self.n_sit_sot
+ self.n_nit_sot
+ self.n_shared_outs
)
return list_inputs[offset:]
def get_oinp_iinp_iout_oout_mappings(self):
"""
Compute and return dictionary mappings between the inputs and
outputs of the inner function and the inputs and outputs of the Scan
node in the outer graph.
The return value is a dictionary in which the keys are the names of
the individual mappings and the values are the mapping dictionaries
themselves. In dictionaries representing mappings to outer variables,
the values are individual integer indices. In dictionaries
representing mappings to inner variables, the values are sequences of
indices because multiple inner variables can be associated with the
same state.
"""
# Lists for outer variables contain individual indices, lists for
# inner variables contain sequences of indices because many inner
# variables can be associated with the same outer variable. The list
# and indices are initialized already containing the data associated
# with the timestep index, the first outer input.
outer_input_indices = [0]
inner_input_indices = [[]]
inner_output_indices = [[]]
outer_output_indices = [-1]
outer_iidx = 1
inner_iidx = 0
inner_oidx = 0
outer_oidx = 0
# Handle sequences inputs
for i in range(self.info.n_seqs):
outer_input_indices.append(outer_iidx)
inner_input_indices.append([inner_iidx])
inner_output_indices.append([])
outer_output_indices.append(-1)
outer_iidx += 1
inner_iidx += 1
inner_oidx += 0
outer_oidx += 0
# Handle mitmots, mitsots and sitsots variables
for i in range(len(self.info.tap_array)):
nb_input_taps = len(self.info.tap_array[i])
if i < self.n_mit_mot:
nb_output_taps = len(self.mit_mot_out_slices[i])
else:
nb_output_taps = 1
outer_input_indices.append(outer_iidx)
inner_input_indices.append(
list(range(inner_iidx, inner_iidx + nb_input_taps))
)
inner_output_indices.append(
list(range(inner_oidx, inner_oidx + nb_output_taps))
)
outer_output_indices.append(outer_oidx)
outer_iidx += 1
inner_iidx += nb_input_taps
inner_oidx += nb_output_taps
outer_oidx += 1
# This is needed because, for outer inputs (and for outer inputs only)
# nitsots come *after* shared variables.
outer_iidx += self.info.n_shared_outs
# Handle nitsots variables
for i in range(self.n_nit_sot):
outer_input_indices.append(outer_iidx)
inner_input_indices.append([])
inner_output_indices.append([inner_oidx])
outer_output_indices.append(outer_oidx)
outer_iidx += 1
inner_iidx += 0
inner_oidx += 1
outer_oidx += 1
# This is needed because, for outer inputs (and for outer inputs only)
# nitsots come *after* shared variables.
outer_iidx -= self.info.n_shared_outs + self.n_nit_sot
# Handle shared states
for i in range(self.info.n_shared_outs):
outer_input_indices.append(outer_iidx)
inner_input_indices.append([inner_iidx])
inner_output_indices.append([inner_oidx])
outer_output_indices.append(outer_oidx)
outer_iidx += 1
inner_iidx += 1
inner_oidx += 1
outer_oidx += 1
# This is needed because, for outer inputs (and for outer inputs only)
# nitsots come *after* shared variables.
outer_iidx += self.n_nit_sot
# Handle non-sequence inputs
# Note : the number of non-sequence inputs is not stored in self.info
# so it has to be inferred from the number of inner inputs that remain
# to be handled
for i in range(len(self.inputs) - inner_iidx):
outer_input_indices.append(outer_iidx)
inner_input_indices.append([inner_iidx])
inner_output_indices.append([])
outer_output_indices.append(-1)
outer_iidx += 1
inner_iidx += 1
inner_oidx += 0
outer_oidx += 0
# With the global mapping inferred, the individual mappings
# can be produced
mappings = {
"outer_inp_from_outer_out": {},
"inner_inp_from_outer_out": {},
"inner_out_from_outer_out": {},
"inner_inp_from_outer_inp": {},
"inner_out_from_outer_inp": {},
"outer_out_from_outer_inp": {},
"outer_inp_from_inner_inp": {},
"inner_out_from_inner_inp": {},
"outer_out_from_inner_inp": {},
"outer_inp_from_inner_out": {},
"inner_inp_from_inner_out": {},
"outer_out_from_inner_out": {},
}
for (oinp, iinp, iout, oout) in zip(
outer_input_indices,
inner_input_indices,
inner_output_indices,
outer_output_indices,
):
if oout != -1:
mappings["outer_inp_from_outer_out"][oout] = oinp
mappings["inner_inp_from_outer_out"][oout] = iinp
mappings["inner_out_from_outer_out"][oout] = iout
if oinp != -1:
mappings["inner_inp_from_outer_inp"][oinp] = iinp
mappings["inner_out_from_outer_inp"][oinp] = iout
mappings["outer_out_from_outer_inp"][oinp] = oout
for idx in iinp:
mappings["outer_inp_from_inner_inp"][idx] = oinp
mappings["inner_out_from_inner_inp"][idx] = iout
mappings["outer_out_from_inner_inp"][idx] = oout
for idx in iout:
mappings["outer_inp_from_inner_out"][idx] = oinp
mappings["inner_inp_from_inner_out"][idx] = iinp
mappings["outer_out_from_inner_out"][idx] = oout
return mappings
def validate_inner_graph(self):
"""
Perform some elementary validations on the inner graph to ensure
that it is coherent.
"""
# For every recurrent output, iterate over the associated inner
# inputs and output and ensure that they have the same dtype
nb_recurr_outputs = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
var_mappings = self.get_oinp_iinp_iout_oout_mappings()
for outer_oidx in range(nb_recurr_outputs):
inner_iidxs = var_mappings["inner_inp_from_outer_out"][outer_oidx]
inner_oidxs = var_mappings["inner_out_from_outer_out"][outer_oidx]
for (inner_iidx, inner_oidx) in itertools.product(inner_iidxs, inner_oidxs):
type_input = self.inputs[inner_iidx].type
type_output = self.outputs[inner_oidx].type
if type_input != type_output:
raise TypeError(
"Inconsistency in the inner graph of "
f"scan '{self.name}' : an input and an output are "
"associated with the same recurrent state "
"and should have the same type but have "
f"type '{type_input}' and '{type_output}' respectively."
)
# If scan has the flag 'gpua' set to false (meaning that is shouldn't
# use the gpuarray gpu backend ), ensure that is has no input and no
# output with type GpuArrayType
from aesara.gpuarray import GpuArrayType
if not self.info.gpua:
for inp in self.inputs:
if isinstance(inp.type, GpuArrayType):
raise TypeError(
"Inconsistency in the inner graph of "
f"scan '{self.name}' : one of the inputs to the "
"inner graph is of type GpuArrayType but "
"the attributes of the scan op indicate "
"that it shouldn't be the case"
)
for out in self.outputs:
if isinstance(out.type, GpuArrayType):
raise TypeError(
"Inconsistency in the inner graph of "
f"scan '{self.name}' : one of the outputs to the "
"inner graph is of type GpuArrayType but "
"the attributes of the scan op indicate "
"that it shouldn't be the case"
)
class Scan(Op, ScanMethodsMixin):
def __init__(
self,
inputs: List[Variable],
......@@ -242,73 +618,9 @@ class Scan(Op):
)
self._hash_inner_graph = hash(self._cmodule_key)
# Compute mappings between outer inputs, outer outputs, inner
# inputs and inner outputs to determine with variables are associated
# with the same states.
self.var_mappings = self.get_oinp_iinp_iout_oout_mappings()
def validate_inner_graph(self):
"""
Perform some elementary validations on the inner graph to ensure
that it is coherent.
"""
# For every recurrent output, iterate over the associated inner
# inputs and output and ensure that they have the same dtype
nb_recurr_outputs = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
for outer_oidx in range(nb_recurr_outputs):
inner_iidxs = self.var_mappings["inner_inp_from_outer_out"][outer_oidx]
inner_oidxs = self.var_mappings["inner_out_from_outer_out"][outer_oidx]
for (inner_iidx, inner_oidx) in itertools.product(inner_iidxs, inner_oidxs):
type_input = self.inputs[inner_iidx].type
type_output = self.outputs[inner_oidx].type
if type_input != type_output:
raise TypeError(
"Inconsistency in the inner graph of "
f"scan '{self.name}' : an input and an output are "
"associated with the same recurrent state "
"and should have the same type but have "
f"type '{type_input}' and '{type_output}' respectively."
)
# If scan has the flag 'gpua' set to false (meaning that is shouldn't
# use the gpuarray gpu backend ), ensure that is has no input and no
# output with type GpuArrayType
from aesara.gpuarray import GpuArrayType
if not self.info.gpua:
for inp in self.inputs:
if isinstance(inp.type, GpuArrayType):
raise TypeError(
"Inconsistency in the inner graph of "
f"scan '{self.name}' : one of the inputs to the "
"inner graph is of type GpuArrayType but "
"the attributes of the scan op indicate "
"that it shouldn't be the case"
)
for out in self.outputs:
if isinstance(out.type, GpuArrayType):
raise TypeError(
"Inconsistency in the inner graph of "
f"scan '{self.name}' : one of the outputs to the "
"inner graph is of type GpuArrayType but "
"the attributes of the scan op indicate "
"that it shouldn't be the case"
)
def __setstate__(self, d):
self.__dict__.update(d)
if not hasattr(self, "var_mappings"):
# Generate the mappings between inner and outer inputs and outputs
# if they haven't already been generated.
self.var_mappings = self.get_oinp_iinp_iout_oout_mappings()
if hasattr(self, "fn"):
if not hasattr(self, "thunk_mit_mot_out_slices"):
# The thunk has been compiled before mit_mot preallocation
......@@ -1010,222 +1322,66 @@ class Scan(Op):
if self.destroy_map:
cython_destroy_map = [
x in self.destroy_map for x in range(len(node.outputs))
]
else:
cython_destroy_map = [0 for x in range(len(node.outputs))]
cython_destroy_map = np.asarray(cython_destroy_map, dtype="int32")
from . import scan_perform_ext
def p(node, args, outs):
return scan_perform_ext.perform(
self.n_shared_outs,
self.n_mit_mot_outs,
self.n_seqs,
self.n_mit_mot,
self.n_mit_sot,
self.n_sit_sot,
self.n_nit_sot,
args[0],
self.as_while,
cython_mintaps,
cython_tap_array,
cython_tap_array_len,
cython_vector_seqs,
cython_vector_outs,
cython_mit_mot_out_slices,
cython_mit_mot_out_nslices,
cython_mitmots_preallocated,
cython_inps_is_tensor,
cython_outs_is_tensor,
self.fn.fn,
self.fn,
cython_destroy_map,
args,
outs,
self,
node,
)
except (ImportError, MissingGXX):
p = self.perform
# default arguments are stored in the closure of `rval`
# Big ugly hack since we can't get the real value of allow_gc
# for the englobing function.
allow_gc = config.allow_gc and not self.allow_gc
def rval(
p=p, i=node_input_storage, o=node_output_storage, n=node, allow_gc=allow_gc
):
r = p(n, [x[0] for x in i], o)
for o in node.outputs:
compute_map[o][0] = True
if allow_gc:
self.fn.free()
return r
rval.inputs = node_input_storage
rval.outputs = node_output_storage
rval.perform = p
rval.lazy = False
return rval
def inner_seqs(self, list_inputs):
# Given the list of inner inputs this function grabs those
# corresponding to sequences
return list_inputs[: self.n_seqs]
def outer_seqs(self, list_inputs):
if isinstance(list_inputs, Apply):
list_inputs = list_inputs.inputs
# Given the list of outer inputs this function grabs those
# corresponding to sequences
return list_inputs[1 : 1 + self.n_seqs]
def inner_mitmot(self, list_inputs):
n_taps = sum(len(x) for x in self.tap_array[: self.n_mit_mot])
return list_inputs[self.n_seqs : self.n_seqs + n_taps]
def outer_mitmot(self, list_inputs):
if isinstance(list_inputs, Apply):
list_inputs = list_inputs.inputs
return list_inputs[1 + self.n_seqs : 1 + self.n_seqs + self.n_mit_mot]
def inner_mitmot_outs(self, list_outputs):
n_taps = sum(len(x) for x in self.mit_mot_out_slices)
return list_outputs[:n_taps]
def outer_mitmot_outs(self, list_outputs):
if isinstance(list_outputs, Apply):
list_outputs = list_outputs.outputs
return list_outputs[: self.n_mit_mot]
def mitmot_taps(self):
return self.tap_array[: self.n_mit_mot]
def mitmot_out_taps(self):
return self.mit_mot_out_slices[: self.n_mit_mot]
def inner_mitsot(self, list_inputs):
n_mitmot_taps = sum(len(x) for x in self.tap_array[: self.n_mit_mot])
ntaps_upto_sit_sot = sum(
len(x) for x in self.tap_array[: (self.n_mit_mot + self.n_mit_sot)]
)
return list_inputs[
self.n_seqs + n_mitmot_taps : self.n_seqs + ntaps_upto_sit_sot
]
def outer_mitsot(self, list_inputs):
if isinstance(list_inputs, Apply):
list_inputs = list_inputs.inputs
offset = 1 + self.n_seqs + self.n_mit_mot
return list_inputs[offset : offset + self.n_mit_sot]
def inner_mitsot_outs(self, list_outputs):
n_taps = sum(len(x) for x in self.mit_mot_out_slices)
return list_outputs[n_taps : n_taps + self.n_mit_sot]
def outer_mitsot_outs(self, list_outputs):
if isinstance(list_outputs, Apply):
list_outputs = list_outputs.outputs
return list_outputs[self.n_mit_mot : self.n_mit_mot + self.n_mit_sot]
def mitsot_taps(self):
return self.tap_array[self.n_mit_mot : self.n_mit_mot + self.n_mit_sot]
def inner_sitsot(self, list_inputs):
n_taps_upto_sit_sot = sum(
len(x) for x in self.tap_array[: (self.n_mit_mot + self.n_mit_sot)]
)
offset = self.n_seqs + n_taps_upto_sit_sot
return list_inputs[offset : offset + self.n_sit_sot]
def outer_sitsot(self, list_inputs):
if isinstance(list_inputs, Apply):
list_inputs = list_inputs.inputs
offset = 1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot
return list_inputs[offset : offset + self.n_sit_sot]
def inner_sitsot_outs(self, list_outputs):
n_taps = sum(len(x) for x in self.mit_mot_out_slices)
offset = self.n_mit_sot + n_taps
return list_outputs[offset : offset + self.n_sit_sot]
def outer_sitsot_outs(self, list_outputs):
if isinstance(list_outputs, Apply):
list_outputs = list_outputs.outputs
offset = self.n_mit_mot + self.n_mit_sot
return list_outputs[offset : offset + self.n_sit_sot]
def outer_nitsot(self, list_inputs):
if isinstance(list_inputs, Apply):
list_inputs = list_inputs.inputs
offset = (
1
+ self.n_seqs
+ self.n_mit_mot
+ self.n_mit_sot
+ self.n_sit_sot
+ self.n_shared_outs
)
return list_inputs[offset : offset + self.n_nit_sot]
def inner_nitsot_outs(self, list_outputs):
n_taps = sum(len(x) for x in self.mit_mot_out_slices)
offset = self.n_mit_sot + n_taps + self.n_sit_sot
return list_outputs[offset : offset + self.n_nit_sot]
def outer_nitsot_outs(self, list_outputs):
if isinstance(list_outputs, Apply):
list_outputs = list_outputs.outputs
offset = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
return list_outputs[offset : offset + self.n_nit_sot]
]
else:
cython_destroy_map = [0 for x in range(len(node.outputs))]
cython_destroy_map = np.asarray(cython_destroy_map, dtype="int32")
from . import scan_perform_ext
def inner_shared(self, list_inputs):
n_taps_upto_sit_sot = sum(
len(x) for x in self.tap_array[: (self.n_mit_mot + self.n_mit_sot)]
)
offset = self.n_seqs + n_taps_upto_sit_sot + self.n_sit_sot
return list_inputs[offset : offset + self.n_shared_outs]
def p(node, args, outs):
return scan_perform_ext.perform(
self.n_shared_outs,
self.n_mit_mot_outs,
self.n_seqs,
self.n_mit_mot,
self.n_mit_sot,
self.n_sit_sot,
self.n_nit_sot,
args[0],
self.as_while,
cython_mintaps,
cython_tap_array,
cython_tap_array_len,
cython_vector_seqs,
cython_vector_outs,
cython_mit_mot_out_slices,
cython_mit_mot_out_nslices,
cython_mitmots_preallocated,
cython_inps_is_tensor,
cython_outs_is_tensor,
self.fn.fn,
self.fn,
cython_destroy_map,
args,
outs,
self,
node,
)
def outer_shared(self, list_inputs):
if isinstance(list_inputs, Apply):
list_inputs = list_inputs.inputs
offset = 1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
return list_inputs[offset : offset + self.n_shared_outs]
except (ImportError, MissingGXX):
p = self.perform
def inner_shared_outs(self, list_outputs):
n_taps = sum(len(x) for x in self.mit_mot_out_slices)
offset = self.n_mit_sot + n_taps + self.n_sit_sot + self.n_nit_sot
return list_outputs[offset : offset + self.n_shared_outs]
# default arguments are stored in the closure of `rval`
def outer_shared_outs(self, list_outputs):
if isinstance(list_outputs, Apply):
list_outputs = list_outputs.outputs
offset = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot + self.n_nit_sot
return list_outputs[offset : offset + self.n_shared_outs]
# Big ugly hack since we can't get the real value of allow_gc
# for the englobing function.
allow_gc = config.allow_gc and not self.allow_gc
def inner_non_seqs(self, list_inputs):
n_taps_upto_sit_sot = sum(
len(x) for x in self.tap_array[: (self.n_mit_mot + self.n_mit_sot)]
)
offset = self.n_seqs + n_taps_upto_sit_sot + self.n_sit_sot + self.n_shared_outs
return list_inputs[offset:]
def rval(
p=p, i=node_input_storage, o=node_output_storage, n=node, allow_gc=allow_gc
):
r = p(n, [x[0] for x in i], o)
for o in node.outputs:
compute_map[o][0] = True
if allow_gc:
self.fn.free()
return r
def outer_non_seqs(self, list_inputs):
if isinstance(list_inputs, Apply):
list_inputs = list_inputs.inputs
offset = (
1
+ self.n_seqs
+ self.n_mit_mot
+ self.n_mit_sot
+ self.n_sit_sot
+ self.n_nit_sot
+ self.n_shared_outs
)
return list_inputs[offset:]
rval.inputs = node_input_storage
rval.outputs = node_output_storage
rval.perform = p
rval.lazy = False
return rval
def perform(self, node, inputs, output_storage, params=None):
"""Compute the scan operation in Python.
......@@ -1885,11 +2041,13 @@ class Scan(Op):
# over every possible pairing of their corresponding inner inputs
# and inner outputs and, if one such pair of inner variables is
# connected than the pair of outer variables is connected.
var_mappings = self.get_oinp_iinp_iout_oout_mappings()
for outer_oidx in range(len(node.outputs)):
inner_oidxs = self.var_mappings["inner_out_from_outer_out"][outer_oidx]
inner_oidxs = var_mappings["inner_out_from_outer_out"][outer_oidx]
for outer_iidx in range(len(node.inputs)):
inner_iidxs = self.var_mappings["inner_inp_from_outer_inp"][outer_iidx]
inner_iidxs = var_mappings["inner_inp_from_outer_inp"][outer_iidx]
for inner_oidx in inner_oidxs:
for inner_iidx in inner_iidxs:
......@@ -1913,7 +2071,7 @@ class Scan(Op):
# Get the idx of the outer input corresponding to that
# outer output
j_inp_idx = self.var_mappings["outer_inp_from_outer_out"][jidx]
j_inp_idx = var_mappings["outer_inp_from_outer_out"][jidx]
if j_inp_idx != -1:
if connection_pattern[j_inp_idx][iidx] is True:
......@@ -1924,168 +2082,6 @@ class Scan(Op):
node.tag.connection_pattern = connection_pattern
return connection_pattern
def get_oinp_iinp_iout_oout_mappings(self):
"""
Compute and return dictionary mappings between the inputs and
outputs of the inner function and the inputs and outputs of the Scan
node in the outer graph.
The return value is a dictionary in which the keys are the names of
the individual mappings and the values are the mapping dictionaries
themselves. In dictionaries representing mappings to outer variables,
the values are individual integer indices. In dictionaries
representing mappings to inner variables, the values are sequences of
indices because multiple inner variables can be associated with the
same state.
"""
# Lists for outer variables contain individual indices, lists for
# inner variables contain sequences of indices because many inner
# variables can be associated with the same outer variable. The list
# and indices are initialized already containing the data associated
# with the timestep index, the first outer input.
outer_input_indices = [0]
inner_input_indices = [[]]
inner_output_indices = [[]]
outer_output_indices = [-1]
outer_iidx = 1
inner_iidx = 0
inner_oidx = 0
outer_oidx = 0
# Handle sequences inputs
for i in range(self.info.n_seqs):
outer_input_indices.append(outer_iidx)
inner_input_indices.append([inner_iidx])
inner_output_indices.append([])
outer_output_indices.append(-1)
outer_iidx += 1
inner_iidx += 1
inner_oidx += 0
outer_oidx += 0
# Handle mitmots, mitsots and sitsots variables
for i in range(len(self.info.tap_array)):
nb_input_taps = len(self.info.tap_array[i])
if i < self.n_mit_mot:
nb_output_taps = len(self.mit_mot_out_slices[i])
else:
nb_output_taps = 1
outer_input_indices.append(outer_iidx)
inner_input_indices.append(
list(range(inner_iidx, inner_iidx + nb_input_taps))
)
inner_output_indices.append(
list(range(inner_oidx, inner_oidx + nb_output_taps))
)
outer_output_indices.append(outer_oidx)
outer_iidx += 1
inner_iidx += nb_input_taps
inner_oidx += nb_output_taps
outer_oidx += 1
# This is needed because, for outer inputs (and for outer inputs only)
# nitsots come *after* shared variables.
outer_iidx += self.info.n_shared_outs
# Handle nitsots variables
for i in range(self.n_nit_sot):
outer_input_indices.append(outer_iidx)
inner_input_indices.append([])
inner_output_indices.append([inner_oidx])
outer_output_indices.append(outer_oidx)
outer_iidx += 1
inner_iidx += 0
inner_oidx += 1
outer_oidx += 1
# This is needed because, for outer inputs (and for outer inputs only)
# nitsots come *after* shared variables.
outer_iidx -= self.info.n_shared_outs + self.n_nit_sot
# Handle shared states
for i in range(self.info.n_shared_outs):
outer_input_indices.append(outer_iidx)
inner_input_indices.append([inner_iidx])
inner_output_indices.append([inner_oidx])
outer_output_indices.append(outer_oidx)
outer_iidx += 1
inner_iidx += 1
inner_oidx += 1
outer_oidx += 1
# This is needed because, for outer inputs (and for outer inputs only)
# nitsots come *after* shared variables.
outer_iidx += self.n_nit_sot
# Handle non-sequence inputs
# Note : the number of non-sequence inputs is not stored in self.info
# so it has to be inferred from the number of inner inputs that remain
# to be handled
for i in range(len(self.inputs) - inner_iidx):
outer_input_indices.append(outer_iidx)
inner_input_indices.append([inner_iidx])
inner_output_indices.append([])
outer_output_indices.append(-1)
outer_iidx += 1
inner_iidx += 1
inner_oidx += 0
outer_oidx += 0
# With the global mapping inferred, the individual mappings
# can be produced
mappings = {
"outer_inp_from_outer_out": {},
"inner_inp_from_outer_out": {},
"inner_out_from_outer_out": {},
"inner_inp_from_outer_inp": {},
"inner_out_from_outer_inp": {},
"outer_out_from_outer_inp": {},
"outer_inp_from_inner_inp": {},
"inner_out_from_inner_inp": {},
"outer_out_from_inner_inp": {},
"outer_inp_from_inner_out": {},
"inner_inp_from_inner_out": {},
"outer_out_from_inner_out": {},
}
for (oinp, iinp, iout, oout) in zip(
outer_input_indices,
inner_input_indices,
inner_output_indices,
outer_output_indices,
):
if oout != -1:
mappings["outer_inp_from_outer_out"][oout] = oinp
mappings["inner_inp_from_outer_out"][oout] = iinp
mappings["inner_out_from_outer_out"][oout] = iout
if oinp != -1:
mappings["inner_inp_from_outer_inp"][oinp] = iinp
mappings["inner_out_from_outer_inp"][oinp] = iout
mappings["outer_out_from_outer_inp"][oinp] = oout
for idx in iinp:
mappings["outer_inp_from_inner_inp"][idx] = oinp
mappings["inner_out_from_inner_inp"][idx] = iout
mappings["outer_out_from_inner_inp"][idx] = oout
for idx in iout:
mappings["outer_inp_from_inner_out"][idx] = oinp
mappings["inner_inp_from_inner_out"][idx] = iinp
mappings["outer_out_from_inner_out"][idx] = oout
return mappings
def L_op(self, inputs, outs, dC_douts):
if not isinstance(outs, (list, tuple)):
outs = [outs]
......@@ -2217,6 +2213,7 @@ class Scan(Op):
rval = [gmp.get(p, None) for p in diff_inputs]
return rval
var_mappings = self.get_oinp_iinp_iout_oout_mappings()
dC_dinps_t = [None for inp in diff_inputs]
disconnected_dC_dinps_t = [True for inp in diff_inputs]
dC_dXts = []
......@@ -2254,7 +2251,7 @@ class Scan(Op):
if inp in graph_inputs([Xt]):
# Get the index of the outer output that to which
# the state variable 'inp' corresponds.
outer_oidx = self.var_mappings["outer_out_from_inner_inp"][
outer_oidx = var_mappings["outer_out_from_inner_inp"][
self.n_seqs + pos
]
......@@ -2307,7 +2304,7 @@ class Scan(Op):
# Get the index of the first inner input corresponding to the
# pos-ieth inner input state
idxs = self.var_mappings["inner_out_from_inner_inp"][self.n_seqs + pos]
idxs = var_mappings["inner_out_from_inner_inp"][self.n_seqs + pos]
# Check if the pos-th input is associated with one of the
# recurrent states
......
......@@ -1069,9 +1069,9 @@ class ScanArgs:
@property
def var_mappings(self):
from aesara.scan.op import Scan
from aesara.scan.op import ScanMethodsMixin
return Scan.get_oinp_iinp_iout_oout_mappings(self)
return ScanMethodsMixin.get_oinp_iinp_iout_oout_mappings(self)
@property
def field_names(self):
......
......@@ -299,7 +299,7 @@ If the goal is to navigate between variables that are associated with the same
states (ex : going from an outer sequence input to the corresponding inner
sequence input, going from an inner output associated with a recurrent state
to the inner input(s) associated with that same recurrent state, etc.), then
the ``var_mappings`` attribute of the scan op can be used.
the `get_oinp_iinp_iout_oout_mappings_mappings` method of the `Scan` `Op` can be used.
This attribute is a dictionary with 12 {key/value} pairs. The keys are listed
below :
......
......@@ -700,11 +700,12 @@ class TestScan:
# outer_inp_from_inner_inp produce the correct results
scan_node = a.owner.inputs[0].owner
result = scan_node.op.var_mappings["outer_inp_from_outer_out"]
var_mappings = scan_node.op.get_oinp_iinp_iout_oout_mappings()
result = var_mappings["outer_inp_from_outer_out"]
expected_result = {0: 1, 1: 2}
assert result == expected_result
result = scan_node.op.var_mappings["outer_inp_from_inner_inp"]
result = var_mappings["outer_inp_from_inner_inp"]
expected_result = {0: 1, 1: 1, 2: 2, 3: 2}
assert result == expected_result
......@@ -733,11 +734,12 @@ class TestScan:
# outer_inp_from_inner_inp produce the correct results
scan_node = out.owner.inputs[0].owner
result = scan_node.op.var_mappings["outer_inp_from_outer_out"]
var_mappings = scan_node.op.get_oinp_iinp_iout_oout_mappings()
result = var_mappings["outer_inp_from_outer_out"]
expected_result = {0: 2}
assert result == expected_result
result = scan_node.op.var_mappings["outer_inp_from_inner_inp"]
result = var_mappings["outer_inp_from_inner_inp"]
expected_result = {0: 1, 1: 2, 2: 2}
assert result == expected_result
......@@ -1685,11 +1687,12 @@ class TestScan:
# outer_inp_from_inner_inp produce the correct results
scan_node = list(updates.values())[0].owner
result = scan_node.op.var_mappings["outer_inp_from_outer_out"]
var_mappings = scan_node.op.get_oinp_iinp_iout_oout_mappings()
result = var_mappings["outer_inp_from_outer_out"]
expected_result = {0: 3, 1: 5, 2: 4}
assert result == expected_result
result = scan_node.op.var_mappings["outer_inp_from_inner_inp"]
result = var_mappings["outer_inp_from_inner_inp"]
expected_result = {0: 1, 1: 2, 2: 3, 3: 4, 4: 6}
assert result == expected_result
......@@ -3491,7 +3494,7 @@ class TestScan:
# Compare the mappings with the expected values
scan_node = scan_outputs[0].owner.inputs[0].owner
mappings = scan_node.op.var_mappings
mappings = scan_node.op.get_oinp_iinp_iout_oout_mappings()
assert mappings["inner_inp_from_outer_inp"] == {
0: [],
......
......@@ -253,7 +253,6 @@ def test_ScanArgs():
# here we make sure it doesn't (and that all the inputs are the same)
assert scan_args.inputs == scan_op.inputs
assert scan_args.info == scan_op.info
assert scan_args.var_mappings == scan_op.var_mappings
# Check that `ScanArgs.find_among_fields` works
test_v = scan_op.inner_seqs(scan_op.inputs)[1]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论