提交 7d532774 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Merge branch 'documentation_equal_computation' of…

Merge branch 'documentation_equal_computation' of https://github.com/pascanur/Theano into pascanur-documentation_equal_computation Conflicts: theano/scan_module/scan_utils.py
...@@ -286,8 +286,7 @@ class Scan(PureOp): ...@@ -286,8 +286,7 @@ class Scan(PureOp):
if not scan_utils.equal_computations(self.outputs, if not scan_utils.equal_computations(self.outputs,
other.outputs, other.outputs,
self.inputs, self.inputs,
other.inputs, other.inputs):
strict=True):
return False return False
# If they do, then they need to match in other small details # If they do, then they need to match in other small details
......
...@@ -268,11 +268,21 @@ def expand(tensor_var, size): ...@@ -268,11 +268,21 @@ def expand(tensor_var, size):
return tensor.set_subtensor(empty[:shapes[0]], tensor_var) return tensor.set_subtensor(empty[:shapes[0]], tensor_var)
def equal_computations(xs, ys, in_xs=None, in_ys=None, strict=True): def equal_computations(xs, ys, in_xs=None, in_ys=None):
''' '''Checks if Theano graphs represent the same computations.
Checks if to theano graphs represent the same computations (with
equivalence of inputs defined by map). Inputs are always assumed The two lists `xs`, `ys` should have the same number of entries. The
equal if strict is set to False. function checks if for any corresponding pair `(x,y)` from `zip(xs,ys)`
`x` and `y` represent the same computations on the same variables
(unless equivalences are provided using `in_xs`, `in_ys`).
If `in_xs` and `in_ys` are provided, then when comparing a node `x` with
a node `y` they are automatically considered as equal if there is some
index `i` such that `x == in_xs[i]` and `y == in_ys[i]`(and they both
have the same type). Note that `x` and `y` can be in the list `xs` and
`ys`, but also represent subgraphs of a computational graph in `xs`
or `ys`.
''' '''
if in_xs is None: if in_xs is None:
in_xs = [] in_xs = []
...@@ -313,13 +323,8 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None, strict=True): ...@@ -313,13 +323,8 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None, strict=True):
return False return False
else: else:
pass pass
else: elif (dx, dy) not in common and dx != dy:
if not strict: return False
if dx.type != dy.type:
return False
else:
if (dx, dy) not in common and dx != dy:
return False
while cont and idx < n_nodes: while cont and idx < n_nodes:
nd_x = nds_x[idx] nd_x = nds_x[idx]
...@@ -333,7 +338,7 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None, strict=True): ...@@ -333,7 +338,7 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None, strict=True):
else: else:
for dx, dy in zip(nd_x.inputs, nd_y.inputs): for dx, dy in zip(nd_x.inputs, nd_y.inputs):
if (dx, dy) not in common: if (dx, dy) not in common:
if strict and dx != dy: if dx != dy:
if (isinstance(dx, tensor.Constant) and if (isinstance(dx, tensor.Constant) and
isinstance(dy, tensor.Constant)): isinstance(dy, tensor.Constant)):
if not (numpy.all(dx.data == dy.data) and if not (numpy.all(dx.data == dy.data) and
...@@ -344,8 +349,6 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None, strict=True): ...@@ -344,8 +349,6 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None, strict=True):
pass pass
else: else:
cont = False cont = False
else:
cont = cont and (dx.type == dy.type)
if cont: if cont:
for dx, dy in zip(nd_x.outputs, nd_y.outputs): for dx, dy in zip(nd_x.outputs, nd_y.outputs):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论