提交 4bf5448d authored 作者: Frederic Bastien's avatar Frederic Bastien

Better error message to help user debug error like gh-5347

上级 13c2b145
...@@ -1041,11 +1041,14 @@ class ShapeFeature(object): ...@@ -1041,11 +1041,14 @@ class ShapeFeature(object):
rval.append(None) rval.append(None)
return rval return rval
def unpack(self, s_i): def unpack(self, s_i, var):
"""Return a symbolic integer scalar for the shape element s_i. """Return a symbolic integer scalar for the shape element s_i.
The s_i argument was produced by the infer_shape() of an Op subclass. The s_i argument was produced by the infer_shape() of an Op subclass.
var: the variable that correspond to s_i. This is just for
error reporting.
""" """
# unpack the s_i that the Op returned # unpack the s_i that the Op returned
assert s_i is not None assert s_i is not None
...@@ -1059,7 +1062,10 @@ class ShapeFeature(object): ...@@ -1059,7 +1062,10 @@ class ShapeFeature(object):
isinstance(s_i, numpy.integer) or isinstance(s_i, numpy.integer) or
(isinstance(s_i, numpy.ndarray) and s_i.ndim == 0)): (isinstance(s_i, numpy.ndarray) and s_i.ndim == 0)):
# this shape is a constant # this shape is a constant
assert s_i >= 0 if s_i < 0:
msg = "There is a negative shape in the graph!"
msg += gof.utils.get_variable_trace_string(var)
raise ValueError(msg)
return T.constant(s_i, dtype='int64') return T.constant(s_i, dtype='int64')
if type(s_i) in (tuple, list): if type(s_i) in (tuple, list):
# this dimension is the same as many of the inputs # this dimension is the same as many of the inputs
...@@ -1137,7 +1143,7 @@ class ShapeFeature(object): ...@@ -1137,7 +1143,7 @@ class ShapeFeature(object):
r.type.broadcastable[i]): r.type.broadcastable[i]):
shape_vars.append(self.lscalar_one) shape_vars.append(self.lscalar_one)
else: else:
shape_vars.append(self.unpack(s[i])) shape_vars.append(self.unpack(s[i], r))
assert all([not hasattr(r.type, "broadcastable") or assert all([not hasattr(r.type, "broadcastable") or
not r.type.broadcastable[i] or not r.type.broadcastable[i] or
# The two following comparison are a speed optimization # The two following comparison are a speed optimization
...@@ -1238,7 +1244,7 @@ class ShapeFeature(object): ...@@ -1238,7 +1244,7 @@ class ShapeFeature(object):
new_shape = [] new_shape = []
for j, s_j in enumerate(prev_shape): for j, s_j in enumerate(prev_shape):
if j == i: if j == i:
new_shape.append(self.unpack(s_i)) new_shape.append(self.unpack(s_i, r))
else: else:
new_shape.append(s_j) new_shape.append(s_j)
assert all([not hasattr(r.type, "broadcastable") or assert all([not hasattr(r.type, "broadcastable") or
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论