提交 7d532774 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Merge branch 'documentation_equal_computation' of…

Merge branch 'documentation_equal_computation' of https://github.com/pascanur/Theano into pascanur-documentation_equal_computation Conflicts: theano/scan_module/scan_utils.py
......@@ -286,8 +286,7 @@ class Scan(PureOp):
if not scan_utils.equal_computations(self.outputs,
other.outputs,
self.inputs,
other.inputs,
strict=True):
other.inputs):
return False
# If they do, then they need to match in other small details
......
......@@ -268,11 +268,21 @@ def expand(tensor_var, size):
return tensor.set_subtensor(empty[:shapes[0]], tensor_var)
def equal_computations(xs, ys, in_xs=None, in_ys=None, strict=True):
'''
Checks if to theano graphs represent the same computations (with
equivalence of inputs defined by map). Inputs are always assumed
equal if strict is set to False.
def equal_computations(xs, ys, in_xs=None, in_ys=None):
'''Checks if Theano graphs represent the same computations.
The two lists `xs`, `ys` should have the same number of entries. The
function checks if for any corresponding pair `(x,y)` from `zip(xs,ys)`
`x` and `y` represent the same computations on the same variables
(unless equivalences are provided using `in_xs`, `in_ys`).
If `in_xs` and `in_ys` are provided, then when comparing a node `x` with
a node `y` they are automatically considered as equal if there is some
index `i` such that `x == in_xs[i]` and `y == in_ys[i]`(and they both
have the same type). Note that `x` and `y` can be in the list `xs` and
`ys`, but also represent subgraphs of a computational graph in `xs`
or `ys`.
'''
if in_xs is None:
in_xs = []
......@@ -313,12 +323,7 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None, strict=True):
return False
else:
pass
else:
if not strict:
if dx.type != dy.type:
return False
else:
if (dx, dy) not in common and dx != dy:
elif (dx, dy) not in common and dx != dy:
return False
while cont and idx < n_nodes:
......@@ -333,7 +338,7 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None, strict=True):
else:
for dx, dy in zip(nd_x.inputs, nd_y.inputs):
if (dx, dy) not in common:
if strict and dx != dy:
if dx != dy:
if (isinstance(dx, tensor.Constant) and
isinstance(dy, tensor.Constant)):
if not (numpy.all(dx.data == dy.data) and
......@@ -344,8 +349,6 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None, strict=True):
pass
else:
cont = False
else:
cont = cont and (dx.type == dy.type)
if cont:
for dx, dy in zip(nd_x.outputs, nd_y.outputs):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论