提交 4551c0ab authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5462 from abergeron/shared_init

Add inplace filtering for GpuArrayType so that shared value replacement
...@@ -197,6 +197,11 @@ class GpuArrayType(Type): ...@@ -197,6 +197,11 @@ class GpuArrayType(Type):
self.broadcastable) self.broadcastable)
def filter(self, data, strict=False, allow_downcast=None): def filter(self, data, strict=False, allow_downcast=None):
return self.filter_inplace(data, None, strict=strict,
allow_downcast=allow_downcast)
def filter_inplace(self, data, old_data, strict=False,
allow_downcast=None):
if (isinstance(data, gpuarray.GpuArray) and if (isinstance(data, gpuarray.GpuArray) and
data.typecode == self.typecode): data.typecode == self.typecode):
# This is just to make this condition not enter the # This is just to make this condition not enter the
...@@ -218,9 +223,13 @@ class GpuArrayType(Type): ...@@ -218,9 +223,13 @@ class GpuArrayType(Type):
(allow_downcast is None and (allow_downcast is None and
type(data) == float and type(data) == float and
self.dtype == config.floatX)): self.dtype == config.floatX)):
data = gpuarray.array(data, dtype=self.typecode, copy=False, if not isinstance(data, gpuarray.GpuArray):
ndmin=len(self.broadcastable), data = np.array(data, dtype=self.dtype, copy=False,
context=self.context) ndmin=len(self.broadcastable))
else:
data = gpuarray.array(data, dtype=self.typecode, copy=False,
ndmin=len(self.broadcastable),
context=self.context)
else: else:
if not hasattr(data, 'dtype'): if not hasattr(data, 'dtype'):
converted_data = theano._asarray(data, self.dtype) converted_data = theano._asarray(data, self.dtype)
...@@ -230,12 +239,13 @@ class GpuArrayType(Type): ...@@ -230,12 +239,13 @@ class GpuArrayType(Type):
converted_data, converted_data,
force_same_dtype=False): force_same_dtype=False):
data = converted_data data = converted_data
data = gpuarray.array(data, context=self.context)
up_dtype = scalar.upcast(self.dtype, data.dtype) up_dtype = scalar.upcast(self.dtype, data.dtype)
if up_dtype == self.dtype: if up_dtype == self.dtype:
data = gpuarray.array(data, dtype=self.dtype, copy=False, if not isinstance(data, gpuarray.GpuArray):
context=self.context) data = np.array(data, dtype=self.dtype, copy=False)
else:
data = gpuarray.array(data, dtype=self.dtype, copy=False)
else: else:
raise TypeError("%s cannot store a value of dtype %s " raise TypeError("%s cannot store a value of dtype %s "
"without risking loss of precision." % "without risking loss of precision." %
...@@ -250,6 +260,12 @@ class GpuArrayType(Type): ...@@ -250,6 +260,12 @@ class GpuArrayType(Type):
if b and shp[i] != 1: if b and shp[i] != 1:
raise TypeError("Non-unit value on shape on a broadcastable" raise TypeError("Non-unit value on shape on a broadcastable"
" dimension.", shp, self.broadcastable) " dimension.", shp, self.broadcastable)
if not isinstance(data, gpuarray.GpuArray):
if old_data is not None and old_data.shape == data.shape:
old_data.write(data)
data = old_data
else:
data = pygpu.array(data, context=self.context)
return data return data
def filter_variable(self, other, allow_convert=True): def filter_variable(self, other, allow_convert=True):
......
...@@ -778,8 +778,7 @@ class Scan(PureOp): ...@@ -778,8 +778,7 @@ class Scan(PureOp):
# broadcastable dimensions, 0 on the others). # broadcastable dimensions, 0 on the others).
default_shape = [1 if _b else 0 default_shape = [1 if _b else 0
for _b in inp.broadcastable] for _b in inp.broadcastable]
default_val = numpy.zeros(default_shape, default_val = inp.type.value_zeros(default_shape)
dtype=inp.dtype)
wrapped_inp = In(variable=inp, value=default_val, wrapped_inp = In(variable=inp, value=default_val,
update=self.outputs[output_idx]) update=self.outputs[output_idx])
wrapped_inputs.append(wrapped_inp) wrapped_inputs.append(wrapped_inp)
......
...@@ -4956,6 +4956,9 @@ class T_Scan_Gpuarray(unittest.TestCase, ScanGpuTests): ...@@ -4956,6 +4956,9 @@ class T_Scan_Gpuarray(unittest.TestCase, ScanGpuTests):
super(T_Scan_Gpuarray, self).__init__(*args, **kwargs) super(T_Scan_Gpuarray, self).__init__(*args, **kwargs)
def setUp(self): def setUp(self):
# Make sure to activate the new backend, if possible otherwise
# tesing this class directly will always skip.
import theano.gpuarray.tests.config
# Skip the test if pygpu is not available # Skip the test if pygpu is not available
if not self.gpu_backend.pygpu_activated: if not self.gpu_backend.pygpu_activated:
raise SkipTest('Optional package pygpu disabled') raise SkipTest('Optional package pygpu disabled')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论