提交 95470852 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Remove filter_update from shared.

上级 067a3631
...@@ -118,29 +118,6 @@ class SharedVariable(Variable): ...@@ -118,29 +118,6 @@ class SharedVariable(Variable):
cp.tag = copy.copy(self.tag) cp.tag = copy.copy(self.tag)
return cp return cp
def filter_update(self, update):
"""
When this shared variable is updated by a pfunc, the update value will be run through this function.
This is a good spot to cast or convert the update expression as necessary.
Default behaviour is to return `update` unmodified if it is a Variable, otherwise to create a SharedVariable for it by calling ``shared(update)``.
:param update: the new value for this shared variable when updated by a pfunc.
:returns: a Variable whose value will be assigned to this SharedVariable by a pfunc.
:note: The return value of this function must match the self.type, or else pfunc()
will raise a TypeError.
"""
if not isinstance(update, Variable):
# The value for the update is not a Variable: we cast it into
# a shared Variable so that it can be used by 'function'. Note that
# it means the update value may change if it is mutable and its
# value is modified after the function is created.
update = shared(update)
return update
def __getitem__(self, *args): def __getitem__(self, *args):
# __getitem__ is not available for generic SharedVariable objects. # __getitem__ is not available for generic SharedVariable objects.
# We raise a TypeError like Python would do if __getitem__ was not # We raise a TypeError like Python would do if __getitem__ was not
......
...@@ -127,19 +127,6 @@ class CudaNdarraySharedVariable(SharedVariable, _operators): ...@@ -127,19 +127,6 @@ class CudaNdarraySharedVariable(SharedVariable, _operators):
value = copy.deepcopy(value) value = copy.deepcopy(value)
self.container.value = value # this will copy a numpy ndarray self.container.value = value # this will copy a numpy ndarray
def filter_update(self, other):
if hasattr(other, '_as_CudaNdarrayVariable'):
return other._as_CudaNdarrayVariable()
if not isinstance(other.type, tensor.TensorType):
raise TypeError('Incompatible type', (self, (self.type, other.type)))
if (other.type.dtype != self.dtype):
raise TypeError('Incompatible dtype', (self, (self.dtype, other.type.dtype)))
if (other.type.broadcastable != self.broadcastable):
raise TypeError('Incompatible broadcastable', (self, (self.broadcastable,
other.type.broadcastable)))
return GpuFromHost()(other)
def __getitem__(self, *args): def __getitem__(self, *args):
# Defined to explicitly use the implementation from `_operators`, since # Defined to explicitly use the implementation from `_operators`, since
# the definition in `SharedVariable` is only meant to raise an error. # the definition in `SharedVariable` is only meant to raise an error.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论