提交 46adff86 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Turn speed tests into actual unit tests

These tests now assume that the C VM versions of scan Ops should be faster than their standard Python counterparts.
上级 abc69153
......@@ -4838,7 +4838,8 @@ def test_speed_rnn():
for i in range(1, L):
r[i] = np.tanh(np.dot(r[i - 1], w))
t1 = time.time()
print("python", t1 - t0)
python_duration = t1 - t0
print("python", python_duration)
r = np.arange(L * N).astype(config.floatX).reshape(L, N)
s_r = tensor.matrix()
......@@ -4854,7 +4855,8 @@ def test_speed_rnn():
t2 = time.time()
f(r)
t3 = time.time()
print("theano (cvm)", t1 - t0)
cvm_duration = t3 - t2
print("theano (cvm)", cvm_duration)
r = np.arange(L * N).astype(config.floatX).reshape(L, N)
shared_r = theano.shared(r)
......@@ -4872,11 +4874,14 @@ def test_speed_rnn():
)
f_fn = f.fn
t2 = time.time()
t4 = time.time()
f_fn(n_calls=L - 2)
f() # 999 to update the profiling timers
t3 = time.time()
print("theano (updates, cvm)", t3 - t2)
t5 = time.time()
cvm_shared_duration = t5 - t4
print("theano (updates, cvm)", cvm_shared_duration)
assert cvm_shared_duration < python_duration
@pytest.mark.skipif(
......@@ -4892,9 +4897,6 @@ def test_speed_batchrnn():
of the optimizations applied, but generally correctness-testing
is not the goal of this test.
To be honest, it isn't really a unit test so much as a tool for testing
approaches to scan.
The computation being tested here is a repeated tanh of a matrix-vector
multiplication - the heart of an ESN or RNN.
"""
......@@ -4910,7 +4912,7 @@ def test_speed_batchrnn():
for i in range(1, L):
r[i] = np.tanh(np.dot(r[i - 1], w))
t1 = time.time()
print("python", t1 - t0)
python_duration = t1 - t0
r = np.arange(B * L * N).astype(config.floatX).reshape(L, B, N)
shared_r = theano.shared(r)
......@@ -4926,14 +4928,14 @@ def test_speed_batchrnn():
updates=[(s_i, s_i + 1), (shared_r, s_rinc)],
mode=theano.Mode(linker="cvm"),
)
# theano.printing.debugprint(f)
f_fn = f.fn
# print f_fn
t2 = time.time()
f_fn(n_calls=L - 2)
f() # 999 to update the profiling timers
t3 = time.time()
print("theano (updates, cvm)", t3 - t2)
cvm_duration = t3 - t2
assert cvm_duration < python_duration
def test_compute_test_value():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论