提交 157330e4 authored 作者: --global's avatar --global

Update tests for helper functions (and fixe error in one expected value)

上级 8a60518e
...@@ -657,19 +657,18 @@ class T_Scan(unittest.TestCase): ...@@ -657,19 +657,18 @@ class T_Scan(unittest.TestCase):
tensor.grad(a[-1], a0) tensor.grad(a[-1], a0)
# Also validate that the methods get_outer_iidx_from_outer_oidx_seq # Also validate that the mappings outer_inp_from_outer_out and
# and get_outer_iidx_from_inner_iidx_seq produce the correct results # outer_inp_from_inner_inp produce the correct results
scan_node = a.owner.inputs[0].owner scan_node = a.owner.inputs[0].owner
result = scan_node.op.get_outer_iidx_from_outer_oidx_seq() result = scan_node.op.var_mappings['outer_inp_from_outer_out']
expected_result = [1, 2] expected_result = {0: 1, 1: 2}
assert(result == expected_result) assert(result == expected_result)
result = scan_node.op.get_outer_iidx_from_inner_iidx_seq() result = scan_node.op.var_mappings['outer_inp_from_inner_inp']
expected_result = [1, 1, 2, 2] expected_result = {0: 1, 1: 1, 2: 2, 3: 2}
assert(result == expected_result) assert(result == expected_result)
def test_connection_pattern2(self): def test_connection_pattern2(self):
# This tests for a crash in connection_pattern() when a scan node # This tests for a crash in connection_pattern() when a scan node
# has more than one mitmot (multiple input taps as well as # has more than one mitmot (multiple input taps as well as
...@@ -690,16 +689,16 @@ class T_Scan(unittest.TestCase): ...@@ -690,16 +689,16 @@ class T_Scan(unittest.TestCase):
scan_node = g_out[0].owner.inputs[1].owner.inputs[1].owner.inputs[0].owner scan_node = g_out[0].owner.inputs[1].owner.inputs[1].owner.inputs[0].owner
connection_pattern = scan_node.op.connection_pattern(scan_node) connection_pattern = scan_node.op.connection_pattern(scan_node)
# Also validate that the methods get_outer_iidx_from_outer_oidx_seq # Also validate that the mappings outer_inp_from_outer_out and
# and get_outer_iidx_from_inner_iidx_seq produce the correct results # outer_inp_from_inner_inp produce the correct results
scan_node = out.owner.inputs[0].owner scan_node = out.owner.inputs[0].owner
result = scan_node.op.get_outer_iidx_from_outer_oidx_seq() result = scan_node.op.var_mappings['outer_inp_from_outer_out']
expected_result = [2] expected_result = {0: 2}
assert(result == expected_result) assert(result == expected_result)
result = scan_node.op.get_outer_iidx_from_inner_iidx_seq() result = scan_node.op.var_mappings['outer_inp_from_inner_inp']
expected_result = [1, 2, 2] expected_result = {0: 1, 1: 2, 2: 2}
assert(result == expected_result) assert(result == expected_result)
def test_grad_grad_mitsot_sitsot(self): def test_grad_grad_mitsot_sitsot(self):
...@@ -1704,16 +1703,16 @@ class T_Scan(unittest.TestCase): ...@@ -1704,16 +1703,16 @@ class T_Scan(unittest.TestCase):
analytic_grad[max_err_pos], analytic_grad[max_err_pos],
num_grad.gx[max_err_pos])) num_grad.gx[max_err_pos]))
# Also validate that the methods get_outer_iidx_from_outer_oidx_seq # Also validate that the mappings outer_inp_from_outer_out and
# and get_outer_iidx_from_inner_iidx_seq produce the correct results # outer_inp_from_inner_inp produce the correct results
scan_node = updates.values()[0].owner scan_node = updates.values()[0].owner
result = scan_node.op.get_outer_iidx_from_outer_oidx_seq() result = scan_node.op.var_mappings['outer_inp_from_outer_out']
expected_result = [3, -1, 4] expected_result = {0: 3, 1: 5, 2: 4}
assert(result == expected_result) assert(result == expected_result)
result = scan_node.op.get_outer_iidx_from_inner_iidx_seq() result = scan_node.op.var_mappings['outer_inp_from_inner_inp']
expected_result = [1, 2, 3, 4, 6] expected_result = {0: 1, 1: 2, 2: 3, 3: 4, 4: 6}
assert(result == expected_result) assert(result == expected_result)
def test_grad_multiple_outs_some_truncate(self): def test_grad_multiple_outs_some_truncate(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论