提交 f6bb307d authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Improve string representation of Subtensor Ops

上级 b74cf3f6
......@@ -840,22 +840,34 @@ class Subtensor(COp):
@staticmethod
def str_from_slice(entry):
msg = []
for x in [entry.start, entry.stop, entry.step]:
if x is None:
msg.append("")
else:
msg.append(str(x))
return ":".join(msg)
if entry.step:
return ":".join(
(
"start" if entry.start else "",
"stop" if entry.stop else "",
"step",
)
)
if entry.stop:
return f"{'start' if entry.start else ''}:stop"
if entry.start:
return "start:"
return ":"
def __str__(self):
@staticmethod
def str_from_indices(idx_list):
indices = []
for entry in self.idx_list:
letter_indexes = 0
for entry in idx_list:
if isinstance(entry, slice):
indices.append(self.str_from_slice(entry))
indices.append(Subtensor.str_from_slice(entry))
else:
indices.append(str(entry))
return f"{self.__class__.__name__}{{{', '.join(indices)}}}"
indices.append("ijk"[letter_indexes % 3] * (letter_indexes // 3 + 1))
letter_indexes += 1
return ", ".join(indices)
def __str__(self):
return f"{self.__class__.__name__}{{{self.str_from_indices(self.idx_list)}}}"
@staticmethod
def default_helper_c_code_args():
......@@ -1498,21 +1510,8 @@ class IncSubtensor(COp):
return hash((type(self), idx_list, self.inplace, self.set_instead_of_inc))
def __str__(self):
indices = []
for entry in self.idx_list:
if isinstance(entry, slice):
indices.append(Subtensor.str_from_slice(entry))
else:
indices.append(str(entry))
if self.inplace:
msg = "Inplace"
else:
msg = ""
if not self.set_instead_of_inc:
msg += "Inc"
else:
msg += "Set"
return f"{self.__class__.__name__}{{{msg};{', '.join(indices)}}}"
name = "SetSubtensor" if self.set_instead_of_inc else "IncSubtensor"
return f"{name}{{{Subtensor.str_from_indices(self.idx_list)}}}"
def make_node(self, x, y, *inputs):
"""
......@@ -2661,10 +2660,10 @@ class AdvancedIncSubtensor(Op):
self.ignore_duplicates = ignore_duplicates
def __str__(self):
return "{}{{{}, {}}}".format(
self.__class__.__name__,
"inplace=" + str(self.inplace),
" set_instead_of_inc=" + str(self.set_instead_of_inc),
return (
"AdvancedSetSubtensor"
if self.set_instead_of_inc
else "AdvancedIncSubtensor"
)
def make_node(self, x, y, *inputs):
......
......@@ -26,15 +26,15 @@ def test_debugprint_sitsot():
output_str = debugprint(final_result, file="str", print_op_info=True)
lines = output_str.split("\n")
expected_output = """Subtensor{int64} [id A]
├─ Subtensor{int64::} [id B]
expected_output = """Subtensor{i} [id A]
├─ Subtensor{start:} [id B]
│ ├─ for{cpu,scan_fn} [id C] (outer_out_sit_sot-0)
│ │ ├─ k [id D] (n_steps)
│ │ ├─ IncSubtensor{Set;:int64:} [id E] (outer_in_sit_sot-0)
│ │ ├─ SetSubtensor{:stop} [id E] (outer_in_sit_sot-0)
│ │ │ ├─ AllocEmpty{dtype='float64'} [id F]
│ │ │ │ ├─ Add [id G]
│ │ │ │ │ ├─ k [id D]
│ │ │ │ │ └─ Subtensor{int64} [id H]
│ │ │ │ │ └─ Subtensor{i} [id H]
│ │ │ │ │ ├─ Shape [id I]
│ │ │ │ │ │ └─ Unbroadcast{0} [id J]
│ │ │ │ │ │ └─ ExpandDims{axis=0} [id K]
......@@ -43,7 +43,7 @@ def test_debugprint_sitsot():
│ │ │ │ │ │ └─ ExpandDims{axis=0} [id N]
│ │ │ │ │ │ └─ TensorConstant{1.0} [id O]
│ │ │ │ │ └─ ScalarConstant{0} [id P]
│ │ │ │ └─ Subtensor{int64} [id Q]
│ │ │ │ └─ Subtensor{i} [id Q]
│ │ │ │ ├─ Shape [id R]
│ │ │ │ │ └─ Unbroadcast{0} [id J]
│ │ │ │ │ └─ ···
......@@ -51,7 +51,7 @@ def test_debugprint_sitsot():
│ │ │ ├─ Unbroadcast{0} [id J]
│ │ │ │ └─ ···
│ │ │ └─ ScalarFromTensor [id T]
│ │ │ └─ Subtensor{int64} [id H]
│ │ │ └─ Subtensor{i} [id H]
│ │ │ └─ ···
│ │ └─ A [id M] (outer_in_non_seqs-0)
│ └─ ScalarConstant{1} [id U]
......@@ -84,15 +84,15 @@ def test_debugprint_sitsot_no_extra_info():
output_str = debugprint(final_result, file="str", print_op_info=False)
lines = output_str.split("\n")
expected_output = """Subtensor{int64} [id A]
├─ Subtensor{int64::} [id B]
expected_output = """Subtensor{i} [id A]
├─ Subtensor{start:} [id B]
│ ├─ for{cpu,scan_fn} [id C]
│ │ ├─ k [id D]
│ │ ├─ IncSubtensor{Set;:int64:} [id E]
│ │ ├─ SetSubtensor{:stop} [id E]
│ │ │ ├─ AllocEmpty{dtype='float64'} [id F]
│ │ │ │ ├─ Add [id G]
│ │ │ │ │ ├─ k [id D]
│ │ │ │ │ └─ Subtensor{int64} [id H]
│ │ │ │ │ └─ Subtensor{i} [id H]
│ │ │ │ │ ├─ Shape [id I]
│ │ │ │ │ │ └─ Unbroadcast{0} [id J]
│ │ │ │ │ │ └─ ExpandDims{axis=0} [id K]
......@@ -101,7 +101,7 @@ def test_debugprint_sitsot_no_extra_info():
│ │ │ │ │ │ └─ ExpandDims{axis=0} [id N]
│ │ │ │ │ │ └─ TensorConstant{1.0} [id O]
│ │ │ │ │ └─ ScalarConstant{0} [id P]
│ │ │ │ └─ Subtensor{int64} [id Q]
│ │ │ │ └─ Subtensor{i} [id Q]
│ │ │ │ ├─ Shape [id R]
│ │ │ │ │ └─ Unbroadcast{0} [id J]
│ │ │ │ │ └─ ···
......@@ -109,7 +109,7 @@ def test_debugprint_sitsot_no_extra_info():
│ │ │ ├─ Unbroadcast{0} [id J]
│ │ │ │ └─ ···
│ │ │ └─ ScalarFromTensor [id T]
│ │ │ └─ Subtensor{int64} [id H]
│ │ │ └─ Subtensor{i} [id H]
│ │ │ └─ ···
│ │ └─ A [id M]
│ └─ ScalarConstant{1} [id U]
......@@ -150,29 +150,29 @@ def test_debugprint_nitsot():
expected_output = """Sum{axes=None} [id A]
└─ for{cpu,scan_fn} [id B] (outer_out_nit_sot-0)
├─ Minimum [id C] (outer_in_nit_sot-0)
│ ├─ Subtensor{int64} [id D]
│ ├─ Subtensor{i} [id D]
│ │ ├─ Shape [id E]
│ │ │ └─ Subtensor{int64::} [id F] 'coefficients[0:]'
│ │ │ └─ Subtensor{start:} [id F] 'coefficients[0:]'
│ │ │ ├─ coefficients [id G]
│ │ │ └─ ScalarConstant{0} [id H]
│ │ └─ ScalarConstant{0} [id I]
│ └─ Subtensor{int64} [id J]
│ └─ Subtensor{i} [id J]
│ ├─ Shape [id K]
│ │ └─ Subtensor{int64::} [id L]
│ │ └─ Subtensor{start:} [id L]
│ │ ├─ ARange{dtype='int64'} [id M]
│ │ │ ├─ TensorConstant{0} [id N]
│ │ │ ├─ TensorConstant{10000} [id O]
│ │ │ └─ TensorConstant{1} [id P]
│ │ └─ ScalarConstant{0} [id Q]
│ └─ ScalarConstant{0} [id R]
├─ Subtensor{:int64:} [id S] (outer_in_seqs-0)
│ ├─ Subtensor{int64::} [id F] 'coefficients[0:]'
├─ Subtensor{:stop} [id S] (outer_in_seqs-0)
│ ├─ Subtensor{start:} [id F] 'coefficients[0:]'
│ │ └─ ···
│ └─ ScalarFromTensor [id T]
│ └─ Minimum [id C]
│ └─ ···
├─ Subtensor{:int64:} [id U] (outer_in_seqs-1)
│ ├─ Subtensor{int64::} [id L]
├─ Subtensor{:stop} [id U] (outer_in_seqs-1)
│ ├─ Subtensor{start:} [id L]
│ │ └─ ···
│ └─ ScalarFromTensor [id V]
│ └─ Minimum [id C]
......@@ -228,29 +228,29 @@ def test_debugprint_nested_scans():
expected_output = """Sum{axes=None} [id A]
└─ for{cpu,scan_fn} [id B] (outer_out_nit_sot-0)
├─ Minimum [id C] (outer_in_nit_sot-0)
│ ├─ Subtensor{int64} [id D]
│ ├─ Subtensor{i} [id D]
│ │ ├─ Shape [id E]
│ │ │ └─ Subtensor{int64::} [id F] 'c[0:]'
│ │ │ └─ Subtensor{start:} [id F] 'c[0:]'
│ │ │ ├─ c [id G]
│ │ │ └─ ScalarConstant{0} [id H]
│ │ └─ ScalarConstant{0} [id I]
│ └─ Subtensor{int64} [id J]
│ └─ Subtensor{i} [id J]
│ ├─ Shape [id K]
│ │ └─ Subtensor{int64::} [id L]
│ │ └─ Subtensor{start:} [id L]
│ │ ├─ ARange{dtype='int64'} [id M]
│ │ │ ├─ TensorConstant{0} [id N]
│ │ │ ├─ TensorConstant{10} [id O]
│ │ │ └─ TensorConstant{1} [id P]
│ │ └─ ScalarConstant{0} [id Q]
│ └─ ScalarConstant{0} [id R]
├─ Subtensor{:int64:} [id S] (outer_in_seqs-0)
│ ├─ Subtensor{int64::} [id F] 'c[0:]'
├─ Subtensor{:stop} [id S] (outer_in_seqs-0)
│ ├─ Subtensor{start:} [id F] 'c[0:]'
│ │ └─ ···
│ └─ ScalarFromTensor [id T]
│ └─ Minimum [id C]
│ └─ ···
├─ Subtensor{:int64:} [id U] (outer_in_seqs-1)
│ ├─ Subtensor{int64::} [id L]
├─ Subtensor{:stop} [id U] (outer_in_seqs-1)
│ ├─ Subtensor{start:} [id L]
│ │ └─ ···
│ └─ ScalarFromTensor [id V]
│ └─ Minimum [id C]
......@@ -267,15 +267,15 @@ def test_debugprint_nested_scans():
├─ ExpandDims{axis=0} [id Z]
│ └─ *0-<TensorType(float64, ())> [id BA] -> [id S] (inner_in_seqs-0)
└─ Pow [id BB]
├─ Subtensor{int64} [id BC]
│ ├─ Subtensor{int64::} [id BD]
├─ Subtensor{i} [id BC]
│ ├─ Subtensor{start:} [id BD]
│ │ ├─ for{cpu,scan_fn} [id BE] (outer_out_sit_sot-0)
│ │ │ ├─ *3-<TensorType(int32, ())> [id BF] -> [id X] (inner_in_non_seqs-1) (n_steps)
│ │ │ ├─ IncSubtensor{Set;:int64:} [id BG] (outer_in_sit_sot-0)
│ │ │ ├─ SetSubtensor{:stop} [id BG] (outer_in_sit_sot-0)
│ │ │ │ ├─ AllocEmpty{dtype='float64'} [id BH]
│ │ │ │ │ ├─ Add [id BI]
│ │ │ │ │ │ ├─ *3-<TensorType(int32, ())> [id BF] -> [id X] (inner_in_non_seqs-1)
│ │ │ │ │ │ └─ Subtensor{int64} [id BJ]
│ │ │ │ │ │ └─ Subtensor{i} [id BJ]
│ │ │ │ │ │ ├─ Shape [id BK]
│ │ │ │ │ │ │ └─ Unbroadcast{0} [id BL]
│ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BM]
......@@ -284,7 +284,7 @@ def test_debugprint_nested_scans():
│ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BP]
│ │ │ │ │ │ │ └─ TensorConstant{1.0} [id BQ]
│ │ │ │ │ │ └─ ScalarConstant{0} [id BR]
│ │ │ │ │ └─ Subtensor{int64} [id BS]
│ │ │ │ │ └─ Subtensor{i} [id BS]
│ │ │ │ │ ├─ Shape [id BT]
│ │ │ │ │ │ └─ Unbroadcast{0} [id BL]
│ │ │ │ │ │ └─ ···
......@@ -292,7 +292,7 @@ def test_debugprint_nested_scans():
│ │ │ │ ├─ Unbroadcast{0} [id BL]
│ │ │ │ │ └─ ···
│ │ │ │ └─ ScalarFromTensor [id BV]
│ │ │ │ └─ Subtensor{int64} [id BJ]
│ │ │ │ └─ Subtensor{i} [id BJ]
│ │ │ │ └─ ···
│ │ │ └─ *2-<TensorType(float64, (?,))> [id BO] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
│ │ └─ ScalarConstant{1} [id BW]
......@@ -321,29 +321,29 @@ def test_debugprint_nested_scans():
Sum{axes=None} [id D] 13
└─ for{cpu,scan_fn} [id E] 12 (outer_out_nit_sot-0)
├─ Minimum [id F] 7 (outer_in_nit_sot-0)
│ ├─ Subtensor{int64} [id G] 6
│ ├─ Subtensor{i} [id G] 6
│ │ ├─ Shape [id H] 5
│ │ │ └─ Subtensor{int64::} [id I] 'c[0:]' 4
│ │ │ └─ Subtensor{start:} [id I] 'c[0:]' 4
│ │ │ ├─ c [id A]
│ │ │ └─ ScalarConstant{0} [id J]
│ │ └─ ScalarConstant{0} [id K]
│ └─ Subtensor{int64} [id L] 3
│ └─ Subtensor{i} [id L] 3
│ ├─ Shape [id M] 2
│ │ └─ Subtensor{int64::} [id N] 1
│ │ └─ Subtensor{start:} [id N] 1
│ │ ├─ ARange{dtype='int64'} [id O] 0
│ │ │ ├─ TensorConstant{0} [id P]
│ │ │ ├─ TensorConstant{10} [id Q]
│ │ │ └─ TensorConstant{1} [id R]
│ │ └─ ScalarConstant{0} [id S]
│ └─ ScalarConstant{0} [id T]
├─ Subtensor{:int64:} [id U] 11 (outer_in_seqs-0)
│ ├─ Subtensor{int64::} [id I] 'c[0:]' 4
├─ Subtensor{:stop} [id U] 11 (outer_in_seqs-0)
│ ├─ Subtensor{start:} [id I] 'c[0:]' 4
│ │ └─ ···
│ └─ ScalarFromTensor [id V] 10
│ └─ Minimum [id F] 7
│ └─ ···
├─ Subtensor{:int64:} [id W] 9 (outer_in_seqs-1)
│ ├─ Subtensor{int64::} [id N] 1
├─ Subtensor{:stop} [id W] 9 (outer_in_seqs-1)
│ ├─ Subtensor{start:} [id N] 1
│ │ └─ ···
│ └─ ScalarFromTensor [id X] 8
│ └─ Minimum [id F] 7
......@@ -364,15 +364,15 @@ def test_debugprint_nested_scans():
├─ ExpandDims{axis=0} [id BD]
│ └─ *0-<TensorType(float64, ())> [id Y] (inner_in_seqs-0)
└─ Pow [id BE]
├─ Subtensor{int64} [id BF]
│ ├─ Subtensor{int64::} [id BG]
├─ Subtensor{i} [id BF]
│ ├─ Subtensor{start:} [id BG]
│ │ ├─ for{cpu,scan_fn} [id BH] (outer_out_sit_sot-0)
│ │ │ ├─ *3-<TensorType(int32, ())> [id BB] (inner_in_non_seqs-1) (n_steps)
│ │ │ ├─ IncSubtensor{Set;:int64:} [id BI] (outer_in_sit_sot-0)
│ │ │ ├─ SetSubtensor{:stop} [id BI] (outer_in_sit_sot-0)
│ │ │ │ ├─ AllocEmpty{dtype='float64'} [id BJ]
│ │ │ │ │ ├─ Add [id BK]
│ │ │ │ │ │ ├─ *3-<TensorType(int32, ())> [id BB] (inner_in_non_seqs-1)
│ │ │ │ │ │ └─ Subtensor{int64} [id BL]
│ │ │ │ │ │ └─ Subtensor{i} [id BL]
│ │ │ │ │ │ ├─ Shape [id BM]
│ │ │ │ │ │ │ └─ Unbroadcast{0} [id BN]
│ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BO]
......@@ -381,7 +381,7 @@ def test_debugprint_nested_scans():
│ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BQ]
│ │ │ │ │ │ │ └─ TensorConstant{1.0} [id BR]
│ │ │ │ │ │ └─ ScalarConstant{0} [id BS]
│ │ │ │ │ └─ Subtensor{int64} [id BT]
│ │ │ │ │ └─ Subtensor{i} [id BT]
│ │ │ │ │ ├─ Shape [id BU]
│ │ │ │ │ │ └─ Unbroadcast{0} [id BN]
│ │ │ │ │ │ └─ ···
......@@ -389,7 +389,7 @@ def test_debugprint_nested_scans():
│ │ │ │ ├─ Unbroadcast{0} [id BN]
│ │ │ │ │ └─ ···
│ │ │ │ └─ ScalarFromTensor [id BW]
│ │ │ │ └─ Subtensor{int64} [id BL]
│ │ │ │ └─ Subtensor{i} [id BL]
│ │ │ │ └─ ···
│ │ │ └─ *2-<TensorType(float64, (?,))> [id BA] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
│ │ └─ ScalarConstant{1} [id BX]
......@@ -430,41 +430,41 @@ def test_debugprint_mitsot():
lines = output_str.split("\n")
expected_output = """Add [id A]
├─ Subtensor{int64::} [id B]
├─ Subtensor{start:} [id B]
│ ├─ for{cpu,scan_fn}.0 [id C] (outer_out_mit_sot-0)
│ │ ├─ TensorConstant{5} [id D] (n_steps)
│ │ ├─ IncSubtensor{Set;:int64:} [id E] (outer_in_mit_sot-0)
│ │ ├─ SetSubtensor{:stop} [id E] (outer_in_mit_sot-0)
│ │ │ ├─ AllocEmpty{dtype='int64'} [id F]
│ │ │ │ └─ Add [id G]
│ │ │ │ ├─ TensorConstant{5} [id D]
│ │ │ │ └─ Subtensor{int64} [id H]
│ │ │ │ └─ Subtensor{i} [id H]
│ │ │ │ ├─ Shape [id I]
│ │ │ │ │ └─ Subtensor{:int64:} [id J]
│ │ │ │ │ └─ Subtensor{:stop} [id J]
│ │ │ │ │ ├─ <TensorType(int64, (?,))> [id K]
│ │ │ │ │ └─ ScalarConstant{2} [id L]
│ │ │ │ └─ ScalarConstant{0} [id M]
│ │ │ ├─ Subtensor{:int64:} [id J]
│ │ │ ├─ Subtensor{:stop} [id J]
│ │ │ │ └─ ···
│ │ │ └─ ScalarFromTensor [id N]
│ │ │ └─ Subtensor{int64} [id H]
│ │ │ └─ Subtensor{i} [id H]
│ │ │ └─ ···
│ │ └─ IncSubtensor{Set;:int64:} [id O] (outer_in_mit_sot-1)
│ │ └─ SetSubtensor{:stop} [id O] (outer_in_mit_sot-1)
│ │ ├─ AllocEmpty{dtype='int64'} [id P]
│ │ │ └─ Add [id Q]
│ │ │ ├─ TensorConstant{5} [id D]
│ │ │ └─ Subtensor{int64} [id R]
│ │ │ └─ Subtensor{i} [id R]
│ │ │ ├─ Shape [id S]
│ │ │ │ └─ Subtensor{:int64:} [id T]
│ │ │ │ └─ Subtensor{:stop} [id T]
│ │ │ │ ├─ <TensorType(int64, (?,))> [id U]
│ │ │ │ └─ ScalarConstant{2} [id V]
│ │ │ └─ ScalarConstant{0} [id W]
│ │ ├─ Subtensor{:int64:} [id T]
│ │ ├─ Subtensor{:stop} [id T]
│ │ │ └─ ···
│ │ └─ ScalarFromTensor [id X]
│ │ └─ Subtensor{int64} [id R]
│ │ └─ Subtensor{i} [id R]
│ │ └─ ···
│ └─ ScalarConstant{2} [id Y]
└─ Subtensor{int64::} [id Z]
└─ Subtensor{start:} [id Z]
├─ for{cpu,scan_fn}.1 [id C] (outer_out_mit_sot-1)
│ └─ ···
└─ ScalarConstant{2} [id BA]
......@@ -501,18 +501,18 @@ def test_debugprint_mitmot():
output_str = debugprint(final_result, file="str", print_op_info=True)
lines = output_str.split("\n")
expected_output = """Subtensor{int64} [id A]
expected_output = """Subtensor{i} [id A]
├─ for{cpu,grad_of_scan_fn}.1 [id B] (outer_out_sit_sot-0)
│ ├─ Sub [id C] (n_steps)
│ │ ├─ Subtensor{int64} [id D]
│ │ ├─ Subtensor{i} [id D]
│ │ │ ├─ Shape [id E]
│ │ │ │ └─ for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
│ │ │ │ ├─ k [id G] (n_steps)
│ │ │ │ ├─ IncSubtensor{Set;:int64:} [id H] (outer_in_sit_sot-0)
│ │ │ │ ├─ SetSubtensor{:stop} [id H] (outer_in_sit_sot-0)
│ │ │ │ │ ├─ AllocEmpty{dtype='float64'} [id I]
│ │ │ │ │ │ ├─ Add [id J]
│ │ │ │ │ │ │ ├─ k [id G]
│ │ │ │ │ │ │ └─ Subtensor{int64} [id K]
│ │ │ │ │ │ │ └─ Subtensor{i} [id K]
│ │ │ │ │ │ │ ├─ Shape [id L]
│ │ │ │ │ │ │ │ └─ Unbroadcast{0} [id M]
│ │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id N]
......@@ -521,7 +521,7 @@ def test_debugprint_mitmot():
│ │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id Q]
│ │ │ │ │ │ │ │ └─ TensorConstant{1.0} [id R]
│ │ │ │ │ │ │ └─ ScalarConstant{0} [id S]
│ │ │ │ │ │ └─ Subtensor{int64} [id T]
│ │ │ │ │ │ └─ Subtensor{i} [id T]
│ │ │ │ │ │ ├─ Shape [id U]
│ │ │ │ │ │ │ └─ Unbroadcast{0} [id M]
│ │ │ │ │ │ │ └─ ···
......@@ -529,14 +529,14 @@ def test_debugprint_mitmot():
│ │ │ │ │ ├─ Unbroadcast{0} [id M]
│ │ │ │ │ │ └─ ···
│ │ │ │ │ └─ ScalarFromTensor [id W]
│ │ │ │ │ └─ Subtensor{int64} [id K]
│ │ │ │ │ └─ Subtensor{i} [id K]
│ │ │ │ │ └─ ···
│ │ │ │ └─ A [id P] (outer_in_non_seqs-0)
│ │ │ └─ ScalarConstant{0} [id X]
│ │ └─ TensorConstant{1} [id Y]
│ ├─ Subtensor{:int64:} [id Z] (outer_in_seqs-0)
│ │ ├─ Subtensor{::int64} [id BA]
│ │ │ ├─ Subtensor{:int64:} [id BB]
│ ├─ Subtensor{:stop} [id Z] (outer_in_seqs-0)
│ │ ├─ Subtensor{::step} [id BA]
│ │ │ ├─ Subtensor{:stop} [id BB]
│ │ │ │ ├─ for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
│ │ │ │ │ └─ ···
│ │ │ │ └─ ScalarConstant{-1} [id BC]
......@@ -544,9 +544,9 @@ def test_debugprint_mitmot():
│ │ └─ ScalarFromTensor [id BE]
│ │ └─ Sub [id C]
│ │ └─ ···
│ ├─ Subtensor{:int64:} [id BF] (outer_in_seqs-1)
│ │ ├─ Subtensor{:int64:} [id BG]
│ │ │ ├─ Subtensor{::int64} [id BH]
│ ├─ Subtensor{:stop} [id BF] (outer_in_seqs-1)
│ │ ├─ Subtensor{:stop} [id BG]
│ │ │ ├─ Subtensor{::step} [id BH]
│ │ │ │ ├─ for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
│ │ │ │ │ └─ ···
│ │ │ │ └─ ScalarConstant{-1} [id BI]
......@@ -554,30 +554,30 @@ def test_debugprint_mitmot():
│ │ └─ ScalarFromTensor [id BK]
│ │ └─ Sub [id C]
│ │ └─ ···
│ ├─ Subtensor{::int64} [id BL] (outer_in_mit_mot-0)
│ │ ├─ IncSubtensor{Inc;int64::} [id BM]
│ ├─ Subtensor{::step} [id BL] (outer_in_mit_mot-0)
│ │ ├─ IncSubtensor{start:} [id BM]
│ │ │ ├─ Second [id BN]
│ │ │ │ ├─ for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
│ │ │ │ │ └─ ···
│ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BO]
│ │ │ │ └─ TensorConstant{0.0} [id BP]
│ │ │ ├─ IncSubtensor{Inc;int64} [id BQ]
│ │ │ ├─ IncSubtensor{i} [id BQ]
│ │ │ │ ├─ Second [id BR]
│ │ │ │ │ ├─ Subtensor{int64::} [id BS]
│ │ │ │ │ ├─ Subtensor{start:} [id BS]
│ │ │ │ │ │ ├─ for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
│ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ └─ ScalarConstant{1} [id BT]
│ │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BU]
│ │ │ │ │ └─ TensorConstant{0.0} [id BV]
│ │ │ │ ├─ Second [id BW]
│ │ │ │ │ ├─ Subtensor{int64} [id BX]
│ │ │ │ │ │ ├─ Subtensor{int64::} [id BS]
│ │ │ │ │ ├─ Subtensor{i} [id BX]
│ │ │ │ │ │ ├─ Subtensor{start:} [id BS]
│ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ └─ ScalarConstant{-1} [id BY]
│ │ │ │ │ └─ ExpandDims{axis=0} [id BZ]
│ │ │ │ │ └─ Second [id CA]
│ │ │ │ │ ├─ Sum{axes=None} [id CB]
│ │ │ │ │ │ └─ Subtensor{int64} [id BX]
│ │ │ │ │ │ └─ Subtensor{i} [id BX]
│ │ │ │ │ │ └─ ···
│ │ │ │ │ └─ TensorConstant{1.0} [id CC]
│ │ │ │ └─ ScalarConstant{-1} [id BY]
......@@ -589,7 +589,7 @@ def test_debugprint_mitmot():
│ │ │ ├─ Sub [id C]
│ │ │ │ └─ ···
│ │ │ └─ TensorConstant{1} [id CH]
│ │ └─ Subtensor{int64} [id CI]
│ │ └─ Subtensor{i} [id CI]
│ │ ├─ Shape [id CJ]
│ │ │ └─ A [id P]
│ │ └─ ScalarConstant{0} [id CK]
......@@ -644,7 +644,7 @@ def test_debugprint_compiled_fn():
expected_output = """forall_inplace,cpu,scan_fn} [id A] 2 (outer_out_sit_sot-0)
├─ TensorConstant{20000} [id B] (n_steps)
├─ TensorConstant{[ 0 ..998 19999]} [id C] (outer_in_seqs-0)
├─ IncSubtensor{InplaceSet;:int64:} [id D] 1 (outer_in_sit_sot-0)
├─ SetSubtensor{:stop} [id D] 1 (outer_in_sit_sot-0)
│ ├─ AllocEmpty{dtype='int64'} [id E] 0
│ │ └─ TensorConstant{20000} [id B]
│ ├─ TensorConstant{(1,) of 0} [id F]
......@@ -656,7 +656,7 @@ def test_debugprint_compiled_fn():
forall_inplace,cpu,scan_fn} [id A]
← Composite{switch(lt(i0, i1), i2, i0)} [id I] (inner_out_sit_sot-0)
├─ TensorConstant{0} [id J]
├─ Subtensor{int64, int64, uint8} [id K]
├─ Subtensor{i, j, k} [id K]
│ ├─ *2-<TensorType(float64, (20000, 2, 2))> [id L] -> [id H] (inner_in_non_seqs-0)
│ ├─ ScalarFromTensor [id M]
│ │ └─ *0-<TensorType(int64, ())> [id N] -> [id C] (inner_in_seqs-0)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论