提交 49fe4a6e authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Fixed a bunch of errors in equal_computation

Errors are obvious, first of all the data of the constants don't have to scalars, and secondly, is not enough for outputs that have no owner to have same type, but they should have the same index (i.e. correspond to the same input). As I have all corresponding pairs of inputs in `common` and these outputs are equal with their inputs, is sufficient to check if them ( as pair) are part of common
上级 568c8a08
...@@ -322,14 +322,15 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True): ...@@ -322,14 +322,15 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
return False return False
elif (isinstance(dx, tensor.Constant) and elif (isinstance(dx, tensor.Constant) and
isinstance(dy, tensor.Constant) and isinstance(dy, tensor.Constant) and
dx.data == dy.data): numpy.all(dx.data == dy.data)):
pass pass
elif strict:
if dx != dy:
return False
else: else:
if dx.type != dy.type: if not strict:
return False if dx.type != dy.type:
return False
else:
if (dx,dy) not in common:
return False
while cont and idx < n_nodes: while cont and idx < n_nodes:
nd_x = nds_x[idx] nd_x = nds_x[idx]
...@@ -346,7 +347,7 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True): ...@@ -346,7 +347,7 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
if strict and dx!= dy: if strict and dx!= dy:
if (isinstance(dx, tensor.Constant) and if (isinstance(dx, tensor.Constant) and
isinstance(dy, tensor.Constant) and isinstance(dy, tensor.Constant) and
dx.data == dy.data): numpy.all(dx.data == dy.data)):
pass pass
else: else:
cont = False cont = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论