提交 032c8b5d authored 作者: Frederic's avatar Frederic

[BUG]Readd the destroyhandler as a requirement for some opt.

This fix bugs on the CPU in fast_compile introduced in: commit 4f8f0da5 Date: Thu Feb 27 09:56:12 2014 -0500
上级 1783f4e9
......@@ -790,9 +790,14 @@ class LocalOptimizer(object):
class FromFunctionLocalOptimizer(LocalOptimizer):
"""WRITEME"""
def __init__(self, fn, tracks=None):
def __init__(self, fn, tracks=None, requirements=()):
self.transform = fn
self._tracks = tracks
self.requirements = requirements
def add_requirements(self, fgraph):
for req in self.requirements:
req(fgraph)
def tracks(self):
return self._tracks
......@@ -808,7 +813,7 @@ class FromFunctionLocalOptimizer(LocalOptimizer):
id(self))
def local_optimizer(tracks):
def local_optimizer(tracks, inplace=False):
def decorator(f):
"""WRITEME"""
if tracks is not None:
......@@ -817,7 +822,12 @@ def local_optimizer(tracks):
for t in tracks:
if not (isinstance(t, op.Op) or issubclass(t, op.PureOp)):
raise ValueError, ("Tracks are op classes or instances", f.__module__, f.__name__)
rval = FromFunctionLocalOptimizer(f, tracks)
requirements = ()
if inplace:
dh_handler = dh.DestroyHandler
requirements = (lambda fgraph:
fgraph.attach_feature(dh_handler()),)
rval = FromFunctionLocalOptimizer(f, tracks, requirements)
rval.__name__ = f.__name__
return rval
return decorator
......@@ -852,6 +862,10 @@ class LocalOptGroup(LocalOptimizer):
for lopt in self.opts:
lopt.print_summary(stream, level=(level + 2), depth=depth)
def add_requirements(self, fgraph):
for opt in self.opts:
opt.add_requirements(fgraph)
class _LocalOpKeyOptGroup(LocalOptGroup):
"""WRITEME"""
......
......@@ -1214,19 +1214,19 @@ def local_gpujoin_1(node):
# shared = dimshuffle(gemm_inplace(dimshuffle(shared)))
# which causes memory leaks (long term fix is to make the above not leak
# memory)
@local_optimizer([gpu_gemm_no_inplace])
@local_optimizer([gpu_gemm_no_inplace], inplace=True)
def local_inplace_gemm(node):
if node.op == gpu_gemm_no_inplace:
return [gpu_gemm_inplace(*node.inputs)]
@local_optimizer([gpu_gemv_no_inplace])
@local_optimizer([gpu_gemv_no_inplace], inplace=True)
def local_inplace_gemv(node):
if node.op == gpu_gemv_no_inplace:
return [gpu_gemv_inplace(*node.inputs)]
@local_optimizer([gpu_ger_no_inplace])
@local_optimizer([gpu_ger_no_inplace], inplace=True)
def local_inplace_ger(node):
if node.op == gpu_ger_no_inplace:
return [gpu_ger_inplace(*node.inputs)]
......
......@@ -1715,20 +1715,19 @@ def local_dot_to_dot22(node):
_logger.info('Not optimizing dot with inputs %s %s %s %s',
x, y, x.type, y.type)
@local_optimizer([gemm_no_inplace])
@local_optimizer([gemm_no_inplace], inplace=True)
def local_inplace_gemm(node):
if node.op == gemm_no_inplace:
return [gemm_inplace(*node.inputs)]
@local_optimizer([gemv_no_inplace])
@local_optimizer([gemv_no_inplace], inplace=True)
def local_inplace_gemv(node):
if node.op == gemv_no_inplace:
return [gemv_inplace(*node.inputs)]
@local_optimizer([ger])
@local_optimizer([ger], inplace=True)
def local_inplace_ger(node):
if node.op == ger:
return [ger_destructive(*node.inputs)]
......
......@@ -2110,7 +2110,7 @@ compile.optdb.register('pre_local_IncSubtensor_serialize',
#after priority 50 Destructive inplace operations
#gemm is the first one now, at priority 70
@gof.local_optimizer([IncSubtensor]) # XXX: GPU
@gof.local_optimizer([IncSubtensor], inplace=True)
def local_inplace_setsubtensor(node):
"""
Also work for GpuIncSubtensor
......@@ -2129,7 +2129,7 @@ compile.optdb.register('local_inplace_setsubtensor',
'fast_run', 'inplace') # DEBUG
@gof.local_optimizer([AdvancedIncSubtensor1]) # XXX: GPU
@gof.local_optimizer([AdvancedIncSubtensor1], inplace=True)
def local_inplace_incsubtensor1(node):
""" also work for GpuAdvancedIncSubtensor1 """
if isinstance(node.op, AdvancedIncSubtensor1) and not node.op.inplace:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论