提交 454df052 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #1843 from bartvm/scan_error_message

Fixed error message for scan op dimension mismatch
...@@ -205,13 +205,18 @@ class Scan(PureOp): ...@@ -205,13 +205,18 @@ class Scan(PureOp):
) )
err_msg2 = ('When compiling the inner function of scan the ' err_msg2 = ('When compiling the inner function of scan the '
'following error has been encountered: The ' 'following error has been encountered: The '
'initial state (outputs_info in scan nomenclature) ' 'initial state (`outputs_info` in scan nomenclature) '
'of variable %s (argument number %d)' 'of variable %s (argument number %d) '
' has dtype %s and %d dimension(s), while the result ' 'has dtype %s, while the result of the inner function '
'of the inner function for this output has dtype %s ' '(`fn`) has dtype %s. This can happen if the inner '
'and %d dimension(s). This could happen if the inner ' 'function of scan results in an upcast or downcast.')
'graph of scan results in an upcast or downcast. ' err_msg3 = ('When compiling the inner function of scan the '
'Please make sure that you use dtypes consistently') 'following error has been encountered: The '
'initial state (`outputs_info` in scan nomenclature) '
'of variable %s (argument number %d) has %d dimension(s), '
'while the result of the inner function (`fn`) has %d '
'dimension(s) (should be one less than the initial '
'state).')
def format(var, as_var): def format(var, as_var):
""" This functions ensures that ``out`` has the same dtype as """ This functions ensures that ``out`` has the same dtype as
...@@ -272,16 +277,18 @@ class Scan(PureOp): ...@@ -272,16 +277,18 @@ class Scan(PureOp):
inner_mitmot[ipos + k].type.ndim)) inner_mitmot[ipos + k].type.ndim))
ipos += len(itaps) ipos += len(itaps)
for k in xrange(len(otaps)): for k in xrange(len(otaps)):
if (inner_mitmot_outs[opos + k].type.dtype != \ if (inner_mitmot_outs[opos + k].type.dtype !=
outer_mitmot.type.dtype or outer_mitmot.type.dtype):
inner_mitmot_outs[opos + k].ndim != \
outer_mitmot.ndim - 1):
raise ValueError(err_msg2 % raise ValueError(err_msg2 %
(str(outer_mitmot), (str(outer_mitmot),
argoffset + idx, argoffset + idx,
outer_mitmot.type.dtype, outer_mitmot.type.dtype,
inner_mitmot_outs[opos + k].type.dtype))
if inner_mitmot_outs[opos + k].ndim != outer_mitmot.ndim - 1:
raise ValueError(err_msg3 %
(str(outer_mitmot),
argoffset + idx,
outer_mitmot.ndim, outer_mitmot.ndim,
inner_mitmot_outs[opos + k].type.dtype,
inner_mitmot_outs[opos + k].ndim)) inner_mitmot_outs[opos + k].ndim))
opos += len(otaps) opos += len(otaps)
argoffset += len(self.outer_mitmot(inputs)) argoffset += len(self.outer_mitmot(inputs))
...@@ -309,15 +316,18 @@ class Scan(PureOp): ...@@ -309,15 +316,18 @@ class Scan(PureOp):
inner_mitsots[ipos + k].type.dtype, inner_mitsots[ipos + k].type.dtype,
inner_mitsots[ipos + k].type.ndim)) inner_mitsots[ipos + k].type.ndim))
ipos += len(itaps) ipos += len(itaps)
if (inner_mitsot_out.type.dtype != outer_mitsot.type.dtype or if inner_mitsot_out.type.dtype != outer_mitsot.type.dtype:
inner_mitsot_out.ndim != outer_mitsot.ndim - 1):
raise ValueError(err_msg2 % raise ValueError(err_msg2 %
(str(outer_mitsot), (str(outer_mitsot),
argoffset + idx, argoffset + idx,
outer_mitsot.type.dtype, outer_mitsot.type.dtype,
outer_mitsot.type.ndim, inner_mitsot_out.type.dtype))
inner_mitsot_out.type.dtype, if inner_mitsot_out.ndim != outer_mitsot.ndim - 1:
inner_mitsot_out.type.ndim)) raise ValueError(err_msg3 %
(str(outer_mitsot),
argoffset + idx,
outer_mitsot.ndim,
inner_mitsot_out.ndim))
argoffset += len(self.outer_mitsot(inputs)) argoffset += len(self.outer_mitsot(inputs))
# Same checks as above but for outputs of type sit_sot # Same checks as above but for outputs of type sit_sot
...@@ -337,14 +347,17 @@ class Scan(PureOp): ...@@ -337,14 +347,17 @@ class Scan(PureOp):
str(inner_sitsot), str(inner_sitsot),
inner_sitsot.type.dtype, inner_sitsot.type.dtype,
inner_sitsot.type.ndim)) inner_sitsot.type.ndim))
if (inner_sitsot_out.type.dtype != outer_sitsot.type.dtype or if inner_sitsot_out.type.dtype != outer_sitsot.type.dtype:
inner_sitsot_out.ndim != outer_sitsot.ndim - 1):
raise ValueError(err_msg2 % raise ValueError(err_msg2 %
(str(outer_sitsot), (str(outer_sitsot),
argoffset + idx, argoffset + idx,
outer_sitsot.type.dtype, outer_sitsot.type.dtype,
inner_sitsot_out.type.dtype))
if inner_sitsot_out.ndim != outer_sitsot.ndim - 1:
raise ValueError(err_msg3 %
(str(outer_sitsot),
argoffset + idx,
outer_sitsot.type.ndim, outer_sitsot.type.ndim,
inner_sitsot_out.type.dtype,
inner_sitsot_out.type.ndim)) inner_sitsot_out.type.ndim))
argoffset += len(self.outer_sitsot(inputs)) argoffset += len(self.outer_sitsot(inputs))
...@@ -357,13 +370,16 @@ class Scan(PureOp): ...@@ -357,13 +370,16 @@ class Scan(PureOp):
outer_shared = format(_outer_shared, as_var=inner_shared) outer_shared = format(_outer_shared, as_var=inner_shared)
new_inputs.append(outer_shared) new_inputs.append(outer_shared)
if (hasattr(outer_shared, 'dtype') and if (hasattr(outer_shared, 'dtype') and
(outer_shared.dtype != inner_shared_out.dtype or outer_shared.dtype != inner_shared_out.dtype):
outer_shared.ndim != inner_shared_out.ndim)):
raise ValueError(err_msg2 % (str(outer_shared), raise ValueError(err_msg2 % (str(outer_shared),
idx + argoffset, idx + argoffset,
outer_shared.dtype, outer_shared.dtype,
inner_shared_out.dtype))
if (hasattr(outer_shared, 'dtype') and
outer_shared.ndim != inner_shared_out.ndim):
raise ValueError(err_msg3 % (str(outer_shared),
idx + argoffset,
outer_shared.ndim, outer_shared.ndim,
inner_shared_out.dtype,
inner_shared_out.ndim)) inner_shared_out.ndim))
if (hasattr(outer_shared, 'dtype') and if (hasattr(outer_shared, 'dtype') and
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论