提交 82efd74e authored 作者: Razvan Pascanu's avatar Razvan Pascanu

More sane check if two graph represent same computations.

The main thing I was looking for was efficiency. The new algorithm is not recursive, it considers a bunch of outputs at the same time, and it uses toposort to traverse the graph, making sure each node is seen only once.
上级 befdc27d
......@@ -301,19 +301,26 @@ class Scan(Op):
return False
elif not len(self.outputs) == len(other.outputs):
return False
elif self.info != other.info:
return False
else:
# If everything went OK up to here, there is still one thing to
# check. Namely, do the internal graph represent same
# computations
for x,y in zip(self.inputs, other.inputs):
if not scan_utils.equal_computations(x,y):
return False
for x,y in zip(self.outputs, other.outputs):
if not scan_utils.equal_computations(x,y):
return False
if not scan_utils.equal_computations(self.inputs,
other.inputs,
strict = False):
return False
if not scan_utils.equal_computations(self.outputs,
other.outputs,
self.inputs,
other.inputs):
return False
# If they do, then they need to match in other small details
# like name, mode, etc.
return self.info == other.info
return True
def __str__(self):
if self.gpu:
......
......@@ -229,40 +229,85 @@ def expand( tensor_var, size):
, dtype = tensor_var.dtype)
return tensor.set_subtensor(empty[:shapes[0]], tensor_var)
def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
'''
Checks if to theano graphs represent the same computations (with
equivalence of inputs defined by map). Inputs are always assumed
equal if strict is set to False.
'''
import time
t00 = time.time()
if in_xs is None:
in_xs = []
if in_ys is None:
in_ys = []
for x,y in zip(xs,ys):
if x.owner and not y.owner:
return False
if y.owner and not x.owner:
return False
if x.owner and y.owner:
if x.owner.outputs.index(x) != y.owner.outputs.index(y):
return False
nds_x = gof.graph.io_toposort(in_xs, xs)
nds_y = gof.graph.io_toposort(in_ys, ys)
if len(nds_x) != len(nds_y):
return False
common = set(zip(in_xs,in_ys))
n_nodes = len(nds_x)
cont = True
idx = 0
for dx,dy in zip(xs,ys):
if not dx.owner or not dy.owner:
if dy.owner or dx.owner:
return False
elif (isinstance(dx, tensor.Constant) and
isinstance(dy, tensor.Constant) and
dx.data == dy.data):
pass
elif strict:
if dx != dy:
return False
else:
if dx.type != dy.type:
return False
while cont and idx < n_nodes:
nd_x = nds_x[idx]
nd_y = nds_y[idx]
if nd_x.op != nd_y.op:
cont = False
elif len(nd_x.inputs) != len(nd_y.inputs):
cont = False
elif len(nd_x.outputs) != len(nd_y.outputs):
cont = False
else:
for dx,dy in zip(nd_x.inputs, nd_y.inputs):
if (dx,dy) not in common:
if strict and dx!= dy:
if (isinstance(dx, tensor.Constant) and
isinstance(dy, tensor.Constant) and
dx.data == dy.data):
pass
else:
cont = False
else:
cont = cont and (dx.type == dy.type)
if cont:
for dx,dy in zip(nd_x.outputs, nd_y.outputs):
common.add((dx,dy))
idx += 1
return cont
def equal_computations(x,y, strict=False):
'''
Checks if to theano graphs represent the same computations (applied to
different inputs).
'''
if not x.type == y.type:
return False
elif not x.owner and not y.owner:
if not strict:
return True
else:
if isinstance(x, tensor.Constant):
# not they both have the same type
return x.data == y.data
else:
return x == y
elif x.owner and not y.owner:
return False
elif not x.owner and y.owner:
return False
elif not x.owner.op == y.owner.op:
return False
elif not len(x.owner.inputs) == len(y.owner.inputs):
return False
else:
for xx,yy in zip(x.owner.inputs,y.owner.inputs):
if not equal_computations(xx,yy):
return False
return True
def infer_shape(outs, inputs, input_shapes):
'''
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论