提交 35baedd6 authored 作者: nouiz's avatar nouiz

Merge pull request #297 from delallea/fix_shape_error

Fixed tests when on_shape_error = 'raise'
...@@ -785,8 +785,8 @@ class ShapeFeature(object): ...@@ -785,8 +785,8 @@ class ShapeFeature(object):
if s is None: if s is None:
self.shape_of[r] = s self.shape_of[r] = s
else: else:
if r.ndim != len(s): if r.ndim != len(s):
raise ShapeError( raise AssertionError(
"Something inferred a shape with %d dimensions " "Something inferred a shape with %d dimensions "
"for a variable with %d dimensions." % ( "for a variable with %d dimensions." % (
len(s), r.ndim)) len(s), r.ndim))
...@@ -908,8 +908,6 @@ class ShapeFeature(object): ...@@ -908,8 +908,6 @@ class ShapeFeature(object):
o_shapes = shape_infer(node, o_shapes = shape_infer(node,
[self.shape_of[r] for r in node.inputs]) [self.shape_of[r] for r in node.inputs])
except ShapeError: except ShapeError:
if config.on_shape_error == "raise":
raise
o_shapes = self.default_infer_shape(node, [self.shape_of[r] for o_shapes = self.default_infer_shape(node, [self.shape_of[r] for
r in node.inputs]) r in node.inputs])
except NotImplementedError, e: except NotImplementedError, e:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论