提交 680c25f2 authored 作者: James Bergstra's avatar James Bergstra

Backing out of the idea to add 'shape' attribute to TensorType. This is a

bigger job. See TRAC ticket on subject.
上级 50d2288f
...@@ -253,6 +253,11 @@ class TensorType(Type): ...@@ -253,6 +253,11 @@ class TensorType(Type):
When this is True, strict filtering rejects data containing NaN or Inf entries. (Used in `DebugMode`) When this is True, strict filtering rejects data containing NaN or Inf entries. (Used in `DebugMode`)
""" """
use_shape = False
"""
This should be removed (hardcoded to be False) after AISTATS09
"""
def __init__(self, dtype, broadcastable, name = None, shape=None): def __init__(self, dtype, broadcastable, name = None, shape=None):
"""Initialize self.dtype and self.broadcastable. """Initialize self.dtype and self.broadcastable.
...@@ -305,10 +310,11 @@ class TensorType(Type): ...@@ -305,10 +310,11 @@ class TensorType(Type):
if self.filter_checks_isfinite and (not numpy.all(numpy.isfinite(data))): if self.filter_checks_isfinite and (not numpy.all(numpy.isfinite(data))):
raise TypeError("non-finite elements not allowed") raise TypeError("non-finite elements not allowed")
for si, di in zip(self.shape, data.shape): if TensorType.use_shape:
if not (si is None or si == di): for si, di in zip(self.shape, data.shape):
raise TypeError('%s requires ndarray with shape matching %s (got %s)'%( if not (si is None or si == di):
self, self.shape, data.shape)) raise TypeError('%s requires ndarray with shape matching %s (got %s)'%(
self, self.shape, data.shape))
return data return data
else: else:
data = numpy.asarray(data, dtype = self.dtype) data = numpy.asarray(data, dtype = self.dtype)
...@@ -347,9 +353,13 @@ class TensorType(Type): ...@@ -347,9 +353,13 @@ class TensorType(Type):
def __eq__(self, other): def __eq__(self, other):
"""Compare True iff other is the same kind of TensorType""" """Compare True iff other is the same kind of TensorType"""
return type(self) == type(other) and other.dtype == self.dtype \ if TensorType.use_shape:
and other.broadcastable == self.broadcastable \ return type(self) == type(other) and other.dtype == self.dtype \
and other.shape == self.shape and other.broadcastable == self.broadcastable \
and other.shape == self.shape
else:
return type(self) == type(other) and other.dtype == self.dtype \
and other.broadcastable == self.broadcastable
@staticmethod @staticmethod
def values_eq(a, b): def values_eq(a, b):
...@@ -420,7 +430,10 @@ class TensorType(Type): ...@@ -420,7 +430,10 @@ class TensorType(Type):
def __hash__(self): def __hash__(self):
"""Hash equal for same kinds of TensorType""" """Hash equal for same kinds of TensorType"""
return hashtype(self) ^ hash(self.dtype) ^ hash(self.broadcastable) ^ hash(self.shape) if TensorType.use_shape:
return hashtype(self) ^ hash(self.dtype) ^ hash(self.broadcastable) ^ hash(self.shape)
else:
return hashtype(self) ^ hash(self.dtype) ^ hash(self.broadcastable)
ndim = property(lambda self: len(self.broadcastable), doc = "number of dimensions") ndim = property(lambda self: len(self.broadcastable), doc = "number of dimensions")
"""Number of dimensions """Number of dimensions
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论