提交 9d9d9020 authored 作者: nouiz's avatar nouiz

Merge pull request #298 from pascanur/scan_check

Scan check
...@@ -119,6 +119,7 @@ class Scalar(Type): ...@@ -119,6 +119,7 @@ class Scalar(Type):
TODO: refactor to be named ScalarType for consistency with TensorType TODO: refactor to be named ScalarType for consistency with TensorType
""" """
ndim = 0
def __init__(self, dtype): def __init__(self, dtype):
if dtype == 'floatX': if dtype == 'floatX':
...@@ -441,6 +442,9 @@ all_types = discrete_types + continuous_types ...@@ -441,6 +442,9 @@ all_types = discrete_types + continuous_types
class _scalar_py_operators: class _scalar_py_operators:
# So that we can simplify checking code when we have a mixture of Scalar
# variables and Tensor variables
ndim = 0
#UNARY #UNARY
def __abs__(self): return abs_(self) def __abs__(self): return abs_(self)
......
...@@ -149,11 +149,36 @@ class Scan(PureOp): ...@@ -149,11 +149,36 @@ class Scan(PureOp):
self._hash_inner_graph = self.info['gpu_hash'] self._hash_inner_graph = self.info['gpu_hash']
def make_node(self, *inputs): def make_node(self, *inputs):
"""
Conventions:
inner_? - the variable corresponding to ? in the inner function
of scan (the lambda function executed at every time
step)
outer_? - the variable corresponding to ? in the outer graph,
i.e. the main graph (where the scan op lives)
inner_?_out - the variable representing the new value of ? after
executing one step of scan (i.e. outputs given by
the inner function)
"""
assert numpy.all(isinstance(i, gof.Variable) for i in inputs) assert numpy.all(isinstance(i, gof.Variable) for i in inputs)
# Check that the number of inputs to the Scan node corresponds to
# the number of inputs of the inner function of scan
n_outer_ins = len(inputs) - len(self.outer_nitsot(inputs)) - 1
n_inner_ins = (len(self.inner_seqs(self.inputs)) +
len(self.mitmot_taps()) +
len(self.mitsot_taps()) +
len(self.inner_sitsot(self.inputs)) +
len(self.inner_shared(self.inputs)) +
len(self.inner_non_seqs(self.inputs)))
assert n_outer_ins == n_inner_ins, \
("The number of inputs given to the inner function of scan"
" does not match the number of inputs given to scan.")
# assert dtype is consistent # assert dtype is consistent
err_msg1 = ('When compiling the inner function of scan the ' err_msg1 = ('When compiling the inner function of scan the '
'following error has been encountered: The ' 'following error has been encountered: The '
'%s %s( the entry number %d) has dtype ' '%s %s (argument number %d) has dtype '
'%s. The corresponding slice %s however has' '%s. The corresponding slice %s however has'
' dtype %s. This should never happen, please ' ' dtype %s. This should never happen, please '
'report to theano-dev mailing list' 'report to theano-dev mailing list'
...@@ -161,7 +186,7 @@ class Scan(PureOp): ...@@ -161,7 +186,7 @@ class Scan(PureOp):
err_msg2 = ('When compiling the inner function of scan the ' err_msg2 = ('When compiling the inner function of scan the '
'following error has been encountered: The ' 'following error has been encountered: The '
'initial state (outputs_info in scan nomenclature)' 'initial state (outputs_info in scan nomenclature)'
'of variable %s (the entry number %d)' 'of variable %s (argument number %d)'
' has dtype %s, while the result of the' ' has dtype %s, while the result of the'
' inner function for this output has dtype %s. This ' ' inner function for this output has dtype %s. This '
'could happen if the inner graph of scan results in ' 'could happen if the inner graph of scan results in '
...@@ -185,85 +210,145 @@ class Scan(PureOp): ...@@ -185,85 +210,145 @@ class Scan(PureOp):
# Check if input sequences and variables representing a slice of # Check if input sequences and variables representing a slice of
# them have the same dtype # them have the same dtype
for idx in xrange(self.n_seqs): argoffset = 0
if inputs[1 + idx].dtype != self.inputs[idx].dtype: for idx, (inner_seq, outer_seq) in enumerate(
zip(self.inner_seqs(self.inputs),
self.outer_seqs(inputs))):
if inner_seq.type.dtype != outer_seq[idx].type.dtype:
raise ValueError(err_msg1 % ('sequence', raise ValueError(err_msg1 % ('sequence',
str(inputs[1 + idx]), str(outer_seq),
idx, idx,
inputs[1 + idx].dtype, outer_seq.type.dtype,
str(self.inputs[idx]), str(inner_seq),
self.inputs[idx].dtype)) inner_seq.type.dtype))
argoffset += len(self.outer_seqs(inputs))
# Check that this 3 things have the same dtype for mit_mot: # Check that this 3 things have the same dtype for mit_mot:
# - initial state of the output # - initial state of the output
# - variable representing an input slice of the otuput # - variable representing an input slice of the otuput
# - variable representing an output slice of the otuput # - variable representing an output slice of the otuput
# Maybe checking that ndim fits would be good as well !? ipos = 0
index_i = self.n_seqs opos = 0
index_o = 0 inner_mitmot = self.inner_mitmot(self.inputs)
index = 1 + self.n_seqs inner_mitmot_outs = self.inner_mitmot_outs(self.outputs)
start = index for idx, (itaps, otaps, outer_mitmot) in enumerate(
end = index + self.n_mit_mot zip(self.mitmot_taps(),
while index < end: self.mitmot_out_taps(),
for k in self.tap_array[index - start]: self.outer_mitmot(inputs))):
if inputs[index].dtype != self.inputs[index_i].dtype: for k in xrange(len(itaps)):
if (inner_mitmot[ipos+k].type.dtype !=
outer_mitmot.type.dtype or
inner_mitmot[ipos+k].ndim != outer_mitmot.ndim - 1):
raise ValueError(err_msg1 % ('initial state (outputs_info' raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ', ' in scan nomenclature) ',
str(inputs[index]), str(outer_mitmot),
index, argoffset + idx,
inputs[index].dtype, outer_mitmot.type.dtype,
str(self.inputs[index_i]), str(inner_mitmot[ipos+k]),
self.inputs[index_i].dtype)) inner_mitmot[ipos+k].type.dtype))
index_i += 1 ipos += len(itaps)
for k in self.mit_mot_out_slices[index - start]: for k in xrange(len(otaps)):
if inputs[index].dtype != self.outputs[index_o].dtype: if (inner_mitmot_outs[opos+k].type.dtype !=
raise ValueError(err_msg2 % (str(inputs[index]), outer_mitmot.type.dtype or
index, inner_mitmot_outs[opos+k].ndim != outer_mitmot.ndim - 1):
inputs[index].dtype, raise ValueError(err_msg2 %
self.outputs[index_o].dtype)) (str(outer_mitmot,
index_o += 1 argoffset + idx,
index += 1 outer_mitmot.type.dtype,
# Same checks as above but for outputs of type mit_sot and sit_sot inner_mitmot_outs[opos+k].type.dtype)))
end += self.n_mit_sot + self.n_sit_sot opos += len(otaps)
while index < end: argoffset += len(self.outer_mitmot(inputs))
for k in self.tap_array[index - start]: # Same checks as above but for outputs of type mit_sot
if inputs[index].dtype != self.inputs[index_i].dtype: ipos = 0
raise ValueError(err_msg1 % ('Initial state', inner_mitsots = self.inner_mitsot(self.inputs)
str(inputs[index]), for idx, (itaps, outer_mitsot, inner_mitsot_out) in enumerate(
index, zip(self.mitsot_taps(),
inputs[index].dtype, self.outer_mitsot(inputs),
str(self.inputs[index_i]), self.inner_mitsot_outs(self.outputs))):
self.inputs[index_i].dtype)) for k in xrange(len(itaps)):
index_i += 1 if (inner_mitsots[ipos+k].type.dtype !=
if inputs[index].dtype != self.outputs[index_o].dtype: outer_mitsot.type.dtype or
raise ValueError(err_msg2 % (str(inputs[index]), inner_mitsots[ipos+k].ndim != outer_mitsot.ndim - 1):
index, raise ValueError(err_msg1 % ('initial state (outputs_info'
inputs[index].dtype, ' in scan nomenclature) ',
self.outputs[index_o].dtype)) str(outer_mitsot),
index_o += 1 argoffset + idx,
index += 1 outer_mitsot.type.dtype,
str(inner_mitsot[ipos+k]),
inner_mitsots[ipos+k].type.dtype))
ipos += len(itaps)
if (inner_mitsot_out.type.dtype != outer_mitsot.type.dtype or
inner_mitsot_out.ndim != outer_mitsot.ndim - 1):
raise ValueError(err_msg2 %
(str(outer_mitsot,
argoffset + idx,
outer_mitsot.type.dtype,
inner_mitsot_out.type.dtype)))
argoffset += len(self.outer_mitsot(inputs))
# Same checks as above but for outputs of type sit_sot
for idx, (inner_sitsot, outer_sitsot, inner_sitsot_out) in enumerate(
zip(self.inner_sitsot(self.inputs),
self.outer_sitsot(inputs),
self.inner_sitsot_outs(self.outputs))):
if (inner_sitsot.type.dtype != outer_sitsot.type.dtype or
inner_sitsot.ndim != outer_sitsot.ndim - 1):
raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ',
str(outer_sitsot),
argoffset + idx,
outer_sitsot.type.dtype,
str(inner_sitsot),
inner_sitsot.type.dtype))
if (inner_sitsot_out.type.dtype != outer_sitsot.type.dtype or
inner_sitsot_out.ndim != outer_sitsot.ndim - 1):
raise ValueError(err_msg2 %
(str(outer_sitsot,
argoffset + idx,
outer_sitsot.type.dtype,
inner_sitsot_out.type.dtype)))
argoffset += len(self.outer_sitsot(inputs))
# Check that the shared variable and their update rule have the same # Check that the shared variable and their update rule have the same
# dtype. Maybe even same type ?! # dtype. Maybe even same type ?!
end += self.n_shared_outs for idx, (inner_shared, inner_shared_out, outer_shared) in enumerate(
index_o += self.n_nit_sot zip(self.inner_shared(self.inputs),
while index < end: self.inner_shared_outs(self.outputs),
if (hasattr(inputs[index], 'dtype') and self.outer_shared(inputs))):
inputs[index].dtype != self.outputs[index_o].dtype): if (hasattr(outer_shared, 'dtype') and
raise ValueError(err_msg2 % (str(inputs[index]), (outer_shared.dtype != inner_shared_out.dtype or
index, outer_shared.ndim != inner_shared_out.ndim)):
inputs[index].dtype, raise ValueError(err_msg2 % (str(outer_shared),
self.outputs[index_o].dtype)) idx + argoffset,
index += 1 outer_shared.dtype,
index_o += 1 inner_shared_out.dtype))
for x in inputs[index:index + self.n_nit_sot]:
if (hasattr(outer_shared, 'dtype') and
(outer_shared.dtype != inner_shared.dtype or
outer_shared.ndim != inner_shared.ndim)):
raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ',
str(outer_shared),
argoffset + idx,
outer_shared.dtype,
str(inner_shared),
inner_shared.dtype))
for inner_nonseq, outer_nonseq in zip(
self.inner_non_seqs(self.inputs),
self.outer_non_seqs(inputs)):
if (inner_nonseq.type.dtype != outer_nonseq.type.dtype or
inner_nonseq.type.ndim != outer_nonseq.type.ndim):
raise ValueError(('Argument %s given to scan node does not'
' match its correspondance %s')%
(str(outer_nonseq), str(inner_nonseq)))
for outer_nitsot in self.outer_nitsot(inputs):
# For every nit_sot input we get as input a int/uint that # For every nit_sot input we get as input a int/uint that
# depicts the size in memory for that sequence. This feature is # depicts the size in memory for that sequence. This feature is
# used by truncated BPTT and by scan space optimization # used by truncated BPTT and by scan space optimization
if (str(x.dtype)[:3] not in ('uin', 'int') or if (str(outer_nitsot.type.dtype)[:3] not in ('uin', 'int') or
x.ndim != 0): outer_nitsot.ndim != 0):
raise ValueError('For output %d you need to provide a ' raise ValueError('For output %s you need to provide a '
'scalar int !', x) 'scalar int !', str(outer_nitsot))
apply_node = Apply(self, apply_node = Apply(self,
inputs, inputs,
...@@ -459,25 +544,29 @@ class Scan(PureOp): ...@@ -459,25 +544,29 @@ class Scan(PureOp):
rval.lazy = False rval.lazy = False
return rval return rval
def inner_seqs(self): def inner_seqs(self, list_inputs):
return self.inputs[:self.n_seqs] # Given the list of inner inputs this function grabs those
# corresponding to sequences
return list_inputs[:self.n_seqs]
def outer_seqs(self, node): def outer_seqs(self, list_inputs):
return node.inputs[1:1 + self.n_seqs] # Given the list of outter inputs this function grabs those
# corresponding to sequences
return list_inputs[1:1 + self.n_seqs]
def inner_mitmot(self): def inner_mitmot(self, list_inputs):
n_taps = sum(len(x) for x in self.tap_array[:self.n_mit_mot]) n_taps = sum(len(x) for x in self.tap_array[:self.n_mit_mot])
return self.inputs[self.n_seqs: self.n_seqs + n_taps] return list_inputs[self.n_seqs: self.n_seqs + n_taps]
def outer_mitmot(self, node): def outer_mitmot(self, list_inputs):
return node.inputs[1 + self.n_seqs:1 + self.n_seqs + self.n_mit_mot] return list_inputs[1 + self.n_seqs:1 + self.n_seqs + self.n_mit_mot]
def inner_mitmot_outs(self): def inner_mitmot_outs(self, list_outputs):
n_taps = sum(len(x) for x in self.mit_mot_out_slices) n_taps = sum(len(x) for x in self.mit_mot_out_slices)
return self.outputs[:n_taps] return list_outputs[:n_taps]
def outer_mitmot_outs(self, node): def outer_mitmot_outs(self, list_outputs):
return node.outputs[:self.n_mit_mot] return list_outputs[:self.n_mit_mot]
def mitmot_taps(self): def mitmot_taps(self):
return self.tap_array[:self.n_mit_mot] return self.tap_array[:self.n_mit_mot]
...@@ -485,98 +574,98 @@ class Scan(PureOp): ...@@ -485,98 +574,98 @@ class Scan(PureOp):
def mitmot_out_taps(self): def mitmot_out_taps(self):
return self.mit_mot_out_slices[:self.n_mit_mot] return self.mit_mot_out_slices[:self.n_mit_mot]
def inner_mitsot(self): def inner_mitsot(self, list_inputs):
n_mitmot_taps = sum(len(x) for x in self.tap_array[:self.n_mit_mot]) n_mitmot_taps = sum(len(x) for x in self.tap_array[:self.n_mit_mot])
ntaps_upto_sit_sot = sum(len(x) for x in ntaps_upto_sit_sot = sum(len(x) for x in
self.tap_array[:(self.n_mit_mot + self.tap_array[:(self.n_mit_mot +
self.n_mit_sot)]) self.n_mit_sot)])
return self.inputs[self.n_seqs + n_mitmot_taps: return list_inputs[self.n_seqs + n_mitmot_taps:
self.n_seqs + ntaps_upto_sit_sot] self.n_seqs + ntaps_upto_sit_sot]
def outer_mitsot(self, node): def outer_mitsot(self, list_inputs):
offset = 1 + self.n_seqs + self.n_mit_mot offset = 1 + self.n_seqs + self.n_mit_mot
return node.inputs[offset:offset + self.n_mit_sot] return list_inputs[offset:offset + self.n_mit_sot]
def inner_mitsot_outs(self): def inner_mitsot_outs(self, list_outputs):
n_taps = sum(len(x) for x in self.mit_mot_out_slices) n_taps = sum(len(x) for x in self.mit_mot_out_slices)
return self.outputs[n_taps:n_taps + self.n_mit_sot] return list_outputs[n_taps:n_taps + self.n_mit_sot]
def outer_mitsot_outs(self, node): def outer_mitsot_outs(self, list_outputs):
return node.outputs[self.n_mit_mot: return list_outputs[self.n_mit_mot:
self.n_mit_mot + self.n_mit_sot] self.n_mit_mot + self.n_mit_sot]
def mitsot_taps(self): def mitsot_taps(self):
return self.tap_array[self.n_mit_mot: return self.tap_array[self.n_mit_mot:
self.n_mit_mot + self.n_mit_sot] self.n_mit_mot + self.n_mit_sot]
def inner_sitsot(self): def inner_sitsot(self, list_inputs):
n_taps_upto_sit_sot = sum(len(x) for x in n_taps_upto_sit_sot = sum(len(x) for x in
self.tap_array[:(self.n_mit_mot + self.tap_array[:(self.n_mit_mot +
self.n_mit_sot)]) self.n_mit_sot)])
offset = self.n_seqs + n_taps_upto_sit_sot offset = self.n_seqs + n_taps_upto_sit_sot
return self.inputs[offset:offset + self.n_sit_sot] return list_inputs[offset:offset + self.n_sit_sot]
def outer_sitsot(self, node): def outer_sitsot(self, list_inputs):
offset = 1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot offset = 1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot
return node.inputs[offset:offset + self.n_sit_sot] return list_inputs[offset:offset + self.n_sit_sot]
def inner_sitsot_outs(self): def inner_sitsot_outs(self, list_outputs):
n_taps = sum(len(x) for x in self.mit_mot_out_slices) n_taps = sum(len(x) for x in self.mit_mot_out_slices)
offset = self.n_mit_sot + n_taps offset = self.n_mit_sot + n_taps
return self.outputs[offset:offset + self.n_sit_sot] return list_outputs[offset:offset + self.n_sit_sot]
def outer_sitsot_outs(self, node): def outer_sitsot_outs(self, list_outputs):
offset = self.n_mit_mot + self.n_mit_sot offset = self.n_mit_mot + self.n_mit_sot
return node.outputs[offset:offset + self.n_sit_sot] return list_outputs[offset:offset + self.n_sit_sot]
def outer_nitsot(self, node): def outer_nitsot(self, list_inputs):
offset = (1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot + offset = (1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot +
self.n_sit_sot + self.n_shared_outs) self.n_sit_sot + self.n_shared_outs)
return node.inputs[offset:offset + self.n_nit_sot] return list_inputs[offset:offset + self.n_nit_sot]
def inner_nitsot_outs(self): def inner_nitsot_outs(self, list_outputs):
n_taps = sum(len(x) for x in self.mit_mot_out_slices) n_taps = sum(len(x) for x in self.mit_mot_out_slices)
offset = self.n_mit_sot + n_taps + self.n_sit_sot offset = self.n_mit_sot + n_taps + self.n_sit_sot
return self.outputs[offset:offset + self.n_nit_sot] return list_outputs[offset:offset + self.n_nit_sot]
def outer_nitsot_outs(self, node): def outer_nitsot_outs(self, list_outputs):
offset = (self.n_mit_mot + self.n_mit_sot + self.n_sit_sot) offset = (self.n_mit_mot + self.n_mit_sot + self.n_sit_sot)
return node.outputs[offset:offset + self.n_nit_sot] return list_outputs[offset:offset + self.n_nit_sot]
def inner_shared(self): def inner_shared(self, list_inputs):
n_taps_upto_sit_sot = sum(len(x) for x in n_taps_upto_sit_sot = sum(len(x) for x in
self.tap_array[:(self.n_mit_mot + self.tap_array[:(self.n_mit_mot +
self.n_mit_sot)]) self.n_mit_sot)])
offset = self.n_seqs + n_taps_upto_sit_sot + self.n_sit_sot offset = self.n_seqs + n_taps_upto_sit_sot + self.n_sit_sot
return self.inputs[offset:offset + self.n_shared_outs] return list_inputs[offset:offset + self.n_shared_outs]
def outer_shared(self, node): def outer_shared(self, list_inputs):
offset = (1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot + offset = (1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot +
self.n_sit_sot) self.n_sit_sot)
return node.inputs[offset:offset + self.n_shared_outs] return list_inputs[offset:offset + self.n_shared_outs]
def inner_shared_outs(self): def inner_shared_outs(self, list_outputs):
n_taps = sum(len(x) for x in self.mit_mot_out_slices) n_taps = sum(len(x) for x in self.mit_mot_out_slices)
offset = self.n_mit_sot + n_taps + self.n_sit_sot + self.n_nit_sot offset = self.n_mit_sot + n_taps + self.n_sit_sot + self.n_nit_sot
return self.outputs[offset:offset + self.n_shared_outs] return list_outputs[offset:offset + self.n_shared_outs]
def outer_shared_outs(self, node): def outer_shared_outs(self, list_outputs):
offset = (self.n_mit_mot + self.n_mit_sot + self.n_sit_sot + offset = (self.n_mit_mot + self.n_mit_sot + self.n_sit_sot +
self.n_nit_sot) self.n_nit_sot)
return node.outputs[offset:offset + self.n_shared_outs] return list_outputs[offset:offset + self.n_shared_outs]
def inner_non_seqs(self): def inner_non_seqs(self, list_inputs):
n_taps_upto_sit_sot = sum(len(x) for x in n_taps_upto_sit_sot = sum(len(x) for x in
self.tap_array[:(self.n_mit_mot + self.tap_array[:(self.n_mit_mot +
self.n_mit_sot)]) self.n_mit_sot)])
offset = (self.n_seqs + n_taps_upto_sit_sot + self.n_sit_sot + offset = (self.n_seqs + n_taps_upto_sit_sot + self.n_sit_sot +
self.n_shared_outs) self.n_shared_outs)
return self.inputs[offset:] return list_inputs[offset:]
def outer_non_seqs(self, node): def outer_non_seqs(self, list_inputs):
offset = (1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot + offset = (1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot +
self.n_sit_sot + self.n_nit_sot + self.n_shared_outs) self.n_sit_sot + self.n_nit_sot + self.n_shared_outs)
return node.inputs[offset:] return list_inputs[offset:]
def execute(self, node, args, outs): def execute(self, node, args, outs):
""" """
......
...@@ -200,8 +200,7 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -200,8 +200,7 @@ class PushOutNonSeqScan(gof.Optimizer):
not isinstance(nd.op, theano.compile.ViewOp) and not isinstance(nd.op, theano.compile.ViewOp) and
not isinstance(nd.op, theano.compile.DeepCopyOp) and not isinstance(nd.op, theano.compile.DeepCopyOp) and
# and we didn't already looked at this node # and we didn't already looked at this node
not nd in to_remove not nd in to_remove):
):
# We have a candidate node to removable # We have a candidate node to removable
# Step 1. Reconstruct it on outside # Step 1. Reconstruct it on outside
...@@ -317,12 +316,12 @@ def scan_make_inplace(node): ...@@ -317,12 +316,12 @@ def scan_make_inplace(node):
info['inplace'] = True info['inplace'] = True
# inputs corresponding to sequences and n_steps # inputs corresponding to sequences and n_steps
ls_begin = node.inputs[:1 + op.n_seqs] ls_begin = node.inputs[:1 + op.n_seqs]
ls = op.outer_mitmot(node) ls = op.outer_mitmot(node.inputs)
ls += op.outer_mitsot(node) ls += op.outer_mitsot(node.inputs)
ls += op.outer_sitsot(node) ls += op.outer_sitsot(node.inputs)
ls_end = op.outer_shared(node) ls_end = op.outer_shared(node.inputs)
ls_end += op.outer_nitsot(node) ls_end += op.outer_nitsot(node.inputs)
ls_end += op.outer_non_seqs(node) ls_end += op.outer_non_seqs(node.inputs)
n_outs = len(ls) n_outs = len(ls)
for idx in xrange(n_outs): for idx in xrange(n_outs):
if ls[idx] in ls[:idx]: if ls[idx] in ls[:idx]:
...@@ -717,8 +716,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -717,8 +716,7 @@ class ScanSaveMem(gof.Optimizer):
fslice = slice( fslice = slice(
sanitize(cnf_slice[0].start), sanitize(cnf_slice[0].start),
sanitize(cnf_slice[0].stop), sanitize(cnf_slice[0].stop),
sanitize(cnf_slice[0].step) sanitize(cnf_slice[0].step))
)
else: else:
fslice = sanitize(cnf_slice[0]) fslice = sanitize(cnf_slice[0])
...@@ -850,54 +848,54 @@ class ScanMerge(gof.Optimizer): ...@@ -850,54 +848,54 @@ class ScanMerge(gof.Optimizer):
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# Seq # Seq
inner_ins += rename(nd.op.inner_seqs(), idx) inner_ins += rename(nd.op.inner_seqs(nd.op.inputs), idx)
outer_ins += rename(nd.op.outer_seqs(nd), idx) outer_ins += rename(nd.op.outer_seqs(nd.inputs), idx)
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# MitMot # MitMot
inner_ins += rename(nd.op.inner_mitmot(), idx) inner_ins += rename(nd.op.inner_mitmot(nd.op.inputs), idx)
inner_outs += nd.op.inner_mitmot_outs() inner_outs += nd.op.inner_mitmot_outs(nd.op.outputs)
info['tap_array'] += nd.op.mitmot_taps() info['tap_array'] += nd.op.mitmot_taps()
info['mit_mot_out_slices'] += nd.op.mitmot_out_taps() info['mit_mot_out_slices'] += nd.op.mitmot_out_taps()
outer_ins += rename(nd.op.outer_mitmot(nd), idx) outer_ins += rename(nd.op.outer_mitmot(nd.inputs), idx)
outer_outs += nd.op.outer_mitmot_outs(nd) outer_outs += nd.op.outer_mitmot_outs(nd.outputs)
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# MitSot # MitSot
inner_ins += rename(nd.op.inner_mitsot(), idx) inner_ins += rename(nd.op.inner_mitsot(nd.op.inputs), idx)
inner_outs += nd.op.inner_mitsot_outs() inner_outs += nd.op.inner_mitsot_outs(nd.op.outputs)
info['tap_array'] += nd.op.mitsot_taps() info['tap_array'] += nd.op.mitsot_taps()
outer_ins += rename(nd.op.outer_mitsot(nd), idx) outer_ins += rename(nd.op.outer_mitsot(nd.inputs), idx)
outer_outs += nd.op.outer_mitsot_outs(nd) outer_outs += nd.op.outer_mitsot_outs(nd.outputs)
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# SitSot # SitSot
inner_ins += rename(nd.op.inner_sitsot(), idx) inner_ins += rename(nd.op.inner_sitsot(nd.op.inputs), idx)
info['tap_array'] += [[-1] for x in xrange(nd.op.n_sit_sot)] info['tap_array'] += [[-1] for x in xrange(nd.op.n_sit_sot)]
inner_outs += nd.op.inner_sitsot_outs() inner_outs += nd.op.inner_sitsot_outs(nd.op.outputs)
outer_ins += rename(nd.op.outer_sitsot(nd), idx) outer_ins += rename(nd.op.outer_sitsot(nd.inputs), idx)
outer_outs += nd.op.outer_sitsot_outs(nd) outer_outs += nd.op.outer_sitsot_outs(nd.outputs)
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# Shared # Shared
inner_ins += rename(nd.op.inner_shared(), idx) inner_ins += rename(nd.op.inner_shared(nd.op.inputs), idx)
outer_ins += rename(nd.op.outer_shared(nd), idx) outer_ins += rename(nd.op.outer_shared(nd.inputs), idx)
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# NitSot # NitSot
inner_outs += nd.op.inner_nitsot_outs() inner_outs += nd.op.inner_nitsot_outs(nd.op.outputs)
outer_ins += rename(nd.op.outer_nitsot(nd), idx) outer_ins += rename(nd.op.outer_nitsot(nd.inputs), idx)
outer_outs += nd.op.outer_nitsot_outs(nd) outer_outs += nd.op.outer_nitsot_outs(nd.outputs)
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# Shared # Shared
outer_outs += nd.op.outer_shared_outs(nd) outer_outs += nd.op.outer_shared_outs(nd.outputs)
inner_outs += nd.op.inner_shared_outs() inner_outs += nd.op.inner_shared_outs(nd.op.outputs)
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# Non Seqs # Non Seqs
inner_ins += rename(nd.op.inner_non_seqs(), idx) inner_ins += rename(nd.op.inner_non_seqs(nd.op.inputs), idx)
outer_ins += rename(nd.op.outer_non_seqs(nd), idx) outer_ins += rename(nd.op.outer_non_seqs(nd.inputs), idx)
# Add back the number of steps # Add back the number of steps
outer_ins = [nodes[0].inputs[0]] + outer_ins outer_ins = [nodes[0].inputs[0]] + outer_ins
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论