提交 eeb32bd2 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Allow None as a valid context.

The NoContext object in graph will serve a "absent context" marker.
上级 df669cc6
...@@ -516,7 +516,7 @@ class CLinker(link.Linker): ...@@ -516,7 +516,7 @@ class CLinker(link.Linker):
self.contexts = dict() self.contexts = dict()
for node in self.node_order: for node in self.node_order:
ctx = node.run_context() ctx = node.run_context()
if ctx is not None: if ctx is not graph.NoContext:
# try to avoid creating more than one variable for the # try to avoid creating more than one variable for the
# same context. # same context.
if ctx in self.contexts: if ctx in self.contexts:
...@@ -665,7 +665,7 @@ class CLinker(link.Linker): ...@@ -665,7 +665,7 @@ class CLinker(link.Linker):
sub = dict(failure_var=failure_var) sub = dict(failure_var=failure_var)
ctx = node.run_context() ctx = node.run_context()
if ctx is not None: if ctx is not graph.NoContext:
context_var = symbol[self.contexts[ctx]] context_var = symbol[self.contexts[ctx]]
# The placeholder will be replaced by a hash of the entire # The placeholder will be replaced by a hash of the entire
...@@ -682,15 +682,15 @@ class CLinker(link.Linker): ...@@ -682,15 +682,15 @@ class CLinker(link.Linker):
# Make the CodeBlock for c_code # Make the CodeBlock for c_code
sub['id'] = id sub['id'] = id
sub['fail'] = failure_code(sub) sub['fail'] = failure_code(sub)
if ctx is not None: if ctx is not graph.NoContext:
sub['context'] = context_var sub['context'] = context_var
sub_struct = dict() sub_struct = dict()
sub_struct['id'] = id + 1 sub_struct['id'] = id + 1
sub_struct['fail'] = failure_code_init(sub) sub_struct['fail'] = failure_code_init(sub)
if ctx is not None: if ctx is not graph.NoContext:
# Since context inputs are always constants they are # Since context inputs are always constants they are
# guarenteed to be available in the struct init code. # guaranteed to be available in the struct init code.
sub_struct['context'] = context_var sub_struct['context'] = context_var
struct_support = "" struct_support = ""
......
...@@ -21,6 +21,7 @@ from theano.misc.ordered_set import OrderedSet ...@@ -21,6 +21,7 @@ from theano.misc.ordered_set import OrderedSet
is_same_graph_with_merge = None is_same_graph_with_merge = None
equal_computations = None equal_computations = None
NoContext = object()
class Node(utils.object2): class Node(utils.object2):
"""A Node in a theano graph. """A Node in a theano graph.
...@@ -120,7 +121,7 @@ class Apply(Node): ...@@ -120,7 +121,7 @@ class Apply(Node):
""" """
if hasattr(self.op, 'get_context'): if hasattr(self.op, 'get_context'):
return self.op.get_context(self) return self.op.get_context(self)
return None return NoContext
def default_output(self): def default_output(self):
"""Returns the default output for this node. """Returns the default output for this node.
......
...@@ -742,9 +742,9 @@ class Op(utils.object2, PureOp, CLinkerOp): ...@@ -742,9 +742,9 @@ class Op(utils.object2, PureOp, CLinkerOp):
p = node.op.perform p = node.op.perform
ctx = node.get_context() ctx = node.run_context()
if ctx is None: if ctx is graph.NoContext:
# default arguments are stored in the closure of `rval` # default arguments are stored in the closure of `rval`
def rval(p=p, i=node_input_storage, o=node_output_storage, n=node): def rval(p=p, i=node_input_storage, o=node_output_storage, n=node):
r = p(n, [x[0] for x in i], o) r = p(n, [x[0] for x in i], o)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论