提交 bf926bb7 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add test

上级 c23ae0eb
...@@ -4107,6 +4107,40 @@ class T_Scan(unittest.TestCase): ...@@ -4107,6 +4107,40 @@ class T_Scan(unittest.TestCase):
f_strict = theano.function([x0_], ret_strict[0][-1]) f_strict = theano.function([x0_], ret_strict[0][-1])
result_strict = f_strict(x0) result_strict = f_strict(x0)
def test_monitor_mode(self):
# Test that it is possible to pass an instance of MonitorMode
# to the inner function
k = tensor.iscalar("k")
A = tensor.vector("A")
# Build a MonitorMode that counts how many values are greater than 10
def detect_large_outputs(i, node, fn):
for output in fn.outputs:
if isinstance(output[0], numpy.ndarray):
detect_large_outputs.large_count += (output[0] > 10).sum()
detect_large_outputs.large_count = 0
mode = theano.compile.MonitorMode(post_func=detect_large_outputs)
# Symbolic description of the result
result, updates = theano.scan(
fn=lambda prior_result, A: prior_result * A,
outputs_info=tensor.ones_like(A),
non_sequences=A,
n_steps=k,
mode=mode)
final_result = result[-1]
f = theano.function(inputs=[A, k],
outputs=final_result,
updates=updates)
f([2, 3, .1, 0, 1], 4)
# There should be 3 outputs greater than 10: prior_result[0] at step 3,
# and prior_result[1] at steps 2 and 3.
assert detect_large_outputs.large_count == 3
class ScanGpuTests: class ScanGpuTests:
""" This class defines a number of tests for Scan on GPU as well as a few """ This class defines a number of tests for Scan on GPU as well as a few
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论