提交 05da60c3 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use ScanInfo attributes directly in Scan Op

上级 af440de8
......@@ -224,132 +224,160 @@ class ScanMethodsMixin:
def inner_seqs(self, list_inputs):
# Given the list of inner inputs this function grabs those
# corresponding to sequences
return list_inputs[: self.n_seqs]
return list_inputs[: self.info.n_seqs]
def outer_seqs(self, list_inputs):
# Given the list of outer inputs this function grabs those
# corresponding to sequences
return list_inputs[1 : 1 + self.n_seqs]
return list_inputs[1 : 1 + self.info.n_seqs]
def inner_mitmot(self, list_inputs):
n_taps = sum(len(x) for x in self.tap_array[: self.n_mit_mot])
return list_inputs[self.n_seqs : self.n_seqs + n_taps]
n_taps = sum(len(x) for x in self.info.tap_array[: self.info.n_mit_mot])
return list_inputs[self.info.n_seqs : self.info.n_seqs + n_taps]
def outer_mitmot(self, list_inputs):
return list_inputs[1 + self.n_seqs : 1 + self.n_seqs + self.n_mit_mot]
return list_inputs[
1 + self.info.n_seqs : 1 + self.info.n_seqs + self.info.n_mit_mot
]
def inner_mitmot_outs(self, list_outputs):
n_taps = sum(len(x) for x in self.mit_mot_out_slices)
n_taps = sum(len(x) for x in self.info.mit_mot_out_slices)
return list_outputs[:n_taps]
def outer_mitmot_outs(self, list_outputs):
return list_outputs[: self.n_mit_mot]
return list_outputs[: self.info.n_mit_mot]
def mitmot_taps(self):
return self.tap_array[: self.n_mit_mot]
return self.info.tap_array[: self.info.n_mit_mot]
def mitmot_out_taps(self):
return self.mit_mot_out_slices[: self.n_mit_mot]
return self.info.mit_mot_out_slices[: self.info.n_mit_mot]
def inner_mitsot(self, list_inputs):
n_mitmot_taps = sum(len(x) for x in self.tap_array[: self.n_mit_mot])
n_mitmot_taps = sum(len(x) for x in self.info.tap_array[: self.info.n_mit_mot])
ntaps_upto_sit_sot = sum(
len(x) for x in self.tap_array[: (self.n_mit_mot + self.n_mit_sot)]
len(x)
for x in self.info.tap_array[: (self.info.n_mit_mot + self.info.n_mit_sot)]
)
return list_inputs[
self.n_seqs + n_mitmot_taps : self.n_seqs + ntaps_upto_sit_sot
self.info.n_seqs + n_mitmot_taps : self.info.n_seqs + ntaps_upto_sit_sot
]
def outer_mitsot(self, list_inputs):
offset = 1 + self.n_seqs + self.n_mit_mot
return list_inputs[offset : offset + self.n_mit_sot]
offset = 1 + self.info.n_seqs + self.info.n_mit_mot
return list_inputs[offset : offset + self.info.n_mit_sot]
def inner_mitsot_outs(self, list_outputs):
n_taps = sum(len(x) for x in self.mit_mot_out_slices)
return list_outputs[n_taps : n_taps + self.n_mit_sot]
n_taps = sum(len(x) for x in self.info.mit_mot_out_slices)
return list_outputs[n_taps : n_taps + self.info.n_mit_sot]
def outer_mitsot_outs(self, list_outputs):
return list_outputs[self.n_mit_mot : self.n_mit_mot + self.n_mit_sot]
return list_outputs[
self.info.n_mit_mot : self.info.n_mit_mot + self.info.n_mit_sot
]
def mitsot_taps(self):
return self.tap_array[self.n_mit_mot : self.n_mit_mot + self.n_mit_sot]
return self.info.tap_array[
self.info.n_mit_mot : self.info.n_mit_mot + self.info.n_mit_sot
]
def inner_sitsot(self, list_inputs):
n_taps_upto_sit_sot = sum(
len(x) for x in self.tap_array[: (self.n_mit_mot + self.n_mit_sot)]
len(x)
for x in self.info.tap_array[: (self.info.n_mit_mot + self.info.n_mit_sot)]
)
offset = self.n_seqs + n_taps_upto_sit_sot
return list_inputs[offset : offset + self.n_sit_sot]
offset = self.info.n_seqs + n_taps_upto_sit_sot
return list_inputs[offset : offset + self.info.n_sit_sot]
def outer_sitsot(self, list_inputs):
offset = 1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot
return list_inputs[offset : offset + self.n_sit_sot]
offset = 1 + self.info.n_seqs + self.info.n_mit_mot + self.info.n_mit_sot
return list_inputs[offset : offset + self.info.n_sit_sot]
def inner_sitsot_outs(self, list_outputs):
n_taps = sum(len(x) for x in self.mit_mot_out_slices)
offset = self.n_mit_sot + n_taps
return list_outputs[offset : offset + self.n_sit_sot]
n_taps = sum(len(x) for x in self.info.mit_mot_out_slices)
offset = self.info.n_mit_sot + n_taps
return list_outputs[offset : offset + self.info.n_sit_sot]
def outer_sitsot_outs(self, list_outputs):
offset = self.n_mit_mot + self.n_mit_sot
return list_outputs[offset : offset + self.n_sit_sot]
offset = self.info.n_mit_mot + self.info.n_mit_sot
return list_outputs[offset : offset + self.info.n_sit_sot]
def outer_nitsot(self, list_inputs):
offset = (
1
+ self.n_seqs
+ self.n_mit_mot
+ self.n_mit_sot
+ self.n_sit_sot
+ self.n_shared_outs
+ self.info.n_seqs
+ self.info.n_mit_mot
+ self.info.n_mit_sot
+ self.info.n_sit_sot
+ self.info.n_shared_outs
)
return list_inputs[offset : offset + self.n_nit_sot]
return list_inputs[offset : offset + self.info.n_nit_sot]
def inner_nitsot_outs(self, list_outputs):
n_taps = sum(len(x) for x in self.mit_mot_out_slices)
offset = self.n_mit_sot + n_taps + self.n_sit_sot
return list_outputs[offset : offset + self.n_nit_sot]
n_taps = sum(len(x) for x in self.info.mit_mot_out_slices)
offset = self.info.n_mit_sot + n_taps + self.info.n_sit_sot
return list_outputs[offset : offset + self.info.n_nit_sot]
def outer_nitsot_outs(self, list_outputs):
offset = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
return list_outputs[offset : offset + self.n_nit_sot]
offset = self.info.n_mit_mot + self.info.n_mit_sot + self.info.n_sit_sot
return list_outputs[offset : offset + self.info.n_nit_sot]
def inner_shared(self, list_inputs):
n_taps_upto_sit_sot = sum(
len(x) for x in self.tap_array[: (self.n_mit_mot + self.n_mit_sot)]
len(x)
for x in self.info.tap_array[: (self.info.n_mit_mot + self.info.n_mit_sot)]
)
offset = self.n_seqs + n_taps_upto_sit_sot + self.n_sit_sot
return list_inputs[offset : offset + self.n_shared_outs]
offset = self.info.n_seqs + n_taps_upto_sit_sot + self.info.n_sit_sot
return list_inputs[offset : offset + self.info.n_shared_outs]
def outer_shared(self, list_inputs):
offset = 1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
return list_inputs[offset : offset + self.n_shared_outs]
offset = (
1
+ self.info.n_seqs
+ self.info.n_mit_mot
+ self.info.n_mit_sot
+ self.info.n_sit_sot
)
return list_inputs[offset : offset + self.info.n_shared_outs]
def inner_shared_outs(self, list_outputs):
n_taps = sum(len(x) for x in self.mit_mot_out_slices)
offset = self.n_mit_sot + n_taps + self.n_sit_sot + self.n_nit_sot
return list_outputs[offset : offset + self.n_shared_outs]
n_taps = sum(len(x) for x in self.info.mit_mot_out_slices)
offset = (
self.info.n_mit_sot + n_taps + self.info.n_sit_sot + self.info.n_nit_sot
)
return list_outputs[offset : offset + self.info.n_shared_outs]
def outer_shared_outs(self, list_outputs):
offset = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot + self.n_nit_sot
return list_outputs[offset : offset + self.n_shared_outs]
offset = (
self.info.n_mit_mot
+ self.info.n_mit_sot
+ self.info.n_sit_sot
+ self.info.n_nit_sot
)
return list_outputs[offset : offset + self.info.n_shared_outs]
def inner_non_seqs(self, list_inputs):
n_taps_upto_sit_sot = sum(
len(x) for x in self.tap_array[: (self.n_mit_mot + self.n_mit_sot)]
len(x)
for x in self.info.tap_array[: (self.info.n_mit_mot + self.info.n_mit_sot)]
)
offset = (
self.info.n_seqs
+ n_taps_upto_sit_sot
+ self.info.n_sit_sot
+ self.info.n_shared_outs
)
offset = self.n_seqs + n_taps_upto_sit_sot + self.n_sit_sot + self.n_shared_outs
return list_inputs[offset:]
def outer_non_seqs(self, list_inputs):
offset = (
1
+ self.n_seqs
+ self.n_mit_mot
+ self.n_mit_sot
+ self.n_sit_sot
+ self.n_nit_sot
+ self.n_shared_outs
+ self.info.n_seqs
+ self.info.n_mit_mot
+ self.info.n_mit_sot
+ self.info.n_sit_sot
+ self.info.n_nit_sot
+ self.info.n_shared_outs
)
return list_inputs[offset:]
......@@ -399,8 +427,8 @@ class ScanMethodsMixin:
for i in range(len(self.info.tap_array)):
nb_input_taps = len(self.info.tap_array[i])
if i < self.n_mit_mot:
nb_output_taps = len(self.mit_mot_out_slices[i])
if i < self.info.n_mit_mot:
nb_output_taps = len(self.info.mit_mot_out_slices[i])
else:
nb_output_taps = 1
......@@ -423,7 +451,7 @@ class ScanMethodsMixin:
outer_iidx += self.info.n_shared_outs
# Handle nitsots variables
for i in range(self.n_nit_sot):
for i in range(self.info.n_nit_sot):
outer_input_indices.append(outer_iidx)
inner_input_indices.append([])
inner_output_indices.append([inner_oidx])
......@@ -436,7 +464,7 @@ class ScanMethodsMixin:
# This is needed because, for outer inputs (and for outer inputs only)
# nitsots come *after* shared variables.
outer_iidx -= self.info.n_shared_outs + self.n_nit_sot
outer_iidx -= self.info.n_shared_outs + self.info.n_nit_sot
# Handle shared states
for i in range(self.info.n_shared_outs):
......@@ -452,7 +480,7 @@ class ScanMethodsMixin:
# This is needed because, for outer inputs (and for outer inputs only)
# nitsots come *after* shared variables.
outer_iidx += self.n_nit_sot
outer_iidx += self.info.n_nit_sot
# Handle non-sequence inputs
# Note : the number of non-sequence inputs is not stored in self.info
......@@ -524,7 +552,9 @@ class ScanMethodsMixin:
# For every recurrent output, iterate over the associated inner
# inputs and output and ensure that they have the same dtype
nb_recurr_outputs = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
nb_recurr_outputs = (
self.info.n_mit_mot + self.info.n_mit_sot + self.info.n_sit_sot
)
var_mappings = self.get_oinp_iinp_iout_oout_mappings()
for outer_oidx in range(nb_recurr_outputs):
......@@ -698,7 +728,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
idx = 0
jdx = 0
while idx < self.n_mit_mot_outs:
while idx < info.n_mit_mot_outs:
# Not that for mit_mot there are several output slices per
# output sequence
o = outputs[idx]
......@@ -706,11 +736,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
typeConstructor((False,) + o.type.broadcastable, o.type.dtype)
)
idx += len(self.mit_mot_out_slices[jdx])
idx += len(info.mit_mot_out_slices[jdx])
jdx += 1
# mit_sot / sit_sot / nit_sot
end = idx + self.n_mit_sot + self.n_sit_sot + self.n_nit_sot
end = idx + info.n_mit_sot + info.n_sit_sot + info.n_nit_sot
for o in outputs[idx:end]:
self.output_types.append(
......@@ -728,17 +758,17 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self.name = "scan_fn"
# Pre-computing some values to speed up perform
self.mintaps = [np.min(x) for x in self.tap_array]
self.mintaps += [0 for x in range(self.n_nit_sot)]
self.seqs_arg_offset = 1 + self.n_seqs
self.mintaps = [np.min(x) for x in info.tap_array]
self.mintaps += [0 for x in range(info.n_nit_sot)]
self.seqs_arg_offset = 1 + info.n_seqs
self.shared_arg_offset = (
self.seqs_arg_offset + self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
self.seqs_arg_offset + info.n_mit_mot + info.n_mit_sot + info.n_sit_sot
)
self.nit_sot_arg_offset = self.shared_arg_offset + self.n_shared_outs
# XXX: This doesn't include `self.n_nit_sot`s, so it's really a count
self.nit_sot_arg_offset = self.shared_arg_offset + info.n_shared_outs
# XXX: This doesn't include `info.n_nit_sot`s, so it's really a count
# of the number of outputs generated by taps with inputs
self.n_outs = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
self.n_tap_outs = self.n_mit_mot + self.n_mit_sot
self.n_outs = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot
self.n_tap_outs = info.n_mit_mot + info.n_mit_sot
# Do the missing inputs check here to have the error early.
for var in graph_inputs(self.outputs, self.inputs):
......@@ -758,15 +788,16 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
if config.scan__allow_output_prealloc:
preallocated_mitmot_outs = []
input_idx = self.n_seqs
for mitmot_idx in range(self.n_mit_mot):
for inp_tap in self.tap_array[mitmot_idx]:
if inp_tap in self.mit_mot_out_slices[mitmot_idx]:
info = self.info
input_idx = info.n_seqs
for mitmot_idx in range(info.n_mit_mot):
for inp_tap in info.tap_array[mitmot_idx]:
if inp_tap in info.mit_mot_out_slices[mitmot_idx]:
# Figure out the index of the corresponding output
output_idx = sum(
[len(m) for m in self.mit_mot_out_slices[:mitmot_idx]]
[len(m) for m in info.mit_mot_out_slices[:mitmot_idx]]
)
output_idx += self.mit_mot_out_slices[mitmot_idx].index(inp_tap)
output_idx += info.mit_mot_out_slices[mitmot_idx].index(inp_tap)
preallocated_mitmot_outs.append(output_idx)
input_idx += 1
......@@ -781,7 +812,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# Store the list of mitmot output taps that have been altered so they
# can be preallocated
mitmots_preallocated = [
i in preallocated_mitmot_outs for i in range(self.n_mit_mot_outs)
i in preallocated_mitmot_outs for i in range(info.n_mit_mot_outs)
]
return preallocated_mitmot_outs, mitmots_preallocated
......@@ -1112,11 +1143,13 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
return isinstance(s.type, TensorType) and s.ndim == 1
self.vector_seqs = [
is_cpu_vector(seq) for seq in new_inputs[1 : 1 + self.n_seqs]
is_cpu_vector(seq) for seq in new_inputs[1 : 1 + self.info.n_seqs]
]
self.vector_outs = [
is_cpu_vector(arg)
for arg in new_inputs[1 + self.n_seqs : (1 + self.n_seqs + self.n_outs)]
for arg in new_inputs[
1 + self.info.n_seqs : (1 + self.info.n_seqs + self.n_outs)
]
]
self.vector_outs += [
isinstance(t.type, TensorType) and t.ndim == 0
......@@ -1174,7 +1207,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
if len(self.destroy_map.keys()) > 0:
# Check if all outputs are inplace
if sorted(self.destroy_map.keys()) == sorted(
range(self.n_mit_mot + self.n_mit_sot + self.n_sit_sot)
range(self.info.n_mit_mot + self.info.n_mit_sot + self.info.n_sit_sot)
):
aux_txt += "all_inplace,%s,%s}"
else:
......@@ -1210,7 +1243,12 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# If a shared variable is the result of a ViewOp it is a clear
# indication that we need to copy that value after the perform of
# scan is done
slices = self.n_mit_mot_outs + self.n_mit_sot + self.n_sit_sot + self.n_nit_sot
slices = (
self.info.n_mit_mot_outs
+ self.info.n_mit_sot
+ self.info.n_sit_sot
+ self.info.n_nit_sot
)
if config.scan__allow_output_prealloc:
......@@ -1218,20 +1256,24 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# input and an output, wrap the input such that the corresponding
# output variable becomes an update to be performed on it, possibly
# inplace at the end of the functions's execution.
wrapped_inputs = [In(x, borrow=False) for x in self.inputs[: self.n_seqs]]
wrapped_inputs = [
In(x, borrow=False) for x in self.inputs[: self.info.n_seqs]
]
new_outputs = [x for x in self.outputs]
input_idx = self.n_seqs
for mitmot_idx in range(self.n_mit_mot):
for inp_tap in self.tap_array[mitmot_idx]:
if inp_tap in self.mit_mot_out_slices[mitmot_idx]:
input_idx = self.info.n_seqs
for mitmot_idx in range(self.info.n_mit_mot):
for inp_tap in self.info.tap_array[mitmot_idx]:
if inp_tap in self.info.mit_mot_out_slices[mitmot_idx]:
inp = self.inputs[input_idx]
# Figure out the index of the corresponding output
output_idx = sum(
[len(m) for m in self.mit_mot_out_slices[:mitmot_idx]]
[len(m) for m in self.info.mit_mot_out_slices[:mitmot_idx]]
)
output_idx += self.info.mit_mot_out_slices[mitmot_idx].index(
inp_tap
)
output_idx += self.mit_mot_out_slices[mitmot_idx].index(inp_tap)
# Make it so the input is automatically updated to the
# output value, possibly inplace, at the end of the
......@@ -1273,8 +1315,13 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# so that it runs before them). This feature will prevent mitsot,
# sitsot and nitsot outputs from being computed inplace (to allow
# their preallocation).
mitsot_start = self.n_mit_mot_outs - len(self.preallocated_mitmot_outs)
nitsot_end = mitsot_start + self.n_mit_sot + self.n_sit_sot + self.n_nit_sot
mitsot_start = self.info.n_mit_mot_outs - len(self.preallocated_mitmot_outs)
nitsot_end = (
mitsot_start
+ self.info.n_mit_sot
+ self.info.n_sit_sot
+ self.info.n_nit_sot
)
feature = NoOutputFromInplace(mitsot_start, nitsot_end)
opt = AddFeatureOptimizer(feature)
compilation_mode = self.mode_instance.register((opt, 49.9))
......@@ -1371,7 +1418,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
cython_mintaps = np.asarray(self.mintaps, dtype="int32")
tap_array_len = tuple(len(x) for x in self.tap_array)
tap_array_len = tuple(len(x) for x in self.info.tap_array)
cython_vector_seqs = np.asarray(self.vector_seqs, dtype="int32")
cython_vector_outs = np.asarray(self.vector_outs, dtype="int32")
......@@ -1421,20 +1468,20 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
try:
t_fn, n_steps = scan_perform_ext.perform(
self.n_shared_outs,
self.n_mit_mot_outs,
self.n_seqs,
self.n_mit_mot,
self.n_mit_sot,
self.n_sit_sot,
self.n_nit_sot,
self.info.n_shared_outs,
self.info.n_mit_mot_outs,
self.info.n_seqs,
self.info.n_mit_mot,
self.info.n_mit_sot,
self.info.n_sit_sot,
self.info.n_nit_sot,
self.as_while,
cython_mintaps,
self.tap_array,
self.info.tap_array,
tap_array_len,
cython_vector_seqs,
cython_vector_outs,
self.mit_mot_out_slices,
self.info.mit_mot_out_slices,
cython_mitmots_preallocated,
cython_outs_is_tensor,
inner_input_storage,
......@@ -1509,14 +1556,14 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
n_steps
X sequence inputs x_1, x_2, ... x_<self.n_seqs>
X sequence inputs x_1, x_2, ... x_<self.info.n_seqs>
Y initial states (u_1, u_2, ... u_<self.n_outs>) for our
outputs. Each must have appropriate length (T_1, T_2, ..., T_Y).
W other inputs w_1, w_2, ... w_W
There are at least ``1 + self.n_seqs + self.n_outs`` inputs, and the
There are at least ``1 + self.info.n_seqs + self.n_outs`` inputs, and the
ones above this number are passed to the scanned function as
non-sequential inputs.
......@@ -1525,6 +1572,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
Y sequence outputs y_1, y_2, ... y_<self.n_outs>
"""
info = self.info
# 1. Unzip the number of steps and sequences.
t0_call = time.time()
t_fn = 0
......@@ -1557,7 +1605,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
store_steps += [
arg
for arg in inputs[
self.nit_sot_arg_offset : self.nit_sot_arg_offset + self.n_nit_sot
self.nit_sot_arg_offset : self.nit_sot_arg_offset + info.n_nit_sot
]
]
......@@ -1575,7 +1623,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
):
# Put in the values of the initial state
output_storage[idx][0] = output_storage[idx][0][: store_steps[idx]]
if idx > self.n_mit_mot:
if idx > info.n_mit_mot:
l = -self.mintaps[idx]
output_storage[idx][0][:l] = inputs[self.seqs_arg_offset + idx][:l]
else:
......@@ -1584,7 +1632,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
output_storage[idx][0] = inputs[self.seqs_arg_offset + idx].copy()
if n_steps == 0:
for idx in range(self.n_outs, self.n_outs + self.n_nit_sot):
for idx in range(self.n_outs, self.n_outs + info.n_nit_sot):
out_var = node.outputs[idx]
if isinstance(out_var, TensorVariable):
output_storage[idx][0] = np.empty(
......@@ -1596,13 +1644,13 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
pos = [
(-self.mintaps[idx]) % store_steps[idx]
for idx in range(self.n_outs + self.n_nit_sot)
for idx in range(self.n_outs + info.n_nit_sot)
]
offset = self.nit_sot_arg_offset + self.n_nit_sot
offset = self.nit_sot_arg_offset + info.n_nit_sot
other_args = inputs[offset:]
inner_input_storage = self.fn.input_storage
nb_mitmot_in = sum(map(len, self.tap_array[: self.n_mit_mot]))
nb_mitmot_in = sum(map(len, info.tap_array[: info.n_mit_mot]))
old_mitmot_input_storage = [None] * nb_mitmot_in
old_mitmot_input_data = [None] * nb_mitmot_in
inner_output_storage = self.fn.output_storage
......@@ -1610,9 +1658,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
old_inner_output_data = [None] * len(inner_output_storage)
fn = self.fn.fn
offset = (
self.n_seqs
+ sum(map(len, self.tap_array[: self.n_outs]))
+ self.n_shared_outs
info.n_seqs
+ sum(map(len, info.tap_array[: self.n_outs]))
+ info.n_shared_outs
)
for idx in range(len(other_args)):
inner_input_storage[idx + offset].storage[0] = other_args[idx]
......@@ -1624,7 +1672,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
while (i < n_steps) and cond:
# sequences over which scan iterates
# 3. collect input slices
for idx in range(self.n_seqs):
for idx in range(info.n_seqs):
if self.vector_seqs[idx]:
inner_input_storage[idx].storage[0] = seqs[idx][i : i + 1].reshape(
()
......@@ -1632,17 +1680,17 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
else:
inner_input_storage[idx].storage[0] = seqs[idx][i]
offset = self.n_seqs
offset = info.n_seqs
for idx in range(self.n_outs):
if self.vector_outs[idx]:
for tap in self.tap_array[idx]:
for tap in info.tap_array[idx]:
_idx = (pos[idx] + tap) % store_steps[idx]
inner_input_storage[offset].storage[0] = output_storage[idx][0][
_idx : _idx + 1
].reshape(())
offset += 1
else:
for tap in self.tap_array[idx]:
for tap in info.tap_array[idx]:
_idx = (pos[idx] + tap) % store_steps[idx]
inner_input_storage[offset].storage[0] = output_storage[idx][0][
_idx
......@@ -1650,13 +1698,13 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
offset += 1
a_offset = self.shared_arg_offset
o_offset = self.n_outs + self.n_nit_sot
o_offset = self.n_outs + info.n_nit_sot
if i == 0:
for j in range(self.n_shared_outs):
for j in range(info.n_shared_outs):
inner_input_storage[offset].storage[0] = inputs[a_offset + j]
offset += 1
else:
for j in range(self.n_shared_outs):
for j in range(info.n_shared_outs):
inner_input_storage[offset].storage[0] = output_storage[
o_offset + j
][0]
......@@ -1666,36 +1714,36 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# 4.1. Collect slices for mitmots
offset = 0
for idx in range(self.n_mit_mot_outs):
for idx in range(info.n_mit_mot_outs):
if not self.mitmots_preallocated[idx]:
inner_output_storage[offset].storage[0] = None
offset += 1
# 4.2. Collect slices for mitsots, sitsots and nitsots
if i != 0:
for idx in range(self.n_outs + self.n_nit_sot - self.n_mit_mot):
for idx in range(self.n_outs + info.n_nit_sot - info.n_mit_mot):
if (
store_steps[idx + self.n_mit_mot] == 1
or self.vector_outs[idx + self.n_mit_mot]
store_steps[idx + info.n_mit_mot] == 1
or self.vector_outs[idx + info.n_mit_mot]
):
inner_output_storage[idx + offset].storage[0] = None
else:
_pos0 = idx + self.n_mit_mot
_pos0 = idx + info.n_mit_mot
inner_output_storage[idx + offset].storage[0] = output_storage[
_pos0
][0][pos[_pos0]]
else:
for idx in range(self.n_outs + self.n_nit_sot - self.n_mit_mot):
for idx in range(self.n_outs + info.n_nit_sot - info.n_mit_mot):
inner_output_storage[idx + offset].storage[0] = None
# 4.3. Collect slices for shared outputs
offset += self.n_outs + self.n_nit_sot - self.n_mit_mot
for idx in range(self.n_shared_outs):
offset += self.n_outs + info.n_nit_sot - info.n_mit_mot
for idx in range(info.n_shared_outs):
inner_output_storage[idx + offset].storage[0] = None
# 4.4. If there is a condition add it to the mix
if self.as_while:
pdx = offset + self.n_shared_outs
pdx = offset + info.n_shared_outs
inner_output_storage[pdx].storage[0] = None
# 4.5. Keep a reference to the variables (ndarrays,
......@@ -1726,7 +1774,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# reused the allocated object but alter the memory region they
# refer to.
for idx in range(nb_mitmot_in):
var = inner_input_storage[idx + self.n_seqs].storage[0]
var = inner_input_storage[idx + info.n_seqs].storage[0]
old_mitmot_input_storage[idx] = var
if var is None:
......@@ -1765,7 +1813,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
dt_fn = time.time() - t0_fn
if self.as_while:
pdx = offset + self.n_shared_outs
pdx = offset + info.n_shared_outs
cond = inner_output_storage[pdx].storage[0] == 0
# 5.2. By calling fn() directly instead of calling the aesara
......@@ -1787,16 +1835,16 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# 5.3 Copy over the values for mit_mot outputs
mitmot_inp_offset = 0
mitmot_out_idx = 0
for j in range(self.n_mit_mot):
for k in self.mit_mot_out_slices[j]:
for j in range(info.n_mit_mot):
for k in info.mit_mot_out_slices[j]:
if self.mitmots_preallocated[mitmot_out_idx]:
# This output tap has been preallocated.
inp_idx = mitmot_inp_offset + self.tap_array[j].index(k)
inp_idx = mitmot_inp_offset + info.tap_array[j].index(k)
# Verify whether the input points to the same data as
# it did before the execution of the inner function.
old_var = old_mitmot_input_storage[inp_idx]
new_var = inner_input_storage[self.n_seqs + inp_idx].storage[0]
new_var = inner_input_storage[info.n_seqs + inp_idx].storage[0]
if old_var is new_var:
old_data = old_mitmot_input_data[inp_idx]
same_data = new_var.data == old_data
......@@ -1809,7 +1857,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# and store it in `outs` as usual
if not same_data:
output_storage[j][0][k + pos[j]] = inner_input_storage[
self.n_seqs + inp_idx
info.n_seqs + inp_idx
].storage[0]
else:
......@@ -1822,12 +1870,12 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
mitmot_out_idx += 1
mitmot_inp_offset += len(self.tap_array[j])
mitmot_inp_offset += len(info.tap_array[j])
# 5.4 Copy over the values for mit_sot/sit_sot outputs
begin = self.n_mit_mot
begin = info.n_mit_mot
end = self.n_outs
offset_out -= self.n_mit_mot
offset_out -= info.n_mit_mot
for j in range(begin, end):
......@@ -1877,7 +1925,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# 5.5 Copy over the values for nit_sot outputs
begin = end
end += self.n_nit_sot
end += info.n_nit_sot
for j in range(begin, end):
if i == 0:
......@@ -1930,7 +1978,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# 5.6 Copy over the values for outputs corresponding to shared
# variables
begin = end
end += self.n_shared_outs
end += info.n_shared_outs
for j in range(begin, end):
jout = j + offset_out
output_storage[j][0] = inner_output_storage[jout].storage[0]
......@@ -1939,8 +1987,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
i = i + 1
# 6. Check if you need to re-order output buffers
begin = self.n_mit_mot
end = self.n_outs + self.n_nit_sot
begin = info.n_mit_mot
end = self.n_outs + info.n_nit_sot
for idx in range(begin, end):
if store_steps[idx] < i - self.mintaps[idx] and pos[idx] < store_steps[idx]:
......@@ -2043,30 +2091,31 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# cases extra scans in the graph. See gh-XXX for the
# investigation.
info = self.info
# We skip the first outer input as it is the total or current number
# of iterations.
# sequences
seqs_shape = [x[1:] for x in input_shapes[1 : 1 + self.n_seqs]]
seqs_shape = [x[1:] for x in input_shapes[1 : 1 + info.n_seqs]]
# We disable extra infer_shape for now. See gh-3765.
extra_infer_shape = False
if extra_infer_shape:
inner_seqs = self.inputs[: self.n_seqs]
outer_seqs = node.inputs[1 : 1 + self.n_seqs]
inner_seqs = self.inputs[: info.n_seqs]
outer_seqs = node.inputs[1 : 1 + info.n_seqs]
for in_s, out_s in zip(inner_seqs, outer_seqs):
out_equivalent[in_s] = out_s[0]
# mit_mot, mit_sot, sit_sot
outer_inp_idx = 1 + self.n_seqs
inner_inp_idx = self.n_seqs
outer_inp_idx = 1 + info.n_seqs
inner_inp_idx = info.n_seqs
else:
outer_inp_idx = 0
n_outs = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
n_outs = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot
outs_shape = []
for idx in range(n_outs):
mintap = abs(min(self.tap_array[idx]))
for k in self.tap_array[idx]:
outs_shape += [input_shapes[idx + self.n_seqs + 1][1:]]
mintap = abs(min(info.tap_array[idx]))
for k in info.tap_array[idx]:
outs_shape += [input_shapes[idx + info.n_seqs + 1][1:]]
if extra_infer_shape:
corresponding_tap = node.inputs[outer_inp_idx][mintap + k]
out_equivalent[self.inputs[inner_inp_idx]] = corresponding_tap
......@@ -2074,12 +2123,12 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
outer_inp_idx += 1
# shared_outs
offset = 1 + self.n_seqs + n_outs
for idx in range(self.n_shared_outs):
offset = 1 + info.n_seqs + n_outs
for idx in range(info.n_shared_outs):
outs_shape += [input_shapes[idx + offset]]
# non_sequences
offset += self.n_nit_sot + self.n_shared_outs
offset += info.n_nit_sot + info.n_shared_outs
inner_ins_shapes = seqs_shape + outs_shape + input_shapes[offset:]
assert len(inner_ins_shapes) == len(self.inputs)
......@@ -2103,11 +2152,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
valid=input_shapes, invalid=self.inputs, valid_equivalent=out_equivalent
)
offset = 1 + self.n_seqs
offset = 1 + info.n_seqs
scan_outs = [x for x in input_shapes[offset : offset + n_outs]]
offset += n_outs
outs_shape_n = self.n_mit_mot_outs + self.n_mit_sot + self.n_sit_sot
for x in range(self.n_nit_sot):
outs_shape_n = info.n_mit_mot_outs + info.n_mit_sot + info.n_sit_sot
for x in range(info.n_nit_sot):
out_shape_x = outs_shape[outs_shape_n + x]
if out_shape_x is None:
# This output is not a tensor, and has no shape
......@@ -2118,7 +2167,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# in the inner function.
r = node.outputs[n_outs + x]
assert r.ndim == 1 + len(out_shape_x)
shp = [node.inputs[offset + self.n_shared_outs + x]]
shp = [node.inputs[offset + info.n_shared_outs + x]]
for i, shp_i in zip(range(1, r.ndim), out_shape_x):
# Validate shp_i. v_shape_i is either None (if invalid),
# or a (variable, Boolean) tuple. The Boolean indicates
......@@ -2135,7 +2184,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
shp.append(v_shp_i[0])
scan_outs.append(tuple(shp))
scan_outs += [x for x in input_shapes[offset : offset + self.n_shared_outs]]
scan_outs += [x for x in input_shapes[offset : offset + info.n_shared_outs]]
# if we are dealing with a repeat-until, then we do not know the
# leading dimension so we replace it for every entry with Shape_i
if self.as_while:
......@@ -2217,13 +2266,14 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# We do not know what kind of outputs the original scan has, so we
# try first to see if it has a nit_sot output, then a sit_sot and
# then a mit_sot
if self.n_nit_sot > 0:
info = self.info
if info.n_nit_sot > 0:
grad_steps = self.outer_nitsot_outs(outs)[0].shape[0]
elif self.n_sit_sot > 0:
elif info.n_sit_sot > 0:
grad_steps = self.outer_sitsot_outs(outs)[0].shape[0] - 1
elif self.n_mit_sot > 0:
elif info.n_mit_sot > 0:
grad_steps = (
self.outer_mitsot_outs(outs)[0].shape[0] + self.mintaps[self.n_mit_mot]
self.outer_mitsot_outs(outs)[0].shape[0] + self.mintaps[info.n_mit_mot]
)
else:
grad_steps = inputs[0]
......@@ -2255,10 +2305,10 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
connection_pattern = self.connection_pattern(scan_node)
def get_inp_idx(iidx):
if iidx < self.n_seqs:
if iidx < info.n_seqs:
return 1 + iidx
oidx = 1 + self.n_seqs
iidx = iidx - self.n_seqs
oidx = 1 + info.n_seqs
iidx = iidx - info.n_seqs
for taps in self.mitmot_taps():
if len(taps) > iidx:
return oidx
......@@ -2272,10 +2322,10 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
oidx += 1
iidx -= len(taps)
if iidx < self.info.n_sit_sot:
if iidx < info.n_sit_sot:
return oidx + iidx
else:
return oidx + iidx + self.info.n_nit_sot
return oidx + iidx + info.n_nit_sot
def get_out_idx(iidx):
oidx = 0
......@@ -2347,7 +2397,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
for idx, Xt in enumerate(diff_outputs):
# We are looking for x[t-1] for a given x[t]
if idx >= self.n_mit_mot_outs:
if idx >= info.n_mit_mot_outs:
Xt_placeholder = safe_new(Xt)
Xts.append(Xt_placeholder)
......@@ -2355,10 +2405,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# or not. NOTE : This cannot be done by using
# "if Xt not in self.inner_nitsot_outs(self_outputs)" because
# the exact same variable can be used as multiple outputs.
idx_nitsot_start = (
self.info.n_mit_mot + self.info.n_mit_sot + self.info.n_sit_sot
)
idx_nitsot_end = idx_nitsot_start + self.info.n_nit_sot
idx_nitsot_start = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot
idx_nitsot_end = idx_nitsot_start + info.n_nit_sot
if idx < idx_nitsot_start or idx >= idx_nitsot_end:
# What we do here is loop through dC_douts and collect all
# those that are connected to the specific one and do an
......@@ -2378,7 +2426,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# Get the index of the outer output that to which
# the state variable 'inp' corresponds.
outer_oidx = var_mappings["outer_out_from_inner_inp"][
self.n_seqs + pos
info.n_seqs + pos
]
if not isinstance(dC_douts[outer_oidx].type, DisconnectedType):
......@@ -2420,55 +2468,55 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
dC_dinps_t[dx] = at.zeros_like(diff_inputs[dx])
else:
disconnected_dC_dinps_t[dx] = False
for Xt, Xt_placeholder in zip(diff_outputs[self.n_mit_mot_outs :], Xts):
for Xt, Xt_placeholder in zip(diff_outputs[info.n_mit_mot_outs :], Xts):
tmp = forced_replace(dC_dinps_t[dx], Xt, Xt_placeholder)
dC_dinps_t[dx] = tmp
# construct dX_dtm1
dC_dXtm1s = []
for pos, x in enumerate(dC_dinps_t[self.n_seqs :]):
for pos, x in enumerate(dC_dinps_t[info.n_seqs :]):
# Get the index of the first inner input corresponding to the
# pos-ieth inner input state
idxs = var_mappings["inner_out_from_inner_inp"][self.n_seqs + pos]
idxs = var_mappings["inner_out_from_inner_inp"][info.n_seqs + pos]
# Check if the pos-th input is associated with one of the
# recurrent states
x_is_state = pos < sum([len(t) for t in self.tap_array])
x_is_state = pos < sum([len(t) for t in info.tap_array])
if x_is_state and len(idxs) > 0:
opos = idxs[0]
dC_dXtm1s.append(safe_new(dC_dXts[opos]))
if hasattr(x, "dtype") and x.dtype != dC_dXts[opos].dtype:
dC_dinps_t[pos + self.n_seqs] = x.astype(dC_dXts[opos].dtype)
dC_dinps_t[pos + info.n_seqs] = x.astype(dC_dXts[opos].dtype)
else:
dC_dXtm1s.append(safe_new(x))
for dx, dC_dXtm1 in enumerate(dC_dXtm1s):
if isinstance(dC_dinps_t[dx + self.n_seqs].type, NullType):
if isinstance(dC_dinps_t[dx + info.n_seqs].type, NullType):
# The accumulated gradient is undefined
pass
elif isinstance(dC_dXtm1.type, NullType):
# The new gradient is undefined, this makes the accumulated
# gradient undefined as weell
dC_dinps_t[dx + self.n_seqs] = dC_dXtm1
dC_dinps_t[dx + info.n_seqs] = dC_dXtm1
else:
dC_dinps_t[dx + self.n_seqs] += dC_dXtm1
dC_dinps_t[dx + info.n_seqs] += dC_dXtm1
# Construct scan op
# Seqs
if self.as_while:
# equivalent to x[:n_steps][::-1]
outer_inp_seqs = [x[n_steps - 1 :: -1] for x in inputs[1 : 1 + self.n_seqs]]
outer_inp_seqs = [x[n_steps - 1 :: -1] for x in inputs[1 : 1 + info.n_seqs]]
else:
outer_inp_seqs = [x[::-1] for x in inputs[1 : 1 + self.n_seqs]]
for idx in range(self.n_mit_mot + self.n_mit_sot):
mintap = np.min(self.tap_array[idx])
if idx < self.n_mit_mot:
outer_inp_seqs = [x[::-1] for x in inputs[1 : 1 + info.n_seqs]]
for idx in range(info.n_mit_mot + info.n_mit_sot):
mintap = np.min(info.tap_array[idx])
if idx < info.n_mit_mot:
outmaxtap = np.max(self.mitmot_out_taps()[idx])
else:
outmaxtap = 0
seq = outs[idx]
for k in self.tap_array[idx]:
for k in info.tap_array[idx]:
if outmaxtap - k != 0:
nw_seq = seq[k - mintap : -(outmaxtap - k)][::-1]
else:
......@@ -2531,11 +2579,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
mitmot_out_taps = []
type_outs = []
out_pos = 0
ins_pos = self.n_seqs
ins_pos = info.n_seqs
n_mitmot_outs = 0
n_mitmot_inps = 0
for idx in range(self.n_mit_mot):
for idx in range(info.n_mit_mot):
if isinstance(dC_douts[idx].type, DisconnectedType):
out = outs[idx]
outer_inp_mitmot.append(at.zeros_like(out))
......@@ -2547,19 +2595,19 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
through_shared = False
disconnected = True
for jdx in range(len(self.mit_mot_out_slices[idx])):
for jdx in range(len(info.mit_mot_out_slices[idx])):
inner_inp_mitmot.append(dC_dXts[out_pos])
mitmot_inp_taps[idx].append(-self.mit_mot_out_slices[idx][jdx])
mitmot_inp_taps[idx].append(-info.mit_mot_out_slices[idx][jdx])
n_mitmot_inps += 1
out_pos += 1
for jdx in range(len(self.tap_array[idx])):
tap = -self.tap_array[idx][jdx]
for jdx in range(len(info.tap_array[idx])):
tap = -info.tap_array[idx][jdx]
# Only create a new inner input if there is not already one
# associated with this input tap
if tap not in mitmot_inp_taps[idx]:
inner_inp_mitmot.append(dC_dXtm1s[ins_pos - self.n_seqs])
inner_inp_mitmot.append(dC_dXtm1s[ins_pos - info.n_seqs])
if isinstance(dC_dinps_t[ins_pos].type, NullType):
# We cannot use Null in the inner graph, so we
......@@ -2575,13 +2623,13 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# input tap, make sure the computation of the new output
# uses it instead of the input it's currently using
if tap in mitmot_inp_taps[idx]:
to_replace = dC_dXtm1s[ins_pos - self.n_seqs]
to_replace = dC_dXtm1s[ins_pos - info.n_seqs]
replacement_idx = len(mitmot_inp_taps[idx]) - mitmot_inp_taps[
idx
].index(tap)
replacement = inner_inp_mitmot[-replacement_idx]
self.tap_array[idx]
info.tap_array[idx]
new_inner_out_mitmot = clone_replace(
new_inner_out_mitmot, replace=[(to_replace, replacement)]
)
......@@ -2597,12 +2645,12 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
ins_pos += 1
n_mitmot_outs += 1
mitmot_out_taps[idx].append(-self.tap_array[idx][jdx])
mitmot_out_taps[idx].append(-info.tap_array[idx][jdx])
# Only add the tap as a new input tap if needed
if tap not in mitmot_inp_taps[idx]:
n_mitmot_inps += 1
mitmot_inp_taps[idx].append(-self.tap_array[idx][jdx])
mitmot_inp_taps[idx].append(-info.tap_array[idx][jdx])
if undefined_msg:
type_outs.append(undefined_msg)
......@@ -2613,15 +2661,15 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
else:
type_outs.append("connected")
offset = self.n_mit_mot
for idx in range(self.n_mit_sot):
offset = info.n_mit_mot
for idx in range(info.n_mit_sot):
if isinstance(dC_douts[idx + offset].type, DisconnectedType):
outer_inp_mitmot.append(outs[idx + offset].zeros_like())
else:
outer_inp_mitmot.append(dC_douts[idx + offset][::-1])
mitmot_inp_taps.append([])
mitmot_out_taps.append([])
idx_tap = idx + self.n_mit_mot
idx_tap = idx + info.n_mit_mot
inner_inp_mitmot.append(dC_dXts[out_pos])
out_pos += 1
n_mitmot_inps += 1
......@@ -2629,8 +2677,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
through_shared = False
disconnected = True
mitmot_inp_taps[idx + offset].append(0)
for jdx in range(len(self.tap_array[idx_tap])):
inner_inp_mitmot.append(dC_dXtm1s[ins_pos - self.n_seqs])
for jdx in range(len(info.tap_array[idx_tap])):
inner_inp_mitmot.append(dC_dXtm1s[ins_pos - info.n_seqs])
if isinstance(dC_dinps_t[ins_pos].type, NullType):
# We cannot use Null in the inner graph, so we
......@@ -2642,8 +2690,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
else:
inner_out_mitmot.append(dC_dinps_t[ins_pos])
mitmot_inp_taps[idx + offset].append(-self.tap_array[idx_tap][jdx])
mitmot_out_taps[idx].append(-self.tap_array[idx_tap][jdx])
mitmot_inp_taps[idx + offset].append(-info.tap_array[idx_tap][jdx])
mitmot_out_taps[idx].append(-info.tap_array[idx_tap][jdx])
if not disconnected_dC_dinps_t[ins_pos]:
disconnected = False
for _sh in self.inner_shared(self_inputs):
......@@ -2663,8 +2711,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
else:
type_outs.append("connected")
offset += self.n_mit_sot
for idx in range(self.n_sit_sot):
offset += info.n_mit_sot
for idx in range(info.n_sit_sot):
mitmot_inp_taps.append([0, 1])
mitmot_out_taps.append([1])
through_shared = False
......@@ -2707,14 +2755,17 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
else:
type_outs.append("connected")
inner_inp_mitmot += [dC_dXts[out_pos], dC_dXtm1s[ins_pos - self.n_seqs]]
inner_inp_mitmot += [
dC_dXts[out_pos],
dC_dXtm1s[ins_pos - info.n_seqs],
]
n_mitmot_outs += 1
out_pos += 1
ins_pos += 1
n_mitmot_inps += 2
n_nit_sot = self.n_seqs
inner_out_nitsot = dC_dinps_t[: self.n_seqs]
n_nit_sot = info.n_seqs
inner_out_nitsot = dC_dinps_t[: info.n_seqs]
inner_out_sitsot = dC_dinps_t[ins_pos:]
for _p, vl in enumerate(inner_out_sitsot):
through_shared = False
......@@ -2755,7 +2806,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
else:
type_outs.append("connected")
inner_inp_sitsot = dC_dXtm1s[ins_pos - self.n_seqs :]
inner_inp_sitsot = dC_dXtm1s[ins_pos - info.n_seqs :]
outer_inp_sitsot = []
for _idx, y in enumerate(inner_inp_sitsot):
x = self.outer_non_seqs(inputs)[_idx]
......@@ -2783,7 +2834,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
n_sitsot_outs = len(outer_inp_sitsot)
new_tap_array = mitmot_inp_taps + [[-1] for k in range(n_sitsot_outs)]
info = ScanInfo(
out_info = ScanInfo(
n_seqs=len(outer_inp_seqs),
n_mit_sot=0,
tap_array=tuple(tuple(v) for v in new_tap_array),
......@@ -2817,7 +2868,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
local_op = Scan(
inner_gfn_ins,
inner_gfn_outs,
info,
out_info,
mode=self.mode,
truncate_gradient=self.truncate_gradient,
as_while=False,
......@@ -2831,11 +2882,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# Re-order the gradients correctly
gradients = [DisconnectedType()()]
offset = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot + n_sitsot_outs
offset = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot + n_sitsot_outs
for p, (x, t) in enumerate(
zip(
outputs[offset : offset + self.n_seqs],
type_outs[offset : offset + self.n_seqs],
outputs[offset : offset + info.n_seqs],
type_outs[offset : offset + info.n_seqs],
)
):
if t == "connected":
......@@ -2864,7 +2915,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# t contains the "why_null" string of a NullType
gradients.append(NullType(t)())
end = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
end = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot
for p, (x, t) in enumerate(zip(outputs[:end], type_outs[:end])):
if t == "connected":
# If the forward scan is in as_while mode, we need to pad
......@@ -2886,8 +2937,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
gradients.append(
grad_undefined(
self,
p + 1 + self.n_seqs,
inputs[p + 1 + self.n_seqs],
p + 1 + info.n_seqs,
inputs[p + 1 + info.n_seqs],
"Depends on a shared variable",
)
)
......@@ -2897,7 +2948,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
start = len(gradients)
node = outs[0].owner
for idx in range(self.n_shared_outs):
for idx in range(info.n_shared_outs):
disconnected = True
connected_flags = self.connection_pattern(node)[idx + start]
for dC_dout, connected in zip(dC_douts, connected_flags):
......@@ -2913,7 +2964,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
)
start = len(gradients)
gradients += [DisconnectedType()() for _ in range(self.n_nit_sot)]
gradients += [DisconnectedType()() for _ in range(info.n_nit_sot)]
begin = end
end = begin + n_sitsot_outs
......@@ -2954,10 +3005,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
def R_op(self, inputs, eval_points):
# Step 0. Prepare some shortcut variable
info = self.info
self_inputs = self.inputs
rop_of_inputs = (
self_inputs[: self.n_seqs + self.n_outs]
+ self_inputs[self.n_seqs + self.n_outs + self.n_shared_outs :]
self_inputs[: info.n_seqs + self.n_outs]
+ self_inputs[info.n_seqs + self.n_outs + info.n_shared_outs :]
)
self_outputs = self.outputs
......@@ -2967,8 +3019,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
rop_self_outputs = self_outputs[:-1]
else:
rop_self_outputs = self_outputs
if self.info.n_shared_outs > 0:
rop_self_outputs = rop_self_outputs[: -self.info.n_shared_outs]
if info.n_shared_outs > 0:
rop_self_outputs = rop_self_outputs[: -info.n_shared_outs]
rop_outs = Rop(rop_self_outputs, rop_of_inputs, inner_eval_points)
if not isinstance(rop_outs, (list, tuple)):
rop_outs = [rop_outs]
......@@ -2985,20 +3037,20 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
new_tap_array = []
b = 0
e = self.n_mit_mot
new_tap_array += self.tap_array[b:e] * 2
e = info.n_mit_mot
new_tap_array += info.tap_array[b:e] * 2
b = e
e += self.n_mit_sot
new_tap_array += self.tap_array[b:e] * 2
e += info.n_mit_sot
new_tap_array += info.tap_array[b:e] * 2
b = e
e += self.n_sit_sot
new_tap_array += self.tap_array[b:e] * 2
e += info.n_sit_sot
new_tap_array += info.tap_array[b:e] * 2
# Sequences ...
b = 1
ib = 0
e = 1 + self.n_seqs
ie = self.n_seqs
e = 1 + info.n_seqs
ie = info.n_seqs
clean_eval_points = []
for inp, evp in zip(inputs[b:e], eval_points[b:e]):
if evp is not None:
......@@ -3011,9 +3063,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# MIT_MOT sequences ...
b = e
e = e + self.n_mit_mot
e = e + info.n_mit_mot
ib = ie
ie = ie + int(sum(len(x) for x in self.tap_array[: self.n_mit_mot]))
ie = ie + int(sum(len(x) for x in info.tap_array[: info.n_mit_mot]))
clean_eval_points = []
for inp, evp in zip(inputs[b:e], eval_points[b:e]):
if evp is not None:
......@@ -3026,13 +3078,13 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# MIT_SOT sequences ...
b = e
e = e + self.n_mit_sot
e = e + info.n_mit_sot
ib = ie
ie = ie + int(
sum(
len(x)
for x in self.tap_array[
self.n_mit_mot : self.n_mit_mot + self.n_mit_sot
for x in info.tap_array[
info.n_mit_mot : info.n_mit_mot + info.n_mit_sot
]
)
)
......@@ -3048,9 +3100,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# SIT_SOT sequences ...
b = e
e = e + self.n_sit_sot
e = e + info.n_sit_sot
ib = ie
ie = ie + self.n_sit_sot
ie = ie + info.n_sit_sot
clean_eval_points = []
for inp, evp in zip(inputs[b:e], eval_points[b:e]):
if evp is not None:
......@@ -3063,15 +3115,15 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# Shared outs ...
b = e
e = e + self.n_shared_outs
e = e + info.n_shared_outs
ib = ie
ie = ie + self.n_shared_outs
ie = ie + info.n_shared_outs
scan_shared = inputs[b:e]
inner_shared = self_inputs[ib:ie]
# NIT_SOT sequences
b = e
e = e + self.n_nit_sot
e = e + info.n_nit_sot
scan_nit_sot = inputs[b:e] * 2
# All other arguments
......@@ -3086,22 +3138,22 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
inner_other = self_inputs[ie:] + inner_eval_points[ib:]
# Outputs
n_mit_mot_outs = int(sum(len(x) for x in self.mit_mot_out_slices))
n_mit_mot_outs = int(sum(len(x) for x in info.mit_mot_out_slices))
b = 0
e = n_mit_mot_outs
inner_out_mit_mot = self_outputs[b:e] + rop_outs[b:e]
b = e
e = e + self.n_mit_sot
e = e + info.n_mit_sot
inner_out_mit_sot = self_outputs[b:e] + rop_outs[b:e]
b = e
e = e + self.n_sit_sot
e = e + info.n_sit_sot
inner_out_sit_sot = self_outputs[b:e] + rop_outs[b:e]
b = e
e = e + self.n_nit_sot
e = e + info.n_nit_sot
inner_out_nit_sot = self_outputs[b:e] + rop_outs[b:e]
b = e
e = e + self.n_shared_outs
e = e + info.n_shared_outs
inner_out_shared = self_outputs[b:e]
inner_ins = (
......@@ -3133,22 +3185,22 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
+ scan_other
)
info = ScanInfo(
n_seqs=self.n_seqs * 2,
n_mit_sot=self.n_mit_sot * 2,
n_sit_sot=self.n_sit_sot * 2,
n_mit_mot=self.n_mit_mot * 2,
n_nit_sot=self.n_nit_sot * 2,
n_shared_outs=self.n_shared_outs,
out_info = ScanInfo(
n_seqs=info.n_seqs * 2,
n_mit_sot=info.n_mit_sot * 2,
n_sit_sot=info.n_sit_sot * 2,
n_mit_mot=info.n_mit_mot * 2,
n_nit_sot=info.n_nit_sot * 2,
n_shared_outs=info.n_shared_outs,
n_mit_mot_outs=n_mit_mot_outs * 2,
tap_array=tuple(tuple(v) for v in new_tap_array),
mit_mot_out_slices=tuple(tuple(v) for v in self.mit_mot_out_slices) * 2,
mit_mot_out_slices=tuple(tuple(v) for v in info.mit_mot_out_slices) * 2,
)
local_op = Scan(
inner_ins,
inner_outs,
info,
out_info,
mode=self.mode,
as_while=self.as_while,
profile=self.profile,
......@@ -3161,19 +3213,19 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
outputs = [outputs]
# Select only the result of the R_op results
final_outs = []
b = self.n_mit_mot
e = self.n_mit_mot * 2
b = info.n_mit_mot
e = info.n_mit_mot * 2
final_outs += outputs[b:e]
b = e + self.n_mit_sot
e = e + self.n_mit_sot * 2
b = e + info.n_mit_sot
e = e + info.n_mit_sot * 2
final_outs += outputs[b:e]
b = e + self.n_sit_sot
e = e + self.n_sit_sot * 2
b = e + info.n_sit_sot
e = e + info.n_sit_sot * 2
final_outs += outputs[b:e]
b = e + self.n_nit_sot
e = e + self.n_nit_sot * 2
b = e + info.n_nit_sot
e = e + info.n_nit_sot * 2
final_outs += outputs[b:e]
final_outs += [None] * self.n_shared_outs
final_outs += [None] * info.n_shared_outs
return final_outs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论