提交 fb605c1c authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Made error message more explicit

上级 5992e548
...@@ -154,12 +154,22 @@ class Scan(PureOp): ...@@ -154,12 +154,22 @@ class Scan(PureOp):
def make_node(self, *inputs): def make_node(self, *inputs):
assert numpy.all(isinstance(i, gof.Variable) for i in inputs) assert numpy.all(isinstance(i, gof.Variable) for i in inputs)
# assert dtype is consistent # assert dtype is consistent
err_msg1 = ('%s %s (index %d) has dtype %s. Slice %s representing ' err_msg1 = ('When compiling the inner function of scan the '
'this input has dtype %s' ) 'following error has been encountered: The '
'%s %s( the entry number %d) has dtype '
err_msg2 = ('Initial state %s (index %d) has dtype %s. The ' '%s. The corresponding slice %s however has'
'corresponding output of the inner function applied ' ' dtype %s. This should never happen, please '
'recurrently has dtype %s') 'report to theano-dev mailing list'
)
err_msg2 = ('When compiling the inner function of scan the '
'following error has been encountered: The '
'initial state (outputs_info in scan nomenclature)'
'of variable %s (the entry number %d)'
' has dtype %s, while the result of the'
' inner function for this output has dtype %s. This '
'could happen if the inner graph of scan results in '
'an upcast or downcast. Please make sure that you use'
'dtypes consistently')
# Flags that indicate which inputs are vectors # Flags that indicate which inputs are vectors
...@@ -174,7 +184,7 @@ class Scan(PureOp): ...@@ -174,7 +184,7 @@ class Scan(PureOp):
# them have the same dtype # them have the same dtype
for idx in xrange(self.n_seqs): for idx in xrange(self.n_seqs):
if inputs[1+idx].dtype != self.inputs[idx].dtype: if inputs[1+idx].dtype != self.inputs[idx].dtype:
raise ValueError(err_msg1%( 'Sequence' raise ValueError(err_msg1%( 'sequence'
, str(inputs[1+idx]) , str(inputs[1+idx])
, idx , idx
, inputs[1+idx].dtype , inputs[1+idx].dtype
...@@ -194,7 +204,8 @@ class Scan(PureOp): ...@@ -194,7 +204,8 @@ class Scan(PureOp):
while index < end: while index < end:
for k in self.tap_array[index-start]: for k in self.tap_array[index-start]:
if inputs[index].dtype != self.inputs[index_i].dtype: if inputs[index].dtype != self.inputs[index_i].dtype:
raise ValueError(err_msg1%( 'Initial state' raise ValueError(err_msg1%( 'initial state (outputs_info'
' in scan nomenclature) '
, str(inputs[index]) , str(inputs[index])
, index , index
, inputs[index].dtype , inputs[index].dtype
...@@ -203,7 +214,7 @@ class Scan(PureOp): ...@@ -203,7 +214,7 @@ class Scan(PureOp):
index_i += 1 index_i += 1
for k in self.mit_mot_out_slices[index-start]: for k in self.mit_mot_out_slices[index-start]:
if inputs[index].dtype != self.outputs[index_o].dtype: if inputs[index].dtype != self.outputs[index_o].dtype:
raise ValueError(err_msg2%( inputs[index].name raise ValueError(err_msg2%( str(inputs[index])
, index , index
, inputs[index].dtype , inputs[index].dtype
, self.outputs[index_o].dtype) ) , self.outputs[index_o].dtype) )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论