提交 9c6930f5 authored 作者: Frederic's avatar Frederic

Add assert, remove useless check and move variable closer to where they are used.

上级 ed4b2b14
...@@ -391,6 +391,7 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None): ...@@ -391,6 +391,7 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
or `ys`. or `ys`.
''' '''
assert len(xs) == len(ys)
if in_xs is None: if in_xs is None:
in_xs = [] in_xs = []
if in_ys is None: if in_ys is None:
...@@ -401,7 +402,7 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None): ...@@ -401,7 +402,7 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
return False return False
if y.owner and not x.owner: if y.owner and not x.owner:
return False return False
if x.owner and y.owner: if x.owner: # Check above tell that y.owner eval to True too.
if x.owner.outputs.index(x) != y.owner.outputs.index(y): if x.owner.outputs.index(x) != y.owner.outputs.index(y):
return False return False
if len(in_xs) != len(in_ys): if len(in_xs) != len(in_ys):
...@@ -415,12 +416,10 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None): ...@@ -415,12 +416,10 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
if len(nds_x) != len(nds_y): if len(nds_x) != len(nds_y):
return False return False
common = set(zip(in_xs, in_ys)) common = set(zip(in_xs, in_ys))
n_nodes = len(nds_x)
for dx, dy in izip(xs, ys): for dx, dy in izip(xs, ys):
if not dx.owner or not dy.owner: # We checked above that both dx and dy have an owner or not
if dy.owner or dx.owner: if not dx.owner:
return False if (isinstance(dx, tensor.Constant) and
elif (isinstance(dx, tensor.Constant) and
isinstance(dy, tensor.Constant)): isinstance(dy, tensor.Constant)):
if not dx.equals(dy): if not dx.equals(dy):
return False return False
...@@ -429,6 +428,7 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None): ...@@ -429,6 +428,7 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
elif (dx, dy) not in common and dx != dy: elif (dx, dy) not in common and dx != dy:
return False return False
n_nodes = len(nds_x)
idx = 0 idx = 0
while idx < n_nodes: while idx < n_nodes:
nd_x = nds_x[idx] nd_x = nds_x[idx]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论