提交 4bf5448d authored 作者: Frederic Bastien's avatar Frederic Bastien

Better error message to help user debug error like gh-5347

上级 13c2b145
......@@ -1041,11 +1041,14 @@ class ShapeFeature(object):
rval.append(None)
return rval
def unpack(self, s_i):
def unpack(self, s_i, var):
"""Return a symbolic integer scalar for the shape element s_i.
The s_i argument was produced by the infer_shape() of an Op subclass.
var: the variable that correspond to s_i. This is just for
error reporting.
"""
# unpack the s_i that the Op returned
assert s_i is not None
......@@ -1059,7 +1062,10 @@ class ShapeFeature(object):
isinstance(s_i, numpy.integer) or
(isinstance(s_i, numpy.ndarray) and s_i.ndim == 0)):
# this shape is a constant
assert s_i >= 0
if s_i < 0:
msg = "There is a negative shape in the graph!"
msg += gof.utils.get_variable_trace_string(var)
raise ValueError(msg)
return T.constant(s_i, dtype='int64')
if type(s_i) in (tuple, list):
# this dimension is the same as many of the inputs
......@@ -1137,7 +1143,7 @@ class ShapeFeature(object):
r.type.broadcastable[i]):
shape_vars.append(self.lscalar_one)
else:
shape_vars.append(self.unpack(s[i]))
shape_vars.append(self.unpack(s[i], r))
assert all([not hasattr(r.type, "broadcastable") or
not r.type.broadcastable[i] or
# The two following comparison are a speed optimization
......@@ -1238,7 +1244,7 @@ class ShapeFeature(object):
new_shape = []
for j, s_j in enumerate(prev_shape):
if j == i:
new_shape.append(self.unpack(s_i))
new_shape.append(self.unpack(s_i, r))
else:
new_shape.append(s_j)
assert all([not hasattr(r.type, "broadcastable") or
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论