提交 6a0d70ef authored 作者: James Bergstra's avatar James Bergstra

added impl of values_eq to TensorType

上级 7a87a567
...@@ -314,6 +314,15 @@ class TensorType(Type): ...@@ -314,6 +314,15 @@ class TensorType(Type):
return type(self) == type(other) and other.dtype == self.dtype and other.broadcastable == self.broadcastable return type(self) == type(other) and other.dtype == self.dtype and other.broadcastable == self.broadcastable
@staticmethod @staticmethod
def values_eq(a, b):
#TODO: check to see if the dtype and shapes must match
# for now, we err on safe side...
if a.shape != b.shape:
return False
if a.dtype != b.dtype:
return False
return numpy.all(a==b)
@staticmethod
def values_eq_approx(a, b): def values_eq_approx(a, b):
if type(a) is numpy.ndarray and type(b) is numpy.ndarray: if type(a) is numpy.ndarray and type(b) is numpy.ndarray:
if a.shape != b.shape: if a.shape != b.shape:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论