提交 de9cd253 authored 作者: Frederic's avatar Frederic

[crash fix] fix crash in fast compile introduced recently.

上级 d22eca31
...@@ -57,7 +57,10 @@ from theano.gof.link import \ ...@@ -57,7 +57,10 @@ from theano.gof.link import \
from theano.gof.op import \ from theano.gof.op import \
Op, OpenMPOp, PureOp, ops_with_inner_function Op, OpenMPOp, PureOp, ops_with_inner_function
from theano.gof.opt import (Optimizer, optimizer, SeqOptimizer, from theano.gof.opt import (
Optimizer,
optimizer, inplace_optimizer,
SeqOptimizer,
MergeOptimizer, MergeOptMerge, MergeOptimizer, MergeOptMerge,
LocalOptimizer, local_optimizer, LocalOptGroup, LocalOptimizer, local_optimizer, LocalOptGroup,
OpSub, OpRemove, PatternSub, OpSub, OpRemove, PatternSub,
......
...@@ -114,13 +114,13 @@ class Optimizer(object): ...@@ -114,13 +114,13 @@ class Optimizer(object):
class FromFunctionOptimizer(Optimizer): class FromFunctionOptimizer(Optimizer):
"""WRITEME""" """WRITEME"""
def __init__(self, fn): def __init__(self, fn, requirements=()):
self.apply = fn self.apply = fn
self.requirements = requirements
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
# Added by default for req in self.requirements:
#fgraph.attach_feature(toolbox.ReplaceValidate()) req(fgraph)
pass
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print >> stream, "%s%s id=%i" % ( print >> stream, "%s%s id=%i" % (
...@@ -142,6 +142,16 @@ def optimizer(f): ...@@ -142,6 +142,16 @@ def optimizer(f):
return rval return rval
def inplace_optimizer(f):
"""decorator for FromFunctionOptimizer"""
dh_handler = dh.DestroyHandler
requirements = (lambda fgraph:
fgraph.attach_feature(dh_handler()),)
rval = FromFunctionOptimizer(f, requirements)
rval.__name__ = f.__name__
return rval
class SeqOptimizer(Optimizer, list): class SeqOptimizer(Optimizer, list):
#inherit from Optimizer first to get Optimizer.__hash__ #inherit from Optimizer first to get Optimizer.__hash__
"""WRITEME """WRITEME
......
...@@ -174,7 +174,7 @@ def inplace_elemwise_optimizer_op(OP): ...@@ -174,7 +174,7 @@ def inplace_elemwise_optimizer_op(OP):
""" """
We parametrise it to make it work for Elemwise and GpuElemwise op. We parametrise it to make it work for Elemwise and GpuElemwise op.
""" """
@gof.optimizer @gof.inplace_optimizer
def inplace_elemwise_optimizer(fgraph): def inplace_elemwise_optimizer(fgraph):
""" """
Usage: inplace_elemwise_optimizer.optimize(fgraph) Usage: inplace_elemwise_optimizer.optimize(fgraph)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论