提交 d61f56fc authored 作者: --global's avatar --global

Add validations for helper functions in existing test cases

上级 8b835037
......@@ -1511,7 +1511,6 @@ class Scan(PureOp):
output_offset += 1
# Process shared inputs/outputs
input_offset += self.n_nit_sot
output_offset += self.n_nit_sot
for i in range(self.n_shared_outs):
result[output_offset] = input_offset
......
......@@ -836,8 +836,22 @@ class T_Scan(unittest.TestCase):
outputs_info=[{'initial': a0, 'taps': [-2, -1]},
{'initial': b0, 'taps': [-2, -1]}],
n_steps=2)
tensor.grad(a[-1], a0)
# Also validate that the methods get_outer_iidx_from_outer_oidx_seq
# and get_outer_iidx_from_inner_iidx_seq produce the correct results
scan_node = a.owner.inputs[0].owner
result = scan_node.op.get_outer_iidx_from_outer_oidx_seq()
expected_result = [1, 2]
assert(result == expected_result)
result = scan_node.op.get_outer_iidx_from_inner_iidx_seq()
expected_result = [1, 1, 2, 2]
assert(result == expected_result)
def test_connection_pattern2(self):
# This tests for a crash in connection_pattern() when a scan node
# has more than one mitmot (multiple input taps as well as
......@@ -858,6 +872,18 @@ class T_Scan(unittest.TestCase):
scan_node = g_out[0].owner.inputs[1].owner.inputs[1].owner.inputs[0].owner
connection_pattern = scan_node.op.connection_pattern(scan_node)
# Also validate that the methods get_outer_iidx_from_outer_oidx_seq
# and get_outer_iidx_from_inner_iidx_seq produce the correct results
scan_node = out.owner.inputs[0].owner
result = scan_node.op.get_outer_iidx_from_outer_oidx_seq()
expected_result = [2]
assert(result == expected_result)
result = scan_node.op.get_outer_iidx_from_inner_iidx_seq()
expected_result = [1, 2, 2]
assert(result == expected_result)
def test_grad_two_scans(self):
# data input & output
......@@ -1870,6 +1896,18 @@ class T_Scan(unittest.TestCase):
analytic_grad[max_err_pos],
num_grad.gx[max_err_pos]))
# Also validate that the methods get_outer_iidx_from_outer_oidx_seq
# and get_outer_iidx_from_inner_iidx_seq produce the correct results
scan_node = updates.values()[0].owner
result = scan_node.op.get_outer_iidx_from_outer_oidx_seq()
expected_result = [3, -1, 4]
assert(result == expected_result)
result = scan_node.op.get_outer_iidx_from_inner_iidx_seq()
expected_result = [1, 2, 3, 4, 6]
assert(result == expected_result)
def test_grad_multiple_outs_some_truncate(self):
rng = numpy.random.RandomState(utt.fetch_seed())
vW_in = asarrayX(rng.uniform(size=(2, 2), low=-.1, high=.1))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论