提交 73d4fe6c authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Fixed tests when on_shape_error = 'raise'

It is expected that some implementations of 'infer_shape' raise a ShapeError to indicate that they do not know how to compute the shape. In such a situation, we should silently fall back to the default implementation, regardless of the value of 'on_shape_error' (whose name may be misleading for developers, since we want to catch all exceptions *except* ShapeError). Also, when a bad shape was inferred, we should not raise a ShapeError (which may be caught silently), but rather something that indicates that there is a serious bug somewhere, like an AssertionError.
上级 5c0887dd
...@@ -786,7 +786,7 @@ class ShapeFeature(object): ...@@ -786,7 +786,7 @@ class ShapeFeature(object):
self.shape_of[r] = s self.shape_of[r] = s
else: else:
if r.ndim != len(s): if r.ndim != len(s):
raise ShapeError( raise AssertionError(
"Something inferred a shape with %d dimensions " "Something inferred a shape with %d dimensions "
"for a variable with %d dimensions." % ( "for a variable with %d dimensions." % (
len(s), r.ndim)) len(s), r.ndim))
...@@ -908,8 +908,6 @@ class ShapeFeature(object): ...@@ -908,8 +908,6 @@ class ShapeFeature(object):
o_shapes = shape_infer(node, o_shapes = shape_infer(node,
[self.shape_of[r] for r in node.inputs]) [self.shape_of[r] for r in node.inputs])
except ShapeError: except ShapeError:
if config.on_shape_error == "raise":
raise
o_shapes = self.default_infer_shape(node, [self.shape_of[r] for o_shapes = self.default_infer_shape(node, [self.shape_of[r] for
r in node.inputs]) r in node.inputs])
except NotImplementedError, e: except NotImplementedError, e:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论