提交 aecfb7d0 authored 作者: Bart van Merriënboer's avatar Bart van Merriënboer

Clearer error message

上级 4d3f04a9
......@@ -205,13 +205,18 @@ class Scan(PureOp):
)
err_msg2 = ('When compiling the inner function of scan the '
'following error has been encountered: The '
'initial state (outputs_info in scan nomenclature) '
'of variable %s (argument number %d)'
' has dtype %s and %d dimension(s), while the result '
'of the inner function for this output has dtype %s '
'and %d dimension(s). This could happen if the inner '
'graph of scan results in an upcast or downcast. '
'Please make sure that you use dtypes consistently')
'initial state (`outputs_info` in scan nomenclature) '
'of variable %s (argument number %d) '
'has dtype %s, while the result of the inner function '
'(`fn`) has dtype %s. This can happen if the inner '
'function of scan results in an upcast or downcast.')
err_msg3 = ('When compiling the inner function of scan the '
'following error has been encountered: The '
'initial state (`outputs_info` in scan nomenclature) '
'of variable %s (argument number %d) has %d dimension(s), '
'while the result of the inner function (`fn`) has %d '
'dimension(s) (should be one less than the initial '
'state).')
def format(var, as_var):
""" This functions ensures that ``out`` has the same dtype as
......@@ -272,17 +277,19 @@ class Scan(PureOp):
inner_mitmot[ipos + k].type.ndim))
ipos += len(itaps)
for k in xrange(len(otaps)):
if (inner_mitmot_outs[opos + k].type.dtype != \
outer_mitmot.type.dtype or
inner_mitmot_outs[opos + k].ndim != \
outer_mitmot.ndim - 1):
if (inner_mitmot_outs[opos + k].type.dtype !=
outer_mitmot.type.dtype):
raise ValueError(err_msg2 %
(str(outer_mitmot),
argoffset + idx,
outer_mitmot.type.dtype,
outer_mitmot.ndim - 1,
inner_mitmot_outs[opos + k].type.dtype,
inner_mitmot_outs[opos + k].ndim))
(str(outer_mitmot),
argoffset + idx,
outer_mitmot.type.dtype,
inner_mitmot_outs[opos + k].type.dtype))
if inner_mitmot_outs[opos + k].ndim != outer_mitmot.ndim - 1:
raise ValueError(err_msg3 %
(str(outer_mitmot),
argoffset + idx,
outer_mitmot.ndim,
inner_mitmot_outs[opos + k].ndim))
opos += len(otaps)
argoffset += len(self.outer_mitmot(inputs))
# Same checks as above but for outputs of type mit_sot
......@@ -309,15 +316,18 @@ class Scan(PureOp):
inner_mitsots[ipos + k].type.dtype,
inner_mitsots[ipos + k].type.ndim))
ipos += len(itaps)
if (inner_mitsot_out.type.dtype != outer_mitsot.type.dtype or
inner_mitsot_out.ndim != outer_mitsot.ndim - 1):
if inner_mitsot_out.type.dtype != outer_mitsot.type.dtype:
raise ValueError(err_msg2 %
(str(outer_mitsot),
argoffset + idx,
outer_mitsot.type.dtype,
outer_mitsot.type.ndim,
inner_mitsot_out.type.dtype,
inner_mitsot_out.type.ndim))
argoffset + idx,
outer_mitsot.type.dtype,
inner_mitsot_out.type.dtype))
if inner_mitsot_out.ndim != outer_mitsot.ndim - 1:
raise ValueError(err_msg3 %
(str(outer_mitsot),
argoffset + idx,
outer_mitsot.ndim,
inner_mitsot_out.ndim))
argoffset += len(self.outer_mitsot(inputs))
# Same checks as above but for outputs of type sit_sot
......@@ -337,15 +347,18 @@ class Scan(PureOp):
str(inner_sitsot),
inner_sitsot.type.dtype,
inner_sitsot.type.ndim))
if (inner_sitsot_out.type.dtype != outer_sitsot.type.dtype or
inner_sitsot_out.ndim != outer_sitsot.ndim - 1):
if inner_sitsot_out.type.dtype != outer_sitsot.type.dtype:
raise ValueError(err_msg2 %
(str(outer_sitsot),
argoffset + idx,
outer_sitsot.type.dtype,
outer_sitsot.type.ndim,
inner_sitsot_out.type.dtype,
inner_sitsot_out.type.ndim))
(str(outer_sitsot),
argoffset + idx,
outer_sitsot.type.dtype,
inner_sitsot_out.type.dtype))
if inner_sitsot_out.ndim != outer_sitsot.ndim - 1:
raise ValueError(err_msg3 %
(str(outer_sitsot),
argoffset + idx,
outer_sitsot.type.ndim,
inner_sitsot_out.type.ndim))
argoffset += len(self.outer_sitsot(inputs))
# Check that the shared variable and their update rule have the same
......@@ -357,13 +370,16 @@ class Scan(PureOp):
outer_shared = format(_outer_shared, as_var=inner_shared)
new_inputs.append(outer_shared)
if (hasattr(outer_shared, 'dtype') and
(outer_shared.dtype != inner_shared_out.dtype or
outer_shared.ndim != inner_shared_out.ndim)):
outer_shared.dtype != inner_shared_out.dtype):
raise ValueError(err_msg2 % (str(outer_shared),
idx + argoffset,
outer_shared.dtype,
inner_shared_out.dtype))
if (hasattr(outer_shared, 'dtype') and
outer_shared.ndim != inner_shared_out.ndim):
raise ValueError(err_msg3 % (str(outer_shared),
idx + argoffset,
outer_shared.ndim,
inner_shared_out.dtype,
inner_shared_out.ndim))
if (hasattr(outer_shared, 'dtype') and
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论