提交 cfecf720 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Prepare for default updates of shared variable. Keyword "no_default_updates" in…

Prepare for default updates of shared variable. Keyword "no_default_updates" in theano.function added.
上级 1997674d
...@@ -10,7 +10,8 @@ from function_module import orig_function ...@@ -10,7 +10,8 @@ from function_module import orig_function
from pfunc import pfunc from pfunc import pfunc
from numpy import any #for to work in python 2.4 from numpy import any #for to work in python 2.4
def function(inputs, outputs=None, mode=None, updates=[], givens=[], accept_inplace=False, name=None): def function(inputs, outputs=None, mode=None, updates=[], givens=[],
no_default_updates=False, accept_inplace=False, name=None):
""" """
Return a callable object that will calculate `outputs` from `inputs`. Return a callable object that will calculate `outputs` from `inputs`.
...@@ -31,7 +32,12 @@ def function(inputs, outputs=None, mode=None, updates=[], givens=[], accept_inpl ...@@ -31,7 +32,12 @@ def function(inputs, outputs=None, mode=None, updates=[], givens=[], accept_inpl
and Var2 in each pair must have the same Type. and Var2 in each pair must have the same Type.
:param givens: specific substitutions to make in the computation graph (Var2 replaces :param givens: specific substitutions to make in the computation graph (Var2 replaces
Var1). Var1).
:type no_default_updates: either bool or list of Variables
:param no_default_updates: if True, do not perform any automatic update on Variables.
If False (default), perform them all. Else, perform automatic updates on all Variables
that are neither in "updates" nor in "no_default_updates".
:param name: an optional name for this function. The profile mode will print the time spent in this function. :param name: an optional name for this function. The profile mode will print the time spent in this function.
...@@ -65,4 +71,5 @@ def function(inputs, outputs=None, mode=None, updates=[], givens=[], accept_inpl ...@@ -65,4 +71,5 @@ def function(inputs, outputs=None, mode=None, updates=[], givens=[], accept_inpl
mode=mode, mode=mode,
updates=updates, updates=updates,
givens=givens, givens=givens,
no_default_updates=no_default_updates,
accept_inplace=accept_inplace,name=name) accept_inplace=accept_inplace,name=name)
...@@ -33,7 +33,8 @@ class Param(object): ...@@ -33,7 +33,8 @@ class Param(object):
self.strict = strict self.strict = strict
self.implicit = implicit self.implicit = implicit
def pfunc(params, outputs=None, mode=None, updates=[], givens=[], accept_inplace=False, name=None): def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
no_default_updates=False, accept_inplace=False, name=None):
"""Function-constructor for graphs with shared variables. """Function-constructor for graphs with shared variables.
:type params: list of either Variable or Param instances. :type params: list of either Variable or Param instances.
...@@ -53,7 +54,12 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], accept_inplace ...@@ -53,7 +54,12 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], accept_inplace
and Var2 in each pair must have the same Type. and Var2 in each pair must have the same Type.
:param givens: specific substitutions to make in the computation graph (Var2 replaces :param givens: specific substitutions to make in the computation graph (Var2 replaces
Var1). Var1).
:type no_default_updates: either bool or list of Variables
:param no_default_updates: if True, do not perform any automatic update on Variables.
If False (default), perform them all. Else, perform automatic updates on all Variables
that are neither in "updates" nor in "no_default_updates".
:param name: an optional name for this fct. If used, the profile mode will print the time spent in this fct. :param name: an optional name for this fct. If used, the profile mode will print the time spent in this fct.
...@@ -86,11 +92,61 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], accept_inplace ...@@ -86,11 +92,61 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], accept_inplace
if not isinstance(params,(list,tuple)): if not isinstance(params,(list,tuple)):
raise Exception("in pfunc() the first argument must be a list or a tuple") raise Exception("in pfunc() the first argument must be a list or a tuple")
# initialize the clone_d mapping with the `givens` argument
clone_d = {} clone_d = {}
def v_clone(v): # Updates as list and dictionary.
return _v_clone(v, clone_d) # They will also store the 'default_update' expressions applicable.
# The dictionary is used to look up the existence of the keys, and to store
# the final (cloned) update expressions.
# The list of pairs is used to iterate in a consistent order while adding
# new pairs.
update_d = {}
update_expr = []
# list of shared inputs that are used as inputs of the graph
shared_inputs = []
def clone_v_get_shared_updates(v):
'''Clone a variable and its inputs, until all are in clone_d.
Also appends all shared variables met along the way to shared_inputs,
and their default_update (if applicable) to update_d and update_expr.
'''
assert v is not None
if v.owner:
clone_a(v.owner)
elif isinstance(v, SharedVariable):
if v not in shared_inputs:
shared_inputs.append(v)
if hasattr(v, 'default_update'):
# Check that v should not be excluded from the default updates list
if no_default_updates is False or\
(isinstance(no_default_updates, list) and\
v not in no_default_updates):
# Do not use default_update if a "real" update was provided
if v not in update_d:
v_update = v.filter_update(v.default_update)
if v_update.type != v.type:
raise TypeError('an update must have the same type as the original shared variable',
(v, v.type, v_update, v_update.type))
update_d[v] = v_update
update_expr.append((v, v_update))
return clone_d.setdefault(v, v)
def clone_a(a):
if a is None:
return None
if a not in clone_d:
for i in a.inputs:
clone_v_get_shared_updates(i)
clone_d[a] = a.clone_with_new_inputs([clone_d[i] for i in a.inputs])
for old_o, new_o in zip(a.outputs, clone_d[a].outputs):
clone_d.setdefault(old_o, new_o)
return clone_d[a]
#def v_clone(v):
# return _v_clone(v, clone_d)
# initialize the clone_d mapping with the `givens` argument
try: try:
givens = givens.items() # converts a dictionary to the sort of list that we want. givens = givens.items() # converts a dictionary to the sort of list that we want.
except: except:
...@@ -101,11 +157,9 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], accept_inplace ...@@ -101,11 +157,9 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], accept_inplace
if not isinstance(v_repl, Variable): if not isinstance(v_repl, Variable):
v_repl = shared(v_repl) v_repl = shared(v_repl)
assert v_orig not in clone_d assert v_orig not in clone_d
clone_d[v_orig] = v_clone(v_repl) clone_d[v_orig] = clone_v_get_shared_updates(v_repl)
# transform params into theano.compile.In objects. # transform params into theano.compile.In objects.
#
# call theano.function
inputs = [_pfunc_param_to_in(p) for p in params] inputs = [_pfunc_param_to_in(p) for p in params]
#Switch inputs to cloned variables #Switch inputs to cloned variables
...@@ -113,101 +167,148 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], accept_inplace ...@@ -113,101 +167,148 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], accept_inplace
for i, iv in zip(inputs, input_variables): for i, iv in zip(inputs, input_variables):
i.variable = iv i.variable = iv
set_of_param_variables = set(input_variables) #set_of_param_variables = set(input_variables)
# It was decided, as a first step, to prevent shared variables from being # It was decided, as a first step, to prevent shared variables from being
# used as function inputs. Although it is technically possible, it is also # used as function inputs. Although it is technically possible, it is also
# potentially ambiguous and dangerous. This restriction may be revisited in # potentially ambiguous and dangerous. This restriction may be revisited in
# the future if there is a need for such a feature. # the future if there is a need for such a feature.
if numpy.any([isinstance(v, SharedVariable) for v in set_of_param_variables]): if numpy.any([isinstance(v, SharedVariable) for v in input_variables]):
raise TypeError('Cannot use a shared variable (%s) as explicit input ' raise TypeError('Cannot use a shared variable (%s) as explicit input '
% v) % v)
# Fill update_d and update_expr with provided updates
for (store_into, update_val) in iter_over_pairs(updates):
if not isinstance(store_into, SharedVariable):
raise TypeError('update target must be a SharedVariable', store_into)
if store_into in update_d:
raise ValueError('this shared variable already has an update expression',
(store_into, update_d[store_into]))
update_val = store_into.filter_update(update_val)
if update_val.type != store_into.type:
raise TypeError('an update must have the same type as the original shared variable',
(store_into, store_into.type,
update_val, update_val.type))
update_d[store_into] = update_val
update_expr.append((store_into, update_val))
# computed_list is a list of output variables (which will be extended later) # computed_list is a list of output variables (which will be extended later)
computed_list = [] #computed_list = []
# Elements of "outputs" are here cloned to "cloned_outputs"
if isinstance(outputs, list): if isinstance(outputs, list):
cloned_outputs = [] cloned_outputs = []
for v in outputs: for v in outputs:
if isinstance(v, Variable): if isinstance(v, Variable):
cloned_v = v_clone(v) cloned_v = clone_v_get_shared_updates(v)
cloned_outputs.append(cloned_v) cloned_outputs.append(cloned_v)
elif isinstance(v, Out): elif isinstance(v, Out):
cloned_v = v_clone(v.variable) cloned_v = clone_v_get_shared_updates(v.variable)
cloned_outputs.append(Out(cloned_v, borrow=v.borrow)) cloned_outputs.append(Out(cloned_v, borrow=v.borrow))
else: else:
raise TypeError('outputs must be theano Variable or Out instances', v) raise TypeError('outputs must be theano Variable or Out instances', v)
computed_list.append(cloned_v) #computed_list.append(cloned_v)
else: else:
if isinstance(outputs, Variable): if isinstance(outputs, Variable):
cloned_v = v_clone(outputs) cloned_v = clone_v_get_shared_updates(outputs)
cloned_outputs = cloned_v cloned_outputs = cloned_v
computed_list.append(cloned_v) #computed_list.append(cloned_v)
elif isinstance(outputs, Out): elif isinstance(outputs, Out):
cloned_v = v_clone(outputs.variable) cloned_v = clone_v_get_shared_updates(outputs.variable)
cloned_outputs = Out(cloned_v, borrow=outputs.borrow) cloned_outputs = Out(cloned_v, borrow=outputs.borrow)
computed_list.append(cloned_v) #computed_list.append(cloned_v)
elif outputs is None: elif outputs is None:
cloned_outputs = [] # TODO: return None cloned_outputs = [] # TODO: return None
else: else:
raise TypeError('output must be a theano Variable or Out instance (or list of them)', outputs) raise TypeError('output must be a theano Variable or Out instance (or list of them)', outputs)
# Add update values as quantities that must be computed. # Iterate over update_expr, cloning its elements, and updating
# Here, we # shared_inputs, update_d and update_expr from the SharedVariables
# - extend the computed_list # we discover.
# - replace some update expressions (but update keys remain) # If the variable to be updated is a shared variable not already
new_updates = {} # in shared_inputs, add it.
for (store_into, update_val) in iter_over_pairs(updates): # Note: we extend update_expr while iterating over it.
if not isinstance(store_into, SharedVariable): i = 0
raise TypeError('update target must be a SharedVariable', store_into) while i<len(update_expr):
if store_into in new_updates: v, v_update = update_expr[i]
raise ValueError('this shared variable already has an update expression', cloned_v_update = clone_v_get_shared_updates(v_update)
(store_into, new_updates[store_into])) update_d[v] = cloned_v_update
update_val = v_clone(store_into.filter_update(update_val)) if isinstance(v, SharedVariable) and v not in shared_inputs:
if update_val.type != store_into.type: shared_inputs.append(v)
raise TypeError('an update must have the same type as the original shared variable', i += 1
(store_into, store_into.type,
update_val, update_val.type)) #updates = update_d #?
computed_list.append(update_val) for sv in shared_inputs:
new_updates[store_into] = update_val if sv in update_d:
updates = new_updates si = In(variable=sv, value=sv.container, mutable=True,
update=update_d[sv])
# Obtain all inputs we need to compute what we want.
graph_inputs = graph.inputs(computed_list,
blockers=set_of_param_variables)
shared_inputs = [i for i in graph_inputs if isinstance(i, SharedVariable)]
# Add shared variables (from shared_inputs) that were not already present in the list of
# params.
inputs += [In(variable=si, value=si.container, mutable=False)
for si in shared_inputs
if si not in set_of_param_variables]
del shared_inputs
# Iterate over the updates, which are either pairs
# (shared_var, expressionvariable), or a similar dictionary.
# For each shared_variable, find the In instance that we created for it in the inputs list.
# Give that In instance (in_sv) an update expression.
#
# I think we usually want to set these Inputs to be mutable,
# ... are there exceptions?
for (sv, new_val) in iter_over_pairs(updates):
in_sv = None
for in_sv_i in inputs:
if in_sv_i.variable is sv:
assert in_sv is None
in_sv = in_sv_i
if in_sv is None:
# This variable was not used anywhere and thus is not in the input
# list yet.
inputs.append(In(variable=sv, value=sv.container, mutable=True,
update=new_val))
else: else:
in_sv.update = new_val si = In(variable=sv, value=sv.container, mutable=False)
in_sv.mutable = True inputs.append(si)
return orig_function(inputs, cloned_outputs, mode,
accept_inplace=accept_inplace, name=name)
if 0:
# Add update values as quantities that must be computed.
# Here, we
# - extend the computed_list
# - replace some update expressions (but update keys remain)
new_updates = {}
for (store_into, update_val) in iter_over_pairs(updates):
if not isinstance(store_into, SharedVariable):
raise TypeError('update target must be a SharedVariable', store_into)
if store_into in new_updates:
raise ValueError('this shared variable already has an update expression',
(store_into, new_updates[store_into]))
update_val = v_clone(store_into.filter_update(update_val))
if update_val.type != store_into.type:
raise TypeError('an update must have the same type as the original shared variable',
(store_into, store_into.type,
update_val, update_val.type))
computed_list.append(update_val)
new_updates[store_into] = update_val
updates = new_updates
# Obtain all inputs we need to compute what we want.
graph_inputs = graph.inputs(computed_list,
blockers=set_of_param_variables)
shared_inputs = [i for i in graph_inputs if isinstance(i, SharedVariable)]
# Add shared variables (from shared_inputs) that were not already present in the list of
# params.
inputs += [In(variable=si, value=si.container, mutable=False)
for si in shared_inputs
if si not in set_of_param_variables]
del shared_inputs
# Iterate over the updates, which are either pairs
# (shared_var, expressionvariable), or a similar dictionary.
# For each shared_variable, find the In instance that we created for it in the inputs list.
# Give that In instance (in_sv) an update expression.
#
# I think we usually want to set these Inputs to be mutable,
# ... are there exceptions?
for (sv, new_val) in iter_over_pairs(updates):
in_sv = None
for in_sv_i in inputs:
if in_sv_i.variable is sv:
assert in_sv is None
in_sv = in_sv_i
if in_sv is None:
# This variable was not used anywhere and thus is not in the input
# list yet.
inputs.append(In(variable=sv, value=sv.container, mutable=True,
update=new_val))
else:
in_sv.update = new_val
in_sv.mutable = True
return orig_function(inputs, cloned_outputs, mode, accept_inplace=accept_inplace,name=name) return orig_function(inputs, cloned_outputs, mode, accept_inplace=accept_inplace,name=name)
def _pfunc_param_to_in(param): def _pfunc_param_to_in(param):
if isinstance(param, Constant): if isinstance(param, Constant):
......
...@@ -28,6 +28,11 @@ class SharedVariable(Variable): ...@@ -28,6 +28,11 @@ class SharedVariable(Variable):
:type: `Container` :type: `Container`
""" """
# default_update
# If this member is present, its value will be used as the "update" for
# this Variable, unless another update value has been passed to "function",
# or the "no_default_updates" list passed to "function" contains it.
def __init__(self, name, type, value, strict, container=None): def __init__(self, name, type, value, strict, container=None):
""" """
:param name: The name for this variable (see `Variable`). :param name: The name for this variable (see `Variable`).
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论