提交 394857d6 authored 作者: Frederic Bastien's avatar Frederic Bastien

Add shortcut theano.change_flags and use it.

上级 682ce237
...@@ -65,6 +65,7 @@ def disable_log_handler(logger=theano_logger, handler=logging_default_handler): ...@@ -65,6 +65,7 @@ def disable_log_handler(logger=theano_logger, handler=logging_default_handler):
from theano.version import version as __version__ from theano.version import version as __version__
from theano.configdefaults import config from theano.configdefaults import config
from theano.configparser import change_flags
# This is the api version for ops that generate C code. External ops # This is the api version for ops that generate C code. External ops
# might need manual changes if this number goes up. An undefined # might need manual changes if this number goes up. An undefined
......
...@@ -271,7 +271,7 @@ class OpFromGraph(gof.Op): ...@@ -271,7 +271,7 @@ class OpFromGraph(gof.Op):
is_inline = self.is_inline is_inline = self.is_inline
return '%(name)s{inline=%(is_inline)s}' % locals() return '%(name)s{inline=%(is_inline)s}' % locals()
@theano.configparser.change_flags(compute_test_value='off') @theano.change_flags(compute_test_value='off')
def _recompute_grad_op(self): def _recompute_grad_op(self):
''' '''
converts self._grad_op from user supplied form to type(self) instance converts self._grad_op from user supplied form to type(self) instance
...@@ -375,7 +375,7 @@ class OpFromGraph(gof.Op): ...@@ -375,7 +375,7 @@ class OpFromGraph(gof.Op):
self._grad_op_stypes_l = all_grads_ov_l self._grad_op_stypes_l = all_grads_ov_l
self._grad_op_is_cached = True self._grad_op_is_cached = True
@theano.configparser.change_flags(compute_test_value='off') @theano.change_flags(compute_test_value='off')
def _recompute_rop_op(self): def _recompute_rop_op(self):
''' '''
converts self._rop_op from user supplied form to type(self) instance converts self._rop_op from user supplied form to type(self) instance
......
...@@ -28,7 +28,7 @@ from theano.compile.function_module import ( ...@@ -28,7 +28,7 @@ from theano.compile.function_module import (
std_fgraph) std_fgraph)
from theano.compile.mode import Mode, register_mode from theano.compile.mode import Mode, register_mode
from theano.compile.ops import OutputGuard, _output_guard from theano.compile.ops import OutputGuard, _output_guard
from theano.configparser import change_flags from theano import change_flags
__docformat__ = "restructuredtext en" __docformat__ = "restructuredtext en"
......
...@@ -1829,7 +1829,7 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False, ...@@ -1829,7 +1829,7 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False,
on_unused_input=on_unused_input, on_unused_input=on_unused_input,
output_keys=output_keys, output_keys=output_keys,
name=name) name=name)
with theano.configparser.change_flags(compute_test_value="off"): with theano.change_flags(compute_test_value="off"):
fn = m.create(defaults) fn = m.create(defaults)
finally: finally:
t2 = time.time() t2 = time.time()
......
...@@ -315,7 +315,7 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -315,7 +315,7 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
np.ones([3, 4], dtype=config.floatX)], np.ones([3, 4], dtype=config.floatX)],
OpFromGraph) OpFromGraph)
@theano.configparser.change_flags(compute_test_value='raise') @theano.change_flags(compute_test_value='raise')
def test_compute_test_value(self): def test_compute_test_value(self):
x = T.scalar('x') x = T.scalar('x')
x.tag.test_value = np.array(1., dtype=config.floatX) x.tag.test_value = np.array(1., dtype=config.floatX)
......
...@@ -284,7 +284,7 @@ def test_badoptimization_opt_err(): ...@@ -284,7 +284,7 @@ def test_badoptimization_opt_err():
# Test that opt that do an illegal change still get the error from gof. # Test that opt that do an illegal change still get the error from gof.
try: try:
with theano.configparser.change_flags(on_opt_error='raise'): with theano.change_flags(on_opt_error='raise'):
f2 = theano.function([a, b], a + b, f2 = theano.function([a, b], a + b,
mode=debugmode.DebugMode(optimizer=opt2, mode=debugmode.DebugMode(optimizer=opt2,
stability_patience=1)) stability_patience=1))
......
...@@ -12,7 +12,7 @@ from theano.gof import destroyhandler ...@@ -12,7 +12,7 @@ from theano.gof import destroyhandler
from theano.gof.fg import FunctionGraph, InconsistencyError from theano.gof.fg import FunctionGraph, InconsistencyError
from theano.gof.toolbox import ReplaceValidate from theano.gof.toolbox import ReplaceValidate
from theano.configparser import change_flags from theano import change_flags
from copy import copy from copy import copy
......
...@@ -571,7 +571,7 @@ class TestEquilibrium(object): ...@@ -571,7 +571,7 @@ class TestEquilibrium(object):
opt.optimize(g) opt.optimize(g)
assert str(g) == '[Op2(x, y)]' assert str(g) == '[Op2(x, y)]'
@theano.configparser.change_flags(on_opt_error='ignore') @theano.change_flags(on_opt_error='ignore')
def test_low_use_ratio(self): def test_low_use_ratio(self):
x, y, z = map(MyVariable, 'xyz') x, y, z = map(MyVariable, 'xyz')
e = op3(op4(x, y)) e = op3(op4(x, y))
......
...@@ -12,10 +12,9 @@ from six import string_types ...@@ -12,10 +12,9 @@ from six import string_types
import re import re
import theano import theano
from theano.gof import utils from theano.gof import graph, utils
from theano.gof.utils import MethodNotDefined, object2 from theano.gof.utils import MethodNotDefined, object2
from theano.gof import graph from theano import change_flags
from theano.configparser import change_flags
######## ########
# Type # # Type #
......
...@@ -283,7 +283,7 @@ class test_gpu_ifelse(test_ifelse.test_ifelse): ...@@ -283,7 +283,7 @@ class test_gpu_ifelse(test_ifelse.test_ifelse):
z = tensor.constant(2.) z = tensor.constant(2.)
a = theano.ifelse.ifelse(x, y, z) a = theano.ifelse.ifelse(x, y, z)
with theano.configparser.change_flags(on_opt_error='raise'): with theano.change_flags(on_opt_error='raise'):
theano.function([x], [a], mode=mode_with_gpu) theano.function([x], [a], mode=mode_with_gpu)
...@@ -516,7 +516,7 @@ def test_not_useless_scalar_gpuelemwise(): ...@@ -516,7 +516,7 @@ def test_not_useless_scalar_gpuelemwise():
# We don't want to move elemwise on scalar on the GPU when the # We don't want to move elemwise on scalar on the GPU when the
# result will not be used on the GPU! # result will not be used on the GPU!
with theano.configparser.change_flags(warn_float64='ignore'): with theano.change_flags(warn_float64='ignore'):
X = tensor.fmatrix() X = tensor.fmatrix()
x = np.random.randn(32, 32).astype(np.float32) x = np.random.randn(32, 32).astype(np.float32)
m1 = theano.shared(np.random.randn(32, 32).astype(np.float32)) m1 = theano.shared(np.random.randn(32, 32).astype(np.float32))
......
...@@ -4,8 +4,7 @@ import functools ...@@ -4,8 +4,7 @@ import functools
import numpy as np import numpy as np
import theano import theano
from theano import tensor from theano import change_flags, tensor
from theano.configparser import change_flags
from theano.sandbox import rng_mrg from theano.sandbox import rng_mrg
from theano.sandbox.rng_mrg import MRG_RandomStreams from theano.sandbox.rng_mrg import MRG_RandomStreams
from theano.sandbox.tests.test_rng_mrg import java_samples, rng_mrg_overflow from theano.sandbox.tests.test_rng_mrg import java_samples, rng_mrg_overflow
......
...@@ -123,7 +123,7 @@ def test_advinc_subtensor1_dtype(): ...@@ -123,7 +123,7 @@ def test_advinc_subtensor1_dtype():
assert np.allclose(rval, rep) assert np.allclose(rval, rep)
@theano.configparser.change_flags(deterministic='more') @theano.change_flags(deterministic='more')
def test_deterministic_flag(): def test_deterministic_flag():
shp = (3, 4) shp = (3, 4)
for dtype1, dtype2 in [('float32', 'int8')]: for dtype1, dtype2 in [('float32', 'int8')]:
......
...@@ -754,7 +754,7 @@ class MRG_RandomStreams(object): ...@@ -754,7 +754,7 @@ class MRG_RandomStreams(object):
self.rstate = multMatVect(self.rstate, A1p134, M1, A2p134, M2) self.rstate = multMatVect(self.rstate, A1p134, M1, A2p134, M2)
assert self.rstate.dtype == np.int32 assert self.rstate.dtype == np.int32
@theano.configparser.change_flags(compute_test_value='off') @theano.change_flags(compute_test_value='off')
def get_substream_rstates(self, n_streams, dtype, inc_rstate=True): def get_substream_rstates(self, n_streams, dtype, inc_rstate=True):
# TODO : need description for parameter and return # TODO : need description for parameter and return
""" """
......
...@@ -9,14 +9,12 @@ import numpy as np ...@@ -9,14 +9,12 @@ import numpy as np
from six.moves import xrange from six.moves import xrange
import theano import theano
from theano import tensor, config from theano import change_flags, config, tensor
from theano.sandbox import rng_mrg from theano.sandbox import rng_mrg
from theano.sandbox.rng_mrg import MRG_RandomStreams from theano.sandbox.rng_mrg import MRG_RandomStreams
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano.tests.unittest_tools import attr from theano.tests.unittest_tools import attr
from theano.configparser import change_flags
# TODO: test MRG_RandomStreams # TODO: test MRG_RandomStreams
# Partly done in test_consistency_randomstreams # Partly done in test_consistency_randomstreams
......
...@@ -2790,7 +2790,7 @@ class T_Scan(unittest.TestCase): ...@@ -2790,7 +2790,7 @@ class T_Scan(unittest.TestCase):
utt.assert_allclose(expected_output, scan_output) utt.assert_allclose(expected_output, scan_output)
utt.assert_allclose(expected_output, jacobian_outputs) utt.assert_allclose(expected_output, jacobian_outputs)
@theano.configparser.change_flags(on_opt_error='raise') @theano.change_flags(on_opt_error='raise')
def test_pushout_seqs2(self): def test_pushout_seqs2(self):
# This test for a bug with PushOutSeqScan that was reported on the # This test for a bug with PushOutSeqScan that was reported on the
# theano-user mailing list where the optimization raised an exception # theano-user mailing list where the optimization raised an exception
...@@ -2807,7 +2807,7 @@ class T_Scan(unittest.TestCase): ...@@ -2807,7 +2807,7 @@ class T_Scan(unittest.TestCase):
# an exception being raised # an exception being raised
theano.function([x], outputs, updates=updates) theano.function([x], outputs, updates=updates)
@theano.configparser.change_flags(on_opt_error='raise') @theano.change_flags(on_opt_error='raise')
def test_pushout_nonseq(self): def test_pushout_nonseq(self):
# Test case originally reported by Daniel Renshaw. The crashed occured # Test case originally reported by Daniel Renshaw. The crashed occured
# during the optimization PushOutNonSeqScan when it attempted to # during the optimization PushOutNonSeqScan when it attempted to
......
...@@ -8,7 +8,7 @@ from six.moves import xrange ...@@ -8,7 +8,7 @@ from six.moves import xrange
import theano import theano
from theano import gof from theano import gof
from theano.compat import izip from theano.compat import izip
from theano.configparser import change_flags from theano import change_flags
from theano.gof import Apply, Op, COp, OpenMPOp, ParamsType from theano.gof import Apply, Op, COp, OpenMPOp, ParamsType
from theano import scalar from theano import scalar
from theano.scalar import get_scalar_type from theano.scalar import get_scalar_type
......
...@@ -14,7 +14,7 @@ from theano.scalar import add, sub, true_div, mul ...@@ -14,7 +14,7 @@ from theano.scalar import add, sub, true_div, mul
class BNComposite(Composite): class BNComposite(Composite):
init_param = ('dtype',) init_param = ('dtype',)
@theano.configparser.change_flags(compute_test_value='off') @theano.change_flags(compute_test_value='off')
def __init__(self, dtype): def __init__(self, dtype):
self.dtype = dtype self.dtype = dtype
x = theano.scalar.Scalar(dtype=dtype).make_variable() x = theano.scalar.Scalar(dtype=dtype).make_variable()
......
...@@ -6,7 +6,7 @@ from nose.tools import assert_raises, assert_true ...@@ -6,7 +6,7 @@ from nose.tools import assert_raises, assert_true
import theano import theano
from theano import tensor from theano import tensor
from theano.configparser import change_flags from theano import change_flags
from theano.gof.opt import check_stack_trace from theano.gof.opt import check_stack_trace
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano.tensor.nnet import (corr, corr3d, conv2d_transpose, from theano.tensor.nnet import (corr, corr3d, conv2d_transpose,
......
...@@ -8,7 +8,7 @@ import theano.tensor as T ...@@ -8,7 +8,7 @@ import theano.tensor as T
from theano.tensor.nnet.neighbours import images2neibs, neibs2images, Images2Neibs from theano.tensor.nnet.neighbours import images2neibs, neibs2images, Images2Neibs
from theano.tests import unittest_tools from theano.tests import unittest_tools
from theano.configparser import change_flags from theano import change_flags
mode_without_gpu = theano.compile.mode.get_default_mode().excluding('gpu') mode_without_gpu = theano.compile.mode.get_default_mode().excluding('gpu')
......
...@@ -57,7 +57,7 @@ from theano.tensor import ( ...@@ -57,7 +57,7 @@ from theano.tensor import (
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano.tests.unittest_tools import attr from theano.tests.unittest_tools import attr
from theano.configparser import change_flags from theano import change_flags
imported_scipy_special = False imported_scipy_special = False
mode_no_scipy = get_default_mode() mode_no_scipy = get_default_mode()
......
...@@ -6,7 +6,7 @@ import subprocess ...@@ -6,7 +6,7 @@ import subprocess
import os import os
from theano.gof.sched import sort_schedule_fn from theano.gof.sched import sort_schedule_fn
from theano.configparser import change_flags from theano import change_flags
mpi_scheduler = sort_schedule_fn(*mpi_cmps) mpi_scheduler = sort_schedule_fn(*mpi_cmps)
mpi_linker = theano.OpWiseCLinker(schedule=mpi_scheduler) mpi_linker = theano.OpWiseCLinker(schedule=mpi_scheduler)
......
...@@ -63,7 +63,7 @@ from theano.tensor.elemwise import DimShuffle ...@@ -63,7 +63,7 @@ from theano.tensor.elemwise import DimShuffle
from theano.tensor.type import values_eq_approx_remove_nan from theano.tensor.type import values_eq_approx_remove_nan
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano.gof.opt import check_stack_trace, out2in from theano.gof.opt import check_stack_trace, out2in
from theano.configparser import change_flags from theano import change_flags
from nose.plugins.attrib import attr from nose.plugins.attrib import attr
mode_opt = theano.config.mode mode_opt = theano.config.mode
......
...@@ -34,7 +34,7 @@ from theano.tensor.subtensor import (AdvancedIncSubtensor, ...@@ -34,7 +34,7 @@ from theano.tensor.subtensor import (AdvancedIncSubtensor,
from theano.tensor.tests.test_basic import inplace_func, rand, randint_ranged from theano.tensor.tests.test_basic import inplace_func, rand, randint_ranged
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano.tests.unittest_tools import attr from theano.tests.unittest_tools import attr
from theano.configparser import change_flags from theano import change_flags
if PY3: if PY3:
def L(i): def L(i):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论