提交 4a79b089 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Refactor test.scan.test_basic.test_compute_test_values

上级 af069dbb
......@@ -4834,32 +4834,32 @@ def test_speed_batchrnn():
assert cvm_duration < python_duration
def test_compute_test_value():
# Verify that test values can be used with scan.
with config.change_flags(compute_test_value="raise"):
x = vector("x")
xv = np.ones(3, dtype=config.floatX)
x.tag.test_value = xv
y = shared(np.arange(3, dtype=config.floatX), name="y")
z, updates = scan(fn=lambda u, v: u + v, sequences=[x, y])
assert not updates
z.name = "z"
# The gradient computation used to crash before 6af465e.
grad(z.sum(), x)
def test_compute_test_value_nonseq():
# Verify that test values can be used for non_sequences with scan.
with config.change_flags(compute_test_value="raise"):
x = vector("x")
xv = np.ones(3, dtype=config.floatX)
x.tag.test_value = xv
y = shared(np.arange(9, dtype=config.floatX).reshape(3, 3), name="y")
z, updates = scan(fn=lambda u, v: u + v, sequences=[x], non_sequences=[y])
assert not updates
z.name = "z"
# The gradient computation used to crash before 6af465e.
grad(z.sum(), x)
@config.change_flags(mode="FAST_COMPILE", compute_test_value="raise")
def test_compute_test_values():
"""Verify that test values can be used with scan."""
x = vector("x")
x.tag.test_value = np.ones(3, dtype=config.floatX)
y = shared(np.arange(3, dtype=config.floatX), name="y")
z, updates = scan(fn=lambda u, v: u + v, sequences=[x, y])
assert not updates
z_grad = grad(z.sum(), x)
assert np.array_equal(z_grad.tag.test_value, np.r_[1.0, 1.0, 1.0])
# Use `non_sequences` this time
y = shared(np.arange(9, dtype=config.floatX).reshape(3, 3), name="y")
z, updates = scan(fn=lambda u, v: u + v, sequences=[x], non_sequences=[y])
assert not updates
z_grad = grad(z.sum(), x)
assert np.array_equal(z_grad.tag.test_value, np.r_[9.0, 9.0, 9.0])
def test_compute_test_value_grad():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论