Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
9d2f8f11
提交
9d2f8f11
authored
4月 29, 2025
作者:
Ricardo Vieira
提交者:
Jesse Grabowski
5月 21, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Reuse LU decomposition in Solve
上级
88c07f4b
显示空白字符变更
内嵌
并排
正在显示
10 个修改的文件
包含
383 行增加
和
7 行删除
+383
-7
mode.py
pytensor/compile/mode.py
+2
-0
rewriting.py
pytensor/scan/rewriting.py
+4
-6
__init__.py
pytensor/tensor/__init__.py
+1
-0
__init__.py
pytensor/tensor/_linalg/__init__.py
+2
-0
__init__.py
pytensor/tensor/_linalg/solve/__init__.py
+2
-0
rewriting.py
pytensor/tensor/_linalg/solve/rewriting.py
+198
-0
linalg.py
pytensor/tensor/rewriting/linalg.py
+7
-0
__init__.py
tests/tensor/linalg/__init__.py
+0
-0
test_rewriting.py
tests/tensor/linalg/test_rewriting.py
+163
-0
test_blockwise.py
tests/tensor/test_blockwise.py
+4
-1
没有找到文件。
pytensor/compile/mode.py
浏览文件 @
9d2f8f11
...
@@ -490,6 +490,8 @@ PYTORCH = Mode(
...
@@ -490,6 +490,8 @@ PYTORCH = Mode(
"fusion"
,
"fusion"
,
"inplace"
,
"inplace"
,
"scan_save_mem_prealloc"
,
"scan_save_mem_prealloc"
,
"reuse_lu_decomposition_multiple_solves"
,
"scan_split_non_sequence_lu_decomposition_solve"
,
],
],
),
),
)
)
...
...
pytensor/scan/rewriting.py
浏览文件 @
9d2f8f11
...
@@ -2561,7 +2561,6 @@ scan_seqopt1.register(
...
@@ -2561,7 +2561,6 @@ scan_seqopt1.register(
position
=
1
,
position
=
1
,
)
)
scan_seqopt1
.
register
(
scan_seqopt1
.
register
(
"scan_push_out_non_seq"
,
"scan_push_out_non_seq"
,
in2out
(
scan_push_out_non_seq
,
ignore_newtrees
=
True
),
in2out
(
scan_push_out_non_seq
,
ignore_newtrees
=
True
),
...
@@ -2569,10 +2568,9 @@ scan_seqopt1.register(
...
@@ -2569,10 +2568,9 @@ scan_seqopt1.register(
"fast_run"
,
"fast_run"
,
"scan"
,
"scan"
,
"scan_pushout"
,
"scan_pushout"
,
position
=
2
,
position
=
3
,
)
)
scan_seqopt1
.
register
(
scan_seqopt1
.
register
(
"scan_push_out_seq"
,
"scan_push_out_seq"
,
in2out
(
scan_push_out_seq
,
ignore_newtrees
=
True
),
in2out
(
scan_push_out_seq
,
ignore_newtrees
=
True
),
...
@@ -2580,7 +2578,7 @@ scan_seqopt1.register(
...
@@ -2580,7 +2578,7 @@ scan_seqopt1.register(
"fast_run"
,
"fast_run"
,
"scan"
,
"scan"
,
"scan_pushout"
,
"scan_pushout"
,
position
=
3
,
position
=
4
,
)
)
...
@@ -2592,7 +2590,7 @@ scan_seqopt1.register(
...
@@ -2592,7 +2590,7 @@ scan_seqopt1.register(
"more_mem"
,
"more_mem"
,
"scan"
,
"scan"
,
"scan_pushout"
,
"scan_pushout"
,
position
=
4
,
position
=
5
,
)
)
...
@@ -2605,7 +2603,7 @@ scan_seqopt1.register(
...
@@ -2605,7 +2603,7 @@ scan_seqopt1.register(
"more_mem"
,
"more_mem"
,
"scan"
,
"scan"
,
"scan_pushout"
,
"scan_pushout"
,
position
=
5
,
position
=
6
,
)
)
scan_eqopt2
.
register
(
scan_eqopt2
.
register
(
...
...
pytensor/tensor/__init__.py
浏览文件 @
9d2f8f11
...
@@ -114,6 +114,7 @@ from pytensor.tensor import (
...
@@ -114,6 +114,7 @@ from pytensor.tensor import (
# isort: off
# isort: off
import
pytensor.tensor._linalg
from
pytensor.tensor
import
linalg
from
pytensor.tensor
import
linalg
from
pytensor.tensor
import
special
from
pytensor.tensor
import
special
from
pytensor.tensor
import
signal
from
pytensor.tensor
import
signal
...
...
pytensor/tensor/_linalg/__init__.py
0 → 100644
浏览文件 @
9d2f8f11
# Register rewrites
import
pytensor.tensor._linalg.solve
pytensor/tensor/_linalg/solve/__init__.py
0 → 100644
浏览文件 @
9d2f8f11
# Register rewrites in the database
import
pytensor.tensor._linalg.solve.rewriting
pytensor/tensor/_linalg/solve/rewriting.py
0 → 100644
浏览文件 @
9d2f8f11
from
collections.abc
import
Container
from
copy
import
copy
from
pytensor.graph
import
Constant
,
graph_inputs
from
pytensor.graph.rewriting.basic
import
copy_stack_trace
,
in2out
,
node_rewriter
from
pytensor.scan.op
import
Scan
from
pytensor.scan.rewriting
import
scan_seqopt1
from
pytensor.tensor.basic
import
atleast_Nd
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.elemwise
import
DimShuffle
from
pytensor.tensor.rewriting.basic
import
register_specialize
from
pytensor.tensor.rewriting.linalg
import
is_matrix_transpose
from
pytensor.tensor.slinalg
import
Solve
,
lu_factor
,
lu_solve
from
pytensor.tensor.variable
import
TensorVariable
def
decompose_A
(
A
,
assume_a
):
if
assume_a
==
"gen"
:
return
lu_factor
(
A
,
check_finite
=
False
)
else
:
raise
NotImplementedError
def
solve_lu_decomposed_system
(
A_decomp
,
b
,
b_ndim
,
assume_a
,
transposed
=
False
):
if
assume_a
==
"gen"
:
return
lu_solve
(
A_decomp
,
b
,
b_ndim
=
b_ndim
,
trans
=
transposed
)
else
:
raise
NotImplementedError
def
_split_lu_solve_steps
(
fgraph
,
node
,
*
,
eager
:
bool
,
allowed_assume_a
:
Container
[
str
]
):
if
not
isinstance
(
node
.
op
.
core_op
,
Solve
):
return
None
def
get_root_A
(
a
:
TensorVariable
)
->
tuple
[
TensorVariable
,
bool
]:
# Find the root variable of the first input to Solve
# If `a` is a left expand_dims or matrix transpose (DimShuffle variants),
# the root variable is the pre-DimShuffled input.
# Otherwise, `a` is considered the root variable.
# We also return whether the root `a` is transposed.
transposed
=
False
if
a
.
owner
is
not
None
and
isinstance
(
a
.
owner
.
op
,
DimShuffle
):
if
a
.
owner
.
op
.
is_left_expand_dims
:
[
a
]
=
a
.
owner
.
inputs
elif
is_matrix_transpose
(
a
):
[
a
]
=
a
.
owner
.
inputs
transposed
=
True
return
a
,
transposed
def
find_solve_clients
(
var
,
assume_a
):
clients
=
[]
for
cl
,
idx
in
fgraph
.
clients
[
var
]:
if
(
idx
==
0
and
isinstance
(
cl
.
op
,
Blockwise
)
and
isinstance
(
cl
.
op
.
core_op
,
Solve
)
and
(
cl
.
op
.
core_op
.
assume_a
==
assume_a
)
):
clients
.
append
(
cl
)
elif
isinstance
(
cl
.
op
,
DimShuffle
)
and
cl
.
op
.
is_left_expand_dims
:
# If it's a left expand_dims, recurse on the output
clients
.
extend
(
find_solve_clients
(
cl
.
outputs
[
0
],
assume_a
))
return
clients
assume_a
=
node
.
op
.
core_op
.
assume_a
if
assume_a
not
in
allowed_assume_a
:
return
None
A
,
_
=
get_root_A
(
node
.
inputs
[
0
])
# Find Solve using A (or left expand_dims of A)
# TODO: We could handle arbitrary shuffle of the batch dimensions, just need to propagate
# that to the A_decomp outputs
A_solve_clients_and_transpose
=
[
(
client
,
False
)
for
client
in
find_solve_clients
(
A
,
assume_a
)
]
# Find Solves using A.T
for
cl
,
_
in
fgraph
.
clients
[
A
]:
if
isinstance
(
cl
.
op
,
DimShuffle
)
and
is_matrix_transpose
(
cl
.
out
):
A_T
=
cl
.
out
A_solve_clients_and_transpose
.
extend
(
(
client
,
True
)
for
client
in
find_solve_clients
(
A_T
,
assume_a
)
)
if
not
eager
and
len
(
A_solve_clients_and_transpose
)
==
1
:
# If theres' a single use don't do it... unless it's being broadcast in a Blockwise (or we're eager)
# That's a "reuse" inside the inner vectorized loop
batch_ndim
=
node
.
op
.
batch_ndim
(
node
)
(
client
,
_
)
=
A_solve_clients_and_transpose
[
0
]
original_A
,
b
=
client
.
inputs
if
not
any
(
a_bcast
and
not
b_bcast
for
a_bcast
,
b_bcast
in
zip
(
original_A
.
type
.
broadcastable
[:
batch_ndim
],
b
.
type
.
broadcastable
[:
batch_ndim
],
strict
=
True
,
)
):
return
None
A_decomp
=
decompose_A
(
A
,
assume_a
=
assume_a
)
replacements
=
{}
for
client
,
transposed
in
A_solve_clients_and_transpose
:
_
,
b
=
client
.
inputs
b_ndim
=
client
.
op
.
core_op
.
b_ndim
new_x
=
solve_lu_decomposed_system
(
A_decomp
,
b
,
b_ndim
=
b_ndim
,
assume_a
=
assume_a
,
transposed
=
transposed
)
[
old_x
]
=
client
.
outputs
new_x
=
atleast_Nd
(
new_x
,
n
=
old_x
.
type
.
ndim
)
.
astype
(
old_x
.
type
.
dtype
)
copy_stack_trace
(
old_x
,
new_x
)
replacements
[
old_x
]
=
new_x
return
replacements
def
_scan_split_non_sequence_lu_decomposition_solve
(
fgraph
,
node
,
*
,
allowed_assume_a
:
Container
[
str
]
):
"""If the A of a Solve within a Scan is a function of non-sequences, split the LU decomposition step.
The LU decomposition step can then be pushed out of the inner loop by the `scan_pushout_non_sequences` rewrite.
"""
scan_op
:
Scan
=
node
.
op
non_sequences
=
set
(
scan_op
.
inner_non_seqs
(
scan_op
.
inner_inputs
))
new_scan_fgraph
=
scan_op
.
fgraph
changed
=
False
while
True
:
for
inner_node
in
new_scan_fgraph
.
toposort
():
if
(
isinstance
(
inner_node
.
op
,
Blockwise
)
and
isinstance
(
inner_node
.
op
.
core_op
,
Solve
)
and
inner_node
.
op
.
core_op
.
assume_a
in
allowed_assume_a
):
A
,
b
=
inner_node
.
inputs
if
all
(
(
isinstance
(
root_inp
,
Constant
)
or
(
root_inp
in
non_sequences
))
for
root_inp
in
graph_inputs
([
A
])
):
if
new_scan_fgraph
is
scan_op
.
fgraph
:
# Clone the first time to avoid mutating the original fgraph
new_scan_fgraph
,
equiv
=
new_scan_fgraph
.
clone_get_equiv
()
non_sequences
=
{
equiv
[
non_seq
]
for
non_seq
in
non_sequences
}
inner_node
=
equiv
[
inner_node
]
# type: ignore
replace_dict
=
_split_lu_solve_steps
(
new_scan_fgraph
,
inner_node
,
eager
=
True
,
allowed_assume_a
=
allowed_assume_a
,
)
assert
(
isinstance
(
replace_dict
,
dict
)
and
len
(
replace_dict
)
>
0
),
"Rewrite failed"
new_scan_fgraph
.
replace_all
(
replace_dict
.
items
())
changed
=
True
break
# Break to start over with a fresh toposort
else
:
# no_break
break
# Nothing else changed
if
not
changed
:
return
# Return a new scan to indicate that a rewrite was done
new_scan_op
=
copy
(
scan_op
)
new_scan_op
.
fgraph
=
new_scan_fgraph
new_outs
=
new_scan_op
.
make_node
(
*
node
.
inputs
)
.
outputs
copy_stack_trace
(
node
.
outputs
,
new_outs
)
return
new_outs
@register_specialize
@node_rewriter
([
Blockwise
])
def
reuse_lu_decomposition_multiple_solves
(
fgraph
,
node
):
return
_split_lu_solve_steps
(
fgraph
,
node
,
eager
=
False
,
allowed_assume_a
=
{
"gen"
})
@node_rewriter
([
Scan
])
def
scan_split_non_sequence_lu_decomposition_solve
(
fgraph
,
node
):
return
_scan_split_non_sequence_lu_decomposition_solve
(
fgraph
,
node
,
allowed_assume_a
=
{
"gen"
}
)
scan_seqopt1
.
register
(
"scan_split_non_sequence_lu_decomposition_solve"
,
in2out
(
scan_split_non_sequence_lu_decomposition_solve
,
ignore_newtrees
=
True
),
"fast_run"
,
"scan"
,
"scan_pushout"
,
position
=
2
,
)
pytensor/tensor/rewriting/linalg.py
浏览文件 @
9d2f8f11
...
@@ -75,6 +75,13 @@ def is_matrix_transpose(x: TensorVariable) -> bool:
...
@@ -75,6 +75,13 @@ def is_matrix_transpose(x: TensorVariable) -> bool:
if
ndims
<
2
:
if
ndims
<
2
:
return
False
return
False
transpose_order
=
(
*
range
(
ndims
-
2
),
ndims
-
1
,
ndims
-
2
)
transpose_order
=
(
*
range
(
ndims
-
2
),
ndims
-
1
,
ndims
-
2
)
# Allow expand_dims on the left of the transpose
if
(
diff
:
=
len
(
transpose_order
)
-
len
(
node
.
op
.
new_order
))
>
0
:
transpose_order
=
(
*
([
"x"
]
*
diff
),
*
transpose_order
,
)
return
node
.
op
.
new_order
==
transpose_order
return
node
.
op
.
new_order
==
transpose_order
return
False
return
False
...
...
tests/tensor/linalg/__init__.py
0 → 100644
浏览文件 @
9d2f8f11
tests/tensor/linalg/test_rewriting.py
0 → 100644
浏览文件 @
9d2f8f11
import
numpy
as
np
import
pytest
from
pytensor
import
config
,
function
,
scan
from
pytensor.compile.mode
import
get_default_mode
from
pytensor.gradient
import
grad
from
pytensor.scan.op
import
Scan
from
pytensor.tensor._linalg.solve.rewriting
import
(
reuse_lu_decomposition_multiple_solves
,
scan_split_non_sequence_lu_decomposition_solve
,
)
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.linalg
import
solve
from
pytensor.tensor.slinalg
import
LUFactor
,
Solve
,
SolveTriangular
from
pytensor.tensor.type
import
tensor
def
count_vanilla_solve_nodes
(
nodes
)
->
int
:
return
sum
(
(
isinstance
(
node
.
op
,
Solve
)
or
(
isinstance
(
node
.
op
,
Blockwise
)
and
isinstance
(
node
.
op
.
core_op
,
Solve
))
)
for
node
in
nodes
)
def
count_lu_decom_nodes
(
nodes
)
->
int
:
return
sum
(
(
isinstance
(
node
.
op
,
LUFactor
)
or
(
isinstance
(
node
.
op
,
Blockwise
)
and
isinstance
(
node
.
op
.
core_op
,
LUFactor
)
)
)
for
node
in
nodes
)
def
count_lu_solve_nodes
(
nodes
)
->
int
:
count
=
sum
(
(
isinstance
(
node
.
op
,
SolveTriangular
)
or
(
isinstance
(
node
.
op
,
Blockwise
)
and
isinstance
(
node
.
op
.
core_op
,
SolveTriangular
)
)
)
for
node
in
nodes
)
# Each LU solve uses two Triangular solves
return
count
//
2
@pytest.mark.parametrize
(
"transposed"
,
(
False
,
True
))
def
test_lu_decomposition_reused_forward_and_gradient
(
transposed
):
rewrite_name
=
reuse_lu_decomposition_multiple_solves
.
__name__
mode
=
get_default_mode
()
A
=
tensor
(
"A"
,
shape
=
(
2
,
2
))
b
=
tensor
(
"b"
,
shape
=
(
2
,
3
))
x
=
solve
(
A
,
b
,
assume_a
=
"gen"
,
transposed
=
transposed
)
grad_x_wrt_A
=
grad
(
x
.
sum
(),
A
)
fn_no_opt
=
function
([
A
,
b
],
[
x
,
grad_x_wrt_A
],
mode
=
mode
.
excluding
(
rewrite_name
))
no_opt_nodes
=
fn_no_opt
.
maker
.
fgraph
.
apply_nodes
assert
count_vanilla_solve_nodes
(
no_opt_nodes
)
==
2
assert
count_lu_decom_nodes
(
no_opt_nodes
)
==
0
assert
count_lu_solve_nodes
(
no_opt_nodes
)
==
0
fn_opt
=
function
([
A
,
b
],
[
x
,
grad_x_wrt_A
],
mode
=
mode
.
including
(
rewrite_name
))
opt_nodes
=
fn_opt
.
maker
.
fgraph
.
apply_nodes
assert
count_vanilla_solve_nodes
(
opt_nodes
)
==
0
assert
count_lu_decom_nodes
(
opt_nodes
)
==
1
assert
count_lu_solve_nodes
(
opt_nodes
)
==
2
# Make sure results are correct
rng
=
np
.
random
.
default_rng
(
31
)
A_test
=
rng
.
random
(
A
.
type
.
shape
,
dtype
=
A
.
type
.
dtype
)
b_test
=
rng
.
random
(
b
.
type
.
shape
,
dtype
=
b
.
type
.
dtype
)
resx0
,
resg0
=
fn_no_opt
(
A_test
,
b_test
)
resx1
,
resg1
=
fn_opt
(
A_test
,
b_test
)
rtol
=
1e-7
if
config
.
floatX
==
"float64"
else
1e-6
np
.
testing
.
assert_allclose
(
resx0
,
resx1
,
rtol
=
rtol
)
np
.
testing
.
assert_allclose
(
resg0
,
resg1
,
rtol
=
rtol
)
@pytest.mark.parametrize
(
"transposed"
,
(
False
,
True
))
def
test_lu_decomposition_reused_blockwise
(
transposed
):
rewrite_name
=
reuse_lu_decomposition_multiple_solves
.
__name__
mode
=
get_default_mode
()
A
=
tensor
(
"A"
,
shape
=
(
2
,
2
))
b
=
tensor
(
"b"
,
shape
=
(
2
,
2
,
3
))
x
=
solve
(
A
,
b
,
transposed
=
transposed
)
fn_no_opt
=
function
([
A
,
b
],
[
x
],
mode
=
mode
.
excluding
(
rewrite_name
))
no_opt_nodes
=
fn_no_opt
.
maker
.
fgraph
.
apply_nodes
assert
count_vanilla_solve_nodes
(
no_opt_nodes
)
==
1
assert
count_lu_decom_nodes
(
no_opt_nodes
)
==
0
assert
count_lu_solve_nodes
(
no_opt_nodes
)
==
0
fn_opt
=
function
([
A
,
b
],
[
x
],
mode
=
mode
.
including
(
rewrite_name
))
opt_nodes
=
fn_opt
.
maker
.
fgraph
.
apply_nodes
assert
count_vanilla_solve_nodes
(
opt_nodes
)
==
0
assert
count_lu_decom_nodes
(
opt_nodes
)
==
1
assert
count_lu_solve_nodes
(
opt_nodes
)
==
1
# Make sure results are correct
rng
=
np
.
random
.
default_rng
(
31
)
A_test
=
rng
.
random
(
A
.
type
.
shape
,
dtype
=
A
.
type
.
dtype
)
b_test
=
rng
.
random
(
b
.
type
.
shape
,
dtype
=
b
.
type
.
dtype
)
resx0
=
fn_no_opt
(
A_test
,
b_test
)
resx1
=
fn_opt
(
A_test
,
b_test
)
np
.
testing
.
assert_allclose
(
resx0
,
resx1
)
@pytest.mark.parametrize
(
"transposed"
,
(
False
,
True
))
def
test_lu_decomposition_reused_scan
(
transposed
):
rewrite_name
=
scan_split_non_sequence_lu_decomposition_solve
.
__name__
mode
=
get_default_mode
()
A
=
tensor
(
"A"
,
shape
=
(
2
,
2
))
x0
=
tensor
(
"b"
,
shape
=
(
2
,
3
))
xs
,
_
=
scan
(
lambda
xtm1
,
A
:
solve
(
A
,
xtm1
,
assume_a
=
"general"
,
transposed
=
transposed
),
outputs_info
=
[
x0
],
non_sequences
=
[
A
],
n_steps
=
10
,
)
fn_no_opt
=
function
(
[
A
,
x0
],
[
xs
],
mode
=
mode
.
excluding
(
rewrite_name
),
)
[
no_opt_scan_node
]
=
[
node
for
node
in
fn_no_opt
.
maker
.
fgraph
.
apply_nodes
if
isinstance
(
node
.
op
,
Scan
)
]
no_opt_nodes
=
no_opt_scan_node
.
op
.
fgraph
.
apply_nodes
assert
count_vanilla_solve_nodes
(
no_opt_nodes
)
==
1
assert
count_lu_decom_nodes
(
no_opt_nodes
)
==
0
assert
count_lu_solve_nodes
(
no_opt_nodes
)
==
0
fn_opt
=
function
([
A
,
x0
],
[
xs
],
mode
=
mode
.
including
(
"scan"
,
rewrite_name
))
[
opt_scan_node
]
=
[
node
for
node
in
fn_opt
.
maker
.
fgraph
.
apply_nodes
if
isinstance
(
node
.
op
,
Scan
)
]
opt_nodes
=
opt_scan_node
.
op
.
fgraph
.
apply_nodes
assert
count_vanilla_solve_nodes
(
opt_nodes
)
==
0
# The LU decomp is outside of the scan!
assert
count_lu_decom_nodes
(
opt_nodes
)
==
0
assert
count_lu_solve_nodes
(
opt_nodes
)
==
1
# Make sure results are correct
rng
=
np
.
random
.
default_rng
(
170
)
A_test
=
rng
.
random
(
A
.
type
.
shape
,
dtype
=
A
.
type
.
dtype
)
x0_test
=
rng
.
random
(
x0
.
type
.
shape
,
dtype
=
x0
.
type
.
dtype
)
resx0
=
fn_no_opt
(
A_test
,
x0_test
)
resx1
=
fn_opt
(
A_test
,
x0_test
)
rtol
=
1e-7
if
config
.
floatX
==
"float64"
else
1e-6
np
.
testing
.
assert_allclose
(
resx0
,
resx1
,
rtol
=
rtol
)
tests/tensor/test_blockwise.py
浏览文件 @
9d2f8f11
...
@@ -579,7 +579,10 @@ class TestInplace:
...
@@ -579,7 +579,10 @@ class TestInplace:
else
:
else
:
x
=
solve_fn
(
A
,
b
,
b_ndim
=
1
)
x
=
solve_fn
(
A
,
b
,
b_ndim
=
1
)
mode
=
get_default_mode
()
.
excluding
(
"batched_vector_b_solve_to_matrix_b_solve"
)
mode
=
get_default_mode
()
.
excluding
(
"batched_vector_b_solve_to_matrix_b_solve"
,
"reuse_lu_decomposition_multiple_solves"
,
)
fn
=
function
([
In
(
A
,
mutable
=
True
),
In
(
b
,
mutable
=
True
)],
x
,
mode
=
mode
)
fn
=
function
([
In
(
A
,
mutable
=
True
),
In
(
b
,
mutable
=
True
)],
x
,
mode
=
mode
)
op
=
fn
.
maker
.
fgraph
.
outputs
[
0
]
.
owner
.
op
op
=
fn
.
maker
.
fgraph
.
outputs
[
0
]
.
owner
.
op
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论