提交 f8dfe6a9 authored 作者: James Bergstra's avatar James Bergstra

added some code comments to pfunc

上级 0c5936e2
...@@ -98,10 +98,10 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], ...@@ -98,10 +98,10 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
clone_d = {} clone_d = {}
# Updates as list and dictionary. # Updates as list and dictionary.
# They will also store the 'default_update' expressions applicable. # They will both store the 'default_update' expressions (where applicable).
# The dictionary is used to look up the existence of the keys, and to store # The dictionary (update_d) is used to look up the existence of the keys, and to store
# the final (cloned) update expressions. # the final [cloned] update expressions.
# The list of pairs is used to iterate in a consistent order while adding # The list of pairs (update_expr) is used to iterate in a consistent order while adding
# new pairs. # new pairs.
update_d = {} update_d = {}
update_expr = [] update_expr = []
...@@ -109,10 +109,11 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], ...@@ -109,10 +109,11 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
shared_inputs = [] shared_inputs = []
def clone_v_get_shared_updates(v): def clone_v_get_shared_updates(v):
'''Clone a variable and its inputs, until all are in clone_d. '''Clone a variable and its inputs recursively until all are in clone_d.
Also appends all shared variables met along the way to shared_inputs, Also appends all shared variables met along the way to shared_inputs,
and their default_update (if applicable) to update_d and update_expr. and their default_update (if applicable) to update_d and update_expr.
''' '''
# this method co-recurses with clone_a
assert v is not None assert v is not None
if v.owner: if v.owner:
clone_a(v.owner) clone_a(v.owner)
...@@ -137,6 +138,7 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], ...@@ -137,6 +138,7 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
return clone_d.setdefault(v, v) return clone_d.setdefault(v, v)
def clone_a(a): def clone_a(a):
# this method co-recurses with clone_v_get_shared_updates
if a is None: if a is None:
return None return None
if a not in clone_d: if a not in clone_d:
...@@ -174,12 +176,13 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], ...@@ -174,12 +176,13 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
#set_of_param_variables = set(input_variables) #set_of_param_variables = set(input_variables)
# It was decided, as a first step, to prevent shared variables from being # It was decided, as a first step, to prevent shared variables from being
# used as function inputs. Although it is technically possible, it is also # used as function inputs. Although it is technically possible, it is also not clear
# potentially ambiguous and dangerous. This restriction may be revisited in # when/how to use the value of that shared variable (is it a default? ignored?, if the
# the future if there is a need for such a feature. # shared variable changes, does that function default also change?).
if numpy.any([isinstance(v, SharedVariable) for v in input_variables]): if numpy.any([isinstance(v, SharedVariable) for v in input_variables]):
raise TypeError('Cannot use a shared variable (%s) as explicit input ' raise TypeError(('Cannot use a shared variable (%s) as explicit input.'
% v) ' Consider substituting a non-shared'
' variable via the `givens` parameter') % v)
# Fill update_d and update_expr with provided updates # Fill update_d and update_expr with provided updates
for (store_into, update_val) in iter_over_pairs(updates): for (store_into, update_val) in iter_over_pairs(updates):
...@@ -189,7 +192,7 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], ...@@ -189,7 +192,7 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
raise ValueError('this shared variable already has an update expression', raise ValueError('this shared variable already has an update expression',
(store_into, update_d[store_into])) (store_into, update_d[store_into]))
update_val = store_into.filter_update(update_val) update_val = store_into.filter_update(update_val) # typically this might be a cast()
if update_val.type != store_into.type: if update_val.type != store_into.type:
raise TypeError('an update must have the same type as the original shared variable', raise TypeError('an update must have the same type as the original shared variable',
(store_into, store_into.type, (store_into, store_into.type,
...@@ -224,7 +227,7 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], ...@@ -224,7 +227,7 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
cloned_outputs = Out(cloned_v, borrow=outputs.borrow) cloned_outputs = Out(cloned_v, borrow=outputs.borrow)
#computed_list.append(cloned_v) #computed_list.append(cloned_v)
elif outputs is None: elif outputs is None:
cloned_outputs = [] # TODO: return None cloned_outputs = [] # TODO: get Function.__call__ to return None
else: else:
raise TypeError('output must be a theano Variable or Out instance (or list of them)', outputs) raise TypeError('output must be a theano Variable or Out instance (or list of them)', outputs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论