提交 e4c45d5a authored 作者: Olivier Delalleau's avatar Olivier Delalleau

infer_shape does not survive through all exceptions anymore, and now uses…

infer_shape does not survive through all exceptions anymore, and now uses ShapeError instead of NotImplementedError to catch situations where a shape cannot be computed
上级 2114f1d0
......@@ -32,6 +32,12 @@ from basic import get_constant_value
# Utilities
class ShapeError(Exception):
"""Raised when the shape cannot be computed."""
pass
def out2in(*local_opts):
"""WRITEME """
return opt.TopoOptimizer(opt.LocalOptGroup(*local_opts),
......@@ -528,7 +534,7 @@ class ShapeFeature(object):
the cost of many Ops accurately, and generate c-code that is specific [e.g. unrolled] to
particular sizes.
If you can determine the shape only in some case, return NotImplementedError when you can't
In cases where you cannot figure out the shape, raise a ShapeError.
.. note::
......@@ -714,13 +720,22 @@ class ShapeFeature(object):
try:
o_shapes = shape_infer(node, [self.shape_of[r] for r in node.inputs])
except NotImplementedError:
except ShapeError:
o_shapes = self.default_infer_shape(node, [self.shape_of[r] for r in node.inputs])
except NotImplementedError, e:
raise NotImplementedError(
'Code called by infer_shape failed raising a '
'NotImplementedError. Raising NotImplementedError to '
'indicate that a shape cannot be computed is no longer '
'supported, and one should now use tensor.opt.ShapeError '
'instead. The original exception message is: %s' % e)
except Exception, e:
_logger.error('Failed to infer_shape from Op %s (i_shapes=%s): %s %s'% (node.op,
[self.shape_of[r] for r in node.inputs],
type(e), str(e)))
o_shapes = self.default_infer_shape(node, [self.shape_of[r] for r in node.inputs])
# We raise the exception to make sure the user knows something bad
# is going on.
raise
# this is packed information
# an element of o_shapes is either None or a tuple
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论