提交 180520eb authored 作者: Sina Honari's avatar Sina Honari

first commit to Make equal_computations more efficient

上级 6f01cf86
...@@ -420,6 +420,7 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None): ...@@ -420,6 +420,7 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
return False return False
common = set(zip(in_xs, in_ys)) common = set(zip(in_xs, in_ys))
different = set()
for dx, dy in izip(xs, ys): for dx, dy in izip(xs, ys):
# We checked above that both dx and dy have an owner or not # We checked above that both dx and dy have an owner or not
if not dx.owner: if not dx.owner:
...@@ -434,7 +435,7 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None): ...@@ -434,7 +435,7 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
# Explore the two graphs, in parallel, depth first, comparing the nodes # Explore the two graphs, in parallel, depth first, comparing the nodes
# along the way for equality. # along the way for equality.
def compare_nodes(nd_x, nd_y): def compare_nodes(nd_x, nd_y, common, different):
''' Compare two nodes to determine if they perform equal computation. ''' Compare two nodes to determine if they perform equal computation.
This is done by comparing the ops, the number of inputs, outputs and This is done by comparing the ops, the number of inputs, outputs and
by ensuring that the inputs themselves are the result of equal by ensuring that the inputs themselves are the result of equal
...@@ -451,6 +452,18 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None): ...@@ -451,6 +452,18 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
elif len(nd_x.outputs) != len(nd_y.outputs): elif len(nd_x.outputs) != len(nd_y.outputs):
return False return False
else: else:
for dx, dy in izip(nd_x.outputs, nd_y.outputs):
if (dx, dy) in different:
return False
all_in_common=True
for dx, dy in izip(nd_x.outputs, nd_y.outputs):
if (dx, dy) not in common:
all_in_common = False
break
if all_in_common:
return True
# Compare the individual inputs for equality # Compare the individual inputs for equality
for dx, dy in izip(nd_x.inputs, nd_y.inputs): for dx, dy in izip(nd_x.inputs, nd_y.inputs):
if (dx, dy) not in common: if (dx, dy) not in common:
...@@ -461,8 +474,9 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None): ...@@ -461,8 +474,9 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
dx.owner.outputs.index(dx) == dx.owner.outputs.index(dx) ==
dy.owner.outputs.index(dy)): dy.owner.outputs.index(dy)):
nodes_equal = compare_nodes(dx.owner, dy.owner) nodes_equal = compare_nodes(dx.owner, dy.owner, common, different)
if not nodes_equal: if not nodes_equal:
different.add((dx, dy))
return False return False
# If both variables don't have an owner, then they are # If both variables don't have an owner, then they are
...@@ -493,7 +507,7 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None): ...@@ -493,7 +507,7 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
if xs[i].owner: if xs[i].owner:
# The case where pairs of x[i]s and y[i]s don't both have an owner # The case where pairs of x[i]s and y[i]s don't both have an owner
# have already been adressed. # have already been adressed.
is_equal = compare_nodes(xs[i].owner, ys[i].owner) is_equal = compare_nodes(xs[i].owner, ys[i].owner, common, different)
if not is_equal: if not is_equal:
return False return False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论