提交 7860e808 authored 作者: James Bergstra's avatar James Bergstra

Removed shape attribute of TensorType.

上级 bcd2bfea
......@@ -238,20 +238,12 @@ def constant_or_value(x, rtype, name=None, ndim=None, dtype=None):
try:
if rtype is TensorConstant:
if 0:
# put the shape into the type
# This is disabled because if a tensor has shape, then the following fails:
# theano.lvector == as_tensor_variable([0,1]).type
# I think the solution is that we should implement something more like
# compatability instead of equality in our Type comparisons... but we're not
# there yet.
x_shape = x_.shape
else:
x_shape = None
return rtype(
TensorType(dtype = x_.dtype, broadcastable = bcastable, shape=x_shape),
x_.copy(), name=name)
rval = rtype(
TensorType(dtype = x_.dtype, broadcastable = bcastable),
x_.copy(),
name=name)
rval.tag.shape = x_.shape
return rval
else:
# leave the shape out of the type
return rtype(TensorType(dtype = x_.dtype, broadcastable = bcastable), x_, name=name)
......@@ -320,7 +312,7 @@ class TensorType(Type):
When this is True, strict filtering rejects data containing NaN or Inf entries. (Used in `DebugMode`)
"""
def __init__(self, dtype, broadcastable, name = None, shape=None):
def __init__(self, dtype, broadcastable, name = None):
"""Initialize self.dtype and self.broadcastable.
:Parameters:
......@@ -343,30 +335,7 @@ class TensorType(Type):
self.dtype_specs() # error checking is done there
self.name = name
self.numpy_dtype = numpy.dtype(self.dtype)
if shape is None:
#backport self.shape = tuple((1 if b else None) for b in self.broadcastable)
l=[]
for b in self.broadcastable:
if b: l.append(1)
else: l.append(None)
self.shape = tuple(l)
else:
self.shape = tuple(shape)
if len(self.shape) != len(self.broadcastable):
raise ValueError('shape and broadcastable must have equal lengths', (self.shape,
self.broadcastable))
def __setstate__(self, dct):
self.__dict__.update(dct)
#add shape when unpickling old pickled things
if 'shape' not in dct:
l=[]
for b in self.broadcastable:
if b: l.append(1)
else: l.append(None)
self.shape = tuple(l)
#backport self.shape = tuple(1 if b else None for b in self.broadcastable)
def filter(self, data, strict = False):
"""Convert `data` to something which can be associated to a `TensorVariable`.
......@@ -523,8 +492,6 @@ class TensorType(Type):
def __str__(self):
if self.name:
return self.name
elif not all(None == si for si in self.shape):
return 'TensorType{%s, %s}' % (self.dtype, self.shape)
else:
b = self.broadcastable
named_broadcastable = {(): 'scalar',
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论