提交 bc1ffeea authored 作者: James Bergstra's avatar James Bergstra

added givens parameter to compile.sandbox.pfunc

上级 88f7f401
......@@ -32,7 +32,7 @@ class Param(object):
self.strict = strict
self.implicit = implicit
def pfunc(params, outputs=None, mode=None, updates=[]):
def pfunc(params, outputs=None, mode=None, updates=[], givens=[]):
"""Function-constructor for graphs with shared variables.
:type params: list of either Variable or Param instances.
......@@ -42,38 +42,72 @@ def pfunc(params, outputs=None, mode=None, updates=[]):
:type outputs: list of Variables or Out instances
:param outputs: expressions to compute
:type mode: string or `theano.compile.Mode` instance.
:param mode: compilation mode
:type updates: iterable over pairs (shared_variable, new_expression). List, tuple or dict.
:param updates: update the values for SharedVariable inputs according to these expressions
:type givens: iterable over pairs (Var1, Var2) of Variables. List, tuple or dict. The Var1
and Var2 in each pair must have the same Type.
:param givens: specific substitutions to make in the computation graph (Var2 replaces
Var1).
:rtype: theano.compile.Function
:returns: a callable object that will compute the outputs (given the inputs)
and update the implicit function arguments according to the `updates`.
:note: Regarding givens: Be careful to make sure that these substitutions are
independent--behaviour when Var1 of one pair appears in the graph leading to Var2 in
another expression is undefined. Replacements specified with givens are different from
optimizations in that Var2 is not expected to be equivalent to Var1.
"""
# Note: in its early design, pfunc was also meant to accept another
# parameter, 'givens'. This was a dictionary assigning some specific
# values to some of the Variable in the graph, so as to allow the
# function to possibly make some optimizations at compile time.
# In the end, this feature was not kept, because it was not obvious
# how to implement it, nor whether it was really needed.
# If one wants to add this feature in the future, it may be easier instead
# to add a new parameter to 'Param' to indicate that some input of the
# function is taking a specific constant value.
if not isinstance(outputs, list):
computed_list = [outputs]
else:
# Copy list (because it may be extended later).
computed_list = [out for out in outputs]
#
# This function works by cloning the graph (except for the inputs), and then shipping it
# off to compile.function
# (There it will be cloned again, unnecessarily, because it doesn't know that we already
# cloned it.)
#
# First, it clones the replacements named in the givens argument, and points each Var1 to
# the clone of Var2.
# Then it sets the inputs in the clone dictionary.
# After these steps, we are assuming that the clone dictionary contains all the inputs to
# the computation graph.
#
# Then it clones the outputs and the update expressions. This rebuilds a computation graph
# from the inputs and the givens.
#
# initialize the clone_d mapping with the `givens` argument
clone_d = {}
def v_clone(v):
return _v_clone(v, clone_d)
try:
givens = givens.items() # converts a dictionary to the sort of list that we want.
except:
pass
for v_orig, v_repl in givens:
if not isinstance(v_orig, Variable):
raise TypeError('given keys must be Variable', v_orig)
if not isinstance(v_repl, Variable):
v_repl = shared(v_repl)
assert v_orig not in clone_d
clone_d[v_orig] = v_clone(v_repl)
# transform params into theano.compile.In objects.
#
# call theano.function
inputs = [_pfunc_param_to_in(p) for p in params]
set_of_param_variables = set([i.variable for i in inputs])
#Switch inputs to cloned variables
input_variables = [clone_d.setdefault(i.variable, i.variable) for i in inputs]
for i, iv in zip(inputs, input_variables):
i.variable = iv
set_of_param_variables = set(input_variables)
# It was decided, as a first step, to prevent shared variables from being
# used as function inputs. Although it is technically possible, it is also
......@@ -83,11 +117,28 @@ def pfunc(params, outputs=None, mode=None, updates=[]):
raise TypeError('Cannot use a shared variable (%s) as explicit input '
% v)
# computed_list is a list of output variables
if isinstance(outputs, list):
for v in outputs:
if not isinstance(v, Variable):
raise TypeError('outputs must be theano Variable instances', v)
# Copy list (because it may be extended later).
computed_list = [v_clone(o) for o in outputs]
cloned_outputs = list(computed_list)
else:
if not isinstance(outputs, Variable):
raise TypeError('output must be a theano Variable instance', outputs)
cloned_outputs = v_clone(outputs)
computed_list = [cloned_outputs]
# Add update values as quantities that must be computed.
# Here, we
# - extend the computed_list
# - replace some update expressions (but update keys remain)
new_updates = {}
for (store_into, update_val) in iter_over_pairs(updates):
assert isinstance(store_into, SharedVariable)
update_val = store_into.filter_update(update_val)
update_val = v_clone(store_into.filter_update(update_val))
if update_val.type != store_into.type:
raise TypeError('an update must have the same type as the original shared variable',
(store_into, store_into.type,
......@@ -98,7 +149,7 @@ def pfunc(params, outputs=None, mode=None, updates=[]):
# Obtain all inputs we need to compute what we want.
graph_inputs = graph.inputs(computed_list,
blockers=set([i.variable for i in inputs]))
blockers=set_of_param_variables)
shared_inputs = [i for i in graph_inputs if isinstance(i, SharedVariable)]
......@@ -131,7 +182,7 @@ def pfunc(params, outputs=None, mode=None, updates=[]):
in_sv.update = new_val
in_sv.mutable = True
return function(inputs, outputs, mode, accept_inplace=False)
return function(inputs, cloned_outputs, mode, accept_inplace=False)
def _pfunc_param_to_in(param):
if isinstance(param, Constant):
......@@ -168,3 +219,23 @@ def iter_over_pairs(pairs):
return pairs.iteritems()
else:
return pairs
#TODO: Make these non-recursive so they can deal with larger graphs
def _a_clone(a, dct):
if a is None:
return None
if a not in dct:
for i in a.inputs:
_v_clone(i, dct)
dct[a] = a.clone_with_new_inputs([dct[i] for i in a.inputs])
for old_o, new_o in zip(a.outputs, dct[a].outputs):
dct.setdefault(old_o, new_o)
return dct[a]
def _v_clone(v, dct):
assert v is not None
if v.owner:
_a_clone(v.owner, dct)
return dct.setdefault(v, v)
......@@ -193,6 +193,22 @@ class Test_pfunc(unittest.TestCase):
inc_by_y()
self.failUnless(x.value == 1)
def test_givens(self):
x = shared(0)
assign = pfunc([], x, givens = {x: 3})
assert assign() == 3
assert x.value == 0
y = tensor.ivector()
f = pfunc([y], y*x, givens = {x : 6})
assert numpy.all(f([1,1,1]) == [6,6,6])
assert x.value == 0
z = tensor.ivector()
c = z*y
f = pfunc([y], c+7, givens = {z : numpy.asarray([4,4,4], dtype='int32')})
assert numpy.all(f([1,1,1]) == [11,11,11])
assert x.value == 0
if __name__ == '__main__':
theano.compile.mode.default_mode = 'FAST_COMPILE'
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论