提交 ed4b2b14 authored 作者: Frederic's avatar Frederic

code simplification

上级 63e73e69
...@@ -416,8 +416,6 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None): ...@@ -416,8 +416,6 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
return False return False
common = set(zip(in_xs, in_ys)) common = set(zip(in_xs, in_ys))
n_nodes = len(nds_x) n_nodes = len(nds_x)
cont = True
idx = 0
for dx, dy in izip(xs, ys): for dx, dy in izip(xs, ys):
if not dx.owner or not dy.owner: if not dx.owner or not dy.owner:
if dy.owner or dx.owner: if dy.owner or dx.owner:
...@@ -431,15 +429,16 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None): ...@@ -431,15 +429,16 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
elif (dx, dy) not in common and dx != dy: elif (dx, dy) not in common and dx != dy:
return False return False
while cont and idx < n_nodes: idx = 0
while idx < n_nodes:
nd_x = nds_x[idx] nd_x = nds_x[idx]
nd_y = nds_y[idx] nd_y = nds_y[idx]
if nd_x.op != nd_y.op: if nd_x.op != nd_y.op:
cont = False return False
elif len(nd_x.inputs) != len(nd_y.inputs): elif len(nd_x.inputs) != len(nd_y.inputs):
cont = False return False
elif len(nd_x.outputs) != len(nd_y.outputs): elif len(nd_x.outputs) != len(nd_y.outputs):
cont = False return False
else: else:
for dx, dy in izip(nd_x.inputs, nd_y.inputs): for dx, dy in izip(nd_x.inputs, nd_y.inputs):
if (dx, dy) not in common: if (dx, dy) not in common:
...@@ -451,14 +450,13 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None): ...@@ -451,14 +450,13 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
else: else:
pass pass
else: else:
cont = False return False
if cont:
for dx, dy in izip(nd_x.outputs, nd_y.outputs): for dx, dy in izip(nd_x.outputs, nd_y.outputs):
common.add((dx, dy)) common.add((dx, dy))
idx += 1 idx += 1
return cont return True
def infer_shape(outs, inputs, input_shapes): def infer_shape(outs, inputs, input_shapes):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论