提交 46adff86 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Turn speed tests into actual unit tests

These tests now assume that the C VM versions of scan Ops should be faster than their standard Python counterparts.
上级 abc69153
...@@ -4838,7 +4838,8 @@ def test_speed_rnn(): ...@@ -4838,7 +4838,8 @@ def test_speed_rnn():
for i in range(1, L): for i in range(1, L):
r[i] = np.tanh(np.dot(r[i - 1], w)) r[i] = np.tanh(np.dot(r[i - 1], w))
t1 = time.time() t1 = time.time()
print("python", t1 - t0) python_duration = t1 - t0
print("python", python_duration)
r = np.arange(L * N).astype(config.floatX).reshape(L, N) r = np.arange(L * N).astype(config.floatX).reshape(L, N)
s_r = tensor.matrix() s_r = tensor.matrix()
...@@ -4854,7 +4855,8 @@ def test_speed_rnn(): ...@@ -4854,7 +4855,8 @@ def test_speed_rnn():
t2 = time.time() t2 = time.time()
f(r) f(r)
t3 = time.time() t3 = time.time()
print("theano (cvm)", t1 - t0) cvm_duration = t3 - t2
print("theano (cvm)", cvm_duration)
r = np.arange(L * N).astype(config.floatX).reshape(L, N) r = np.arange(L * N).astype(config.floatX).reshape(L, N)
shared_r = theano.shared(r) shared_r = theano.shared(r)
...@@ -4872,11 +4874,14 @@ def test_speed_rnn(): ...@@ -4872,11 +4874,14 @@ def test_speed_rnn():
) )
f_fn = f.fn f_fn = f.fn
t2 = time.time() t4 = time.time()
f_fn(n_calls=L - 2) f_fn(n_calls=L - 2)
f() # 999 to update the profiling timers f() # 999 to update the profiling timers
t3 = time.time() t5 = time.time()
print("theano (updates, cvm)", t3 - t2) cvm_shared_duration = t5 - t4
print("theano (updates, cvm)", cvm_shared_duration)
assert cvm_shared_duration < python_duration
@pytest.mark.skipif( @pytest.mark.skipif(
...@@ -4892,9 +4897,6 @@ def test_speed_batchrnn(): ...@@ -4892,9 +4897,6 @@ def test_speed_batchrnn():
of the optimizations applied, but generally correctness-testing of the optimizations applied, but generally correctness-testing
is not the goal of this test. is not the goal of this test.
To be honest, it isn't really a unit test so much as a tool for testing
approaches to scan.
The computation being tested here is a repeated tanh of a matrix-vector The computation being tested here is a repeated tanh of a matrix-vector
multiplication - the heart of an ESN or RNN. multiplication - the heart of an ESN or RNN.
""" """
...@@ -4910,7 +4912,7 @@ def test_speed_batchrnn(): ...@@ -4910,7 +4912,7 @@ def test_speed_batchrnn():
for i in range(1, L): for i in range(1, L):
r[i] = np.tanh(np.dot(r[i - 1], w)) r[i] = np.tanh(np.dot(r[i - 1], w))
t1 = time.time() t1 = time.time()
print("python", t1 - t0) python_duration = t1 - t0
r = np.arange(B * L * N).astype(config.floatX).reshape(L, B, N) r = np.arange(B * L * N).astype(config.floatX).reshape(L, B, N)
shared_r = theano.shared(r) shared_r = theano.shared(r)
...@@ -4926,14 +4928,14 @@ def test_speed_batchrnn(): ...@@ -4926,14 +4928,14 @@ def test_speed_batchrnn():
updates=[(s_i, s_i + 1), (shared_r, s_rinc)], updates=[(s_i, s_i + 1), (shared_r, s_rinc)],
mode=theano.Mode(linker="cvm"), mode=theano.Mode(linker="cvm"),
) )
# theano.printing.debugprint(f)
f_fn = f.fn f_fn = f.fn
# print f_fn
t2 = time.time() t2 = time.time()
f_fn(n_calls=L - 2) f_fn(n_calls=L - 2)
f() # 999 to update the profiling timers f() # 999 to update the profiling timers
t3 = time.time() t3 = time.time()
print("theano (updates, cvm)", t3 - t2) cvm_duration = t3 - t2
assert cvm_duration < python_duration
def test_compute_test_value(): def test_compute_test_value():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论