提交 157330e4 authored 作者: --global's avatar --global

Update tests for helper functions (and fixe error in one expected value)

上级 8a60518e
......@@ -657,19 +657,18 @@ class T_Scan(unittest.TestCase):
tensor.grad(a[-1], a0)
# Also validate that the methods get_outer_iidx_from_outer_oidx_seq
# and get_outer_iidx_from_inner_iidx_seq produce the correct results
# Also validate that the mappings outer_inp_from_outer_out and
# outer_inp_from_inner_inp produce the correct results
scan_node = a.owner.inputs[0].owner
result = scan_node.op.get_outer_iidx_from_outer_oidx_seq()
expected_result = [1, 2]
result = scan_node.op.var_mappings['outer_inp_from_outer_out']
expected_result = {0: 1, 1: 2}
assert(result == expected_result)
result = scan_node.op.get_outer_iidx_from_inner_iidx_seq()
expected_result = [1, 1, 2, 2]
result = scan_node.op.var_mappings['outer_inp_from_inner_inp']
expected_result = {0: 1, 1: 1, 2: 2, 3: 2}
assert(result == expected_result)
def test_connection_pattern2(self):
# This tests for a crash in connection_pattern() when a scan node
# has more than one mitmot (multiple input taps as well as
......@@ -690,16 +689,16 @@ class T_Scan(unittest.TestCase):
scan_node = g_out[0].owner.inputs[1].owner.inputs[1].owner.inputs[0].owner
connection_pattern = scan_node.op.connection_pattern(scan_node)
# Also validate that the methods get_outer_iidx_from_outer_oidx_seq
# and get_outer_iidx_from_inner_iidx_seq produce the correct results
# Also validate that the mappings outer_inp_from_outer_out and
# outer_inp_from_inner_inp produce the correct results
scan_node = out.owner.inputs[0].owner
result = scan_node.op.get_outer_iidx_from_outer_oidx_seq()
expected_result = [2]
result = scan_node.op.var_mappings['outer_inp_from_outer_out']
expected_result = {0: 2}
assert(result == expected_result)
result = scan_node.op.get_outer_iidx_from_inner_iidx_seq()
expected_result = [1, 2, 2]
result = scan_node.op.var_mappings['outer_inp_from_inner_inp']
expected_result = {0: 1, 1: 2, 2: 2}
assert(result == expected_result)
def test_grad_grad_mitsot_sitsot(self):
......@@ -1704,16 +1703,16 @@ class T_Scan(unittest.TestCase):
analytic_grad[max_err_pos],
num_grad.gx[max_err_pos]))
# Also validate that the methods get_outer_iidx_from_outer_oidx_seq
# and get_outer_iidx_from_inner_iidx_seq produce the correct results
# Also validate that the mappings outer_inp_from_outer_out and
# outer_inp_from_inner_inp produce the correct results
scan_node = updates.values()[0].owner
result = scan_node.op.get_outer_iidx_from_outer_oidx_seq()
expected_result = [3, -1, 4]
result = scan_node.op.var_mappings['outer_inp_from_outer_out']
expected_result = {0: 3, 1: 5, 2: 4}
assert(result == expected_result)
result = scan_node.op.get_outer_iidx_from_inner_iidx_seq()
expected_result = [1, 2, 3, 4, 6]
result = scan_node.op.var_mappings['outer_inp_from_inner_inp']
expected_result = {0: 1, 1: 2, 2: 3, 3: 4, 4: 6}
assert(result == expected_result)
def test_grad_multiple_outs_some_truncate(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论