提交 35baedd6 authored 作者: nouiz's avatar nouiz

Merge pull request #297 from delallea/fix_shape_error

Fixed tests when on_shape_error = 'raise'
......@@ -785,8 +785,8 @@ class ShapeFeature(object):
if s is None:
self.shape_of[r] = s
else:
if r.ndim != len(s):
raise ShapeError(
if r.ndim != len(s):
raise AssertionError(
"Something inferred a shape with %d dimensions "
"for a variable with %d dimensions." % (
len(s), r.ndim))
......@@ -908,8 +908,6 @@ class ShapeFeature(object):
o_shapes = shape_infer(node,
[self.shape_of[r] for r in node.inputs])
except ShapeError:
if config.on_shape_error == "raise":
raise
o_shapes = self.default_infer_shape(node, [self.shape_of[r] for
r in node.inputs])
except NotImplementedError, e:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论