提交 27780095 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Documented `equal_computations` and removed useless flag strict.

上级 469803aa
......@@ -286,8 +286,7 @@ class Scan(PureOp):
if not scan_utils.equal_computations(self.outputs,
other.outputs,
self.inputs,
other.inputs,
strict=True):
other.inputs):
return False
# If they do, then they need to match in other small details
......
......@@ -268,11 +268,21 @@ def expand(tensor_var, size):
return tensor.set_subtensor(empty[:shapes[0]], tensor_var)
def equal_computations(xs, ys, in_xs=None, in_ys=None, strict=True):
'''
Checks if to theano graphs represent the same computations (with
equivalence of inputs defined by map). Inputs are always assumed
equal if strict is set to False.
def equal_computations(xs, ys, in_xs=None, in_ys=None):
'''Checks if Theano graphs represent the same computations.
The two lists `xs`, `ys` should have the same number of entries. The
function checks if for any corresponding pair `(x,y)` from `zip(xs,ys)`
`x` and `y` represent the same computations on the same variables
(unless equivalences are provided unsing `in_xs`, `in_ys`)
If `in_xs` and `in_ys` are provided, then these nodes are considered
equivalent even if they do not compare equal (they should however have
the same type). These lists could be used for example to provide
equivalence between inputs of two different graphs if what we want is
actually to see if (regardless of the input units) the two graph
actually perform the same computations on them.
'''
if in_xs is None:
in_xs = []
......@@ -314,12 +324,8 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None, strict=True):
else:
pass
else:
if not strict:
if dx.type != dy.type:
return False
else:
if (dx, dy) not in common:
return False
if (dx, dy) not in common:
return False
while cont and idx < n_nodes:
nd_x = nds_x[idx]
......@@ -333,7 +339,7 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None, strict=True):
else:
for dx, dy in zip(nd_x.inputs, nd_y.inputs):
if (dx, dy) not in common:
if strict and dx != dy:
if dx != dy:
if (isinstance(dx, tensor.Constant) and
isinstance(dy, tensor.Constant)):
if not (numpy.all(dx.data == dy.data) and
......@@ -344,8 +350,6 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None, strict=True):
pass
else:
cont = False
else:
cont = cont and (dx.type == dy.type)
if cont:
for dx, dy in zip(nd_x.outputs, nd_y.outputs):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论