提交 ea44b16d authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Refactor imports and test design in tests.gof.test_op

上级 747e80fd
import numpy as np import numpy as np
import pytest import pytest
import theano import theano
import theano.gof.op as op import theano.gof.op as op
import theano.tensor as tt
from six import string_types from six import string_types
from theano.gof.type import Type, Generic from theano import scalar, shared
from theano.configparser import change_flags
from theano.gof.graph import Apply, Variable from theano.gof.graph import Apply, Variable
import theano.tensor as T from theano.gof.type import Generic, Type
from theano import scalar
from theano import shared
config = theano.config config = theano.config
Op = op.Op Op = op.Op
...@@ -238,15 +238,15 @@ class TestMakeThunk: ...@@ -238,15 +238,15 @@ class TestMakeThunk:
__props__ = () __props__ = ()
itypes = [T.dmatrix] itypes = [tt.dmatrix]
otypes = [T.dmatrix] otypes = [tt.dmatrix]
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
inp = inputs[0] inp = inputs[0]
output = outputs[0] output = outputs[0]
output[0] = inp * 2 output[0] = inp * 2
x_input = T.dmatrix("x_input") x_input = tt.dmatrix("x_input")
f = theano.function([x_input], DoubleOp()(x_input)) f = theano.function([x_input], DoubleOp()(x_input))
inp = np.random.rand(5, 4) inp = np.random.rand(5, 4)
out = f(inp) out = f(inp)
...@@ -255,17 +255,17 @@ class TestMakeThunk: ...@@ -255,17 +255,17 @@ class TestMakeThunk:
def test_test_value_python_objects(): def test_test_value_python_objects():
for x in ([0, 1, 2], 0, 0.5, 1): for x in ([0, 1, 2], 0, 0.5, 1):
assert (op.get_test_value(x) == x).all() assert np.all(op.get_test_value(x) == x)
def test_test_value_ndarray(): def test_test_value_ndarray():
x = np.zeros((5, 5)) x = np.zeros((5, 5))
v = op.get_test_value(x) v = op.get_test_value(x)
assert (v == x).all() assert np.all(v == x)
def test_test_value_constant(): def test_test_value_constant():
x = T.as_tensor_variable(np.zeros((5, 5))) x = tt.as_tensor_variable(np.zeros((5, 5)))
v = op.get_test_value(x) v = op.get_test_value(x)
assert np.all(v == np.zeros((5, 5))) assert np.all(v == np.zeros((5, 5)))
...@@ -278,62 +278,37 @@ def test_test_value_shared(): ...@@ -278,62 +278,37 @@ def test_test_value_shared():
assert np.all(v == np.zeros((5, 5))) assert np.all(v == np.zeros((5, 5)))
@change_flags(compute_test_value="raise")
def test_test_value_op(): def test_test_value_op():
try:
prev_value = config.compute_test_value x = tt.log(np.ones((5, 5)))
config.compute_test_value = "raise"
x = T.log(np.ones((5, 5)))
v = op.get_test_value(x) v = op.get_test_value(x)
assert np.allclose(v, np.zeros((5, 5))) assert np.allclose(v, np.zeros((5, 5)))
finally:
config.compute_test_value = prev_value
@change_flags(compute_test_value="off")
def test_get_debug_values_no_debugger(): def test_get_debug_values_no_debugger():
"get_debug_values should return [] when debugger is off" """Tests that `get_debug_values` returns `[]` when debugger is off."""
prev_value = config.compute_test_value
try:
config.compute_test_value = "off"
x = T.vector()
for x_val in op.get_debug_values(x):
assert False
finally: x = tt.vector()
config.compute_test_value = prev_value assert op.get_debug_values(x) == []
@change_flags(compute_test_value="ignore")
def test_get_det_debug_values_ignore(): def test_get_det_debug_values_ignore():
# get_debug_values should return [] when debugger is ignore """Tests that `get_debug_values` returns `[]` when debugger is set to "ignore" and some values are missing."""
# and some values are missing
prev_value = config.compute_test_value
try:
config.compute_test_value = "ignore"
x = T.vector() x = tt.vector()
assert op.get_debug_values(x) == []
for x_val in op.get_debug_values(x):
assert False
finally:
config.compute_test_value = prev_value
def test_get_debug_values_success(): def test_get_debug_values_success():
# tests that get_debug_value returns values when available """Tests that `get_debug_value` returns values when available (and the debugger is on)."""
# (and the debugger is on)
prev_value = config.compute_test_value
for mode in ["ignore", "warn", "raise"]: for mode in ["ignore", "warn", "raise"]:
with change_flags(compute_test_value=mode):
try: x = tt.vector()
config.compute_test_value = mode
x = T.vector()
x.tag.test_value = np.zeros((4,), dtype=config.floatX) x.tag.test_value = np.zeros((4,), dtype=config.floatX)
y = np.zeros((5, 5)) y = np.zeros((5, 5))
...@@ -348,33 +323,11 @@ def test_get_debug_values_success(): ...@@ -348,33 +323,11 @@ def test_get_debug_values_success():
assert iters == 1 assert iters == 1
finally:
config.compute_test_value = prev_value
@change_flags(compute_test_value="raise")
def test_get_debug_values_exc(): def test_get_debug_values_exc():
# tests that get_debug_value raises an exception when """Tests that `get_debug_value` raises an exception when debugger is set to raise and a value is missing."""
# debugger is set to raise and a value is missing
prev_value = config.compute_test_value
try:
config.compute_test_value = "raise"
x = T.vector() with pytest.raises(AttributeError):
x = tt.vector()
try: assert op.get_debug_values(x) == []
for x_val in op.get_debug_values(x):
# this assert catches the case where we
# erroneously get a value returned
assert False
raised = False
except AttributeError:
raised = True
# this assert catches the case where we got []
# returned, and possibly issued a warning,
# rather than raising an exception
assert raised
finally:
config.compute_test_value = prev_value
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论