提交 88516d2e authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Simplify Scan string representation

上级 f6bb307d
......@@ -1282,27 +1282,18 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
)
def __str__(self):
device_str = "cpu"
if self.info.as_while:
name = "do_while"
else:
name = "for"
aux_txt = "%s"
inplace = "none"
if len(self.destroy_map.keys()) > 0:
# Check if all outputs are inplace
if sorted(self.destroy_map.keys()) == sorted(
range(self.info.n_mit_mot + self.info.n_mit_sot + self.info.n_sit_sot)
):
aux_txt += "all_inplace,%s,%s}"
inplace = "all"
else:
aux_txt += "{inplace{"
for k in self.destroy_map.keys():
aux_txt += str(k) + ","
aux_txt += "},%s,%s}"
else:
aux_txt += "{%s,%s}"
aux_txt = aux_txt % (name, device_str, str(self.name))
return aux_txt
inplace = str(list(self.destroy_map.keys()))
return (
f"Scan{{{self.name}, while_loop={self.info.as_while}, inplace={inplace}}}"
)
def __hash__(self):
return hash(
......
......@@ -28,7 +28,7 @@ def test_debugprint_sitsot():
expected_output = """Subtensor{i} [id A]
├─ Subtensor{start:} [id B]
│ ├─ for{cpu,scan_fn} [id C] (outer_out_sit_sot-0)
│ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id C] (outer_out_sit_sot-0)
│ │ ├─ k [id D] (n_steps)
│ │ ├─ SetSubtensor{:stop} [id E] (outer_in_sit_sot-0)
│ │ │ ├─ AllocEmpty{dtype='float64'} [id F]
......@@ -59,7 +59,7 @@ def test_debugprint_sitsot():
Inner graphs:
for{cpu,scan_fn} [id C]
Scan{scan_fn, while_loop=False, inplace=none} [id C]
← Mul [id W] (inner_out_sit_sot-0)
├─ *0-<TensorType(float64, (?,))> [id X] -> [id E] (inner_in_sit_sot-0)
└─ *1-<TensorType(float64, (?,))> [id Y] -> [id M] (inner_in_non_seqs-0)"""
......@@ -86,7 +86,7 @@ def test_debugprint_sitsot_no_extra_info():
expected_output = """Subtensor{i} [id A]
├─ Subtensor{start:} [id B]
│ ├─ for{cpu,scan_fn} [id C]
│ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id C]
│ │ ├─ k [id D]
│ │ ├─ SetSubtensor{:stop} [id E]
│ │ │ ├─ AllocEmpty{dtype='float64'} [id F]
......@@ -117,7 +117,7 @@ def test_debugprint_sitsot_no_extra_info():
Inner graphs:
for{cpu,scan_fn} [id C]
Scan{scan_fn, while_loop=False, inplace=none} [id C]
← Mul [id W]
├─ *0-<TensorType(float64, (?,))> [id X] -> [id E]
└─ *1-<TensorType(float64, (?,))> [id Y] -> [id M]"""
......@@ -148,7 +148,7 @@ def test_debugprint_nitsot():
lines = output_str.split("\n")
expected_output = """Sum{axes=None} [id A]
└─ for{cpu,scan_fn} [id B] (outer_out_nit_sot-0)
└─ Scan{scan_fn, while_loop=False, inplace=none} [id B] (outer_out_nit_sot-0)
├─ Minimum [id C] (outer_in_nit_sot-0)
│ ├─ Subtensor{i} [id D]
│ │ ├─ Shape [id E]
......@@ -183,7 +183,7 @@ def test_debugprint_nitsot():
Inner graphs:
for{cpu,scan_fn} [id B]
Scan{scan_fn, while_loop=False, inplace=none} [id B]
← Mul [id X] (inner_out_nit_sot-0)
├─ *0-<TensorType(float64, ())> [id Y] -> [id S] (inner_in_seqs-0)
└─ Pow [id Z]
......@@ -226,7 +226,7 @@ def test_debugprint_nested_scans():
lines = output_str.split("\n")
expected_output = """Sum{axes=None} [id A]
└─ for{cpu,scan_fn} [id B] (outer_out_nit_sot-0)
└─ Scan{scan_fn, while_loop=False, inplace=none} [id B] (outer_out_nit_sot-0)
├─ Minimum [id C] (outer_in_nit_sot-0)
│ ├─ Subtensor{i} [id D]
│ │ ├─ Shape [id E]
......@@ -262,14 +262,14 @@ def test_debugprint_nested_scans():
Inner graphs:
for{cpu,scan_fn} [id B]
Scan{scan_fn, while_loop=False, inplace=none} [id B]
← Mul [id Y] (inner_out_nit_sot-0)
├─ ExpandDims{axis=0} [id Z]
│ └─ *0-<TensorType(float64, ())> [id BA] -> [id S] (inner_in_seqs-0)
└─ Pow [id BB]
├─ Subtensor{i} [id BC]
│ ├─ Subtensor{start:} [id BD]
│ │ ├─ for{cpu,scan_fn} [id BE] (outer_out_sit_sot-0)
│ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id BE] (outer_out_sit_sot-0)
│ │ │ ├─ *3-<TensorType(int32, ())> [id BF] -> [id X] (inner_in_non_seqs-1) (n_steps)
│ │ │ ├─ SetSubtensor{:stop} [id BG] (outer_in_sit_sot-0)
│ │ │ │ ├─ AllocEmpty{dtype='float64'} [id BH]
......@@ -300,7 +300,7 @@ def test_debugprint_nested_scans():
└─ ExpandDims{axis=0} [id BY]
└─ *1-<TensorType(int64, ())> [id BZ] -> [id U] (inner_in_seqs-1)
for{cpu,scan_fn} [id BE]
Scan{scan_fn, while_loop=False, inplace=none} [id BE]
← Mul [id CA] (inner_out_sit_sot-0)
├─ *0-<TensorType(float64, (?,))> [id CB] -> [id BG] (inner_in_sit_sot-0)
└─ *1-<TensorType(float64, (?,))> [id CC] -> [id BO] (inner_in_non_seqs-0)"""
......@@ -319,7 +319,7 @@ def test_debugprint_nested_scans():
→ k [id B]
→ A [id C]
Sum{axes=None} [id D] 13
└─ for{cpu,scan_fn} [id E] 12 (outer_out_nit_sot-0)
└─ Scan{scan_fn, while_loop=False, inplace=none} [id E] 12 (outer_out_nit_sot-0)
├─ Minimum [id F] 7 (outer_in_nit_sot-0)
│ ├─ Subtensor{i} [id G] 6
│ │ ├─ Shape [id H] 5
......@@ -355,7 +355,7 @@ def test_debugprint_nested_scans():
Inner graphs:
for{cpu,scan_fn} [id E]
Scan{scan_fn, while_loop=False, inplace=none} [id E]
→ *0-<TensorType(float64, ())> [id Y] -> [id U] (inner_in_seqs-0)
→ *1-<TensorType(int64, ())> [id Z] -> [id W] (inner_in_seqs-1)
→ *2-<TensorType(float64, (?,))> [id BA] -> [id C] (inner_in_non_seqs-0)
......@@ -366,7 +366,7 @@ def test_debugprint_nested_scans():
└─ Pow [id BE]
├─ Subtensor{i} [id BF]
│ ├─ Subtensor{start:} [id BG]
│ │ ├─ for{cpu,scan_fn} [id BH] (outer_out_sit_sot-0)
│ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id BH] (outer_out_sit_sot-0)
│ │ │ ├─ *3-<TensorType(int32, ())> [id BB] (inner_in_non_seqs-1) (n_steps)
│ │ │ ├─ SetSubtensor{:stop} [id BI] (outer_in_sit_sot-0)
│ │ │ │ ├─ AllocEmpty{dtype='float64'} [id BJ]
......@@ -397,7 +397,7 @@ def test_debugprint_nested_scans():
└─ ExpandDims{axis=0} [id BZ]
└─ *1-<TensorType(int64, ())> [id Z] (inner_in_seqs-1)
for{cpu,scan_fn} [id BH]
Scan{scan_fn, while_loop=False, inplace=none} [id BH]
→ *0-<TensorType(float64, (?,))> [id CA] -> [id BI] (inner_in_sit_sot-0)
→ *1-<TensorType(float64, (?,))> [id CB] -> [id BA] (inner_in_non_seqs-0)
← Mul [id CC] (inner_out_sit_sot-0)
......@@ -431,7 +431,7 @@ def test_debugprint_mitsot():
expected_output = """Add [id A]
├─ Subtensor{start:} [id B]
│ ├─ for{cpu,scan_fn}.0 [id C] (outer_out_mit_sot-0)
│ ├─ Scan{scan_fn, while_loop=False, inplace=none}.0 [id C] (outer_out_mit_sot-0)
│ │ ├─ TensorConstant{5} [id D] (n_steps)
│ │ ├─ SetSubtensor{:stop} [id E] (outer_in_mit_sot-0)
│ │ │ ├─ AllocEmpty{dtype='int64'} [id F]
......@@ -465,13 +465,13 @@ def test_debugprint_mitsot():
│ │ └─ ···
│ └─ ScalarConstant{2} [id Y]
└─ Subtensor{start:} [id Z]
├─ for{cpu,scan_fn}.1 [id C] (outer_out_mit_sot-1)
├─ Scan{scan_fn, while_loop=False, inplace=none}.1 [id C] (outer_out_mit_sot-1)
│ └─ ···
└─ ScalarConstant{2} [id BA]
Inner graphs:
for{cpu,scan_fn} [id C]
Scan{scan_fn, while_loop=False, inplace=none} [id C]
← Add [id BB] (inner_out_mit_sot-0)
├─ *1-<TensorType(int64, ())> [id BC] -> [id E] (inner_in_mit_sot-0-1)
└─ *0-<TensorType(int64, ())> [id BD] -> [id E] (inner_in_mit_sot-0-0)
......@@ -502,11 +502,11 @@ def test_debugprint_mitmot():
lines = output_str.split("\n")
expected_output = """Subtensor{i} [id A]
├─ for{cpu,grad_of_scan_fn}.1 [id B] (outer_out_sit_sot-0)
├─ Scan{grad_of_scan_fn, while_loop=False, inplace=none}.1 [id B] (outer_out_sit_sot-0)
│ ├─ Sub [id C] (n_steps)
│ │ ├─ Subtensor{i} [id D]
│ │ │ ├─ Shape [id E]
│ │ │ │ └─ for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
│ │ │ │ └─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
│ │ │ │ ├─ k [id G] (n_steps)
│ │ │ │ ├─ SetSubtensor{:stop} [id H] (outer_in_sit_sot-0)
│ │ │ │ │ ├─ AllocEmpty{dtype='float64'} [id I]
......@@ -537,7 +537,7 @@ def test_debugprint_mitmot():
│ ├─ Subtensor{:stop} [id Z] (outer_in_seqs-0)
│ │ ├─ Subtensor{::step} [id BA]
│ │ │ ├─ Subtensor{:stop} [id BB]
│ │ │ │ ├─ for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
│ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
│ │ │ │ │ └─ ···
│ │ │ │ └─ ScalarConstant{-1} [id BC]
│ │ │ └─ ScalarConstant{-1} [id BD]
......@@ -547,7 +547,7 @@ def test_debugprint_mitmot():
│ ├─ Subtensor{:stop} [id BF] (outer_in_seqs-1)
│ │ ├─ Subtensor{:stop} [id BG]
│ │ │ ├─ Subtensor{::step} [id BH]
│ │ │ │ ├─ for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
│ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
│ │ │ │ │ └─ ···
│ │ │ │ └─ ScalarConstant{-1} [id BI]
│ │ │ └─ ScalarConstant{-1} [id BJ]
......@@ -557,14 +557,14 @@ def test_debugprint_mitmot():
│ ├─ Subtensor{::step} [id BL] (outer_in_mit_mot-0)
│ │ ├─ IncSubtensor{start:} [id BM]
│ │ │ ├─ Second [id BN]
│ │ │ │ ├─ for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
│ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
│ │ │ │ │ └─ ···
│ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BO]
│ │ │ │ └─ TensorConstant{0.0} [id BP]
│ │ │ ├─ IncSubtensor{i} [id BQ]
│ │ │ │ ├─ Second [id BR]
│ │ │ │ │ ├─ Subtensor{start:} [id BS]
│ │ │ │ │ │ ├─ for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
│ │ │ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
│ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ └─ ScalarConstant{1} [id BT]
│ │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BU]
......@@ -598,7 +598,7 @@ def test_debugprint_mitmot():
Inner graphs:
for{cpu,grad_of_scan_fn} [id B]
Scan{grad_of_scan_fn, while_loop=False, inplace=none} [id B]
← Add [id CM] (inner_out_mit_mot-0-0)
├─ Mul [id CN]
│ ├─ *2-<TensorType(float64, (?,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0)
......@@ -610,7 +610,7 @@ def test_debugprint_mitmot():
│ └─ *0-<TensorType(float64, (?,))> [id CT] -> [id Z] (inner_in_seqs-0)
└─ *4-<TensorType(float64, (?,))> [id CU] -> [id CE] (inner_in_sit_sot-0)
for{cpu,scan_fn} [id F]
Scan{scan_fn, while_loop=False, inplace=none} [id F]
← Mul [id CV] (inner_out_sit_sot-0)
├─ *0-<TensorType(float64, (?,))> [id CT] -> [id H] (inner_in_sit_sot-0)
└─ *1-<TensorType(float64, (?,))> [id CW] -> [id P] (inner_in_non_seqs-0)"""
......@@ -641,7 +641,7 @@ def test_debugprint_compiled_fn():
# (i.e. from `Scan._fn`)
out = pytensor.function([M], out, updates=updates, mode="FAST_RUN")
expected_output = """forall_inplace,cpu,scan_fn} [id A] 2 (outer_out_sit_sot-0)
expected_output = """Scan{scan_fn, while_loop=False, inplace=all} [id A] 2 (outer_out_sit_sot-0)
├─ TensorConstant{20000} [id B] (n_steps)
├─ TensorConstant{[ 0 ..998 19999]} [id C] (outer_in_seqs-0)
├─ SetSubtensor{:stop} [id D] 1 (outer_in_sit_sot-0)
......@@ -653,7 +653,7 @@ def test_debugprint_compiled_fn():
Inner graphs:
forall_inplace,cpu,scan_fn} [id A]
Scan{scan_fn, while_loop=False, inplace=all} [id A]
← Composite{switch(lt(i0, i1), i2, i0)} [id I] (inner_out_sit_sot-0)
├─ TensorConstant{0} [id J]
├─ Subtensor{i, j, k} [id K]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论