提交 05da60c3 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use ScanInfo attributes directly in Scan Op

上级 af440de8
...@@ -224,132 +224,160 @@ class ScanMethodsMixin: ...@@ -224,132 +224,160 @@ class ScanMethodsMixin:
def inner_seqs(self, list_inputs): def inner_seqs(self, list_inputs):
# Given the list of inner inputs this function grabs those # Given the list of inner inputs this function grabs those
# corresponding to sequences # corresponding to sequences
return list_inputs[: self.n_seqs] return list_inputs[: self.info.n_seqs]
def outer_seqs(self, list_inputs): def outer_seqs(self, list_inputs):
# Given the list of outer inputs this function grabs those # Given the list of outer inputs this function grabs those
# corresponding to sequences # corresponding to sequences
return list_inputs[1 : 1 + self.n_seqs] return list_inputs[1 : 1 + self.info.n_seqs]
def inner_mitmot(self, list_inputs): def inner_mitmot(self, list_inputs):
n_taps = sum(len(x) for x in self.tap_array[: self.n_mit_mot]) n_taps = sum(len(x) for x in self.info.tap_array[: self.info.n_mit_mot])
return list_inputs[self.n_seqs : self.n_seqs + n_taps] return list_inputs[self.info.n_seqs : self.info.n_seqs + n_taps]
def outer_mitmot(self, list_inputs): def outer_mitmot(self, list_inputs):
return list_inputs[1 + self.n_seqs : 1 + self.n_seqs + self.n_mit_mot] return list_inputs[
1 + self.info.n_seqs : 1 + self.info.n_seqs + self.info.n_mit_mot
]
def inner_mitmot_outs(self, list_outputs): def inner_mitmot_outs(self, list_outputs):
n_taps = sum(len(x) for x in self.mit_mot_out_slices) n_taps = sum(len(x) for x in self.info.mit_mot_out_slices)
return list_outputs[:n_taps] return list_outputs[:n_taps]
def outer_mitmot_outs(self, list_outputs): def outer_mitmot_outs(self, list_outputs):
return list_outputs[: self.n_mit_mot] return list_outputs[: self.info.n_mit_mot]
def mitmot_taps(self): def mitmot_taps(self):
return self.tap_array[: self.n_mit_mot] return self.info.tap_array[: self.info.n_mit_mot]
def mitmot_out_taps(self): def mitmot_out_taps(self):
return self.mit_mot_out_slices[: self.n_mit_mot] return self.info.mit_mot_out_slices[: self.info.n_mit_mot]
def inner_mitsot(self, list_inputs): def inner_mitsot(self, list_inputs):
n_mitmot_taps = sum(len(x) for x in self.tap_array[: self.n_mit_mot]) n_mitmot_taps = sum(len(x) for x in self.info.tap_array[: self.info.n_mit_mot])
ntaps_upto_sit_sot = sum( ntaps_upto_sit_sot = sum(
len(x) for x in self.tap_array[: (self.n_mit_mot + self.n_mit_sot)] len(x)
for x in self.info.tap_array[: (self.info.n_mit_mot + self.info.n_mit_sot)]
) )
return list_inputs[ return list_inputs[
self.n_seqs + n_mitmot_taps : self.n_seqs + ntaps_upto_sit_sot self.info.n_seqs + n_mitmot_taps : self.info.n_seqs + ntaps_upto_sit_sot
] ]
def outer_mitsot(self, list_inputs): def outer_mitsot(self, list_inputs):
offset = 1 + self.n_seqs + self.n_mit_mot offset = 1 + self.info.n_seqs + self.info.n_mit_mot
return list_inputs[offset : offset + self.n_mit_sot] return list_inputs[offset : offset + self.info.n_mit_sot]
def inner_mitsot_outs(self, list_outputs): def inner_mitsot_outs(self, list_outputs):
n_taps = sum(len(x) for x in self.mit_mot_out_slices) n_taps = sum(len(x) for x in self.info.mit_mot_out_slices)
return list_outputs[n_taps : n_taps + self.n_mit_sot] return list_outputs[n_taps : n_taps + self.info.n_mit_sot]
def outer_mitsot_outs(self, list_outputs): def outer_mitsot_outs(self, list_outputs):
return list_outputs[self.n_mit_mot : self.n_mit_mot + self.n_mit_sot] return list_outputs[
self.info.n_mit_mot : self.info.n_mit_mot + self.info.n_mit_sot
]
def mitsot_taps(self): def mitsot_taps(self):
return self.tap_array[self.n_mit_mot : self.n_mit_mot + self.n_mit_sot] return self.info.tap_array[
self.info.n_mit_mot : self.info.n_mit_mot + self.info.n_mit_sot
]
def inner_sitsot(self, list_inputs): def inner_sitsot(self, list_inputs):
n_taps_upto_sit_sot = sum( n_taps_upto_sit_sot = sum(
len(x) for x in self.tap_array[: (self.n_mit_mot + self.n_mit_sot)] len(x)
for x in self.info.tap_array[: (self.info.n_mit_mot + self.info.n_mit_sot)]
) )
offset = self.n_seqs + n_taps_upto_sit_sot offset = self.info.n_seqs + n_taps_upto_sit_sot
return list_inputs[offset : offset + self.n_sit_sot] return list_inputs[offset : offset + self.info.n_sit_sot]
def outer_sitsot(self, list_inputs): def outer_sitsot(self, list_inputs):
offset = 1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot offset = 1 + self.info.n_seqs + self.info.n_mit_mot + self.info.n_mit_sot
return list_inputs[offset : offset + self.n_sit_sot] return list_inputs[offset : offset + self.info.n_sit_sot]
def inner_sitsot_outs(self, list_outputs): def inner_sitsot_outs(self, list_outputs):
n_taps = sum(len(x) for x in self.mit_mot_out_slices) n_taps = sum(len(x) for x in self.info.mit_mot_out_slices)
offset = self.n_mit_sot + n_taps offset = self.info.n_mit_sot + n_taps
return list_outputs[offset : offset + self.n_sit_sot] return list_outputs[offset : offset + self.info.n_sit_sot]
def outer_sitsot_outs(self, list_outputs): def outer_sitsot_outs(self, list_outputs):
offset = self.n_mit_mot + self.n_mit_sot offset = self.info.n_mit_mot + self.info.n_mit_sot
return list_outputs[offset : offset + self.n_sit_sot] return list_outputs[offset : offset + self.info.n_sit_sot]
def outer_nitsot(self, list_inputs): def outer_nitsot(self, list_inputs):
offset = ( offset = (
1 1
+ self.n_seqs + self.info.n_seqs
+ self.n_mit_mot + self.info.n_mit_mot
+ self.n_mit_sot + self.info.n_mit_sot
+ self.n_sit_sot + self.info.n_sit_sot
+ self.n_shared_outs + self.info.n_shared_outs
) )
return list_inputs[offset : offset + self.n_nit_sot] return list_inputs[offset : offset + self.info.n_nit_sot]
def inner_nitsot_outs(self, list_outputs): def inner_nitsot_outs(self, list_outputs):
n_taps = sum(len(x) for x in self.mit_mot_out_slices) n_taps = sum(len(x) for x in self.info.mit_mot_out_slices)
offset = self.n_mit_sot + n_taps + self.n_sit_sot offset = self.info.n_mit_sot + n_taps + self.info.n_sit_sot
return list_outputs[offset : offset + self.n_nit_sot] return list_outputs[offset : offset + self.info.n_nit_sot]
def outer_nitsot_outs(self, list_outputs): def outer_nitsot_outs(self, list_outputs):
offset = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot offset = self.info.n_mit_mot + self.info.n_mit_sot + self.info.n_sit_sot
return list_outputs[offset : offset + self.n_nit_sot] return list_outputs[offset : offset + self.info.n_nit_sot]
def inner_shared(self, list_inputs): def inner_shared(self, list_inputs):
n_taps_upto_sit_sot = sum( n_taps_upto_sit_sot = sum(
len(x) for x in self.tap_array[: (self.n_mit_mot + self.n_mit_sot)] len(x)
for x in self.info.tap_array[: (self.info.n_mit_mot + self.info.n_mit_sot)]
) )
offset = self.n_seqs + n_taps_upto_sit_sot + self.n_sit_sot offset = self.info.n_seqs + n_taps_upto_sit_sot + self.info.n_sit_sot
return list_inputs[offset : offset + self.n_shared_outs] return list_inputs[offset : offset + self.info.n_shared_outs]
def outer_shared(self, list_inputs): def outer_shared(self, list_inputs):
offset = 1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot + self.n_sit_sot offset = (
return list_inputs[offset : offset + self.n_shared_outs] 1
+ self.info.n_seqs
+ self.info.n_mit_mot
+ self.info.n_mit_sot
+ self.info.n_sit_sot
)
return list_inputs[offset : offset + self.info.n_shared_outs]
def inner_shared_outs(self, list_outputs): def inner_shared_outs(self, list_outputs):
n_taps = sum(len(x) for x in self.mit_mot_out_slices) n_taps = sum(len(x) for x in self.info.mit_mot_out_slices)
offset = self.n_mit_sot + n_taps + self.n_sit_sot + self.n_nit_sot offset = (
return list_outputs[offset : offset + self.n_shared_outs] self.info.n_mit_sot + n_taps + self.info.n_sit_sot + self.info.n_nit_sot
)
return list_outputs[offset : offset + self.info.n_shared_outs]
def outer_shared_outs(self, list_outputs): def outer_shared_outs(self, list_outputs):
offset = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot + self.n_nit_sot offset = (
return list_outputs[offset : offset + self.n_shared_outs] self.info.n_mit_mot
+ self.info.n_mit_sot
+ self.info.n_sit_sot
+ self.info.n_nit_sot
)
return list_outputs[offset : offset + self.info.n_shared_outs]
def inner_non_seqs(self, list_inputs): def inner_non_seqs(self, list_inputs):
n_taps_upto_sit_sot = sum( n_taps_upto_sit_sot = sum(
len(x) for x in self.tap_array[: (self.n_mit_mot + self.n_mit_sot)] len(x)
for x in self.info.tap_array[: (self.info.n_mit_mot + self.info.n_mit_sot)]
)
offset = (
self.info.n_seqs
+ n_taps_upto_sit_sot
+ self.info.n_sit_sot
+ self.info.n_shared_outs
) )
offset = self.n_seqs + n_taps_upto_sit_sot + self.n_sit_sot + self.n_shared_outs
return list_inputs[offset:] return list_inputs[offset:]
def outer_non_seqs(self, list_inputs): def outer_non_seqs(self, list_inputs):
offset = ( offset = (
1 1
+ self.n_seqs + self.info.n_seqs
+ self.n_mit_mot + self.info.n_mit_mot
+ self.n_mit_sot + self.info.n_mit_sot
+ self.n_sit_sot + self.info.n_sit_sot
+ self.n_nit_sot + self.info.n_nit_sot
+ self.n_shared_outs + self.info.n_shared_outs
) )
return list_inputs[offset:] return list_inputs[offset:]
...@@ -399,8 +427,8 @@ class ScanMethodsMixin: ...@@ -399,8 +427,8 @@ class ScanMethodsMixin:
for i in range(len(self.info.tap_array)): for i in range(len(self.info.tap_array)):
nb_input_taps = len(self.info.tap_array[i]) nb_input_taps = len(self.info.tap_array[i])
if i < self.n_mit_mot: if i < self.info.n_mit_mot:
nb_output_taps = len(self.mit_mot_out_slices[i]) nb_output_taps = len(self.info.mit_mot_out_slices[i])
else: else:
nb_output_taps = 1 nb_output_taps = 1
...@@ -423,7 +451,7 @@ class ScanMethodsMixin: ...@@ -423,7 +451,7 @@ class ScanMethodsMixin:
outer_iidx += self.info.n_shared_outs outer_iidx += self.info.n_shared_outs
# Handle nitsots variables # Handle nitsots variables
for i in range(self.n_nit_sot): for i in range(self.info.n_nit_sot):
outer_input_indices.append(outer_iidx) outer_input_indices.append(outer_iidx)
inner_input_indices.append([]) inner_input_indices.append([])
inner_output_indices.append([inner_oidx]) inner_output_indices.append([inner_oidx])
...@@ -436,7 +464,7 @@ class ScanMethodsMixin: ...@@ -436,7 +464,7 @@ class ScanMethodsMixin:
# This is needed because, for outer inputs (and for outer inputs only) # This is needed because, for outer inputs (and for outer inputs only)
# nitsots come *after* shared variables. # nitsots come *after* shared variables.
outer_iidx -= self.info.n_shared_outs + self.n_nit_sot outer_iidx -= self.info.n_shared_outs + self.info.n_nit_sot
# Handle shared states # Handle shared states
for i in range(self.info.n_shared_outs): for i in range(self.info.n_shared_outs):
...@@ -452,7 +480,7 @@ class ScanMethodsMixin: ...@@ -452,7 +480,7 @@ class ScanMethodsMixin:
# This is needed because, for outer inputs (and for outer inputs only) # This is needed because, for outer inputs (and for outer inputs only)
# nitsots come *after* shared variables. # nitsots come *after* shared variables.
outer_iidx += self.n_nit_sot outer_iidx += self.info.n_nit_sot
# Handle non-sequence inputs # Handle non-sequence inputs
# Note : the number of non-sequence inputs is not stored in self.info # Note : the number of non-sequence inputs is not stored in self.info
...@@ -524,7 +552,9 @@ class ScanMethodsMixin: ...@@ -524,7 +552,9 @@ class ScanMethodsMixin:
# For every recurrent output, iterate over the associated inner # For every recurrent output, iterate over the associated inner
# inputs and output and ensure that they have the same dtype # inputs and output and ensure that they have the same dtype
nb_recurr_outputs = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot nb_recurr_outputs = (
self.info.n_mit_mot + self.info.n_mit_sot + self.info.n_sit_sot
)
var_mappings = self.get_oinp_iinp_iout_oout_mappings() var_mappings = self.get_oinp_iinp_iout_oout_mappings()
for outer_oidx in range(nb_recurr_outputs): for outer_oidx in range(nb_recurr_outputs):
...@@ -698,7 +728,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -698,7 +728,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
idx = 0 idx = 0
jdx = 0 jdx = 0
while idx < self.n_mit_mot_outs: while idx < info.n_mit_mot_outs:
# Not that for mit_mot there are several output slices per # Not that for mit_mot there are several output slices per
# output sequence # output sequence
o = outputs[idx] o = outputs[idx]
...@@ -706,11 +736,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -706,11 +736,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
typeConstructor((False,) + o.type.broadcastable, o.type.dtype) typeConstructor((False,) + o.type.broadcastable, o.type.dtype)
) )
idx += len(self.mit_mot_out_slices[jdx]) idx += len(info.mit_mot_out_slices[jdx])
jdx += 1 jdx += 1
# mit_sot / sit_sot / nit_sot # mit_sot / sit_sot / nit_sot
end = idx + self.n_mit_sot + self.n_sit_sot + self.n_nit_sot end = idx + info.n_mit_sot + info.n_sit_sot + info.n_nit_sot
for o in outputs[idx:end]: for o in outputs[idx:end]:
self.output_types.append( self.output_types.append(
...@@ -728,17 +758,17 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -728,17 +758,17 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self.name = "scan_fn" self.name = "scan_fn"
# Pre-computing some values to speed up perform # Pre-computing some values to speed up perform
self.mintaps = [np.min(x) for x in self.tap_array] self.mintaps = [np.min(x) for x in info.tap_array]
self.mintaps += [0 for x in range(self.n_nit_sot)] self.mintaps += [0 for x in range(info.n_nit_sot)]
self.seqs_arg_offset = 1 + self.n_seqs self.seqs_arg_offset = 1 + info.n_seqs
self.shared_arg_offset = ( self.shared_arg_offset = (
self.seqs_arg_offset + self.n_mit_mot + self.n_mit_sot + self.n_sit_sot self.seqs_arg_offset + info.n_mit_mot + info.n_mit_sot + info.n_sit_sot
) )
self.nit_sot_arg_offset = self.shared_arg_offset + self.n_shared_outs self.nit_sot_arg_offset = self.shared_arg_offset + info.n_shared_outs
# XXX: This doesn't include `self.n_nit_sot`s, so it's really a count # XXX: This doesn't include `info.n_nit_sot`s, so it's really a count
# of the number of outputs generated by taps with inputs # of the number of outputs generated by taps with inputs
self.n_outs = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot self.n_outs = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot
self.n_tap_outs = self.n_mit_mot + self.n_mit_sot self.n_tap_outs = info.n_mit_mot + info.n_mit_sot
# Do the missing inputs check here to have the error early. # Do the missing inputs check here to have the error early.
for var in graph_inputs(self.outputs, self.inputs): for var in graph_inputs(self.outputs, self.inputs):
...@@ -758,15 +788,16 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -758,15 +788,16 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
if config.scan__allow_output_prealloc: if config.scan__allow_output_prealloc:
preallocated_mitmot_outs = [] preallocated_mitmot_outs = []
input_idx = self.n_seqs info = self.info
for mitmot_idx in range(self.n_mit_mot): input_idx = info.n_seqs
for inp_tap in self.tap_array[mitmot_idx]: for mitmot_idx in range(info.n_mit_mot):
if inp_tap in self.mit_mot_out_slices[mitmot_idx]: for inp_tap in info.tap_array[mitmot_idx]:
if inp_tap in info.mit_mot_out_slices[mitmot_idx]:
# Figure out the index of the corresponding output # Figure out the index of the corresponding output
output_idx = sum( output_idx = sum(
[len(m) for m in self.mit_mot_out_slices[:mitmot_idx]] [len(m) for m in info.mit_mot_out_slices[:mitmot_idx]]
) )
output_idx += self.mit_mot_out_slices[mitmot_idx].index(inp_tap) output_idx += info.mit_mot_out_slices[mitmot_idx].index(inp_tap)
preallocated_mitmot_outs.append(output_idx) preallocated_mitmot_outs.append(output_idx)
input_idx += 1 input_idx += 1
...@@ -781,7 +812,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -781,7 +812,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# Store the list of mitmot output taps that have been altered so they # Store the list of mitmot output taps that have been altered so they
# can be preallocated # can be preallocated
mitmots_preallocated = [ mitmots_preallocated = [
i in preallocated_mitmot_outs for i in range(self.n_mit_mot_outs) i in preallocated_mitmot_outs for i in range(info.n_mit_mot_outs)
] ]
return preallocated_mitmot_outs, mitmots_preallocated return preallocated_mitmot_outs, mitmots_preallocated
...@@ -1112,11 +1143,13 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1112,11 +1143,13 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
return isinstance(s.type, TensorType) and s.ndim == 1 return isinstance(s.type, TensorType) and s.ndim == 1
self.vector_seqs = [ self.vector_seqs = [
is_cpu_vector(seq) for seq in new_inputs[1 : 1 + self.n_seqs] is_cpu_vector(seq) for seq in new_inputs[1 : 1 + self.info.n_seqs]
] ]
self.vector_outs = [ self.vector_outs = [
is_cpu_vector(arg) is_cpu_vector(arg)
for arg in new_inputs[1 + self.n_seqs : (1 + self.n_seqs + self.n_outs)] for arg in new_inputs[
1 + self.info.n_seqs : (1 + self.info.n_seqs + self.n_outs)
]
] ]
self.vector_outs += [ self.vector_outs += [
isinstance(t.type, TensorType) and t.ndim == 0 isinstance(t.type, TensorType) and t.ndim == 0
...@@ -1174,7 +1207,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1174,7 +1207,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
if len(self.destroy_map.keys()) > 0: if len(self.destroy_map.keys()) > 0:
# Check if all outputs are inplace # Check if all outputs are inplace
if sorted(self.destroy_map.keys()) == sorted( if sorted(self.destroy_map.keys()) == sorted(
range(self.n_mit_mot + self.n_mit_sot + self.n_sit_sot) range(self.info.n_mit_mot + self.info.n_mit_sot + self.info.n_sit_sot)
): ):
aux_txt += "all_inplace,%s,%s}" aux_txt += "all_inplace,%s,%s}"
else: else:
...@@ -1210,7 +1243,12 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1210,7 +1243,12 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# If a shared variable is the result of a ViewOp it is a clear # If a shared variable is the result of a ViewOp it is a clear
# indication that we need to copy that value after the perform of # indication that we need to copy that value after the perform of
# scan is done # scan is done
slices = self.n_mit_mot_outs + self.n_mit_sot + self.n_sit_sot + self.n_nit_sot slices = (
self.info.n_mit_mot_outs
+ self.info.n_mit_sot
+ self.info.n_sit_sot
+ self.info.n_nit_sot
)
if config.scan__allow_output_prealloc: if config.scan__allow_output_prealloc:
...@@ -1218,20 +1256,24 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1218,20 +1256,24 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# input and an output, wrap the input such that the corresponding # input and an output, wrap the input such that the corresponding
# output variable becomes an update to be performed on it, possibly # output variable becomes an update to be performed on it, possibly
# inplace at the end of the functions's execution. # inplace at the end of the functions's execution.
wrapped_inputs = [In(x, borrow=False) for x in self.inputs[: self.n_seqs]] wrapped_inputs = [
In(x, borrow=False) for x in self.inputs[: self.info.n_seqs]
]
new_outputs = [x for x in self.outputs] new_outputs = [x for x in self.outputs]
input_idx = self.n_seqs input_idx = self.info.n_seqs
for mitmot_idx in range(self.n_mit_mot): for mitmot_idx in range(self.info.n_mit_mot):
for inp_tap in self.tap_array[mitmot_idx]: for inp_tap in self.info.tap_array[mitmot_idx]:
if inp_tap in self.mit_mot_out_slices[mitmot_idx]: if inp_tap in self.info.mit_mot_out_slices[mitmot_idx]:
inp = self.inputs[input_idx] inp = self.inputs[input_idx]
# Figure out the index of the corresponding output # Figure out the index of the corresponding output
output_idx = sum( output_idx = sum(
[len(m) for m in self.mit_mot_out_slices[:mitmot_idx]] [len(m) for m in self.info.mit_mot_out_slices[:mitmot_idx]]
)
output_idx += self.info.mit_mot_out_slices[mitmot_idx].index(
inp_tap
) )
output_idx += self.mit_mot_out_slices[mitmot_idx].index(inp_tap)
# Make it so the input is automatically updated to the # Make it so the input is automatically updated to the
# output value, possibly inplace, at the end of the # output value, possibly inplace, at the end of the
...@@ -1273,8 +1315,13 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1273,8 +1315,13 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# so that it runs before them). This feature will prevent mitsot, # so that it runs before them). This feature will prevent mitsot,
# sitsot and nitsot outputs from being computed inplace (to allow # sitsot and nitsot outputs from being computed inplace (to allow
# their preallocation). # their preallocation).
mitsot_start = self.n_mit_mot_outs - len(self.preallocated_mitmot_outs) mitsot_start = self.info.n_mit_mot_outs - len(self.preallocated_mitmot_outs)
nitsot_end = mitsot_start + self.n_mit_sot + self.n_sit_sot + self.n_nit_sot nitsot_end = (
mitsot_start
+ self.info.n_mit_sot
+ self.info.n_sit_sot
+ self.info.n_nit_sot
)
feature = NoOutputFromInplace(mitsot_start, nitsot_end) feature = NoOutputFromInplace(mitsot_start, nitsot_end)
opt = AddFeatureOptimizer(feature) opt = AddFeatureOptimizer(feature)
compilation_mode = self.mode_instance.register((opt, 49.9)) compilation_mode = self.mode_instance.register((opt, 49.9))
...@@ -1371,7 +1418,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1371,7 +1418,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
cython_mintaps = np.asarray(self.mintaps, dtype="int32") cython_mintaps = np.asarray(self.mintaps, dtype="int32")
tap_array_len = tuple(len(x) for x in self.tap_array) tap_array_len = tuple(len(x) for x in self.info.tap_array)
cython_vector_seqs = np.asarray(self.vector_seqs, dtype="int32") cython_vector_seqs = np.asarray(self.vector_seqs, dtype="int32")
cython_vector_outs = np.asarray(self.vector_outs, dtype="int32") cython_vector_outs = np.asarray(self.vector_outs, dtype="int32")
...@@ -1421,20 +1468,20 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1421,20 +1468,20 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
try: try:
t_fn, n_steps = scan_perform_ext.perform( t_fn, n_steps = scan_perform_ext.perform(
self.n_shared_outs, self.info.n_shared_outs,
self.n_mit_mot_outs, self.info.n_mit_mot_outs,
self.n_seqs, self.info.n_seqs,
self.n_mit_mot, self.info.n_mit_mot,
self.n_mit_sot, self.info.n_mit_sot,
self.n_sit_sot, self.info.n_sit_sot,
self.n_nit_sot, self.info.n_nit_sot,
self.as_while, self.as_while,
cython_mintaps, cython_mintaps,
self.tap_array, self.info.tap_array,
tap_array_len, tap_array_len,
cython_vector_seqs, cython_vector_seqs,
cython_vector_outs, cython_vector_outs,
self.mit_mot_out_slices, self.info.mit_mot_out_slices,
cython_mitmots_preallocated, cython_mitmots_preallocated,
cython_outs_is_tensor, cython_outs_is_tensor,
inner_input_storage, inner_input_storage,
...@@ -1509,14 +1556,14 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1509,14 +1556,14 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
n_steps n_steps
X sequence inputs x_1, x_2, ... x_<self.n_seqs> X sequence inputs x_1, x_2, ... x_<self.info.n_seqs>
Y initial states (u_1, u_2, ... u_<self.n_outs>) for our Y initial states (u_1, u_2, ... u_<self.n_outs>) for our
outputs. Each must have appropriate length (T_1, T_2, ..., T_Y). outputs. Each must have appropriate length (T_1, T_2, ..., T_Y).
W other inputs w_1, w_2, ... w_W W other inputs w_1, w_2, ... w_W
There are at least ``1 + self.n_seqs + self.n_outs`` inputs, and the There are at least ``1 + self.info.n_seqs + self.n_outs`` inputs, and the
ones above this number are passed to the scanned function as ones above this number are passed to the scanned function as
non-sequential inputs. non-sequential inputs.
...@@ -1525,6 +1572,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1525,6 +1572,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
Y sequence outputs y_1, y_2, ... y_<self.n_outs> Y sequence outputs y_1, y_2, ... y_<self.n_outs>
""" """
info = self.info
# 1. Unzip the number of steps and sequences. # 1. Unzip the number of steps and sequences.
t0_call = time.time() t0_call = time.time()
t_fn = 0 t_fn = 0
...@@ -1557,7 +1605,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1557,7 +1605,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
store_steps += [ store_steps += [
arg arg
for arg in inputs[ for arg in inputs[
self.nit_sot_arg_offset : self.nit_sot_arg_offset + self.n_nit_sot self.nit_sot_arg_offset : self.nit_sot_arg_offset + info.n_nit_sot
] ]
] ]
...@@ -1575,7 +1623,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1575,7 +1623,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
): ):
# Put in the values of the initial state # Put in the values of the initial state
output_storage[idx][0] = output_storage[idx][0][: store_steps[idx]] output_storage[idx][0] = output_storage[idx][0][: store_steps[idx]]
if idx > self.n_mit_mot: if idx > info.n_mit_mot:
l = -self.mintaps[idx] l = -self.mintaps[idx]
output_storage[idx][0][:l] = inputs[self.seqs_arg_offset + idx][:l] output_storage[idx][0][:l] = inputs[self.seqs_arg_offset + idx][:l]
else: else:
...@@ -1584,7 +1632,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1584,7 +1632,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
output_storage[idx][0] = inputs[self.seqs_arg_offset + idx].copy() output_storage[idx][0] = inputs[self.seqs_arg_offset + idx].copy()
if n_steps == 0: if n_steps == 0:
for idx in range(self.n_outs, self.n_outs + self.n_nit_sot): for idx in range(self.n_outs, self.n_outs + info.n_nit_sot):
out_var = node.outputs[idx] out_var = node.outputs[idx]
if isinstance(out_var, TensorVariable): if isinstance(out_var, TensorVariable):
output_storage[idx][0] = np.empty( output_storage[idx][0] = np.empty(
...@@ -1596,13 +1644,13 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1596,13 +1644,13 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
pos = [ pos = [
(-self.mintaps[idx]) % store_steps[idx] (-self.mintaps[idx]) % store_steps[idx]
for idx in range(self.n_outs + self.n_nit_sot) for idx in range(self.n_outs + info.n_nit_sot)
] ]
offset = self.nit_sot_arg_offset + self.n_nit_sot offset = self.nit_sot_arg_offset + info.n_nit_sot
other_args = inputs[offset:] other_args = inputs[offset:]
inner_input_storage = self.fn.input_storage inner_input_storage = self.fn.input_storage
nb_mitmot_in = sum(map(len, self.tap_array[: self.n_mit_mot])) nb_mitmot_in = sum(map(len, info.tap_array[: info.n_mit_mot]))
old_mitmot_input_storage = [None] * nb_mitmot_in old_mitmot_input_storage = [None] * nb_mitmot_in
old_mitmot_input_data = [None] * nb_mitmot_in old_mitmot_input_data = [None] * nb_mitmot_in
inner_output_storage = self.fn.output_storage inner_output_storage = self.fn.output_storage
...@@ -1610,9 +1658,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1610,9 +1658,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
old_inner_output_data = [None] * len(inner_output_storage) old_inner_output_data = [None] * len(inner_output_storage)
fn = self.fn.fn fn = self.fn.fn
offset = ( offset = (
self.n_seqs info.n_seqs
+ sum(map(len, self.tap_array[: self.n_outs])) + sum(map(len, info.tap_array[: self.n_outs]))
+ self.n_shared_outs + info.n_shared_outs
) )
for idx in range(len(other_args)): for idx in range(len(other_args)):
inner_input_storage[idx + offset].storage[0] = other_args[idx] inner_input_storage[idx + offset].storage[0] = other_args[idx]
...@@ -1624,7 +1672,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1624,7 +1672,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
while (i < n_steps) and cond: while (i < n_steps) and cond:
# sequences over which scan iterates # sequences over which scan iterates
# 3. collect input slices # 3. collect input slices
for idx in range(self.n_seqs): for idx in range(info.n_seqs):
if self.vector_seqs[idx]: if self.vector_seqs[idx]:
inner_input_storage[idx].storage[0] = seqs[idx][i : i + 1].reshape( inner_input_storage[idx].storage[0] = seqs[idx][i : i + 1].reshape(
() ()
...@@ -1632,17 +1680,17 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1632,17 +1680,17 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
else: else:
inner_input_storage[idx].storage[0] = seqs[idx][i] inner_input_storage[idx].storage[0] = seqs[idx][i]
offset = self.n_seqs offset = info.n_seqs
for idx in range(self.n_outs): for idx in range(self.n_outs):
if self.vector_outs[idx]: if self.vector_outs[idx]:
for tap in self.tap_array[idx]: for tap in info.tap_array[idx]:
_idx = (pos[idx] + tap) % store_steps[idx] _idx = (pos[idx] + tap) % store_steps[idx]
inner_input_storage[offset].storage[0] = output_storage[idx][0][ inner_input_storage[offset].storage[0] = output_storage[idx][0][
_idx : _idx + 1 _idx : _idx + 1
].reshape(()) ].reshape(())
offset += 1 offset += 1
else: else:
for tap in self.tap_array[idx]: for tap in info.tap_array[idx]:
_idx = (pos[idx] + tap) % store_steps[idx] _idx = (pos[idx] + tap) % store_steps[idx]
inner_input_storage[offset].storage[0] = output_storage[idx][0][ inner_input_storage[offset].storage[0] = output_storage[idx][0][
_idx _idx
...@@ -1650,13 +1698,13 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1650,13 +1698,13 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
offset += 1 offset += 1
a_offset = self.shared_arg_offset a_offset = self.shared_arg_offset
o_offset = self.n_outs + self.n_nit_sot o_offset = self.n_outs + info.n_nit_sot
if i == 0: if i == 0:
for j in range(self.n_shared_outs): for j in range(info.n_shared_outs):
inner_input_storage[offset].storage[0] = inputs[a_offset + j] inner_input_storage[offset].storage[0] = inputs[a_offset + j]
offset += 1 offset += 1
else: else:
for j in range(self.n_shared_outs): for j in range(info.n_shared_outs):
inner_input_storage[offset].storage[0] = output_storage[ inner_input_storage[offset].storage[0] = output_storage[
o_offset + j o_offset + j
][0] ][0]
...@@ -1666,36 +1714,36 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1666,36 +1714,36 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# 4.1. Collect slices for mitmots # 4.1. Collect slices for mitmots
offset = 0 offset = 0
for idx in range(self.n_mit_mot_outs): for idx in range(info.n_mit_mot_outs):
if not self.mitmots_preallocated[idx]: if not self.mitmots_preallocated[idx]:
inner_output_storage[offset].storage[0] = None inner_output_storage[offset].storage[0] = None
offset += 1 offset += 1
# 4.2. Collect slices for mitsots, sitsots and nitsots # 4.2. Collect slices for mitsots, sitsots and nitsots
if i != 0: if i != 0:
for idx in range(self.n_outs + self.n_nit_sot - self.n_mit_mot): for idx in range(self.n_outs + info.n_nit_sot - info.n_mit_mot):
if ( if (
store_steps[idx + self.n_mit_mot] == 1 store_steps[idx + info.n_mit_mot] == 1
or self.vector_outs[idx + self.n_mit_mot] or self.vector_outs[idx + info.n_mit_mot]
): ):
inner_output_storage[idx + offset].storage[0] = None inner_output_storage[idx + offset].storage[0] = None
else: else:
_pos0 = idx + self.n_mit_mot _pos0 = idx + info.n_mit_mot
inner_output_storage[idx + offset].storage[0] = output_storage[ inner_output_storage[idx + offset].storage[0] = output_storage[
_pos0 _pos0
][0][pos[_pos0]] ][0][pos[_pos0]]
else: else:
for idx in range(self.n_outs + self.n_nit_sot - self.n_mit_mot): for idx in range(self.n_outs + info.n_nit_sot - info.n_mit_mot):
inner_output_storage[idx + offset].storage[0] = None inner_output_storage[idx + offset].storage[0] = None
# 4.3. Collect slices for shared outputs # 4.3. Collect slices for shared outputs
offset += self.n_outs + self.n_nit_sot - self.n_mit_mot offset += self.n_outs + info.n_nit_sot - info.n_mit_mot
for idx in range(self.n_shared_outs): for idx in range(info.n_shared_outs):
inner_output_storage[idx + offset].storage[0] = None inner_output_storage[idx + offset].storage[0] = None
# 4.4. If there is a condition add it to the mix # 4.4. If there is a condition add it to the mix
if self.as_while: if self.as_while:
pdx = offset + self.n_shared_outs pdx = offset + info.n_shared_outs
inner_output_storage[pdx].storage[0] = None inner_output_storage[pdx].storage[0] = None
# 4.5. Keep a reference to the variables (ndarrays, # 4.5. Keep a reference to the variables (ndarrays,
...@@ -1726,7 +1774,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1726,7 +1774,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# reused the allocated object but alter the memory region they # reused the allocated object but alter the memory region they
# refer to. # refer to.
for idx in range(nb_mitmot_in): for idx in range(nb_mitmot_in):
var = inner_input_storage[idx + self.n_seqs].storage[0] var = inner_input_storage[idx + info.n_seqs].storage[0]
old_mitmot_input_storage[idx] = var old_mitmot_input_storage[idx] = var
if var is None: if var is None:
...@@ -1765,7 +1813,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1765,7 +1813,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
dt_fn = time.time() - t0_fn dt_fn = time.time() - t0_fn
if self.as_while: if self.as_while:
pdx = offset + self.n_shared_outs pdx = offset + info.n_shared_outs
cond = inner_output_storage[pdx].storage[0] == 0 cond = inner_output_storage[pdx].storage[0] == 0
# 5.2. By calling fn() directly instead of calling the aesara # 5.2. By calling fn() directly instead of calling the aesara
...@@ -1787,16 +1835,16 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1787,16 +1835,16 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# 5.3 Copy over the values for mit_mot outputs # 5.3 Copy over the values for mit_mot outputs
mitmot_inp_offset = 0 mitmot_inp_offset = 0
mitmot_out_idx = 0 mitmot_out_idx = 0
for j in range(self.n_mit_mot): for j in range(info.n_mit_mot):
for k in self.mit_mot_out_slices[j]: for k in info.mit_mot_out_slices[j]:
if self.mitmots_preallocated[mitmot_out_idx]: if self.mitmots_preallocated[mitmot_out_idx]:
# This output tap has been preallocated. # This output tap has been preallocated.
inp_idx = mitmot_inp_offset + self.tap_array[j].index(k) inp_idx = mitmot_inp_offset + info.tap_array[j].index(k)
# Verify whether the input points to the same data as # Verify whether the input points to the same data as
# it did before the execution of the inner function. # it did before the execution of the inner function.
old_var = old_mitmot_input_storage[inp_idx] old_var = old_mitmot_input_storage[inp_idx]
new_var = inner_input_storage[self.n_seqs + inp_idx].storage[0] new_var = inner_input_storage[info.n_seqs + inp_idx].storage[0]
if old_var is new_var: if old_var is new_var:
old_data = old_mitmot_input_data[inp_idx] old_data = old_mitmot_input_data[inp_idx]
same_data = new_var.data == old_data same_data = new_var.data == old_data
...@@ -1809,7 +1857,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1809,7 +1857,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# and store it in `outs` as usual # and store it in `outs` as usual
if not same_data: if not same_data:
output_storage[j][0][k + pos[j]] = inner_input_storage[ output_storage[j][0][k + pos[j]] = inner_input_storage[
self.n_seqs + inp_idx info.n_seqs + inp_idx
].storage[0] ].storage[0]
else: else:
...@@ -1822,12 +1870,12 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1822,12 +1870,12 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
mitmot_out_idx += 1 mitmot_out_idx += 1
mitmot_inp_offset += len(self.tap_array[j]) mitmot_inp_offset += len(info.tap_array[j])
# 5.4 Copy over the values for mit_sot/sit_sot outputs # 5.4 Copy over the values for mit_sot/sit_sot outputs
begin = self.n_mit_mot begin = info.n_mit_mot
end = self.n_outs end = self.n_outs
offset_out -= self.n_mit_mot offset_out -= info.n_mit_mot
for j in range(begin, end): for j in range(begin, end):
...@@ -1877,7 +1925,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1877,7 +1925,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# 5.5 Copy over the values for nit_sot outputs # 5.5 Copy over the values for nit_sot outputs
begin = end begin = end
end += self.n_nit_sot end += info.n_nit_sot
for j in range(begin, end): for j in range(begin, end):
if i == 0: if i == 0:
...@@ -1930,7 +1978,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1930,7 +1978,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# 5.6 Copy over the values for outputs corresponding to shared # 5.6 Copy over the values for outputs corresponding to shared
# variables # variables
begin = end begin = end
end += self.n_shared_outs end += info.n_shared_outs
for j in range(begin, end): for j in range(begin, end):
jout = j + offset_out jout = j + offset_out
output_storage[j][0] = inner_output_storage[jout].storage[0] output_storage[j][0] = inner_output_storage[jout].storage[0]
...@@ -1939,8 +1987,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1939,8 +1987,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
i = i + 1 i = i + 1
# 6. Check if you need to re-order output buffers # 6. Check if you need to re-order output buffers
begin = self.n_mit_mot begin = info.n_mit_mot
end = self.n_outs + self.n_nit_sot end = self.n_outs + info.n_nit_sot
for idx in range(begin, end): for idx in range(begin, end):
if store_steps[idx] < i - self.mintaps[idx] and pos[idx] < store_steps[idx]: if store_steps[idx] < i - self.mintaps[idx] and pos[idx] < store_steps[idx]:
...@@ -2043,30 +2091,31 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2043,30 +2091,31 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# cases extra scans in the graph. See gh-XXX for the # cases extra scans in the graph. See gh-XXX for the
# investigation. # investigation.
info = self.info
# We skip the first outer input as it is the total or current number # We skip the first outer input as it is the total or current number
# of iterations. # of iterations.
# sequences # sequences
seqs_shape = [x[1:] for x in input_shapes[1 : 1 + self.n_seqs]] seqs_shape = [x[1:] for x in input_shapes[1 : 1 + info.n_seqs]]
# We disable extra infer_shape for now. See gh-3765. # We disable extra infer_shape for now. See gh-3765.
extra_infer_shape = False extra_infer_shape = False
if extra_infer_shape: if extra_infer_shape:
inner_seqs = self.inputs[: self.n_seqs] inner_seqs = self.inputs[: info.n_seqs]
outer_seqs = node.inputs[1 : 1 + self.n_seqs] outer_seqs = node.inputs[1 : 1 + info.n_seqs]
for in_s, out_s in zip(inner_seqs, outer_seqs): for in_s, out_s in zip(inner_seqs, outer_seqs):
out_equivalent[in_s] = out_s[0] out_equivalent[in_s] = out_s[0]
# mit_mot, mit_sot, sit_sot # mit_mot, mit_sot, sit_sot
outer_inp_idx = 1 + self.n_seqs outer_inp_idx = 1 + info.n_seqs
inner_inp_idx = self.n_seqs inner_inp_idx = info.n_seqs
else: else:
outer_inp_idx = 0 outer_inp_idx = 0
n_outs = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot n_outs = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot
outs_shape = [] outs_shape = []
for idx in range(n_outs): for idx in range(n_outs):
mintap = abs(min(self.tap_array[idx])) mintap = abs(min(info.tap_array[idx]))
for k in self.tap_array[idx]: for k in info.tap_array[idx]:
outs_shape += [input_shapes[idx + self.n_seqs + 1][1:]] outs_shape += [input_shapes[idx + info.n_seqs + 1][1:]]
if extra_infer_shape: if extra_infer_shape:
corresponding_tap = node.inputs[outer_inp_idx][mintap + k] corresponding_tap = node.inputs[outer_inp_idx][mintap + k]
out_equivalent[self.inputs[inner_inp_idx]] = corresponding_tap out_equivalent[self.inputs[inner_inp_idx]] = corresponding_tap
...@@ -2074,12 +2123,12 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2074,12 +2123,12 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
outer_inp_idx += 1 outer_inp_idx += 1
# shared_outs # shared_outs
offset = 1 + self.n_seqs + n_outs offset = 1 + info.n_seqs + n_outs
for idx in range(self.n_shared_outs): for idx in range(info.n_shared_outs):
outs_shape += [input_shapes[idx + offset]] outs_shape += [input_shapes[idx + offset]]
# non_sequences # non_sequences
offset += self.n_nit_sot + self.n_shared_outs offset += info.n_nit_sot + info.n_shared_outs
inner_ins_shapes = seqs_shape + outs_shape + input_shapes[offset:] inner_ins_shapes = seqs_shape + outs_shape + input_shapes[offset:]
assert len(inner_ins_shapes) == len(self.inputs) assert len(inner_ins_shapes) == len(self.inputs)
...@@ -2103,11 +2152,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2103,11 +2152,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
valid=input_shapes, invalid=self.inputs, valid_equivalent=out_equivalent valid=input_shapes, invalid=self.inputs, valid_equivalent=out_equivalent
) )
offset = 1 + self.n_seqs offset = 1 + info.n_seqs
scan_outs = [x for x in input_shapes[offset : offset + n_outs]] scan_outs = [x for x in input_shapes[offset : offset + n_outs]]
offset += n_outs offset += n_outs
outs_shape_n = self.n_mit_mot_outs + self.n_mit_sot + self.n_sit_sot outs_shape_n = info.n_mit_mot_outs + info.n_mit_sot + info.n_sit_sot
for x in range(self.n_nit_sot): for x in range(info.n_nit_sot):
out_shape_x = outs_shape[outs_shape_n + x] out_shape_x = outs_shape[outs_shape_n + x]
if out_shape_x is None: if out_shape_x is None:
# This output is not a tensor, and has no shape # This output is not a tensor, and has no shape
...@@ -2118,7 +2167,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2118,7 +2167,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# in the inner function. # in the inner function.
r = node.outputs[n_outs + x] r = node.outputs[n_outs + x]
assert r.ndim == 1 + len(out_shape_x) assert r.ndim == 1 + len(out_shape_x)
shp = [node.inputs[offset + self.n_shared_outs + x]] shp = [node.inputs[offset + info.n_shared_outs + x]]
for i, shp_i in zip(range(1, r.ndim), out_shape_x): for i, shp_i in zip(range(1, r.ndim), out_shape_x):
# Validate shp_i. v_shape_i is either None (if invalid), # Validate shp_i. v_shape_i is either None (if invalid),
# or a (variable, Boolean) tuple. The Boolean indicates # or a (variable, Boolean) tuple. The Boolean indicates
...@@ -2135,7 +2184,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2135,7 +2184,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
shp.append(v_shp_i[0]) shp.append(v_shp_i[0])
scan_outs.append(tuple(shp)) scan_outs.append(tuple(shp))
scan_outs += [x for x in input_shapes[offset : offset + self.n_shared_outs]] scan_outs += [x for x in input_shapes[offset : offset + info.n_shared_outs]]
# if we are dealing with a repeat-until, then we do not know the # if we are dealing with a repeat-until, then we do not know the
# leading dimension so we replace it for every entry with Shape_i # leading dimension so we replace it for every entry with Shape_i
if self.as_while: if self.as_while:
...@@ -2217,13 +2266,14 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2217,13 +2266,14 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# We do not know what kind of outputs the original scan has, so we # We do not know what kind of outputs the original scan has, so we
# try first to see if it has a nit_sot output, then a sit_sot and # try first to see if it has a nit_sot output, then a sit_sot and
# then a mit_sot # then a mit_sot
if self.n_nit_sot > 0: info = self.info
if info.n_nit_sot > 0:
grad_steps = self.outer_nitsot_outs(outs)[0].shape[0] grad_steps = self.outer_nitsot_outs(outs)[0].shape[0]
elif self.n_sit_sot > 0: elif info.n_sit_sot > 0:
grad_steps = self.outer_sitsot_outs(outs)[0].shape[0] - 1 grad_steps = self.outer_sitsot_outs(outs)[0].shape[0] - 1
elif self.n_mit_sot > 0: elif info.n_mit_sot > 0:
grad_steps = ( grad_steps = (
self.outer_mitsot_outs(outs)[0].shape[0] + self.mintaps[self.n_mit_mot] self.outer_mitsot_outs(outs)[0].shape[0] + self.mintaps[info.n_mit_mot]
) )
else: else:
grad_steps = inputs[0] grad_steps = inputs[0]
...@@ -2255,10 +2305,10 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2255,10 +2305,10 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
connection_pattern = self.connection_pattern(scan_node) connection_pattern = self.connection_pattern(scan_node)
def get_inp_idx(iidx): def get_inp_idx(iidx):
if iidx < self.n_seqs: if iidx < info.n_seqs:
return 1 + iidx return 1 + iidx
oidx = 1 + self.n_seqs oidx = 1 + info.n_seqs
iidx = iidx - self.n_seqs iidx = iidx - info.n_seqs
for taps in self.mitmot_taps(): for taps in self.mitmot_taps():
if len(taps) > iidx: if len(taps) > iidx:
return oidx return oidx
...@@ -2272,10 +2322,10 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2272,10 +2322,10 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
oidx += 1 oidx += 1
iidx -= len(taps) iidx -= len(taps)
if iidx < self.info.n_sit_sot: if iidx < info.n_sit_sot:
return oidx + iidx return oidx + iidx
else: else:
return oidx + iidx + self.info.n_nit_sot return oidx + iidx + info.n_nit_sot
def get_out_idx(iidx): def get_out_idx(iidx):
oidx = 0 oidx = 0
...@@ -2347,7 +2397,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2347,7 +2397,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
for idx, Xt in enumerate(diff_outputs): for idx, Xt in enumerate(diff_outputs):
# We are looking for x[t-1] for a given x[t] # We are looking for x[t-1] for a given x[t]
if idx >= self.n_mit_mot_outs: if idx >= info.n_mit_mot_outs:
Xt_placeholder = safe_new(Xt) Xt_placeholder = safe_new(Xt)
Xts.append(Xt_placeholder) Xts.append(Xt_placeholder)
...@@ -2355,10 +2405,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2355,10 +2405,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# or not. NOTE : This cannot be done by using # or not. NOTE : This cannot be done by using
# "if Xt not in self.inner_nitsot_outs(self_outputs)" because # "if Xt not in self.inner_nitsot_outs(self_outputs)" because
# the exact same variable can be used as multiple outputs. # the exact same variable can be used as multiple outputs.
idx_nitsot_start = ( idx_nitsot_start = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot
self.info.n_mit_mot + self.info.n_mit_sot + self.info.n_sit_sot idx_nitsot_end = idx_nitsot_start + info.n_nit_sot
)
idx_nitsot_end = idx_nitsot_start + self.info.n_nit_sot
if idx < idx_nitsot_start or idx >= idx_nitsot_end: if idx < idx_nitsot_start or idx >= idx_nitsot_end:
# What we do here is loop through dC_douts and collect all # What we do here is loop through dC_douts and collect all
# those that are connected to the specific one and do an # those that are connected to the specific one and do an
...@@ -2378,7 +2426,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2378,7 +2426,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# Get the index of the outer output that to which # Get the index of the outer output that to which
# the state variable 'inp' corresponds. # the state variable 'inp' corresponds.
outer_oidx = var_mappings["outer_out_from_inner_inp"][ outer_oidx = var_mappings["outer_out_from_inner_inp"][
self.n_seqs + pos info.n_seqs + pos
] ]
if not isinstance(dC_douts[outer_oidx].type, DisconnectedType): if not isinstance(dC_douts[outer_oidx].type, DisconnectedType):
...@@ -2420,55 +2468,55 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2420,55 +2468,55 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
dC_dinps_t[dx] = at.zeros_like(diff_inputs[dx]) dC_dinps_t[dx] = at.zeros_like(diff_inputs[dx])
else: else:
disconnected_dC_dinps_t[dx] = False disconnected_dC_dinps_t[dx] = False
for Xt, Xt_placeholder in zip(diff_outputs[self.n_mit_mot_outs :], Xts): for Xt, Xt_placeholder in zip(diff_outputs[info.n_mit_mot_outs :], Xts):
tmp = forced_replace(dC_dinps_t[dx], Xt, Xt_placeholder) tmp = forced_replace(dC_dinps_t[dx], Xt, Xt_placeholder)
dC_dinps_t[dx] = tmp dC_dinps_t[dx] = tmp
# construct dX_dtm1 # construct dX_dtm1
dC_dXtm1s = [] dC_dXtm1s = []
for pos, x in enumerate(dC_dinps_t[self.n_seqs :]): for pos, x in enumerate(dC_dinps_t[info.n_seqs :]):
# Get the index of the first inner input corresponding to the # Get the index of the first inner input corresponding to the
# pos-ieth inner input state # pos-ieth inner input state
idxs = var_mappings["inner_out_from_inner_inp"][self.n_seqs + pos] idxs = var_mappings["inner_out_from_inner_inp"][info.n_seqs + pos]
# Check if the pos-th input is associated with one of the # Check if the pos-th input is associated with one of the
# recurrent states # recurrent states
x_is_state = pos < sum([len(t) for t in self.tap_array]) x_is_state = pos < sum([len(t) for t in info.tap_array])
if x_is_state and len(idxs) > 0: if x_is_state and len(idxs) > 0:
opos = idxs[0] opos = idxs[0]
dC_dXtm1s.append(safe_new(dC_dXts[opos])) dC_dXtm1s.append(safe_new(dC_dXts[opos]))
if hasattr(x, "dtype") and x.dtype != dC_dXts[opos].dtype: if hasattr(x, "dtype") and x.dtype != dC_dXts[opos].dtype:
dC_dinps_t[pos + self.n_seqs] = x.astype(dC_dXts[opos].dtype) dC_dinps_t[pos + info.n_seqs] = x.astype(dC_dXts[opos].dtype)
else: else:
dC_dXtm1s.append(safe_new(x)) dC_dXtm1s.append(safe_new(x))
for dx, dC_dXtm1 in enumerate(dC_dXtm1s): for dx, dC_dXtm1 in enumerate(dC_dXtm1s):
if isinstance(dC_dinps_t[dx + self.n_seqs].type, NullType): if isinstance(dC_dinps_t[dx + info.n_seqs].type, NullType):
# The accumulated gradient is undefined # The accumulated gradient is undefined
pass pass
elif isinstance(dC_dXtm1.type, NullType): elif isinstance(dC_dXtm1.type, NullType):
# The new gradient is undefined, this makes the accumulated # The new gradient is undefined, this makes the accumulated
# gradient undefined as weell # gradient undefined as weell
dC_dinps_t[dx + self.n_seqs] = dC_dXtm1 dC_dinps_t[dx + info.n_seqs] = dC_dXtm1
else: else:
dC_dinps_t[dx + self.n_seqs] += dC_dXtm1 dC_dinps_t[dx + info.n_seqs] += dC_dXtm1
# Construct scan op # Construct scan op
# Seqs # Seqs
if self.as_while: if self.as_while:
# equivalent to x[:n_steps][::-1] # equivalent to x[:n_steps][::-1]
outer_inp_seqs = [x[n_steps - 1 :: -1] for x in inputs[1 : 1 + self.n_seqs]] outer_inp_seqs = [x[n_steps - 1 :: -1] for x in inputs[1 : 1 + info.n_seqs]]
else: else:
outer_inp_seqs = [x[::-1] for x in inputs[1 : 1 + self.n_seqs]] outer_inp_seqs = [x[::-1] for x in inputs[1 : 1 + info.n_seqs]]
for idx in range(self.n_mit_mot + self.n_mit_sot): for idx in range(info.n_mit_mot + info.n_mit_sot):
mintap = np.min(self.tap_array[idx]) mintap = np.min(info.tap_array[idx])
if idx < self.n_mit_mot: if idx < info.n_mit_mot:
outmaxtap = np.max(self.mitmot_out_taps()[idx]) outmaxtap = np.max(self.mitmot_out_taps()[idx])
else: else:
outmaxtap = 0 outmaxtap = 0
seq = outs[idx] seq = outs[idx]
for k in self.tap_array[idx]: for k in info.tap_array[idx]:
if outmaxtap - k != 0: if outmaxtap - k != 0:
nw_seq = seq[k - mintap : -(outmaxtap - k)][::-1] nw_seq = seq[k - mintap : -(outmaxtap - k)][::-1]
else: else:
...@@ -2531,11 +2579,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2531,11 +2579,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
mitmot_out_taps = [] mitmot_out_taps = []
type_outs = [] type_outs = []
out_pos = 0 out_pos = 0
ins_pos = self.n_seqs ins_pos = info.n_seqs
n_mitmot_outs = 0 n_mitmot_outs = 0
n_mitmot_inps = 0 n_mitmot_inps = 0
for idx in range(self.n_mit_mot): for idx in range(info.n_mit_mot):
if isinstance(dC_douts[idx].type, DisconnectedType): if isinstance(dC_douts[idx].type, DisconnectedType):
out = outs[idx] out = outs[idx]
outer_inp_mitmot.append(at.zeros_like(out)) outer_inp_mitmot.append(at.zeros_like(out))
...@@ -2547,19 +2595,19 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2547,19 +2595,19 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
through_shared = False through_shared = False
disconnected = True disconnected = True
for jdx in range(len(self.mit_mot_out_slices[idx])): for jdx in range(len(info.mit_mot_out_slices[idx])):
inner_inp_mitmot.append(dC_dXts[out_pos]) inner_inp_mitmot.append(dC_dXts[out_pos])
mitmot_inp_taps[idx].append(-self.mit_mot_out_slices[idx][jdx]) mitmot_inp_taps[idx].append(-info.mit_mot_out_slices[idx][jdx])
n_mitmot_inps += 1 n_mitmot_inps += 1
out_pos += 1 out_pos += 1
for jdx in range(len(self.tap_array[idx])): for jdx in range(len(info.tap_array[idx])):
tap = -self.tap_array[idx][jdx] tap = -info.tap_array[idx][jdx]
# Only create a new inner input if there is not already one # Only create a new inner input if there is not already one
# associated with this input tap # associated with this input tap
if tap not in mitmot_inp_taps[idx]: if tap not in mitmot_inp_taps[idx]:
inner_inp_mitmot.append(dC_dXtm1s[ins_pos - self.n_seqs]) inner_inp_mitmot.append(dC_dXtm1s[ins_pos - info.n_seqs])
if isinstance(dC_dinps_t[ins_pos].type, NullType): if isinstance(dC_dinps_t[ins_pos].type, NullType):
# We cannot use Null in the inner graph, so we # We cannot use Null in the inner graph, so we
...@@ -2575,13 +2623,13 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2575,13 +2623,13 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# input tap, make sure the computation of the new output # input tap, make sure the computation of the new output
# uses it instead of the input it's currently using # uses it instead of the input it's currently using
if tap in mitmot_inp_taps[idx]: if tap in mitmot_inp_taps[idx]:
to_replace = dC_dXtm1s[ins_pos - self.n_seqs] to_replace = dC_dXtm1s[ins_pos - info.n_seqs]
replacement_idx = len(mitmot_inp_taps[idx]) - mitmot_inp_taps[ replacement_idx = len(mitmot_inp_taps[idx]) - mitmot_inp_taps[
idx idx
].index(tap) ].index(tap)
replacement = inner_inp_mitmot[-replacement_idx] replacement = inner_inp_mitmot[-replacement_idx]
self.tap_array[idx] info.tap_array[idx]
new_inner_out_mitmot = clone_replace( new_inner_out_mitmot = clone_replace(
new_inner_out_mitmot, replace=[(to_replace, replacement)] new_inner_out_mitmot, replace=[(to_replace, replacement)]
) )
...@@ -2597,12 +2645,12 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2597,12 +2645,12 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
ins_pos += 1 ins_pos += 1
n_mitmot_outs += 1 n_mitmot_outs += 1
mitmot_out_taps[idx].append(-self.tap_array[idx][jdx]) mitmot_out_taps[idx].append(-info.tap_array[idx][jdx])
# Only add the tap as a new input tap if needed # Only add the tap as a new input tap if needed
if tap not in mitmot_inp_taps[idx]: if tap not in mitmot_inp_taps[idx]:
n_mitmot_inps += 1 n_mitmot_inps += 1
mitmot_inp_taps[idx].append(-self.tap_array[idx][jdx]) mitmot_inp_taps[idx].append(-info.tap_array[idx][jdx])
if undefined_msg: if undefined_msg:
type_outs.append(undefined_msg) type_outs.append(undefined_msg)
...@@ -2613,15 +2661,15 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2613,15 +2661,15 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
else: else:
type_outs.append("connected") type_outs.append("connected")
offset = self.n_mit_mot offset = info.n_mit_mot
for idx in range(self.n_mit_sot): for idx in range(info.n_mit_sot):
if isinstance(dC_douts[idx + offset].type, DisconnectedType): if isinstance(dC_douts[idx + offset].type, DisconnectedType):
outer_inp_mitmot.append(outs[idx + offset].zeros_like()) outer_inp_mitmot.append(outs[idx + offset].zeros_like())
else: else:
outer_inp_mitmot.append(dC_douts[idx + offset][::-1]) outer_inp_mitmot.append(dC_douts[idx + offset][::-1])
mitmot_inp_taps.append([]) mitmot_inp_taps.append([])
mitmot_out_taps.append([]) mitmot_out_taps.append([])
idx_tap = idx + self.n_mit_mot idx_tap = idx + info.n_mit_mot
inner_inp_mitmot.append(dC_dXts[out_pos]) inner_inp_mitmot.append(dC_dXts[out_pos])
out_pos += 1 out_pos += 1
n_mitmot_inps += 1 n_mitmot_inps += 1
...@@ -2629,8 +2677,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2629,8 +2677,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
through_shared = False through_shared = False
disconnected = True disconnected = True
mitmot_inp_taps[idx + offset].append(0) mitmot_inp_taps[idx + offset].append(0)
for jdx in range(len(self.tap_array[idx_tap])): for jdx in range(len(info.tap_array[idx_tap])):
inner_inp_mitmot.append(dC_dXtm1s[ins_pos - self.n_seqs]) inner_inp_mitmot.append(dC_dXtm1s[ins_pos - info.n_seqs])
if isinstance(dC_dinps_t[ins_pos].type, NullType): if isinstance(dC_dinps_t[ins_pos].type, NullType):
# We cannot use Null in the inner graph, so we # We cannot use Null in the inner graph, so we
...@@ -2642,8 +2690,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2642,8 +2690,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
else: else:
inner_out_mitmot.append(dC_dinps_t[ins_pos]) inner_out_mitmot.append(dC_dinps_t[ins_pos])
mitmot_inp_taps[idx + offset].append(-self.tap_array[idx_tap][jdx]) mitmot_inp_taps[idx + offset].append(-info.tap_array[idx_tap][jdx])
mitmot_out_taps[idx].append(-self.tap_array[idx_tap][jdx]) mitmot_out_taps[idx].append(-info.tap_array[idx_tap][jdx])
if not disconnected_dC_dinps_t[ins_pos]: if not disconnected_dC_dinps_t[ins_pos]:
disconnected = False disconnected = False
for _sh in self.inner_shared(self_inputs): for _sh in self.inner_shared(self_inputs):
...@@ -2663,8 +2711,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2663,8 +2711,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
else: else:
type_outs.append("connected") type_outs.append("connected")
offset += self.n_mit_sot offset += info.n_mit_sot
for idx in range(self.n_sit_sot): for idx in range(info.n_sit_sot):
mitmot_inp_taps.append([0, 1]) mitmot_inp_taps.append([0, 1])
mitmot_out_taps.append([1]) mitmot_out_taps.append([1])
through_shared = False through_shared = False
...@@ -2707,14 +2755,17 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2707,14 +2755,17 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
else: else:
type_outs.append("connected") type_outs.append("connected")
inner_inp_mitmot += [dC_dXts[out_pos], dC_dXtm1s[ins_pos - self.n_seqs]] inner_inp_mitmot += [
dC_dXts[out_pos],
dC_dXtm1s[ins_pos - info.n_seqs],
]
n_mitmot_outs += 1 n_mitmot_outs += 1
out_pos += 1 out_pos += 1
ins_pos += 1 ins_pos += 1
n_mitmot_inps += 2 n_mitmot_inps += 2
n_nit_sot = self.n_seqs n_nit_sot = info.n_seqs
inner_out_nitsot = dC_dinps_t[: self.n_seqs] inner_out_nitsot = dC_dinps_t[: info.n_seqs]
inner_out_sitsot = dC_dinps_t[ins_pos:] inner_out_sitsot = dC_dinps_t[ins_pos:]
for _p, vl in enumerate(inner_out_sitsot): for _p, vl in enumerate(inner_out_sitsot):
through_shared = False through_shared = False
...@@ -2755,7 +2806,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2755,7 +2806,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
else: else:
type_outs.append("connected") type_outs.append("connected")
inner_inp_sitsot = dC_dXtm1s[ins_pos - self.n_seqs :] inner_inp_sitsot = dC_dXtm1s[ins_pos - info.n_seqs :]
outer_inp_sitsot = [] outer_inp_sitsot = []
for _idx, y in enumerate(inner_inp_sitsot): for _idx, y in enumerate(inner_inp_sitsot):
x = self.outer_non_seqs(inputs)[_idx] x = self.outer_non_seqs(inputs)[_idx]
...@@ -2783,7 +2834,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2783,7 +2834,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
n_sitsot_outs = len(outer_inp_sitsot) n_sitsot_outs = len(outer_inp_sitsot)
new_tap_array = mitmot_inp_taps + [[-1] for k in range(n_sitsot_outs)] new_tap_array = mitmot_inp_taps + [[-1] for k in range(n_sitsot_outs)]
info = ScanInfo( out_info = ScanInfo(
n_seqs=len(outer_inp_seqs), n_seqs=len(outer_inp_seqs),
n_mit_sot=0, n_mit_sot=0,
tap_array=tuple(tuple(v) for v in new_tap_array), tap_array=tuple(tuple(v) for v in new_tap_array),
...@@ -2817,7 +2868,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2817,7 +2868,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
local_op = Scan( local_op = Scan(
inner_gfn_ins, inner_gfn_ins,
inner_gfn_outs, inner_gfn_outs,
info, out_info,
mode=self.mode, mode=self.mode,
truncate_gradient=self.truncate_gradient, truncate_gradient=self.truncate_gradient,
as_while=False, as_while=False,
...@@ -2831,11 +2882,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2831,11 +2882,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# Re-order the gradients correctly # Re-order the gradients correctly
gradients = [DisconnectedType()()] gradients = [DisconnectedType()()]
offset = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot + n_sitsot_outs offset = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot + n_sitsot_outs
for p, (x, t) in enumerate( for p, (x, t) in enumerate(
zip( zip(
outputs[offset : offset + self.n_seqs], outputs[offset : offset + info.n_seqs],
type_outs[offset : offset + self.n_seqs], type_outs[offset : offset + info.n_seqs],
) )
): ):
if t == "connected": if t == "connected":
...@@ -2864,7 +2915,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2864,7 +2915,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# t contains the "why_null" string of a NullType # t contains the "why_null" string of a NullType
gradients.append(NullType(t)()) gradients.append(NullType(t)())
end = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot end = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot
for p, (x, t) in enumerate(zip(outputs[:end], type_outs[:end])): for p, (x, t) in enumerate(zip(outputs[:end], type_outs[:end])):
if t == "connected": if t == "connected":
# If the forward scan is in as_while mode, we need to pad # If the forward scan is in as_while mode, we need to pad
...@@ -2886,8 +2937,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2886,8 +2937,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
gradients.append( gradients.append(
grad_undefined( grad_undefined(
self, self,
p + 1 + self.n_seqs, p + 1 + info.n_seqs,
inputs[p + 1 + self.n_seqs], inputs[p + 1 + info.n_seqs],
"Depends on a shared variable", "Depends on a shared variable",
) )
) )
...@@ -2897,7 +2948,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2897,7 +2948,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
start = len(gradients) start = len(gradients)
node = outs[0].owner node = outs[0].owner
for idx in range(self.n_shared_outs): for idx in range(info.n_shared_outs):
disconnected = True disconnected = True
connected_flags = self.connection_pattern(node)[idx + start] connected_flags = self.connection_pattern(node)[idx + start]
for dC_dout, connected in zip(dC_douts, connected_flags): for dC_dout, connected in zip(dC_douts, connected_flags):
...@@ -2913,7 +2964,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2913,7 +2964,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
) )
start = len(gradients) start = len(gradients)
gradients += [DisconnectedType()() for _ in range(self.n_nit_sot)] gradients += [DisconnectedType()() for _ in range(info.n_nit_sot)]
begin = end begin = end
end = begin + n_sitsot_outs end = begin + n_sitsot_outs
...@@ -2954,10 +3005,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2954,10 +3005,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
# Step 0. Prepare some shortcut variable # Step 0. Prepare some shortcut variable
info = self.info
self_inputs = self.inputs self_inputs = self.inputs
rop_of_inputs = ( rop_of_inputs = (
self_inputs[: self.n_seqs + self.n_outs] self_inputs[: info.n_seqs + self.n_outs]
+ self_inputs[self.n_seqs + self.n_outs + self.n_shared_outs :] + self_inputs[info.n_seqs + self.n_outs + info.n_shared_outs :]
) )
self_outputs = self.outputs self_outputs = self.outputs
...@@ -2967,8 +3019,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2967,8 +3019,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
rop_self_outputs = self_outputs[:-1] rop_self_outputs = self_outputs[:-1]
else: else:
rop_self_outputs = self_outputs rop_self_outputs = self_outputs
if self.info.n_shared_outs > 0: if info.n_shared_outs > 0:
rop_self_outputs = rop_self_outputs[: -self.info.n_shared_outs] rop_self_outputs = rop_self_outputs[: -info.n_shared_outs]
rop_outs = Rop(rop_self_outputs, rop_of_inputs, inner_eval_points) rop_outs = Rop(rop_self_outputs, rop_of_inputs, inner_eval_points)
if not isinstance(rop_outs, (list, tuple)): if not isinstance(rop_outs, (list, tuple)):
rop_outs = [rop_outs] rop_outs = [rop_outs]
...@@ -2985,20 +3037,20 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2985,20 +3037,20 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
new_tap_array = [] new_tap_array = []
b = 0 b = 0
e = self.n_mit_mot e = info.n_mit_mot
new_tap_array += self.tap_array[b:e] * 2 new_tap_array += info.tap_array[b:e] * 2
b = e b = e
e += self.n_mit_sot e += info.n_mit_sot
new_tap_array += self.tap_array[b:e] * 2 new_tap_array += info.tap_array[b:e] * 2
b = e b = e
e += self.n_sit_sot e += info.n_sit_sot
new_tap_array += self.tap_array[b:e] * 2 new_tap_array += info.tap_array[b:e] * 2
# Sequences ... # Sequences ...
b = 1 b = 1
ib = 0 ib = 0
e = 1 + self.n_seqs e = 1 + info.n_seqs
ie = self.n_seqs ie = info.n_seqs
clean_eval_points = [] clean_eval_points = []
for inp, evp in zip(inputs[b:e], eval_points[b:e]): for inp, evp in zip(inputs[b:e], eval_points[b:e]):
if evp is not None: if evp is not None:
...@@ -3011,9 +3063,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3011,9 +3063,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# MIT_MOT sequences ... # MIT_MOT sequences ...
b = e b = e
e = e + self.n_mit_mot e = e + info.n_mit_mot
ib = ie ib = ie
ie = ie + int(sum(len(x) for x in self.tap_array[: self.n_mit_mot])) ie = ie + int(sum(len(x) for x in info.tap_array[: info.n_mit_mot]))
clean_eval_points = [] clean_eval_points = []
for inp, evp in zip(inputs[b:e], eval_points[b:e]): for inp, evp in zip(inputs[b:e], eval_points[b:e]):
if evp is not None: if evp is not None:
...@@ -3026,13 +3078,13 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3026,13 +3078,13 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# MIT_SOT sequences ... # MIT_SOT sequences ...
b = e b = e
e = e + self.n_mit_sot e = e + info.n_mit_sot
ib = ie ib = ie
ie = ie + int( ie = ie + int(
sum( sum(
len(x) len(x)
for x in self.tap_array[ for x in info.tap_array[
self.n_mit_mot : self.n_mit_mot + self.n_mit_sot info.n_mit_mot : info.n_mit_mot + info.n_mit_sot
] ]
) )
) )
...@@ -3048,9 +3100,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3048,9 +3100,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# SIT_SOT sequences ... # SIT_SOT sequences ...
b = e b = e
e = e + self.n_sit_sot e = e + info.n_sit_sot
ib = ie ib = ie
ie = ie + self.n_sit_sot ie = ie + info.n_sit_sot
clean_eval_points = [] clean_eval_points = []
for inp, evp in zip(inputs[b:e], eval_points[b:e]): for inp, evp in zip(inputs[b:e], eval_points[b:e]):
if evp is not None: if evp is not None:
...@@ -3063,15 +3115,15 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3063,15 +3115,15 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# Shared outs ... # Shared outs ...
b = e b = e
e = e + self.n_shared_outs e = e + info.n_shared_outs
ib = ie ib = ie
ie = ie + self.n_shared_outs ie = ie + info.n_shared_outs
scan_shared = inputs[b:e] scan_shared = inputs[b:e]
inner_shared = self_inputs[ib:ie] inner_shared = self_inputs[ib:ie]
# NIT_SOT sequences # NIT_SOT sequences
b = e b = e
e = e + self.n_nit_sot e = e + info.n_nit_sot
scan_nit_sot = inputs[b:e] * 2 scan_nit_sot = inputs[b:e] * 2
# All other arguments # All other arguments
...@@ -3086,22 +3138,22 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3086,22 +3138,22 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
inner_other = self_inputs[ie:] + inner_eval_points[ib:] inner_other = self_inputs[ie:] + inner_eval_points[ib:]
# Outputs # Outputs
n_mit_mot_outs = int(sum(len(x) for x in self.mit_mot_out_slices)) n_mit_mot_outs = int(sum(len(x) for x in info.mit_mot_out_slices))
b = 0 b = 0
e = n_mit_mot_outs e = n_mit_mot_outs
inner_out_mit_mot = self_outputs[b:e] + rop_outs[b:e] inner_out_mit_mot = self_outputs[b:e] + rop_outs[b:e]
b = e b = e
e = e + self.n_mit_sot e = e + info.n_mit_sot
inner_out_mit_sot = self_outputs[b:e] + rop_outs[b:e] inner_out_mit_sot = self_outputs[b:e] + rop_outs[b:e]
b = e b = e
e = e + self.n_sit_sot e = e + info.n_sit_sot
inner_out_sit_sot = self_outputs[b:e] + rop_outs[b:e] inner_out_sit_sot = self_outputs[b:e] + rop_outs[b:e]
b = e b = e
e = e + self.n_nit_sot e = e + info.n_nit_sot
inner_out_nit_sot = self_outputs[b:e] + rop_outs[b:e] inner_out_nit_sot = self_outputs[b:e] + rop_outs[b:e]
b = e b = e
e = e + self.n_shared_outs e = e + info.n_shared_outs
inner_out_shared = self_outputs[b:e] inner_out_shared = self_outputs[b:e]
inner_ins = ( inner_ins = (
...@@ -3133,22 +3185,22 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3133,22 +3185,22 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
+ scan_other + scan_other
) )
info = ScanInfo( out_info = ScanInfo(
n_seqs=self.n_seqs * 2, n_seqs=info.n_seqs * 2,
n_mit_sot=self.n_mit_sot * 2, n_mit_sot=info.n_mit_sot * 2,
n_sit_sot=self.n_sit_sot * 2, n_sit_sot=info.n_sit_sot * 2,
n_mit_mot=self.n_mit_mot * 2, n_mit_mot=info.n_mit_mot * 2,
n_nit_sot=self.n_nit_sot * 2, n_nit_sot=info.n_nit_sot * 2,
n_shared_outs=self.n_shared_outs, n_shared_outs=info.n_shared_outs,
n_mit_mot_outs=n_mit_mot_outs * 2, n_mit_mot_outs=n_mit_mot_outs * 2,
tap_array=tuple(tuple(v) for v in new_tap_array), tap_array=tuple(tuple(v) for v in new_tap_array),
mit_mot_out_slices=tuple(tuple(v) for v in self.mit_mot_out_slices) * 2, mit_mot_out_slices=tuple(tuple(v) for v in info.mit_mot_out_slices) * 2,
) )
local_op = Scan( local_op = Scan(
inner_ins, inner_ins,
inner_outs, inner_outs,
info, out_info,
mode=self.mode, mode=self.mode,
as_while=self.as_while, as_while=self.as_while,
profile=self.profile, profile=self.profile,
...@@ -3161,19 +3213,19 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3161,19 +3213,19 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
outputs = [outputs] outputs = [outputs]
# Select only the result of the R_op results # Select only the result of the R_op results
final_outs = [] final_outs = []
b = self.n_mit_mot b = info.n_mit_mot
e = self.n_mit_mot * 2 e = info.n_mit_mot * 2
final_outs += outputs[b:e] final_outs += outputs[b:e]
b = e + self.n_mit_sot b = e + info.n_mit_sot
e = e + self.n_mit_sot * 2 e = e + info.n_mit_sot * 2
final_outs += outputs[b:e] final_outs += outputs[b:e]
b = e + self.n_sit_sot b = e + info.n_sit_sot
e = e + self.n_sit_sot * 2 e = e + info.n_sit_sot * 2
final_outs += outputs[b:e] final_outs += outputs[b:e]
b = e + self.n_nit_sot b = e + info.n_nit_sot
e = e + self.n_nit_sot * 2 e = e + info.n_nit_sot * 2
final_outs += outputs[b:e] final_outs += outputs[b:e]
final_outs += [None] * self.n_shared_outs final_outs += [None] * info.n_shared_outs
return final_outs return final_outs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论