提交 692ab088 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

fixed error message

The error was raised if either the dtype didn't match or ndim didn't match. However the error message did not display the ndims.
上级 065845eb
...@@ -179,19 +179,20 @@ class Scan(PureOp): ...@@ -179,19 +179,20 @@ class Scan(PureOp):
err_msg1 = ('When compiling the inner function of scan the ' err_msg1 = ('When compiling the inner function of scan the '
'following error has been encountered: The ' 'following error has been encountered: The '
'%s %s (argument number %d) has dtype ' '%s %s (argument number %d) has dtype '
'%s. The corresponding slice %s however has' '%s and %d dimension(s). The corresponding slice %s '
' dtype %s. This should never happen, please ' 'however has dtype %s and %d dimension(s). This '
'should never happen, please '
'report to theano-dev mailing list' 'report to theano-dev mailing list'
) )
err_msg2 = ('When compiling the inner function of scan the ' err_msg2 = ('When compiling the inner function of scan the '
'following error has been encountered: The ' 'following error has been encountered: The '
'initial state (outputs_info in scan nomenclature)' 'initial state (outputs_info in scan nomenclature)'
'of variable %s (argument number %d)' 'of variable %s (argument number %d)'
' has dtype %s, while the result of the' ' has dtype %s and %d dimension(s), while the result of the'
' inner function for this output has dtype %s. This ' ' inner function for this output has dtype %s and %d '
'could happen if the inner graph of scan results in ' 'dimension(s). This could happen if the inner graph of '
'an upcast or downcast. Please make sure that you use' ' scan results in an upcast or downcast. Please make '
'dtypes consistently') 'sure that you use dtypes consistently')
# TODO make the assert exact # TODO make the assert exact
# TODO assert the type(dtype, nbdim of self.inputs and inputs correspond) # TODO assert the type(dtype, nbdim of self.inputs and inputs correspond)
#assert len(inputs) >= len(self.inputs) #assert len(inputs) >= len(self.inputs)
...@@ -273,8 +274,10 @@ class Scan(PureOp): ...@@ -273,8 +274,10 @@ class Scan(PureOp):
str(outer_mitsot), str(outer_mitsot),
argoffset + idx, argoffset + idx,
outer_mitsot.type.dtype, outer_mitsot.type.dtype,
otuer_mitsot.type.ndim,
str(inner_mitsot[ipos+k]), str(inner_mitsot[ipos+k]),
inner_mitsots[ipos+k].type.dtype)) inner_mitsots[ipos+k].type.dtype,
inner_mitsots[ipos+k].type.ndim))
ipos += len(itaps) ipos += len(itaps)
if (inner_mitsot_out.type.dtype != outer_mitsot.type.dtype or if (inner_mitsot_out.type.dtype != outer_mitsot.type.dtype or
inner_mitsot_out.ndim != outer_mitsot.ndim - 1): inner_mitsot_out.ndim != outer_mitsot.ndim - 1):
...@@ -282,7 +285,9 @@ class Scan(PureOp): ...@@ -282,7 +285,9 @@ class Scan(PureOp):
(str(outer_mitsot), (str(outer_mitsot),
argoffset + idx, argoffset + idx,
outer_mitsot.type.dtype, outer_mitsot.type.dtype,
inner_mitsot_out.type.dtype)) outer_mitsot.type.ndim,
inner_mitsot_out.type.dtype,
inner_mitsot_out.type.ndim ))
argoffset += len(self.outer_mitsot(inputs)) argoffset += len(self.outer_mitsot(inputs))
# Same checks as above but for outputs of type sit_sot # Same checks as above but for outputs of type sit_sot
...@@ -297,15 +302,19 @@ class Scan(PureOp): ...@@ -297,15 +302,19 @@ class Scan(PureOp):
str(outer_sitsot), str(outer_sitsot),
argoffset + idx, argoffset + idx,
outer_sitsot.type.dtype, outer_sitsot.type.dtype,
outer_sitsot.type.ndim,
str(inner_sitsot), str(inner_sitsot),
inner_sitsot.type.dtype)) inner_sitsot.type.dtype,
inner_sitsot.type.ndim))
if (inner_sitsot_out.type.dtype != outer_sitsot.type.dtype or if (inner_sitsot_out.type.dtype != outer_sitsot.type.dtype or
inner_sitsot_out.ndim != outer_sitsot.ndim - 1): inner_sitsot_out.ndim != outer_sitsot.ndim - 1):
raise ValueError(err_msg2 % raise ValueError(err_msg2 %
(str(outer_sitsot), (str(outer_sitsot),
argoffset + idx, argoffset + idx,
outer_sitsot.type.dtype, outer_sitsot.type.dtype,
inner_sitsot_out.type.dtype)) outer_sitsot.type.ndim,
inner_sitsot_out.type.dtype,
inner_sitsot_out.type.ndim ))
argoffset += len(self.outer_sitsot(inputs)) argoffset += len(self.outer_sitsot(inputs))
# Check that the shared variable and their update rule have the same # Check that the shared variable and their update rule have the same
...@@ -320,7 +329,9 @@ class Scan(PureOp): ...@@ -320,7 +329,9 @@ class Scan(PureOp):
raise ValueError(err_msg2 % (str(outer_shared), raise ValueError(err_msg2 % (str(outer_shared),
idx + argoffset, idx + argoffset,
outer_shared.dtype, outer_shared.dtype,
inner_shared_out.dtype)) outer_shared.ndim,
inner_shared_out.dtype,
inner_shared_out.ndim))
if (hasattr(outer_shared, 'dtype') and if (hasattr(outer_shared, 'dtype') and
(outer_shared.dtype != inner_shared.dtype or (outer_shared.dtype != inner_shared.dtype or
...@@ -330,8 +341,10 @@ class Scan(PureOp): ...@@ -330,8 +341,10 @@ class Scan(PureOp):
str(outer_shared), str(outer_shared),
argoffset + idx, argoffset + idx,
outer_shared.dtype, outer_shared.dtype,
outer_shared.ndim,
str(inner_shared), str(inner_shared),
inner_shared.dtype)) inner_shared.dtype,
inner_shared.ndim))
for inner_nonseq, outer_nonseq in zip( for inner_nonseq, outer_nonseq in zip(
self.inner_non_seqs(self.inputs), self.inner_non_seqs(self.inputs),
self.outer_non_seqs(inputs)): self.outer_non_seqs(inputs)):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论