提交 ea44b16d authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Refactor imports and test design in tests.gof.test_op

上级 747e80fd
import numpy as np
import pytest
import theano
import theano.gof.op as op
import theano.tensor as tt
from six import string_types
from theano.gof.type import Type, Generic
from theano import scalar, shared
from theano.configparser import change_flags
from theano.gof.graph import Apply, Variable
import theano.tensor as T
from theano import scalar
from theano import shared
from theano.gof.type import Generic, Type
config = theano.config
Op = op.Op
......@@ -238,15 +238,15 @@ class TestMakeThunk:
__props__ = ()
itypes = [T.dmatrix]
otypes = [T.dmatrix]
itypes = [tt.dmatrix]
otypes = [tt.dmatrix]
def perform(self, node, inputs, outputs):
inp = inputs[0]
output = outputs[0]
output[0] = inp * 2
x_input = T.dmatrix("x_input")
x_input = tt.dmatrix("x_input")
f = theano.function([x_input], DoubleOp()(x_input))
inp = np.random.rand(5, 4)
out = f(inp)
......@@ -255,17 +255,17 @@ class TestMakeThunk:
def test_test_value_python_objects():
for x in ([0, 1, 2], 0, 0.5, 1):
assert (op.get_test_value(x) == x).all()
assert np.all(op.get_test_value(x) == x)
def test_test_value_ndarray():
x = np.zeros((5, 5))
v = op.get_test_value(x)
assert (v == x).all()
assert np.all(v == x)
def test_test_value_constant():
x = T.as_tensor_variable(np.zeros((5, 5)))
x = tt.as_tensor_variable(np.zeros((5, 5)))
v = op.get_test_value(x)
assert np.all(v == np.zeros((5, 5)))
......@@ -278,62 +278,37 @@ def test_test_value_shared():
assert np.all(v == np.zeros((5, 5)))
@change_flags(compute_test_value="raise")
def test_test_value_op():
try:
prev_value = config.compute_test_value
config.compute_test_value = "raise"
x = T.log(np.ones((5, 5)))
v = op.get_test_value(x)
assert np.allclose(v, np.zeros((5, 5)))
finally:
config.compute_test_value = prev_value
x = tt.log(np.ones((5, 5)))
v = op.get_test_value(x)
def test_get_debug_values_no_debugger():
"get_debug_values should return [] when debugger is off"
prev_value = config.compute_test_value
try:
config.compute_test_value = "off"
assert np.allclose(v, np.zeros((5, 5)))
x = T.vector()
for x_val in op.get_debug_values(x):
assert False
@change_flags(compute_test_value="off")
def test_get_debug_values_no_debugger():
"""Tests that `get_debug_values` returns `[]` when debugger is off."""
finally:
config.compute_test_value = prev_value
x = tt.vector()
assert op.get_debug_values(x) == []
@change_flags(compute_test_value="ignore")
def test_get_det_debug_values_ignore():
# get_debug_values should return [] when debugger is ignore
# and some values are missing
prev_value = config.compute_test_value
try:
config.compute_test_value = "ignore"
x = T.vector()
"""Tests that `get_debug_values` returns `[]` when debugger is set to "ignore" and some values are missing."""
for x_val in op.get_debug_values(x):
assert False
finally:
config.compute_test_value = prev_value
x = tt.vector()
assert op.get_debug_values(x) == []
def test_get_debug_values_success():
# tests that get_debug_value returns values when available
# (and the debugger is on)
"""Tests that `get_debug_value` returns values when available (and the debugger is on)."""
prev_value = config.compute_test_value
for mode in ["ignore", "warn", "raise"]:
try:
config.compute_test_value = mode
x = T.vector()
with change_flags(compute_test_value=mode):
x = tt.vector()
x.tag.test_value = np.zeros((4,), dtype=config.floatX)
y = np.zeros((5, 5))
......@@ -348,33 +323,11 @@ def test_get_debug_values_success():
assert iters == 1
finally:
config.compute_test_value = prev_value
@change_flags(compute_test_value="raise")
def test_get_debug_values_exc():
# tests that get_debug_value raises an exception when
# debugger is set to raise and a value is missing
prev_value = config.compute_test_value
try:
config.compute_test_value = "raise"
"""Tests that `get_debug_value` raises an exception when debugger is set to raise and a value is missing."""
x = T.vector()
try:
for x_val in op.get_debug_values(x):
# this assert catches the case where we
# erroneously get a value returned
assert False
raised = False
except AttributeError:
raised = True
# this assert catches the case where we got []
# returned, and possibly issued a warning,
# rather than raising an exception
assert raised
finally:
config.compute_test_value = prev_value
with pytest.raises(AttributeError):
x = tt.vector()
assert op.get_debug_values(x) == []
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论