提交 394857d6 authored 作者: Frederic Bastien's avatar Frederic Bastien

Add shortcut theano.change_flags and use it.

上级 682ce237
......@@ -65,6 +65,7 @@ def disable_log_handler(logger=theano_logger, handler=logging_default_handler):
from theano.version import version as __version__
from theano.configdefaults import config
from theano.configparser import change_flags
# This is the api version for ops that generate C code. External ops
# might need manual changes if this number goes up. An undefined
......
......@@ -271,7 +271,7 @@ class OpFromGraph(gof.Op):
is_inline = self.is_inline
return '%(name)s{inline=%(is_inline)s}' % locals()
@theano.configparser.change_flags(compute_test_value='off')
@theano.change_flags(compute_test_value='off')
def _recompute_grad_op(self):
'''
converts self._grad_op from user supplied form to type(self) instance
......@@ -375,7 +375,7 @@ class OpFromGraph(gof.Op):
self._grad_op_stypes_l = all_grads_ov_l
self._grad_op_is_cached = True
@theano.configparser.change_flags(compute_test_value='off')
@theano.change_flags(compute_test_value='off')
def _recompute_rop_op(self):
'''
converts self._rop_op from user supplied form to type(self) instance
......
......@@ -28,7 +28,7 @@ from theano.compile.function_module import (
std_fgraph)
from theano.compile.mode import Mode, register_mode
from theano.compile.ops import OutputGuard, _output_guard
from theano.configparser import change_flags
from theano import change_flags
__docformat__ = "restructuredtext en"
......
......@@ -1829,7 +1829,7 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False,
on_unused_input=on_unused_input,
output_keys=output_keys,
name=name)
with theano.configparser.change_flags(compute_test_value="off"):
with theano.change_flags(compute_test_value="off"):
fn = m.create(defaults)
finally:
t2 = time.time()
......
......@@ -315,7 +315,7 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
np.ones([3, 4], dtype=config.floatX)],
OpFromGraph)
@theano.configparser.change_flags(compute_test_value='raise')
@theano.change_flags(compute_test_value='raise')
def test_compute_test_value(self):
x = T.scalar('x')
x.tag.test_value = np.array(1., dtype=config.floatX)
......
......@@ -284,7 +284,7 @@ def test_badoptimization_opt_err():
# Test that opt that do an illegal change still get the error from gof.
try:
with theano.configparser.change_flags(on_opt_error='raise'):
with theano.change_flags(on_opt_error='raise'):
f2 = theano.function([a, b], a + b,
mode=debugmode.DebugMode(optimizer=opt2,
stability_patience=1))
......
......@@ -12,7 +12,7 @@ from theano.gof import destroyhandler
from theano.gof.fg import FunctionGraph, InconsistencyError
from theano.gof.toolbox import ReplaceValidate
from theano.configparser import change_flags
from theano import change_flags
from copy import copy
......
......@@ -571,7 +571,7 @@ class TestEquilibrium(object):
opt.optimize(g)
assert str(g) == '[Op2(x, y)]'
@theano.configparser.change_flags(on_opt_error='ignore')
@theano.change_flags(on_opt_error='ignore')
def test_low_use_ratio(self):
x, y, z = map(MyVariable, 'xyz')
e = op3(op4(x, y))
......
......@@ -12,10 +12,9 @@ from six import string_types
import re
import theano
from theano.gof import utils
from theano.gof import graph, utils
from theano.gof.utils import MethodNotDefined, object2
from theano.gof import graph
from theano.configparser import change_flags
from theano import change_flags
########
# Type #
......
......@@ -283,7 +283,7 @@ class test_gpu_ifelse(test_ifelse.test_ifelse):
z = tensor.constant(2.)
a = theano.ifelse.ifelse(x, y, z)
with theano.configparser.change_flags(on_opt_error='raise'):
with theano.change_flags(on_opt_error='raise'):
theano.function([x], [a], mode=mode_with_gpu)
......@@ -516,7 +516,7 @@ def test_not_useless_scalar_gpuelemwise():
# We don't want to move elemwise on scalar on the GPU when the
# result will not be used on the GPU!
with theano.configparser.change_flags(warn_float64='ignore'):
with theano.change_flags(warn_float64='ignore'):
X = tensor.fmatrix()
x = np.random.randn(32, 32).astype(np.float32)
m1 = theano.shared(np.random.randn(32, 32).astype(np.float32))
......
......@@ -4,8 +4,7 @@ import functools
import numpy as np
import theano
from theano import tensor
from theano.configparser import change_flags
from theano import change_flags, tensor
from theano.sandbox import rng_mrg
from theano.sandbox.rng_mrg import MRG_RandomStreams
from theano.sandbox.tests.test_rng_mrg import java_samples, rng_mrg_overflow
......
......@@ -123,7 +123,7 @@ def test_advinc_subtensor1_dtype():
assert np.allclose(rval, rep)
@theano.configparser.change_flags(deterministic='more')
@theano.change_flags(deterministic='more')
def test_deterministic_flag():
shp = (3, 4)
for dtype1, dtype2 in [('float32', 'int8')]:
......
......@@ -754,7 +754,7 @@ class MRG_RandomStreams(object):
self.rstate = multMatVect(self.rstate, A1p134, M1, A2p134, M2)
assert self.rstate.dtype == np.int32
@theano.configparser.change_flags(compute_test_value='off')
@theano.change_flags(compute_test_value='off')
def get_substream_rstates(self, n_streams, dtype, inc_rstate=True):
# TODO : need description for parameter and return
"""
......
......@@ -9,14 +9,12 @@ import numpy as np
from six.moves import xrange
import theano
from theano import tensor, config
from theano import change_flags, config, tensor
from theano.sandbox import rng_mrg
from theano.sandbox.rng_mrg import MRG_RandomStreams
from theano.tests import unittest_tools as utt
from theano.tests.unittest_tools import attr
from theano.configparser import change_flags
# TODO: test MRG_RandomStreams
# Partly done in test_consistency_randomstreams
......
......@@ -2790,7 +2790,7 @@ class T_Scan(unittest.TestCase):
utt.assert_allclose(expected_output, scan_output)
utt.assert_allclose(expected_output, jacobian_outputs)
@theano.configparser.change_flags(on_opt_error='raise')
@theano.change_flags(on_opt_error='raise')
def test_pushout_seqs2(self):
# This test for a bug with PushOutSeqScan that was reported on the
# theano-user mailing list where the optimization raised an exception
......@@ -2807,7 +2807,7 @@ class T_Scan(unittest.TestCase):
# an exception being raised
theano.function([x], outputs, updates=updates)
@theano.configparser.change_flags(on_opt_error='raise')
@theano.change_flags(on_opt_error='raise')
def test_pushout_nonseq(self):
# Test case originally reported by Daniel Renshaw. The crashed occured
# during the optimization PushOutNonSeqScan when it attempted to
......
......@@ -8,7 +8,7 @@ from six.moves import xrange
import theano
from theano import gof
from theano.compat import izip
from theano.configparser import change_flags
from theano import change_flags
from theano.gof import Apply, Op, COp, OpenMPOp, ParamsType
from theano import scalar
from theano.scalar import get_scalar_type
......
......@@ -14,7 +14,7 @@ from theano.scalar import add, sub, true_div, mul
class BNComposite(Composite):
init_param = ('dtype',)
@theano.configparser.change_flags(compute_test_value='off')
@theano.change_flags(compute_test_value='off')
def __init__(self, dtype):
self.dtype = dtype
x = theano.scalar.Scalar(dtype=dtype).make_variable()
......
......@@ -6,7 +6,7 @@ from nose.tools import assert_raises, assert_true
import theano
from theano import tensor
from theano.configparser import change_flags
from theano import change_flags
from theano.gof.opt import check_stack_trace
from theano.tests import unittest_tools as utt
from theano.tensor.nnet import (corr, corr3d, conv2d_transpose,
......
......@@ -8,7 +8,7 @@ import theano.tensor as T
from theano.tensor.nnet.neighbours import images2neibs, neibs2images, Images2Neibs
from theano.tests import unittest_tools
from theano.configparser import change_flags
from theano import change_flags
mode_without_gpu = theano.compile.mode.get_default_mode().excluding('gpu')
......
......@@ -57,7 +57,7 @@ from theano.tensor import (
from theano.tests import unittest_tools as utt
from theano.tests.unittest_tools import attr
from theano.configparser import change_flags
from theano import change_flags
imported_scipy_special = False
mode_no_scipy = get_default_mode()
......
......@@ -6,7 +6,7 @@ import subprocess
import os
from theano.gof.sched import sort_schedule_fn
from theano.configparser import change_flags
from theano import change_flags
mpi_scheduler = sort_schedule_fn(*mpi_cmps)
mpi_linker = theano.OpWiseCLinker(schedule=mpi_scheduler)
......
......@@ -63,7 +63,7 @@ from theano.tensor.elemwise import DimShuffle
from theano.tensor.type import values_eq_approx_remove_nan
from theano.tests import unittest_tools as utt
from theano.gof.opt import check_stack_trace, out2in
from theano.configparser import change_flags
from theano import change_flags
from nose.plugins.attrib import attr
mode_opt = theano.config.mode
......
......@@ -34,7 +34,7 @@ from theano.tensor.subtensor import (AdvancedIncSubtensor,
from theano.tensor.tests.test_basic import inplace_func, rand, randint_ranged
from theano.tests import unittest_tools as utt
from theano.tests.unittest_tools import attr
from theano.configparser import change_flags
from theano import change_flags
if PY3:
def L(i):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论