提交 680c25f2 authored 作者: James Bergstra's avatar James Bergstra

Backing out of the idea to add 'shape' attribute to TensorType. This is a

bigger job. See TRAC ticket on subject.
上级 50d2288f
...@@ -253,6 +253,11 @@ class TensorType(Type): ...@@ -253,6 +253,11 @@ class TensorType(Type):
When this is True, strict filtering rejects data containing NaN or Inf entries. (Used in `DebugMode`) When this is True, strict filtering rejects data containing NaN or Inf entries. (Used in `DebugMode`)
""" """
use_shape = False
"""
This should be removed (hardcoded to be False) after AISTATS09
"""
def __init__(self, dtype, broadcastable, name = None, shape=None): def __init__(self, dtype, broadcastable, name = None, shape=None):
"""Initialize self.dtype and self.broadcastable. """Initialize self.dtype and self.broadcastable.
...@@ -305,6 +310,7 @@ class TensorType(Type): ...@@ -305,6 +310,7 @@ class TensorType(Type):
if self.filter_checks_isfinite and (not numpy.all(numpy.isfinite(data))): if self.filter_checks_isfinite and (not numpy.all(numpy.isfinite(data))):
raise TypeError("non-finite elements not allowed") raise TypeError("non-finite elements not allowed")
if TensorType.use_shape:
for si, di in zip(self.shape, data.shape): for si, di in zip(self.shape, data.shape):
if not (si is None or si == di): if not (si is None or si == di):
raise TypeError('%s requires ndarray with shape matching %s (got %s)'%( raise TypeError('%s requires ndarray with shape matching %s (got %s)'%(
...@@ -347,9 +353,13 @@ class TensorType(Type): ...@@ -347,9 +353,13 @@ class TensorType(Type):
def __eq__(self, other): def __eq__(self, other):
"""Compare True iff other is the same kind of TensorType""" """Compare True iff other is the same kind of TensorType"""
if TensorType.use_shape:
return type(self) == type(other) and other.dtype == self.dtype \ return type(self) == type(other) and other.dtype == self.dtype \
and other.broadcastable == self.broadcastable \ and other.broadcastable == self.broadcastable \
and other.shape == self.shape and other.shape == self.shape
else:
return type(self) == type(other) and other.dtype == self.dtype \
and other.broadcastable == self.broadcastable
@staticmethod @staticmethod
def values_eq(a, b): def values_eq(a, b):
...@@ -420,7 +430,10 @@ class TensorType(Type): ...@@ -420,7 +430,10 @@ class TensorType(Type):
def __hash__(self): def __hash__(self):
"""Hash equal for same kinds of TensorType""" """Hash equal for same kinds of TensorType"""
if TensorType.use_shape:
return hashtype(self) ^ hash(self.dtype) ^ hash(self.broadcastable) ^ hash(self.shape) return hashtype(self) ^ hash(self.dtype) ^ hash(self.broadcastable) ^ hash(self.shape)
else:
return hashtype(self) ^ hash(self.dtype) ^ hash(self.broadcastable)
ndim = property(lambda self: len(self.broadcastable), doc = "number of dimensions") ndim = property(lambda self: len(self.broadcastable), doc = "number of dimensions")
"""Number of dimensions """Number of dimensions
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论