提交 032a0aa6 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2345 from abergeron/multi_fixes

Add support for node context
...@@ -43,6 +43,12 @@ There are less methods to define for an Op than for a Type: ...@@ -43,6 +43,12 @@ There are less methods to define for an Op than for a Type:
that a python exception is set) if your C code needs to that a python exception is set) if your C code needs to
raise an exception. raise an exception.
``sub['context']``
(optional) The name of the variable which holds the context
for the node. This will only appear if the op has requested
a context by having a :meth:`get_context()` method that return
something other than None.
.. method:: c_code_cleanup(node, name, input_names, output_names, sub) .. method:: c_code_cleanup(node, name, input_names, output_names, sub)
...@@ -112,6 +118,13 @@ There are less methods to define for an Op than for a Type: ...@@ -112,6 +118,13 @@ There are less methods to define for an Op than for a Type:
that a python exception is set) if your C code needs to that a python exception is set) if your C code needs to
raise an exception. raise an exception.
``sub['context']``
(optional) The name of the variable which holds the context
for the node. This will only appear if the op has requested
a context by having a :meth:`get_context()` method that return
something other than None.
.. method:: c_support_code() .. method:: c_support_code()
Allows you to specify helper functions/structs that the Allows you to specify helper functions/structs that the
...@@ -186,6 +199,19 @@ There are less methods to define for an Op than for a Type: ...@@ -186,6 +199,19 @@ There are less methods to define for an Op than for a Type:
is high or when theano compilation directory is shared by many is high or when theano compilation directory is shared by many
process (like on a network file server on a cluster). process (like on a network file server on a cluster).
.. method:: get_context(node)
(optional) If defined, should return the runtime context the op
needs. This context will be passed to the C code through the
variable named in `sub['context']`. The variable is also
available for use in the code returned by
:meth:`c_init_code_struct`. If it returns `None` this is
considered the same as if the method was not defined.
If this method is defined and does not return `None`, then the
Op *must* have a `context_type` property with the Type to use
for the context variable.
The ``name`` argument is currently given an invalid value, so steer The ``name`` argument is currently given an invalid value, so steer
away from it. As was the case with Type, ``sub['fail']`` provides away from it. As was the case with Type, ``sub['fail']`` provides
failure code that you *must* use if you want to raise an exception, failure code that you *must* use if you want to raise an exception,
......
...@@ -14,19 +14,17 @@ import theano ...@@ -14,19 +14,17 @@ import theano
from theano import gof from theano import gof
from theano.compat import get_unbound_function from theano.compat import get_unbound_function
from theano.compat.six import StringIO from theano.compat.six import StringIO
from theano.gof import FunctionGraph,graph, utils, link, ops_with_inner_function from theano.gof import (FunctionGraph, graph, utils, link,
ops_with_inner_function)
from theano.gof.link import raise_with_op from theano.gof.link import raise_with_op
from theano.gof.cc import CLinker from theano.gof.cc import CLinker
from theano.gof.python25 import all, any, product as itertools_product from theano.gof.python25 import all, any, product as itertools_product
from theano.configparser import (config, AddConfigVar, BoolParam, IntParam, from theano.configparser import (config, AddConfigVar, BoolParam, IntParam,
StrParam) StrParam)
from theano.compile.function_module import (FunctionMaker, from theano.compile.function_module import (
Function, FunctionMaker, Function, infer_reuse_pattern,
infer_reuse_pattern, SymbolicInputKit, SymbolicOutput, Supervisor, std_fgraph
SymbolicInputKit, )
SymbolicOutput,
Supervisor,
std_fgraph)
from theano.compile.mode import Mode, register_mode from theano.compile.mode import Mode, register_mode
from theano.compile.ops import OutputGuard from theano.compile.ops import OutputGuard
...@@ -1694,9 +1692,16 @@ class _Linker(gof.link.LocalLinker): ...@@ -1694,9 +1692,16 @@ class _Linker(gof.link.LocalLinker):
if ((self.maker.mode.check_py_code or thunks_c[-1] is None) and if ((self.maker.mode.check_py_code or thunks_c[-1] is None) and
node.op.perform.func_code != gof.op.PureOp.perform.func_code): node.op.perform.func_code != gof.op.PureOp.perform.func_code):
p = node.op.perform p = node.op.perform
ctx = node.run_context()
if ctx is graph.NoContext:
thunk = (lambda p=p, i=node_input_storage, thunk = (lambda p=p, i=node_input_storage,
o=node_output_storage, o=node_output_storage,
n=node: p(n, [x[0] for x in i], o)) n=node: p(n, [x[0] for x in i], o))
else:
ctx_val = node.context_type.filter(ctx)
thunk = (lambda p=p, i=node_input_storage,
o=node_output_storage, ctx=ctx_val,
n=node: p(n, [x[0] for x in i], o, ctx))
thunk.inputs = node_input_storage thunk.inputs = node_input_storage
thunk.outputs = node_output_storage thunk.outputs = node_output_storage
thunk.perform = p thunk.perform = p
......
...@@ -503,12 +503,32 @@ class CLinker(link.Linker): ...@@ -503,12 +503,32 @@ class CLinker(link.Linker):
self.inputs = fgraph.inputs self.inputs = fgraph.inputs
self.outputs = fgraph.outputs self.outputs = fgraph.outputs
self.node_order = self.schedule(fgraph)
# list(fgraph.variables) # list(fgraph.variables)
# We need to include the not used inputs in our variables, # We need to include the unused inputs in our variables,
# otherwise we can't pass them to the module. # otherwise we can't pass them to the module.
self.variables = [var for var in self.inputs if not len(var.clients)] self.variables = [var for var in self.inputs if not len(var.clients)]
self.variables += graph.variables(self.inputs, self.outputs) self.variables += graph.variables(self.inputs, self.outputs)
# This adds a hidden input which is the context for each node
# that needs it
self.contexts = dict()
for node in self.node_order:
ctx = node.run_context()
if ctx is not graph.NoContext:
# try to avoid creating more than one variable for the
# same context.
if ctx in self.contexts:
var = self.contexts[ctx]
assert var.type == node.context_type
var.clients.append((node, 'context'))
else:
var = graph.Constant(node.context_type, ctx)
var.clients = [(node, 'context')]
self.contexts[ctx] = var
self.variables.append(var)
# The orphans field is listified to ensure a consistent order. # The orphans field is listified to ensure a consistent order.
#list(fgraph.orphans.difference(self.outputs)) #list(fgraph.orphans.difference(self.outputs))
self.orphans = list(r for r in self.variables self.orphans = list(r for r in self.variables
...@@ -517,7 +537,6 @@ class CLinker(link.Linker): ...@@ -517,7 +537,6 @@ class CLinker(link.Linker):
self.temps = list(set(self.variables).difference( self.temps = list(set(self.variables).difference(
self.inputs).difference(self.outputs).difference(self.orphans)) self.inputs).difference(self.outputs).difference(self.orphans))
self.consts = [] self.consts = []
self.node_order = self.schedule(fgraph)
def code_gen(self): def code_gen(self):
"""WRITEME """WRITEME
...@@ -642,9 +661,13 @@ class CLinker(link.Linker): ...@@ -642,9 +661,13 @@ class CLinker(link.Linker):
id += 2 id += 2
for node_num, node in enumerate(self.node_order): for node_num, node in enumerate(self.node_order):
# Why is this here?
sub = dict(failure_var=failure_var) sub = dict(failure_var=failure_var)
ctx = node.run_context()
if ctx is not graph.NoContext:
context_var = symbol[self.contexts[ctx]]
# The placeholder will be replaced by a hash of the entire # The placeholder will be replaced by a hash of the entire
# code (module + support code) in DynamicModule.code. # code (module + support code) in DynamicModule.code.
# This ensures that, when defining functions in support code, # This ensures that, when defining functions in support code,
...@@ -659,10 +682,16 @@ class CLinker(link.Linker): ...@@ -659,10 +682,16 @@ class CLinker(link.Linker):
# Make the CodeBlock for c_code # Make the CodeBlock for c_code
sub['id'] = id sub['id'] = id
sub['fail'] = failure_code(sub) sub['fail'] = failure_code(sub)
if ctx is not graph.NoContext:
sub['context'] = context_var
sub_struct = dict() sub_struct = dict()
sub_struct['id'] = id + 1 sub_struct['id'] = id + 1
sub_struct['fail'] = failure_code_init(sub) sub_struct['fail'] = failure_code_init(sub)
if ctx is not graph.NoContext:
# Since context inputs are always constants they are
# guaranteed to be available in the struct init code.
sub_struct['context'] = context_var
struct_support = "" struct_support = ""
struct_init = "" struct_init = ""
...@@ -1422,8 +1451,8 @@ class CLinker(link.Linker): ...@@ -1422,8 +1451,8 @@ class CLinker(link.Linker):
in_storage = [x for i, x in enumerate(in_storage) if i not in dupidx] in_storage = [x for i, x in enumerate(in_storage) if i not in dupidx]
orphd = [[orphan.data] for orphan in self.orphans] orphd = [[orphan.data] for orphan in self.orphans]
ret = module.instantiate(error_storage, *(in_storage + out_storage + ret = module.instantiate(error_storage,
orphd)) *(in_storage + out_storage + orphd))
return ret return ret
......
...@@ -23,6 +23,7 @@ from theano.misc.ordered_set import OrderedSet ...@@ -23,6 +23,7 @@ from theano.misc.ordered_set import OrderedSet
is_same_graph_with_merge = None is_same_graph_with_merge = None
equal_computations = None equal_computations = None
NoContext = object()
class Node(utils.object2): class Node(utils.object2):
"""A Node in a theano graph. """A Node in a theano graph.
...@@ -116,6 +117,13 @@ class Apply(Node): ...@@ -116,6 +117,13 @@ class Apply(Node):
else: else:
raise TypeError("The 'outputs' argument to Apply must contain Variable instances with no owner, not %s" % output) raise TypeError("The 'outputs' argument to Apply must contain Variable instances with no owner, not %s" % output)
def run_context(self):
"""Returns the context for the node, or NoContext if no context is set.
"""
if hasattr(self.op, 'get_context'):
return self.op.get_context(self)
return NoContext
def default_output(self): def default_output(self):
"""Returns the default output for this node. """Returns the default output for this node.
...@@ -238,6 +246,8 @@ class Apply(Node): ...@@ -238,6 +246,8 @@ class Apply(Node):
nout = property(lambda self: len(self.outputs), doc='same as len(self.outputs)') nout = property(lambda self: len(self.outputs), doc='same as len(self.outputs)')
"""property: Number of outputs""" """property: Number of outputs"""
context_type = property(lambda self: self.op.context_type, doc='type to use for the context')
class Variable(Node): class Variable(Node):
""" """
......
...@@ -733,13 +733,24 @@ class Op(utils.object2, PureOp, CLinkerOp): ...@@ -733,13 +733,24 @@ class Op(utils.object2, PureOp, CLinkerOp):
# condition: either there was no c_code, or it failed # condition: either there was no c_code, or it failed
p = node.op.perform p = node.op.perform
# default arguments are stored in the closure of `rval`
ctx = node.run_context()
if ctx is graph.NoContext:
# default arguments are stored in the closure of `rval`
def rval(p=p, i=node_input_storage, o=node_output_storage, n=node): def rval(p=p, i=node_input_storage, o=node_output_storage, n=node):
r = p(n, [x[0] for x in i], o) r = p(n, [x[0] for x in i], o)
for o in node.outputs: for o in node.outputs:
compute_map[o][0] = True compute_map[o][0] = True
return r return r
else:
ctx_val = node.context_type.filter(ctx)
def rval(p=p, i=node_input_storage, o=node_output_storage, n=node,
ctx=ctx_val):
r = p(n, [x[0] for x in i], o, ctx)
for o in node.outputs:
compute_map[o][0] = True
return r
rval.inputs = node_input_storage rval.inputs = node_input_storage
rval.outputs = node_output_storage rval.outputs = node_output_storage
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论