提交 2c2edc8c authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add a way for nodes to specify the context in which they run.

This depends on the Op that is in the node and context is optional.
上级 26a22806
......@@ -43,6 +43,12 @@ There are less methods to define for an Op than for a Type:
that a python exception is set) if your C code needs to
raise an exception.
``sub['context']``
(optional) The name of the variable which holds the context
for the node. This will only appear if the op has requested
a context by having a :meth:`get_context()` method that return
something other than None.
.. method:: c_code_cleanup(node, name, input_names, output_names, sub)
......@@ -112,6 +118,13 @@ There are less methods to define for an Op than for a Type:
that a python exception is set) if your C code needs to
raise an exception.
``sub['context']``
(optional) The name of the variable which holds the context
for the node. This will only appear if the op has requested
a context by having a :meth:`get_context()` method that return
something other than None.
.. method:: c_support_code()
Allows you to specify helper functions/structs that the
......@@ -186,6 +199,19 @@ There are less methods to define for an Op than for a Type:
is high or when theano compilation directory is shared by many
process (like on a network file server on a cluster).
.. method:: get_context(node)
(optional) If defined, should return the runtime context the op
needs. This context will be passed to the C code through the
variable named in `sub['context']`. The variable is also
available for use in the code returned by
:meth:`c_init_code_struct`. If it returns `None` this is
considered the same as if the method was not defined.
If this method is defined and does not return `None`, then the
Op *must* have a `context_type` property with the Type to use
for the context variable.
The ``name`` argument is currently given an invalid value, so steer
away from it. As was the case with Type, ``sub['fail']`` provides
failure code that you *must* use if you want to raise an exception,
......
......@@ -503,12 +503,32 @@ class CLinker(link.Linker):
self.inputs = fgraph.inputs
self.outputs = fgraph.outputs
self.node_order = self.schedule(fgraph)
# list(fgraph.variables)
# We need to include the not used inputs in our variables,
# We need to include the unused inputs in our variables,
# otherwise we can't pass them to the module.
self.variables = [var for var in self.inputs if not len(var.clients)]
self.variables += graph.variables(self.inputs, self.outputs)
# This adds a hidden input which is the context for each node
# that needs it
self.contexts = dict()
for node in self.node_order:
ctx = node.run_context()
if ctx is not None:
# try to avoid creating more than one variable for the
# same context.
if ctx in self.contexts:
var = self.contexts[ctx]
assert var.type == node.context_type
var.clients.append((node, 'context'))
else:
var = graph.Constant(node.context_type, ctx)
var.clients = [(node, 'context')]
self.contexts[ctx] = var
self.variables.append(var)
# The orphans field is listified to ensure a consistent order.
#list(fgraph.orphans.difference(self.outputs))
self.orphans = list(r for r in self.variables
......@@ -517,7 +537,6 @@ class CLinker(link.Linker):
self.temps = list(set(self.variables).difference(
self.inputs).difference(self.outputs).difference(self.orphans))
self.consts = []
self.node_order = self.schedule(fgraph)
def code_gen(self):
"""WRITEME
......@@ -642,9 +661,13 @@ class CLinker(link.Linker):
id += 2
for node_num, node in enumerate(self.node_order):
# Why is this here?
sub = dict(failure_var=failure_var)
ctx = node.run_context()
if ctx is not None:
context_var = symbol[self.contexts[ctx]]
# The placeholder will be replaced by a hash of the entire
# code (module + support code) in DynamicModule.code.
# This ensures that, when defining functions in support code,
......@@ -659,10 +682,16 @@ class CLinker(link.Linker):
# Make the CodeBlock for c_code
sub['id'] = id
sub['fail'] = failure_code(sub)
if ctx is not None:
sub['context'] = context_var
sub_struct = dict()
sub_struct['id'] = id + 1
sub_struct['fail'] = failure_code_init(sub)
if ctx is not None:
# Since context inputs are always constants they are
# guarenteed to be available in the struct init code.
sub_struct['context'] = context_var
struct_support = ""
struct_init = ""
......@@ -1433,8 +1462,8 @@ class CLinker(link.Linker):
in_storage = [x for i, x in enumerate(in_storage) if i not in dupidx]
orphd = [[orphan.data] for orphan in self.orphans]
ret = module.instantiate(error_storage, *(in_storage + out_storage +
orphd))
ret = module.instantiate(error_storage,
*(in_storage + out_storage + orphd))
return ret
......
......@@ -115,6 +115,13 @@ class Apply(Node):
else:
raise TypeError("The 'outputs' argument to Apply must contain Variable instances with no owner, not %s" % output)
def run_context(self):
"""Returns the context for the node, or None if no context is set.
"""
if hasattr(self.op, 'get_context'):
return self.op.get_context(self)
return None
def default_output(self):
"""Returns the default output for this node.
......@@ -237,6 +244,8 @@ class Apply(Node):
nout = property(lambda self: len(self.outputs), doc='same as len(self.outputs)')
"""property: Number of outputs"""
context_type = property(lambda self: self.op.context_type, doc='type to use for the context')
class Variable(Node):
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论