提交 f5ccee2e authored 作者: Razvan Pascanu's avatar Razvan Pascanu

fixed equal_computation following Pascal's suggestion

numpy.all doesn't check dtype and shape
上级 1d752218
...@@ -321,9 +321,13 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True): ...@@ -321,9 +321,13 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
if dy.owner or dx.owner: if dy.owner or dx.owner:
return False return False
elif (isinstance(dx, tensor.Constant) and elif (isinstance(dx, tensor.Constant) and
isinstance(dy, tensor.Constant) and isinstance(dy, tensor.Constant)):
numpy.all(dx.data == dy.data)): if not ( numpy.all(dx.data == dy.data) and
pass dx.dtype == dy.dtype and
dx.shape == dy.shape):
return False
else:
pass
else: else:
if not strict: if not strict:
if dx.type != dy.type: if dx.type != dy.type:
...@@ -346,9 +350,13 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True): ...@@ -346,9 +350,13 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
if (dx,dy) not in common: if (dx,dy) not in common:
if strict and dx!= dy: if strict and dx!= dy:
if (isinstance(dx, tensor.Constant) and if (isinstance(dx, tensor.Constant) and
isinstance(dy, tensor.Constant) and isinstance(dy, tensor.Constant)):
numpy.all(dx.data == dy.data)): if not (numpy.all(dx.data == dy.data) and
pass dx.dtype == dy.dtype and
dx.shape == dy.shape):
return False
else:
pass
else: else:
cont = False cont = False
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论