提交 d5ba6134 authored 作者: abergeron's avatar abergeron 提交者: GitHub

Merge pull request #6044 from ReyhaneAskari/bye_bye_outputGaurd

Remove output guard
...@@ -27,7 +27,8 @@ from theano.compile.function_module import ( ...@@ -27,7 +27,8 @@ from theano.compile.function_module import (
FunctionMaker, Function, infer_reuse_pattern, FunctionMaker, Function, infer_reuse_pattern,
std_fgraph) std_fgraph)
from theano.compile.mode import Mode, register_mode from theano.compile.mode import Mode, register_mode
from theano.compile.ops import OutputGuard from theano.compile.ops import OutputGuard, _output_guard
__docformat__ = "restructuredtext en" __docformat__ = "restructuredtext en"
_logger = logging.getLogger("theano.compile.debugmode") _logger = logging.getLogger("theano.compile.debugmode")
...@@ -2276,6 +2277,24 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions ...@@ -2276,6 +2277,24 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
"of", len(li), "events was stable.", "of", len(li), "events was stable.",
file=sys.stderr) file=sys.stderr)
self.fgraph = fgraph self.fgraph = fgraph
destroy_handler_added = False
for feature in fgraph._features:
if isinstance(feature, gof.DestroyHandler):
destroy_handler_added = True
break
if not destroy_handler_added:
fgraph.attach_feature(gof.DestroyHandler())
for o in fgraph.outputs:
try:
fgraph.replace_validate(o, _output_guard(o), reason='output_guard')
raise Exception("Output variable %s required output_guard, "
"how was this output left unprotected against "
"destructive operations?" % o)
except gof.InconsistencyError:
# This output is already impossible to destroy.
# No guard necessary
pass
linker = _Linker(self) linker = _Linker(self)
......
...@@ -18,7 +18,6 @@ import theano ...@@ -18,7 +18,6 @@ import theano
from theano import config, gof from theano import config, gof
from theano.compat import izip from theano.compat import izip
from theano.gof import graph from theano.gof import graph
import theano.compile.mode
import theano.compile.profiling import theano.compile.profiling
from theano.compile.io import ( from theano.compile.io import (
In, SymbolicInput, SymbolicOutput) In, SymbolicInput, SymbolicOutput)
......
...@@ -4,13 +4,14 @@ WRITEME ...@@ -4,13 +4,14 @@ WRITEME
""" """
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
import logging import logging
import warnings
import theano import theano
from theano import gof from theano import gof
import theano.gof.vm import theano.gof.vm
from theano.configparser import config from theano.configparser import config
from theano.compile.ops import _output_guard
from six import string_types from six import string_types
from theano.compile.function_module import Supervisor
_logger = logging.getLogger('theano.compile.mode') _logger = logging.getLogger('theano.compile.mode')
...@@ -111,18 +112,16 @@ class AddDestroyHandler(gof.Optimizer): ...@@ -111,18 +112,16 @@ class AddDestroyHandler(gof.Optimizer):
""" """
def apply(self, fgraph): def apply(self, fgraph):
for o in fgraph.outputs: supervisor_added = False
try: for feature in fgraph._features:
fgraph.replace_validate(o, _output_guard(o), if isinstance(feature, Supervisor):
reason='output_guard') supervisor_added = True
_logger.info("Output variable %s required output_guard, " break
"how was this output left unprotected against " if not supervisor_added:
"destructive operations?" warnings.warn("WARNING: Supervisor is not added. Please build a FunctionGraph"
% o) "via theano.compile.function_module.std_graph()"
except gof.InconsistencyError: "or add the Supervisor class manually.",
# This output is already impossible to destroy. stacklevel=3)
# No guard necessary
pass
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
super(AddDestroyHandler, self).add_requirements(fgraph) super(AddDestroyHandler, self).add_requirements(fgraph)
......
...@@ -447,6 +447,8 @@ def get_scalar_constant_value(orig_v, elemwise=True, ...@@ -447,6 +447,8 @@ def get_scalar_constant_value(orig_v, elemwise=True,
max_recur > 0): max_recur > 0):
max_recur -= 1 max_recur -= 1
if isinstance(v.owner.op, (Alloc, DimShuffle, Rebroadcast, if isinstance(v.owner.op, (Alloc, DimShuffle, Rebroadcast,
# outputguard is only used in debugmode but we
# keep it here to avoid problems with old pickels.
compile.ops.OutputGuard, compile.ops.OutputGuard,
compile.DeepCopyOp)): compile.DeepCopyOp)):
v = v.owner.inputs[0] v = v.owner.inputs[0]
......
...@@ -579,8 +579,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -579,8 +579,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
theano.compile.mode.optdb.query( theano.compile.mode.optdb.query(
theano.compile.mode.OPT_FAST_RUN).optimize(fgraph) theano.compile.mode.OPT_FAST_RUN).optimize(fgraph)
assert str(fgraph.outputs[0].owner.op) == 'OutputGuard' assert (fgraph.outputs[0].owner.op ==
assert (fgraph.outputs[0].owner.inputs[0].owner.op ==
crossentropy_softmax_argmax_1hot_with_bias) crossentropy_softmax_argmax_1hot_with_bias)
def test_softmax_optimizations_vector(self): def test_softmax_optimizations_vector(self):
...@@ -594,8 +593,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -594,8 +593,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
theano.compile.mode.optdb.query( theano.compile.mode.optdb.query(
theano.compile.mode.OPT_FAST_RUN).optimize(fgraph) theano.compile.mode.OPT_FAST_RUN).optimize(fgraph)
assert str(fgraph.outputs[0].owner.op) == 'OutputGuard' assert (fgraph.outputs[0].owner.op ==
assert (fgraph.outputs[0].owner.inputs[0].owner.op ==
crossentropy_softmax_argmax_1hot_with_bias) crossentropy_softmax_argmax_1hot_with_bias)
def test_softmax_optimizations_w_bias(self): def test_softmax_optimizations_w_bias(self):
...@@ -624,10 +622,8 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -624,10 +622,8 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
# print node.op # print node.op
# print printing.pprint(node.outputs[0]) # print printing.pprint(node.outputs[0])
# print '====' # print '===='
assert len(fgraph.toposort()) == 2 assert len(fgraph.toposort()) == 1
assert (fgraph.outputs[0].owner.op ==
assert str(fgraph.outputs[0].owner.op) == 'OutputGuard'
assert (fgraph.outputs[0].owner.inputs[0].owner.op ==
crossentropy_softmax_argmax_1hot_with_bias) crossentropy_softmax_argmax_1hot_with_bias)
def test_softmax_optimizations_w_bias2(self): def test_softmax_optimizations_w_bias2(self):
...@@ -654,10 +650,9 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -654,10 +650,9 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
# for node in fgraph.toposort(): # for node in fgraph.toposort():
# print node.op # print node.op
# print '====' # print '===='
assert len(fgraph.toposort()) == 3 assert len(fgraph.toposort()) == 2
assert str(fgraph.outputs[0].owner.op) == 'OutputGuard' assert (fgraph.outputs[0].owner.op ==
assert (fgraph.outputs[0].owner.inputs[0].owner.op ==
crossentropy_softmax_argmax_1hot_with_bias) crossentropy_softmax_argmax_1hot_with_bias)
def test_softmax_optimizations_w_bias_vector(self): def test_softmax_optimizations_w_bias_vector(self):
...@@ -681,9 +676,8 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -681,9 +676,8 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
# for node in fgraph.toposort(): # for node in fgraph.toposort():
# print node.op # print node.op
# print '====' # print '===='
assert len(fgraph.toposort()) == 3 assert len(fgraph.toposort()) == 2
assert str(fgraph.outputs[0].owner.op) == 'OutputGuard' assert (fgraph.outputs[0].owner.op ==
assert (fgraph.outputs[0].owner.inputs[0].owner.op ==
crossentropy_softmax_argmax_1hot_with_bias) crossentropy_softmax_argmax_1hot_with_bias)
def test_softmax_grad_optimizations(self): def test_softmax_grad_optimizations(self):
...@@ -1338,9 +1332,8 @@ def test_argmax_pushdown(): ...@@ -1338,9 +1332,8 @@ def test_argmax_pushdown():
# print 'AFTER' # print 'AFTER'
# for node in fgraph.toposort(): # for node in fgraph.toposort():
# print node.op # print node.op
assert len(fgraph.toposort()) == 2 # an output_guard is second assert len(fgraph.toposort()) == 1
assert fgraph.toposort()[0].op == tensor.basic._argmax assert fgraph.toposort()[0].op == tensor.basic._argmax
assert str(fgraph.toposort()[1].op) == 'OutputGuard'
assert check_stack_trace( assert check_stack_trace(
fgraph, ops_to_check=tensor.basic._argmax) fgraph, ops_to_check=tensor.basic._argmax)
x = tensor.matrix() x = tensor.matrix()
...@@ -1364,12 +1357,11 @@ def test_argmax_pushdown(): ...@@ -1364,12 +1357,11 @@ def test_argmax_pushdown():
# print 'AFTER' # print 'AFTER'
# for node in fgraph.toposort(): # for node in fgraph.toposort():
# print node.op # print node.op
assert len(fgraph.toposort()) == 4 # an output_guard is second assert len(fgraph.toposort()) == 3
assert isinstance(fgraph.toposort()[0].op, tensor.Elemwise) assert isinstance(fgraph.toposort()[0].op, tensor.Elemwise)
assert isinstance(fgraph.toposort()[1].op, Softmax) assert isinstance(fgraph.toposort()[1].op, Softmax)
assert isinstance(fgraph.toposort()[2].op, tensor.CAReduce) assert isinstance(fgraph.toposort()[2].op, tensor.CAReduce)
assert isinstance(fgraph.toposort()[2].op.scalar_op, theano.scalar.Maximum) assert isinstance(fgraph.toposort()[2].op.scalar_op, theano.scalar.Maximum)
assert str(fgraph.toposort()[3].op) == 'OutputGuard'
def test_argmax_pushdown_bias(): def test_argmax_pushdown_bias():
...@@ -1388,10 +1380,9 @@ def test_argmax_pushdown_bias(): ...@@ -1388,10 +1380,9 @@ def test_argmax_pushdown_bias():
# for node in fgraph.toposort(): # for node in fgraph.toposort():
# print node.op # print node.op
types_to_check = (tensor.DimShuffle, tensor.Elemwise, tensor.Argmax) types_to_check = (tensor.DimShuffle, tensor.Elemwise, tensor.Argmax)
assert len(fgraph.toposort()) == 4 assert len(fgraph.toposort()) == 3
for i, type in enumerate(types_to_check): for i, type in enumerate(types_to_check):
assert isinstance(fgraph.toposort()[i].op, type) assert isinstance(fgraph.toposort()[i].op, type)
assert str(fgraph.toposort()[3].op) == 'OutputGuard'
assert check_stack_trace(fgraph, ops_to_check=types_to_check) assert check_stack_trace(fgraph, ops_to_check=types_to_check)
x = tensor.matrix() x = tensor.matrix()
...@@ -1412,11 +1403,10 @@ def test_argmax_pushdown_bias(): ...@@ -1412,11 +1403,10 @@ def test_argmax_pushdown_bias():
# print 'AFTER' # print 'AFTER'
# for node in fgraph.toposort(): # for node in fgraph.toposort():
# print node.op # print node.op
assert len(fgraph.toposort()) == 3 assert len(fgraph.toposort()) == 2
assert isinstance(fgraph.toposort()[0].op, SoftmaxWithBias) assert isinstance(fgraph.toposort()[0].op, SoftmaxWithBias)
assert isinstance(fgraph.toposort()[1].op, tensor.CAReduce) assert isinstance(fgraph.toposort()[1].op, tensor.CAReduce)
assert isinstance(fgraph.toposort()[1].op.scalar_op, theano.scalar.Maximum) assert isinstance(fgraph.toposort()[1].op.scalar_op, theano.scalar.Maximum)
assert str(fgraph.toposort()[2].op) == 'OutputGuard'
assert check_stack_trace( assert check_stack_trace(
fgraph, ops_to_check=(SoftmaxWithBias, tensor.CAReduce)) fgraph, ops_to_check=(SoftmaxWithBias, tensor.CAReduce))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论