提交 7860e808 authored 作者: James Bergstra's avatar James Bergstra

Removed shape attribute of TensorType.

上级 bcd2bfea
...@@ -238,20 +238,12 @@ def constant_or_value(x, rtype, name=None, ndim=None, dtype=None): ...@@ -238,20 +238,12 @@ def constant_or_value(x, rtype, name=None, ndim=None, dtype=None):
try: try:
if rtype is TensorConstant: if rtype is TensorConstant:
if 0: rval = rtype(
# put the shape into the type TensorType(dtype = x_.dtype, broadcastable = bcastable),
x_.copy(),
# This is disabled because if a tensor has shape, then the following fails: name=name)
# theano.lvector == as_tensor_variable([0,1]).type rval.tag.shape = x_.shape
# I think the solution is that we should implement something more like return rval
# compatability instead of equality in our Type comparisons... but we're not
# there yet.
x_shape = x_.shape
else:
x_shape = None
return rtype(
TensorType(dtype = x_.dtype, broadcastable = bcastable, shape=x_shape),
x_.copy(), name=name)
else: else:
# leave the shape out of the type # leave the shape out of the type
return rtype(TensorType(dtype = x_.dtype, broadcastable = bcastable), x_, name=name) return rtype(TensorType(dtype = x_.dtype, broadcastable = bcastable), x_, name=name)
...@@ -320,7 +312,7 @@ class TensorType(Type): ...@@ -320,7 +312,7 @@ class TensorType(Type):
When this is True, strict filtering rejects data containing NaN or Inf entries. (Used in `DebugMode`) When this is True, strict filtering rejects data containing NaN or Inf entries. (Used in `DebugMode`)
""" """
def __init__(self, dtype, broadcastable, name = None, shape=None): def __init__(self, dtype, broadcastable, name = None):
"""Initialize self.dtype and self.broadcastable. """Initialize self.dtype and self.broadcastable.
:Parameters: :Parameters:
...@@ -343,30 +335,7 @@ class TensorType(Type): ...@@ -343,30 +335,7 @@ class TensorType(Type):
self.dtype_specs() # error checking is done there self.dtype_specs() # error checking is done there
self.name = name self.name = name
self.numpy_dtype = numpy.dtype(self.dtype) self.numpy_dtype = numpy.dtype(self.dtype)
if shape is None:
#backport self.shape = tuple((1 if b else None) for b in self.broadcastable)
l=[]
for b in self.broadcastable:
if b: l.append(1)
else: l.append(None)
self.shape = tuple(l)
else:
self.shape = tuple(shape)
if len(self.shape) != len(self.broadcastable):
raise ValueError('shape and broadcastable must have equal lengths', (self.shape,
self.broadcastable))
def __setstate__(self, dct):
self.__dict__.update(dct)
#add shape when unpickling old pickled things
if 'shape' not in dct:
l=[]
for b in self.broadcastable:
if b: l.append(1)
else: l.append(None)
self.shape = tuple(l)
#backport self.shape = tuple(1 if b else None for b in self.broadcastable)
def filter(self, data, strict = False): def filter(self, data, strict = False):
"""Convert `data` to something which can be associated to a `TensorVariable`. """Convert `data` to something which can be associated to a `TensorVariable`.
...@@ -523,8 +492,6 @@ class TensorType(Type): ...@@ -523,8 +492,6 @@ class TensorType(Type):
def __str__(self): def __str__(self):
if self.name: if self.name:
return self.name return self.name
elif not all(None == si for si in self.shape):
return 'TensorType{%s, %s}' % (self.dtype, self.shape)
else: else:
b = self.broadcastable b = self.broadcastable
named_broadcastable = {(): 'scalar', named_broadcastable = {(): 'scalar',
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论