提交 b7ef62dc authored 作者: James Bergstra's avatar James Bergstra

Added shape tuple to tensortype. This is not a finished commit.

In a finished commit, the shape would replace the broadcastable parameter, since the broadcastable vector can be computed from the shape. The shape has None in components that are unknown.
上级 a8a59e90
...@@ -180,6 +180,13 @@ def constant_or_value(x, rtype, name=None, ndim=None, dtype=None): ...@@ -180,6 +180,13 @@ def constant_or_value(x, rtype, name=None, ndim=None, dtype=None):
assert len(bcastable) == ndim assert len(bcastable) == ndim
try: try:
if rtype is TensorConstant:
# put the shape into the type
return rtype(
TensorType(dtype = x_.dtype, broadcastable = bcastable, shape=x_.shape),
x_, name=name)
else:
# leave the shape out of the type
return rtype(TensorType(dtype = x_.dtype, broadcastable = bcastable), x_, name=name) return rtype(TensorType(dtype = x_.dtype, broadcastable = bcastable), x_, name=name)
except: except:
raise TypeError("Could not convert %s to TensorType" % x, type(x)) raise TypeError("Could not convert %s to TensorType" % x, type(x))
...@@ -236,7 +243,7 @@ class TensorType(Type): ...@@ -236,7 +243,7 @@ class TensorType(Type):
When this is True, strict filtering rejects data containing NaN or Inf entries. (Used in `DebugMode`) When this is True, strict filtering rejects data containing NaN or Inf entries. (Used in `DebugMode`)
""" """
def __init__(self, dtype, broadcastable, name = None): def __init__(self, dtype, broadcastable, name = None, shape=None):
"""Initialize self.dtype and self.broadcastable. """Initialize self.dtype and self.broadcastable.
:Parameters: :Parameters:
...@@ -256,6 +263,20 @@ class TensorType(Type): ...@@ -256,6 +263,20 @@ class TensorType(Type):
self.broadcastable = tuple(broadcastable) self.broadcastable = tuple(broadcastable)
self.dtype_specs() # error checking is done there self.dtype_specs() # error checking is done there
self.name = name self.name = name
if shape is None:
self.shape = tuple((1 if b else None) for b in self.broadcastable)
else:
self.shape = tuple(shape)
if len(self.shape) != len(self.broadcastable):
raise ValueError('shape and broadcastable must have equal lengths', (self.shape,
self.broadcastable))
def __setstate__(self, dct):
self.__dict__.update(dct)
#add shape when unpickling old pickled things
if 'shape' not in dct:
self.shape = tuple(1 if b else None for b in self.broadcastable)
def filter(self, data, strict = False): def filter(self, data, strict = False):
"""Convert `data` to something which can be associated to a `TensorVariable`. """Convert `data` to something which can be associated to a `TensorVariable`.
...@@ -273,6 +294,11 @@ class TensorType(Type): ...@@ -273,6 +294,11 @@ class TensorType(Type):
raise TypeError("%s expected a ndarray object with %s dimensions (got %s)." % (self, self.ndim, data.ndim)) raise TypeError("%s expected a ndarray object with %s dimensions (got %s)." % (self, self.ndim, data.ndim))
if self.filter_checks_isfinite and (not numpy.all(numpy.isfinite(data))): if self.filter_checks_isfinite and (not numpy.all(numpy.isfinite(data))):
raise TypeError("non-finite elements not allowed") raise TypeError("non-finite elements not allowed")
for si, di in zip(self.shape, data.shape):
if not (si is None or si == di):
raise TypeError('%s requires ndarray with shape matching %s (got %s)'%(
self, self.shape, data.shape))
return data return data
else: else:
data = numpy.asarray(data, dtype = self.dtype) data = numpy.asarray(data, dtype = self.dtype)
...@@ -311,7 +337,9 @@ class TensorType(Type): ...@@ -311,7 +337,9 @@ class TensorType(Type):
def __eq__(self, other): def __eq__(self, other):
"""Compare True iff other is the same kind of TensorType""" """Compare True iff other is the same kind of TensorType"""
return type(self) == type(other) and other.dtype == self.dtype and other.broadcastable == self.broadcastable return type(self) == type(other) and other.dtype == self.dtype \
and other.broadcastable == self.broadcastable \
and other.shape == self.shape
@staticmethod @staticmethod
def values_eq(a, b): def values_eq(a, b):
...@@ -382,7 +410,7 @@ class TensorType(Type): ...@@ -382,7 +410,7 @@ class TensorType(Type):
def __hash__(self): def __hash__(self):
"""Hash equal for same kinds of TensorType""" """Hash equal for same kinds of TensorType"""
return hashtype(self) ^ hash(self.dtype) ^ hash(self.broadcastable) return hashtype(self) ^ hash(self.dtype) ^ hash(self.broadcastable) ^ hash(self.shape)
ndim = property(lambda self: len(self.broadcastable), doc = "number of dimensions") ndim = property(lambda self: len(self.broadcastable), doc = "number of dimensions")
"""Number of dimensions """Number of dimensions
...@@ -405,6 +433,8 @@ class TensorType(Type): ...@@ -405,6 +433,8 @@ class TensorType(Type):
def __str__(self): def __str__(self):
if self.name: if self.name:
return self.name return self.name
elif not all(None == si for si in self.shape):
return 'TensorType{%s, %s}' % (self.dtype, self.shape)
else: else:
b = self.broadcastable b = self.broadcastable
named_broadcastable = {(): 'scalar', named_broadcastable = {(): 'scalar',
...@@ -782,7 +812,6 @@ class _tensor_py_operators: ...@@ -782,7 +812,6 @@ class _tensor_py_operators:
dtype = property(lambda self: self.type.dtype) dtype = property(lambda self: self.type.dtype)
""" The dtype of this tensor. """ """ The dtype of this tensor. """
#extra pseudo-operator symbols #extra pseudo-operator symbols
def __dot__(left, right): return dot(left, right) def __dot__(left, right): return dot(left, right)
def __rdot__(right, left): return dot(left, right) def __rdot__(right, left): return dot(left, right)
...@@ -1051,9 +1080,23 @@ class Shape(Op): ...@@ -1051,9 +1080,23 @@ class Shape(Op):
out[0] = numpy.asarray(x.shape, dtype = 'int64') out[0] = numpy.asarray(x.shape, dtype = 'int64')
def grad(self, (x,), (gz,)): def grad(self, (x,), (gz,)):
return [None] return [None]
@_redefine_asRoutine(Shape()) _shape = Shape()
@constructor
def shape(a): def shape(a):
pass """Return the shape tuple of a TensorType Variable, it may be either symbolic or nonsymbolic.
If the shape of the expression is not known at graph-construction time, then a symbolic
lvector will be returned, corresponding to the actual shape at graph-execution time.
"""
print 'GOT A', a, a.type
va = as_tensor_variable(a)
if None in va.type.shape:
# Some shape components are unknown at this time
return _shape(va)
else:
# all shape components are known at compile time, so we return
# a tuple directly. This tuple is like the numpy.ndarray.shape tuple.
return va.type.shape
pprint.assign(shape, printing.MemberPrinter('shape')) pprint.assign(shape, printing.MemberPrinter('shape'))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论