提交 61707ce3 authored 作者: Frederic's avatar Frederic

Override TensorConstant.equals to allow to compare with numpy.ndarray.

上级 507624c4
...@@ -2004,6 +2004,13 @@ class TensorConstant(_tensor_py_operators, Constant): ...@@ -2004,6 +2004,13 @@ class TensorConstant(_tensor_py_operators, Constant):
def signature(self): def signature(self):
return TensorConstantSignature((self.type, self.data)) return TensorConstantSignature((self.type, self.data))
def equals(self, other):
# Override Contant.equals to allow to compare with numpy.ndarray
if isinstance(other, numpy.ndarray):
# Make a TensorConstant to be able to compare
other = constant(other)
return (isinstance(other, TensorConstant) and
self.signature() == other.signature())
TensorType.Constant = TensorConstant TensorType.Constant = TensorConstant
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论