提交 f94d63f4 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Move ops_with_inner_function to to gof/op.py

上级 ccb01355
from theano import gof from theano import gof
from theano import gradient as G from theano import gradient as G
from function_module import orig_function from theano.compile.function_module import orig_function
from theano.gof import ops_with_inner_function
class OpFromGraph(gof.Op): class OpFromGraph(gof.Op):
...@@ -69,13 +70,6 @@ class OpFromGraph(gof.Op): ...@@ -69,13 +70,6 @@ class OpFromGraph(gof.Op):
grad_depth=grad_depth - 1, grad_depth=grad_depth - 1,
on_unused_input='ignore')) on_unused_input='ignore'))
# Since OpFromGraph contains a Theano compiled function, we should let
# DebugMode know about it
# We do that here to avoid circular import problems
from theano.compile.debugmode import ops_with_inner_function
if type(self) not in ops_with_inner_function:
ops_with_inner_function[type(self)] = 'fn'
def __eq__(self, other): def __eq__(self, other):
#TODO: recognize a copy #TODO: recognize a copy
return self is other return self is other
...@@ -106,3 +100,7 @@ class OpFromGraph(gof.Op): ...@@ -106,3 +100,7 @@ class OpFromGraph(gof.Op):
return [go(*(inputs + output_grads)) for go in self.grad_ops] return [go(*(inputs + output_grads)) for go in self.grad_ops]
else: else:
raise NotImplementedError raise NotImplementedError
# Since OpFromGraph contains a Theano compiled function, we should let
# DebugMode know about it
ops_with_inner_function[OpFromGraph] = 'fn'
...@@ -13,7 +13,7 @@ import numpy ...@@ -13,7 +13,7 @@ import numpy
import theano import theano
from theano import gof from theano import gof
from theano.gof import Env, graph, utils, link from theano.gof import Env, graph, utils, link, ops_with_inner_function
from theano.gof.link import raise_with_op from theano.gof.link import raise_with_op
from theano.gof.cc import CLinker from theano.gof.cc import CLinker
from theano.gof.python25 import product as itertools_product from theano.gof.python25 import product as itertools_product
...@@ -104,20 +104,6 @@ class NoDuplicateOptWarningFilter(logging.Filter): ...@@ -104,20 +104,6 @@ class NoDuplicateOptWarningFilter(logging.Filter):
_logger.addFilter(NoDuplicateOptWarningFilter()) _logger.addFilter(NoDuplicateOptWarningFilter())
"""
Registry of Ops that have an inner compiled Theano function.
The keys are Op classes (not instances), and values are the name of the
attribute that contains the function. For instance, if the function is
self.fn, the value will be 'fn'.
We need that to be able not to run debug checks a number of times that is
exponential in the nesting level of those ops.
For instance, Scan will be registered here.
"""
ops_with_inner_function = {}
######################## ########################
# #
# Exceptions # Exceptions
......
...@@ -18,7 +18,7 @@ from link import \ ...@@ -18,7 +18,7 @@ from link import \
Container, Linker, LocalLinker, PerformLinker, WrapLinker, WrapLinkerMany Container, Linker, LocalLinker, PerformLinker, WrapLinker, WrapLinkerMany
from op import \ from op import \
Op, PureOp Op, PureOp, ops_with_inner_function
from opt import (Optimizer, optimizer, SeqOptimizer, from opt import (Optimizer, optimizer, SeqOptimizer,
MergeOptimizer, MergeOptMerge, MergeOptimizer, MergeOptMerge,
......
...@@ -717,3 +717,17 @@ def get_debug_values(*args): ...@@ -717,3 +717,17 @@ def get_debug_values(*args):
return rval return rval
return [tuple(rval)] return [tuple(rval)]
ops_with_inner_function = {}
"""
Registry of Ops that have an inner compiled Theano function.
The keys are Op classes (not instances), and values are the name of the
attribute that contains the function. For instance, if the function is
self.fn, the value will be 'fn'.
We need that to be able not to run debug checks a number of times that is
exponential in the nesting level of those ops.
For instance, Scan will be registered here.
"""
...@@ -1678,7 +1678,7 @@ class Scan(PureOp): ...@@ -1678,7 +1678,7 @@ class Scan(PureOp):
# Since Scan is an op that contains a Theano compiled function, it is # Since Scan is an op that contains a Theano compiled function, it is
# useful to let DebugMode know about it. # useful to let DebugMode know about it.
compile.debugmode.ops_with_inner_function[Scan] = 'fn' gof.ops_with_inner_function[Scan] = 'fn'
@theano.compile.profilemode.register_profiler_printer @theano.compile.profilemode.register_profiler_printer
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论