提交 3c3c186b authored 作者: Razvan Pascanu's avatar Razvan Pascanu

updated make_node to use convinience functions

The body of make_node has been updated to make use of the newly revised methods of scan to parse arguments, hopefully increasing clarity.
上级 ab5555c5
...@@ -189,85 +189,129 @@ class Scan(PureOp): ...@@ -189,85 +189,129 @@ class Scan(PureOp):
# Check if input sequences and variables representing a slice of # Check if input sequences and variables representing a slice of
# them have the same dtype # them have the same dtype
for idx in xrange(self.n_seqs): argoffset = 0
if inputs[1 + idx].dtype != self.inputs[idx].dtype: for idx, inner_seq, outer_seq in enumerate(
zip(self.inner_seqs(self.inputs),
self.outer_seqs(inputs))):
if inner_seq.dtype != outer_seq[idx].dtype:
raise ValueError(err_msg1 % ('sequence', raise ValueError(err_msg1 % ('sequence',
str(inputs[1 + idx]), str(outer_seq),
idx, idx,
inputs[1 + idx].dtype, outer_seq.dtype,
str(self.inputs[idx]), str(inner_seq),
self.inputs[idx].dtype)) inner_seq.dtype))
argoffset += len(self.outer_seqs(inputs))
# Check that this 3 things have the same dtype for mit_mot: # Check that this 3 things have the same dtype for mit_mot:
# - initial state of the output # - initial state of the output
# - variable representing an input slice of the otuput # - variable representing an input slice of the otuput
# - variable representing an output slice of the otuput # - variable representing an output slice of the otuput
# Maybe checking that ndim fits would be good as well !? ipos = 0
index_i = self.n_seqs opos = 0
index_o = 0 inner_mitmot = self.inner_mitmot(self.inputs)
index = 1 + self.n_seqs inner_mitmot_outs = self.inner_mitmot_outs(self.outputs)
start = index for idx, itaps, otaps, outer_mitmot in enumerate(
end = index + self.n_mit_mot zip(self.mitmot_taps(),
while index < end: self.mitmot_out_taps(),
for k in self.tap_array[index - start]: self.outer_mitmot(inputs))):
if inputs[index].dtype != self.inputs[index_i].dtype: for k in xrange(len(itaps)):
if inner_mitmot[ipos+k].dtpe != outer_mitmot.dtype:
raise ValueError(err_msg1 % ('initial state (outputs_info' raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ', ' in scan nomenclature) ',
str(inputs[index]), str(outer_mitmot),
index, argoffset + idx,
inputs[index].dtype, outer_mitmot.dtype,
str(self.inputs[index_i]), str(inner_mitmot(ipos+k]),
self.inputs[index_i].dtype)) inner_mitmot[ipos+k].dtype)))
index_i += 1 ipos += len(itaps)
for k in self.mit_mot_out_slices[index - start]: for k in xrange(len(otaps)):
if inputs[index].dtype != self.outputs[index_o].dtype: if (inner_mitmot_outs[opos+k].dtype !=
raise ValueError(err_msg2 % (str(inputs[index]), outer_mitmot.dtype):
index, raise ValueError(err_msg2 %
inputs[index].dtype, (str(outer_mitmot,
self.outputs[index_o].dtype)) argoffset + idx,
index_o += 1 outer_mitmot.dtype,
index += 1 inner_mitmot_outs[opos+k].dtype)))
# Same checks as above but for outputs of type mit_sot and sit_sot opos += len(otaps)
end += self.n_mit_sot + self.n_sit_sot argoffset += len(self.outer_mitmot(inputs))
while index < end: # Same checks as above but for outputs of type mit_sot
for k in self.tap_array[index - start]: ipos = 0
if inputs[index].dtype != self.inputs[index_i].dtype: inner_mitsots = self.inner_mitsot(self.inputs)
raise ValueError(err_msg1 % ('Initial state', for idx, itaps, outer_mitsot, inner_mitsot_out in enumerate(
str(inputs[index]), zip(self.mitsot_taps(),
index, self.outer_mitsot(inputs),
inputs[index].dtype, self.inner_mitsot_outs(self.outputs))):
str(self.inputs[index_i]), for k in xrange(len(itaps)):
self.inputs[index_i].dtype)) if inner_mitsots[ipos+k].dtype != outer_mitsot.dtype:
index_i += 1 raise ValueError(err_msg1 % ('initial state (outputs_info'
if inputs[index].dtype != self.outputs[index_o].dtype: ' in scan nomenclature) ',
raise ValueError(err_msg2 % (str(inputs[index]), str(outer_mitsot),
index, argoffset + idx,
inputs[index].dtype, outer_mitsot.dtype,
self.outputs[index_o].dtype)) str(inner_mitsot(ipos+k]),
index_o += 1 inner_mitsots[ipos+k].dtype)))
index += 1 ipos += len(itaps)
if (inner_mitsot_out.dtype != outer_mitsot.dtype):
raise ValueError(err_msg2 %
(str(outer_mitsot,
argoffset + idx,
outer_mitsot.dtype,
inner_mitsot_out.dtype)))
argoffset += len(self.outer_mitsot(inputs))
# Same checks as above but for outputs of type sit_sot
for idx, inner_sitsot, outer_sitsot, inner_sitsot_out in enumerate(
zip(self.inner_sitsot(self.inputs),
self.outer_sitsot(inputs),
self.inner_sitsot_outs(self.outputs))):
if inner_sitsot.dtype != outer_sitsot.dtype:
raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ',
str(outer_sitsot),
argoffset + idx,
outer_sitsot.dtype,
str(inner_sitsot),
inner_sitsot.dtype)))
if (inner_sitsot_out.dtype != outer_sitsot.dtype):
raise ValueError(err_msg2 %
(str(outer_sitsot,
argoffset + idx,
outer_sitsot.dtype,
inner_sitsot_out.dtype)))
argoffset += len(self.outer_sitsot(inputs))
# Check that the shared variable and their update rule have the same # Check that the shared variable and their update rule have the same
# dtype. Maybe even same type ?! # dtype. Maybe even same type ?!
end += self.n_shared_outs end += self.n_shared_outs
index_o += self.n_nit_sot index_o += self.n_nit_sot
while index < end: for idx, inner_shared, inner_shared_out, outer_shared in enumerate(
if (hasattr(inputs[index], 'dtype') and zip(self.inner_shared(self.inputs),
inputs[index].dtype != self.outputs[index_o].dtype): self.inner_shared_outs(self.outputs),
raise ValueError(err_msg2 % (str(inputs[index]), self.outer_shared(inputs))):
index, if (hasattr(outer_shared, 'dtype') and
inputs[index].dtype, outer_shared.dtype != inner_shared_out.dtype):
self.outputs[index_o].dtype)) raise ValueError(err_msg2 % (str(outer_shared),
index += 1 idx + argoffset,
index_o += 1 outer_shared.dtype,
for x in inputs[index:index + self.n_nit_sot]: inner_shared_out.dtype))
if (hasattr(outer_shared, 'dtype') and
outer_shared.dtype != inner_shared.dtype):
raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ',
str(outer_shared),
argoffset + idx,
outer_shared.dtype,
str(inner_shared),
inner_shared.dtype)))
for outer_nitsot in self.outer_nitsot(inputs):
# For every nit_sot input we get as input a int/uint that # For every nit_sot input we get as input a int/uint that
# depicts the size in memory for that sequence. This feature is # depicts the size in memory for that sequence. This feature is
# used by truncated BPTT and by scan space optimization # used by truncated BPTT and by scan space optimization
if (str(x.dtype)[:3] not in ('uin', 'int') or if (str(outer_nitsot.dtype)[:3] not in ('uin', 'int') or
x.ndim != 0): x.ndim != 0):
raise ValueError('For output %d you need to provide a ' raise ValueError('For output %d you need to provide a '
'scalar int !', x) 'scalar int !', outer_nitsot)
apply_node = Apply(self, apply_node = Apply(self,
inputs, inputs,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论