提交 1838b8e2 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #2861 from SinaHonari/issue2799

first commit to Make equal_computations more efficient
...@@ -420,6 +420,7 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None): ...@@ -420,6 +420,7 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
return False return False
common = set(zip(in_xs, in_ys)) common = set(zip(in_xs, in_ys))
different = set()
for dx, dy in izip(xs, ys): for dx, dy in izip(xs, ys):
# We checked above that both dx and dy have an owner or not # We checked above that both dx and dy have an owner or not
if not dx.owner: if not dx.owner:
...@@ -434,7 +435,7 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None): ...@@ -434,7 +435,7 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
# Explore the two graphs, in parallel, depth first, comparing the nodes # Explore the two graphs, in parallel, depth first, comparing the nodes
# along the way for equality. # along the way for equality.
def compare_nodes(nd_x, nd_y): def compare_nodes(nd_x, nd_y, common, different):
''' Compare two nodes to determine if they perform equal computation. ''' Compare two nodes to determine if they perform equal computation.
This is done by comparing the ops, the number of inputs, outputs and This is done by comparing the ops, the number of inputs, outputs and
by ensuring that the inputs themselves are the result of equal by ensuring that the inputs themselves are the result of equal
...@@ -451,6 +452,16 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None): ...@@ -451,6 +452,16 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
elif len(nd_x.outputs) != len(nd_y.outputs): elif len(nd_x.outputs) != len(nd_y.outputs):
return False return False
else: else:
all_in_common=True
for dx, dy in izip(nd_x.outputs, nd_y.outputs):
if (dx, dy) in different:
return False
if (dx, dy) not in common:
all_in_common = False
if all_in_common:
return True
# Compare the individual inputs for equality # Compare the individual inputs for equality
for dx, dy in izip(nd_x.inputs, nd_y.inputs): for dx, dy in izip(nd_x.inputs, nd_y.inputs):
if (dx, dy) not in common: if (dx, dy) not in common:
...@@ -461,8 +472,9 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None): ...@@ -461,8 +472,9 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
dx.owner.outputs.index(dx) == dx.owner.outputs.index(dx) ==
dy.owner.outputs.index(dy)): dy.owner.outputs.index(dy)):
nodes_equal = compare_nodes(dx.owner, dy.owner) nodes_equal = compare_nodes(dx.owner, dy.owner, common, different)
if not nodes_equal: if not nodes_equal:
different.add((dx, dy))
return False return False
# If both variables don't have an owner, then they are # If both variables don't have an owner, then they are
...@@ -493,7 +505,7 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None): ...@@ -493,7 +505,7 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
if xs[i].owner: if xs[i].owner:
# The case where pairs of x[i]s and y[i]s don't both have an owner # The case where pairs of x[i]s and y[i]s don't both have an owner
# have already been adressed. # have already been adressed.
is_equal = compare_nodes(xs[i].owner, ys[i].owner) is_equal = compare_nodes(xs[i].owner, ys[i].owner, common, different)
if not is_equal: if not is_equal:
return False return False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论