Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
1d5b1d94
提交
1d5b1d94
authored
8月 29, 2025
作者:
ricardoV94
提交者:
Ricardo Vieira
9月 20, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Replace uses of in2out and out2in by a depth-first search rewriter
上级
aa7e4d6b
隐藏空白字符变更
内嵌
并排
正在显示
17 个修改的文件
包含
86 行增加
和
101 行删除
+86
-101
basic.py
pytensor/graph/rewriting/basic.py
+16
-5
rewriting.py
pytensor/scan/rewriting.py
+14
-14
rewriting.py
pytensor/tensor/_linalg/solve/rewriting.py
+6
-4
basic.py
pytensor/tensor/random/rewriting/basic.py
+6
-2
jax.py
pytensor/tensor/random/rewriting/jax.py
+11
-47
numba.py
pytensor/tensor/random/rewriting/numba.py
+2
-2
basic.py
pytensor/tensor/rewriting/basic.py
+2
-1
blas.py
pytensor/tensor/rewriting/blas.py
+4
-4
blas_c.py
pytensor/tensor/rewriting/blas_c.py
+5
-3
blockwise.py
pytensor/tensor/rewriting/blockwise.py
+2
-2
elemwise.py
pytensor/tensor/rewriting/elemwise.py
+4
-3
jax.py
pytensor/tensor/rewriting/jax.py
+4
-4
linalg.py
pytensor/tensor/rewriting/linalg.py
+2
-2
numba.py
pytensor/tensor/rewriting/numba.py
+2
-2
ofg.py
pytensor/tensor/rewriting/ofg.py
+2
-2
subtensor.py
pytensor/tensor/rewriting/subtensor.py
+2
-2
utils.py
pytensor/xtensor/rewriting/utils.py
+2
-2
没有找到文件。
pytensor/graph/rewriting/basic.py
浏览文件 @
1d5b1d94
...
@@ -27,7 +27,12 @@ from pytensor.graph.features import AlreadyThere, Feature
...
@@ -27,7 +27,12 @@ from pytensor.graph.features import AlreadyThere, Feature
from
pytensor.graph.fg
import
FunctionGraph
,
Output
from
pytensor.graph.fg
import
FunctionGraph
,
Output
from
pytensor.graph.op
import
Op
from
pytensor.graph.op
import
Op
from
pytensor.graph.rewriting.unify
import
OpPattern
,
Var
,
convert_strs_to_vars
from
pytensor.graph.rewriting.unify
import
OpPattern
,
Var
,
convert_strs_to_vars
from
pytensor.graph.traversal
import
applys_between
,
toposort
,
vars_between
from
pytensor.graph.traversal
import
(
apply_ancestors
,
applys_between
,
toposort
,
vars_between
,
)
from
pytensor.graph.utils
import
AssocList
,
InconsistencyError
from
pytensor.graph.utils
import
AssocList
,
InconsistencyError
from
pytensor.misc.ordered_set
import
OrderedSet
from
pytensor.misc.ordered_set
import
OrderedSet
from
pytensor.utils
import
flatten
from
pytensor.utils
import
flatten
...
@@ -1995,12 +2000,13 @@ class WalkingGraphRewriter(NodeProcessingGraphRewriter):
...
@@ -1995,12 +2000,13 @@ class WalkingGraphRewriter(NodeProcessingGraphRewriter):
def
__init__
(
def
__init__
(
self
,
self
,
node_rewriter
:
NodeRewriter
,
node_rewriter
:
NodeRewriter
,
order
:
Literal
[
"out_to_in"
,
"in_to_out"
]
=
"in_to_out"
,
order
:
Literal
[
"out_to_in"
,
"in_to_out"
,
"dfs"
]
=
"in_to_out"
,
ignore_newtrees
:
bool
=
False
,
ignore_newtrees
:
bool
=
False
,
failure_callback
:
FailureCallbackType
|
None
=
None
,
failure_callback
:
FailureCallbackType
|
None
=
None
,
):
):
if
order
not
in
(
"out_to_in"
,
"in_to_out"
):
valid_orders
=
(
"out_to_in"
,
"in_to_out"
,
"dfs"
)
raise
ValueError
(
"order must be 'out_to_in' or 'in_to_out'"
)
if
order
not
in
valid_orders
:
raise
ValueError
(
f
"order must be one of {valid_orders}, got {order}"
)
self
.
order
=
order
self
.
order
=
order
super
()
.
__init__
(
node_rewriter
,
ignore_newtrees
,
failure_callback
)
super
()
.
__init__
(
node_rewriter
,
ignore_newtrees
,
failure_callback
)
...
@@ -2010,7 +2016,11 @@ class WalkingGraphRewriter(NodeProcessingGraphRewriter):
...
@@ -2010,7 +2016,11 @@ class WalkingGraphRewriter(NodeProcessingGraphRewriter):
callback_before
=
fgraph
.
execute_callbacks_time
callback_before
=
fgraph
.
execute_callbacks_time
nb_nodes_start
=
len
(
fgraph
.
apply_nodes
)
nb_nodes_start
=
len
(
fgraph
.
apply_nodes
)
t0
=
time
.
perf_counter
()
t0
=
time
.
perf_counter
()
q
=
deque
(
toposort
(
start_from
))
q
=
deque
(
apply_ancestors
(
start_from
)
if
(
self
.
order
==
"dfs"
)
else
toposort
(
start_from
)
)
io_t
=
time
.
perf_counter
()
-
t0
io_t
=
time
.
perf_counter
()
-
t0
def
importer
(
node
):
def
importer
(
node
):
...
@@ -2134,6 +2144,7 @@ def walking_rewriter(
...
@@ -2134,6 +2144,7 @@ def walking_rewriter(
in2out
=
partial
(
walking_rewriter
,
"in_to_out"
)
in2out
=
partial
(
walking_rewriter
,
"in_to_out"
)
out2in
=
partial
(
walking_rewriter
,
"out_to_in"
)
out2in
=
partial
(
walking_rewriter
,
"out_to_in"
)
dfs_rewriter
=
partial
(
walking_rewriter
,
"dfs"
)
class
ChangeTracker
(
Feature
):
class
ChangeTracker
(
Feature
):
...
...
pytensor/scan/rewriting.py
浏览文件 @
1d5b1d94
...
@@ -29,7 +29,7 @@ from pytensor.graph.rewriting.basic import (
...
@@ -29,7 +29,7 @@ from pytensor.graph.rewriting.basic import (
EquilibriumGraphRewriter
,
EquilibriumGraphRewriter
,
GraphRewriter
,
GraphRewriter
,
copy_stack_trace
,
copy_stack_trace
,
in2out
,
dfs_rewriter
,
node_rewriter
,
node_rewriter
,
)
)
from
pytensor.graph.rewriting.db
import
EquilibriumDB
,
SequenceDB
from
pytensor.graph.rewriting.db
import
EquilibriumDB
,
SequenceDB
...
@@ -2558,7 +2558,7 @@ optdb.register("scan_eqopt2", scan_eqopt2, "fast_run", "scan", position=1.6)
...
@@ -2558,7 +2558,7 @@ optdb.register("scan_eqopt2", scan_eqopt2, "fast_run", "scan", position=1.6)
# ScanSaveMem should execute only once per node.
# ScanSaveMem should execute only once per node.
optdb
.
register
(
optdb
.
register
(
"scan_save_mem_prealloc"
,
"scan_save_mem_prealloc"
,
in2out
(
scan_save_mem_prealloc
,
ignore_newtrees
=
True
),
dfs_rewriter
(
scan_save_mem_prealloc
,
ignore_newtrees
=
True
),
"fast_run"
,
"fast_run"
,
"scan"
,
"scan"
,
"scan_save_mem"
,
"scan_save_mem"
,
...
@@ -2566,7 +2566,7 @@ optdb.register(
...
@@ -2566,7 +2566,7 @@ optdb.register(
)
)
optdb
.
register
(
optdb
.
register
(
"scan_save_mem_no_prealloc"
,
"scan_save_mem_no_prealloc"
,
in2out
(
scan_save_mem_no_prealloc
,
ignore_newtrees
=
True
),
dfs_rewriter
(
scan_save_mem_no_prealloc
,
ignore_newtrees
=
True
),
"numba"
,
"numba"
,
"jax"
,
"jax"
,
"pytorch"
,
"pytorch"
,
...
@@ -2587,7 +2587,7 @@ scan_eqopt1.register("all_pushout_opt", scan_seqopt1, "fast_run", "scan")
...
@@ -2587,7 +2587,7 @@ scan_eqopt1.register("all_pushout_opt", scan_seqopt1, "fast_run", "scan")
scan_seqopt1
.
register
(
scan_seqopt1
.
register
(
"scan_remove_constants_and_unused_inputs0"
,
"scan_remove_constants_and_unused_inputs0"
,
in2out
(
remove_constants_and_unused_inputs_scan
,
ignore_newtrees
=
True
),
dfs_rewriter
(
remove_constants_and_unused_inputs_scan
,
ignore_newtrees
=
True
),
"remove_constants_and_unused_inputs_scan"
,
"remove_constants_and_unused_inputs_scan"
,
"fast_run"
,
"fast_run"
,
"scan"
,
"scan"
,
...
@@ -2596,7 +2596,7 @@ scan_seqopt1.register(
...
@@ -2596,7 +2596,7 @@ scan_seqopt1.register(
scan_seqopt1
.
register
(
scan_seqopt1
.
register
(
"scan_push_out_non_seq"
,
"scan_push_out_non_seq"
,
in2out
(
scan_push_out_non_seq
,
ignore_newtrees
=
True
),
dfs_rewriter
(
scan_push_out_non_seq
,
ignore_newtrees
=
True
),
"scan_pushout_nonseqs_ops"
,
# For backcompat: so it can be tagged with old name
"scan_pushout_nonseqs_ops"
,
# For backcompat: so it can be tagged with old name
"fast_run"
,
"fast_run"
,
"scan"
,
"scan"
,
...
@@ -2606,7 +2606,7 @@ scan_seqopt1.register(
...
@@ -2606,7 +2606,7 @@ scan_seqopt1.register(
scan_seqopt1
.
register
(
scan_seqopt1
.
register
(
"scan_push_out_seq"
,
"scan_push_out_seq"
,
in2out
(
scan_push_out_seq
,
ignore_newtrees
=
True
),
dfs_rewriter
(
scan_push_out_seq
,
ignore_newtrees
=
True
),
"scan_pushout_seqs_ops"
,
# For backcompat: so it can be tagged with old name
"scan_pushout_seqs_ops"
,
# For backcompat: so it can be tagged with old name
"fast_run"
,
"fast_run"
,
"scan"
,
"scan"
,
...
@@ -2617,7 +2617,7 @@ scan_seqopt1.register(
...
@@ -2617,7 +2617,7 @@ scan_seqopt1.register(
scan_seqopt1
.
register
(
scan_seqopt1
.
register
(
"scan_push_out_dot1"
,
"scan_push_out_dot1"
,
in2out
(
scan_push_out_dot1
,
ignore_newtrees
=
True
),
dfs_rewriter
(
scan_push_out_dot1
,
ignore_newtrees
=
True
),
"scan_pushout_dot1"
,
# For backcompat: so it can be tagged with old name
"scan_pushout_dot1"
,
# For backcompat: so it can be tagged with old name
"fast_run"
,
"fast_run"
,
"more_mem"
,
"more_mem"
,
...
@@ -2630,7 +2630,7 @@ scan_seqopt1.register(
...
@@ -2630,7 +2630,7 @@ scan_seqopt1.register(
scan_seqopt1
.
register
(
scan_seqopt1
.
register
(
"scan_push_out_add"
,
"scan_push_out_add"
,
# TODO: Perhaps this should be an `EquilibriumGraphRewriter`?
# TODO: Perhaps this should be an `EquilibriumGraphRewriter`?
in2out
(
scan_push_out_add
,
ignore_newtrees
=
False
),
dfs_rewriter
(
scan_push_out_add
,
ignore_newtrees
=
False
),
"scan_pushout_add"
,
# For backcompat: so it can be tagged with old name
"scan_pushout_add"
,
# For backcompat: so it can be tagged with old name
"fast_run"
,
"fast_run"
,
"more_mem"
,
"more_mem"
,
...
@@ -2641,14 +2641,14 @@ scan_seqopt1.register(
...
@@ -2641,14 +2641,14 @@ scan_seqopt1.register(
scan_eqopt2
.
register
(
scan_eqopt2
.
register
(
"while_scan_merge_subtensor_last_element"
,
"while_scan_merge_subtensor_last_element"
,
in2out
(
while_scan_merge_subtensor_last_element
,
ignore_newtrees
=
True
),
dfs_rewriter
(
while_scan_merge_subtensor_last_element
,
ignore_newtrees
=
True
),
"fast_run"
,
"fast_run"
,
"scan"
,
"scan"
,
)
)
scan_eqopt2
.
register
(
scan_eqopt2
.
register
(
"constant_folding_for_scan2"
,
"constant_folding_for_scan2"
,
in2out
(
constant_folding
,
ignore_newtrees
=
True
),
dfs_rewriter
(
constant_folding
,
ignore_newtrees
=
True
),
"fast_run"
,
"fast_run"
,
"scan"
,
"scan"
,
)
)
...
@@ -2656,7 +2656,7 @@ scan_eqopt2.register(
...
@@ -2656,7 +2656,7 @@ scan_eqopt2.register(
scan_eqopt2
.
register
(
scan_eqopt2
.
register
(
"scan_remove_constants_and_unused_inputs1"
,
"scan_remove_constants_and_unused_inputs1"
,
in2out
(
remove_constants_and_unused_inputs_scan
,
ignore_newtrees
=
True
),
dfs_rewriter
(
remove_constants_and_unused_inputs_scan
,
ignore_newtrees
=
True
),
"remove_constants_and_unused_inputs_scan"
,
"remove_constants_and_unused_inputs_scan"
,
"fast_run"
,
"fast_run"
,
"scan"
,
"scan"
,
...
@@ -2671,7 +2671,7 @@ scan_eqopt2.register("scan_merge", ScanMerge(), "fast_run", "scan")
...
@@ -2671,7 +2671,7 @@ scan_eqopt2.register("scan_merge", ScanMerge(), "fast_run", "scan")
# After Merge optimization
# After Merge optimization
scan_eqopt2
.
register
(
scan_eqopt2
.
register
(
"scan_remove_constants_and_unused_inputs2"
,
"scan_remove_constants_and_unused_inputs2"
,
in2out
(
remove_constants_and_unused_inputs_scan
,
ignore_newtrees
=
True
),
dfs_rewriter
(
remove_constants_and_unused_inputs_scan
,
ignore_newtrees
=
True
),
"remove_constants_and_unused_inputs_scan"
,
"remove_constants_and_unused_inputs_scan"
,
"fast_run"
,
"fast_run"
,
"scan"
,
"scan"
,
...
@@ -2679,7 +2679,7 @@ scan_eqopt2.register(
...
@@ -2679,7 +2679,7 @@ scan_eqopt2.register(
scan_eqopt2
.
register
(
scan_eqopt2
.
register
(
"scan_merge_inouts"
,
"scan_merge_inouts"
,
in2out
(
scan_merge_inouts
,
ignore_newtrees
=
True
),
dfs_rewriter
(
scan_merge_inouts
,
ignore_newtrees
=
True
),
"fast_run"
,
"fast_run"
,
"scan"
,
"scan"
,
)
)
...
@@ -2687,7 +2687,7 @@ scan_eqopt2.register(
...
@@ -2687,7 +2687,7 @@ scan_eqopt2.register(
# After everything else
# After everything else
scan_eqopt2
.
register
(
scan_eqopt2
.
register
(
"scan_remove_constants_and_unused_inputs3"
,
"scan_remove_constants_and_unused_inputs3"
,
in2out
(
remove_constants_and_unused_inputs_scan
,
ignore_newtrees
=
True
),
dfs_rewriter
(
remove_constants_and_unused_inputs_scan
,
ignore_newtrees
=
True
),
"remove_constants_and_unused_inputs_scan"
,
"remove_constants_and_unused_inputs_scan"
,
"fast_run"
,
"fast_run"
,
"scan"
,
"scan"
,
...
...
pytensor/tensor/_linalg/solve/rewriting.py
浏览文件 @
1d5b1d94
...
@@ -3,7 +3,7 @@ from copy import copy
...
@@ -3,7 +3,7 @@ from copy import copy
from
pytensor.compile
import
optdb
from
pytensor.compile
import
optdb
from
pytensor.graph
import
Constant
,
graph_inputs
from
pytensor.graph
import
Constant
,
graph_inputs
from
pytensor.graph.rewriting.basic
import
copy_stack_trace
,
in2out
,
node_rewriter
from
pytensor.graph.rewriting.basic
import
copy_stack_trace
,
dfs_rewriter
,
node_rewriter
from
pytensor.scan.op
import
Scan
from
pytensor.scan.op
import
Scan
from
pytensor.scan.rewriting
import
scan_seqopt1
from
pytensor.scan.rewriting
import
scan_seqopt1
from
pytensor.tensor._linalg.solve.tridiagonal
import
(
from
pytensor.tensor._linalg.solve.tridiagonal
import
(
...
@@ -244,7 +244,7 @@ def scan_split_non_sequence_decomposition_and_solve(fgraph, node):
...
@@ -244,7 +244,7 @@ def scan_split_non_sequence_decomposition_and_solve(fgraph, node):
scan_seqopt1
.
register
(
scan_seqopt1
.
register
(
scan_split_non_sequence_decomposition_and_solve
.
__name__
,
scan_split_non_sequence_decomposition_and_solve
.
__name__
,
in2out
(
scan_split_non_sequence_decomposition_and_solve
,
ignore_newtrees
=
True
),
dfs_rewriter
(
scan_split_non_sequence_decomposition_and_solve
,
ignore_newtrees
=
True
),
"fast_run"
,
"fast_run"
,
"scan"
,
"scan"
,
"scan_pushout"
,
"scan_pushout"
,
...
@@ -261,7 +261,7 @@ def reuse_decomposition_multiple_solves_jax(fgraph, node):
...
@@ -261,7 +261,7 @@ def reuse_decomposition_multiple_solves_jax(fgraph, node):
optdb
[
"specialize"
]
.
register
(
optdb
[
"specialize"
]
.
register
(
reuse_decomposition_multiple_solves_jax
.
__name__
,
reuse_decomposition_multiple_solves_jax
.
__name__
,
in2out
(
reuse_decomposition_multiple_solves_jax
,
ignore_newtrees
=
True
),
dfs_rewriter
(
reuse_decomposition_multiple_solves_jax
,
ignore_newtrees
=
True
),
"jax"
,
"jax"
,
use_db_name_as_tag
=
False
,
use_db_name_as_tag
=
False
,
)
)
...
@@ -276,7 +276,9 @@ def scan_split_non_sequence_decomposition_and_solve_jax(fgraph, node):
...
@@ -276,7 +276,9 @@ def scan_split_non_sequence_decomposition_and_solve_jax(fgraph, node):
scan_seqopt1
.
register
(
scan_seqopt1
.
register
(
scan_split_non_sequence_decomposition_and_solve_jax
.
__name__
,
scan_split_non_sequence_decomposition_and_solve_jax
.
__name__
,
in2out
(
scan_split_non_sequence_decomposition_and_solve_jax
,
ignore_newtrees
=
True
),
dfs_rewriter
(
scan_split_non_sequence_decomposition_and_solve_jax
,
ignore_newtrees
=
True
),
"jax"
,
"jax"
,
use_db_name_as_tag
=
False
,
use_db_name_as_tag
=
False
,
position
=
2
,
position
=
2
,
...
...
pytensor/tensor/random/rewriting/basic.py
浏览文件 @
1d5b1d94
...
@@ -4,7 +4,11 @@ from pytensor.compile import optdb
...
@@ -4,7 +4,11 @@ from pytensor.compile import optdb
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph
import
ancestors
from
pytensor.graph
import
ancestors
from
pytensor.graph.op
import
compute_test_value
from
pytensor.graph.op
import
compute_test_value
from
pytensor.graph.rewriting.basic
import
copy_stack_trace
,
in2out
,
node_rewriter
from
pytensor.graph.rewriting.basic
import
(
copy_stack_trace
,
dfs_rewriter
,
node_rewriter
,
)
from
pytensor.tensor
import
NoneConst
,
TensorVariable
from
pytensor.tensor
import
NoneConst
,
TensorVariable
from
pytensor.tensor.basic
import
constant
from
pytensor.tensor.basic
import
constant
from
pytensor.tensor.elemwise
import
DimShuffle
from
pytensor.tensor.elemwise
import
DimShuffle
...
@@ -57,7 +61,7 @@ def random_make_inplace(fgraph, node):
...
@@ -57,7 +61,7 @@ def random_make_inplace(fgraph, node):
optdb
.
register
(
optdb
.
register
(
"random_make_inplace"
,
"random_make_inplace"
,
in2out
(
random_make_inplace
,
ignore_newtrees
=
True
),
dfs_rewriter
(
random_make_inplace
,
ignore_newtrees
=
True
),
"fast_run"
,
"fast_run"
,
"inplace"
,
"inplace"
,
position
=
50.9
,
position
=
50.9
,
...
...
pytensor/tensor/random/rewriting/jax.py
浏览文件 @
1d5b1d94
...
@@ -2,8 +2,7 @@ import re
...
@@ -2,8 +2,7 @@ import re
from
pytensor.compile
import
optdb
from
pytensor.compile
import
optdb
from
pytensor.graph
import
Constant
from
pytensor.graph
import
Constant
from
pytensor.graph.rewriting.basic
import
in2out
,
node_rewriter
from
pytensor.graph.rewriting.basic
import
dfs_rewriter
,
in2out
,
node_rewriter
from
pytensor.graph.rewriting.db
import
SequenceDB
from
pytensor.tensor
import
abs
as
abs_t
from
pytensor.tensor
import
abs
as
abs_t
from
pytensor.tensor
import
broadcast_arrays
,
exp
,
floor
,
log
,
log1p
,
reciprocal
,
sqrt
from
pytensor.tensor
import
broadcast_arrays
,
exp
,
floor
,
log
,
log1p
,
reciprocal
,
sqrt
from
pytensor.tensor.basic
import
(
from
pytensor.tensor.basic
import
(
...
@@ -179,51 +178,16 @@ def materialize_implicit_arange_choice_without_replacement(fgraph, node):
...
@@ -179,51 +178,16 @@ def materialize_implicit_arange_choice_without_replacement(fgraph, node):
return
new_op
.
make_node
(
rng
,
size
,
a_vector_param
,
*
other_params
)
.
outputs
return
new_op
.
make_node
(
rng
,
size
,
a_vector_param
,
*
other_params
)
.
outputs
random_vars_opt
=
SequenceDB
()
random_vars_opt
=
dfs_rewriter
(
random_vars_opt
.
register
(
lognormal_from_normal
,
"lognormal_from_normal"
,
halfnormal_from_normal
,
in2out
(
lognormal_from_normal
),
geometric_from_uniform
,
"jax"
,
negative_binomial_from_gamma_poisson
,
)
inverse_gamma_from_gamma
,
random_vars_opt
.
register
(
generalized_gamma_from_gamma
,
"halfnormal_from_normal"
,
wald_from_normal_uniform
,
in2out
(
halfnormal_from_normal
),
beta_binomial_from_beta_binomial
,
"jax"
,
materialize_implicit_arange_choice_without_replacement
,
)
random_vars_opt
.
register
(
"geometric_from_uniform"
,
in2out
(
geometric_from_uniform
),
"jax"
,
)
random_vars_opt
.
register
(
"negative_binomial_from_gamma_poisson"
,
in2out
(
negative_binomial_from_gamma_poisson
),
"jax"
,
)
random_vars_opt
.
register
(
"inverse_gamma_from_gamma"
,
in2out
(
inverse_gamma_from_gamma
),
"jax"
,
)
random_vars_opt
.
register
(
"generalized_gamma_from_gamma"
,
in2out
(
generalized_gamma_from_gamma
),
"jax"
,
)
random_vars_opt
.
register
(
"wald_from_normal_uniform"
,
in2out
(
wald_from_normal_uniform
),
"jax"
,
)
random_vars_opt
.
register
(
"beta_binomial_from_beta_binomial"
,
in2out
(
beta_binomial_from_beta_binomial
),
"jax"
,
)
random_vars_opt
.
register
(
"materialize_implicit_arange_choice_without_replacement"
,
in2out
(
materialize_implicit_arange_choice_without_replacement
),
"jax"
,
)
)
optdb
.
register
(
"jax_random_vars_rewrites"
,
random_vars_opt
,
"jax"
,
position
=
110
)
optdb
.
register
(
"jax_random_vars_rewrites"
,
random_vars_opt
,
"jax"
,
position
=
110
)
...
...
pytensor/tensor/random/rewriting/numba.py
浏览文件 @
1d5b1d94
from
pytensor.compile
import
optdb
from
pytensor.compile
import
optdb
from
pytensor.graph
import
node_rewriter
from
pytensor.graph
import
node_rewriter
from
pytensor.graph.rewriting.basic
import
out2in
from
pytensor.graph.rewriting.basic
import
dfs_rewriter
from
pytensor.tensor
import
as_tensor
,
constant
from
pytensor.tensor
import
as_tensor
,
constant
from
pytensor.tensor.random.op
import
RandomVariable
,
RandomVariableWithCoreShape
from
pytensor.tensor.random.op
import
RandomVariable
,
RandomVariableWithCoreShape
from
pytensor.tensor.rewriting.shape
import
ShapeFeature
from
pytensor.tensor.rewriting.shape
import
ShapeFeature
...
@@ -82,7 +82,7 @@ def introduce_explicit_core_shape_rv(fgraph, node):
...
@@ -82,7 +82,7 @@ def introduce_explicit_core_shape_rv(fgraph, node):
optdb
.
register
(
optdb
.
register
(
introduce_explicit_core_shape_rv
.
__name__
,
introduce_explicit_core_shape_rv
.
__name__
,
out2in
(
introduce_explicit_core_shape_rv
),
dfs_rewriter
(
introduce_explicit_core_shape_rv
),
"numba"
,
"numba"
,
position
=
100
,
position
=
100
,
)
)
pytensor/tensor/rewriting/basic.py
浏览文件 @
1d5b1d94
...
@@ -35,6 +35,7 @@ from pytensor.graph.rewriting.basic import (
...
@@ -35,6 +35,7 @@ from pytensor.graph.rewriting.basic import (
NodeRewriter
,
NodeRewriter
,
Rewriter
,
Rewriter
,
copy_stack_trace
,
copy_stack_trace
,
dfs_rewriter
,
in2out
,
in2out
,
node_rewriter
,
node_rewriter
,
)
)
...
@@ -538,7 +539,7 @@ def local_alloc_empty_to_zeros(fgraph, node):
...
@@ -538,7 +539,7 @@ def local_alloc_empty_to_zeros(fgraph, node):
compile
.
optdb
.
register
(
compile
.
optdb
.
register
(
"local_alloc_empty_to_zeros"
,
"local_alloc_empty_to_zeros"
,
in2out
(
local_alloc_empty_to_zeros
),
dfs_rewriter
(
local_alloc_empty_to_zeros
),
# After move to gpu and merge2, before inplace.
# After move to gpu and merge2, before inplace.
"alloc_empty_to_zeros"
,
"alloc_empty_to_zeros"
,
position
=
49.3
,
position
=
49.3
,
...
...
pytensor/tensor/rewriting/blas.py
浏览文件 @
1d5b1d94
...
@@ -77,7 +77,7 @@ from pytensor.graph.rewriting.basic import (
...
@@ -77,7 +77,7 @@ from pytensor.graph.rewriting.basic import (
EquilibriumGraphRewriter
,
EquilibriumGraphRewriter
,
GraphRewriter
,
GraphRewriter
,
copy_stack_trace
,
copy_stack_trace
,
in2out
,
dfs_rewriter
,
node_rewriter
,
node_rewriter
,
)
)
from
pytensor.graph.rewriting.db
import
SequenceDB
from
pytensor.graph.rewriting.db
import
SequenceDB
...
@@ -721,7 +721,7 @@ optdb.register("BlasOpt", blas_optdb, "fast_run", "fast_compile", position=1.7)
...
@@ -721,7 +721,7 @@ optdb.register("BlasOpt", blas_optdb, "fast_run", "fast_compile", position=1.7)
# fast_compile is needed to have GpuDot22 created.
# fast_compile is needed to have GpuDot22 created.
blas_optdb
.
register
(
blas_optdb
.
register
(
"local_dot_to_dot22"
,
"local_dot_to_dot22"
,
in2out
(
local_dot_to_dot22
),
dfs_rewriter
(
local_dot_to_dot22
),
"fast_run"
,
"fast_run"
,
"fast_compile"
,
"fast_compile"
,
position
=
0
,
position
=
0
,
...
@@ -744,7 +744,7 @@ blas_optdb.register(
...
@@ -744,7 +744,7 @@ blas_optdb.register(
)
)
blas_opt_inplace
=
in2out
(
blas_opt_inplace
=
dfs_rewriter
(
local_inplace_gemm
,
local_inplace_gemv
,
local_inplace_ger
,
name
=
"blas_opt_inplace"
local_inplace_gemm
,
local_inplace_gemv
,
local_inplace_ger
,
name
=
"blas_opt_inplace"
)
)
optdb
.
register
(
optdb
.
register
(
...
@@ -883,7 +883,7 @@ def local_dot22_to_dot22scalar(fgraph, node):
...
@@ -883,7 +883,7 @@ def local_dot22_to_dot22scalar(fgraph, node):
# dot22scalar and gemm give more speed up then dot22scalar
# dot22scalar and gemm give more speed up then dot22scalar
blas_optdb
.
register
(
blas_optdb
.
register
(
"local_dot22_to_dot22scalar"
,
"local_dot22_to_dot22scalar"
,
in2out
(
local_dot22_to_dot22scalar
),
dfs_rewriter
(
local_dot22_to_dot22scalar
),
"fast_run"
,
"fast_run"
,
position
=
12
,
position
=
12
,
)
)
...
...
pytensor/tensor/rewriting/blas_c.py
浏览文件 @
1d5b1d94
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.rewriting.basic
import
in2out
from
pytensor.graph.rewriting.basic
import
dfs_rewriter
from
pytensor.tensor
import
basic
as
ptb
from
pytensor.tensor
import
basic
as
ptb
from
pytensor.tensor.blas
import
gemv_inplace
,
gemv_no_inplace
,
ger
,
ger_destructive
from
pytensor.tensor.blas
import
gemv_inplace
,
gemv_no_inplace
,
ger
,
ger_destructive
from
pytensor.tensor.blas_c
import
(
from
pytensor.tensor.blas_c
import
(
...
@@ -56,13 +56,15 @@ def make_c_gemv_destructive(fgraph, node):
...
@@ -56,13 +56,15 @@ def make_c_gemv_destructive(fgraph, node):
blas_optdb
.
register
(
blas_optdb
.
register
(
"use_c_blas"
,
in2out
(
use_c_ger
,
use_c_gemv
),
"fast_run"
,
"c_blas"
,
position
=
20
"use_c_blas"
,
dfs_rewriter
(
use_c_ger
,
use_c_gemv
),
"fast_run"
,
"c_blas"
,
position
=
20
)
)
# this matches the InplaceBlasOpt defined in blas.py
# this matches the InplaceBlasOpt defined in blas.py
optdb
.
register
(
optdb
.
register
(
"c_blas_destructive"
,
"c_blas_destructive"
,
in2out
(
make_c_ger_destructive
,
make_c_gemv_destructive
,
name
=
"c_blas_destructive"
),
dfs_rewriter
(
make_c_ger_destructive
,
make_c_gemv_destructive
,
name
=
"c_blas_destructive"
),
"fast_run"
,
"fast_run"
,
"inplace"
,
"inplace"
,
"c_blas"
,
"c_blas"
,
...
...
pytensor/tensor/rewriting/blockwise.py
浏览文件 @
1d5b1d94
...
@@ -2,7 +2,7 @@ from pytensor.compile.mode import optdb
...
@@ -2,7 +2,7 @@ from pytensor.compile.mode import optdb
from
pytensor.graph
import
Constant
,
Op
,
node_rewriter
from
pytensor.graph
import
Constant
,
Op
,
node_rewriter
from
pytensor.graph.destroyhandler
import
inplace_candidates
from
pytensor.graph.destroyhandler
import
inplace_candidates
from
pytensor.graph.replace
import
vectorize_node
from
pytensor.graph.replace
import
vectorize_node
from
pytensor.graph.rewriting.basic
import
copy_stack_trace
,
out2in
from
pytensor.graph.rewriting.basic
import
copy_stack_trace
,
dfs_rewriter
from
pytensor.graph.rewriting.unify
import
OpPattern
,
OpPatternOpTypeType
from
pytensor.graph.rewriting.unify
import
OpPattern
,
OpPatternOpTypeType
from
pytensor.tensor.basic
import
Alloc
,
ARange
,
alloc
,
shape_padleft
from
pytensor.tensor.basic
import
Alloc
,
ARange
,
alloc
,
shape_padleft
from
pytensor.tensor.blockwise
import
Blockwise
,
_squeeze_left
from
pytensor.tensor.blockwise
import
Blockwise
,
_squeeze_left
...
@@ -66,7 +66,7 @@ def local_useless_unbatched_blockwise(fgraph, node):
...
@@ -66,7 +66,7 @@ def local_useless_unbatched_blockwise(fgraph, node):
# We do it after position>=60 so that Blockwise inplace rewrites will work also on useless Blockwise Ops
# We do it after position>=60 so that Blockwise inplace rewrites will work also on useless Blockwise Ops
optdb
.
register
(
optdb
.
register
(
"local_useless_unbatched_blockwise"
,
"local_useless_unbatched_blockwise"
,
out2in
(
local_useless_unbatched_blockwise
,
ignore_newtrees
=
True
),
dfs_rewriter
(
local_useless_unbatched_blockwise
,
ignore_newtrees
=
True
),
"fast_run"
,
"fast_run"
,
"fast_compile"
,
"fast_compile"
,
"blockwise"
,
"blockwise"
,
...
...
pytensor/tensor/rewriting/elemwise.py
浏览文件 @
1d5b1d94
...
@@ -21,6 +21,7 @@ from pytensor.graph.op import Op
...
@@ -21,6 +21,7 @@ from pytensor.graph.op import Op
from
pytensor.graph.rewriting.basic
import
(
from
pytensor.graph.rewriting.basic
import
(
GraphRewriter
,
GraphRewriter
,
copy_stack_trace
,
copy_stack_trace
,
dfs_rewriter
,
in2out
,
in2out
,
node_rewriter
,
node_rewriter
,
out2in
,
out2in
,
...
@@ -1237,21 +1238,21 @@ fuse_seqopt.register(
...
@@ -1237,21 +1238,21 @@ fuse_seqopt.register(
)
)
fuse_seqopt
.
register
(
fuse_seqopt
.
register
(
"local_useless_composite_outputs"
,
"local_useless_composite_outputs"
,
in2out
(
local_useless_composite_outputs
),
dfs_rewriter
(
local_useless_composite_outputs
),
"fast_run"
,
"fast_run"
,
"fusion"
,
"fusion"
,
position
=
2
,
position
=
2
,
)
)
fuse_seqopt
.
register
(
fuse_seqopt
.
register
(
"local_careduce_fusion"
,
"local_careduce_fusion"
,
in2out
(
local_careduce_fusion
),
dfs_rewriter
(
local_careduce_fusion
),
"fast_run"
,
"fast_run"
,
"fusion"
,
"fusion"
,
position
=
10
,
position
=
10
,
)
)
fuse_seqopt
.
register
(
fuse_seqopt
.
register
(
"local_inline_composite_constants"
,
"local_inline_composite_constants"
,
in2out
(
local_inline_composite_constants
,
ignore_newtrees
=
True
),
dfs_rewriter
(
local_inline_composite_constants
,
ignore_newtrees
=
True
),
"fast_run"
,
"fast_run"
,
"fusion"
,
"fusion"
,
position
=
20
,
position
=
20
,
...
...
pytensor/tensor/rewriting/jax.py
浏览文件 @
1d5b1d94
import
pytensor.tensor
as
pt
import
pytensor.tensor
as
pt
from
pytensor.compile
import
optdb
from
pytensor.compile
import
optdb
from
pytensor.graph.rewriting.basic
import
in2out
,
node_rewriter
from
pytensor.graph.rewriting.basic
import
dfs_rewriter
,
node_rewriter
from
pytensor.tensor.basic
import
MakeVector
from
pytensor.tensor.basic
import
MakeVector
from
pytensor.tensor.elemwise
import
DimShuffle
from
pytensor.tensor.elemwise
import
DimShuffle
from
pytensor.tensor.math
import
Sum
from
pytensor.tensor.math
import
Sum
...
@@ -46,7 +46,7 @@ def boolean_indexing_set_or_inc(fgraph, node):
...
@@ -46,7 +46,7 @@ def boolean_indexing_set_or_inc(fgraph, node):
optdb
.
register
(
optdb
.
register
(
"jax_boolean_indexing_set_or_inc"
,
"jax_boolean_indexing_set_or_inc"
,
in2out
(
boolean_indexing_set_or_inc
),
dfs_rewriter
(
boolean_indexing_set_or_inc
),
"jax"
,
"jax"
,
position
=
100
,
position
=
100
,
)
)
...
@@ -96,7 +96,7 @@ def boolean_indexing_sum(fgraph, node):
...
@@ -96,7 +96,7 @@ def boolean_indexing_sum(fgraph, node):
optdb
.
register
(
optdb
.
register
(
"jax_boolean_indexing_sum"
,
in2out
(
boolean_indexing_sum
),
"jax"
,
position
=
100
"jax_boolean_indexing_sum"
,
dfs_rewriter
(
boolean_indexing_sum
),
"jax"
,
position
=
100
)
)
...
@@ -144,7 +144,7 @@ def shape_parameter_as_tuple(fgraph, node):
...
@@ -144,7 +144,7 @@ def shape_parameter_as_tuple(fgraph, node):
optdb
.
register
(
optdb
.
register
(
"jax_shape_parameter_as_tuple"
,
"jax_shape_parameter_as_tuple"
,
in2out
(
shape_parameter_as_tuple
),
dfs_rewriter
(
shape_parameter_as_tuple
),
"jax"
,
"jax"
,
position
=
100
,
position
=
100
,
)
)
pytensor/tensor/rewriting/linalg.py
浏览文件 @
1d5b1d94
...
@@ -10,7 +10,7 @@ from pytensor.compile import optdb
...
@@ -10,7 +10,7 @@ from pytensor.compile import optdb
from
pytensor.graph
import
Apply
,
FunctionGraph
from
pytensor.graph
import
Apply
,
FunctionGraph
from
pytensor.graph.rewriting.basic
import
(
from
pytensor.graph.rewriting.basic
import
(
copy_stack_trace
,
copy_stack_trace
,
in2out
,
dfs_rewriter
,
node_rewriter
,
node_rewriter
,
)
)
from
pytensor.graph.rewriting.unify
import
OpPattern
from
pytensor.graph.rewriting.unify
import
OpPattern
...
@@ -905,7 +905,7 @@ def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply):
...
@@ -905,7 +905,7 @@ def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply):
optdb
.
register
(
optdb
.
register
(
"jax_bilinaer_lyapunov_to_direct"
,
"jax_bilinaer_lyapunov_to_direct"
,
in2out
(
jax_bilinaer_lyapunov_to_direct
),
dfs_rewriter
(
jax_bilinaer_lyapunov_to_direct
),
"jax"
,
"jax"
,
position
=
0.9
,
# Run before canonicalization
position
=
0.9
,
# Run before canonicalization
)
)
...
...
pytensor/tensor/rewriting/numba.py
浏览文件 @
1d5b1d94
from
pytensor.compile
import
optdb
from
pytensor.compile
import
optdb
from
pytensor.graph
import
node_rewriter
from
pytensor.graph
import
node_rewriter
from
pytensor.graph.rewriting.basic
import
out2in
from
pytensor.graph.rewriting.basic
import
dfs_rewriter
from
pytensor.graph.traversal
import
applys_between
from
pytensor.graph.traversal
import
applys_between
from
pytensor.tensor.basic
import
as_tensor
,
constant
from
pytensor.tensor.basic
import
as_tensor
,
constant
from
pytensor.tensor.blockwise
import
Blockwise
,
BlockwiseWithCoreShape
from
pytensor.tensor.blockwise
import
Blockwise
,
BlockwiseWithCoreShape
...
@@ -102,7 +102,7 @@ def introduce_explicit_core_shape_blockwise(fgraph, node):
...
@@ -102,7 +102,7 @@ def introduce_explicit_core_shape_blockwise(fgraph, node):
optdb
.
register
(
optdb
.
register
(
introduce_explicit_core_shape_blockwise
.
__name__
,
introduce_explicit_core_shape_blockwise
.
__name__
,
out2in
(
introduce_explicit_core_shape_blockwise
),
dfs_rewriter
(
introduce_explicit_core_shape_blockwise
),
"numba"
,
"numba"
,
position
=
100
,
position
=
100
,
)
)
pytensor/tensor/rewriting/ofg.py
浏览文件 @
1d5b1d94
...
@@ -4,7 +4,7 @@ from pytensor import Variable, clone_replace
...
@@ -4,7 +4,7 @@ from pytensor import Variable, clone_replace
from
pytensor.compile
import
optdb
from
pytensor.compile
import
optdb
from
pytensor.compile.builders
import
OpFromGraph
from
pytensor.compile.builders
import
OpFromGraph
from
pytensor.graph
import
Apply
,
node_rewriter
from
pytensor.graph
import
Apply
,
node_rewriter
from
pytensor.graph.rewriting.basic
import
copy_stack_trace
,
in2out
from
pytensor.graph.rewriting.basic
import
copy_stack_trace
,
dfs_rewriter
from
pytensor.tensor.basic
import
AllocDiag
from
pytensor.tensor.basic
import
AllocDiag
from
pytensor.tensor.rewriting.basic
import
register_specialize
from
pytensor.tensor.rewriting.basic
import
register_specialize
...
@@ -37,7 +37,7 @@ def inline_ofg_expansion(fgraph, node):
...
@@ -37,7 +37,7 @@ def inline_ofg_expansion(fgraph, node):
# and before the first scan optimizer.
# and before the first scan optimizer.
optdb
.
register
(
optdb
.
register
(
"inline_ofg_expansion"
,
"inline_ofg_expansion"
,
in2out
(
inline_ofg_expansion
),
dfs_rewriter
(
inline_ofg_expansion
),
"fast_compile"
,
"fast_compile"
,
"fast_run"
,
"fast_run"
,
position
=-
0.01
,
position
=-
0.01
,
...
...
pytensor/tensor/rewriting/subtensor.py
浏览文件 @
1d5b1d94
...
@@ -10,9 +10,9 @@ from pytensor.graph.basic import Constant, Variable
...
@@ -10,9 +10,9 @@ from pytensor.graph.basic import Constant, Variable
from
pytensor.graph.rewriting.basic
import
(
from
pytensor.graph.rewriting.basic
import
(
WalkingGraphRewriter
,
WalkingGraphRewriter
,
copy_stack_trace
,
copy_stack_trace
,
dfs_rewriter
,
in2out
,
in2out
,
node_rewriter
,
node_rewriter
,
out2in
,
)
)
from
pytensor.raise_op
import
Assert
from
pytensor.raise_op
import
Assert
from
pytensor.scalar
import
Add
,
ScalarConstant
,
ScalarType
from
pytensor.scalar
import
Add
,
ScalarConstant
,
ScalarType
...
@@ -1560,7 +1560,7 @@ def local_uint_constant_indices(fgraph, node):
...
@@ -1560,7 +1560,7 @@ def local_uint_constant_indices(fgraph, node):
compile
.
optdb
.
register
(
compile
.
optdb
.
register
(
local_uint_constant_indices
.
__name__
,
local_uint_constant_indices
.
__name__
,
out2in
(
local_uint_constant_indices
),
dfs_rewriter
(
local_uint_constant_indices
),
# We don't include in the Python / C because those always cast indices to int64 internally.
# We don't include in the Python / C because those always cast indices to int64 internally.
"numba"
,
"numba"
,
"jax"
,
"jax"
,
...
...
pytensor/xtensor/rewriting/utils.py
浏览文件 @
1d5b1d94
...
@@ -2,7 +2,7 @@ import typing
...
@@ -2,7 +2,7 @@ import typing
from
collections.abc
import
Sequence
from
collections.abc
import
Sequence
from
pytensor.compile
import
optdb
from
pytensor.compile
import
optdb
from
pytensor.graph.rewriting.basic
import
NodeRewriter
,
in2out
from
pytensor.graph.rewriting.basic
import
NodeRewriter
,
dfs_rewriter
from
pytensor.graph.rewriting.db
import
EquilibriumDB
,
RewriteDatabase
from
pytensor.graph.rewriting.db
import
EquilibriumDB
,
RewriteDatabase
from
pytensor.tensor.rewriting.ofg
import
inline_ofg_expansion
from
pytensor.tensor.rewriting.ofg
import
inline_ofg_expansion
from
pytensor.tensor.variable
import
TensorVariable
from
pytensor.tensor.variable
import
TensorVariable
...
@@ -23,7 +23,7 @@ optdb.register(
...
@@ -23,7 +23,7 @@ optdb.register(
# Register OFG inline again after lowering xtensor
# Register OFG inline again after lowering xtensor
optdb
.
register
(
optdb
.
register
(
"inline_ofg_expansion_xtensor"
,
"inline_ofg_expansion_xtensor"
,
in2out
(
inline_ofg_expansion
),
dfs_rewriter
(
inline_ofg_expansion
),
"fast_run"
,
"fast_run"
,
"fast_compile"
,
"fast_compile"
,
position
=
0.11
,
position
=
0.11
,
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论