提交 cfecf720 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Prepare for default updates of shared variable. Keyword "no_default_updates" in…

Prepare for default updates of shared variable. Keyword "no_default_updates" in theano.function added.
上级 1997674d
...@@ -10,7 +10,8 @@ from function_module import orig_function ...@@ -10,7 +10,8 @@ from function_module import orig_function
from pfunc import pfunc from pfunc import pfunc
from numpy import any #for to work in python 2.4 from numpy import any #for to work in python 2.4
def function(inputs, outputs=None, mode=None, updates=[], givens=[], accept_inplace=False, name=None): def function(inputs, outputs=None, mode=None, updates=[], givens=[],
no_default_updates=False, accept_inplace=False, name=None):
""" """
Return a callable object that will calculate `outputs` from `inputs`. Return a callable object that will calculate `outputs` from `inputs`.
...@@ -31,7 +32,12 @@ def function(inputs, outputs=None, mode=None, updates=[], givens=[], accept_inpl ...@@ -31,7 +32,12 @@ def function(inputs, outputs=None, mode=None, updates=[], givens=[], accept_inpl
and Var2 in each pair must have the same Type. and Var2 in each pair must have the same Type.
:param givens: specific substitutions to make in the computation graph (Var2 replaces :param givens: specific substitutions to make in the computation graph (Var2 replaces
Var1). Var1).
:type no_default_updates: either bool or list of Variables
:param no_default_updates: if True, do not perform any automatic update on Variables.
If False (default), perform them all. Else, perform automatic updates on all Variables
that are neither in "updates" nor in "no_default_updates".
:param name: an optional name for this function. The profile mode will print the time spent in this function. :param name: an optional name for this function. The profile mode will print the time spent in this function.
...@@ -65,4 +71,5 @@ def function(inputs, outputs=None, mode=None, updates=[], givens=[], accept_inpl ...@@ -65,4 +71,5 @@ def function(inputs, outputs=None, mode=None, updates=[], givens=[], accept_inpl
mode=mode, mode=mode,
updates=updates, updates=updates,
givens=givens, givens=givens,
no_default_updates=no_default_updates,
accept_inplace=accept_inplace,name=name) accept_inplace=accept_inplace,name=name)
差异被折叠。
...@@ -28,6 +28,11 @@ class SharedVariable(Variable): ...@@ -28,6 +28,11 @@ class SharedVariable(Variable):
:type: `Container` :type: `Container`
""" """
# default_update
# If this member is present, its value will be used as the "update" for
# this Variable, unless another update value has been passed to "function",
# or the "no_default_updates" list passed to "function" contains it.
def __init__(self, name, type, value, strict, container=None): def __init__(self, name, type, value, strict, container=None):
""" """
:param name: The name for this variable (see `Variable`). :param name: The name for this variable (see `Variable`).
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论