提交 520dcf13 authored 作者: Frederic's avatar Frederic

code simplification.

上级 ce8cc0f8
...@@ -1041,13 +1041,13 @@ class ShapeFeature(object): ...@@ -1041,13 +1041,13 @@ class ShapeFeature(object):
# Ensure shapes are in 'int64'. This is to make sure the assert # Ensure shapes are in 'int64'. This is to make sure the assert
# found in the `local_useless_subtensor` optimization does not fail. # found in the `local_useless_subtensor` optimization does not fail.
new_shape = []
for sh_idx, sh in enumerate(o_shapes): for sh_idx, sh in enumerate(o_shapes):
if sh is None: if sh is None:
continue continue
if not isinstance(sh, (list, tuple)): if not isinstance(sh, (list, tuple)):
raise ValueError("infer_shape of %s didn't return a list of" raise ValueError("infer_shape of %s didn't return a list of"
" list. It returned '%s'" % (str(node), str(o_shapes))) " list. It returned '%s'" % (str(node), str(o_shapes)))
new_shape = []
for i, d in enumerate(sh): for i, d in enumerate(sh):
# Note: we ignore any shape element that is not typed (i.e., # Note: we ignore any shape element that is not typed (i.e.,
# does not have a 'dtype' attribute). This means there may # does not have a 'dtype' attribute). This means there may
...@@ -1064,7 +1064,6 @@ class ShapeFeature(object): ...@@ -1064,7 +1064,6 @@ class ShapeFeature(object):
# 'int64'. # 'int64'.
new_shape += sh[len(new_shape):] new_shape += sh[len(new_shape):]
o_shapes[sh_idx] = tuple(new_shape) o_shapes[sh_idx] = tuple(new_shape)
new_shape = []
for r, s in izip(node.outputs, o_shapes): for r, s in izip(node.outputs, o_shapes):
self.set_shape(r, s) self.set_shape(r, s)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论