提交 04e10833 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

fixed astensor when taking Tensor or Result inputs

上级 59aa5470
...@@ -5,14 +5,14 @@ from copy import copy ...@@ -5,14 +5,14 @@ from copy import copy
from compile import Function from compile import Function
import gof import gof
def _tensor(data, broadcastable=None, role=None, name=None): def _tensor(data, broadcastable=None, name=None):
"""Return a BaseTensor containing given data""" """Return a BaseTensor containing given data"""
data = numpy.asarray(data) data = numpy.asarray(data)
if broadcastable is None: if broadcastable is None:
broadcastable = [s==1 for s in data.shape] broadcastable = [s==1 for s in data.shape]
elif broadcastable in [0, 1]: elif broadcastable in [0, 1]:
broadcastable = [broadcastable] * len(data.shape) broadcastable = [broadcastable] * len(data.shape)
rval = BaseTensor(data.dtype, broadcastable, role, name) rval = BaseTensor(data.dtype, broadcastable, name)
rval.data = data # will raise if broadcastable was mis-specified rval.data = data # will raise if broadcastable was mis-specified
return rval return rval
......
...@@ -26,7 +26,7 @@ class BaseTensor(Result): ...@@ -26,7 +26,7 @@ class BaseTensor(Result):
on the L{Op}s that use it. on the L{Op}s that use it.
""" """
def __init__(self, dtype, broadcastable, role=None, name=None): def __init__(self, dtype, broadcastable, name=None):
"""Initialize a L{Tensor} """Initialize a L{Tensor}
@todo: Initialize a L{Tensor} or a L{BaseTensor}? -jpt @todo: Initialize a L{Tensor} or a L{BaseTensor}? -jpt
...@@ -45,7 +45,7 @@ class BaseTensor(Result): ...@@ -45,7 +45,7 @@ class BaseTensor(Result):
# the argument that is awkward to construct, I decided to put all this # the argument that is awkward to construct, I decided to put all this
# into the tensor(data,...) function below, which is like a second # into the tensor(data,...) function below, which is like a second
# constructor that works with an ndarray. # constructor that works with an ndarray.
Result.__init__(self, role=role, name=name) Result.__init__(self, role=None, name=name)
self._dtype = str(dtype) self._dtype = str(dtype)
self.dtype_specs() # this is just for error checking self.dtype_specs() # this is just for error checking
self._broadcastable = tuple(broadcastable) self._broadcastable = tuple(broadcastable)
...@@ -232,7 +232,7 @@ class BaseTensor(Result): ...@@ -232,7 +232,7 @@ class BaseTensor(Result):
If transfer_data is True, a copy of self.data is assigned to the copy's If transfer_data is True, a copy of self.data is assigned to the copy's
data property, otherwise the copy's data is left as None. data property, otherwise the copy's data is left as None.
""" """
cpy = self.__class__(self.dtype, self.broadcastable, None, self.name) cpy = self.__class__(self.dtype, self.broadcastable, self.name)
if transfer_data: if transfer_data:
cpy.data = copy(self.data) cpy.data = copy(self.data)
return cpy return cpy
......
...@@ -75,16 +75,27 @@ s2t.Tensor = Tensor ...@@ -75,16 +75,27 @@ s2t.Tensor = Tensor
# alternate Tensor constructor # alternate Tensor constructor
def astensor(data, broadcastable=None, role=None, name=None): def astensor(data, broadcastable=None, name=None):
"""Return a L{Tensor} containing given data""" """Return a L{Tensor} containing given data"""
if isinstance(data, Tensor) and broadcastable is None and role is None and name is None: if isinstance(data, BaseTensor):
if broadcastable is not None and list(data.broadcastable) != list(broadcastable):
raise TypeError("The data to wrap as a Tensor has the wrong broadcastable pattern. Expected %s, got %s." % (broadcastable, data.broadcastable))
if isinstance(data, Tensor) and (name is None or name == data.name):
return data return data
else:
return Tensor(data.dtype, data.broadcastable, name = name)
elif isinstance(data, Result):
data = data.data
if data is None and broadcastable is None:
raise TypeError("Cannot make a Tensor out of None or a Result with no data.")
data = numpy.asarray(data) data = numpy.asarray(data)
if broadcastable is None: if broadcastable is None:
broadcastable = [s==1 for s in data.shape] broadcastable = [s==1 for s in data.shape]
elif broadcastable in [0, 1]: elif broadcastable in [0, 1]:
broadcastable = [broadcastable] * len(data.shape) broadcastable = [broadcastable] * len(data.shape)
rval = Tensor(data.dtype, broadcastable, role, name) rval = Tensor(data.dtype, broadcastable, name = name)
rval.data = data # will raise if broadcastable was mis-specified rval.data = data # will raise if broadcastable was mis-specified
return rval return rval
s2t.astensor = astensor s2t.astensor = astensor
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论