提交 3c3c186b authored 作者: Razvan Pascanu's avatar Razvan Pascanu

updated make_node to use convinience functions

The body of make_node has been updated to make use of the newly revised methods of scan to parse arguments, hopefully increasing clarity.
上级 ab5555c5
......@@ -189,85 +189,129 @@ class Scan(PureOp):
# Check if input sequences and variables representing a slice of
# them have the same dtype
for idx in xrange(self.n_seqs):
if inputs[1 + idx].dtype != self.inputs[idx].dtype:
argoffset = 0
for idx, inner_seq, outer_seq in enumerate(
zip(self.inner_seqs(self.inputs),
self.outer_seqs(inputs))):
if inner_seq.dtype != outer_seq[idx].dtype:
raise ValueError(err_msg1 % ('sequence',
str(inputs[1 + idx]),
str(outer_seq),
idx,
inputs[1 + idx].dtype,
str(self.inputs[idx]),
self.inputs[idx].dtype))
outer_seq.dtype,
str(inner_seq),
inner_seq.dtype))
argoffset += len(self.outer_seqs(inputs))
# Check that this 3 things have the same dtype for mit_mot:
# - initial state of the output
# - variable representing an input slice of the otuput
# - variable representing an output slice of the otuput
# Maybe checking that ndim fits would be good as well !?
index_i = self.n_seqs
index_o = 0
index = 1 + self.n_seqs
start = index
end = index + self.n_mit_mot
while index < end:
for k in self.tap_array[index - start]:
if inputs[index].dtype != self.inputs[index_i].dtype:
ipos = 0
opos = 0
inner_mitmot = self.inner_mitmot(self.inputs)
inner_mitmot_outs = self.inner_mitmot_outs(self.outputs)
for idx, itaps, otaps, outer_mitmot in enumerate(
zip(self.mitmot_taps(),
self.mitmot_out_taps(),
self.outer_mitmot(inputs))):
for k in xrange(len(itaps)):
if inner_mitmot[ipos+k].dtpe != outer_mitmot.dtype:
raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ',
str(inputs[index]),
index,
inputs[index].dtype,
str(self.inputs[index_i]),
self.inputs[index_i].dtype))
index_i += 1
for k in self.mit_mot_out_slices[index - start]:
if inputs[index].dtype != self.outputs[index_o].dtype:
raise ValueError(err_msg2 % (str(inputs[index]),
index,
inputs[index].dtype,
self.outputs[index_o].dtype))
index_o += 1
index += 1
# Same checks as above but for outputs of type mit_sot and sit_sot
end += self.n_mit_sot + self.n_sit_sot
while index < end:
for k in self.tap_array[index - start]:
if inputs[index].dtype != self.inputs[index_i].dtype:
raise ValueError(err_msg1 % ('Initial state',
str(inputs[index]),
index,
inputs[index].dtype,
str(self.inputs[index_i]),
self.inputs[index_i].dtype))
index_i += 1
if inputs[index].dtype != self.outputs[index_o].dtype:
raise ValueError(err_msg2 % (str(inputs[index]),
index,
inputs[index].dtype,
self.outputs[index_o].dtype))
index_o += 1
index += 1
str(outer_mitmot),
argoffset + idx,
outer_mitmot.dtype,
str(inner_mitmot(ipos+k]),
inner_mitmot[ipos+k].dtype)))
ipos += len(itaps)
for k in xrange(len(otaps)):
if (inner_mitmot_outs[opos+k].dtype !=
outer_mitmot.dtype):
raise ValueError(err_msg2 %
(str(outer_mitmot,
argoffset + idx,
outer_mitmot.dtype,
inner_mitmot_outs[opos+k].dtype)))
opos += len(otaps)
argoffset += len(self.outer_mitmot(inputs))
# Same checks as above but for outputs of type mit_sot
ipos = 0
inner_mitsots = self.inner_mitsot(self.inputs)
for idx, itaps, outer_mitsot, inner_mitsot_out in enumerate(
zip(self.mitsot_taps(),
self.outer_mitsot(inputs),
self.inner_mitsot_outs(self.outputs))):
for k in xrange(len(itaps)):
if inner_mitsots[ipos+k].dtype != outer_mitsot.dtype:
raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ',
str(outer_mitsot),
argoffset + idx,
outer_mitsot.dtype,
str(inner_mitsot(ipos+k]),
inner_mitsots[ipos+k].dtype)))
ipos += len(itaps)
if (inner_mitsot_out.dtype != outer_mitsot.dtype):
raise ValueError(err_msg2 %
(str(outer_mitsot,
argoffset + idx,
outer_mitsot.dtype,
inner_mitsot_out.dtype)))
argoffset += len(self.outer_mitsot(inputs))
# Same checks as above but for outputs of type sit_sot
for idx, inner_sitsot, outer_sitsot, inner_sitsot_out in enumerate(
zip(self.inner_sitsot(self.inputs),
self.outer_sitsot(inputs),
self.inner_sitsot_outs(self.outputs))):
if inner_sitsot.dtype != outer_sitsot.dtype:
raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ',
str(outer_sitsot),
argoffset + idx,
outer_sitsot.dtype,
str(inner_sitsot),
inner_sitsot.dtype)))
if (inner_sitsot_out.dtype != outer_sitsot.dtype):
raise ValueError(err_msg2 %
(str(outer_sitsot,
argoffset + idx,
outer_sitsot.dtype,
inner_sitsot_out.dtype)))
argoffset += len(self.outer_sitsot(inputs))
# Check that the shared variable and their update rule have the same
# dtype. Maybe even same type ?!
end += self.n_shared_outs
index_o += self.n_nit_sot
while index < end:
if (hasattr(inputs[index], 'dtype') and
inputs[index].dtype != self.outputs[index_o].dtype):
raise ValueError(err_msg2 % (str(inputs[index]),
index,
inputs[index].dtype,
self.outputs[index_o].dtype))
index += 1
index_o += 1
for x in inputs[index:index + self.n_nit_sot]:
for idx, inner_shared, inner_shared_out, outer_shared in enumerate(
zip(self.inner_shared(self.inputs),
self.inner_shared_outs(self.outputs),
self.outer_shared(inputs))):
if (hasattr(outer_shared, 'dtype') and
outer_shared.dtype != inner_shared_out.dtype):
raise ValueError(err_msg2 % (str(outer_shared),
idx + argoffset,
outer_shared.dtype,
inner_shared_out.dtype))
if (hasattr(outer_shared, 'dtype') and
outer_shared.dtype != inner_shared.dtype):
raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ',
str(outer_shared),
argoffset + idx,
outer_shared.dtype,
str(inner_shared),
inner_shared.dtype)))
for outer_nitsot in self.outer_nitsot(inputs):
# For every nit_sot input we get as input a int/uint that
# depicts the size in memory for that sequence. This feature is
# used by truncated BPTT and by scan space optimization
if (str(x.dtype)[:3] not in ('uin', 'int') or
if (str(outer_nitsot.dtype)[:3] not in ('uin', 'int') or
x.ndim != 0):
raise ValueError('For output %d you need to provide a '
'scalar int !', x)
'scalar int !', outer_nitsot)
apply_node = Apply(self,
inputs,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论