提交 b27c59d1 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Simplify `expand_empty`

上级 c70a8869
...@@ -231,8 +231,8 @@ def expand_empty(tensor_var, size): ...@@ -231,8 +231,8 @@ def expand_empty(tensor_var, size):
if size == 0: if size == 0:
return tensor_var return tensor_var
shapes = [tensor_var.shape[x] for x in range(tensor_var.ndim)] shapes = tuple(tensor_var.shape)
new_shape = [size + shapes[0]] + shapes[1:] new_shape = (size + shapes[0], *shapes[1:])
empty = AllocEmpty(tensor_var.dtype)(*new_shape) empty = AllocEmpty(tensor_var.dtype)(*new_shape)
ret = set_subtensor(empty[: shapes[0]], tensor_var) ret = set_subtensor(empty[: shapes[0]], tensor_var)
......
...@@ -44,25 +44,24 @@ def test_debugprint_sitsot(): ...@@ -44,25 +44,24 @@ def test_debugprint_sitsot():
│ │ │ │ │ │ └─ 1.0 [id O] │ │ │ │ │ │ └─ 1.0 [id O]
│ │ │ │ │ └─ 0 [id P] │ │ │ │ │ └─ 0 [id P]
│ │ │ │ └─ Subtensor{i} [id Q] │ │ │ │ └─ Subtensor{i} [id Q]
│ │ │ │ ├─ Shape [id R] │ │ │ │ ├─ Shape [id I]
│ │ │ │ │ └─ Unbroadcast{0} [id J] │ │ │ │ │ └─ ···
│ │ │ │ │ └─ ··· │ │ │ │ └─ 1 [id R]
│ │ │ │ └─ 1 [id S]
│ │ │ ├─ Unbroadcast{0} [id J] │ │ │ ├─ Unbroadcast{0} [id J]
│ │ │ │ └─ ··· │ │ │ │ └─ ···
│ │ │ └─ ScalarFromTensor [id T] │ │ │ └─ ScalarFromTensor [id S]
│ │ │ └─ Subtensor{i} [id H] │ │ │ └─ Subtensor{i} [id H]
│ │ │ └─ ··· │ │ │ └─ ···
│ │ └─ A [id M] (outer_in_non_seqs-0) │ │ └─ A [id M] (outer_in_non_seqs-0)
│ └─ 1 [id U] │ └─ 1 [id T]
└─ -1 [id V] └─ -1 [id U]
Inner graphs: Inner graphs:
Scan{scan_fn, while_loop=False, inplace=none} [id C] Scan{scan_fn, while_loop=False, inplace=none} [id C]
← Mul [id W] (inner_out_sit_sot-0) ← Mul [id V] (inner_out_sit_sot-0)
├─ *0-<Vector(float64, shape=(?,))> [id X] -> [id E] (inner_in_sit_sot-0) ├─ *0-<Vector(float64, shape=(?,))> [id W] -> [id E] (inner_in_sit_sot-0)
└─ *1-<Vector(float64, shape=(?,))> [id Y] -> [id M] (inner_in_non_seqs-0) └─ *1-<Vector(float64, shape=(?,))> [id X] -> [id M] (inner_in_non_seqs-0)
""" """
for truth, out in zip(expected_output.split("\n"), lines, strict=True): for truth, out in zip(expected_output.split("\n"), lines, strict=True):
...@@ -103,25 +102,24 @@ def test_debugprint_sitsot_no_extra_info(): ...@@ -103,25 +102,24 @@ def test_debugprint_sitsot_no_extra_info():
│ │ │ │ │ │ └─ 1.0 [id O] │ │ │ │ │ │ └─ 1.0 [id O]
│ │ │ │ │ └─ 0 [id P] │ │ │ │ │ └─ 0 [id P]
│ │ │ │ └─ Subtensor{i} [id Q] │ │ │ │ └─ Subtensor{i} [id Q]
│ │ │ │ ├─ Shape [id R] │ │ │ │ ├─ Shape [id I]
│ │ │ │ │ └─ Unbroadcast{0} [id J] │ │ │ │ │ └─ ···
│ │ │ │ │ └─ ··· │ │ │ │ └─ 1 [id R]
│ │ │ │ └─ 1 [id S]
│ │ │ ├─ Unbroadcast{0} [id J] │ │ │ ├─ Unbroadcast{0} [id J]
│ │ │ │ └─ ··· │ │ │ │ └─ ···
│ │ │ └─ ScalarFromTensor [id T] │ │ │ └─ ScalarFromTensor [id S]
│ │ │ └─ Subtensor{i} [id H] │ │ │ └─ Subtensor{i} [id H]
│ │ │ └─ ··· │ │ │ └─ ···
│ │ └─ A [id M] │ │ └─ A [id M]
│ └─ 1 [id U] │ └─ 1 [id T]
└─ -1 [id V] └─ -1 [id U]
Inner graphs: Inner graphs:
Scan{scan_fn, while_loop=False, inplace=none} [id C] Scan{scan_fn, while_loop=False, inplace=none} [id C]
← Mul [id W] ← Mul [id V]
├─ *0-<Vector(float64, shape=(?,))> [id X] -> [id E] ├─ *0-<Vector(float64, shape=(?,))> [id W] -> [id E]
└─ *1-<Vector(float64, shape=(?,))> [id Y] -> [id M] └─ *1-<Vector(float64, shape=(?,))> [id X] -> [id M]
""" """
for truth, out in zip(expected_output.split("\n"), lines, strict=True): for truth, out in zip(expected_output.split("\n"), lines, strict=True):
...@@ -288,25 +286,24 @@ def test_debugprint_nested_scans(): ...@@ -288,25 +286,24 @@ def test_debugprint_nested_scans():
│ │ │ │ │ │ │ └─ 1.0 [id BQ] │ │ │ │ │ │ │ └─ 1.0 [id BQ]
│ │ │ │ │ │ └─ 0 [id BR] │ │ │ │ │ │ └─ 0 [id BR]
│ │ │ │ │ └─ Subtensor{i} [id BS] │ │ │ │ │ └─ Subtensor{i} [id BS]
│ │ │ │ │ ├─ Shape [id BT] │ │ │ │ │ ├─ Shape [id BK]
│ │ │ │ │ │ └─ Unbroadcast{0} [id BL] │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ └─ ··· │ │ │ │ │ └─ 1 [id BT]
│ │ │ │ │ └─ 1 [id BU]
│ │ │ │ ├─ Unbroadcast{0} [id BL] │ │ │ │ ├─ Unbroadcast{0} [id BL]
│ │ │ │ │ └─ ··· │ │ │ │ │ └─ ···
│ │ │ │ └─ ScalarFromTensor [id BV] │ │ │ │ └─ ScalarFromTensor [id BU]
│ │ │ │ └─ Subtensor{i} [id BJ] │ │ │ │ └─ Subtensor{i} [id BJ]
│ │ │ │ └─ ··· │ │ │ │ └─ ···
│ │ │ └─ *2-<Vector(float64, shape=(?,))> [id BO] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0) │ │ │ └─ *2-<Vector(float64, shape=(?,))> [id BO] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
│ │ └─ 1 [id BW] │ │ └─ 1 [id BV]
│ └─ -1 [id BX] │ └─ -1 [id BW]
└─ ExpandDims{axis=0} [id BY] └─ ExpandDims{axis=0} [id BX]
└─ *1-<Scalar(int64, shape=())> [id BZ] -> [id U] (inner_in_seqs-1) └─ *1-<Scalar(int64, shape=())> [id BY] -> [id U] (inner_in_seqs-1)
Scan{scan_fn, while_loop=False, inplace=none} [id BE] Scan{scan_fn, while_loop=False, inplace=none} [id BE]
← Mul [id CA] (inner_out_sit_sot-0) ← Mul [id BZ] (inner_out_sit_sot-0)
├─ *0-<Vector(float64, shape=(?,))> [id CB] -> [id BG] (inner_in_sit_sot-0) ├─ *0-<Vector(float64, shape=(?,))> [id CA] -> [id BG] (inner_in_sit_sot-0)
└─ *1-<Vector(float64, shape=(?,))> [id CC] -> [id BO] (inner_in_non_seqs-0) └─ *1-<Vector(float64, shape=(?,))> [id CB] -> [id BO] (inner_in_non_seqs-0)
""" """
for truth, out in zip(expected_output.split("\n"), lines, strict=True): for truth, out in zip(expected_output.split("\n"), lines, strict=True):
...@@ -386,27 +383,26 @@ def test_debugprint_nested_scans(): ...@@ -386,27 +383,26 @@ def test_debugprint_nested_scans():
│ │ │ │ │ │ │ └─ 1.0 [id BR] │ │ │ │ │ │ │ └─ 1.0 [id BR]
│ │ │ │ │ │ └─ 0 [id BS] │ │ │ │ │ │ └─ 0 [id BS]
│ │ │ │ │ └─ Subtensor{i} [id BT] │ │ │ │ │ └─ Subtensor{i} [id BT]
│ │ │ │ │ ├─ Shape [id BU] │ │ │ │ │ ├─ Shape [id BM]
│ │ │ │ │ │ └─ Unbroadcast{0} [id BN] │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ └─ ··· │ │ │ │ │ └─ 1 [id BU]
│ │ │ │ │ └─ 1 [id BV]
│ │ │ │ ├─ Unbroadcast{0} [id BN] │ │ │ │ ├─ Unbroadcast{0} [id BN]
│ │ │ │ │ └─ ··· │ │ │ │ │ └─ ···
│ │ │ │ └─ ScalarFromTensor [id BW] │ │ │ │ └─ ScalarFromTensor [id BV]
│ │ │ │ └─ Subtensor{i} [id BL] │ │ │ │ └─ Subtensor{i} [id BL]
│ │ │ │ └─ ··· │ │ │ │ └─ ···
│ │ │ └─ *2-<Vector(float64, shape=(?,))> [id BA] (inner_in_non_seqs-0) (outer_in_non_seqs-0) │ │ │ └─ *2-<Vector(float64, shape=(?,))> [id BA] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
│ │ └─ 1 [id BX] │ │ └─ 1 [id BW]
│ └─ -1 [id BY] │ └─ -1 [id BX]
└─ ExpandDims{axis=0} [id BZ] └─ ExpandDims{axis=0} [id BY]
└─ *1-<Scalar(int64, shape=())> [id Z] (inner_in_seqs-1) └─ *1-<Scalar(int64, shape=())> [id Z] (inner_in_seqs-1)
Scan{scan_fn, while_loop=False, inplace=none} [id BH] Scan{scan_fn, while_loop=False, inplace=none} [id BH]
→ *0-<Vector(float64, shape=(?,))> [id CA] -> [id BI] (inner_in_sit_sot-0) → *0-<Vector(float64, shape=(?,))> [id BZ] -> [id BI] (inner_in_sit_sot-0)
→ *1-<Vector(float64, shape=(?,))> [id CB] -> [id BA] (inner_in_non_seqs-0) → *1-<Vector(float64, shape=(?,))> [id CA] -> [id BA] (inner_in_non_seqs-0)
← Mul [id CC] (inner_out_sit_sot-0) ← Mul [id CB] (inner_out_sit_sot-0)
├─ *0-<Vector(float64, shape=(?,))> [id CA] (inner_in_sit_sot-0) ├─ *0-<Vector(float64, shape=(?,))> [id BZ] (inner_in_sit_sot-0)
└─ *1-<Vector(float64, shape=(?,))> [id CB] (inner_in_non_seqs-0) └─ *1-<Vector(float64, shape=(?,))> [id CA] (inner_in_non_seqs-0)
""" """
for truth, out in zip(expected_output.split("\n"), lines, strict=True): for truth, out in zip(expected_output.split("\n"), lines, strict=True):
...@@ -528,98 +524,97 @@ def test_debugprint_mitmot(): ...@@ -528,98 +524,97 @@ def test_debugprint_mitmot():
│ │ │ │ │ │ │ │ └─ 1.0 [id R] │ │ │ │ │ │ │ │ └─ 1.0 [id R]
│ │ │ │ │ │ │ └─ 0 [id S] │ │ │ │ │ │ │ └─ 0 [id S]
│ │ │ │ │ │ └─ Subtensor{i} [id T] │ │ │ │ │ │ └─ Subtensor{i} [id T]
│ │ │ │ │ │ ├─ Shape [id U] │ │ │ │ │ │ ├─ Shape [id L]
│ │ │ │ │ │ │ └─ Unbroadcast{0} [id M] │ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ │ └─ ··· │ │ │ │ │ │ └─ 1 [id U]
│ │ │ │ │ │ └─ 1 [id V]
│ │ │ │ │ ├─ Unbroadcast{0} [id M] │ │ │ │ │ ├─ Unbroadcast{0} [id M]
│ │ │ │ │ │ └─ ··· │ │ │ │ │ │ └─ ···
│ │ │ │ │ └─ ScalarFromTensor [id W] │ │ │ │ │ └─ ScalarFromTensor [id V]
│ │ │ │ │ └─ Subtensor{i} [id K] │ │ │ │ │ └─ Subtensor{i} [id K]
│ │ │ │ │ └─ ··· │ │ │ │ │ └─ ···
│ │ │ │ └─ A [id P] (outer_in_non_seqs-0) │ │ │ │ └─ A [id P] (outer_in_non_seqs-0)
│ │ │ └─ 0 [id X] │ │ │ └─ 0 [id W]
│ │ └─ 1 [id Y] │ │ └─ 1 [id X]
│ ├─ Subtensor{:stop} [id Z] (outer_in_seqs-0) │ ├─ Subtensor{:stop} [id Y] (outer_in_seqs-0)
│ │ ├─ Subtensor{::step} [id BA] │ │ ├─ Subtensor{::step} [id Z]
│ │ │ ├─ Subtensor{:stop} [id BB] │ │ │ ├─ Subtensor{:stop} [id BA]
│ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0) │ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
│ │ │ │ │ └─ ··· │ │ │ │ │ └─ ···
│ │ │ │ └─ -1 [id BC] │ │ │ │ └─ -1 [id BB]
│ │ │ └─ -1 [id BD] │ │ │ └─ -1 [id BC]
│ │ └─ ScalarFromTensor [id BE] │ │ └─ ScalarFromTensor [id BD]
│ │ └─ Sub [id C] │ │ └─ Sub [id C]
│ │ └─ ··· │ │ └─ ···
│ ├─ Subtensor{:stop} [id BF] (outer_in_seqs-1) │ ├─ Subtensor{:stop} [id BE] (outer_in_seqs-1)
│ │ ├─ Subtensor{:stop} [id BG] │ │ ├─ Subtensor{:stop} [id BF]
│ │ │ ├─ Subtensor{::step} [id BH] │ │ │ ├─ Subtensor{::step} [id BG]
│ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0) │ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
│ │ │ │ │ └─ ··· │ │ │ │ │ └─ ···
│ │ │ │ └─ -1 [id BI] │ │ │ │ └─ -1 [id BH]
│ │ │ └─ -1 [id BJ] │ │ │ └─ -1 [id BI]
│ │ └─ ScalarFromTensor [id BK] │ │ └─ ScalarFromTensor [id BJ]
│ │ └─ Sub [id C] │ │ └─ Sub [id C]
│ │ └─ ··· │ │ └─ ···
│ ├─ Subtensor{::step} [id BL] (outer_in_mit_mot-0) │ ├─ Subtensor{::step} [id BK] (outer_in_mit_mot-0)
│ │ ├─ IncSubtensor{start:} [id BM] │ │ ├─ IncSubtensor{start:} [id BL]
│ │ │ ├─ Second [id BN] │ │ │ ├─ Second [id BM]
│ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0) │ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
│ │ │ │ │ └─ ··· │ │ │ │ │ └─ ···
│ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BO] │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BN]
│ │ │ │ └─ 0.0 [id BP] │ │ │ │ └─ 0.0 [id BO]
│ │ │ ├─ IncSubtensor{i} [id BQ] │ │ │ ├─ IncSubtensor{i} [id BP]
│ │ │ │ ├─ Second [id BR] │ │ │ │ ├─ Second [id BQ]
│ │ │ │ │ ├─ Subtensor{start:} [id BS] │ │ │ │ │ ├─ Subtensor{start:} [id BR]
│ │ │ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0) │ │ │ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
│ │ │ │ │ │ │ └─ ··· │ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ └─ 1 [id BT] │ │ │ │ │ │ └─ 1 [id BS]
│ │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BU] │ │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BT]
│ │ │ │ │ └─ 0.0 [id BV] │ │ │ │ │ └─ 0.0 [id BU]
│ │ │ │ ├─ Second [id BW] │ │ │ │ ├─ Second [id BV]
│ │ │ │ │ ├─ Subtensor{i} [id BX] │ │ │ │ │ ├─ Subtensor{i} [id BW]
│ │ │ │ │ │ ├─ Subtensor{start:} [id BS] │ │ │ │ │ │ ├─ Subtensor{start:} [id BR]
│ │ │ │ │ │ │ └─ ··· │ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ └─ -1 [id BY] │ │ │ │ │ │ └─ -1 [id BX]
│ │ │ │ │ └─ ExpandDims{axis=0} [id BZ] │ │ │ │ │ └─ ExpandDims{axis=0} [id BY]
│ │ │ │ │ └─ Second [id CA] │ │ │ │ │ └─ Second [id BZ]
│ │ │ │ │ ├─ Sum{axes=None} [id CB] │ │ │ │ │ ├─ Sum{axes=None} [id CA]
│ │ │ │ │ │ └─ Subtensor{i} [id BX] │ │ │ │ │ │ └─ Subtensor{i} [id BW]
│ │ │ │ │ │ └─ ··· │ │ │ │ │ │ └─ ···
│ │ │ │ │ └─ 1.0 [id CC] │ │ │ │ │ └─ 1.0 [id CB]
│ │ │ │ └─ -1 [id BY] │ │ │ │ └─ -1 [id BX]
│ │ │ └─ 1 [id BT] │ │ │ └─ 1 [id BS]
│ │ └─ -1 [id CD] │ │ └─ -1 [id CC]
│ ├─ Alloc [id CE] (outer_in_sit_sot-0) │ ├─ Alloc [id CD] (outer_in_sit_sot-0)
│ │ ├─ 0.0 [id CF] │ │ ├─ 0.0 [id CE]
│ │ ├─ Add [id CG] │ │ ├─ Add [id CF]
│ │ │ ├─ Sub [id C] │ │ │ ├─ Sub [id C]
│ │ │ │ └─ ··· │ │ │ │ └─ ···
│ │ │ └─ 1 [id CH] │ │ │ └─ 1 [id CG]
│ │ └─ Subtensor{i} [id CI] │ │ └─ Subtensor{i} [id CH]
│ │ ├─ Shape [id CJ] │ │ ├─ Shape [id CI]
│ │ │ └─ A [id P] │ │ │ └─ A [id P]
│ │ └─ 0 [id CK] │ │ └─ 0 [id CJ]
│ └─ A [id P] (outer_in_non_seqs-0) │ └─ A [id P] (outer_in_non_seqs-0)
└─ -1 [id CL] └─ -1 [id CK]
Inner graphs: Inner graphs:
Scan{grad_of_scan_fn, while_loop=False, inplace=none} [id B] Scan{grad_of_scan_fn, while_loop=False, inplace=none} [id B]
← Add [id CM] (inner_out_mit_mot-0-0) ← Add [id CL] (inner_out_mit_mot-0-0)
├─ Mul [id CN] ├─ Mul [id CM]
│ ├─ *2-<Vector(float64, shape=(?,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0) │ ├─ *2-<Vector(float64, shape=(?,))> [id CN] -> [id BK] (inner_in_mit_mot-0-0)
│ └─ *5-<Vector(float64, shape=(?,))> [id CP] -> [id P] (inner_in_non_seqs-0) │ └─ *5-<Vector(float64, shape=(?,))> [id CO] -> [id P] (inner_in_non_seqs-0)
└─ *3-<Vector(float64, shape=(?,))> [id CQ] -> [id BL] (inner_in_mit_mot-0-1) └─ *3-<Vector(float64, shape=(?,))> [id CP] -> [id BK] (inner_in_mit_mot-0-1)
← Add [id CR] (inner_out_sit_sot-0) ← Add [id CQ] (inner_out_sit_sot-0)
├─ Mul [id CS] ├─ Mul [id CR]
│ ├─ *2-<Vector(float64, shape=(?,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0) │ ├─ *2-<Vector(float64, shape=(?,))> [id CN] -> [id BK] (inner_in_mit_mot-0-0)
│ └─ *0-<Vector(float64, shape=(?,))> [id CT] -> [id Z] (inner_in_seqs-0) │ └─ *0-<Vector(float64, shape=(?,))> [id CS] -> [id Y] (inner_in_seqs-0)
└─ *4-<Vector(float64, shape=(?,))> [id CU] -> [id CE] (inner_in_sit_sot-0) └─ *4-<Vector(float64, shape=(?,))> [id CT] -> [id CD] (inner_in_sit_sot-0)
Scan{scan_fn, while_loop=False, inplace=none} [id F] Scan{scan_fn, while_loop=False, inplace=none} [id F]
← Mul [id CV] (inner_out_sit_sot-0) ← Mul [id CU] (inner_out_sit_sot-0)
├─ *0-<Vector(float64, shape=(?,))> [id CT] -> [id H] (inner_in_sit_sot-0) ├─ *0-<Vector(float64, shape=(?,))> [id CS] -> [id H] (inner_in_sit_sot-0)
└─ *1-<Vector(float64, shape=(?,))> [id CW] -> [id P] (inner_in_non_seqs-0) └─ *1-<Vector(float64, shape=(?,))> [id CV] -> [id P] (inner_in_non_seqs-0)
""" """
for truth, out in zip(expected_output.split("\n"), lines, strict=True): for truth, out in zip(expected_output.split("\n"), lines, strict=True):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论