提交 eaf14709 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Use change_flags instead of try...finally statements in test_compute_test_value.py

上级 24b8bc28
......@@ -3,15 +3,14 @@ import sys
import traceback
import warnings
import numpy as np
import pytest
import numpy as np
import theano
from theano import config
from theano import scalar
from theano import tensor as T
from theano.gof import Apply, Op
from theano.gof import utils
import theano.tensor as tt
from theano import config, scalar
from theano.gof import Apply, Op, utils
from theano.tensor.basic import _allclose
......@@ -38,18 +37,15 @@ class IncOneC(Op):
class TestComputeTestValue:
@theano.change_flags(compute_test_value="raise")
def test_variable_only(self):
orig_compute_test_value = theano.config.compute_test_value
try:
theano.config.compute_test_value = "raise"
x = T.matrix("x")
x = tt.matrix("x")
x.tag.test_value = np.random.rand(3, 4).astype(config.floatX)
y = T.matrix("y")
y = tt.matrix("y")
y.tag.test_value = np.random.rand(4, 5).astype(config.floatX)
# should work
z = T.dot(x, y)
z = tt.dot(x, y)
assert hasattr(z.tag, "test_value")
f = theano.function([x, y], z)
assert _allclose(f(x.tag.test_value, y.tag.test_value), z.tag.test_value)
......@@ -57,80 +53,67 @@ class TestComputeTestValue:
# this test should fail
y.tag.test_value = np.random.rand(6, 5).astype(config.floatX)
with pytest.raises(ValueError):
T.dot(x, y)
finally:
theano.config.compute_test_value = orig_compute_test_value
tt.dot(x, y)
@theano.change_flags(compute_test_value="raise")
def test_compute_flag(self):
orig_compute_test_value = theano.config.compute_test_value
try:
x = T.matrix("x")
y = T.matrix("y")
x = tt.matrix("x")
y = tt.matrix("y")
y.tag.test_value = np.random.rand(4, 5).astype(config.floatX)
# should skip computation of test value
theano.config.compute_test_value = "off"
z = T.dot(x, y)
z = tt.dot(x, y)
assert not hasattr(z.tag, "test_value")
# should fail when asked by user
theano.config.compute_test_value = "raise"
with pytest.raises(ValueError):
T.dot(x, y)
tt.dot(x, y)
# test that a warning is raised if required
theano.config.compute_test_value = "warn"
warnings.simplefilter("error", UserWarning)
try:
with pytest.raises(UserWarning):
T.dot(x, y)
tt.dot(x, y)
finally:
# Restore the default behavior.
# TODO There is a cleaner way to do this in Python 2.6, once
# Theano drops support of Python 2.4 and 2.5.
warnings.simplefilter("default", UserWarning)
finally:
theano.config.compute_test_value = orig_compute_test_value
@theano.change_flags(compute_test_value="raise")
def test_string_var(self):
orig_compute_test_value = theano.config.compute_test_value
try:
theano.config.compute_test_value = "raise"
x = T.matrix("x")
x = tt.matrix("x")
x.tag.test_value = np.random.rand(3, 4).astype(config.floatX)
y = T.matrix("y")
y = tt.matrix("y")
y.tag.test_value = np.random.rand(4, 5).astype(config.floatX)
z = theano.shared(np.random.rand(5, 6).astype(config.floatX))
# should work
out = T.dot(T.dot(x, y), z)
out = tt.dot(tt.dot(x, y), z)
assert hasattr(out.tag, "test_value")
tf = theano.function([x, y], out)
assert _allclose(tf(x.tag.test_value, y.tag.test_value), out.tag.test_value)
def f(x, y, z):
return T.dot(T.dot(x, y), z)
return tt.dot(tt.dot(x, y), z)
# this test should fail
z.set_value(np.random.rand(7, 6).astype(config.floatX))
with pytest.raises(ValueError):
f(x, y, z)
finally:
theano.config.compute_test_value = orig_compute_test_value
@theano.change_flags(compute_test_value="raise")
def test_shared(self):
orig_compute_test_value = theano.config.compute_test_value
try:
theano.config.compute_test_value = "raise"
x = T.matrix("x")
x = tt.matrix("x")
x.tag.test_value = np.random.rand(3, 4).astype(config.floatX)
y = theano.shared(np.random.rand(4, 6).astype(config.floatX), "y")
# should work
z = T.dot(x, y)
z = tt.dot(x, y)
assert hasattr(z.tag, "test_value")
f = theano.function([x], z)
assert _allclose(f(x.tag.test_value), z.tag.test_value)
......@@ -138,20 +121,15 @@ class TestComputeTestValue:
# this test should fail
y.set_value(np.random.rand(5, 6).astype(config.floatX))
with pytest.raises(ValueError):
T.dot(x, y)
finally:
theano.config.compute_test_value = orig_compute_test_value
tt.dot(x, y)
@theano.change_flags(compute_test_value="raise")
def test_ndarray(self):
orig_compute_test_value = theano.config.compute_test_value
try:
theano.config.compute_test_value = "raise"
x = np.random.rand(2, 3).astype(config.floatX)
y = theano.shared(np.random.rand(3, 6).astype(config.floatX), "y")
# should work
z = T.dot(x, y)
z = tt.dot(x, y)
assert hasattr(z.tag, "test_value")
f = theano.function([], z)
assert _allclose(f(), z.tag.test_value)
......@@ -159,15 +137,10 @@ class TestComputeTestValue:
# this test should fail
x = np.random.rand(2, 4).astype(config.floatX)
with pytest.raises(ValueError):
T.dot(x, y)
finally:
theano.config.compute_test_value = orig_compute_test_value
tt.dot(x, y)
@theano.change_flags(compute_test_value="raise")
def test_empty_elemwise(self):
orig_compute_test_value = theano.config.compute_test_value
try:
theano.config.compute_test_value = "raise"
x = theano.shared(np.random.rand(0, 6).astype(config.floatX), "x")
# should work
......@@ -176,70 +149,49 @@ class TestComputeTestValue:
f = theano.function([], z)
assert _allclose(f(), z.tag.test_value)
finally:
theano.config.compute_test_value = orig_compute_test_value
@theano.change_flags(compute_test_value="raise")
def test_constant(self):
orig_compute_test_value = theano.config.compute_test_value
try:
theano.config.compute_test_value = "raise"
x = T.constant(np.random.rand(2, 3), dtype=config.floatX)
x = tt.constant(np.random.rand(2, 3), dtype=config.floatX)
y = theano.shared(np.random.rand(3, 6).astype(config.floatX), "y")
# should work
z = T.dot(x, y)
z = tt.dot(x, y)
assert hasattr(z.tag, "test_value")
f = theano.function([], z)
assert _allclose(f(), z.tag.test_value)
# this test should fail
x = T.constant(np.random.rand(2, 4), dtype=config.floatX)
x = tt.constant(np.random.rand(2, 4), dtype=config.floatX)
with pytest.raises(ValueError):
T.dot(x, y)
finally:
theano.config.compute_test_value = orig_compute_test_value
tt.dot(x, y)
@theano.change_flags(compute_test_value="raise")
def test_incorrect_type(self):
orig_compute_test_value = theano.config.compute_test_value
try:
theano.config.compute_test_value = "raise"
x = T.fmatrix("x")
x = tt.fmatrix("x")
# Incorrect dtype (float64) for test_value
x.tag.test_value = np.random.rand(3, 4)
y = T.dmatrix("y")
y = tt.dmatrix("y")
y.tag.test_value = np.random.rand(4, 5)
with pytest.raises(TypeError):
T.dot(x, y)
finally:
theano.config.compute_test_value = orig_compute_test_value
tt.dot(x, y)
@theano.change_flags(compute_test_value="raise")
def test_overided_function(self):
# We need to test those as they mess with Exception
# And we don't want the exception to be changed.
orig_compute_test_value = theano.config.compute_test_value
try:
config.compute_test_value = "raise"
x = T.matrix()
x = tt.matrix()
x.tag.test_value = np.zeros((2, 3), dtype=config.floatX)
y = T.matrix()
y = tt.matrix()
y.tag.test_value = np.zeros((2, 2), dtype=config.floatX)
with pytest.raises(ValueError):
x.__mul__(y)
finally:
theano.config.compute_test_value = orig_compute_test_value
@theano.change_flags(compute_test_value="raise")
def test_scan(self):
# Test the compute_test_value mechanism Scan.
orig_compute_test_value = theano.config.compute_test_value
try:
theano.config.compute_test_value = "raise"
# theano.config.compute_test_value = 'warn'
k = T.iscalar("k")
A = T.vector("A")
k = tt.iscalar("k")
A = tt.vector("A")
k.tag.test_value = 3
A.tag.test_value = np.random.rand(5).astype(config.floatX)
......@@ -248,7 +200,7 @@ class TestComputeTestValue:
# Symbolic description of the result
result, updates = theano.scan(
fn=fx, outputs_info=T.ones_like(A), non_sequences=A, n_steps=k
fn=fx, outputs_info=tt.ones_like(A), non_sequences=A, n_steps=k
)
# We only care about A**k, but scan has provided us with A**1 through A**k.
......@@ -256,29 +208,22 @@ class TestComputeTestValue:
# notice this and not waste memory saving them.
final_result = result[-1]
assert hasattr(final_result.tag, "test_value")
finally:
theano.config.compute_test_value = orig_compute_test_value
@theano.change_flags(compute_test_value="raise")
def test_scan_err1(self):
# This test should fail when building fx for the first time
orig_compute_test_value = theano.config.compute_test_value
try:
theano.config.compute_test_value = "raise"
k = T.iscalar("k")
A = T.matrix("A")
k = tt.iscalar("k")
A = tt.matrix("A")
k.tag.test_value = 3
A.tag.test_value = np.random.rand(5, 3).astype(config.floatX)
def fx(prior_result, A):
return T.dot(prior_result, A)
return tt.dot(prior_result, A)
# Since we have to inspect the traceback,
# we cannot simply use self.assertRaises()
try:
theano.scan(
fn=fx, outputs_info=T.ones_like(A), non_sequences=A, n_steps=k
)
theano.scan(fn=fx, outputs_info=tt.ones_like(A), non_sequences=A, n_steps=k)
assert False
except ValueError:
# Get traceback
......@@ -288,49 +233,38 @@ class TestComputeTestValue:
# We should be in the "fx" function defined above
expected = "test_compute_test_value.py"
assert any(
(
os.path.split(frame_info[0])[1] == expected
and frame_info[2] == "fx"
)
(os.path.split(frame_info[0])[1] == expected and frame_info[2] == "fx")
for frame_info in frame_infos
), frame_infos
finally:
theano.config.compute_test_value = orig_compute_test_value
@theano.change_flags(compute_test_value="raise")
def test_scan_err2(self):
# This test should not fail when building fx for the first time,
# but when calling the scan's perform()
orig_compute_test_value = theano.config.compute_test_value
try:
theano.config.compute_test_value = "raise"
k = T.iscalar("k")
A = T.matrix("A")
k = tt.iscalar("k")
A = tt.matrix("A")
k.tag.test_value = 3
A.tag.test_value = np.random.rand(5, 3).astype(config.floatX)
def fx(prior_result, A):
return T.dot(prior_result, A)
return tt.dot(prior_result, A)
with pytest.raises(ValueError):
theano.scan(
fn=fx, outputs_info=T.ones_like(A.T), non_sequences=A, n_steps=k
fn=fx, outputs_info=tt.ones_like(A.T), non_sequences=A, n_steps=k
)
# Since we have to inspect the traceback,
# we cannot simply use self.assertRaises()
try:
theano.scan(
fn=fx, outputs_info=T.ones_like(A.T), non_sequences=A, n_steps=k
fn=fx, outputs_info=tt.ones_like(A.T), non_sequences=A, n_steps=k
)
assert False
except ValueError as e:
assert str(e).startswith("could not broadcast input"), str(e)
finally:
theano.config.compute_test_value = orig_compute_test_value
@theano.change_flags(compute_test_value="raise")
def test_no_c_code(self):
class IncOnePython(Op):
"""
......@@ -349,10 +283,6 @@ class TestComputeTestValue:
(output,) = outputs
output[0] = input + 1
orig_compute_test_value = theano.config.compute_test_value
try:
theano.config.compute_test_value = "raise"
i = scalar.int32("i")
i.tag.test_value = 3
......@@ -365,17 +295,11 @@ class TestComputeTestValue:
assert hasattr(o.tag, "test_value")
assert o.tag.test_value == 4
finally:
theano.config.compute_test_value = orig_compute_test_value
@pytest.mark.skipif(
not theano.config.cxx, reason="G++ not available, so we need to skip this test."
)
@theano.change_flags(compute_test_value="raise")
def test_no_perform(self):
orig_compute_test_value = theano.config.compute_test_value
try:
theano.config.compute_test_value = "raise"
i = scalar.int32("i")
i.tag.test_value = 3
......@@ -390,19 +314,10 @@ class TestComputeTestValue:
assert hasattr(o.tag, "test_value")
assert o.tag.test_value == 4
finally:
theano.config.compute_test_value = orig_compute_test_value
@theano.change_flags(compute_test_value="raise")
def test_disabled_during_compilation(self):
# We test that it is disabled when we include deep copy in the code
# This don't test that it is disabled during optimization, but the code do it.
orig_compute_test_value = theano.config.compute_test_value
try:
theano.config.compute_test_value = "raise"
init_Mu1 = theano.shared(np.zeros((5,), dtype=config.floatX)).dimshuffle(
"x", 0
)
init_Mu1 = theano.shared(np.zeros((5,), dtype=config.floatX)).dimshuffle("x", 0)
theano.function([], outputs=[init_Mu1])
finally:
theano.config.compute_test_value = orig_compute_test_value
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论