提交 d650fa30 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add c_{init,support,cleanup}_code_struct to allow inserting stuf at the struct level.

上级 b1cc9fd8
...@@ -615,15 +615,8 @@ class CLinker(link.Linker): ...@@ -615,15 +615,8 @@ class CLinker(link.Linker):
id += 2 id += 2
for node_num, node in enumerate(self.node_order): for node_num, node in enumerate(self.node_order):
# Why is this here?
# We populate sub with a mapping from the variable names
# specified by the op's c_var_names method to the actual
# variable names that we will use.
## ivnames, ovnames = op.c_var_names()
sub = dict(failure_var=failure_var) sub = dict(failure_var=failure_var)
## for variable, vname in zip(op.inputs + op.outputs,
## ivnames + ovnames):
## sub[vname] = symbol[variable]
# The placeholder will be replaced by a hash of the entire # The placeholder will be replaced by a hash of the entire
# code (module + support code) in DynamicModule.code. # code (module + support code) in DynamicModule.code.
...@@ -636,15 +629,15 @@ class CLinker(link.Linker): ...@@ -636,15 +629,15 @@ class CLinker(link.Linker):
isyms = [symbol[r] for r in node.inputs] isyms = [symbol[r] for r in node.inputs]
osyms = [symbol[r] for r in node.outputs] osyms = [symbol[r] for r in node.outputs]
# c_validate_update is deprecated
if hasattr(node.op, 'c_validate_update'):
raise Exception("c_validate_update is deprecated,"
" move contents to c_code", node.op)
# Make the CodeBlock for c_code # Make the CodeBlock for c_code
sub['id'] = id sub['id'] = id
sub['struct_id'] = id + 1
sub['fail'] = failure_code(sub) sub['fail'] = failure_code(sub)
struct_support = ""
struct_init = ""
struct_cleanup = ""
op = node.op op = node.op
# type-specific support code # type-specific support code
try: try:
...@@ -657,6 +650,7 @@ class CLinker(link.Linker): ...@@ -657,6 +650,7 @@ class CLinker(link.Linker):
assert isinstance(c_support_code_apply[-1], basestring), ( assert isinstance(c_support_code_apply[-1], basestring), (
str(node.op) + str(node.op) +
" didn't return a string for c_support_code_apply") " didn't return a string for c_support_code_apply")
try: try:
c_init_code_apply.append(op.c_init_code_apply(node, name)) c_init_code_apply.append(op.c_init_code_apply(node, name))
except utils.MethodNotDefined: except utils.MethodNotDefined:
...@@ -666,6 +660,30 @@ class CLinker(link.Linker): ...@@ -666,6 +660,30 @@ class CLinker(link.Linker):
str(node.op) + str(node.op) +
" didn't return a string for c_init_code_apply") " didn't return a string for c_init_code_apply")
try:
struct_init = op.c_init_code_struct(node, id + 1)
assert isinstance(struct_init, basestring), (
str(node.op) +
" didn't return a string for c_init_code_struct")
except utils.MethodNotDefined:
pass
try:
struct_support = op.c_support_code_struct(node, id + 1)
assert isinstance(struct_support, basestring), (
str(node.op) +
" didn't return a string for c_support_code_struct")
except utils.MethodNotDefined:
pass
try:
struct_cleanup = op.c_cleanup_code_struct(node, id + 1)
assert isinstance(struct_cleanup, basestring), (
str(node.op) +
" didn't return a string for c_cleanup_code_struct")
except utils.MethodNotDefined:
pass
# emit c_code # emit c_code
try: try:
behavior = op.c_code(node, name, isyms, osyms, sub) behavior = op.c_code(node, name, isyms, osyms, sub)
...@@ -690,6 +708,12 @@ class CLinker(link.Linker): ...@@ -690,6 +708,12 @@ class CLinker(link.Linker):
tasks.append((node, 'code', id)) tasks.append((node, 'code', id))
id += 1 id += 1
init_blocks.append(CodeBlock(struct_support, struct_init,
struct_cleanup, {'id': id}))
init_tasks.append((node, 'init', id))
id += 1
# List of arg names for use in struct_gen. Note the call to # List of arg names for use in struct_gen. Note the call to
# uniq: duplicate inputs must only be passed once because they # uniq: duplicate inputs must only be passed once because they
# are mapped to the same name. Duplicates are defined by (a # are mapped to the same name. Duplicates are defined by (a
...@@ -955,7 +979,8 @@ class CLinker(link.Linker): ...@@ -955,7 +979,8 @@ class CLinker(link.Linker):
id += 2 id += 2
for node in self.node_order: for node in self.node_order:
tasks.append((node, 'code', id)) tasks.append((node, 'code', id))
id += 1 init_tasks.append((node, 'init', id + 1))
id += 2
return init_tasks, tasks return init_tasks, tasks
def make_thunk(self, input_storage=None, output_storage=None, def make_thunk(self, input_storage=None, output_storage=None,
......
...@@ -276,7 +276,7 @@ class CLinkerOp(CLinkerObject): ...@@ -276,7 +276,7 @@ class CLinkerOp(CLinkerObject):
def c_support_code_apply(self, node, name): def c_support_code_apply(self, node, name):
"""Optional: Return utility code for use by an `Op` that will be """Optional: Return utility code for use by an `Op` that will be
inserted at struct scope, that can be specialized for the inserted at global scope, that can be specialized for the
support of a particular `Apply` node. support of a particular `Apply` node.
:param node: an Apply instance in the graph being compiled :param node: an Apply instance in the graph being compiled
...@@ -300,7 +300,7 @@ class CLinkerOp(CLinkerObject): ...@@ -300,7 +300,7 @@ class CLinkerOp(CLinkerObject):
def c_init_code_apply(self, node, name): def c_init_code_apply(self, node, name):
""" """
Optional: return a code string specific to the apply Optional: return a code string specific to the apply
to be inserted in the struct initialization code. to be inserted in the module initialization code.
:param node: an Apply instance in the graph being compiled :param node: an Apply instance in the graph being compiled
...@@ -319,6 +319,61 @@ class CLinkerOp(CLinkerObject): ...@@ -319,6 +319,61 @@ class CLinkerOp(CLinkerObject):
raise utils.MethodNotDefined("c_init_code_apply", type(self), raise utils.MethodNotDefined("c_init_code_apply", type(self),
self.__class__.__name__) self.__class__.__name__)
def c_init_code_struct(self, node, struct_id):
"""
Optional: return a code string specific to the apply
to be inserted in the struct initialization code.
:param node: an Apply instance in the graph being compiled
:param struct_id: a number that serves to uniquely identify
this code. The c_code will receive another
sub parameter named struct_id that will
contain this name.
:Exceptions:
- `MethodNotDefined`: the subclass does not override this method
"""
raise utils.MethodNotDefined("c_init_code_apply", type(self),
self.__class__.__name__)
def c_support_code_struct(self, node, struct_id):
"""Optional: Return utility code for use by an `Op` that will be
inserted at struct scope, that can be specialized for the
support of a particular `Apply` node.
:param node: an Apply instance in the graph being compiled
:param struct_id: a number that serves to uniquely identify
this code. The c_code will receive another
sub parameter named struct_id that will
contain this name.
:Exceptions:
- `MethodNotDefined`: Subclass does not implement this method
"""
raise utils.MethodNotDefined("c_support_code_struct",
type(self), self.__class__.__name__)
def c_cleanup_code_struct(self, node, struct_id):
"""
Optional: return a code string specific to the apply to be
inserted in the struct cleanup code.
:param node: an Apply instance in the graph being compiled
:param struct_id: a number that serves to uniquely identify
this code. The c_code will receive another
sub parameter named struct_id that will
contain this name.
:Exceptions:
- `MethodNotDefined`: the subclass does not override this method
"""
raise utils.MethodNotDefined("c_cleanup_code_struct", type(self),
self.__class__.__name__)
class PureOp(object): class PureOp(object):
""" """
......
...@@ -63,6 +63,10 @@ class HideC(object): ...@@ -63,6 +63,10 @@ class HideC(object):
c_init_code = __hide c_init_code = __hide
c_init_code_apply = __hide c_init_code_apply = __hide
c_init_code_struct = __hide
c_support_code_struct = __hide
c_cleanup_code_struct = __hide
def c_code_cache_version(self): def c_code_cache_version(self):
return () return ()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论