提交 9d9d9020 authored 作者: nouiz's avatar nouiz

Merge pull request #298 from pascanur/scan_check

Scan check
......@@ -119,6 +119,7 @@ class Scalar(Type):
TODO: refactor to be named ScalarType for consistency with TensorType
"""
ndim = 0
def __init__(self, dtype):
if dtype == 'floatX':
......@@ -441,6 +442,9 @@ all_types = discrete_types + continuous_types
class _scalar_py_operators:
# So that we can simplify checking code when we have a mixture of Scalar
# variables and Tensor variables
ndim = 0
#UNARY
def __abs__(self): return abs_(self)
......
......@@ -149,11 +149,36 @@ class Scan(PureOp):
self._hash_inner_graph = self.info['gpu_hash']
def make_node(self, *inputs):
"""
Conventions:
inner_? - the variable corresponding to ? in the inner function
of scan (the lambda function executed at every time
step)
outer_? - the variable corresponding to ? in the outer graph,
i.e. the main graph (where the scan op lives)
inner_?_out - the variable representing the new value of ? after
executing one step of scan (i.e. outputs given by
the inner function)
"""
assert numpy.all(isinstance(i, gof.Variable) for i in inputs)
# Check that the number of inputs to the Scan node corresponds to
# the number of inputs of the inner function of scan
n_outer_ins = len(inputs) - len(self.outer_nitsot(inputs)) - 1
n_inner_ins = (len(self.inner_seqs(self.inputs)) +
len(self.mitmot_taps()) +
len(self.mitsot_taps()) +
len(self.inner_sitsot(self.inputs)) +
len(self.inner_shared(self.inputs)) +
len(self.inner_non_seqs(self.inputs)))
assert n_outer_ins == n_inner_ins, \
("The number of inputs given to the inner function of scan"
" does not match the number of inputs given to scan.")
# assert dtype is consistent
err_msg1 = ('When compiling the inner function of scan the '
'following error has been encountered: The '
'%s %s( the entry number %d) has dtype '
'%s %s (argument number %d) has dtype '
'%s. The corresponding slice %s however has'
' dtype %s. This should never happen, please '
'report to theano-dev mailing list'
......@@ -161,7 +186,7 @@ class Scan(PureOp):
err_msg2 = ('When compiling the inner function of scan the '
'following error has been encountered: The '
'initial state (outputs_info in scan nomenclature)'
'of variable %s (the entry number %d)'
'of variable %s (argument number %d)'
' has dtype %s, while the result of the'
' inner function for this output has dtype %s. This '
'could happen if the inner graph of scan results in '
......@@ -185,85 +210,145 @@ class Scan(PureOp):
# Check if input sequences and variables representing a slice of
# them have the same dtype
for idx in xrange(self.n_seqs):
if inputs[1 + idx].dtype != self.inputs[idx].dtype:
argoffset = 0
for idx, (inner_seq, outer_seq) in enumerate(
zip(self.inner_seqs(self.inputs),
self.outer_seqs(inputs))):
if inner_seq.type.dtype != outer_seq[idx].type.dtype:
raise ValueError(err_msg1 % ('sequence',
str(inputs[1 + idx]),
str(outer_seq),
idx,
inputs[1 + idx].dtype,
str(self.inputs[idx]),
self.inputs[idx].dtype))
outer_seq.type.dtype,
str(inner_seq),
inner_seq.type.dtype))
argoffset += len(self.outer_seqs(inputs))
# Check that this 3 things have the same dtype for mit_mot:
# - initial state of the output
# - variable representing an input slice of the otuput
# - variable representing an output slice of the otuput
# Maybe checking that ndim fits would be good as well !?
index_i = self.n_seqs
index_o = 0
index = 1 + self.n_seqs
start = index
end = index + self.n_mit_mot
while index < end:
for k in self.tap_array[index - start]:
if inputs[index].dtype != self.inputs[index_i].dtype:
ipos = 0
opos = 0
inner_mitmot = self.inner_mitmot(self.inputs)
inner_mitmot_outs = self.inner_mitmot_outs(self.outputs)
for idx, (itaps, otaps, outer_mitmot) in enumerate(
zip(self.mitmot_taps(),
self.mitmot_out_taps(),
self.outer_mitmot(inputs))):
for k in xrange(len(itaps)):
if (inner_mitmot[ipos+k].type.dtype !=
outer_mitmot.type.dtype or
inner_mitmot[ipos+k].ndim != outer_mitmot.ndim - 1):
raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ',
str(inputs[index]),
index,
inputs[index].dtype,
str(self.inputs[index_i]),
self.inputs[index_i].dtype))
index_i += 1
for k in self.mit_mot_out_slices[index - start]:
if inputs[index].dtype != self.outputs[index_o].dtype:
raise ValueError(err_msg2 % (str(inputs[index]),
index,
inputs[index].dtype,
self.outputs[index_o].dtype))
index_o += 1
index += 1
# Same checks as above but for outputs of type mit_sot and sit_sot
end += self.n_mit_sot + self.n_sit_sot
while index < end:
for k in self.tap_array[index - start]:
if inputs[index].dtype != self.inputs[index_i].dtype:
raise ValueError(err_msg1 % ('Initial state',
str(inputs[index]),
index,
inputs[index].dtype,
str(self.inputs[index_i]),
self.inputs[index_i].dtype))
index_i += 1
if inputs[index].dtype != self.outputs[index_o].dtype:
raise ValueError(err_msg2 % (str(inputs[index]),
index,
inputs[index].dtype,
self.outputs[index_o].dtype))
index_o += 1
index += 1
str(outer_mitmot),
argoffset + idx,
outer_mitmot.type.dtype,
str(inner_mitmot[ipos+k]),
inner_mitmot[ipos+k].type.dtype))
ipos += len(itaps)
for k in xrange(len(otaps)):
if (inner_mitmot_outs[opos+k].type.dtype !=
outer_mitmot.type.dtype or
inner_mitmot_outs[opos+k].ndim != outer_mitmot.ndim - 1):
raise ValueError(err_msg2 %
(str(outer_mitmot,
argoffset + idx,
outer_mitmot.type.dtype,
inner_mitmot_outs[opos+k].type.dtype)))
opos += len(otaps)
argoffset += len(self.outer_mitmot(inputs))
# Same checks as above but for outputs of type mit_sot
ipos = 0
inner_mitsots = self.inner_mitsot(self.inputs)
for idx, (itaps, outer_mitsot, inner_mitsot_out) in enumerate(
zip(self.mitsot_taps(),
self.outer_mitsot(inputs),
self.inner_mitsot_outs(self.outputs))):
for k in xrange(len(itaps)):
if (inner_mitsots[ipos+k].type.dtype !=
outer_mitsot.type.dtype or
inner_mitsots[ipos+k].ndim != outer_mitsot.ndim - 1):
raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ',
str(outer_mitsot),
argoffset + idx,
outer_mitsot.type.dtype,
str(inner_mitsot[ipos+k]),
inner_mitsots[ipos+k].type.dtype))
ipos += len(itaps)
if (inner_mitsot_out.type.dtype != outer_mitsot.type.dtype or
inner_mitsot_out.ndim != outer_mitsot.ndim - 1):
raise ValueError(err_msg2 %
(str(outer_mitsot,
argoffset + idx,
outer_mitsot.type.dtype,
inner_mitsot_out.type.dtype)))
argoffset += len(self.outer_mitsot(inputs))
# Same checks as above but for outputs of type sit_sot
for idx, (inner_sitsot, outer_sitsot, inner_sitsot_out) in enumerate(
zip(self.inner_sitsot(self.inputs),
self.outer_sitsot(inputs),
self.inner_sitsot_outs(self.outputs))):
if (inner_sitsot.type.dtype != outer_sitsot.type.dtype or
inner_sitsot.ndim != outer_sitsot.ndim - 1):
raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ',
str(outer_sitsot),
argoffset + idx,
outer_sitsot.type.dtype,
str(inner_sitsot),
inner_sitsot.type.dtype))
if (inner_sitsot_out.type.dtype != outer_sitsot.type.dtype or
inner_sitsot_out.ndim != outer_sitsot.ndim - 1):
raise ValueError(err_msg2 %
(str(outer_sitsot,
argoffset + idx,
outer_sitsot.type.dtype,
inner_sitsot_out.type.dtype)))
argoffset += len(self.outer_sitsot(inputs))
# Check that the shared variable and their update rule have the same
# dtype. Maybe even same type ?!
end += self.n_shared_outs
index_o += self.n_nit_sot
while index < end:
if (hasattr(inputs[index], 'dtype') and
inputs[index].dtype != self.outputs[index_o].dtype):
raise ValueError(err_msg2 % (str(inputs[index]),
index,
inputs[index].dtype,
self.outputs[index_o].dtype))
index += 1
index_o += 1
for x in inputs[index:index + self.n_nit_sot]:
for idx, (inner_shared, inner_shared_out, outer_shared) in enumerate(
zip(self.inner_shared(self.inputs),
self.inner_shared_outs(self.outputs),
self.outer_shared(inputs))):
if (hasattr(outer_shared, 'dtype') and
(outer_shared.dtype != inner_shared_out.dtype or
outer_shared.ndim != inner_shared_out.ndim)):
raise ValueError(err_msg2 % (str(outer_shared),
idx + argoffset,
outer_shared.dtype,
inner_shared_out.dtype))
if (hasattr(outer_shared, 'dtype') and
(outer_shared.dtype != inner_shared.dtype or
outer_shared.ndim != inner_shared.ndim)):
raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ',
str(outer_shared),
argoffset + idx,
outer_shared.dtype,
str(inner_shared),
inner_shared.dtype))
for inner_nonseq, outer_nonseq in zip(
self.inner_non_seqs(self.inputs),
self.outer_non_seqs(inputs)):
if (inner_nonseq.type.dtype != outer_nonseq.type.dtype or
inner_nonseq.type.ndim != outer_nonseq.type.ndim):
raise ValueError(('Argument %s given to scan node does not'
' match its correspondance %s')%
(str(outer_nonseq), str(inner_nonseq)))
for outer_nitsot in self.outer_nitsot(inputs):
# For every nit_sot input we get as input a int/uint that
# depicts the size in memory for that sequence. This feature is
# used by truncated BPTT and by scan space optimization
if (str(x.dtype)[:3] not in ('uin', 'int') or
x.ndim != 0):
raise ValueError('For output %d you need to provide a '
'scalar int !', x)
if (str(outer_nitsot.type.dtype)[:3] not in ('uin', 'int') or
outer_nitsot.ndim != 0):
raise ValueError('For output %s you need to provide a '
'scalar int !', str(outer_nitsot))
apply_node = Apply(self,
inputs,
......@@ -459,25 +544,29 @@ class Scan(PureOp):
rval.lazy = False
return rval
def inner_seqs(self):
return self.inputs[:self.n_seqs]
def inner_seqs(self, list_inputs):
# Given the list of inner inputs this function grabs those
# corresponding to sequences
return list_inputs[:self.n_seqs]
def outer_seqs(self, node):
return node.inputs[1:1 + self.n_seqs]
def outer_seqs(self, list_inputs):
# Given the list of outter inputs this function grabs those
# corresponding to sequences
return list_inputs[1:1 + self.n_seqs]
def inner_mitmot(self):
def inner_mitmot(self, list_inputs):
n_taps = sum(len(x) for x in self.tap_array[:self.n_mit_mot])
return self.inputs[self.n_seqs: self.n_seqs + n_taps]
return list_inputs[self.n_seqs: self.n_seqs + n_taps]
def outer_mitmot(self, node):
return node.inputs[1 + self.n_seqs:1 + self.n_seqs + self.n_mit_mot]
def outer_mitmot(self, list_inputs):
return list_inputs[1 + self.n_seqs:1 + self.n_seqs + self.n_mit_mot]
def inner_mitmot_outs(self):
def inner_mitmot_outs(self, list_outputs):
n_taps = sum(len(x) for x in self.mit_mot_out_slices)
return self.outputs[:n_taps]
return list_outputs[:n_taps]
def outer_mitmot_outs(self, node):
return node.outputs[:self.n_mit_mot]
def outer_mitmot_outs(self, list_outputs):
return list_outputs[:self.n_mit_mot]
def mitmot_taps(self):
return self.tap_array[:self.n_mit_mot]
......@@ -485,98 +574,98 @@ class Scan(PureOp):
def mitmot_out_taps(self):
return self.mit_mot_out_slices[:self.n_mit_mot]
def inner_mitsot(self):
def inner_mitsot(self, list_inputs):
n_mitmot_taps = sum(len(x) for x in self.tap_array[:self.n_mit_mot])
ntaps_upto_sit_sot = sum(len(x) for x in
self.tap_array[:(self.n_mit_mot +
self.n_mit_sot)])
return self.inputs[self.n_seqs + n_mitmot_taps:
return list_inputs[self.n_seqs + n_mitmot_taps:
self.n_seqs + ntaps_upto_sit_sot]
def outer_mitsot(self, node):
def outer_mitsot(self, list_inputs):
offset = 1 + self.n_seqs + self.n_mit_mot
return node.inputs[offset:offset + self.n_mit_sot]
return list_inputs[offset:offset + self.n_mit_sot]
def inner_mitsot_outs(self):
def inner_mitsot_outs(self, list_outputs):
n_taps = sum(len(x) for x in self.mit_mot_out_slices)
return self.outputs[n_taps:n_taps + self.n_mit_sot]
return list_outputs[n_taps:n_taps + self.n_mit_sot]
def outer_mitsot_outs(self, node):
return node.outputs[self.n_mit_mot:
def outer_mitsot_outs(self, list_outputs):
return list_outputs[self.n_mit_mot:
self.n_mit_mot + self.n_mit_sot]
def mitsot_taps(self):
return self.tap_array[self.n_mit_mot:
self.n_mit_mot + self.n_mit_sot]
def inner_sitsot(self):
def inner_sitsot(self, list_inputs):
n_taps_upto_sit_sot = sum(len(x) for x in
self.tap_array[:(self.n_mit_mot +
self.n_mit_sot)])
offset = self.n_seqs + n_taps_upto_sit_sot
return self.inputs[offset:offset + self.n_sit_sot]
return list_inputs[offset:offset + self.n_sit_sot]
def outer_sitsot(self, node):
def outer_sitsot(self, list_inputs):
offset = 1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot
return node.inputs[offset:offset + self.n_sit_sot]
return list_inputs[offset:offset + self.n_sit_sot]
def inner_sitsot_outs(self):
def inner_sitsot_outs(self, list_outputs):
n_taps = sum(len(x) for x in self.mit_mot_out_slices)
offset = self.n_mit_sot + n_taps
return self.outputs[offset:offset + self.n_sit_sot]
return list_outputs[offset:offset + self.n_sit_sot]
def outer_sitsot_outs(self, node):
def outer_sitsot_outs(self, list_outputs):
offset = self.n_mit_mot + self.n_mit_sot
return node.outputs[offset:offset + self.n_sit_sot]
return list_outputs[offset:offset + self.n_sit_sot]
def outer_nitsot(self, node):
def outer_nitsot(self, list_inputs):
offset = (1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot +
self.n_sit_sot + self.n_shared_outs)
return node.inputs[offset:offset + self.n_nit_sot]
return list_inputs[offset:offset + self.n_nit_sot]
def inner_nitsot_outs(self):
def inner_nitsot_outs(self, list_outputs):
n_taps = sum(len(x) for x in self.mit_mot_out_slices)
offset = self.n_mit_sot + n_taps + self.n_sit_sot
return self.outputs[offset:offset + self.n_nit_sot]
return list_outputs[offset:offset + self.n_nit_sot]
def outer_nitsot_outs(self, node):
def outer_nitsot_outs(self, list_outputs):
offset = (self.n_mit_mot + self.n_mit_sot + self.n_sit_sot)
return node.outputs[offset:offset + self.n_nit_sot]
return list_outputs[offset:offset + self.n_nit_sot]
def inner_shared(self):
def inner_shared(self, list_inputs):
n_taps_upto_sit_sot = sum(len(x) for x in
self.tap_array[:(self.n_mit_mot +
self.n_mit_sot)])
offset = self.n_seqs + n_taps_upto_sit_sot + self.n_sit_sot
return self.inputs[offset:offset + self.n_shared_outs]
return list_inputs[offset:offset + self.n_shared_outs]
def outer_shared(self, node):
def outer_shared(self, list_inputs):
offset = (1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot +
self.n_sit_sot)
return node.inputs[offset:offset + self.n_shared_outs]
return list_inputs[offset:offset + self.n_shared_outs]
def inner_shared_outs(self):
def inner_shared_outs(self, list_outputs):
n_taps = sum(len(x) for x in self.mit_mot_out_slices)
offset = self.n_mit_sot + n_taps + self.n_sit_sot + self.n_nit_sot
return self.outputs[offset:offset + self.n_shared_outs]
return list_outputs[offset:offset + self.n_shared_outs]
def outer_shared_outs(self, node):
def outer_shared_outs(self, list_outputs):
offset = (self.n_mit_mot + self.n_mit_sot + self.n_sit_sot +
self.n_nit_sot)
return node.outputs[offset:offset + self.n_shared_outs]
return list_outputs[offset:offset + self.n_shared_outs]
def inner_non_seqs(self):
def inner_non_seqs(self, list_inputs):
n_taps_upto_sit_sot = sum(len(x) for x in
self.tap_array[:(self.n_mit_mot +
self.n_mit_sot)])
offset = (self.n_seqs + n_taps_upto_sit_sot + self.n_sit_sot +
self.n_shared_outs)
return self.inputs[offset:]
return list_inputs[offset:]
def outer_non_seqs(self, node):
def outer_non_seqs(self, list_inputs):
offset = (1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot +
self.n_sit_sot + self.n_nit_sot + self.n_shared_outs)
return node.inputs[offset:]
return list_inputs[offset:]
def execute(self, node, args, outs):
"""
......
......@@ -200,8 +200,7 @@ class PushOutNonSeqScan(gof.Optimizer):
not isinstance(nd.op, theano.compile.ViewOp) and
not isinstance(nd.op, theano.compile.DeepCopyOp) and
# and we didn't already looked at this node
not nd in to_remove
):
not nd in to_remove):
# We have a candidate node to removable
# Step 1. Reconstruct it on outside
......@@ -317,12 +316,12 @@ def scan_make_inplace(node):
info['inplace'] = True
# inputs corresponding to sequences and n_steps
ls_begin = node.inputs[:1 + op.n_seqs]
ls = op.outer_mitmot(node)
ls += op.outer_mitsot(node)
ls += op.outer_sitsot(node)
ls_end = op.outer_shared(node)
ls_end += op.outer_nitsot(node)
ls_end += op.outer_non_seqs(node)
ls = op.outer_mitmot(node.inputs)
ls += op.outer_mitsot(node.inputs)
ls += op.outer_sitsot(node.inputs)
ls_end = op.outer_shared(node.inputs)
ls_end += op.outer_nitsot(node.inputs)
ls_end += op.outer_non_seqs(node.inputs)
n_outs = len(ls)
for idx in xrange(n_outs):
if ls[idx] in ls[:idx]:
......@@ -717,8 +716,7 @@ class ScanSaveMem(gof.Optimizer):
fslice = slice(
sanitize(cnf_slice[0].start),
sanitize(cnf_slice[0].stop),
sanitize(cnf_slice[0].step)
)
sanitize(cnf_slice[0].step))
else:
fslice = sanitize(cnf_slice[0])
......@@ -850,54 +848,54 @@ class ScanMerge(gof.Optimizer):
for idx, nd in enumerate(nodes):
# Seq
inner_ins += rename(nd.op.inner_seqs(), idx)
outer_ins += rename(nd.op.outer_seqs(nd), idx)
inner_ins += rename(nd.op.inner_seqs(nd.op.inputs), idx)
outer_ins += rename(nd.op.outer_seqs(nd.inputs), idx)
for idx, nd in enumerate(nodes):
# MitMot
inner_ins += rename(nd.op.inner_mitmot(), idx)
inner_outs += nd.op.inner_mitmot_outs()
inner_ins += rename(nd.op.inner_mitmot(nd.op.inputs), idx)
inner_outs += nd.op.inner_mitmot_outs(nd.op.outputs)
info['tap_array'] += nd.op.mitmot_taps()
info['mit_mot_out_slices'] += nd.op.mitmot_out_taps()
outer_ins += rename(nd.op.outer_mitmot(nd), idx)
outer_outs += nd.op.outer_mitmot_outs(nd)
outer_ins += rename(nd.op.outer_mitmot(nd.inputs), idx)
outer_outs += nd.op.outer_mitmot_outs(nd.outputs)
for idx, nd in enumerate(nodes):
# MitSot
inner_ins += rename(nd.op.inner_mitsot(), idx)
inner_outs += nd.op.inner_mitsot_outs()
inner_ins += rename(nd.op.inner_mitsot(nd.op.inputs), idx)
inner_outs += nd.op.inner_mitsot_outs(nd.op.outputs)
info['tap_array'] += nd.op.mitsot_taps()
outer_ins += rename(nd.op.outer_mitsot(nd), idx)
outer_outs += nd.op.outer_mitsot_outs(nd)
outer_ins += rename(nd.op.outer_mitsot(nd.inputs), idx)
outer_outs += nd.op.outer_mitsot_outs(nd.outputs)
for idx, nd in enumerate(nodes):
# SitSot
inner_ins += rename(nd.op.inner_sitsot(), idx)
inner_ins += rename(nd.op.inner_sitsot(nd.op.inputs), idx)
info['tap_array'] += [[-1] for x in xrange(nd.op.n_sit_sot)]
inner_outs += nd.op.inner_sitsot_outs()
outer_ins += rename(nd.op.outer_sitsot(nd), idx)
outer_outs += nd.op.outer_sitsot_outs(nd)
inner_outs += nd.op.inner_sitsot_outs(nd.op.outputs)
outer_ins += rename(nd.op.outer_sitsot(nd.inputs), idx)
outer_outs += nd.op.outer_sitsot_outs(nd.outputs)
for idx, nd in enumerate(nodes):
# Shared
inner_ins += rename(nd.op.inner_shared(), idx)
outer_ins += rename(nd.op.outer_shared(nd), idx)
inner_ins += rename(nd.op.inner_shared(nd.op.inputs), idx)
outer_ins += rename(nd.op.outer_shared(nd.inputs), idx)
for idx, nd in enumerate(nodes):
# NitSot
inner_outs += nd.op.inner_nitsot_outs()
outer_ins += rename(nd.op.outer_nitsot(nd), idx)
outer_outs += nd.op.outer_nitsot_outs(nd)
inner_outs += nd.op.inner_nitsot_outs(nd.op.outputs)
outer_ins += rename(nd.op.outer_nitsot(nd.inputs), idx)
outer_outs += nd.op.outer_nitsot_outs(nd.outputs)
for idx, nd in enumerate(nodes):
# Shared
outer_outs += nd.op.outer_shared_outs(nd)
inner_outs += nd.op.inner_shared_outs()
outer_outs += nd.op.outer_shared_outs(nd.outputs)
inner_outs += nd.op.inner_shared_outs(nd.op.outputs)
for idx, nd in enumerate(nodes):
# Non Seqs
inner_ins += rename(nd.op.inner_non_seqs(), idx)
outer_ins += rename(nd.op.outer_non_seqs(nd), idx)
inner_ins += rename(nd.op.inner_non_seqs(nd.op.inputs), idx)
outer_ins += rename(nd.op.outer_non_seqs(nd.inputs), idx)
# Add back the number of steps
outer_ins = [nodes[0].inputs[0]] + outer_ins
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论