Unverified 提交 9f7a1b69 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: GitHub

Merge pull request #110 from brandonwillard/enforce-sane-test-values

Refactor test value framework so that test value validation is performed up-front.
...@@ -15,10 +15,6 @@ ...@@ -15,10 +15,6 @@
.. autoclass:: theano.misc.pkl_utils.StripPickler .. autoclass:: theano.misc.pkl_utils.StripPickler
.. autoclass:: theano.misc.pkl_utils.CompatUnpickler
.. seealso:: .. seealso::
:ref:`tutorial_loadsave` :ref:`tutorial_loadsave`
...@@ -167,14 +167,16 @@ class TestComputeTestValue: ...@@ -167,14 +167,16 @@ class TestComputeTestValue:
@theano.change_flags(compute_test_value="raise") @theano.change_flags(compute_test_value="raise")
def test_incorrect_type(self): def test_incorrect_type(self):
x = tt.fmatrix("x")
# Incorrect dtype (float64) for test_value
x.tag.test_value = np.random.rand(3, 4)
y = tt.dmatrix("y")
y.tag.test_value = np.random.rand(4, 5)
x = tt.vector("x")
with pytest.raises(TypeError): with pytest.raises(TypeError):
tt.dot(x, y) # Incorrect shape for test value
x.tag.test_value = np.empty((2, 2))
x = tt.fmatrix("x")
with pytest.raises(TypeError):
# Incorrect dtype (float64) for test value
x.tag.test_value = np.random.rand(3, 4)
@theano.change_flags(compute_test_value="raise") @theano.change_flags(compute_test_value="raise")
def test_overided_function(self): def test_overided_function(self):
......
import os
import pickle import pickle
import pytest
import theano
from theano.compat import PY3
from theano.gof.fg import FunctionGraph
from theano import tensor as tt from theano import tensor as tt
from theano.gof.fg import FunctionGraph
class TestFunctionGraph: class TestFunctionGraph:
...@@ -16,24 +11,3 @@ class TestFunctionGraph: ...@@ -16,24 +11,3 @@ class TestFunctionGraph:
s = pickle.dumps(func) s = pickle.dumps(func)
pickle.loads(s) pickle.loads(s)
@pytest.mark.skipif(
not theano.config.cxx, reason="G++ not available, so we need to skip this test."
)
@pytest.mark.slow
def test_node_outputs_not_used(self):
# In the past, we where removing some not used variable from
# fgraph.variables event if the apply had other output used in
# the graph. This caused a crash.
# This test run the pickle that reproduce this case.
with open(
os.path.join(os.path.dirname(__file__), "test_fg_old_crash.pkl"), "rb"
) as f:
from theano.misc.pkl_utils import CompatUnpickler
if PY3:
u = CompatUnpickler(f, encoding="latin1")
else:
u = CompatUnpickler(f)
d = u.load()
f = theano.function(**d)
import numpy as np import numpy as np
import pytest import pytest
import theano import theano
import theano.gof.op as op import theano.gof.op as op
import theano.tensor as tt
from six import string_types from six import string_types
from theano.gof.type import Type, Generic from theano import scalar, shared
from theano.configparser import change_flags
from theano.gof.graph import Apply, Variable from theano.gof.graph import Apply, Variable
import theano.tensor as T from theano.gof.type import Generic, Type
from theano import scalar
from theano import shared
config = theano.config config = theano.config
Op = op.Op Op = op.Op
...@@ -238,15 +238,15 @@ class TestMakeThunk: ...@@ -238,15 +238,15 @@ class TestMakeThunk:
__props__ = () __props__ = ()
itypes = [T.dmatrix] itypes = [tt.dmatrix]
otypes = [T.dmatrix] otypes = [tt.dmatrix]
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
inp = inputs[0] inp = inputs[0]
output = outputs[0] output = outputs[0]
output[0] = inp * 2 output[0] = inp * 2
x_input = T.dmatrix("x_input") x_input = tt.dmatrix("x_input")
f = theano.function([x_input], DoubleOp()(x_input)) f = theano.function([x_input], DoubleOp()(x_input))
inp = np.random.rand(5, 4) inp = np.random.rand(5, 4)
out = f(inp) out = f(inp)
...@@ -255,17 +255,17 @@ class TestMakeThunk: ...@@ -255,17 +255,17 @@ class TestMakeThunk:
def test_test_value_python_objects(): def test_test_value_python_objects():
for x in ([0, 1, 2], 0, 0.5, 1): for x in ([0, 1, 2], 0, 0.5, 1):
assert (op.get_test_value(x) == x).all() assert np.all(op.get_test_value(x) == x)
def test_test_value_ndarray(): def test_test_value_ndarray():
x = np.zeros((5, 5)) x = np.zeros((5, 5))
v = op.get_test_value(x) v = op.get_test_value(x)
assert (v == x).all() assert np.all(v == x)
def test_test_value_constant(): def test_test_value_constant():
x = T.as_tensor_variable(np.zeros((5, 5))) x = tt.as_tensor_variable(np.zeros((5, 5)))
v = op.get_test_value(x) v = op.get_test_value(x)
assert np.all(v == np.zeros((5, 5))) assert np.all(v == np.zeros((5, 5)))
...@@ -278,62 +278,37 @@ def test_test_value_shared(): ...@@ -278,62 +278,37 @@ def test_test_value_shared():
assert np.all(v == np.zeros((5, 5))) assert np.all(v == np.zeros((5, 5)))
@change_flags(compute_test_value="raise")
def test_test_value_op(): def test_test_value_op():
try:
prev_value = config.compute_test_value
config.compute_test_value = "raise"
x = T.log(np.ones((5, 5)))
v = op.get_test_value(x)
assert np.allclose(v, np.zeros((5, 5)))
finally:
config.compute_test_value = prev_value
def test_get_debug_values_no_debugger(): x = tt.log(np.ones((5, 5)))
"get_debug_values should return [] when debugger is off" v = op.get_test_value(x)
prev_value = config.compute_test_value assert np.allclose(v, np.zeros((5, 5)))
try:
config.compute_test_value = "off"
x = T.vector()
for x_val in op.get_debug_values(x): @change_flags(compute_test_value="off")
assert False def test_get_debug_values_no_debugger():
"""Tests that `get_debug_values` returns `[]` when debugger is off."""
finally: x = tt.vector()
config.compute_test_value = prev_value assert op.get_debug_values(x) == []
@change_flags(compute_test_value="ignore")
def test_get_det_debug_values_ignore(): def test_get_det_debug_values_ignore():
# get_debug_values should return [] when debugger is ignore """Tests that `get_debug_values` returns `[]` when debugger is set to "ignore" and some values are missing."""
# and some values are missing
prev_value = config.compute_test_value x = tt.vector()
try: assert op.get_debug_values(x) == []
config.compute_test_value = "ignore"
x = T.vector()
for x_val in op.get_debug_values(x):
assert False
finally:
config.compute_test_value = prev_value
def test_get_debug_values_success(): def test_get_debug_values_success():
# tests that get_debug_value returns values when available """Tests that `get_debug_value` returns values when available (and the debugger is on)."""
# (and the debugger is on)
prev_value = config.compute_test_value
for mode in ["ignore", "warn", "raise"]: for mode in ["ignore", "warn", "raise"]:
with change_flags(compute_test_value=mode):
try: x = tt.vector()
config.compute_test_value = mode
x = T.vector()
x.tag.test_value = np.zeros((4,), dtype=config.floatX) x.tag.test_value = np.zeros((4,), dtype=config.floatX)
y = np.zeros((5, 5)) y = np.zeros((5, 5))
...@@ -348,54 +323,11 @@ def test_get_debug_values_success(): ...@@ -348,54 +323,11 @@ def test_get_debug_values_success():
assert iters == 1 assert iters == 1
finally:
config.compute_test_value = prev_value
@change_flags(compute_test_value="raise")
def test_get_debug_values_exc(): def test_get_debug_values_exc():
# tests that get_debug_value raises an exception when """Tests that `get_debug_value` raises an exception when debugger is set to raise and a value is missing."""
# debugger is set to raise and a value is missing
prev_value = config.compute_test_value
try:
config.compute_test_value = "raise"
x = T.vector()
try:
for x_val in op.get_debug_values(x):
# this assert catches the case where we
# erroneously get a value returned
assert False
raised = False
except AttributeError:
raised = True
# this assert catches the case where we got []
# returned, and possibly issued a warning,
# rather than raising an exception
assert raised
finally: with pytest.raises(AttributeError):
config.compute_test_value = prev_value x = tt.vector()
assert op.get_debug_values(x) == []
def test_debug_error_message():
# tests that debug_error_message raises an
# exception when it should.
prev_value = config.compute_test_value
for mode in ["ignore", "raise"]:
try:
config.compute_test_value = mode
try:
op.debug_error_message("msg")
raised = False
except ValueError:
raised = True
assert raised
finally:
config.compute_test_value = prev_value
...@@ -7,9 +7,10 @@ import theano ...@@ -7,9 +7,10 @@ import theano
import tests.unittest_tools as utt import tests.unittest_tools as utt
from pickle import Unpickler
from theano import config, function, tensor from theano import config, function, tensor
from theano.compat import PY3 from theano.compat import PY3
from theano.misc.pkl_utils import CompatUnpickler
from theano.sandbox import multinomial from theano.sandbox import multinomial
from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams
from theano.gpuarray.multinomial import ( from theano.gpuarray.multinomial import (
...@@ -384,6 +385,6 @@ def test_unpickle_legacy_op(): ...@@ -384,6 +385,6 @@ def test_unpickle_legacy_op():
if not PY3: if not PY3:
with open(os.path.join(testfile_dir, fname), "r") as fp: with open(os.path.join(testfile_dir, fname), "r") as fp:
u = CompatUnpickler(fp) u = Unpickler(fp)
m = u.load() m = u.load()
assert isinstance(m, GPUAChoiceFromUniform) assert isinstance(m, GPUAChoiceFromUniform)
...@@ -13,9 +13,9 @@ import pytest ...@@ -13,9 +13,9 @@ import pytest
import numpy as np import numpy as np
from pickle import Unpickler
from theano import config from theano import config
from theano.compat import PY3
from theano.misc.pkl_utils import CompatUnpickler
from theano.gpuarray.type import ContextNotDefined from theano.gpuarray.type import ContextNotDefined
...@@ -37,10 +37,7 @@ def test_unpickle_gpuarray_as_numpy_ndarray_flag1(): ...@@ -37,10 +37,7 @@ def test_unpickle_gpuarray_as_numpy_ndarray_flag1():
fname = "GpuArray.pkl" fname = "GpuArray.pkl"
with open(os.path.join(testfile_dir, fname), "rb") as fp: with open(os.path.join(testfile_dir, fname), "rb") as fp:
if PY3: u = Unpickler(fp, encoding="latin1")
u = CompatUnpickler(fp, encoding="latin1")
else:
u = CompatUnpickler(fp)
with pytest.raises((ImportError, ContextNotDefined)): with pytest.raises((ImportError, ContextNotDefined)):
u.load() u.load()
finally: finally:
...@@ -56,10 +53,7 @@ def test_unpickle_gpuarray_as_numpy_ndarray_flag2(): ...@@ -56,10 +53,7 @@ def test_unpickle_gpuarray_as_numpy_ndarray_flag2():
fname = "GpuArray.pkl" fname = "GpuArray.pkl"
with open(os.path.join(testfile_dir, fname), "rb") as fp: with open(os.path.join(testfile_dir, fname), "rb") as fp:
if PY3: u = Unpickler(fp, encoding="latin1")
u = CompatUnpickler(fp, encoding="latin1")
else:
u = CompatUnpickler(fp)
try: try:
mat = u.load() mat = u.load()
except ImportError: except ImportError:
......
...@@ -5,10 +5,10 @@ import theano ...@@ -5,10 +5,10 @@ import theano
pygpu = pytest.importorskip("pygpu") pygpu = pytest.importorskip("pygpu")
from theano.compat import PY3 from pickle import Unpickler
from theano import config from theano import config
from theano.compile import DeepCopyOp, Rebroadcast, ViewOp from theano.compile import DeepCopyOp, Rebroadcast, ViewOp
from theano.misc.pkl_utils import CompatUnpickler
from theano.gpuarray.type import GpuArrayType, gpuarray_shared_constructor from theano.gpuarray.type import GpuArrayType, gpuarray_shared_constructor
from tests.gpuarray.config import test_ctx_name from tests.gpuarray.config import test_ctx_name
...@@ -122,10 +122,7 @@ def test_unpickle_gpuarray_as_numpy_ndarray_flag0(): ...@@ -122,10 +122,7 @@ def test_unpickle_gpuarray_as_numpy_ndarray_flag0():
fname = "GpuArray.pkl" fname = "GpuArray.pkl"
with open(os.path.join(testfile_dir, fname), "rb") as fp: with open(os.path.join(testfile_dir, fname), "rb") as fp:
if PY3: u = Unpickler(fp, encoding="latin1")
u = CompatUnpickler(fp, encoding="latin1")
else:
u = CompatUnpickler(fp)
mat = u.load() mat = u.load()
assert isinstance(mat, pygpu.gpuarray.GpuArray) assert isinstance(mat, pygpu.gpuarray.GpuArray)
assert np.asarray(mat)[0] == -42.0 assert np.asarray(mat)[0] == -42.0
......
import os
import sys
import numpy as np import numpy as np
import theano
import tests.unittest_tools as utt import tests.unittest_tools as utt
from theano import config, function, tensor from theano import config, function, tensor
from theano.sandbox import multinomial from theano.sandbox import multinomial
from theano.compat import PY3
from theano.misc.pkl_utils import CompatUnpickler
def test_n_samples_1(): def test_n_samples_1():
...@@ -51,40 +45,6 @@ def test_n_samples_2(): ...@@ -51,40 +45,6 @@ def test_n_samples_2():
assert res.sum() == i assert res.sum() == i
def test_n_samples_compatibility():
# This test checks if the new change to MultinomialFromUniform is still compatible
# with old interface. Here I will load a graph created (using the old interface) as follows:
# RandomStreams = theano.sandbox.rng_mrg.MRG_RandomStreams
# th_rng = RandomStreams(12345)
# X = T.matrix('X')
# pvals = T.exp(X)
# pvals = pvals / pvals.sum(axis=1, keepdims=True)
# samples = th_rng.multinomial(pvals=pvals)
# pickle.dump([X, samples], open("multinomial_test_graph.pkl", "w"))
folder = os.path.dirname(os.path.abspath(__file__))
with open(os.path.join(folder, "multinomial_test_graph.pkl"), "rb") as pkl_file:
if PY3:
u = CompatUnpickler(pkl_file, encoding="latin1")
else:
u = CompatUnpickler(pkl_file)
try:
X, samples = u.load()
except ImportError:
# Windows sometimes fail with nonsensical errors like:
# ImportError: No module named type
# ImportError: No module named copy_reg
# when "type" and "copy_reg" are builtin modules.
if sys.platform == "win32":
exc_type, exc_value, exc_trace = sys.exc_info()
raise
raise
f = theano.function([X], samples)
res = f(np.random.randn(20, 10))
assert np.all(res.sum(axis=1) == 1)
def test_multinomial_0(): def test_multinomial_0():
# This tests the MultinomialFromUniform Op directly, not going through the # This tests the MultinomialFromUniform Op directly, not going through the
# multinomial() call in GPU random generation. # multinomial() call in GPU random generation.
......
import numpy as np import numpy as np
import pytest import pytest
import os
from theano import config, function, tensor from theano import config, function, tensor
from theano.compat import PY3
from theano.misc.pkl_utils import CompatUnpickler
from theano.sandbox import multinomial from theano.sandbox import multinomial
from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams
...@@ -214,14 +212,3 @@ class TestFunction: ...@@ -214,14 +212,3 @@ class TestFunction:
avg_pvals /= avg_pvals.sum() avg_pvals /= avg_pvals.sum()
avg_diff = np.mean(abs(avg_pvals - pvals)) avg_diff = np.mean(abs(avg_pvals - pvals))
assert avg_diff < mean_rtol assert avg_diff < mean_rtol
def test_unpickle_legacy_op(self):
testfile_dir = os.path.dirname(os.path.realpath(__file__))
fname = "test_sandbox_multinomial_wo_replacement.pkl"
if not PY3:
with open(os.path.join(testfile_dir, fname), "r") as fp:
u = CompatUnpickler(fp)
m = u.load()
print(m)
assert isinstance(m, multinomial.ChoiceFromUniform)
...@@ -255,7 +255,7 @@ class InferShapeTester: ...@@ -255,7 +255,7 @@ class InferShapeTester:
else: else:
shp = inp.shape shp = inp.shape
if len(set(shp)) != len(shp): if len(set(shp)) != len(shp):
_logger.warn( _logger.warning(
"While testing shape inference for %r, we received an" "While testing shape inference for %r, we received an"
" input with a shape that has some repeated values: %r" " input with a shape that has some repeated values: %r"
", like a square matrix. This makes it impossible to" ", like a square matrix. This makes it impossible to"
......
...@@ -1437,7 +1437,7 @@ def _check_preallocated_output( ...@@ -1437,7 +1437,7 @@ def _check_preallocated_output(
fn_attr_name = ops_with_inner_function[type(node.op)] fn_attr_name = ops_with_inner_function[type(node.op)]
fn = getattr(node.op, fn_attr_name, None) fn = getattr(node.op, fn_attr_name, None)
if not fn or not hasattr(fn, "maker") or not hasattr(fn.maker, "mode"): if not fn or not hasattr(fn, "maker") or not hasattr(fn.maker, "mode"):
_logger.warn( _logger.warning(
"Expected theano function not found in %s.%s", node.op, fn_attr_name "Expected theano function not found in %s.%s", node.op, fn_attr_name
) )
else: else:
...@@ -1482,7 +1482,7 @@ def _check_preallocated_output( ...@@ -1482,7 +1482,7 @@ def _check_preallocated_output(
if not out_map: if not out_map:
# Map is empty, there is no need to execute thunk() again # Map is empty, there is no need to execute thunk() again
_logger.warn("%s: out_map is empty", name) _logger.warning("%s: out_map is empty", name)
continue continue
# Copy the inputs over, if they were marked as destroyed or viewed # Copy the inputs over, if they were marked as destroyed or viewed
...@@ -1904,7 +1904,7 @@ class _Linker(gof.link.LocalLinker): ...@@ -1904,7 +1904,7 @@ class _Linker(gof.link.LocalLinker):
thunks_py.append(None) thunks_py.append(None)
if not self.maker.mode.check_c_code and thunks_py[-1] is None: if not self.maker.mode.check_c_code and thunks_py[-1] is None:
_logger.warn( _logger.warning(
"Op %s doesn't have a perform, " "Op %s doesn't have a perform, "
"forcing check of the C code" % node.op "forcing check of the C code" % node.op
) )
...@@ -1921,7 +1921,7 @@ class _Linker(gof.link.LocalLinker): ...@@ -1921,7 +1921,7 @@ class _Linker(gof.link.LocalLinker):
elif thunks_c[-1] is None: elif thunks_c[-1] is None:
thunks_c[-1] = thunk_other thunks_c[-1] = thunk_other
else: else:
_logger.warn( _logger.warning(
"We won't check the perform function " "We won't check the perform function "
"of node '%s' but we will check its " "of node '%s' but we will check its "
"make_thunk function" % node "make_thunk function" % node
......
...@@ -2055,7 +2055,7 @@ class GCC_compiler(Compiler): ...@@ -2055,7 +2055,7 @@ class GCC_compiler(Compiler):
and "clang-omp++" not in theano.config.cxx and "clang-omp++" not in theano.config.cxx
and "icpc" not in theano.config.cxx and "icpc" not in theano.config.cxx
): ):
_logger.warn( _logger.warning(
"OPTIMIZATION WARNING: your Theano flag `cxx` seems not to be" "OPTIMIZATION WARNING: your Theano flag `cxx` seems not to be"
" the g++ compiler. So we disable the compiler optimization" " the g++ compiler. So we disable the compiler optimization"
" specific to g++ that tell to compile for a specific CPU." " specific to g++ that tell to compile for a specific CPU."
...@@ -2124,7 +2124,7 @@ class GCC_compiler(Compiler): ...@@ -2124,7 +2124,7 @@ class GCC_compiler(Compiler):
) )
else: else:
reported_lines = native_lines reported_lines = native_lines
_logger.warn( _logger.warning(
"OPTIMIZATION WARNING: Theano was not able to find the" "OPTIMIZATION WARNING: Theano was not able to find the"
" g++ parameters that tune the compilation to your " " g++ parameters that tune the compilation to your "
" specific CPU. This can slow down the execution of Theano" " specific CPU. This can slow down the execution of Theano"
...@@ -2137,7 +2137,7 @@ class GCC_compiler(Compiler): ...@@ -2137,7 +2137,7 @@ class GCC_compiler(Compiler):
default_lines = get_lines("%s -E -v -" % theano.config.cxx) default_lines = get_lines("%s -E -v -" % theano.config.cxx)
_logger.info("g++ default lines: %s", default_lines) _logger.info("g++ default lines: %s", default_lines)
if len(default_lines) < 1: if len(default_lines) < 1:
_logger.warn( _logger.warning(
"OPTIMIZATION WARNING: Theano was not able to find the" "OPTIMIZATION WARNING: Theano was not able to find the"
" default g++ parameters. This is needed to tune" " default g++ parameters. This is needed to tune"
" the compilation to your specific" " the compilation to your specific"
......
...@@ -349,7 +349,7 @@ def refresh_lock(lock_file): ...@@ -349,7 +349,7 @@ def refresh_lock(lock_file):
# This way, only 1 test would fail. # This way, only 1 test would fail.
while get_lock.n_lock > 0: while get_lock.n_lock > 0:
release_lock() release_lock()
_logger.warn( _logger.warning(
"Refreshing lock failed, we release the" "Refreshing lock failed, we release the"
" lock before raising again the exception" " lock before raising again the exception"
) )
......
...@@ -92,7 +92,7 @@ class Apply(Node): ...@@ -92,7 +92,7 @@ class Apply(Node):
def __init__(self, op, inputs, outputs): def __init__(self, op, inputs, outputs):
self.op = op self.op = op
self.inputs = [] self.inputs = []
self.tag = utils.scratchpad() self.tag = utils.Scratchpad()
if not isinstance(inputs, (list, tuple)): if not isinstance(inputs, (list, tuple)):
raise TypeError("The inputs of an Apply must be a list or tuple") raise TypeError("The inputs of an Apply must be a list or tuple")
...@@ -383,7 +383,8 @@ class Variable(Node): ...@@ -383,7 +383,8 @@ class Variable(Node):
def __init__(self, type, owner=None, index=None, name=None): def __init__(self, type, owner=None, index=None, name=None):
super(Variable, self).__init__() super(Variable, self).__init__()
self.tag = utils.scratchpad() self.tag = utils.ValidatingScratchpad("test_value", type.filter)
self.type = type self.type = type
if owner is not None and not isinstance(owner, Apply): if owner is not None and not isinstance(owner, Apply):
raise TypeError("owner must be an Apply instance", owner) raise TypeError("owner must be an Apply instance", owner)
......
...@@ -553,29 +553,10 @@ class PureOp(object): ...@@ -553,29 +553,10 @@ class PureOp(object):
elif isinstance(v, SharedVariable): elif isinstance(v, SharedVariable):
return v.get_value(borrow=True, return_internal_type=True) return v.get_value(borrow=True, return_internal_type=True)
elif isinstance(v, graph.Variable) and hasattr(v.tag, "test_value"): elif isinstance(v, graph.Variable) and hasattr(v.tag, "test_value"):
# ensure that the test value is correct return v.tag.test_value
try:
ret = v.type.filter(v.tag.test_value)
except Exception as e:
# Better error message.
detailed_err_msg = (
"For compute_test_value, one input test value does not"
" have the requested type.\n"
)
detailed_err_msg += utils.get_variable_trace_string(v)
detailed_err_msg += (
"\nThe error when converting the test value to that"
" variable type:"
)
# We need to only have 1 args and it should be of type
# string. Otherwise, it print the tuple and so the
# new line do not get printed.
args = (detailed_err_msg,) + tuple(str(arg) for arg in e.args)
e.args = ("\n".join(args),)
raise
return ret
detailed_err_msg = utils.get_variable_trace_string(v) detailed_err_msg = utils.get_variable_trace_string(v)
raise AttributeError("%s has no test value %s" % (v, detailed_err_msg)) raise AttributeError("%s has no test value %s" % (v, detailed_err_msg))
def __call__(self, *inputs, **kwargs): def __call__(self, *inputs, **kwargs):
...@@ -1057,48 +1038,13 @@ def missing_test_message(msg): ...@@ -1057,48 +1038,13 @@ def missing_test_message(msg):
assert action in ["ignore", "off"] assert action in ["ignore", "off"]
def debug_error_message(msg):
"""
Displays a message saying that an error was found in some
test_values. Becomes a warning or a ValueError depending on
config.compute_test_value.
"""
action = config.compute_test_value
# this message should never be called when the debugger is off
assert action != "off"
if action in ["raise", "ignore"]:
raise ValueError(msg)
else:
assert action == "warn"
warnings.warn(msg, stacklevel=2)
def debug_assert(condition, msg=None):
"""
Customized assert with options to ignore the assert
with just a warning
"""
if msg is None:
msg = "debug_assert failed"
if not condition:
action = config.compute_test_value
if action in ["raise", "ignore"]:
raise AssertionError(msg)
else:
assert action == "warn"
warnings.warn(msg, stacklevel=2)
def get_debug_values(*args): def get_debug_values(*args):
""" """
Intended use: Intended use:
for val_1, ..., val_n in get_debug_values(var_1, ..., var_n): for val_1, ..., val_n in get_debug_values(var_1, ..., var_n):
if some condition on val_1, ..., val_n is not met: if some condition on val_1, ..., val_n is not met:
debug_error_message("condition was not met") missing_test_message("condition was not met")
Given a list of variables, get_debug_values does one of three things: Given a list of variables, get_debug_values does one of three things:
...@@ -1128,10 +1074,10 @@ def get_debug_values(*args): ...@@ -1128,10 +1074,10 @@ def get_debug_values(*args):
except AttributeError: except AttributeError:
if hasattr(arg, "name") and arg.name is not None: if hasattr(arg, "name") and arg.name is not None:
missing_test_message( missing_test_message(
"Argument " + str(i) + "('" + arg.name + "') has no test value" "Argument {} ('{}') has no test value".format(i, arg.name)
) )
else: else:
missing_test_message("Argument " + str(i) + " has no test value") missing_test_message("Argument {} has no test value".format(i))
return [] return []
if len(rval) == 1: if len(rval) == 1:
......
...@@ -239,7 +239,7 @@ class object2(with_metaclass(MetaObject, object)): ...@@ -239,7 +239,7 @@ class object2(with_metaclass(MetaObject, object)):
return not self == other return not self == other
class scratchpad(object): class Scratchpad(object):
def clear(self): def clear(self):
self.__dict__.clear() self.__dict__.clear()
...@@ -259,6 +259,23 @@ class scratchpad(object): ...@@ -259,6 +259,23 @@ class scratchpad(object):
print(" %s: %s" % (k, v)) print(" %s: %s" % (k, v))
class ValidatingScratchpad(Scratchpad):
"""This `Scratchpad` validates attribute values."""
def __init__(self, attr, attr_filter):
super().__init__()
object.__setattr__(self, "attr", attr)
object.__setattr__(self, "attr_filter", attr_filter)
def __setattr__(self, attr, obj):
if getattr(self, "attr", None) == attr:
obj = self.attr_filter(obj)
return object.__setattr__(self, attr, obj)
class D: class D:
def __init__(self, **d): def __init__(self, **d):
self.__dict__.update(d) self.__dict__.update(d)
......
...@@ -924,7 +924,7 @@ class VM_Linker(link.LocalLinker): ...@@ -924,7 +924,7 @@ class VM_Linker(link.LocalLinker):
if self.use_cloop and ( if self.use_cloop and (
self.callback is not None or self.callback_input is not None self.callback is not None or self.callback_input is not None
): ):
logger.warn("CVM does not support callback, using Stack VM.") logger.warning("CVM does not support callback, using Stack VM.")
if self.use_cloop and config.profile_memory: if self.use_cloop and config.profile_memory:
warnings.warn("CVM does not support memory profile, using Stack VM.") warnings.warn("CVM does not support memory profile, using Stack VM.")
if not self.use_cloop and self.allow_partial_eval: if not self.use_cloop and self.allow_partial_eval:
......
...@@ -12,6 +12,9 @@ import sys ...@@ -12,6 +12,9 @@ import sys
import tempfile import tempfile
import zipfile import zipfile
import warnings import warnings
import theano
from collections import defaultdict from collections import defaultdict
from contextlib import closing from contextlib import closing
from pickle import HIGHEST_PROTOCOL from pickle import HIGHEST_PROTOCOL
...@@ -22,10 +25,7 @@ try: ...@@ -22,10 +25,7 @@ try:
except ImportError: except ImportError:
DEFAULT_PROTOCOL = HIGHEST_PROTOCOL DEFAULT_PROTOCOL = HIGHEST_PROTOCOL
import theano
from theano import config from theano import config
from theano.compat import PY3
from six import string_types
from theano.compile.sharedvalue import SharedVariable from theano.compile.sharedvalue import SharedVariable
__docformat__ = "restructuredtext en" __docformat__ = "restructuredtext en"
...@@ -68,7 +68,7 @@ class StripPickler(Pickler): ...@@ -68,7 +68,7 @@ class StripPickler(Pickler):
def save(self, obj): def save(self, obj):
# Remove the tag.trace attribute from Variable and Apply nodes # Remove the tag.trace attribute from Variable and Apply nodes
if isinstance(obj, theano.gof.utils.scratchpad): if isinstance(obj, theano.gof.utils.Scratchpad):
for tag in self.tag_to_remove: for tag in self.tag_to_remove:
if hasattr(obj, tag): if hasattr(obj, tag):
del obj.__dict__[tag] del obj.__dict__[tag]
...@@ -80,93 +80,6 @@ class StripPickler(Pickler): ...@@ -80,93 +80,6 @@ class StripPickler(Pickler):
return Pickler.save(self, obj) return Pickler.save(self, obj)
# Make an unpickler that tries encoding byte streams before raising TypeError.
# This is useful with python 3, in order to unpickle files created with
# python 2.
# This code is taken from Pandas, https://github.com/pydata/pandas,
# under the same 3-clause BSD license.
def load_reduce(self):
stack = self.stack
args = stack.pop()
func = stack[-1]
try:
value = func(*args)
except Exception:
# try to reencode the arguments
if self.encoding is not None:
new_args = []
for arg in args:
if isinstance(arg, string_types):
new_args.append(arg.encode(self.encoding))
else:
new_args.append(arg)
args = tuple(new_args)
try:
stack[-1] = func(*args)
return
except Exception:
pass
# if self.is_verbose:
# print(sys.exc_info())
# print(func, args)
raise
stack[-1] = value
if PY3:
class CompatUnpickler(pickle._Unpickler):
"""
Allow to reload in python 3 some pickled numpy ndarray.
.. versionadded:: 0.8
Examples
--------
::
with open(fname, 'rb') as fp:
if PY3:
u = CompatUnpickler(fp, encoding="latin1")
else:
u = CompatUnpickler(fp)
mat = u.load()
"""
pass
# Register `load_reduce` defined above in CompatUnpickler
CompatUnpickler.dispatch[pickle.REDUCE[0]] = load_reduce
else:
class CompatUnpickler(pickle.Unpickler):
"""
Allow to reload in python 3 some pickled numpy ndarray.
.. versionadded:: 0.8
Examples
--------
::
with open(fname, 'rb') as fp:
if PY3:
u = CompatUnpickler(fp, encoding="latin1")
else:
u = CompatUnpickler(fp)
mat = u.load()
"""
pass
class PersistentNdarrayID(object): class PersistentNdarrayID(object):
"""Persist ndarrays in an object by saving them to a zip file. """Persist ndarrays in an object by saving them to a zip file.
......
...@@ -371,11 +371,11 @@ class Print(Op): ...@@ -371,11 +371,11 @@ class Print(Op):
return (1,) return (1,)
class PrinterState(gof.utils.scratchpad): class PrinterState(gof.utils.Scratchpad):
def __init__(self, props=None, **more_props): def __init__(self, props=None, **more_props):
if props is None: if props is None:
props = {} props = {}
elif isinstance(props, gof.utils.scratchpad): elif isinstance(props, gof.utils.Scratchpad):
self.__update__(props) self.__update__(props)
else: else:
self.__dict__.update(props) self.__dict__.update(props)
...@@ -862,7 +862,7 @@ def pydotprint( ...@@ -862,7 +862,7 @@ def pydotprint(
): ):
cond = node cond = node
if cond is None: if cond is None:
_logger.warn( _logger.warning(
"pydotprint: cond_highlight is set but there is no" "pydotprint: cond_highlight is set but there is no"
" IfElse node in the graph" " IfElse node in the graph"
) )
......
...@@ -559,7 +559,7 @@ class ConvOp(OpenMPOp): ...@@ -559,7 +559,7 @@ class ConvOp(OpenMPOp):
" bsize(%i). We revert it to %i. This" " bsize(%i). We revert it to %i. This"
" won't change the result, but may make it slower." " won't change the result, but may make it slower."
) )
_logger.warn(warnstr, self.unroll_batch, self.bsize, new) _logger.warning(warnstr, self.unroll_batch, self.bsize, new)
self.unroll_batch = new self.unroll_batch = new
...@@ -585,7 +585,7 @@ class ConvOp(OpenMPOp): ...@@ -585,7 +585,7 @@ class ConvOp(OpenMPOp):
" nkern(%i). We revert it to %i. This" " nkern(%i). We revert it to %i. This"
" won't change the result, but may make it slower." " won't change the result, but may make it slower."
) )
_logger.warn(warnstr, self.unroll_kern, self.nkern, new) _logger.warning(warnstr, self.unroll_kern, self.nkern, new)
self.unroll_kern = new self.unroll_kern = new
self.outshp = get_conv_output_shape( self.outshp = get_conv_output_shape(
......
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论