提交 88516d2e authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Simplify Scan string representation

上级 f6bb307d
...@@ -1282,27 +1282,18 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1282,27 +1282,18 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
) )
def __str__(self): def __str__(self):
device_str = "cpu" inplace = "none"
if self.info.as_while:
name = "do_while"
else:
name = "for"
aux_txt = "%s"
if len(self.destroy_map.keys()) > 0: if len(self.destroy_map.keys()) > 0:
# Check if all outputs are inplace # Check if all outputs are inplace
if sorted(self.destroy_map.keys()) == sorted( if sorted(self.destroy_map.keys()) == sorted(
range(self.info.n_mit_mot + self.info.n_mit_sot + self.info.n_sit_sot) range(self.info.n_mit_mot + self.info.n_mit_sot + self.info.n_sit_sot)
): ):
aux_txt += "all_inplace,%s,%s}" inplace = "all"
else: else:
aux_txt += "{inplace{" inplace = str(list(self.destroy_map.keys()))
for k in self.destroy_map.keys(): return (
aux_txt += str(k) + "," f"Scan{{{self.name}, while_loop={self.info.as_while}, inplace={inplace}}}"
aux_txt += "},%s,%s}" )
else:
aux_txt += "{%s,%s}"
aux_txt = aux_txt % (name, device_str, str(self.name))
return aux_txt
def __hash__(self): def __hash__(self):
return hash( return hash(
......
...@@ -28,7 +28,7 @@ def test_debugprint_sitsot(): ...@@ -28,7 +28,7 @@ def test_debugprint_sitsot():
expected_output = """Subtensor{i} [id A] expected_output = """Subtensor{i} [id A]
├─ Subtensor{start:} [id B] ├─ Subtensor{start:} [id B]
│ ├─ for{cpu,scan_fn} [id C] (outer_out_sit_sot-0) │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id C] (outer_out_sit_sot-0)
│ │ ├─ k [id D] (n_steps) │ │ ├─ k [id D] (n_steps)
│ │ ├─ SetSubtensor{:stop} [id E] (outer_in_sit_sot-0) │ │ ├─ SetSubtensor{:stop} [id E] (outer_in_sit_sot-0)
│ │ │ ├─ AllocEmpty{dtype='float64'} [id F] │ │ │ ├─ AllocEmpty{dtype='float64'} [id F]
...@@ -59,7 +59,7 @@ def test_debugprint_sitsot(): ...@@ -59,7 +59,7 @@ def test_debugprint_sitsot():
Inner graphs: Inner graphs:
for{cpu,scan_fn} [id C] Scan{scan_fn, while_loop=False, inplace=none} [id C]
← Mul [id W] (inner_out_sit_sot-0) ← Mul [id W] (inner_out_sit_sot-0)
├─ *0-<TensorType(float64, (?,))> [id X] -> [id E] (inner_in_sit_sot-0) ├─ *0-<TensorType(float64, (?,))> [id X] -> [id E] (inner_in_sit_sot-0)
└─ *1-<TensorType(float64, (?,))> [id Y] -> [id M] (inner_in_non_seqs-0)""" └─ *1-<TensorType(float64, (?,))> [id Y] -> [id M] (inner_in_non_seqs-0)"""
...@@ -86,7 +86,7 @@ def test_debugprint_sitsot_no_extra_info(): ...@@ -86,7 +86,7 @@ def test_debugprint_sitsot_no_extra_info():
expected_output = """Subtensor{i} [id A] expected_output = """Subtensor{i} [id A]
├─ Subtensor{start:} [id B] ├─ Subtensor{start:} [id B]
│ ├─ for{cpu,scan_fn} [id C] │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id C]
│ │ ├─ k [id D] │ │ ├─ k [id D]
│ │ ├─ SetSubtensor{:stop} [id E] │ │ ├─ SetSubtensor{:stop} [id E]
│ │ │ ├─ AllocEmpty{dtype='float64'} [id F] │ │ │ ├─ AllocEmpty{dtype='float64'} [id F]
...@@ -117,7 +117,7 @@ def test_debugprint_sitsot_no_extra_info(): ...@@ -117,7 +117,7 @@ def test_debugprint_sitsot_no_extra_info():
Inner graphs: Inner graphs:
for{cpu,scan_fn} [id C] Scan{scan_fn, while_loop=False, inplace=none} [id C]
← Mul [id W] ← Mul [id W]
├─ *0-<TensorType(float64, (?,))> [id X] -> [id E] ├─ *0-<TensorType(float64, (?,))> [id X] -> [id E]
└─ *1-<TensorType(float64, (?,))> [id Y] -> [id M]""" └─ *1-<TensorType(float64, (?,))> [id Y] -> [id M]"""
...@@ -148,7 +148,7 @@ def test_debugprint_nitsot(): ...@@ -148,7 +148,7 @@ def test_debugprint_nitsot():
lines = output_str.split("\n") lines = output_str.split("\n")
expected_output = """Sum{axes=None} [id A] expected_output = """Sum{axes=None} [id A]
└─ for{cpu,scan_fn} [id B] (outer_out_nit_sot-0) └─ Scan{scan_fn, while_loop=False, inplace=none} [id B] (outer_out_nit_sot-0)
├─ Minimum [id C] (outer_in_nit_sot-0) ├─ Minimum [id C] (outer_in_nit_sot-0)
│ ├─ Subtensor{i} [id D] │ ├─ Subtensor{i} [id D]
│ │ ├─ Shape [id E] │ │ ├─ Shape [id E]
...@@ -183,7 +183,7 @@ def test_debugprint_nitsot(): ...@@ -183,7 +183,7 @@ def test_debugprint_nitsot():
Inner graphs: Inner graphs:
for{cpu,scan_fn} [id B] Scan{scan_fn, while_loop=False, inplace=none} [id B]
← Mul [id X] (inner_out_nit_sot-0) ← Mul [id X] (inner_out_nit_sot-0)
├─ *0-<TensorType(float64, ())> [id Y] -> [id S] (inner_in_seqs-0) ├─ *0-<TensorType(float64, ())> [id Y] -> [id S] (inner_in_seqs-0)
└─ Pow [id Z] └─ Pow [id Z]
...@@ -226,7 +226,7 @@ def test_debugprint_nested_scans(): ...@@ -226,7 +226,7 @@ def test_debugprint_nested_scans():
lines = output_str.split("\n") lines = output_str.split("\n")
expected_output = """Sum{axes=None} [id A] expected_output = """Sum{axes=None} [id A]
└─ for{cpu,scan_fn} [id B] (outer_out_nit_sot-0) └─ Scan{scan_fn, while_loop=False, inplace=none} [id B] (outer_out_nit_sot-0)
├─ Minimum [id C] (outer_in_nit_sot-0) ├─ Minimum [id C] (outer_in_nit_sot-0)
│ ├─ Subtensor{i} [id D] │ ├─ Subtensor{i} [id D]
│ │ ├─ Shape [id E] │ │ ├─ Shape [id E]
...@@ -262,14 +262,14 @@ def test_debugprint_nested_scans(): ...@@ -262,14 +262,14 @@ def test_debugprint_nested_scans():
Inner graphs: Inner graphs:
for{cpu,scan_fn} [id B] Scan{scan_fn, while_loop=False, inplace=none} [id B]
← Mul [id Y] (inner_out_nit_sot-0) ← Mul [id Y] (inner_out_nit_sot-0)
├─ ExpandDims{axis=0} [id Z] ├─ ExpandDims{axis=0} [id Z]
│ └─ *0-<TensorType(float64, ())> [id BA] -> [id S] (inner_in_seqs-0) │ └─ *0-<TensorType(float64, ())> [id BA] -> [id S] (inner_in_seqs-0)
└─ Pow [id BB] └─ Pow [id BB]
├─ Subtensor{i} [id BC] ├─ Subtensor{i} [id BC]
│ ├─ Subtensor{start:} [id BD] │ ├─ Subtensor{start:} [id BD]
│ │ ├─ for{cpu,scan_fn} [id BE] (outer_out_sit_sot-0) │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id BE] (outer_out_sit_sot-0)
│ │ │ ├─ *3-<TensorType(int32, ())> [id BF] -> [id X] (inner_in_non_seqs-1) (n_steps) │ │ │ ├─ *3-<TensorType(int32, ())> [id BF] -> [id X] (inner_in_non_seqs-1) (n_steps)
│ │ │ ├─ SetSubtensor{:stop} [id BG] (outer_in_sit_sot-0) │ │ │ ├─ SetSubtensor{:stop} [id BG] (outer_in_sit_sot-0)
│ │ │ │ ├─ AllocEmpty{dtype='float64'} [id BH] │ │ │ │ ├─ AllocEmpty{dtype='float64'} [id BH]
...@@ -300,7 +300,7 @@ def test_debugprint_nested_scans(): ...@@ -300,7 +300,7 @@ def test_debugprint_nested_scans():
└─ ExpandDims{axis=0} [id BY] └─ ExpandDims{axis=0} [id BY]
└─ *1-<TensorType(int64, ())> [id BZ] -> [id U] (inner_in_seqs-1) └─ *1-<TensorType(int64, ())> [id BZ] -> [id U] (inner_in_seqs-1)
for{cpu,scan_fn} [id BE] Scan{scan_fn, while_loop=False, inplace=none} [id BE]
← Mul [id CA] (inner_out_sit_sot-0) ← Mul [id CA] (inner_out_sit_sot-0)
├─ *0-<TensorType(float64, (?,))> [id CB] -> [id BG] (inner_in_sit_sot-0) ├─ *0-<TensorType(float64, (?,))> [id CB] -> [id BG] (inner_in_sit_sot-0)
└─ *1-<TensorType(float64, (?,))> [id CC] -> [id BO] (inner_in_non_seqs-0)""" └─ *1-<TensorType(float64, (?,))> [id CC] -> [id BO] (inner_in_non_seqs-0)"""
...@@ -319,7 +319,7 @@ def test_debugprint_nested_scans(): ...@@ -319,7 +319,7 @@ def test_debugprint_nested_scans():
→ k [id B] → k [id B]
→ A [id C] → A [id C]
Sum{axes=None} [id D] 13 Sum{axes=None} [id D] 13
└─ for{cpu,scan_fn} [id E] 12 (outer_out_nit_sot-0) └─ Scan{scan_fn, while_loop=False, inplace=none} [id E] 12 (outer_out_nit_sot-0)
├─ Minimum [id F] 7 (outer_in_nit_sot-0) ├─ Minimum [id F] 7 (outer_in_nit_sot-0)
│ ├─ Subtensor{i} [id G] 6 │ ├─ Subtensor{i} [id G] 6
│ │ ├─ Shape [id H] 5 │ │ ├─ Shape [id H] 5
...@@ -355,7 +355,7 @@ def test_debugprint_nested_scans(): ...@@ -355,7 +355,7 @@ def test_debugprint_nested_scans():
Inner graphs: Inner graphs:
for{cpu,scan_fn} [id E] Scan{scan_fn, while_loop=False, inplace=none} [id E]
→ *0-<TensorType(float64, ())> [id Y] -> [id U] (inner_in_seqs-0) → *0-<TensorType(float64, ())> [id Y] -> [id U] (inner_in_seqs-0)
→ *1-<TensorType(int64, ())> [id Z] -> [id W] (inner_in_seqs-1) → *1-<TensorType(int64, ())> [id Z] -> [id W] (inner_in_seqs-1)
→ *2-<TensorType(float64, (?,))> [id BA] -> [id C] (inner_in_non_seqs-0) → *2-<TensorType(float64, (?,))> [id BA] -> [id C] (inner_in_non_seqs-0)
...@@ -366,7 +366,7 @@ def test_debugprint_nested_scans(): ...@@ -366,7 +366,7 @@ def test_debugprint_nested_scans():
└─ Pow [id BE] └─ Pow [id BE]
├─ Subtensor{i} [id BF] ├─ Subtensor{i} [id BF]
│ ├─ Subtensor{start:} [id BG] │ ├─ Subtensor{start:} [id BG]
│ │ ├─ for{cpu,scan_fn} [id BH] (outer_out_sit_sot-0) │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id BH] (outer_out_sit_sot-0)
│ │ │ ├─ *3-<TensorType(int32, ())> [id BB] (inner_in_non_seqs-1) (n_steps) │ │ │ ├─ *3-<TensorType(int32, ())> [id BB] (inner_in_non_seqs-1) (n_steps)
│ │ │ ├─ SetSubtensor{:stop} [id BI] (outer_in_sit_sot-0) │ │ │ ├─ SetSubtensor{:stop} [id BI] (outer_in_sit_sot-0)
│ │ │ │ ├─ AllocEmpty{dtype='float64'} [id BJ] │ │ │ │ ├─ AllocEmpty{dtype='float64'} [id BJ]
...@@ -397,7 +397,7 @@ def test_debugprint_nested_scans(): ...@@ -397,7 +397,7 @@ def test_debugprint_nested_scans():
└─ ExpandDims{axis=0} [id BZ] └─ ExpandDims{axis=0} [id BZ]
└─ *1-<TensorType(int64, ())> [id Z] (inner_in_seqs-1) └─ *1-<TensorType(int64, ())> [id Z] (inner_in_seqs-1)
for{cpu,scan_fn} [id BH] Scan{scan_fn, while_loop=False, inplace=none} [id BH]
→ *0-<TensorType(float64, (?,))> [id CA] -> [id BI] (inner_in_sit_sot-0) → *0-<TensorType(float64, (?,))> [id CA] -> [id BI] (inner_in_sit_sot-0)
→ *1-<TensorType(float64, (?,))> [id CB] -> [id BA] (inner_in_non_seqs-0) → *1-<TensorType(float64, (?,))> [id CB] -> [id BA] (inner_in_non_seqs-0)
← Mul [id CC] (inner_out_sit_sot-0) ← Mul [id CC] (inner_out_sit_sot-0)
...@@ -431,7 +431,7 @@ def test_debugprint_mitsot(): ...@@ -431,7 +431,7 @@ def test_debugprint_mitsot():
expected_output = """Add [id A] expected_output = """Add [id A]
├─ Subtensor{start:} [id B] ├─ Subtensor{start:} [id B]
│ ├─ for{cpu,scan_fn}.0 [id C] (outer_out_mit_sot-0) │ ├─ Scan{scan_fn, while_loop=False, inplace=none}.0 [id C] (outer_out_mit_sot-0)
│ │ ├─ TensorConstant{5} [id D] (n_steps) │ │ ├─ TensorConstant{5} [id D] (n_steps)
│ │ ├─ SetSubtensor{:stop} [id E] (outer_in_mit_sot-0) │ │ ├─ SetSubtensor{:stop} [id E] (outer_in_mit_sot-0)
│ │ │ ├─ AllocEmpty{dtype='int64'} [id F] │ │ │ ├─ AllocEmpty{dtype='int64'} [id F]
...@@ -465,13 +465,13 @@ def test_debugprint_mitsot(): ...@@ -465,13 +465,13 @@ def test_debugprint_mitsot():
│ │ └─ ··· │ │ └─ ···
│ └─ ScalarConstant{2} [id Y] │ └─ ScalarConstant{2} [id Y]
└─ Subtensor{start:} [id Z] └─ Subtensor{start:} [id Z]
├─ for{cpu,scan_fn}.1 [id C] (outer_out_mit_sot-1) ├─ Scan{scan_fn, while_loop=False, inplace=none}.1 [id C] (outer_out_mit_sot-1)
│ └─ ··· │ └─ ···
└─ ScalarConstant{2} [id BA] └─ ScalarConstant{2} [id BA]
Inner graphs: Inner graphs:
for{cpu,scan_fn} [id C] Scan{scan_fn, while_loop=False, inplace=none} [id C]
← Add [id BB] (inner_out_mit_sot-0) ← Add [id BB] (inner_out_mit_sot-0)
├─ *1-<TensorType(int64, ())> [id BC] -> [id E] (inner_in_mit_sot-0-1) ├─ *1-<TensorType(int64, ())> [id BC] -> [id E] (inner_in_mit_sot-0-1)
└─ *0-<TensorType(int64, ())> [id BD] -> [id E] (inner_in_mit_sot-0-0) └─ *0-<TensorType(int64, ())> [id BD] -> [id E] (inner_in_mit_sot-0-0)
...@@ -502,11 +502,11 @@ def test_debugprint_mitmot(): ...@@ -502,11 +502,11 @@ def test_debugprint_mitmot():
lines = output_str.split("\n") lines = output_str.split("\n")
expected_output = """Subtensor{i} [id A] expected_output = """Subtensor{i} [id A]
├─ for{cpu,grad_of_scan_fn}.1 [id B] (outer_out_sit_sot-0) ├─ Scan{grad_of_scan_fn, while_loop=False, inplace=none}.1 [id B] (outer_out_sit_sot-0)
│ ├─ Sub [id C] (n_steps) │ ├─ Sub [id C] (n_steps)
│ │ ├─ Subtensor{i} [id D] │ │ ├─ Subtensor{i} [id D]
│ │ │ ├─ Shape [id E] │ │ │ ├─ Shape [id E]
│ │ │ │ └─ for{cpu,scan_fn} [id F] (outer_out_sit_sot-0) │ │ │ │ └─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
│ │ │ │ ├─ k [id G] (n_steps) │ │ │ │ ├─ k [id G] (n_steps)
│ │ │ │ ├─ SetSubtensor{:stop} [id H] (outer_in_sit_sot-0) │ │ │ │ ├─ SetSubtensor{:stop} [id H] (outer_in_sit_sot-0)
│ │ │ │ │ ├─ AllocEmpty{dtype='float64'} [id I] │ │ │ │ │ ├─ AllocEmpty{dtype='float64'} [id I]
...@@ -537,7 +537,7 @@ def test_debugprint_mitmot(): ...@@ -537,7 +537,7 @@ def test_debugprint_mitmot():
│ ├─ Subtensor{:stop} [id Z] (outer_in_seqs-0) │ ├─ Subtensor{:stop} [id Z] (outer_in_seqs-0)
│ │ ├─ Subtensor{::step} [id BA] │ │ ├─ Subtensor{::step} [id BA]
│ │ │ ├─ Subtensor{:stop} [id BB] │ │ │ ├─ Subtensor{:stop} [id BB]
│ │ │ │ ├─ for{cpu,scan_fn} [id F] (outer_out_sit_sot-0) │ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
│ │ │ │ │ └─ ··· │ │ │ │ │ └─ ···
│ │ │ │ └─ ScalarConstant{-1} [id BC] │ │ │ │ └─ ScalarConstant{-1} [id BC]
│ │ │ └─ ScalarConstant{-1} [id BD] │ │ │ └─ ScalarConstant{-1} [id BD]
...@@ -547,7 +547,7 @@ def test_debugprint_mitmot(): ...@@ -547,7 +547,7 @@ def test_debugprint_mitmot():
│ ├─ Subtensor{:stop} [id BF] (outer_in_seqs-1) │ ├─ Subtensor{:stop} [id BF] (outer_in_seqs-1)
│ │ ├─ Subtensor{:stop} [id BG] │ │ ├─ Subtensor{:stop} [id BG]
│ │ │ ├─ Subtensor{::step} [id BH] │ │ │ ├─ Subtensor{::step} [id BH]
│ │ │ │ ├─ for{cpu,scan_fn} [id F] (outer_out_sit_sot-0) │ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
│ │ │ │ │ └─ ··· │ │ │ │ │ └─ ···
│ │ │ │ └─ ScalarConstant{-1} [id BI] │ │ │ │ └─ ScalarConstant{-1} [id BI]
│ │ │ └─ ScalarConstant{-1} [id BJ] │ │ │ └─ ScalarConstant{-1} [id BJ]
...@@ -557,14 +557,14 @@ def test_debugprint_mitmot(): ...@@ -557,14 +557,14 @@ def test_debugprint_mitmot():
│ ├─ Subtensor{::step} [id BL] (outer_in_mit_mot-0) │ ├─ Subtensor{::step} [id BL] (outer_in_mit_mot-0)
│ │ ├─ IncSubtensor{start:} [id BM] │ │ ├─ IncSubtensor{start:} [id BM]
│ │ │ ├─ Second [id BN] │ │ │ ├─ Second [id BN]
│ │ │ │ ├─ for{cpu,scan_fn} [id F] (outer_out_sit_sot-0) │ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
│ │ │ │ │ └─ ··· │ │ │ │ │ └─ ···
│ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BO] │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BO]
│ │ │ │ └─ TensorConstant{0.0} [id BP] │ │ │ │ └─ TensorConstant{0.0} [id BP]
│ │ │ ├─ IncSubtensor{i} [id BQ] │ │ │ ├─ IncSubtensor{i} [id BQ]
│ │ │ │ ├─ Second [id BR] │ │ │ │ ├─ Second [id BR]
│ │ │ │ │ ├─ Subtensor{start:} [id BS] │ │ │ │ │ ├─ Subtensor{start:} [id BS]
│ │ │ │ │ │ ├─ for{cpu,scan_fn} [id F] (outer_out_sit_sot-0) │ │ │ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
│ │ │ │ │ │ │ └─ ··· │ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ └─ ScalarConstant{1} [id BT] │ │ │ │ │ │ └─ ScalarConstant{1} [id BT]
│ │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BU] │ │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BU]
...@@ -598,7 +598,7 @@ def test_debugprint_mitmot(): ...@@ -598,7 +598,7 @@ def test_debugprint_mitmot():
Inner graphs: Inner graphs:
for{cpu,grad_of_scan_fn} [id B] Scan{grad_of_scan_fn, while_loop=False, inplace=none} [id B]
← Add [id CM] (inner_out_mit_mot-0-0) ← Add [id CM] (inner_out_mit_mot-0-0)
├─ Mul [id CN] ├─ Mul [id CN]
│ ├─ *2-<TensorType(float64, (?,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0) │ ├─ *2-<TensorType(float64, (?,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0)
...@@ -610,7 +610,7 @@ def test_debugprint_mitmot(): ...@@ -610,7 +610,7 @@ def test_debugprint_mitmot():
│ └─ *0-<TensorType(float64, (?,))> [id CT] -> [id Z] (inner_in_seqs-0) │ └─ *0-<TensorType(float64, (?,))> [id CT] -> [id Z] (inner_in_seqs-0)
└─ *4-<TensorType(float64, (?,))> [id CU] -> [id CE] (inner_in_sit_sot-0) └─ *4-<TensorType(float64, (?,))> [id CU] -> [id CE] (inner_in_sit_sot-0)
for{cpu,scan_fn} [id F] Scan{scan_fn, while_loop=False, inplace=none} [id F]
← Mul [id CV] (inner_out_sit_sot-0) ← Mul [id CV] (inner_out_sit_sot-0)
├─ *0-<TensorType(float64, (?,))> [id CT] -> [id H] (inner_in_sit_sot-0) ├─ *0-<TensorType(float64, (?,))> [id CT] -> [id H] (inner_in_sit_sot-0)
└─ *1-<TensorType(float64, (?,))> [id CW] -> [id P] (inner_in_non_seqs-0)""" └─ *1-<TensorType(float64, (?,))> [id CW] -> [id P] (inner_in_non_seqs-0)"""
...@@ -641,7 +641,7 @@ def test_debugprint_compiled_fn(): ...@@ -641,7 +641,7 @@ def test_debugprint_compiled_fn():
# (i.e. from `Scan._fn`) # (i.e. from `Scan._fn`)
out = pytensor.function([M], out, updates=updates, mode="FAST_RUN") out = pytensor.function([M], out, updates=updates, mode="FAST_RUN")
expected_output = """forall_inplace,cpu,scan_fn} [id A] 2 (outer_out_sit_sot-0) expected_output = """Scan{scan_fn, while_loop=False, inplace=all} [id A] 2 (outer_out_sit_sot-0)
├─ TensorConstant{20000} [id B] (n_steps) ├─ TensorConstant{20000} [id B] (n_steps)
├─ TensorConstant{[ 0 ..998 19999]} [id C] (outer_in_seqs-0) ├─ TensorConstant{[ 0 ..998 19999]} [id C] (outer_in_seqs-0)
├─ SetSubtensor{:stop} [id D] 1 (outer_in_sit_sot-0) ├─ SetSubtensor{:stop} [id D] 1 (outer_in_sit_sot-0)
...@@ -653,7 +653,7 @@ def test_debugprint_compiled_fn(): ...@@ -653,7 +653,7 @@ def test_debugprint_compiled_fn():
Inner graphs: Inner graphs:
forall_inplace,cpu,scan_fn} [id A] Scan{scan_fn, while_loop=False, inplace=all} [id A]
← Composite{switch(lt(i0, i1), i2, i0)} [id I] (inner_out_sit_sot-0) ← Composite{switch(lt(i0, i1), i2, i0)} [id I] (inner_out_sit_sot-0)
├─ TensorConstant{0} [id J] ├─ TensorConstant{0} [id J]
├─ Subtensor{i, j, k} [id K] ├─ Subtensor{i, j, k} [id K]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论