提交 ce1d5d01 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add inplace filtering for GpuArrayType so that shared value replacement

can be done without additional memory usage.
上级 e8daf149
......@@ -197,6 +197,11 @@ class GpuArrayType(Type):
self.broadcastable)
def filter(self, data, strict=False, allow_downcast=None):
return self.filter_inplace(data, None, strict=strict,
allow_downcast=allow_downcast)
def filter_inplace(self, data, old_data, strict=False,
allow_downcast=None):
if (isinstance(data, gpuarray.GpuArray) and
data.typecode == self.typecode):
# This is just to make this condition not enter the
......@@ -218,6 +223,10 @@ class GpuArrayType(Type):
(allow_downcast is None and
type(data) == float and
self.dtype == config.floatX)):
if not isinstance(data, gpuarray.GpuArray):
data = numpy.array(data, dtype=self.dtype, copy=False,
ndmin=len(self.broadcastable))
else:
data = gpuarray.array(data, dtype=self.typecode, copy=False,
ndmin=len(self.broadcastable),
context=self.context)
......@@ -230,12 +239,10 @@ class GpuArrayType(Type):
converted_data,
force_same_dtype=False):
data = converted_data
data = gpuarray.array(data, context=self.context)
up_dtype = scalar.upcast(self.dtype, data.dtype)
if up_dtype == self.dtype:
data = gpuarray.array(data, dtype=self.dtype, copy=False,
context=self.context)
data = numpy.array(data, dtype=self.dtype, copy=False)
else:
raise TypeError("%s cannot store a value of dtype %s "
"without risking loss of precision." %
......@@ -250,6 +257,12 @@ class GpuArrayType(Type):
if b and shp[i] != 1:
raise TypeError("Non-unit value on shape on a broadcastable"
" dimension.", shp, self.broadcastable)
if not isinstance(data, gpuarray.GpuArray):
if old_data and old_data.shape == data.shape:
old_data.write(data)
data = old_data
else:
data = pygpu.array(data, context=self.context)
return data
def filter_variable(self, other, allow_convert=True):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论