提交 8f2c17ec authored 作者: James Bergstra's avatar James Bergstra

added filter_update function to shared variable interface

上级 120b1ee9
......@@ -86,12 +86,8 @@ def pfunc(params, outputs=None, mode=None, updates=[]):
# Add update values as quantities that must be computed.
new_updates = {}
for (store_into, update_val) in iter_over_pairs(updates):
if not isinstance(update_val, Variable):
# The value for the update is not a Variable: we cast it into
# a shared Variable so that it can be used by 'function'. Note that
# it means the update value may change if it is mutable and its
# value is modified after the function is created.
update_val = shared(update_val)
assert isinstance(store_into, SharedVariable)
update_val = store_into.filter_update(update_val)
computed_list.append(update_val)
new_updates[store_into] = update_val
updates = new_updates
......
......@@ -83,6 +83,24 @@ class SharedVariable(Variable):
"""
def filter_update(self, update):
"""When this shared variable is updated by a pfunc, the update value will be run through this function.
This is a good spot to cast or convert the update expression as necessary.
Default behaviour is to return `update` unmodified if it is a Variable, otherwise to create a SharedVariable for it by calling ``shared(update)``.
:param update: the new value for this shared variable when updated by a pfunc.
:returns: a Variable whose value will be assigned to this SharedVariable by a pfunc.
"""
if not isinstance(update, Variable):
# The value for the update is not a Variable: we cast it into
# a shared Variable so that it can be used by 'function'. Note that
# it means the update value may change if it is mutable and its
# value is modified after the function is created.
update = shared(update)
return update
def shared_constructor(ctor):
shared.constructors.append(ctor)
return ctor
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论