提交 a8df83f8 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #2769 from carriepl/equal_computations

Remove dependance on io_toposort in equal_computations()
...@@ -432,16 +432,18 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None): ...@@ -432,16 +432,18 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
elif (dx, dy) not in common and dx != dy: elif (dx, dy) not in common and dx != dy:
return False return False
nds_x = gof.graph.io_toposort(in_xs, xs) # Explore the two graphs, in parallel, depth first, comparing the nodes
nds_y = gof.graph.io_toposort(in_ys, ys) # along the way for equality.
if len(nds_x) != len(nds_y): def compare_nodes(nd_x, nd_y):
return False ''' Compare two nodes to determine if they perform equal computation.
This is done by comparing the ops, the number of inputs, outputs and
by ensuring that the inputs themselves are the result of equal
computation.
NOTE : This function relies on the variable common to cache
results to be more efficient.
'''
n_nodes = len(nds_x)
idx = 0
while idx < n_nodes:
nd_x = nds_x[idx]
nd_y = nds_y[idx]
if nd_x.op != nd_y.op: if nd_x.op != nd_y.op:
return False return False
elif len(nd_x.inputs) != len(nd_y.inputs): elif len(nd_x.inputs) != len(nd_y.inputs):
...@@ -449,21 +451,51 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None): ...@@ -449,21 +451,51 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
elif len(nd_x.outputs) != len(nd_y.outputs): elif len(nd_x.outputs) != len(nd_y.outputs):
return False return False
else: else:
# Compare the individual inputs for equality
for dx, dy in izip(nd_x.inputs, nd_y.inputs): for dx, dy in izip(nd_x.inputs, nd_y.inputs):
if (dx, dy) not in common: if (dx, dy) not in common:
# Equality between the variables is unknown, compare
# their respective owners, if they have some
if (dx.owner and dy.owner and
dx.owner.outputs.index(dx) ==
dy.owner.outputs.index(dy)):
nodes_equal = compare_nodes(dx.owner, dy.owner)
if not nodes_equal:
return False
# If both variables don't have an owner, then they are
# inputs and can be directly compared
elif dx.owner is None and dy.owner is None:
if dx != dy: if dx != dy:
if (isinstance(dx, tensor.Constant) and if (isinstance(dx, tensor.Constant) and
isinstance(dy, tensor.Constant)): isinstance(dy, tensor.Constant)):
if not dx.equals(dy): if not dx.equals(dy):
return False return False
else: else:
pass return False
else: else:
return False return False
# If the code reaches this statement then the inputs are pair-wise
# equivalent so the outputs of the current nodes are also
# pair-wise equivalents
for dx, dy in izip(nd_x.outputs, nd_y.outputs): for dx, dy in izip(nd_x.outputs, nd_y.outputs):
common.add((dx, dy)) common.add((dx, dy))
idx += 1
return True
# Validate that each xs[i], ys[i] pair represents the same computation
for i in range(len(xs)):
if xs[0].owner:
# The case where xs and ys don't both have an owner
# have already been adressed.
is_equal = compare_nodes(xs[i].owner, ys[i].owner)
if not is_equal:
return False
return True return True
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论