提交 27780095 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Documented `equal_computations` and removed useless flag strict.

上级 469803aa
...@@ -286,8 +286,7 @@ class Scan(PureOp): ...@@ -286,8 +286,7 @@ class Scan(PureOp):
if not scan_utils.equal_computations(self.outputs, if not scan_utils.equal_computations(self.outputs,
other.outputs, other.outputs,
self.inputs, self.inputs,
other.inputs, other.inputs):
strict=True):
return False return False
# If they do, then they need to match in other small details # If they do, then they need to match in other small details
......
...@@ -268,11 +268,21 @@ def expand(tensor_var, size): ...@@ -268,11 +268,21 @@ def expand(tensor_var, size):
return tensor.set_subtensor(empty[:shapes[0]], tensor_var) return tensor.set_subtensor(empty[:shapes[0]], tensor_var)
def equal_computations(xs, ys, in_xs=None, in_ys=None, strict=True): def equal_computations(xs, ys, in_xs=None, in_ys=None):
''' '''Checks if Theano graphs represent the same computations.
Checks if to theano graphs represent the same computations (with
equivalence of inputs defined by map). Inputs are always assumed The two lists `xs`, `ys` should have the same number of entries. The
equal if strict is set to False. function checks if for any corresponding pair `(x,y)` from `zip(xs,ys)`
`x` and `y` represent the same computations on the same variables
(unless equivalences are provided unsing `in_xs`, `in_ys`)
If `in_xs` and `in_ys` are provided, then these nodes are considered
equivalent even if they do not compare equal (they should however have
the same type). These lists could be used for example to provide
equivalence between inputs of two different graphs if what we want is
actually to see if (regardless of the input units) the two graph
actually perform the same computations on them.
''' '''
if in_xs is None: if in_xs is None:
in_xs = [] in_xs = []
...@@ -314,12 +324,8 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None, strict=True): ...@@ -314,12 +324,8 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None, strict=True):
else: else:
pass pass
else: else:
if not strict: if (dx, dy) not in common:
if dx.type != dy.type: return False
return False
else:
if (dx, dy) not in common:
return False
while cont and idx < n_nodes: while cont and idx < n_nodes:
nd_x = nds_x[idx] nd_x = nds_x[idx]
...@@ -333,7 +339,7 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None, strict=True): ...@@ -333,7 +339,7 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None, strict=True):
else: else:
for dx, dy in zip(nd_x.inputs, nd_y.inputs): for dx, dy in zip(nd_x.inputs, nd_y.inputs):
if (dx, dy) not in common: if (dx, dy) not in common:
if strict and dx != dy: if dx != dy:
if (isinstance(dx, tensor.Constant) and if (isinstance(dx, tensor.Constant) and
isinstance(dy, tensor.Constant)): isinstance(dy, tensor.Constant)):
if not (numpy.all(dx.data == dy.data) and if not (numpy.all(dx.data == dy.data) and
...@@ -344,8 +350,6 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None, strict=True): ...@@ -344,8 +350,6 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None, strict=True):
pass pass
else: else:
cont = False cont = False
else:
cont = cont and (dx.type == dy.type)
if cont: if cont:
for dx, dy in zip(nd_x.outputs, nd_y.outputs): for dx, dy in zip(nd_x.outputs, nd_y.outputs):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论