提交 da6ea2fb authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Rename get_debug_values to get_test_values

上级 0ea10588
......@@ -288,23 +288,23 @@ def test_test_value_op():
@change_flags(compute_test_value="off")
def test_get_debug_values_no_debugger():
"""Tests that `get_debug_values` returns `[]` when debugger is off."""
def test_get_test_values_no_debugger():
"""Tests that `get_test_values` returns `[]` when debugger is off."""
x = tt.vector()
assert op.get_debug_values(x) == []
assert op.get_test_values(x) == []
@change_flags(compute_test_value="ignore")
def test_get_det_debug_values_ignore():
"""Tests that `get_debug_values` returns `[]` when debugger is set to "ignore" and some values are missing."""
def test_get_test_values_ignore():
"""Tests that `get_test_values` returns `[]` when debugger is set to "ignore" and some values are missing."""
x = tt.vector()
assert op.get_debug_values(x) == []
assert op.get_test_values(x) == []
def test_get_debug_values_success():
"""Tests that `get_debug_value` returns values when available (and the debugger is on)."""
def test_get_test_values_success():
"""Tests that `get_test_values` returns values when available (and the debugger is on)."""
for mode in ["ignore", "warn", "raise"]:
with change_flags(compute_test_value=mode):
......@@ -314,7 +314,7 @@ def test_get_debug_values_success():
iters = 0
for x_val, y_val in op.get_debug_values(x, y):
for x_val, y_val in op.get_test_values(x, y):
assert x_val.shape == (4,)
assert y_val.shape == (5, 5)
......@@ -325,9 +325,9 @@ def test_get_debug_values_success():
@change_flags(compute_test_value="raise")
def test_get_debug_values_exc():
"""Tests that `get_debug_value` raises an exception when debugger is set to raise and a value is missing."""
def test_get_test_values_exc():
"""Tests that `get_test_values` raises an exception when debugger is set to raise and a value is missing."""
with pytest.raises(AttributeError):
x = tt.vector()
assert op.get_debug_values(x) == []
assert op.get_test_values(x) == []
......@@ -1080,7 +1080,7 @@ def missing_test_message(msg):
assert action in ["ignore", "off"]
def get_debug_values(*args):
def get_test_values(*args):
"""
Intended use:
......
......@@ -15,7 +15,7 @@ from theano import gof
from theano.gof import utils, Variable
from theano.gof.null_type import NullType, null_type
from theano.gof.op import get_debug_values
from theano.gof.op import get_test_values
from theano.compile import ViewOp, FAST_RUN, DebugMode, get_mode
__authors__ = "James Bergstra, Razvan Pascanu, Arnaud Bergeron, Ian Goodfellow"
......@@ -1217,7 +1217,7 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None):
continue
if isinstance(new_output_grad.type, DisconnectedType):
continue
for orig_output_v, new_output_grad_v in get_debug_values(*packed):
for orig_output_v, new_output_grad_v in get_test_values(*packed):
o_shape = orig_output_v.shape
g_shape = new_output_grad_v.shape
if o_shape != g_shape:
......@@ -1310,7 +1310,7 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None):
# has the right shape
if hasattr(term, "shape"):
orig_ipt = inputs[i]
for orig_ipt_v, term_v in get_debug_values(orig_ipt, term):
for orig_ipt_v, term_v in get_test_values(orig_ipt, term):
i_shape = orig_ipt_v.shape
t_shape = term_v.shape
if i_shape != t_shape:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论