提交 4450770c authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Relax GpuArrayType.filter() when applied to python floats so that it mimics the…

Relax GpuArrayType.filter() when applied to python floats so that it mimics the Tensor behavior a bit more.
上级 2907f95a
...@@ -40,7 +40,12 @@ class GpuArrayType(Type): ...@@ -40,7 +40,12 @@ class GpuArrayType(Type):
return "GpuArrayType(%s, %s)" % (self.dtype, self.broadcastable) return "GpuArrayType(%s, %s)" % (self.dtype, self.broadcastable)
def filter(self, data, strict=False, allow_downcast=None): def filter(self, data, strict=False, allow_downcast=None):
if strict: if (isinstance(data, gpuarray.GpuArray) and
data.typecode == self.typecode):
# This is just to make this condition not enter the
# following branches
pass
elif strict:
if not isinstance(data, gpuarray.GpuArray): if not isinstance(data, gpuarray.GpuArray):
raise TypeError("%s expected a GpuArray object." % self, raise TypeError("%s expected a GpuArray object." % self,
data, type(data)) data, type(data))
...@@ -50,13 +55,20 @@ class GpuArrayType(Type): ...@@ -50,13 +55,20 @@ class GpuArrayType(Type):
(self, self.typecode, self.dtype, (self, self.typecode, self.dtype,
data.typecode, str(data.dtype))) data.typecode, str(data.dtype)))
# fallthrough to ndim check # fallthrough to ndim check
elif allow_downcast: elif (allow_downcast or
(allow_downcast is None and
type(data) == float and
self.dtype == theano.config.floatX)):
data = gpuarray.array(data, dtype=self.typecode, copy=False, data = gpuarray.array(data, dtype=self.typecode, copy=False,
ndmin=len(self.broadcastable)) ndmin=len(self.broadcastable))
else: else:
if not hasattr(data, 'dtype'):
data = gpuarray.array(data, copy=False)
up_dtype = scalar.upcast(self.dtype, data.dtype) up_dtype = scalar.upcast(self.dtype, data.dtype)
if up_dtype == self.dtype: if up_dtype == self.dtype:
data = gpuarray.array(data, dtype=self.dtype, copy=False) data = gpuarray.array(data, dtype=self.dtype,
copy=False)
else: else:
raise TypeError("%s cannot store a value of dtype %s " raise TypeError("%s cannot store a value of dtype %s "
"without risking loss of precision." % "without risking loss of precision." %
...@@ -312,7 +324,9 @@ class GpuArraySharedVariable(_operators, SharedVariable): ...@@ -312,7 +324,9 @@ class GpuArraySharedVariable(_operators, SharedVariable):
return numpy.asarray(self.container.value) return numpy.asarray(self.container.value)
def set_value(self, value, borrow=False): def set_value(self, value, borrow=False):
self.container.value = pygpu.gpuarray.array(value, copy=(not borrow)) if isinstance(value, pygpu.gpuarray.GpuArray):
value = pygpu.gpuarray.array(value, copy=(not borrow))
self.container.value = value
def __getitem__(self, *args): def __getitem__(self, *args):
return _operators.__getitem__(self, *args) return _operators.__getitem__(self, *args)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论