提交 b27c59d1 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Simplify `expand_empty`

上级 c70a8869
......@@ -231,8 +231,8 @@ def expand_empty(tensor_var, size):
if size == 0:
return tensor_var
shapes = [tensor_var.shape[x] for x in range(tensor_var.ndim)]
new_shape = [size + shapes[0]] + shapes[1:]
shapes = tuple(tensor_var.shape)
new_shape = (size + shapes[0], *shapes[1:])
empty = AllocEmpty(tensor_var.dtype)(*new_shape)
ret = set_subtensor(empty[: shapes[0]], tensor_var)
......
......@@ -44,25 +44,24 @@ def test_debugprint_sitsot():
│ │ │ │ │ │ └─ 1.0 [id O]
│ │ │ │ │ └─ 0 [id P]
│ │ │ │ └─ Subtensor{i} [id Q]
│ │ │ │ ├─ Shape [id R]
│ │ │ │ │ └─ Unbroadcast{0} [id J]
│ │ │ │ │ └─ ···
│ │ │ │ └─ 1 [id S]
│ │ │ │ ├─ Shape [id I]
│ │ │ │ │ └─ ···
│ │ │ │ └─ 1 [id R]
│ │ │ ├─ Unbroadcast{0} [id J]
│ │ │ │ └─ ···
│ │ │ └─ ScalarFromTensor [id T]
│ │ │ └─ ScalarFromTensor [id S]
│ │ │ └─ Subtensor{i} [id H]
│ │ │ └─ ···
│ │ └─ A [id M] (outer_in_non_seqs-0)
│ └─ 1 [id U]
└─ -1 [id V]
│ └─ 1 [id T]
└─ -1 [id U]
Inner graphs:
Scan{scan_fn, while_loop=False, inplace=none} [id C]
← Mul [id W] (inner_out_sit_sot-0)
├─ *0-<Vector(float64, shape=(?,))> [id X] -> [id E] (inner_in_sit_sot-0)
└─ *1-<Vector(float64, shape=(?,))> [id Y] -> [id M] (inner_in_non_seqs-0)
← Mul [id V] (inner_out_sit_sot-0)
├─ *0-<Vector(float64, shape=(?,))> [id W] -> [id E] (inner_in_sit_sot-0)
└─ *1-<Vector(float64, shape=(?,))> [id X] -> [id M] (inner_in_non_seqs-0)
"""
for truth, out in zip(expected_output.split("\n"), lines, strict=True):
......@@ -103,25 +102,24 @@ def test_debugprint_sitsot_no_extra_info():
│ │ │ │ │ │ └─ 1.0 [id O]
│ │ │ │ │ └─ 0 [id P]
│ │ │ │ └─ Subtensor{i} [id Q]
│ │ │ │ ├─ Shape [id R]
│ │ │ │ │ └─ Unbroadcast{0} [id J]
│ │ │ │ │ └─ ···
│ │ │ │ └─ 1 [id S]
│ │ │ │ ├─ Shape [id I]
│ │ │ │ │ └─ ···
│ │ │ │ └─ 1 [id R]
│ │ │ ├─ Unbroadcast{0} [id J]
│ │ │ │ └─ ···
│ │ │ └─ ScalarFromTensor [id T]
│ │ │ └─ ScalarFromTensor [id S]
│ │ │ └─ Subtensor{i} [id H]
│ │ │ └─ ···
│ │ └─ A [id M]
│ └─ 1 [id U]
└─ -1 [id V]
│ └─ 1 [id T]
└─ -1 [id U]
Inner graphs:
Scan{scan_fn, while_loop=False, inplace=none} [id C]
← Mul [id W]
├─ *0-<Vector(float64, shape=(?,))> [id X] -> [id E]
└─ *1-<Vector(float64, shape=(?,))> [id Y] -> [id M]
← Mul [id V]
├─ *0-<Vector(float64, shape=(?,))> [id W] -> [id E]
└─ *1-<Vector(float64, shape=(?,))> [id X] -> [id M]
"""
for truth, out in zip(expected_output.split("\n"), lines, strict=True):
......@@ -288,25 +286,24 @@ def test_debugprint_nested_scans():
│ │ │ │ │ │ │ └─ 1.0 [id BQ]
│ │ │ │ │ │ └─ 0 [id BR]
│ │ │ │ │ └─ Subtensor{i} [id BS]
│ │ │ │ │ ├─ Shape [id BT]
│ │ │ │ │ │ └─ Unbroadcast{0} [id BL]
│ │ │ │ │ │ └─ ···
│ │ │ │ │ └─ 1 [id BU]
│ │ │ │ │ ├─ Shape [id BK]
│ │ │ │ │ │ └─ ···
│ │ │ │ │ └─ 1 [id BT]
│ │ │ │ ├─ Unbroadcast{0} [id BL]
│ │ │ │ │ └─ ···
│ │ │ │ └─ ScalarFromTensor [id BV]
│ │ │ │ └─ ScalarFromTensor [id BU]
│ │ │ │ └─ Subtensor{i} [id BJ]
│ │ │ │ └─ ···
│ │ │ └─ *2-<Vector(float64, shape=(?,))> [id BO] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
│ │ └─ 1 [id BW]
│ └─ -1 [id BX]
└─ ExpandDims{axis=0} [id BY]
└─ *1-<Scalar(int64, shape=())> [id BZ] -> [id U] (inner_in_seqs-1)
│ │ └─ 1 [id BV]
│ └─ -1 [id BW]
└─ ExpandDims{axis=0} [id BX]
└─ *1-<Scalar(int64, shape=())> [id BY] -> [id U] (inner_in_seqs-1)
Scan{scan_fn, while_loop=False, inplace=none} [id BE]
← Mul [id CA] (inner_out_sit_sot-0)
├─ *0-<Vector(float64, shape=(?,))> [id CB] -> [id BG] (inner_in_sit_sot-0)
└─ *1-<Vector(float64, shape=(?,))> [id CC] -> [id BO] (inner_in_non_seqs-0)
← Mul [id BZ] (inner_out_sit_sot-0)
├─ *0-<Vector(float64, shape=(?,))> [id CA] -> [id BG] (inner_in_sit_sot-0)
└─ *1-<Vector(float64, shape=(?,))> [id CB] -> [id BO] (inner_in_non_seqs-0)
"""
for truth, out in zip(expected_output.split("\n"), lines, strict=True):
......@@ -386,27 +383,26 @@ def test_debugprint_nested_scans():
│ │ │ │ │ │ │ └─ 1.0 [id BR]
│ │ │ │ │ │ └─ 0 [id BS]
│ │ │ │ │ └─ Subtensor{i} [id BT]
│ │ │ │ │ ├─ Shape [id BU]
│ │ │ │ │ │ └─ Unbroadcast{0} [id BN]
│ │ │ │ │ │ └─ ···
│ │ │ │ │ └─ 1 [id BV]
│ │ │ │ │ ├─ Shape [id BM]
│ │ │ │ │ │ └─ ···
│ │ │ │ │ └─ 1 [id BU]
│ │ │ │ ├─ Unbroadcast{0} [id BN]
│ │ │ │ │ └─ ···
│ │ │ │ └─ ScalarFromTensor [id BW]
│ │ │ │ └─ ScalarFromTensor [id BV]
│ │ │ │ └─ Subtensor{i} [id BL]
│ │ │ │ └─ ···
│ │ │ └─ *2-<Vector(float64, shape=(?,))> [id BA] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
│ │ └─ 1 [id BX]
│ └─ -1 [id BY]
└─ ExpandDims{axis=0} [id BZ]
│ │ └─ 1 [id BW]
│ └─ -1 [id BX]
└─ ExpandDims{axis=0} [id BY]
└─ *1-<Scalar(int64, shape=())> [id Z] (inner_in_seqs-1)
Scan{scan_fn, while_loop=False, inplace=none} [id BH]
→ *0-<Vector(float64, shape=(?,))> [id CA] -> [id BI] (inner_in_sit_sot-0)
→ *1-<Vector(float64, shape=(?,))> [id CB] -> [id BA] (inner_in_non_seqs-0)
← Mul [id CC] (inner_out_sit_sot-0)
├─ *0-<Vector(float64, shape=(?,))> [id CA] (inner_in_sit_sot-0)
└─ *1-<Vector(float64, shape=(?,))> [id CB] (inner_in_non_seqs-0)
→ *0-<Vector(float64, shape=(?,))> [id BZ] -> [id BI] (inner_in_sit_sot-0)
→ *1-<Vector(float64, shape=(?,))> [id CA] -> [id BA] (inner_in_non_seqs-0)
← Mul [id CB] (inner_out_sit_sot-0)
├─ *0-<Vector(float64, shape=(?,))> [id BZ] (inner_in_sit_sot-0)
└─ *1-<Vector(float64, shape=(?,))> [id CA] (inner_in_non_seqs-0)
"""
for truth, out in zip(expected_output.split("\n"), lines, strict=True):
......@@ -528,98 +524,97 @@ def test_debugprint_mitmot():
│ │ │ │ │ │ │ │ └─ 1.0 [id R]
│ │ │ │ │ │ │ └─ 0 [id S]
│ │ │ │ │ │ └─ Subtensor{i} [id T]
│ │ │ │ │ │ ├─ Shape [id U]
│ │ │ │ │ │ │ └─ Unbroadcast{0} [id M]
│ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ └─ 1 [id V]
│ │ │ │ │ │ ├─ Shape [id L]
│ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ └─ 1 [id U]
│ │ │ │ │ ├─ Unbroadcast{0} [id M]
│ │ │ │ │ │ └─ ···
│ │ │ │ │ └─ ScalarFromTensor [id W]
│ │ │ │ │ └─ ScalarFromTensor [id V]
│ │ │ │ │ └─ Subtensor{i} [id K]
│ │ │ │ │ └─ ···
│ │ │ │ └─ A [id P] (outer_in_non_seqs-0)
│ │ │ └─ 0 [id X]
│ │ └─ 1 [id Y]
│ ├─ Subtensor{:stop} [id Z] (outer_in_seqs-0)
│ │ ├─ Subtensor{::step} [id BA]
│ │ │ ├─ Subtensor{:stop} [id BB]
│ │ │ └─ 0 [id W]
│ │ └─ 1 [id X]
│ ├─ Subtensor{:stop} [id Y] (outer_in_seqs-0)
│ │ ├─ Subtensor{::step} [id Z]
│ │ │ ├─ Subtensor{:stop} [id BA]
│ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
│ │ │ │ │ └─ ···
│ │ │ │ └─ -1 [id BC]
│ │ │ └─ -1 [id BD]
│ │ └─ ScalarFromTensor [id BE]
│ │ │ │ └─ -1 [id BB]
│ │ │ └─ -1 [id BC]
│ │ └─ ScalarFromTensor [id BD]
│ │ └─ Sub [id C]
│ │ └─ ···
│ ├─ Subtensor{:stop} [id BF] (outer_in_seqs-1)
│ │ ├─ Subtensor{:stop} [id BG]
│ │ │ ├─ Subtensor{::step} [id BH]
│ ├─ Subtensor{:stop} [id BE] (outer_in_seqs-1)
│ │ ├─ Subtensor{:stop} [id BF]
│ │ │ ├─ Subtensor{::step} [id BG]
│ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
│ │ │ │ │ └─ ···
│ │ │ │ └─ -1 [id BI]
│ │ │ └─ -1 [id BJ]
│ │ └─ ScalarFromTensor [id BK]
│ │ │ │ └─ -1 [id BH]
│ │ │ └─ -1 [id BI]
│ │ └─ ScalarFromTensor [id BJ]
│ │ └─ Sub [id C]
│ │ └─ ···
│ ├─ Subtensor{::step} [id BL] (outer_in_mit_mot-0)
│ │ ├─ IncSubtensor{start:} [id BM]
│ │ │ ├─ Second [id BN]
│ ├─ Subtensor{::step} [id BK] (outer_in_mit_mot-0)
│ │ ├─ IncSubtensor{start:} [id BL]
│ │ │ ├─ Second [id BM]
│ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
│ │ │ │ │ └─ ···
│ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BO]
│ │ │ │ └─ 0.0 [id BP]
│ │ │ ├─ IncSubtensor{i} [id BQ]
│ │ │ │ ├─ Second [id BR]
│ │ │ │ │ ├─ Subtensor{start:} [id BS]
│ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BN]
│ │ │ │ └─ 0.0 [id BO]
│ │ │ ├─ IncSubtensor{i} [id BP]
│ │ │ │ ├─ Second [id BQ]
│ │ │ │ │ ├─ Subtensor{start:} [id BR]
│ │ │ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
│ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ └─ 1 [id BT]
│ │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BU]
│ │ │ │ │ └─ 0.0 [id BV]
│ │ │ │ ├─ Second [id BW]
│ │ │ │ │ ├─ Subtensor{i} [id BX]
│ │ │ │ │ │ ├─ Subtensor{start:} [id BS]
│ │ │ │ │ │ └─ 1 [id BS]
│ │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BT]
│ │ │ │ │ └─ 0.0 [id BU]
│ │ │ │ ├─ Second [id BV]
│ │ │ │ │ ├─ Subtensor{i} [id BW]
│ │ │ │ │ │ ├─ Subtensor{start:} [id BR]
│ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ └─ -1 [id BY]
│ │ │ │ │ └─ ExpandDims{axis=0} [id BZ]
│ │ │ │ │ └─ Second [id CA]
│ │ │ │ │ ├─ Sum{axes=None} [id CB]
│ │ │ │ │ │ └─ Subtensor{i} [id BX]
│ │ │ │ │ │ └─ -1 [id BX]
│ │ │ │ │ └─ ExpandDims{axis=0} [id BY]
│ │ │ │ │ └─ Second [id BZ]
│ │ │ │ │ ├─ Sum{axes=None} [id CA]
│ │ │ │ │ │ └─ Subtensor{i} [id BW]
│ │ │ │ │ │ └─ ···
│ │ │ │ │ └─ 1.0 [id CC]
│ │ │ │ └─ -1 [id BY]
│ │ │ └─ 1 [id BT]
│ │ └─ -1 [id CD]
│ ├─ Alloc [id CE] (outer_in_sit_sot-0)
│ │ ├─ 0.0 [id CF]
│ │ ├─ Add [id CG]
│ │ │ │ │ └─ 1.0 [id CB]
│ │ │ │ └─ -1 [id BX]
│ │ │ └─ 1 [id BS]
│ │ └─ -1 [id CC]
│ ├─ Alloc [id CD] (outer_in_sit_sot-0)
│ │ ├─ 0.0 [id CE]
│ │ ├─ Add [id CF]
│ │ │ ├─ Sub [id C]
│ │ │ │ └─ ···
│ │ │ └─ 1 [id CH]
│ │ └─ Subtensor{i} [id CI]
│ │ ├─ Shape [id CJ]
│ │ │ └─ 1 [id CG]
│ │ └─ Subtensor{i} [id CH]
│ │ ├─ Shape [id CI]
│ │ │ └─ A [id P]
│ │ └─ 0 [id CK]
│ │ └─ 0 [id CJ]
│ └─ A [id P] (outer_in_non_seqs-0)
└─ -1 [id CL]
└─ -1 [id CK]
Inner graphs:
Scan{grad_of_scan_fn, while_loop=False, inplace=none} [id B]
← Add [id CM] (inner_out_mit_mot-0-0)
├─ Mul [id CN]
│ ├─ *2-<Vector(float64, shape=(?,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0)
│ └─ *5-<Vector(float64, shape=(?,))> [id CP] -> [id P] (inner_in_non_seqs-0)
└─ *3-<Vector(float64, shape=(?,))> [id CQ] -> [id BL] (inner_in_mit_mot-0-1)
← Add [id CR] (inner_out_sit_sot-0)
├─ Mul [id CS]
│ ├─ *2-<Vector(float64, shape=(?,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0)
│ └─ *0-<Vector(float64, shape=(?,))> [id CT] -> [id Z] (inner_in_seqs-0)
└─ *4-<Vector(float64, shape=(?,))> [id CU] -> [id CE] (inner_in_sit_sot-0)
← Add [id CL] (inner_out_mit_mot-0-0)
├─ Mul [id CM]
│ ├─ *2-<Vector(float64, shape=(?,))> [id CN] -> [id BK] (inner_in_mit_mot-0-0)
│ └─ *5-<Vector(float64, shape=(?,))> [id CO] -> [id P] (inner_in_non_seqs-0)
└─ *3-<Vector(float64, shape=(?,))> [id CP] -> [id BK] (inner_in_mit_mot-0-1)
← Add [id CQ] (inner_out_sit_sot-0)
├─ Mul [id CR]
│ ├─ *2-<Vector(float64, shape=(?,))> [id CN] -> [id BK] (inner_in_mit_mot-0-0)
│ └─ *0-<Vector(float64, shape=(?,))> [id CS] -> [id Y] (inner_in_seqs-0)
└─ *4-<Vector(float64, shape=(?,))> [id CT] -> [id CD] (inner_in_sit_sot-0)
Scan{scan_fn, while_loop=False, inplace=none} [id F]
← Mul [id CV] (inner_out_sit_sot-0)
├─ *0-<Vector(float64, shape=(?,))> [id CT] -> [id H] (inner_in_sit_sot-0)
└─ *1-<Vector(float64, shape=(?,))> [id CW] -> [id P] (inner_in_non_seqs-0)
← Mul [id CU] (inner_out_sit_sot-0)
├─ *0-<Vector(float64, shape=(?,))> [id CS] -> [id H] (inner_in_sit_sot-0)
└─ *1-<Vector(float64, shape=(?,))> [id CV] -> [id P] (inner_in_non_seqs-0)
"""
for truth, out in zip(expected_output.split("\n"), lines, strict=True):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论