提交 3f6b9b51 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Added checks for ndim and switched from obj.dtype to obj.type.dtype

上级 4dffef0b
...@@ -207,13 +207,13 @@ class Scan(PureOp): ...@@ -207,13 +207,13 @@ class Scan(PureOp):
for idx, (inner_seq, outer_seq) in enumerate( for idx, (inner_seq, outer_seq) in enumerate(
zip(self.inner_seqs(self.inputs), zip(self.inner_seqs(self.inputs),
self.outer_seqs(inputs))): self.outer_seqs(inputs))):
if inner_seq.dtype != outer_seq[idx].dtype: if inner_seq.type.dtype != outer_seq[idx].type.dtype:
raise ValueError(err_msg1 % ('sequence', raise ValueError(err_msg1 % ('sequence',
str(outer_seq), str(outer_seq),
idx, idx,
outer_seq.dtype, outer_seq.type.dtype,
str(inner_seq), str(inner_seq),
inner_seq.dtype)) inner_seq.type.dtype))
argoffset += len(self.outer_seqs(inputs)) argoffset += len(self.outer_seqs(inputs))
# Check that this 3 things have the same dtype for mit_mot: # Check that this 3 things have the same dtype for mit_mot:
# - initial state of the output # - initial state of the output
...@@ -228,23 +228,26 @@ class Scan(PureOp): ...@@ -228,23 +228,26 @@ class Scan(PureOp):
self.mitmot_out_taps(), self.mitmot_out_taps(),
self.outer_mitmot(inputs))): self.outer_mitmot(inputs))):
for k in xrange(len(itaps)): for k in xrange(len(itaps)):
if inner_mitmot[ipos+k].dtpe != outer_mitmot.dtype: if (inner_mitmot[ipos+k].type.dtype !=
outer_mitmot.type.dtype or
inner_mitmot[ipos+k].ndim != outer_mitmot.ndim - 1):
raise ValueError(err_msg1 % ('initial state (outputs_info' raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ', ' in scan nomenclature) ',
str(outer_mitmot), str(outer_mitmot),
argoffset + idx, argoffset + idx,
outer_mitmot.dtype, outer_mitmot.type.dtype,
str(inner_mitmot(ipos+k]), str(inner_mitmot[ipos+k]),
inner_mitmot[ipos+k].dtype))) inner_mitmot[ipos+k].type.dtype))
ipos += len(itaps) ipos += len(itaps)
for k in xrange(len(otaps)): for k in xrange(len(otaps)):
if (inner_mitmot_outs[opos+k].dtype != if (inner_mitmot_outs[opos+k].type.dtype !=
outer_mitmot.dtype): outer_mitmot.type.dtype or
inner_mitmot_outs[opos+k].ndim != outer_mitmot.ndim - 1):
raise ValueError(err_msg2 % raise ValueError(err_msg2 %
(str(outer_mitmot, (str(outer_mitmot,
argoffset + idx, argoffset + idx,
outer_mitmot.dtype, outer_mitmot.type.dtype,
inner_mitmot_outs[opos+k].dtype))) inner_mitmot_outs[opos+k].type.dtype)))
opos += len(otaps) opos += len(otaps)
argoffset += len(self.outer_mitmot(inputs)) argoffset += len(self.outer_mitmot(inputs))
# Same checks as above but for outputs of type mit_sot # Same checks as above but for outputs of type mit_sot
...@@ -255,21 +258,24 @@ class Scan(PureOp): ...@@ -255,21 +258,24 @@ class Scan(PureOp):
self.outer_mitsot(inputs), self.outer_mitsot(inputs),
self.inner_mitsot_outs(self.outputs))): self.inner_mitsot_outs(self.outputs))):
for k in xrange(len(itaps)): for k in xrange(len(itaps)):
if inner_mitsots[ipos+k].dtype != outer_mitsot.dtype: if (inner_mitsots[ipos+k].type.dtype !=
outer_mitsot.type.dtype or
inner_mitsots[ipos+k].ndim != outer_mitsot.ndim - 1):
raise ValueError(err_msg1 % ('initial state (outputs_info' raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ', ' in scan nomenclature) ',
str(outer_mitsot), str(outer_mitsot),
argoffset + idx, argoffset + idx,
outer_mitsot.dtype, outer_mitsot.type.dtype,
str(inner_mitsot(ipos+k]), str(inner_mitsot[ipos+k]),
inner_mitsots[ipos+k].dtype))) inner_mitsots[ipos+k].type.dtype))
ipos += len(itaps) ipos += len(itaps)
if (inner_mitsot_out.dtype != outer_mitsot.dtype): if (inner_mitsot_out.type.dtype != outer_mitsot.type.dtype or
inner_mitsot_out.ndim != outer_mitsot.ndim - 1):
raise ValueError(err_msg2 % raise ValueError(err_msg2 %
(str(outer_mitsot, (str(outer_mitsot,
argoffset + idx, argoffset + idx,
outer_mitsot.dtype, outer_mitsot.type.dtype,
inner_mitsot_out.dtype))) inner_mitsot_out.type.dtype)))
argoffset += len(self.outer_mitsot(inputs)) argoffset += len(self.outer_mitsot(inputs))
# Same checks as above but for outputs of type sit_sot # Same checks as above but for outputs of type sit_sot
...@@ -277,20 +283,22 @@ class Scan(PureOp): ...@@ -277,20 +283,22 @@ class Scan(PureOp):
zip(self.inner_sitsot(self.inputs), zip(self.inner_sitsot(self.inputs),
self.outer_sitsot(inputs), self.outer_sitsot(inputs),
self.inner_sitsot_outs(self.outputs))): self.inner_sitsot_outs(self.outputs))):
if inner_sitsot.dtype != outer_sitsot.dtype: if (inner_sitsot.type.dtype != outer_sitsot.type.dtype or
inner_sitsot.ndim != outer_sitsot.ndim - 1):
raise ValueError(err_msg1 % ('initial state (outputs_info' raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ', ' in scan nomenclature) ',
str(outer_sitsot), str(outer_sitsot),
argoffset + idx, argoffset + idx,
outer_sitsot.dtype, outer_sitsot.type.dtype,
str(inner_sitsot), str(inner_sitsot),
inner_sitsot.dtype))) inner_sitsot.type.dtype))
if (inner_sitsot_out.dtype != outer_sitsot.dtype): if (inner_sitsot_out.type.dtype != outer_sitsot.type.dtype or
inner_sitsot_out.ndim != outer_sitsot.ndim - 1):
raise ValueError(err_msg2 % raise ValueError(err_msg2 %
(str(outer_sitsot, (str(outer_sitsot,
argoffset + idx, argoffset + idx,
outer_sitsot.dtype, outer_sitsot.type.dtype,
inner_sitsot_out.dtype))) inner_sitsot_out.type.dtype)))
argoffset += len(self.outer_sitsot(inputs)) argoffset += len(self.outer_sitsot(inputs))
# Check that the shared variable and their update rule have the same # Check that the shared variable and their update rule have the same
...@@ -300,30 +308,40 @@ class Scan(PureOp): ...@@ -300,30 +308,40 @@ class Scan(PureOp):
self.inner_shared_outs(self.outputs), self.inner_shared_outs(self.outputs),
self.outer_shared(inputs))): self.outer_shared(inputs))):
if (hasattr(outer_shared, 'dtype') and if (hasattr(outer_shared, 'dtype') and
outer_shared.dtype != inner_shared_out.dtype): (outer_shared.dtype != inner_shared_out.dtype or
outer_shared.ndim != inner_shared_out.ndim)):
raise ValueError(err_msg2 % (str(outer_shared), raise ValueError(err_msg2 % (str(outer_shared),
idx + argoffset, idx + argoffset,
outer_shared.dtype, outer_shared.dtype,
inner_shared_out.dtype)) inner_shared_out.dtype))
if (hasattr(outer_shared, 'dtype') and if (hasattr(outer_shared, 'dtype') and
outer_shared.dtype != inner_shared.dtype): (outer_shared.dtype != inner_shared.dtype or
outer_shared.ndim != inner_shared.ndim)):
raise ValueError(err_msg1 % ('initial state (outputs_info' raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ', ' in scan nomenclature) ',
str(outer_shared), str(outer_shared),
argoffset + idx, argoffset + idx,
outer_shared.dtype, outer_shared.dtype,
str(inner_shared), str(inner_shared),
inner_shared.dtype))) inner_shared.dtype))
for inner_nonseq, outer_nonseq in zip(
self.inner_non_seqs(self.inputs),
self.outer_non_seqs(inputs)):
if (inner_nonseq.type.dtype != outer_nonseq.type.dtype or
inner_nonseq.type.ndim != outer_nonseq.type.ndim):
raise ValueError(('Argument %s given to scan node does not'
' match its correspondance %s')%
(str(outer_nonseq), str(inner_nonseq)))
for outer_nitsot in self.outer_nitsot(inputs): for outer_nitsot in self.outer_nitsot(inputs):
# For every nit_sot input we get as input a int/uint that # For every nit_sot input we get as input a int/uint that
# depicts the size in memory for that sequence. This feature is # depicts the size in memory for that sequence. This feature is
# used by truncated BPTT and by scan space optimization # used by truncated BPTT and by scan space optimization
if (str(outer_nitsot.dtype)[:3] not in ('uin', 'int') or if (str(outer_nitsot.type.dtype)[:3] not in ('uin', 'int') or
x.ndim != 0): outer_nitsot.ndim != 0):
raise ValueError('For output %d you need to provide a ' raise ValueError('For output %s you need to provide a '
'scalar int !', outer_nitsot) 'scalar int !', str(outer_nitsot))
apply_node = Apply(self, apply_node = Apply(self,
inputs, inputs,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论