提交 ec7681c1 authored 作者: Frederic's avatar Frederic

Code refactoring for later change

上级 8d3a67b7
......@@ -921,6 +921,39 @@ class ShapeFeature(object):
constants?? That would be confusing.
"""
def get_node_infer_shape(self, node):
try:
shape_infer = node.op.infer_shape
except AttributeError:
shape_infer = self.default_infer_shape
try:
o_shapes = shape_infer(node,
[self.shape_of[r] for r in node.inputs])
except ShapeError:
o_shapes = self.default_infer_shape(node, [self.shape_of[r] for
r in node.inputs])
except NotImplementedError as e:
raise NotImplementedError(
'Code called by infer_shape failed raising a '
'NotImplementedError. Raising NotImplementedError to '
'indicate that a shape cannot be computed is no longer '
'supported, and one should now use tensor.ShapeError '
'instead. The original exception message is: %s' % e)
except Exception as e:
msg = ('Failed to infer_shape from Op %s.\nInput shapes: '
'%s\nException encountered during infer_shape: '
'%s\nException message: %s\nTraceback: %s') % (
node.op, [self.shape_of[r] for r in node.inputs],
type(e), str(e), traceback.format_exc())
if config.on_shape_error == "raise":
raise Exception(msg)
else:
_logger.warning(msg)
o_shapes = self.default_infer_shape(
node, [self.shape_of[r] for r in node.inputs])
return o_shapes
def shape_ir(self, i, r):
"""Return symbolic r.shape[i] for tensor variable r, int i."""
......@@ -1207,36 +1240,7 @@ class ShapeFeature(object):
# make sure we have shapes for the inputs
self.init_r(r)
try:
shape_infer = node.op.infer_shape
except AttributeError:
shape_infer = self.default_infer_shape
try:
o_shapes = shape_infer(node,
[self.shape_of[r] for r in node.inputs])
except ShapeError:
o_shapes = self.default_infer_shape(node, [self.shape_of[r] for
r in node.inputs])
except NotImplementedError as e:
raise NotImplementedError(
'Code called by infer_shape failed raising a '
'NotImplementedError. Raising NotImplementedError to '
'indicate that a shape cannot be computed is no longer '
'supported, and one should now use tensor.ShapeError '
'instead. The original exception message is: %s' % e)
except Exception as e:
msg = ('Failed to infer_shape from Op %s.\nInput shapes: '
'%s\nException encountered during infer_shape: '
'%s\nException message: %s\nTraceback: %s') % (
node.op, [self.shape_of[r] for r in node.inputs],
type(e), str(e), traceback.format_exc())
if config.on_shape_error == "raise":
raise Exception(msg)
else:
_logger.warning(msg)
o_shapes = self.default_infer_shape(
node, [self.shape_of[r] for r in node.inputs])
o_shapes = self.get_node_infer_shape(node)
# this is packed information
# an element of o_shapes is either None or a tuple
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论