提交 d5ba6134 authored 作者: abergeron's avatar abergeron 提交者: GitHub

Merge pull request #6044 from ReyhaneAskari/bye_bye_outputGaurd

Remove output guard
......@@ -27,7 +27,8 @@ from theano.compile.function_module import (
FunctionMaker, Function, infer_reuse_pattern,
std_fgraph)
from theano.compile.mode import Mode, register_mode
from theano.compile.ops import OutputGuard
from theano.compile.ops import OutputGuard, _output_guard
__docformat__ = "restructuredtext en"
_logger = logging.getLogger("theano.compile.debugmode")
......@@ -2276,6 +2277,24 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
"of", len(li), "events was stable.",
file=sys.stderr)
self.fgraph = fgraph
destroy_handler_added = False
for feature in fgraph._features:
if isinstance(feature, gof.DestroyHandler):
destroy_handler_added = True
break
if not destroy_handler_added:
fgraph.attach_feature(gof.DestroyHandler())
for o in fgraph.outputs:
try:
fgraph.replace_validate(o, _output_guard(o), reason='output_guard')
raise Exception("Output variable %s required output_guard, "
"how was this output left unprotected against "
"destructive operations?" % o)
except gof.InconsistencyError:
# This output is already impossible to destroy.
# No guard necessary
pass
linker = _Linker(self)
......
......@@ -18,7 +18,6 @@ import theano
from theano import config, gof
from theano.compat import izip
from theano.gof import graph
import theano.compile.mode
import theano.compile.profiling
from theano.compile.io import (
In, SymbolicInput, SymbolicOutput)
......
......@@ -4,13 +4,14 @@ WRITEME
"""
from __future__ import absolute_import, print_function, division
import logging
import warnings
import theano
from theano import gof
import theano.gof.vm
from theano.configparser import config
from theano.compile.ops import _output_guard
from six import string_types
from theano.compile.function_module import Supervisor
_logger = logging.getLogger('theano.compile.mode')
......@@ -111,18 +112,16 @@ class AddDestroyHandler(gof.Optimizer):
"""
def apply(self, fgraph):
for o in fgraph.outputs:
try:
fgraph.replace_validate(o, _output_guard(o),
reason='output_guard')
_logger.info("Output variable %s required output_guard, "
"how was this output left unprotected against "
"destructive operations?"
% o)
except gof.InconsistencyError:
# This output is already impossible to destroy.
# No guard necessary
pass
supervisor_added = False
for feature in fgraph._features:
if isinstance(feature, Supervisor):
supervisor_added = True
break
if not supervisor_added:
warnings.warn("WARNING: Supervisor is not added. Please build a FunctionGraph"
"via theano.compile.function_module.std_graph()"
"or add the Supervisor class manually.",
stacklevel=3)
def add_requirements(self, fgraph):
super(AddDestroyHandler, self).add_requirements(fgraph)
......
......@@ -447,6 +447,8 @@ def get_scalar_constant_value(orig_v, elemwise=True,
max_recur > 0):
max_recur -= 1
if isinstance(v.owner.op, (Alloc, DimShuffle, Rebroadcast,
# outputguard is only used in debugmode but we
# keep it here to avoid problems with old pickels.
compile.ops.OutputGuard,
compile.DeepCopyOp)):
v = v.owner.inputs[0]
......
......@@ -579,8 +579,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
theano.compile.mode.optdb.query(
theano.compile.mode.OPT_FAST_RUN).optimize(fgraph)
assert str(fgraph.outputs[0].owner.op) == 'OutputGuard'
assert (fgraph.outputs[0].owner.inputs[0].owner.op ==
assert (fgraph.outputs[0].owner.op ==
crossentropy_softmax_argmax_1hot_with_bias)
def test_softmax_optimizations_vector(self):
......@@ -594,8 +593,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
theano.compile.mode.optdb.query(
theano.compile.mode.OPT_FAST_RUN).optimize(fgraph)
assert str(fgraph.outputs[0].owner.op) == 'OutputGuard'
assert (fgraph.outputs[0].owner.inputs[0].owner.op ==
assert (fgraph.outputs[0].owner.op ==
crossentropy_softmax_argmax_1hot_with_bias)
def test_softmax_optimizations_w_bias(self):
......@@ -624,10 +622,8 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
# print node.op
# print printing.pprint(node.outputs[0])
# print '===='
assert len(fgraph.toposort()) == 2
assert str(fgraph.outputs[0].owner.op) == 'OutputGuard'
assert (fgraph.outputs[0].owner.inputs[0].owner.op ==
assert len(fgraph.toposort()) == 1
assert (fgraph.outputs[0].owner.op ==
crossentropy_softmax_argmax_1hot_with_bias)
def test_softmax_optimizations_w_bias2(self):
......@@ -654,10 +650,9 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
# for node in fgraph.toposort():
# print node.op
# print '===='
assert len(fgraph.toposort()) == 3
assert len(fgraph.toposort()) == 2
assert str(fgraph.outputs[0].owner.op) == 'OutputGuard'
assert (fgraph.outputs[0].owner.inputs[0].owner.op ==
assert (fgraph.outputs[0].owner.op ==
crossentropy_softmax_argmax_1hot_with_bias)
def test_softmax_optimizations_w_bias_vector(self):
......@@ -681,9 +676,8 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
# for node in fgraph.toposort():
# print node.op
# print '===='
assert len(fgraph.toposort()) == 3
assert str(fgraph.outputs[0].owner.op) == 'OutputGuard'
assert (fgraph.outputs[0].owner.inputs[0].owner.op ==
assert len(fgraph.toposort()) == 2
assert (fgraph.outputs[0].owner.op ==
crossentropy_softmax_argmax_1hot_with_bias)
def test_softmax_grad_optimizations(self):
......@@ -1338,9 +1332,8 @@ def test_argmax_pushdown():
# print 'AFTER'
# for node in fgraph.toposort():
# print node.op
assert len(fgraph.toposort()) == 2 # an output_guard is second
assert len(fgraph.toposort()) == 1
assert fgraph.toposort()[0].op == tensor.basic._argmax
assert str(fgraph.toposort()[1].op) == 'OutputGuard'
assert check_stack_trace(
fgraph, ops_to_check=tensor.basic._argmax)
x = tensor.matrix()
......@@ -1364,12 +1357,11 @@ def test_argmax_pushdown():
# print 'AFTER'
# for node in fgraph.toposort():
# print node.op
assert len(fgraph.toposort()) == 4 # an output_guard is second
assert len(fgraph.toposort()) == 3
assert isinstance(fgraph.toposort()[0].op, tensor.Elemwise)
assert isinstance(fgraph.toposort()[1].op, Softmax)
assert isinstance(fgraph.toposort()[2].op, tensor.CAReduce)
assert isinstance(fgraph.toposort()[2].op.scalar_op, theano.scalar.Maximum)
assert str(fgraph.toposort()[3].op) == 'OutputGuard'
def test_argmax_pushdown_bias():
......@@ -1388,10 +1380,9 @@ def test_argmax_pushdown_bias():
# for node in fgraph.toposort():
# print node.op
types_to_check = (tensor.DimShuffle, tensor.Elemwise, tensor.Argmax)
assert len(fgraph.toposort()) == 4
assert len(fgraph.toposort()) == 3
for i, type in enumerate(types_to_check):
assert isinstance(fgraph.toposort()[i].op, type)
assert str(fgraph.toposort()[3].op) == 'OutputGuard'
assert check_stack_trace(fgraph, ops_to_check=types_to_check)
x = tensor.matrix()
......@@ -1412,11 +1403,10 @@ def test_argmax_pushdown_bias():
# print 'AFTER'
# for node in fgraph.toposort():
# print node.op
assert len(fgraph.toposort()) == 3
assert len(fgraph.toposort()) == 2
assert isinstance(fgraph.toposort()[0].op, SoftmaxWithBias)
assert isinstance(fgraph.toposort()[1].op, tensor.CAReduce)
assert isinstance(fgraph.toposort()[1].op.scalar_op, theano.scalar.Maximum)
assert str(fgraph.toposort()[2].op) == 'OutputGuard'
assert check_stack_trace(
fgraph, ops_to_check=(SoftmaxWithBias, tensor.CAReduce))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论