提交 e527fda1 authored 作者: James Bergstra's avatar James Bergstra

Changes to TensorConstantSignature.

1) Fixed __eq__ to work when other is not of same type as self. 2) Optimized TensorConstantSignature to compare data.sum() before comparing all elements.
上级 9d5e1145
......@@ -770,13 +770,29 @@ class TensorVariable(Variable, _tensor_py_operators):
"""Subclass to add the tensor operators to the basic `Variable` class."""
class TensorConstantSignature(tuple):
"""A Signature object for comparing TensorConstant instances
An instance is a pair: (Type instance, ndarray).
"""
def __eq__(self, other):
(a, b), (x,y) = self, other
try:
(t0, d0), (t1,d1) = self, other
except:
return False
#N.B. compare shape to ensure no broadcasting in ==
return (x == a) and (b.shape == y.shape) and (numpy.all(b == y))
return (t0 == t1) and (d0.shape == d1.shape) \
and (self.sum == other.sum) and (numpy.all(d0 == d1))
def __hash__(self):
a, b = self
return hashtype(self) ^ hash(a) ^ hash(b.shape)
t, d = self
return hashtype(self) ^ hash(t) ^ hash(d.shape) ^ hash(self.sum)
def _get_sum(self):
try:
return self._sum
except:
self._sum = self[1].sum()
return self._sum
sum = property(_get_sum)
class TensorConstant(Constant, _tensor_py_operators):
"""Subclass to add the tensor operators to the basic `Constant` class.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论