提交 23aa7ddc authored 作者: James Bergstra's avatar James Bergstra

unhacking the filter code by adding a new filter_inplace method to Type API - NEEDS TESTING

上级 b5986c84
...@@ -163,19 +163,16 @@ class Container(object): ...@@ -163,19 +163,16 @@ class Container(object):
if value is None: if value is None:
self.storage[0] = None self.storage[0] = None
return return
if self.type.__class__.__name__ == "CudaNdarrayType" and isinstance(value,numpy.ndarray):
#The filter method of CudaNdarray alloc a new memory region on the gpu.
#The ref count will be decremented after that.
#That cause 2 region allocated at the same time!
#We decrement the memory reference conter now to try to lower the memory usage.
self.storage[0] = None
kwargs = {} kwargs = {}
if self.strict: if self.strict:
kwargs['strict'] = True kwargs['strict'] = True
if self.allow_downcast is not None: if self.allow_downcast is not None:
kwargs['allow_downcast'] = self.allow_downcast kwargs['allow_downcast'] = self.allow_downcast
self.storage[0] = self.type.filter(value, **kwargs) if hasattr(self.type,'filter_inplace'):
self.storage[0] = self.type.filter_inplace(value, self.storage[0], **kwargs)
else:
self.storage[0] = self.type.filter(value, **kwargs)
except Exception, e: except Exception, e:
e.args = e.args + (('Container name "%s"' % self.name),) e.args = e.args + (('Container name "%s"' % self.name),)
......
...@@ -53,14 +53,18 @@ class CudaNdarrayType(Type): ...@@ -53,14 +53,18 @@ class CudaNdarrayType(Type):
self.dtype_specs() # error checking is done there self.dtype_specs() # error checking is done there
def filter(self, data, strict=False, allow_downcast=None): def filter(self, data, strict=False, allow_downcast=None):
return self.filter_inplace(data, None, strict=strict, allow_downcast=allow_downcast)
def filter_inplace(self, data, old_data, strict=False, allow_downcast=None):
if strict or allow_downcast or isinstance(data, cuda.CudaNdarray): if strict or allow_downcast or isinstance(data, cuda.CudaNdarray):
return cuda.filter(data, self.broadcastable, strict, None) return cuda.filter(data, self.broadcastable, strict, old_data)
else: # (not strict) and (not allow_downcast) else: # (not strict) and (not allow_downcast)
# Check if data.dtype can be accurately casted to self.dtype # Check if data.dtype can be accurately casted to self.dtype
if isinstance(data, numpy.ndarray): if isinstance(data, numpy.ndarray):
up_dtype = scal.upcast(self.dtype, data.dtype) up_dtype = scal.upcast(self.dtype, data.dtype)
if up_dtype == self.dtype: if up_dtype == self.dtype:
return cuda.filter(data, self.broadcastable, strict, None) return cuda.filter(data, self.broadcastable, strict, old_data)
else: else:
raise TypeError( raise TypeError(
'%s, with dtype %s, cannot store a value of ' '%s, with dtype %s, cannot store a value of '
...@@ -75,10 +79,10 @@ class CudaNdarrayType(Type): ...@@ -75,10 +79,10 @@ class CudaNdarrayType(Type):
type(data) is float and type(data) is float and
self.dtype==theano.config.floatX): self.dtype==theano.config.floatX):
return cuda.filter(converted_data, self.broadcastable, return cuda.filter(converted_data, self.broadcastable,
strict, None) strict, old_data)
elif numpy.all(data == converted_data): elif numpy.all(data == converted_data):
return cuda.filter(converted_data, self.broadcastable, return cuda.filter(converted_data, self.broadcastable,
strict, None) strict, old_data)
else: else:
raise TypeError( raise TypeError(
'%s, with dtype %s, cannot store accurately value %s, ' '%s, with dtype %s, cannot store accurately value %s, '
...@@ -87,6 +91,7 @@ class CudaNdarrayType(Type): ...@@ -87,6 +91,7 @@ class CudaNdarrayType(Type):
% (self, self.dtype, data, converted_data, self.dtype), % (self, self.dtype, data, converted_data, self.dtype),
data) data)
@staticmethod @staticmethod
def bound(a): def bound(a):
high = a.gpudata high = a.gpudata
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论