提交 da6ea2fb authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Rename get_debug_values to get_test_values

上级 0ea10588
...@@ -288,23 +288,23 @@ def test_test_value_op(): ...@@ -288,23 +288,23 @@ def test_test_value_op():
@change_flags(compute_test_value="off") @change_flags(compute_test_value="off")
def test_get_debug_values_no_debugger(): def test_get_test_values_no_debugger():
"""Tests that `get_debug_values` returns `[]` when debugger is off.""" """Tests that `get_test_values` returns `[]` when debugger is off."""
x = tt.vector() x = tt.vector()
assert op.get_debug_values(x) == [] assert op.get_test_values(x) == []
@change_flags(compute_test_value="ignore") @change_flags(compute_test_value="ignore")
def test_get_det_debug_values_ignore(): def test_get_test_values_ignore():
"""Tests that `get_debug_values` returns `[]` when debugger is set to "ignore" and some values are missing.""" """Tests that `get_test_values` returns `[]` when debugger is set to "ignore" and some values are missing."""
x = tt.vector() x = tt.vector()
assert op.get_debug_values(x) == [] assert op.get_test_values(x) == []
def test_get_debug_values_success(): def test_get_test_values_success():
"""Tests that `get_debug_value` returns values when available (and the debugger is on).""" """Tests that `get_test_values` returns values when available (and the debugger is on)."""
for mode in ["ignore", "warn", "raise"]: for mode in ["ignore", "warn", "raise"]:
with change_flags(compute_test_value=mode): with change_flags(compute_test_value=mode):
...@@ -314,7 +314,7 @@ def test_get_debug_values_success(): ...@@ -314,7 +314,7 @@ def test_get_debug_values_success():
iters = 0 iters = 0
for x_val, y_val in op.get_debug_values(x, y): for x_val, y_val in op.get_test_values(x, y):
assert x_val.shape == (4,) assert x_val.shape == (4,)
assert y_val.shape == (5, 5) assert y_val.shape == (5, 5)
...@@ -325,9 +325,9 @@ def test_get_debug_values_success(): ...@@ -325,9 +325,9 @@ def test_get_debug_values_success():
@change_flags(compute_test_value="raise") @change_flags(compute_test_value="raise")
def test_get_debug_values_exc(): def test_get_test_values_exc():
"""Tests that `get_debug_value` raises an exception when debugger is set to raise and a value is missing.""" """Tests that `get_test_values` raises an exception when debugger is set to raise and a value is missing."""
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
x = tt.vector() x = tt.vector()
assert op.get_debug_values(x) == [] assert op.get_test_values(x) == []
...@@ -1080,7 +1080,7 @@ def missing_test_message(msg): ...@@ -1080,7 +1080,7 @@ def missing_test_message(msg):
assert action in ["ignore", "off"] assert action in ["ignore", "off"]
def get_debug_values(*args): def get_test_values(*args):
""" """
Intended use: Intended use:
......
...@@ -15,7 +15,7 @@ from theano import gof ...@@ -15,7 +15,7 @@ from theano import gof
from theano.gof import utils, Variable from theano.gof import utils, Variable
from theano.gof.null_type import NullType, null_type from theano.gof.null_type import NullType, null_type
from theano.gof.op import get_debug_values from theano.gof.op import get_test_values
from theano.compile import ViewOp, FAST_RUN, DebugMode, get_mode from theano.compile import ViewOp, FAST_RUN, DebugMode, get_mode
__authors__ = "James Bergstra, Razvan Pascanu, Arnaud Bergeron, Ian Goodfellow" __authors__ = "James Bergstra, Razvan Pascanu, Arnaud Bergeron, Ian Goodfellow"
...@@ -1217,7 +1217,7 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None): ...@@ -1217,7 +1217,7 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None):
continue continue
if isinstance(new_output_grad.type, DisconnectedType): if isinstance(new_output_grad.type, DisconnectedType):
continue continue
for orig_output_v, new_output_grad_v in get_debug_values(*packed): for orig_output_v, new_output_grad_v in get_test_values(*packed):
o_shape = orig_output_v.shape o_shape = orig_output_v.shape
g_shape = new_output_grad_v.shape g_shape = new_output_grad_v.shape
if o_shape != g_shape: if o_shape != g_shape:
...@@ -1310,7 +1310,7 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None): ...@@ -1310,7 +1310,7 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None):
# has the right shape # has the right shape
if hasattr(term, "shape"): if hasattr(term, "shape"):
orig_ipt = inputs[i] orig_ipt = inputs[i]
for orig_ipt_v, term_v in get_debug_values(orig_ipt, term): for orig_ipt_v, term_v in get_test_values(orig_ipt, term):
i_shape = orig_ipt_v.shape i_shape = orig_ipt_v.shape
t_shape = term_v.shape t_shape = term_v.shape
if i_shape != t_shape: if i_shape != t_shape:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论