提交 82efd74e authored 作者: Razvan Pascanu's avatar Razvan Pascanu

More sane check if two graph represent same computations.

The main thing I was looking for was efficiency. The new algorithm is not recursive, it considers a bunch of outputs at the same time, and it uses toposort to traverse the graph, making sure each node is seen only once.
上级 befdc27d
...@@ -301,19 +301,26 @@ class Scan(Op): ...@@ -301,19 +301,26 @@ class Scan(Op):
return False return False
elif not len(self.outputs) == len(other.outputs): elif not len(self.outputs) == len(other.outputs):
return False return False
elif self.info != other.info:
return False
else: else:
# If everything went OK up to here, there is still one thing to # If everything went OK up to here, there is still one thing to
# check. Namely, do the internal graph represent same # check. Namely, do the internal graph represent same
# computations # computations
for x,y in zip(self.inputs, other.inputs): if not scan_utils.equal_computations(self.inputs,
if not scan_utils.equal_computations(x,y): other.inputs,
strict = False):
return False return False
for x,y in zip(self.outputs, other.outputs):
if not scan_utils.equal_computations(x,y): if not scan_utils.equal_computations(self.outputs,
other.outputs,
self.inputs,
other.inputs):
return False return False
# If they do, then they need to match in other small details # If they do, then they need to match in other small details
# like name, mode, etc. # like name, mode, etc.
return self.info == other.info return True
def __str__(self): def __str__(self):
if self.gpu: if self.gpu:
......
...@@ -229,40 +229,85 @@ def expand( tensor_var, size): ...@@ -229,40 +229,85 @@ def expand( tensor_var, size):
, dtype = tensor_var.dtype) , dtype = tensor_var.dtype)
return tensor.set_subtensor(empty[:shapes[0]], tensor_var) return tensor.set_subtensor(empty[:shapes[0]], tensor_var)
def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
'''
Checks if to theano graphs represent the same computations (with
equivalence of inputs defined by map). Inputs are always assumed
equal if strict is set to False.
'''
import time
t00 = time.time()
if in_xs is None:
in_xs = []
if in_ys is None:
in_ys = []
for x,y in zip(xs,ys):
if x.owner and not y.owner:
def equal_computations(x,y, strict=False):
'''
Checks if to theano graphs represent the same computations (applied to
different inputs).
'''
if not x.type == y.type:
return False return False
elif not x.owner and not y.owner: if y.owner and not x.owner:
if not strict:
return True
else:
if isinstance(x, tensor.Constant):
# not they both have the same type
return x.data == y.data
else:
return x == y
elif x.owner and not y.owner:
return False return False
elif not x.owner and y.owner: if x.owner and y.owner:
if x.owner.outputs.index(x) != y.owner.outputs.index(y):
return False return False
elif not x.owner.op == y.owner.op:
nds_x = gof.graph.io_toposort(in_xs, xs)
nds_y = gof.graph.io_toposort(in_ys, ys)
if len(nds_x) != len(nds_y):
return False return False
elif not len(x.owner.inputs) == len(y.owner.inputs): common = set(zip(in_xs,in_ys))
n_nodes = len(nds_x)
cont = True
idx = 0
for dx,dy in zip(xs,ys):
if not dx.owner or not dy.owner:
if dy.owner or dx.owner:
return False
elif (isinstance(dx, tensor.Constant) and
isinstance(dy, tensor.Constant) and
dx.data == dy.data):
pass
elif strict:
if dx != dy:
return False return False
else: else:
for xx,yy in zip(x.owner.inputs,y.owner.inputs): if dx.type != dy.type:
if not equal_computations(xx,yy):
return False return False
return True
while cont and idx < n_nodes:
nd_x = nds_x[idx]
nd_y = nds_y[idx]
if nd_x.op != nd_y.op:
cont = False
elif len(nd_x.inputs) != len(nd_y.inputs):
cont = False
elif len(nd_x.outputs) != len(nd_y.outputs):
cont = False
else:
for dx,dy in zip(nd_x.inputs, nd_y.inputs):
if (dx,dy) not in common:
if strict and dx!= dy:
if (isinstance(dx, tensor.Constant) and
isinstance(dy, tensor.Constant) and
dx.data == dy.data):
pass
else:
cont = False
else:
cont = cont and (dx.type == dy.type)
if cont:
for dx,dy in zip(nd_x.outputs, nd_y.outputs):
common.add((dx,dy))
idx += 1
return cont
def infer_shape(outs, inputs, input_shapes): def infer_shape(outs, inputs, input_shapes):
''' '''
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论