提交 04e10833 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

fixed astensor when taking Tensor or Result inputs

上级 59aa5470
......@@ -5,14 +5,14 @@ from copy import copy
from compile import Function
import gof
def _tensor(data, broadcastable=None, role=None, name=None):
def _tensor(data, broadcastable=None, name=None):
"""Return a BaseTensor containing given data"""
data = numpy.asarray(data)
if broadcastable is None:
broadcastable = [s==1 for s in data.shape]
elif broadcastable in [0, 1]:
broadcastable = [broadcastable] * len(data.shape)
rval = BaseTensor(data.dtype, broadcastable, role, name)
rval = BaseTensor(data.dtype, broadcastable, name)
rval.data = data # will raise if broadcastable was mis-specified
return rval
......
......@@ -26,7 +26,7 @@ class BaseTensor(Result):
on the L{Op}s that use it.
"""
def __init__(self, dtype, broadcastable, role=None, name=None):
def __init__(self, dtype, broadcastable, name=None):
"""Initialize a L{Tensor}
@todo: Initialize a L{Tensor} or a L{BaseTensor}? -jpt
......@@ -45,7 +45,7 @@ class BaseTensor(Result):
# the argument that is awkward to construct, I decided to put all this
# into the tensor(data,...) function below, which is like a second
# constructor that works with an ndarray.
Result.__init__(self, role=role, name=name)
Result.__init__(self, role=None, name=name)
self._dtype = str(dtype)
self.dtype_specs() # this is just for error checking
self._broadcastable = tuple(broadcastable)
......@@ -232,7 +232,7 @@ class BaseTensor(Result):
If transfer_data is True, a copy of self.data is assigned to the copy's
data property, otherwise the copy's data is left as None.
"""
cpy = self.__class__(self.dtype, self.broadcastable, None, self.name)
cpy = self.__class__(self.dtype, self.broadcastable, self.name)
if transfer_data:
cpy.data = copy(self.data)
return cpy
......
......@@ -75,16 +75,27 @@ s2t.Tensor = Tensor
# alternate Tensor constructor
def astensor(data, broadcastable=None, role=None, name=None):
def astensor(data, broadcastable=None, name=None):
"""Return a L{Tensor} containing given data"""
if isinstance(data, Tensor) and broadcastable is None and role is None and name is None:
return data
if isinstance(data, BaseTensor):
if broadcastable is not None and list(data.broadcastable) != list(broadcastable):
raise TypeError("The data to wrap as a Tensor has the wrong broadcastable pattern. Expected %s, got %s." % (broadcastable, data.broadcastable))
if isinstance(data, Tensor) and (name is None or name == data.name):
return data
else:
return Tensor(data.dtype, data.broadcastable, name = name)
elif isinstance(data, Result):
data = data.data
if data is None and broadcastable is None:
raise TypeError("Cannot make a Tensor out of None or a Result with no data.")
data = numpy.asarray(data)
if broadcastable is None:
broadcastable = [s==1 for s in data.shape]
elif broadcastable in [0, 1]:
broadcastable = [broadcastable] * len(data.shape)
rval = Tensor(data.dtype, broadcastable, role, name)
rval = Tensor(data.dtype, broadcastable, name = name)
rval.data = data # will raise if broadcastable was mis-specified
return rval
s2t.astensor = astensor
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论