提交 bc1ffeea authored 作者: James Bergstra's avatar James Bergstra

added givens parameter to compile.sandbox.pfunc

上级 88f7f401
...@@ -32,7 +32,7 @@ class Param(object): ...@@ -32,7 +32,7 @@ class Param(object):
self.strict = strict self.strict = strict
self.implicit = implicit self.implicit = implicit
def pfunc(params, outputs=None, mode=None, updates=[]): def pfunc(params, outputs=None, mode=None, updates=[], givens=[]):
"""Function-constructor for graphs with shared variables. """Function-constructor for graphs with shared variables.
:type params: list of either Variable or Param instances. :type params: list of either Variable or Param instances.
...@@ -42,38 +42,72 @@ def pfunc(params, outputs=None, mode=None, updates=[]): ...@@ -42,38 +42,72 @@ def pfunc(params, outputs=None, mode=None, updates=[]):
:type outputs: list of Variables or Out instances :type outputs: list of Variables or Out instances
:param outputs: expressions to compute :param outputs: expressions to compute
:type mode: string or `theano.compile.Mode` instance.
:param mode: compilation mode :param mode: compilation mode
:type updates: iterable over pairs (shared_variable, new_expression). List, tuple or dict. :type updates: iterable over pairs (shared_variable, new_expression). List, tuple or dict.
:param updates: update the values for SharedVariable inputs according to these expressions :param updates: update the values for SharedVariable inputs according to these expressions
:type givens: iterable over pairs (Var1, Var2) of Variables. List, tuple or dict. The Var1
and Var2 in each pair must have the same Type.
:param givens: specific substitutions to make in the computation graph (Var2 replaces
Var1).
:rtype: theano.compile.Function :rtype: theano.compile.Function
:returns: a callable object that will compute the outputs (given the inputs) :returns: a callable object that will compute the outputs (given the inputs)
and update the implicit function arguments according to the `updates`. and update the implicit function arguments according to the `updates`.
:note: Regarding givens: Be careful to make sure that these substitutions are
independent--behaviour when Var1 of one pair appears in the graph leading to Var2 in
another expression is undefined. Replacements specified with givens are different from
optimizations in that Var2 is not expected to be equivalent to Var1.
""" """
# Note: in its early design, pfunc was also meant to accept another #
# parameter, 'givens'. This was a dictionary assigning some specific # This function works by cloning the graph (except for the inputs), and then shipping it
# values to some of the Variable in the graph, so as to allow the # off to compile.function
# function to possibly make some optimizations at compile time. # (There it will be cloned again, unnecessarily, because it doesn't know that we already
# In the end, this feature was not kept, because it was not obvious # cloned it.)
# how to implement it, nor whether it was really needed. #
# If one wants to add this feature in the future, it may be easier instead # First, it clones the replacements named in the givens argument, and points each Var1 to
# to add a new parameter to 'Param' to indicate that some input of the # the clone of Var2.
# function is taking a specific constant value. # Then it sets the inputs in the clone dictionary.
# After these steps, we are assuming that the clone dictionary contains all the inputs to
if not isinstance(outputs, list): # the computation graph.
computed_list = [outputs] #
else: # Then it clones the outputs and the update expressions. This rebuilds a computation graph
# Copy list (because it may be extended later). # from the inputs and the givens.
computed_list = [out for out in outputs] #
# initialize the clone_d mapping with the `givens` argument
clone_d = {}
def v_clone(v):
return _v_clone(v, clone_d)
try:
givens = givens.items() # converts a dictionary to the sort of list that we want.
except:
pass
for v_orig, v_repl in givens:
if not isinstance(v_orig, Variable):
raise TypeError('given keys must be Variable', v_orig)
if not isinstance(v_repl, Variable):
v_repl = shared(v_repl)
assert v_orig not in clone_d
clone_d[v_orig] = v_clone(v_repl)
# transform params into theano.compile.In objects. # transform params into theano.compile.In objects.
# #
# call theano.function # call theano.function
inputs = [_pfunc_param_to_in(p) for p in params] inputs = [_pfunc_param_to_in(p) for p in params]
set_of_param_variables = set([i.variable for i in inputs]) #Switch inputs to cloned variables
input_variables = [clone_d.setdefault(i.variable, i.variable) for i in inputs]
for i, iv in zip(inputs, input_variables):
i.variable = iv
set_of_param_variables = set(input_variables)
# It was decided, as a first step, to prevent shared variables from being # It was decided, as a first step, to prevent shared variables from being
# used as function inputs. Although it is technically possible, it is also # used as function inputs. Although it is technically possible, it is also
...@@ -83,11 +117,28 @@ def pfunc(params, outputs=None, mode=None, updates=[]): ...@@ -83,11 +117,28 @@ def pfunc(params, outputs=None, mode=None, updates=[]):
raise TypeError('Cannot use a shared variable (%s) as explicit input ' raise TypeError('Cannot use a shared variable (%s) as explicit input '
% v) % v)
# computed_list is a list of output variables
if isinstance(outputs, list):
for v in outputs:
if not isinstance(v, Variable):
raise TypeError('outputs must be theano Variable instances', v)
# Copy list (because it may be extended later).
computed_list = [v_clone(o) for o in outputs]
cloned_outputs = list(computed_list)
else:
if not isinstance(outputs, Variable):
raise TypeError('output must be a theano Variable instance', outputs)
cloned_outputs = v_clone(outputs)
computed_list = [cloned_outputs]
# Add update values as quantities that must be computed. # Add update values as quantities that must be computed.
# Here, we
# - extend the computed_list
# - replace some update expressions (but update keys remain)
new_updates = {} new_updates = {}
for (store_into, update_val) in iter_over_pairs(updates): for (store_into, update_val) in iter_over_pairs(updates):
assert isinstance(store_into, SharedVariable) assert isinstance(store_into, SharedVariable)
update_val = store_into.filter_update(update_val) update_val = v_clone(store_into.filter_update(update_val))
if update_val.type != store_into.type: if update_val.type != store_into.type:
raise TypeError('an update must have the same type as the original shared variable', raise TypeError('an update must have the same type as the original shared variable',
(store_into, store_into.type, (store_into, store_into.type,
...@@ -98,7 +149,7 @@ def pfunc(params, outputs=None, mode=None, updates=[]): ...@@ -98,7 +149,7 @@ def pfunc(params, outputs=None, mode=None, updates=[]):
# Obtain all inputs we need to compute what we want. # Obtain all inputs we need to compute what we want.
graph_inputs = graph.inputs(computed_list, graph_inputs = graph.inputs(computed_list,
blockers=set([i.variable for i in inputs])) blockers=set_of_param_variables)
shared_inputs = [i for i in graph_inputs if isinstance(i, SharedVariable)] shared_inputs = [i for i in graph_inputs if isinstance(i, SharedVariable)]
...@@ -131,7 +182,7 @@ def pfunc(params, outputs=None, mode=None, updates=[]): ...@@ -131,7 +182,7 @@ def pfunc(params, outputs=None, mode=None, updates=[]):
in_sv.update = new_val in_sv.update = new_val
in_sv.mutable = True in_sv.mutable = True
return function(inputs, outputs, mode, accept_inplace=False) return function(inputs, cloned_outputs, mode, accept_inplace=False)
def _pfunc_param_to_in(param): def _pfunc_param_to_in(param):
if isinstance(param, Constant): if isinstance(param, Constant):
...@@ -168,3 +219,23 @@ def iter_over_pairs(pairs): ...@@ -168,3 +219,23 @@ def iter_over_pairs(pairs):
return pairs.iteritems() return pairs.iteritems()
else: else:
return pairs return pairs
#TODO: Make these non-recursive so they can deal with larger graphs
def _a_clone(a, dct):
if a is None:
return None
if a not in dct:
for i in a.inputs:
_v_clone(i, dct)
dct[a] = a.clone_with_new_inputs([dct[i] for i in a.inputs])
for old_o, new_o in zip(a.outputs, dct[a].outputs):
dct.setdefault(old_o, new_o)
return dct[a]
def _v_clone(v, dct):
assert v is not None
if v.owner:
_a_clone(v.owner, dct)
return dct.setdefault(v, v)
...@@ -193,6 +193,22 @@ class Test_pfunc(unittest.TestCase): ...@@ -193,6 +193,22 @@ class Test_pfunc(unittest.TestCase):
inc_by_y() inc_by_y()
self.failUnless(x.value == 1) self.failUnless(x.value == 1)
def test_givens(self):
x = shared(0)
assign = pfunc([], x, givens = {x: 3})
assert assign() == 3
assert x.value == 0
y = tensor.ivector()
f = pfunc([y], y*x, givens = {x : 6})
assert numpy.all(f([1,1,1]) == [6,6,6])
assert x.value == 0
z = tensor.ivector()
c = z*y
f = pfunc([y], c+7, givens = {z : numpy.asarray([4,4,4], dtype='int32')})
assert numpy.all(f([1,1,1]) == [11,11,11])
assert x.value == 0
if __name__ == '__main__': if __name__ == '__main__':
theano.compile.mode.default_mode = 'FAST_COMPILE' theano.compile.mode.default_mode = 'FAST_COMPILE'
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论