提交 b2cbd92b authored 作者: Pierre Luc Carrier's avatar Pierre Luc Carrier

Add Feature (and optimization to add it to a fgraph) to prevent computing…

Add Feature (and optimization to add it to a fgraph) to prevent computing outputs via inplace operations
上级 cd3979e5
...@@ -159,6 +159,16 @@ class AddDestroyHandler(gof.Optimizer): ...@@ -159,6 +159,16 @@ class AddDestroyHandler(gof.Optimizer):
fgraph.attach_feature(gof.DestroyHandler()) fgraph.attach_feature(gof.DestroyHandler())
class AddNoOutputFromInplace(gof.Optimizer):
"""This optimizer adds to the fgraph a feature that will prevent outputs
of a fgraph to be created by performing inplace operations on intermediary
variables.
"""
def add_requirements(self, fgraph):
super(AddNoOutputFromInplace, self).add_requirements(fgraph)
fgraph.attach_feature(gof.NoOutputFromInplace())
class PrintCurrentFunctionGraph(gof.Optimizer): class PrintCurrentFunctionGraph(gof.Optimizer):
"""This optimizer is for debugging. """This optimizer is for debugging.
...@@ -211,6 +221,9 @@ optdb.register('specialize_device', gof.EquilibriumDB(), ...@@ -211,6 +221,9 @@ optdb.register('specialize_device', gof.EquilibriumDB(),
optdb.register('merge2', gof.MergeOptimizer(), optdb.register('merge2', gof.MergeOptimizer(),
49, 'fast_run', 'merge') 49, 'fast_run', 'merge')
optdb.register('add_no_output_from_inplace', AddNoOutputFromInplace(),
49.4)
optdb.register('add_destroy_handler', AddDestroyHandler(), optdb.register('add_destroy_handler', AddDestroyHandler(),
49.5, 'fast_run', 'inplace') 49.5, 'fast_run', 'inplace')
......
import theano
from theano.compile.mode import Mode
import theano.tensor as T
def test_no_output_from_implace():
x = T.matrix()
y = T.matrix()
a = T.dot(x, y)
b = T.tanh(a)
# Ensure that the elemwise op that produces the output is inplace when
# using a mode that does not include the optimization
fct_no_opt = theano.function([x,y], b, mode="FAST_RUN")
op = fct_no_opt.maker.fgraph.outputs[0].owner.op
assert (hasattr(op, 'destroy_map') and 0 in op.destroy_map)
# Ensure that the elemwise op that produces the output is not inplace when
# using a mode that includes the optimization
mode_opt = Mode(linker="cvm", optimizer="fast_run")
mode_opt = mode_opt.including("add_no_output_from_inplace")
fct_opt = theano.function([x,y], b, mode=mode_opt)
op = fct_opt.maker.fgraph.outputs[0].owner.op
assert (not hasattr(op, 'destroy_map') or 0 not in op.destroy_map)
...@@ -46,7 +46,7 @@ from theano.gof.fg import \ ...@@ -46,7 +46,7 @@ from theano.gof.fg import \
CachedConstantError, InconsistencyError, MissingInputError, FunctionGraph CachedConstantError, InconsistencyError, MissingInputError, FunctionGraph
from theano.gof.destroyhandler import \ from theano.gof.destroyhandler import \
DestroyHandler DestroyHandler, NoOutputFromInplace
from theano.gof.graph import \ from theano.gof.graph import \
Apply, Variable, Constant, view_roots Apply, Variable, Constant, view_roots
......
...@@ -7,6 +7,7 @@ import toolbox ...@@ -7,6 +7,7 @@ import toolbox
import graph import graph
from theano.gof.python25 import deque from theano.gof.python25 import deque
from theano.gof.python25 import OrderedDict from theano.gof.python25 import OrderedDict
from theano.gof.toolbox import Feature
from theano.misc.ordered_set import OrderedSet from theano.misc.ordered_set import OrderedSet
from fg import InconsistencyError from fg import InconsistencyError
...@@ -1026,3 +1027,23 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -1026,3 +1027,23 @@ class DestroyHandler(toolbox.Bookkeeper):
rval[app] = root_clients rval[app] = root_clients
return rval return rval
class NoOutputFromInplace(Feature):
def __init__(self):
pass
def validate(self, fgraph):
if not hasattr(fgraph, 'destroyers'):
return True
for out in list(fgraph.outputs):
# Validate that the node that produces the output does not produce
# it by modifying something else inplace.
node = out.owner
op = node.op
out_idx = node.outputs.index(out)
if hasattr(op, 'destroy_map') and out_idx in op.destroy_map.keys():
raise InconsistencyError(
"Trying to produce an output ", out,
" by modifying another variable inplace")
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论