Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
f6bb307d
提交
f6bb307d
authored
5月 24, 2023
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
6月 08, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Improve string representation of Subtensor Ops
上级
b74cf3f6
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
112 行增加
和
113 行删除
+112
-113
subtensor.py
pytensor/tensor/subtensor.py
+30
-31
test_printing.py
tests/scan/test_printing.py
+82
-82
没有找到文件。
pytensor/tensor/subtensor.py
浏览文件 @
f6bb307d
...
@@ -840,22 +840,34 @@ class Subtensor(COp):
...
@@ -840,22 +840,34 @@ class Subtensor(COp):
@staticmethod
@staticmethod
def
str_from_slice
(
entry
):
def
str_from_slice
(
entry
):
msg
=
[]
if
entry
.
step
:
for
x
in
[
entry
.
start
,
entry
.
stop
,
entry
.
step
]:
return
":"
.
join
(
if
x
is
None
:
(
msg
.
append
(
""
)
"start"
if
entry
.
start
else
""
,
else
:
"stop"
if
entry
.
stop
else
""
,
msg
.
append
(
str
(
x
))
"step"
,
return
":"
.
join
(
msg
)
)
)
if
entry
.
stop
:
return
f
"{'start' if entry.start else ''}:stop"
if
entry
.
start
:
return
"start:"
return
":"
def
__str__
(
self
):
@staticmethod
def
str_from_indices
(
idx_list
):
indices
=
[]
indices
=
[]
for
entry
in
self
.
idx_list
:
letter_indexes
=
0
for
entry
in
idx_list
:
if
isinstance
(
entry
,
slice
):
if
isinstance
(
entry
,
slice
):
indices
.
append
(
self
.
str_from_slice
(
entry
))
indices
.
append
(
Subtensor
.
str_from_slice
(
entry
))
else
:
else
:
indices
.
append
(
str
(
entry
))
indices
.
append
(
"ijk"
[
letter_indexes
%
3
]
*
(
letter_indexes
//
3
+
1
))
return
f
"{self.__class__.__name__}{{{', '.join(indices)}}}"
letter_indexes
+=
1
return
", "
.
join
(
indices
)
def
__str__
(
self
):
return
f
"{self.__class__.__name__}{{{self.str_from_indices(self.idx_list)}}}"
@staticmethod
@staticmethod
def
default_helper_c_code_args
():
def
default_helper_c_code_args
():
...
@@ -1498,21 +1510,8 @@ class IncSubtensor(COp):
...
@@ -1498,21 +1510,8 @@ class IncSubtensor(COp):
return
hash
((
type
(
self
),
idx_list
,
self
.
inplace
,
self
.
set_instead_of_inc
))
return
hash
((
type
(
self
),
idx_list
,
self
.
inplace
,
self
.
set_instead_of_inc
))
def
__str__
(
self
):
def
__str__
(
self
):
indices
=
[]
name
=
"SetSubtensor"
if
self
.
set_instead_of_inc
else
"IncSubtensor"
for
entry
in
self
.
idx_list
:
return
f
"{name}{{{Subtensor.str_from_indices(self.idx_list)}}}"
if
isinstance
(
entry
,
slice
):
indices
.
append
(
Subtensor
.
str_from_slice
(
entry
))
else
:
indices
.
append
(
str
(
entry
))
if
self
.
inplace
:
msg
=
"Inplace"
else
:
msg
=
""
if
not
self
.
set_instead_of_inc
:
msg
+=
"Inc"
else
:
msg
+=
"Set"
return
f
"{self.__class__.__name__}{{{msg};{', '.join(indices)}}}"
def
make_node
(
self
,
x
,
y
,
*
inputs
):
def
make_node
(
self
,
x
,
y
,
*
inputs
):
"""
"""
...
@@ -2661,10 +2660,10 @@ class AdvancedIncSubtensor(Op):
...
@@ -2661,10 +2660,10 @@ class AdvancedIncSubtensor(Op):
self
.
ignore_duplicates
=
ignore_duplicates
self
.
ignore_duplicates
=
ignore_duplicates
def
__str__
(
self
):
def
__str__
(
self
):
return
"{}{{{}, {}}}"
.
format
(
return
(
self
.
__class__
.
__name__
,
"AdvancedSetSubtensor"
"inplace="
+
str
(
self
.
inplace
),
if
self
.
set_instead_of_inc
" set_instead_of_inc="
+
str
(
self
.
set_instead_of_inc
),
else
"AdvancedIncSubtensor"
)
)
def
make_node
(
self
,
x
,
y
,
*
inputs
):
def
make_node
(
self
,
x
,
y
,
*
inputs
):
...
...
tests/scan/test_printing.py
浏览文件 @
f6bb307d
...
@@ -26,15 +26,15 @@ def test_debugprint_sitsot():
...
@@ -26,15 +26,15 @@ def test_debugprint_sitsot():
output_str
=
debugprint
(
final_result
,
file
=
"str"
,
print_op_info
=
True
)
output_str
=
debugprint
(
final_result
,
file
=
"str"
,
print_op_info
=
True
)
lines
=
output_str
.
split
(
"
\n
"
)
lines
=
output_str
.
split
(
"
\n
"
)
expected_output
=
"""Subtensor{i
nt64
} [id A]
expected_output
=
"""Subtensor{i} [id A]
├─ Subtensor{
int64:
:} [id B]
├─ Subtensor{
start
:} [id B]
│ ├─ for{cpu,scan_fn} [id C] (outer_out_sit_sot-0)
│ ├─ for{cpu,scan_fn} [id C] (outer_out_sit_sot-0)
│ │ ├─ k [id D] (n_steps)
│ │ ├─ k [id D] (n_steps)
│ │ ├─
IncSubtensor{Set;:int64:
} [id E] (outer_in_sit_sot-0)
│ │ ├─
SetSubtensor{:stop
} [id E] (outer_in_sit_sot-0)
│ │ │ ├─ AllocEmpty{dtype='float64'} [id F]
│ │ │ ├─ AllocEmpty{dtype='float64'} [id F]
│ │ │ │ ├─ Add [id G]
│ │ │ │ ├─ Add [id G]
│ │ │ │ │ ├─ k [id D]
│ │ │ │ │ ├─ k [id D]
│ │ │ │ │ └─ Subtensor{i
nt64
} [id H]
│ │ │ │ │ └─ Subtensor{i} [id H]
│ │ │ │ │ ├─ Shape [id I]
│ │ │ │ │ ├─ Shape [id I]
│ │ │ │ │ │ └─ Unbroadcast{0} [id J]
│ │ │ │ │ │ └─ Unbroadcast{0} [id J]
│ │ │ │ │ │ └─ ExpandDims{axis=0} [id K]
│ │ │ │ │ │ └─ ExpandDims{axis=0} [id K]
...
@@ -43,7 +43,7 @@ def test_debugprint_sitsot():
...
@@ -43,7 +43,7 @@ def test_debugprint_sitsot():
│ │ │ │ │ │ └─ ExpandDims{axis=0} [id N]
│ │ │ │ │ │ └─ ExpandDims{axis=0} [id N]
│ │ │ │ │ │ └─ TensorConstant{1.0} [id O]
│ │ │ │ │ │ └─ TensorConstant{1.0} [id O]
│ │ │ │ │ └─ ScalarConstant{0} [id P]
│ │ │ │ │ └─ ScalarConstant{0} [id P]
│ │ │ │ └─ Subtensor{i
nt64
} [id Q]
│ │ │ │ └─ Subtensor{i} [id Q]
│ │ │ │ ├─ Shape [id R]
│ │ │ │ ├─ Shape [id R]
│ │ │ │ │ └─ Unbroadcast{0} [id J]
│ │ │ │ │ └─ Unbroadcast{0} [id J]
│ │ │ │ │ └─ ···
│ │ │ │ │ └─ ···
...
@@ -51,7 +51,7 @@ def test_debugprint_sitsot():
...
@@ -51,7 +51,7 @@ def test_debugprint_sitsot():
│ │ │ ├─ Unbroadcast{0} [id J]
│ │ │ ├─ Unbroadcast{0} [id J]
│ │ │ │ └─ ···
│ │ │ │ └─ ···
│ │ │ └─ ScalarFromTensor [id T]
│ │ │ └─ ScalarFromTensor [id T]
│ │ │ └─ Subtensor{i
nt64
} [id H]
│ │ │ └─ Subtensor{i} [id H]
│ │ │ └─ ···
│ │ │ └─ ···
│ │ └─ A [id M] (outer_in_non_seqs-0)
│ │ └─ A [id M] (outer_in_non_seqs-0)
│ └─ ScalarConstant{1} [id U]
│ └─ ScalarConstant{1} [id U]
...
@@ -84,15 +84,15 @@ def test_debugprint_sitsot_no_extra_info():
...
@@ -84,15 +84,15 @@ def test_debugprint_sitsot_no_extra_info():
output_str
=
debugprint
(
final_result
,
file
=
"str"
,
print_op_info
=
False
)
output_str
=
debugprint
(
final_result
,
file
=
"str"
,
print_op_info
=
False
)
lines
=
output_str
.
split
(
"
\n
"
)
lines
=
output_str
.
split
(
"
\n
"
)
expected_output
=
"""Subtensor{i
nt64
} [id A]
expected_output
=
"""Subtensor{i} [id A]
├─ Subtensor{
int64:
:} [id B]
├─ Subtensor{
start
:} [id B]
│ ├─ for{cpu,scan_fn} [id C]
│ ├─ for{cpu,scan_fn} [id C]
│ │ ├─ k [id D]
│ │ ├─ k [id D]
│ │ ├─
IncSubtensor{Set;:int64:
} [id E]
│ │ ├─
SetSubtensor{:stop
} [id E]
│ │ │ ├─ AllocEmpty{dtype='float64'} [id F]
│ │ │ ├─ AllocEmpty{dtype='float64'} [id F]
│ │ │ │ ├─ Add [id G]
│ │ │ │ ├─ Add [id G]
│ │ │ │ │ ├─ k [id D]
│ │ │ │ │ ├─ k [id D]
│ │ │ │ │ └─ Subtensor{i
nt64
} [id H]
│ │ │ │ │ └─ Subtensor{i} [id H]
│ │ │ │ │ ├─ Shape [id I]
│ │ │ │ │ ├─ Shape [id I]
│ │ │ │ │ │ └─ Unbroadcast{0} [id J]
│ │ │ │ │ │ └─ Unbroadcast{0} [id J]
│ │ │ │ │ │ └─ ExpandDims{axis=0} [id K]
│ │ │ │ │ │ └─ ExpandDims{axis=0} [id K]
...
@@ -101,7 +101,7 @@ def test_debugprint_sitsot_no_extra_info():
...
@@ -101,7 +101,7 @@ def test_debugprint_sitsot_no_extra_info():
│ │ │ │ │ │ └─ ExpandDims{axis=0} [id N]
│ │ │ │ │ │ └─ ExpandDims{axis=0} [id N]
│ │ │ │ │ │ └─ TensorConstant{1.0} [id O]
│ │ │ │ │ │ └─ TensorConstant{1.0} [id O]
│ │ │ │ │ └─ ScalarConstant{0} [id P]
│ │ │ │ │ └─ ScalarConstant{0} [id P]
│ │ │ │ └─ Subtensor{i
nt64
} [id Q]
│ │ │ │ └─ Subtensor{i} [id Q]
│ │ │ │ ├─ Shape [id R]
│ │ │ │ ├─ Shape [id R]
│ │ │ │ │ └─ Unbroadcast{0} [id J]
│ │ │ │ │ └─ Unbroadcast{0} [id J]
│ │ │ │ │ └─ ···
│ │ │ │ │ └─ ···
...
@@ -109,7 +109,7 @@ def test_debugprint_sitsot_no_extra_info():
...
@@ -109,7 +109,7 @@ def test_debugprint_sitsot_no_extra_info():
│ │ │ ├─ Unbroadcast{0} [id J]
│ │ │ ├─ Unbroadcast{0} [id J]
│ │ │ │ └─ ···
│ │ │ │ └─ ···
│ │ │ └─ ScalarFromTensor [id T]
│ │ │ └─ ScalarFromTensor [id T]
│ │ │ └─ Subtensor{i
nt64
} [id H]
│ │ │ └─ Subtensor{i} [id H]
│ │ │ └─ ···
│ │ │ └─ ···
│ │ └─ A [id M]
│ │ └─ A [id M]
│ └─ ScalarConstant{1} [id U]
│ └─ ScalarConstant{1} [id U]
...
@@ -150,29 +150,29 @@ def test_debugprint_nitsot():
...
@@ -150,29 +150,29 @@ def test_debugprint_nitsot():
expected_output
=
"""Sum{axes=None} [id A]
expected_output
=
"""Sum{axes=None} [id A]
└─ for{cpu,scan_fn} [id B] (outer_out_nit_sot-0)
└─ for{cpu,scan_fn} [id B] (outer_out_nit_sot-0)
├─ Minimum [id C] (outer_in_nit_sot-0)
├─ Minimum [id C] (outer_in_nit_sot-0)
│ ├─ Subtensor{i
nt64
} [id D]
│ ├─ Subtensor{i} [id D]
│ │ ├─ Shape [id E]
│ │ ├─ Shape [id E]
│ │ │ └─ Subtensor{
int64:
:} [id F] 'coefficients[0:]'
│ │ │ └─ Subtensor{
start
:} [id F] 'coefficients[0:]'
│ │ │ ├─ coefficients [id G]
│ │ │ ├─ coefficients [id G]
│ │ │ └─ ScalarConstant{0} [id H]
│ │ │ └─ ScalarConstant{0} [id H]
│ │ └─ ScalarConstant{0} [id I]
│ │ └─ ScalarConstant{0} [id I]
│ └─ Subtensor{i
nt64
} [id J]
│ └─ Subtensor{i} [id J]
│ ├─ Shape [id K]
│ ├─ Shape [id K]
│ │ └─ Subtensor{
int64:
:} [id L]
│ │ └─ Subtensor{
start
:} [id L]
│ │ ├─ ARange{dtype='int64'} [id M]
│ │ ├─ ARange{dtype='int64'} [id M]
│ │ │ ├─ TensorConstant{0} [id N]
│ │ │ ├─ TensorConstant{0} [id N]
│ │ │ ├─ TensorConstant{10000} [id O]
│ │ │ ├─ TensorConstant{10000} [id O]
│ │ │ └─ TensorConstant{1} [id P]
│ │ │ └─ TensorConstant{1} [id P]
│ │ └─ ScalarConstant{0} [id Q]
│ │ └─ ScalarConstant{0} [id Q]
│ └─ ScalarConstant{0} [id R]
│ └─ ScalarConstant{0} [id R]
├─ Subtensor{:
int64:
} [id S] (outer_in_seqs-0)
├─ Subtensor{:
stop
} [id S] (outer_in_seqs-0)
│ ├─ Subtensor{
int64:
:} [id F] 'coefficients[0:]'
│ ├─ Subtensor{
start
:} [id F] 'coefficients[0:]'
│ │ └─ ···
│ │ └─ ···
│ └─ ScalarFromTensor [id T]
│ └─ ScalarFromTensor [id T]
│ └─ Minimum [id C]
│ └─ Minimum [id C]
│ └─ ···
│ └─ ···
├─ Subtensor{:
int64:
} [id U] (outer_in_seqs-1)
├─ Subtensor{:
stop
} [id U] (outer_in_seqs-1)
│ ├─ Subtensor{
int64:
:} [id L]
│ ├─ Subtensor{
start
:} [id L]
│ │ └─ ···
│ │ └─ ···
│ └─ ScalarFromTensor [id V]
│ └─ ScalarFromTensor [id V]
│ └─ Minimum [id C]
│ └─ Minimum [id C]
...
@@ -228,29 +228,29 @@ def test_debugprint_nested_scans():
...
@@ -228,29 +228,29 @@ def test_debugprint_nested_scans():
expected_output
=
"""Sum{axes=None} [id A]
expected_output
=
"""Sum{axes=None} [id A]
└─ for{cpu,scan_fn} [id B] (outer_out_nit_sot-0)
└─ for{cpu,scan_fn} [id B] (outer_out_nit_sot-0)
├─ Minimum [id C] (outer_in_nit_sot-0)
├─ Minimum [id C] (outer_in_nit_sot-0)
│ ├─ Subtensor{i
nt64
} [id D]
│ ├─ Subtensor{i} [id D]
│ │ ├─ Shape [id E]
│ │ ├─ Shape [id E]
│ │ │ └─ Subtensor{
int64:
:} [id F] 'c[0:]'
│ │ │ └─ Subtensor{
start
:} [id F] 'c[0:]'
│ │ │ ├─ c [id G]
│ │ │ ├─ c [id G]
│ │ │ └─ ScalarConstant{0} [id H]
│ │ │ └─ ScalarConstant{0} [id H]
│ │ └─ ScalarConstant{0} [id I]
│ │ └─ ScalarConstant{0} [id I]
│ └─ Subtensor{i
nt64
} [id J]
│ └─ Subtensor{i} [id J]
│ ├─ Shape [id K]
│ ├─ Shape [id K]
│ │ └─ Subtensor{
int64:
:} [id L]
│ │ └─ Subtensor{
start
:} [id L]
│ │ ├─ ARange{dtype='int64'} [id M]
│ │ ├─ ARange{dtype='int64'} [id M]
│ │ │ ├─ TensorConstant{0} [id N]
│ │ │ ├─ TensorConstant{0} [id N]
│ │ │ ├─ TensorConstant{10} [id O]
│ │ │ ├─ TensorConstant{10} [id O]
│ │ │ └─ TensorConstant{1} [id P]
│ │ │ └─ TensorConstant{1} [id P]
│ │ └─ ScalarConstant{0} [id Q]
│ │ └─ ScalarConstant{0} [id Q]
│ └─ ScalarConstant{0} [id R]
│ └─ ScalarConstant{0} [id R]
├─ Subtensor{:
int64:
} [id S] (outer_in_seqs-0)
├─ Subtensor{:
stop
} [id S] (outer_in_seqs-0)
│ ├─ Subtensor{
int64:
:} [id F] 'c[0:]'
│ ├─ Subtensor{
start
:} [id F] 'c[0:]'
│ │ └─ ···
│ │ └─ ···
│ └─ ScalarFromTensor [id T]
│ └─ ScalarFromTensor [id T]
│ └─ Minimum [id C]
│ └─ Minimum [id C]
│ └─ ···
│ └─ ···
├─ Subtensor{:
int64:
} [id U] (outer_in_seqs-1)
├─ Subtensor{:
stop
} [id U] (outer_in_seqs-1)
│ ├─ Subtensor{
int64:
:} [id L]
│ ├─ Subtensor{
start
:} [id L]
│ │ └─ ···
│ │ └─ ···
│ └─ ScalarFromTensor [id V]
│ └─ ScalarFromTensor [id V]
│ └─ Minimum [id C]
│ └─ Minimum [id C]
...
@@ -267,15 +267,15 @@ def test_debugprint_nested_scans():
...
@@ -267,15 +267,15 @@ def test_debugprint_nested_scans():
├─ ExpandDims{axis=0} [id Z]
├─ ExpandDims{axis=0} [id Z]
│ └─ *0-<TensorType(float64, ())> [id BA] -> [id S] (inner_in_seqs-0)
│ └─ *0-<TensorType(float64, ())> [id BA] -> [id S] (inner_in_seqs-0)
└─ Pow [id BB]
└─ Pow [id BB]
├─ Subtensor{i
nt64
} [id BC]
├─ Subtensor{i} [id BC]
│ ├─ Subtensor{
int64:
:} [id BD]
│ ├─ Subtensor{
start
:} [id BD]
│ │ ├─ for{cpu,scan_fn} [id BE] (outer_out_sit_sot-0)
│ │ ├─ for{cpu,scan_fn} [id BE] (outer_out_sit_sot-0)
│ │ │ ├─ *3-<TensorType(int32, ())> [id BF] -> [id X] (inner_in_non_seqs-1) (n_steps)
│ │ │ ├─ *3-<TensorType(int32, ())> [id BF] -> [id X] (inner_in_non_seqs-1) (n_steps)
│ │ │ ├─
IncSubtensor{Set;:int64:
} [id BG] (outer_in_sit_sot-0)
│ │ │ ├─
SetSubtensor{:stop
} [id BG] (outer_in_sit_sot-0)
│ │ │ │ ├─ AllocEmpty{dtype='float64'} [id BH]
│ │ │ │ ├─ AllocEmpty{dtype='float64'} [id BH]
│ │ │ │ │ ├─ Add [id BI]
│ │ │ │ │ ├─ Add [id BI]
│ │ │ │ │ │ ├─ *3-<TensorType(int32, ())> [id BF] -> [id X] (inner_in_non_seqs-1)
│ │ │ │ │ │ ├─ *3-<TensorType(int32, ())> [id BF] -> [id X] (inner_in_non_seqs-1)
│ │ │ │ │ │ └─ Subtensor{i
nt64
} [id BJ]
│ │ │ │ │ │ └─ Subtensor{i} [id BJ]
│ │ │ │ │ │ ├─ Shape [id BK]
│ │ │ │ │ │ ├─ Shape [id BK]
│ │ │ │ │ │ │ └─ Unbroadcast{0} [id BL]
│ │ │ │ │ │ │ └─ Unbroadcast{0} [id BL]
│ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BM]
│ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BM]
...
@@ -284,7 +284,7 @@ def test_debugprint_nested_scans():
...
@@ -284,7 +284,7 @@ def test_debugprint_nested_scans():
│ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BP]
│ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BP]
│ │ │ │ │ │ │ └─ TensorConstant{1.0} [id BQ]
│ │ │ │ │ │ │ └─ TensorConstant{1.0} [id BQ]
│ │ │ │ │ │ └─ ScalarConstant{0} [id BR]
│ │ │ │ │ │ └─ ScalarConstant{0} [id BR]
│ │ │ │ │ └─ Subtensor{i
nt64
} [id BS]
│ │ │ │ │ └─ Subtensor{i} [id BS]
│ │ │ │ │ ├─ Shape [id BT]
│ │ │ │ │ ├─ Shape [id BT]
│ │ │ │ │ │ └─ Unbroadcast{0} [id BL]
│ │ │ │ │ │ └─ Unbroadcast{0} [id BL]
│ │ │ │ │ │ └─ ···
│ │ │ │ │ │ └─ ···
...
@@ -292,7 +292,7 @@ def test_debugprint_nested_scans():
...
@@ -292,7 +292,7 @@ def test_debugprint_nested_scans():
│ │ │ │ ├─ Unbroadcast{0} [id BL]
│ │ │ │ ├─ Unbroadcast{0} [id BL]
│ │ │ │ │ └─ ···
│ │ │ │ │ └─ ···
│ │ │ │ └─ ScalarFromTensor [id BV]
│ │ │ │ └─ ScalarFromTensor [id BV]
│ │ │ │ └─ Subtensor{i
nt64
} [id BJ]
│ │ │ │ └─ Subtensor{i} [id BJ]
│ │ │ │ └─ ···
│ │ │ │ └─ ···
│ │ │ └─ *2-<TensorType(float64, (?,))> [id BO] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
│ │ │ └─ *2-<TensorType(float64, (?,))> [id BO] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
│ │ └─ ScalarConstant{1} [id BW]
│ │ └─ ScalarConstant{1} [id BW]
...
@@ -321,29 +321,29 @@ def test_debugprint_nested_scans():
...
@@ -321,29 +321,29 @@ def test_debugprint_nested_scans():
Sum{axes=None} [id D] 13
Sum{axes=None} [id D] 13
└─ for{cpu,scan_fn} [id E] 12 (outer_out_nit_sot-0)
└─ for{cpu,scan_fn} [id E] 12 (outer_out_nit_sot-0)
├─ Minimum [id F] 7 (outer_in_nit_sot-0)
├─ Minimum [id F] 7 (outer_in_nit_sot-0)
│ ├─ Subtensor{i
nt64
} [id G] 6
│ ├─ Subtensor{i} [id G] 6
│ │ ├─ Shape [id H] 5
│ │ ├─ Shape [id H] 5
│ │ │ └─ Subtensor{
int64:
:} [id I] 'c[0:]' 4
│ │ │ └─ Subtensor{
start
:} [id I] 'c[0:]' 4
│ │ │ ├─ c [id A]
│ │ │ ├─ c [id A]
│ │ │ └─ ScalarConstant{0} [id J]
│ │ │ └─ ScalarConstant{0} [id J]
│ │ └─ ScalarConstant{0} [id K]
│ │ └─ ScalarConstant{0} [id K]
│ └─ Subtensor{i
nt64
} [id L] 3
│ └─ Subtensor{i} [id L] 3
│ ├─ Shape [id M] 2
│ ├─ Shape [id M] 2
│ │ └─ Subtensor{
int64:
:} [id N] 1
│ │ └─ Subtensor{
start
:} [id N] 1
│ │ ├─ ARange{dtype='int64'} [id O] 0
│ │ ├─ ARange{dtype='int64'} [id O] 0
│ │ │ ├─ TensorConstant{0} [id P]
│ │ │ ├─ TensorConstant{0} [id P]
│ │ │ ├─ TensorConstant{10} [id Q]
│ │ │ ├─ TensorConstant{10} [id Q]
│ │ │ └─ TensorConstant{1} [id R]
│ │ │ └─ TensorConstant{1} [id R]
│ │ └─ ScalarConstant{0} [id S]
│ │ └─ ScalarConstant{0} [id S]
│ └─ ScalarConstant{0} [id T]
│ └─ ScalarConstant{0} [id T]
├─ Subtensor{:
int64:
} [id U] 11 (outer_in_seqs-0)
├─ Subtensor{:
stop
} [id U] 11 (outer_in_seqs-0)
│ ├─ Subtensor{
int64:
:} [id I] 'c[0:]' 4
│ ├─ Subtensor{
start
:} [id I] 'c[0:]' 4
│ │ └─ ···
│ │ └─ ···
│ └─ ScalarFromTensor [id V] 10
│ └─ ScalarFromTensor [id V] 10
│ └─ Minimum [id F] 7
│ └─ Minimum [id F] 7
│ └─ ···
│ └─ ···
├─ Subtensor{:
int64:
} [id W] 9 (outer_in_seqs-1)
├─ Subtensor{:
stop
} [id W] 9 (outer_in_seqs-1)
│ ├─ Subtensor{
int64:
:} [id N] 1
│ ├─ Subtensor{
start
:} [id N] 1
│ │ └─ ···
│ │ └─ ···
│ └─ ScalarFromTensor [id X] 8
│ └─ ScalarFromTensor [id X] 8
│ └─ Minimum [id F] 7
│ └─ Minimum [id F] 7
...
@@ -364,15 +364,15 @@ def test_debugprint_nested_scans():
...
@@ -364,15 +364,15 @@ def test_debugprint_nested_scans():
├─ ExpandDims{axis=0} [id BD]
├─ ExpandDims{axis=0} [id BD]
│ └─ *0-<TensorType(float64, ())> [id Y] (inner_in_seqs-0)
│ └─ *0-<TensorType(float64, ())> [id Y] (inner_in_seqs-0)
└─ Pow [id BE]
└─ Pow [id BE]
├─ Subtensor{i
nt64
} [id BF]
├─ Subtensor{i} [id BF]
│ ├─ Subtensor{
int64:
:} [id BG]
│ ├─ Subtensor{
start
:} [id BG]
│ │ ├─ for{cpu,scan_fn} [id BH] (outer_out_sit_sot-0)
│ │ ├─ for{cpu,scan_fn} [id BH] (outer_out_sit_sot-0)
│ │ │ ├─ *3-<TensorType(int32, ())> [id BB] (inner_in_non_seqs-1) (n_steps)
│ │ │ ├─ *3-<TensorType(int32, ())> [id BB] (inner_in_non_seqs-1) (n_steps)
│ │ │ ├─
IncSubtensor{Set;:int64:
} [id BI] (outer_in_sit_sot-0)
│ │ │ ├─
SetSubtensor{:stop
} [id BI] (outer_in_sit_sot-0)
│ │ │ │ ├─ AllocEmpty{dtype='float64'} [id BJ]
│ │ │ │ ├─ AllocEmpty{dtype='float64'} [id BJ]
│ │ │ │ │ ├─ Add [id BK]
│ │ │ │ │ ├─ Add [id BK]
│ │ │ │ │ │ ├─ *3-<TensorType(int32, ())> [id BB] (inner_in_non_seqs-1)
│ │ │ │ │ │ ├─ *3-<TensorType(int32, ())> [id BB] (inner_in_non_seqs-1)
│ │ │ │ │ │ └─ Subtensor{i
nt64
} [id BL]
│ │ │ │ │ │ └─ Subtensor{i} [id BL]
│ │ │ │ │ │ ├─ Shape [id BM]
│ │ │ │ │ │ ├─ Shape [id BM]
│ │ │ │ │ │ │ └─ Unbroadcast{0} [id BN]
│ │ │ │ │ │ │ └─ Unbroadcast{0} [id BN]
│ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BO]
│ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BO]
...
@@ -381,7 +381,7 @@ def test_debugprint_nested_scans():
...
@@ -381,7 +381,7 @@ def test_debugprint_nested_scans():
│ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BQ]
│ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BQ]
│ │ │ │ │ │ │ └─ TensorConstant{1.0} [id BR]
│ │ │ │ │ │ │ └─ TensorConstant{1.0} [id BR]
│ │ │ │ │ │ └─ ScalarConstant{0} [id BS]
│ │ │ │ │ │ └─ ScalarConstant{0} [id BS]
│ │ │ │ │ └─ Subtensor{i
nt64
} [id BT]
│ │ │ │ │ └─ Subtensor{i} [id BT]
│ │ │ │ │ ├─ Shape [id BU]
│ │ │ │ │ ├─ Shape [id BU]
│ │ │ │ │ │ └─ Unbroadcast{0} [id BN]
│ │ │ │ │ │ └─ Unbroadcast{0} [id BN]
│ │ │ │ │ │ └─ ···
│ │ │ │ │ │ └─ ···
...
@@ -389,7 +389,7 @@ def test_debugprint_nested_scans():
...
@@ -389,7 +389,7 @@ def test_debugprint_nested_scans():
│ │ │ │ ├─ Unbroadcast{0} [id BN]
│ │ │ │ ├─ Unbroadcast{0} [id BN]
│ │ │ │ │ └─ ···
│ │ │ │ │ └─ ···
│ │ │ │ └─ ScalarFromTensor [id BW]
│ │ │ │ └─ ScalarFromTensor [id BW]
│ │ │ │ └─ Subtensor{i
nt64
} [id BL]
│ │ │ │ └─ Subtensor{i} [id BL]
│ │ │ │ └─ ···
│ │ │ │ └─ ···
│ │ │ └─ *2-<TensorType(float64, (?,))> [id BA] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
│ │ │ └─ *2-<TensorType(float64, (?,))> [id BA] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
│ │ └─ ScalarConstant{1} [id BX]
│ │ └─ ScalarConstant{1} [id BX]
...
@@ -430,41 +430,41 @@ def test_debugprint_mitsot():
...
@@ -430,41 +430,41 @@ def test_debugprint_mitsot():
lines
=
output_str
.
split
(
"
\n
"
)
lines
=
output_str
.
split
(
"
\n
"
)
expected_output
=
"""Add [id A]
expected_output
=
"""Add [id A]
├─ Subtensor{
int64:
:} [id B]
├─ Subtensor{
start
:} [id B]
│ ├─ for{cpu,scan_fn}.0 [id C] (outer_out_mit_sot-0)
│ ├─ for{cpu,scan_fn}.0 [id C] (outer_out_mit_sot-0)
│ │ ├─ TensorConstant{5} [id D] (n_steps)
│ │ ├─ TensorConstant{5} [id D] (n_steps)
│ │ ├─
IncSubtensor{Set;:int64:
} [id E] (outer_in_mit_sot-0)
│ │ ├─
SetSubtensor{:stop
} [id E] (outer_in_mit_sot-0)
│ │ │ ├─ AllocEmpty{dtype='int64'} [id F]
│ │ │ ├─ AllocEmpty{dtype='int64'} [id F]
│ │ │ │ └─ Add [id G]
│ │ │ │ └─ Add [id G]
│ │ │ │ ├─ TensorConstant{5} [id D]
│ │ │ │ ├─ TensorConstant{5} [id D]
│ │ │ │ └─ Subtensor{i
nt64
} [id H]
│ │ │ │ └─ Subtensor{i} [id H]
│ │ │ │ ├─ Shape [id I]
│ │ │ │ ├─ Shape [id I]
│ │ │ │ │ └─ Subtensor{:
int64:
} [id J]
│ │ │ │ │ └─ Subtensor{:
stop
} [id J]
│ │ │ │ │ ├─ <TensorType(int64, (?,))> [id K]
│ │ │ │ │ ├─ <TensorType(int64, (?,))> [id K]
│ │ │ │ │ └─ ScalarConstant{2} [id L]
│ │ │ │ │ └─ ScalarConstant{2} [id L]
│ │ │ │ └─ ScalarConstant{0} [id M]
│ │ │ │ └─ ScalarConstant{0} [id M]
│ │ │ ├─ Subtensor{:
int64:
} [id J]
│ │ │ ├─ Subtensor{:
stop
} [id J]
│ │ │ │ └─ ···
│ │ │ │ └─ ···
│ │ │ └─ ScalarFromTensor [id N]
│ │ │ └─ ScalarFromTensor [id N]
│ │ │ └─ Subtensor{i
nt64
} [id H]
│ │ │ └─ Subtensor{i} [id H]
│ │ │ └─ ···
│ │ │ └─ ···
│ │ └─
IncSubtensor{Set;:int64:
} [id O] (outer_in_mit_sot-1)
│ │ └─
SetSubtensor{:stop
} [id O] (outer_in_mit_sot-1)
│ │ ├─ AllocEmpty{dtype='int64'} [id P]
│ │ ├─ AllocEmpty{dtype='int64'} [id P]
│ │ │ └─ Add [id Q]
│ │ │ └─ Add [id Q]
│ │ │ ├─ TensorConstant{5} [id D]
│ │ │ ├─ TensorConstant{5} [id D]
│ │ │ └─ Subtensor{i
nt64
} [id R]
│ │ │ └─ Subtensor{i} [id R]
│ │ │ ├─ Shape [id S]
│ │ │ ├─ Shape [id S]
│ │ │ │ └─ Subtensor{:
int64:
} [id T]
│ │ │ │ └─ Subtensor{:
stop
} [id T]
│ │ │ │ ├─ <TensorType(int64, (?,))> [id U]
│ │ │ │ ├─ <TensorType(int64, (?,))> [id U]
│ │ │ │ └─ ScalarConstant{2} [id V]
│ │ │ │ └─ ScalarConstant{2} [id V]
│ │ │ └─ ScalarConstant{0} [id W]
│ │ │ └─ ScalarConstant{0} [id W]
│ │ ├─ Subtensor{:
int64:
} [id T]
│ │ ├─ Subtensor{:
stop
} [id T]
│ │ │ └─ ···
│ │ │ └─ ···
│ │ └─ ScalarFromTensor [id X]
│ │ └─ ScalarFromTensor [id X]
│ │ └─ Subtensor{i
nt64
} [id R]
│ │ └─ Subtensor{i} [id R]
│ │ └─ ···
│ │ └─ ···
│ └─ ScalarConstant{2} [id Y]
│ └─ ScalarConstant{2} [id Y]
└─ Subtensor{
int64:
:} [id Z]
└─ Subtensor{
start
:} [id Z]
├─ for{cpu,scan_fn}.1 [id C] (outer_out_mit_sot-1)
├─ for{cpu,scan_fn}.1 [id C] (outer_out_mit_sot-1)
│ └─ ···
│ └─ ···
└─ ScalarConstant{2} [id BA]
└─ ScalarConstant{2} [id BA]
...
@@ -501,18 +501,18 @@ def test_debugprint_mitmot():
...
@@ -501,18 +501,18 @@ def test_debugprint_mitmot():
output_str
=
debugprint
(
final_result
,
file
=
"str"
,
print_op_info
=
True
)
output_str
=
debugprint
(
final_result
,
file
=
"str"
,
print_op_info
=
True
)
lines
=
output_str
.
split
(
"
\n
"
)
lines
=
output_str
.
split
(
"
\n
"
)
expected_output
=
"""Subtensor{i
nt64
} [id A]
expected_output
=
"""Subtensor{i} [id A]
├─ for{cpu,grad_of_scan_fn}.1 [id B] (outer_out_sit_sot-0)
├─ for{cpu,grad_of_scan_fn}.1 [id B] (outer_out_sit_sot-0)
│ ├─ Sub [id C] (n_steps)
│ ├─ Sub [id C] (n_steps)
│ │ ├─ Subtensor{i
nt64
} [id D]
│ │ ├─ Subtensor{i} [id D]
│ │ │ ├─ Shape [id E]
│ │ │ ├─ Shape [id E]
│ │ │ │ └─ for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
│ │ │ │ └─ for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
│ │ │ │ ├─ k [id G] (n_steps)
│ │ │ │ ├─ k [id G] (n_steps)
│ │ │ │ ├─
IncSubtensor{Set;:int64:
} [id H] (outer_in_sit_sot-0)
│ │ │ │ ├─
SetSubtensor{:stop
} [id H] (outer_in_sit_sot-0)
│ │ │ │ │ ├─ AllocEmpty{dtype='float64'} [id I]
│ │ │ │ │ ├─ AllocEmpty{dtype='float64'} [id I]
│ │ │ │ │ │ ├─ Add [id J]
│ │ │ │ │ │ ├─ Add [id J]
│ │ │ │ │ │ │ ├─ k [id G]
│ │ │ │ │ │ │ ├─ k [id G]
│ │ │ │ │ │ │ └─ Subtensor{i
nt64
} [id K]
│ │ │ │ │ │ │ └─ Subtensor{i} [id K]
│ │ │ │ │ │ │ ├─ Shape [id L]
│ │ │ │ │ │ │ ├─ Shape [id L]
│ │ │ │ │ │ │ │ └─ Unbroadcast{0} [id M]
│ │ │ │ │ │ │ │ └─ Unbroadcast{0} [id M]
│ │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id N]
│ │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id N]
...
@@ -521,7 +521,7 @@ def test_debugprint_mitmot():
...
@@ -521,7 +521,7 @@ def test_debugprint_mitmot():
│ │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id Q]
│ │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id Q]
│ │ │ │ │ │ │ │ └─ TensorConstant{1.0} [id R]
│ │ │ │ │ │ │ │ └─ TensorConstant{1.0} [id R]
│ │ │ │ │ │ │ └─ ScalarConstant{0} [id S]
│ │ │ │ │ │ │ └─ ScalarConstant{0} [id S]
│ │ │ │ │ │ └─ Subtensor{i
nt64
} [id T]
│ │ │ │ │ │ └─ Subtensor{i} [id T]
│ │ │ │ │ │ ├─ Shape [id U]
│ │ │ │ │ │ ├─ Shape [id U]
│ │ │ │ │ │ │ └─ Unbroadcast{0} [id M]
│ │ │ │ │ │ │ └─ Unbroadcast{0} [id M]
│ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ │ └─ ···
...
@@ -529,14 +529,14 @@ def test_debugprint_mitmot():
...
@@ -529,14 +529,14 @@ def test_debugprint_mitmot():
│ │ │ │ │ ├─ Unbroadcast{0} [id M]
│ │ │ │ │ ├─ Unbroadcast{0} [id M]
│ │ │ │ │ │ └─ ···
│ │ │ │ │ │ └─ ···
│ │ │ │ │ └─ ScalarFromTensor [id W]
│ │ │ │ │ └─ ScalarFromTensor [id W]
│ │ │ │ │ └─ Subtensor{i
nt64
} [id K]
│ │ │ │ │ └─ Subtensor{i} [id K]
│ │ │ │ │ └─ ···
│ │ │ │ │ └─ ···
│ │ │ │ └─ A [id P] (outer_in_non_seqs-0)
│ │ │ │ └─ A [id P] (outer_in_non_seqs-0)
│ │ │ └─ ScalarConstant{0} [id X]
│ │ │ └─ ScalarConstant{0} [id X]
│ │ └─ TensorConstant{1} [id Y]
│ │ └─ TensorConstant{1} [id Y]
│ ├─ Subtensor{:
int64:
} [id Z] (outer_in_seqs-0)
│ ├─ Subtensor{:
stop
} [id Z] (outer_in_seqs-0)
│ │ ├─ Subtensor{::
int64
} [id BA]
│ │ ├─ Subtensor{::
step
} [id BA]
│ │ │ ├─ Subtensor{:
int64:
} [id BB]
│ │ │ ├─ Subtensor{:
stop
} [id BB]
│ │ │ │ ├─ for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
│ │ │ │ ├─ for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
│ │ │ │ │ └─ ···
│ │ │ │ │ └─ ···
│ │ │ │ └─ ScalarConstant{-1} [id BC]
│ │ │ │ └─ ScalarConstant{-1} [id BC]
...
@@ -544,9 +544,9 @@ def test_debugprint_mitmot():
...
@@ -544,9 +544,9 @@ def test_debugprint_mitmot():
│ │ └─ ScalarFromTensor [id BE]
│ │ └─ ScalarFromTensor [id BE]
│ │ └─ Sub [id C]
│ │ └─ Sub [id C]
│ │ └─ ···
│ │ └─ ···
│ ├─ Subtensor{:
int64:
} [id BF] (outer_in_seqs-1)
│ ├─ Subtensor{:
stop
} [id BF] (outer_in_seqs-1)
│ │ ├─ Subtensor{:
int64:
} [id BG]
│ │ ├─ Subtensor{:
stop
} [id BG]
│ │ │ ├─ Subtensor{::
int64
} [id BH]
│ │ │ ├─ Subtensor{::
step
} [id BH]
│ │ │ │ ├─ for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
│ │ │ │ ├─ for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
│ │ │ │ │ └─ ···
│ │ │ │ │ └─ ···
│ │ │ │ └─ ScalarConstant{-1} [id BI]
│ │ │ │ └─ ScalarConstant{-1} [id BI]
...
@@ -554,30 +554,30 @@ def test_debugprint_mitmot():
...
@@ -554,30 +554,30 @@ def test_debugprint_mitmot():
│ │ └─ ScalarFromTensor [id BK]
│ │ └─ ScalarFromTensor [id BK]
│ │ └─ Sub [id C]
│ │ └─ Sub [id C]
│ │ └─ ···
│ │ └─ ···
│ ├─ Subtensor{::
int64
} [id BL] (outer_in_mit_mot-0)
│ ├─ Subtensor{::
step
} [id BL] (outer_in_mit_mot-0)
│ │ ├─ IncSubtensor{
Inc;int64:
:} [id BM]
│ │ ├─ IncSubtensor{
start
:} [id BM]
│ │ │ ├─ Second [id BN]
│ │ │ ├─ Second [id BN]
│ │ │ │ ├─ for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
│ │ │ │ ├─ for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
│ │ │ │ │ └─ ···
│ │ │ │ │ └─ ···
│ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BO]
│ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BO]
│ │ │ │ └─ TensorConstant{0.0} [id BP]
│ │ │ │ └─ TensorConstant{0.0} [id BP]
│ │ │ ├─ IncSubtensor{
Inc;int64
} [id BQ]
│ │ │ ├─ IncSubtensor{
i
} [id BQ]
│ │ │ │ ├─ Second [id BR]
│ │ │ │ ├─ Second [id BR]
│ │ │ │ │ ├─ Subtensor{
int64:
:} [id BS]
│ │ │ │ │ ├─ Subtensor{
start
:} [id BS]
│ │ │ │ │ │ ├─ for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
│ │ │ │ │ │ ├─ for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
│ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ └─ ScalarConstant{1} [id BT]
│ │ │ │ │ │ └─ ScalarConstant{1} [id BT]
│ │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BU]
│ │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BU]
│ │ │ │ │ └─ TensorConstant{0.0} [id BV]
│ │ │ │ │ └─ TensorConstant{0.0} [id BV]
│ │ │ │ ├─ Second [id BW]
│ │ │ │ ├─ Second [id BW]
│ │ │ │ │ ├─ Subtensor{i
nt64
} [id BX]
│ │ │ │ │ ├─ Subtensor{i} [id BX]
│ │ │ │ │ │ ├─ Subtensor{
int64:
:} [id BS]
│ │ │ │ │ │ ├─ Subtensor{
start
:} [id BS]
│ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ └─ ScalarConstant{-1} [id BY]
│ │ │ │ │ │ └─ ScalarConstant{-1} [id BY]
│ │ │ │ │ └─ ExpandDims{axis=0} [id BZ]
│ │ │ │ │ └─ ExpandDims{axis=0} [id BZ]
│ │ │ │ │ └─ Second [id CA]
│ │ │ │ │ └─ Second [id CA]
│ │ │ │ │ ├─ Sum{axes=None} [id CB]
│ │ │ │ │ ├─ Sum{axes=None} [id CB]
│ │ │ │ │ │ └─ Subtensor{i
nt64
} [id BX]
│ │ │ │ │ │ └─ Subtensor{i} [id BX]
│ │ │ │ │ │ └─ ···
│ │ │ │ │ │ └─ ···
│ │ │ │ │ └─ TensorConstant{1.0} [id CC]
│ │ │ │ │ └─ TensorConstant{1.0} [id CC]
│ │ │ │ └─ ScalarConstant{-1} [id BY]
│ │ │ │ └─ ScalarConstant{-1} [id BY]
...
@@ -589,7 +589,7 @@ def test_debugprint_mitmot():
...
@@ -589,7 +589,7 @@ def test_debugprint_mitmot():
│ │ │ ├─ Sub [id C]
│ │ │ ├─ Sub [id C]
│ │ │ │ └─ ···
│ │ │ │ └─ ···
│ │ │ └─ TensorConstant{1} [id CH]
│ │ │ └─ TensorConstant{1} [id CH]
│ │ └─ Subtensor{i
nt64
} [id CI]
│ │ └─ Subtensor{i} [id CI]
│ │ ├─ Shape [id CJ]
│ │ ├─ Shape [id CJ]
│ │ │ └─ A [id P]
│ │ │ └─ A [id P]
│ │ └─ ScalarConstant{0} [id CK]
│ │ └─ ScalarConstant{0} [id CK]
...
@@ -644,7 +644,7 @@ def test_debugprint_compiled_fn():
...
@@ -644,7 +644,7 @@ def test_debugprint_compiled_fn():
expected_output
=
"""forall_inplace,cpu,scan_fn} [id A] 2 (outer_out_sit_sot-0)
expected_output
=
"""forall_inplace,cpu,scan_fn} [id A] 2 (outer_out_sit_sot-0)
├─ TensorConstant{20000} [id B] (n_steps)
├─ TensorConstant{20000} [id B] (n_steps)
├─ TensorConstant{[ 0 ..998 19999]} [id C] (outer_in_seqs-0)
├─ TensorConstant{[ 0 ..998 19999]} [id C] (outer_in_seqs-0)
├─
IncSubtensor{InplaceSet;:int64:
} [id D] 1 (outer_in_sit_sot-0)
├─
SetSubtensor{:stop
} [id D] 1 (outer_in_sit_sot-0)
│ ├─ AllocEmpty{dtype='int64'} [id E] 0
│ ├─ AllocEmpty{dtype='int64'} [id E] 0
│ │ └─ TensorConstant{20000} [id B]
│ │ └─ TensorConstant{20000} [id B]
│ ├─ TensorConstant{(1,) of 0} [id F]
│ ├─ TensorConstant{(1,) of 0} [id F]
...
@@ -656,7 +656,7 @@ def test_debugprint_compiled_fn():
...
@@ -656,7 +656,7 @@ def test_debugprint_compiled_fn():
forall_inplace,cpu,scan_fn} [id A]
forall_inplace,cpu,scan_fn} [id A]
← Composite{switch(lt(i0, i1), i2, i0)} [id I] (inner_out_sit_sot-0)
← Composite{switch(lt(i0, i1), i2, i0)} [id I] (inner_out_sit_sot-0)
├─ TensorConstant{0} [id J]
├─ TensorConstant{0} [id J]
├─ Subtensor{i
nt64, int64, uint8
} [id K]
├─ Subtensor{i
, j, k
} [id K]
│ ├─ *2-<TensorType(float64, (20000, 2, 2))> [id L] -> [id H] (inner_in_non_seqs-0)
│ ├─ *2-<TensorType(float64, (20000, 2, 2))> [id L] -> [id H] (inner_in_non_seqs-0)
│ ├─ ScalarFromTensor [id M]
│ ├─ ScalarFromTensor [id M]
│ │ └─ *0-<TensorType(int64, ())> [id N] -> [id C] (inner_in_seqs-0)
│ │ └─ *0-<TensorType(int64, ())> [id N] -> [id C] (inner_in_seqs-0)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论