提交 db72e2b8 authored 作者: James Bergstra's avatar James Bergstra

ShapeFeature - better diagnostic information when unpacking infer_shape rval fails

上级 c1aeb86c
...@@ -379,10 +379,13 @@ class ShapeFeature(object): ...@@ -379,10 +379,13 @@ class ShapeFeature(object):
# #
# worst case, we loop over shape_of and replace things # worst case, we loop over shape_of and replace things
raise NotImplementedError(s_i) raise NotImplementedError(s_i)
elif s_i.type == T.lscalar: elif s_i.type.dtype[:3] in ('int', 'uint'):
if getattr(s_i.type, 'ndim', 0):
raise TypeError('Shape element must be scalar', s_i)
return s_i return s_i
else: else:
raise TypeError('Unsupported shape element', s_i) raise TypeError('Unsupported shape element',
s_i, type(s_i), getattr(s_i, 'type', None))
def set_shape(self, r, s): def set_shape(self, r, s):
assert r not in self.shape_of assert r not in self.shape_of
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论