提交 075b8f4e authored 作者: --global's avatar --global

Remove dependance on toposort in equal_computations

上级 928a8af5
......@@ -430,16 +430,18 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
elif (dx, dy) not in common and dx != dy:
return False
nds_x = gof.graph.io_toposort(in_xs, xs)
nds_y = gof.graph.io_toposort(in_ys, ys)
if len(nds_x) != len(nds_y):
return False
# Explore the two graphs, in parallel, depth first, comparing the nodes
# along the way for equality.
def compare_nodes(nd_x, nd_y):
''' Compare two nodes to determine if they perform equal computation.
This is done by comparing the ops, the number of inputs, outputs and
by ensuring that the inputs themselves are the result of equal
computation.
NOTE : This function relies on the variable common to cache
results to be more efficient.
'''
n_nodes = len(nds_x)
idx = 0
while idx < n_nodes:
nd_x = nds_x[idx]
nd_y = nds_y[idx]
if nd_x.op != nd_y.op:
return False
elif len(nd_x.inputs) != len(nd_y.inputs):
......@@ -447,21 +449,51 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
elif len(nd_x.outputs) != len(nd_y.outputs):
return False
else:
# Compare the individual inputs for equality
for dx, dy in izip(nd_x.inputs, nd_y.inputs):
if (dx, dy) not in common:
if dx != dy:
if (isinstance(dx, tensor.Constant) and
isinstance(dy, tensor.Constant)):
if not dx.equals(dy):
return False
else:
pass
else:
# Equality between the variables is unknown, compare
# their respective owners, if they have some
if (dx.owner and dy.owner and
dx.owner.outputs.index(dx) ==
dy.owner.outputs.index(dy)):
nodes_equal = compare_nodes(dx.owner, dy.owner)
if not nodes_equal:
return False
# If both variables don't have an owner, then they are
# inputs and can be directly compared
elif dx.owner is None and dy.owner is None:
if dx != dy:
if (isinstance(dx, tensor.Constant) and
isinstance(dy, tensor.Constant)):
if not dx.equals(dy):
return False
else:
return False
else:
return False
# If the code reaches this statement then the inputs are pair-wise
# equivalent so the outputs of the current nodes are also
# pair-wise equivalents
for dx, dy in izip(nd_x.outputs, nd_y.outputs):
common.add((dx, dy))
idx += 1
return True
# Validate that each xs[i], ys[i] pair represents the same computation
for i in range(len(xs)):
if xs[0].owner:
# The case where xs and ys don't both have an owner
# have already been adressed.
is_equal = compare_nodes(xs[i].owner, ys[i].owner)
if not is_equal:
return False
return True
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论