提交 710e6b48 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add a new prepare_node() method that can be used to do stuff 'just

before' make_thunk and help DebugMode work correctly.
上级 3a0de13a
...@@ -1784,8 +1784,8 @@ class _VariableEquivalenceTracker(object): ...@@ -1784,8 +1784,8 @@ class _VariableEquivalenceTracker(object):
# List of default version of make thunk. # List of default version of make thunk.
# This is needed to know if the user overrided it. # This is needed to know if the user overrided it.
# The GpuOp will be added here when theano.sandbox.cuda is imported. # The GpuOp will be added here when theano.sandbox.cuda is imported.
default_make_thunk = [get_unbound_function(theano.gof.Op.make_thunk), default_make_thunk = [get_unbound_function(theano.gof.Op.make_thunk)]
get_unbound_function(theano.gof.OpenMPOp.make_thunk)]
# Debug mode cheats and initializes the linker in a different way in # Debug mode cheats and initializes the linker in a different way in
...@@ -1879,6 +1879,10 @@ class _Linker(gof.link.LocalLinker): ...@@ -1879,6 +1879,10 @@ class _Linker(gof.link.LocalLinker):
thunk.inputs = [storage_map[v] for v in node.inputs] thunk.inputs = [storage_map[v] for v in node.inputs]
thunk.outputs = [storage_map[v] for v in node.outputs] thunk.outputs = [storage_map[v] for v in node.outputs]
thunk_other = thunk thunk_other = thunk
else:
new_node = node.op.prepare_node(node)
if new_node is not None:
node = new_node
try: try:
if not self.maker.mode.check_c_code: if not self.maker.mode.check_c_code:
......
...@@ -821,6 +821,14 @@ class Op(utils.object2, PureOp, CLinkerOp): ...@@ -821,6 +821,14 @@ class Op(utils.object2, PureOp, CLinkerOp):
else: else:
return NotImplemented return NotImplemented
def prepare_node(self, node):
"""
Make any special modifications that the Op needs before doing
make_thunk().
"""
pass
def make_c_thunk(self, node, storage_map, compute_map, no_recycling): def make_c_thunk(self, node, storage_map, compute_map, no_recycling):
""" """
Like make_thunk, but will only try to make a C thunk. Like make_thunk, but will only try to make a C thunk.
...@@ -930,6 +938,10 @@ class Op(utils.object2, PureOp, CLinkerOp): ...@@ -930,6 +938,10 @@ class Op(utils.object2, PureOp, CLinkerOp):
""" """
logger = logging.getLogger('theano.gof.op.Op') logger = logging.getLogger('theano.gof.op.Op')
new_node = self.prepare_node(self, node)
if new_node is not None:
node = new_node
if self._op_use_c_code: if self._op_use_c_code:
try: try:
return self.make_c_thunk(node, storage_map, compute_map, return self.make_c_thunk(node, storage_map, compute_map,
...@@ -1166,10 +1178,8 @@ int main( int argc, const char* argv[] ) ...@@ -1166,10 +1178,8 @@ int main( int argc, const char* argv[] )
self.openmp = False self.openmp = False
theano.config.openmp = False theano.config.openmp = False
def make_thunk(self, node, storage_map, compute_map, no_recycling): def prepare_node(self, node):
self.update_self_openmp() self.update_self_openmp()
return super(OpenMPOp, self).make_thunk(node, storage_map,
compute_map, no_recycling)
def simple_meth(tag): def simple_meth(tag):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论