Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
27d79707
提交
27d79707
authored
8月 28, 2025
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
9月 20, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Use OpPattern in tracks
上级
19f1486b
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
8 个修改的文件
包含
140 行增加
和
171 行删除
+140
-171
basic.py
pytensor/scalar/basic.py
+5
-3
rewriting.py
pytensor/tensor/_linalg/solve/rewriting.py
+2
-1
basic.py
pytensor/tensor/rewriting/basic.py
+35
-27
blockwise.py
pytensor/tensor/rewriting/blockwise.py
+26
-20
elemwise.py
pytensor/tensor/rewriting/elemwise.py
+14
-29
linalg.py
pytensor/tensor/rewriting/linalg.py
+0
-0
math.py
pytensor/tensor/rewriting/math.py
+33
-58
subtensor.py
pytensor/tensor/rewriting/subtensor.py
+25
-33
没有找到文件。
pytensor/scalar/basic.py
浏览文件 @
27d79707
...
...
@@ -1228,6 +1228,8 @@ class ScalarOp(COp):
f
"(got: {output_types_preference})"
)
self
.
output_types_preference
=
output_types_preference
elif
not
hasattr
(
self
,
"output_types_preference"
):
self
.
output_types_preference
=
None
def
make_node
(
self
,
*
inputs
):
if
self
.
nin
>=
0
:
...
...
@@ -1247,7 +1249,7 @@ class ScalarOp(COp):
return
Apply
(
self
,
inputs
,
outputs
)
def
output_types
(
self
,
types
):
if
hasattr
(
self
,
"output_types_preference"
)
:
if
self
.
output_types_preference
is
not
None
:
variables
=
self
.
output_types_preference
(
*
types
)
if
not
isinstance
(
variables
,
list
|
tuple
)
or
any
(
not
isinstance
(
x
,
CType
)
for
x
in
variables
...
...
@@ -2696,7 +2698,7 @@ class Sign(UnaryScalarOp):
nfunc_spec
=
(
"sign"
,
1
,
1
)
@staticmethod
def
output_types_preference
(
x
):
def
_
output_types_preference
(
x
):
if
x
==
bool
:
raise
TypeError
(
x
)
return
same_out_nocomplex
(
x
)
...
...
@@ -2737,7 +2739,7 @@ class Sign(UnaryScalarOp):
return
s
sign
=
Sign
(
name
=
"sign"
)
sign
=
Sign
(
name
=
"sign"
,
output_types_preference
=
Sign
.
_output_types_preference
)
class
Ceil
(
UnaryScalarOp
):
...
...
pytensor/tensor/_linalg/solve/rewriting.py
浏览文件 @
27d79707
...
...
@@ -14,6 +14,7 @@ from pytensor.tensor.basic import atleast_Nd
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.elemwise
import
DimShuffle
from
pytensor.tensor.rewriting.basic
import
register_specialize
from
pytensor.tensor.rewriting.blockwise
import
blockwise_of
from
pytensor.tensor.rewriting.linalg
import
is_matrix_transpose
from
pytensor.tensor.slinalg
import
Solve
,
cho_solve
,
cholesky
,
lu_factor
,
lu_solve
from
pytensor.tensor.variable
import
TensorVariable
...
...
@@ -227,7 +228,7 @@ def _scan_split_non_sequence_decomposition_and_solve(
@register_specialize
@node_rewriter
([
Blockwise
])
@node_rewriter
([
blockwise_of
(
Solve
)
])
def
reuse_decomposition_multiple_solves
(
fgraph
,
node
):
return
_split_decomp_and_solve_steps
(
fgraph
,
node
,
eager
=
False
,
allowed_assume_a
=
{
"gen"
,
"tridiagonal"
,
"pos"
}
...
...
pytensor/tensor/rewriting/basic.py
浏览文件 @
27d79707
...
...
@@ -26,10 +26,9 @@ import logging
import
numpy
as
np
import
pytensor.scalar.basic
as
ps
from
pytensor
import
compile
,
config
from
pytensor.compile.ops
import
ViewOp
from
pytensor.graph
import
FunctionGraph
from
pytensor.graph
import
FunctionGraph
,
Op
from
pytensor.graph.basic
import
Constant
from
pytensor.graph.rewriting.basic
import
(
NodeProcessingGraphRewriter
,
...
...
@@ -40,9 +39,24 @@ from pytensor.graph.rewriting.basic import (
node_rewriter
,
)
from
pytensor.graph.rewriting.db
import
RewriteDatabase
from
pytensor.graph.rewriting.unify
import
OpPattern
,
OpPatternOpTypeType
from
pytensor.npy_2_compat
import
normalize_axis_index
from
pytensor.raise_op
import
Assert
,
CheckAndRaise
,
assert_op
from
pytensor.scalar.basic
import
Second
from
pytensor.scalar
import
(
AND
,
EQ
,
LE
,
NEQ
,
OR
,
XOR
,
Add
,
BinaryScalarOp
,
Cast
,
Identity
,
Mul
,
Second
,
Switch
,
)
from
pytensor.tensor.basic
import
(
Alloc
,
AllocEmpty
,
...
...
@@ -225,6 +239,12 @@ def register_uncanonicalize(
return
node_rewriter
def
elemwise_of
(
scalar_op
:
OpPatternOpTypeType
|
OpPattern
)
->
OpPattern
:
if
not
isinstance
(
scalar_op
,
Op
|
OpPattern
):
scalar_op
=
OpPattern
(
scalar_op
)
return
OpPattern
(
Elemwise
,
scalar_op
=
scalar_op
)
@register_canonicalize
@register_specialize
@node_rewriter
([
TensorFromScalar
])
...
...
@@ -551,7 +571,7 @@ def local_useless_elemwise(fgraph, node):
dtype
=
node
.
outputs
[
0
]
.
type
.
dtype
scalar_op
=
node
.
op
.
scalar_op
if
isinstance
(
scalar_op
,
ps
.
EQ
)
and
len
(
node
.
inputs
)
==
2
:
if
isinstance
(
scalar_op
,
EQ
)
and
len
(
node
.
inputs
)
==
2
:
if
node
.
inputs
[
0
]
is
node
.
inputs
[
1
]:
# it is the same var in the graph. That will always be true
ret
=
ones_like
(
node
.
inputs
[
0
],
dtype
=
dtype
,
opt
=
True
)
...
...
@@ -559,7 +579,7 @@ def local_useless_elemwise(fgraph, node):
# Copy stack trace from input to constant output
copy_stack_trace
(
node
.
outputs
[
0
],
ret
)
return
[
ret
]
elif
isinstance
(
scalar_op
,
ps
.
NEQ
|
ps
.
XOR
)
and
len
(
node
.
inputs
)
==
2
:
elif
isinstance
(
scalar_op
,
NEQ
|
XOR
)
and
len
(
node
.
inputs
)
==
2
:
if
node
.
inputs
[
0
]
is
node
.
inputs
[
1
]:
# it is the same var in the graph. That will always be false
ret
=
zeros_like
(
node
.
inputs
[
0
],
dtype
=
dtype
,
opt
=
True
)
...
...
@@ -568,14 +588,11 @@ def local_useless_elemwise(fgraph, node):
copy_stack_trace
(
node
.
outputs
[
0
],
ret
)
return
[
ret
]
elif
(
isinstance
(
node
.
op
.
scalar_op
,
ps
.
Mul
|
ps
.
Add
|
ps
.
Identity
)
and
len
(
node
.
inputs
)
==
1
):
elif
isinstance
(
node
.
op
.
scalar_op
,
Mul
|
Add
|
Identity
)
and
len
(
node
.
inputs
)
==
1
:
# No need to copy over any stack trace
return
[
node
.
inputs
[
0
]]
elif
isinstance
(
node
.
op
.
scalar_op
,
ps
.
AND
)
and
len
(
node
.
inputs
)
==
2
:
elif
isinstance
(
node
.
op
.
scalar_op
,
AND
)
and
len
(
node
.
inputs
)
==
2
:
if
(
isinstance
(
node
.
inputs
[
0
],
TensorConstant
)
and
node
.
inputs
[
1
]
.
type
.
broadcastable
==
out_bcast
...
...
@@ -602,7 +619,7 @@ def local_useless_elemwise(fgraph, node):
# and this rewrite would be wrong
return
[
node
.
inputs
[
0
]
.
astype
(
node
.
outputs
[
0
]
.
dtype
)]
elif
isinstance
(
node
.
op
.
scalar_op
,
ps
.
OR
)
and
len
(
node
.
inputs
)
==
2
:
elif
isinstance
(
node
.
op
.
scalar_op
,
OR
)
and
len
(
node
.
inputs
)
==
2
:
if
(
isinstance
(
node
.
inputs
[
0
],
TensorConstant
)
and
node
.
inputs
[
1
]
.
type
.
broadcastable
==
out_bcast
...
...
@@ -653,7 +670,7 @@ def local_alloc_unary(fgraph, node):
@register_canonicalize
@register_specialize
@node_rewriter
([
Elemwise
])
@node_rewriter
([
elemwise_of
(
Cast
)
])
def
local_cast_cast
(
fgraph
,
node
):
"""cast(cast(x, dtype1), dtype2)
...
...
@@ -663,13 +680,11 @@ def local_cast_cast(fgraph, node):
and the first cast cause an upcast.
"""
if
not
(
isinstance
(
node
.
op
,
Elemwise
)
and
isinstance
(
node
.
op
.
scalar_op
,
ps
.
Cast
)):
return
x
=
node
.
inputs
[
0
]
if
not
(
x
.
owner
and
isinstance
(
x
.
owner
.
op
,
Elemwise
)
and
isinstance
(
x
.
owner
.
op
.
scalar_op
,
ps
.
Cast
)
and
isinstance
(
x
.
owner
.
op
.
scalar_op
,
Cast
)
):
return
...
...
@@ -1009,7 +1024,7 @@ def local_useless_switch(fgraph, node):
node
.
outputs
[
0
]
.
type
.
ndim
==
0
and
cond_var
.
owner
and
isinstance
(
cond_var
.
owner
.
op
,
Elemwise
)
and
isinstance
(
cond_var
.
owner
.
op
.
scalar_op
,
ps
.
LE
)
and
isinstance
(
cond_var
.
owner
.
op
.
scalar_op
,
LE
)
and
cond_var
.
owner
.
inputs
[
0
]
.
owner
and
isinstance
(
cond_var
.
owner
.
inputs
[
0
]
.
owner
.
op
,
Shape_i
)
and
get_scalar_constant_value
(
...
...
@@ -1031,24 +1046,18 @@ def local_useless_switch(fgraph, node):
@register_canonicalize
@node_rewriter
([
Elemwise
])
@node_rewriter
([
elemwise_of
(
BinaryScalarOp
|
Add
|
Mul
)
])
def
local_merge_switch_same_cond
(
fgraph
,
node
):
"""
Merge add/sub/mul/div/minimum/maximum/... of switches sharing the same
condition, to enable further simplification of their branches
Example: switch(c, a, b) + switch(c, x, y) -> switch(c, a+x, b+y)
"""
# node must be binary elemwise or add or mul
if
not
(
isinstance
(
node
.
op
,
Elemwise
)
and
isinstance
(
node
.
op
.
scalar_op
,
ps
.
BinaryScalarOp
|
ps
.
Add
|
ps
.
Mul
)
):
return
# all inputs must be switch
if
not
all
(
s
.
owner
and
isinstance
(
s
.
owner
.
op
,
Elemwise
)
and
isinstance
(
s
.
owner
.
op
.
scalar_op
,
ps
.
Switch
)
and
isinstance
(
s
.
owner
.
op
.
scalar_op
,
Switch
)
for
s
in
node
.
inputs
):
return
...
...
@@ -1174,10 +1183,9 @@ register_specialize(topo_constant_folding, "fast_compile", final_rewriter=True)
@register_infer_shape
@register_canonicalize
(
"fast_compile"
)
@register_useless
(
"fast_compile"
)
@node_rewriter
(
None
)
@node_rewriter
(
[
ViewOp
]
)
def
local_view_op
(
fgraph
,
node
):
if
isinstance
(
node
.
op
,
ViewOp
):
return
node
.
inputs
return
node
.
inputs
@register_infer_shape
...
...
pytensor/tensor/rewriting/blockwise.py
浏览文件 @
27d79707
from
pytensor.compile.mode
import
optdb
from
pytensor.graph
import
Constant
,
node_rewriter
from
pytensor.graph
import
Constant
,
Op
,
node_rewriter
from
pytensor.graph.destroyhandler
import
inplace_candidates
from
pytensor.graph.replace
import
vectorize_node
from
pytensor.graph.rewriting.basic
import
copy_stack_trace
,
out2in
from
pytensor.graph.rewriting.unify
import
OpPattern
,
OpPatternOpTypeType
from
pytensor.tensor.basic
import
Alloc
,
ARange
,
alloc
,
shape_padleft
from
pytensor.tensor.blockwise
import
Blockwise
,
_squeeze_left
from
pytensor.tensor.math
import
Dot
...
...
@@ -20,6 +21,12 @@ from pytensor.tensor.subtensor import (
)
def
blockwise_of
(
core_op
:
OpPatternOpTypeType
|
OpPattern
)
->
OpPattern
:
if
not
isinstance
(
core_op
,
Op
|
OpPattern
):
core_op
=
OpPattern
(
core_op
)
return
OpPattern
(
Blockwise
,
core_op
=
core_op
)
@node_rewriter
([
Blockwise
])
def
local_useless_blockwise
(
fgraph
,
node
):
"""
...
...
@@ -71,22 +78,24 @@ optdb.register(
@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter
(
tracks
=
[
Blockwise
])
@node_rewriter
(
tracks
=
[
blockwise_of
(
Dot
|
Alloc
|
ARange
|
Subtensor
|
AdvancedSubtensor
|
AdvancedIncSubtensor
|
Reshape
)
]
)
def
local_eager_useless_unbatched_blockwise
(
fgraph
,
node
):
if
isinstance
(
node
.
op
.
core_op
,
Dot
|
Alloc
|
ARange
|
Subtensor
|
AdvancedSubtensor
|
AdvancedIncSubtensor
|
Reshape
,
):
# Many Dot-related rewrites (eg, all of BlasOpt) happen before specialize
# These other Ops can't always be trivially vectorized at runtime,
# since their inputs may imply non-rectangular shapes.
return
local_useless_unbatched_blockwise
.
fn
(
fgraph
,
node
)
# Many Dot-related rewrites (eg, all of BlasOpt) happen before specialize
# These other Ops can't always be trivially vectorized at runtime,
# since their inputs may imply non-rectangular shapes.
return
local_useless_unbatched_blockwise
.
fn
(
fgraph
,
node
)
@register_specialize
(
"shape_unsafe"
)
...
...
@@ -204,7 +213,7 @@ def local_blockwise_alloc(fgraph, node):
@register_specialize
@node_rewriter
([
Blockwise
])
@node_rewriter
([
blockwise_of
(
Reshape
)
])
def
local_blockwise_reshape
(
fgraph
,
node
):
"""Rewrite away square Blockwise reshapes.
...
...
@@ -215,9 +224,6 @@ def local_blockwise_reshape(fgraph, node):
For the square Reshape case, we must wait for all the intermediate
operations to be lifted as Allocs
"""
if
not
isinstance
(
node
.
op
.
core_op
,
Reshape
):
return
None
x
,
output_shape
=
node
.
inputs
batch_ndim
=
node
.
op
.
batch_ndim
(
node
)
if
all
(
output_shape
.
type
.
broadcastable
[:
batch_ndim
]):
...
...
pytensor/tensor/rewriting/elemwise.py
浏览文件 @
27d79707
...
...
@@ -26,6 +26,7 @@ from pytensor.graph.rewriting.basic import (
out2in
,
)
from
pytensor.graph.rewriting.db
import
SequenceDB
from
pytensor.graph.rewriting.unify
import
OpPattern
from
pytensor.graph.utils
import
InconsistencyError
,
MethodNotDefined
from
pytensor.scalar.math
import
Grad2F1Loop
,
_grad_2f1_loop
from
pytensor.tensor.basic
import
(
...
...
@@ -37,6 +38,7 @@ from pytensor.tensor.math import add, exp, mul
from
pytensor.tensor.rewriting.basic
import
(
alloc_like
,
broadcasted_by
,
elemwise_of
,
register_canonicalize
,
register_specialize
,
register_stabilize
,
...
...
@@ -422,7 +424,14 @@ def local_useless_dimshuffle_makevector(fgraph, node):
@register_canonicalize
@node_rewriter
([
Elemwise
])
@node_rewriter
(
[
elemwise_of
(
OpPattern
(
ps
.
ScalarOp
,
output_types_preference
=
ps
.
upgrade_to_float
)
),
elemwise_of
(
OpPattern
(
ps
.
ScalarOp
,
output_types_preference
=
ps
.
upcast_out
)),
]
)
def
local_upcast_elemwise_constant_inputs
(
fgraph
,
node
):
"""This explicitly upcasts constant inputs to elemwise Ops, when
those Ops do implicit upcasting anyway.
...
...
@@ -433,12 +442,6 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):
if
len
(
node
.
outputs
)
>
1
:
return
None
if
getattr
(
node
.
op
.
scalar_op
,
"output_types_preference"
,
None
)
not
in
(
ps
.
upgrade_to_float
,
ps
.
upcast_out
,
):
return
None
# this is the kind of op that we can screw with the input
# dtypes by upcasting explicitly
[
old_out
]
=
node
.
outputs
...
...
@@ -988,13 +991,9 @@ class FusionOptimizer(GraphRewriter):
@register_canonicalize
@register_specialize
@node_rewriter
([
Elemwise
])
@node_rewriter
([
elemwise_of
(
ps
.
Composite
)
])
def
local_useless_composite_outputs
(
fgraph
,
node
):
"""Remove inputs and outputs of Composite Ops that are not used anywhere."""
if
not
(
isinstance
(
node
.
op
,
Elemwise
)
and
isinstance
(
node
.
op
.
scalar_op
,
ps
.
Composite
)
):
return
comp
=
node
.
op
.
scalar_op
used_outputs_idxs
=
[
i
for
i
,
o_extern
in
enumerate
(
node
.
outputs
)
if
fgraph
.
clients
[
o_extern
]
...
...
@@ -1104,14 +1103,10 @@ def local_careduce_fusion(fgraph, node):
return
[
new_car_op
(
*
elm_inputs
)]
@node_rewriter
([
Elemwise
])
@node_rewriter
([
elemwise_of
(
ps
.
Composite
)
])
def
local_inline_composite_constants
(
fgraph
,
node
):
"""Inline scalar constants in Composite graphs."""
composite_op
=
node
.
op
.
scalar_op
if
not
isinstance
(
composite_op
,
ps
.
Composite
):
return
None
new_outer_inputs
=
[]
new_inner_inputs
=
[]
inner_replacements
=
{}
...
...
@@ -1287,14 +1282,9 @@ def _rebuild_partial_2f1grad_loop(node, wrt):
@register_specialize
@node_rewriter
([
Elemwise
])
@node_rewriter
([
elemwise_of
(
Grad2F1Loop
)
])
def
local_useless_2f1grad_loop
(
fgraph
,
node
):
# Remove unused terms from the hyp2f1 grad loop
loop_op
=
node
.
op
.
scalar_op
if
not
isinstance
(
loop_op
,
Grad2F1Loop
):
return
grad_related_vars
=
node
.
outputs
[:
-
4
]
# Rewrite was already applied
if
len
(
grad_related_vars
)
//
3
!=
3
:
...
...
@@ -1326,18 +1316,13 @@ def local_useless_2f1grad_loop(fgraph, node):
return
replacements
@node_rewriter
([
Elemwise
])
@node_rewriter
([
elemwise_of
(
Grad2F1Loop
)
])
def
split_2f1grad_loop
(
fgraph
,
node
):
"""
2f1grad loop has too many operands for Numpy frompyfunc code used by Elemwise nodes on python mode.
This rewrite splits it across 3 different operations. It is not needed if `local_useless_2f1grad_loop` was applied
"""
loop_op
=
node
.
op
.
scalar_op
if
not
isinstance
(
loop_op
,
Grad2F1Loop
):
return
None
grad_related_vars
=
node
.
outputs
[:
-
4
]
# local_useless_2f1grad_loop was used, we should be safe
if
len
(
grad_related_vars
)
//
3
!=
3
:
...
...
pytensor/tensor/rewriting/linalg.py
浏览文件 @
27d79707
差异被折叠。
点击展开。
pytensor/tensor/rewriting/math.py
浏览文件 @
27d79707
...
...
@@ -37,7 +37,6 @@ from pytensor.tensor.basic import (
zeros
,
zeros_like
,
)
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.elemwise
import
CAReduce
,
DimShuffle
,
Elemwise
from
pytensor.tensor.exceptions
import
NotScalarConstantError
from
pytensor.tensor.extra_ops
import
broadcast_arrays
,
concat_with_broadcast
...
...
@@ -49,6 +48,11 @@ from pytensor.tensor.math import (
_dot
,
_matmul
,
add
,
arccosh
,
arcsinh
,
arctanh
,
cosh
,
deg2rad
,
digamma
,
dot
,
erf
,
...
...
@@ -70,13 +74,16 @@ from pytensor.tensor.math import (
neg
,
polygamma
,
prod
,
rad2deg
,
reciprocal
,
sigmoid
,
sign
,
sinh
,
softplus
,
sqr
,
sqrt
,
sub
,
tanh
,
tri_gamma
,
true_div
,
variadic_add
,
...
...
@@ -96,6 +103,7 @@ from pytensor.tensor.rewriting.basic import (
register_uncanonicalize
,
register_useless
,
)
from
pytensor.tensor.rewriting.blockwise
import
blockwise_of
from
pytensor.tensor.rewriting.elemwise
import
apply_local_dimshuffle_lift
from
pytensor.tensor.rewriting.linalg
import
is_matrix_transpose
from
pytensor.tensor.shape
import
Shape
,
Shape_i
...
...
@@ -151,7 +159,7 @@ def local_0_dot_x(fgraph, node):
@register_stabilize
@node_rewriter
([
Blockwise
])
@node_rewriter
([
blockwise_of
(
BlockDiagonal
)
])
def
local_block_diag_dot_to_dot_block_diag
(
fgraph
,
node
):
r"""
Perform the rewrite ``dot(block_diag(A, B), C) -> concat(dot(A, C), dot(B, C))``
...
...
@@ -160,9 +168,6 @@ def local_block_diag_dot_to_dot_block_diag(fgraph, node):
of approximately O(n^3), it's always better to perform two dot products on the smaller matrices, rather than
a single dot on the larger matrix.
"""
if
not
isinstance
(
node
.
op
.
core_op
,
BlockDiagonal
):
return
# Check that the BlockDiagonal is an input to a Dot node:
for
client
in
itertools
.
chain
.
from_iterable
(
get_clients_at_depth
(
fgraph
,
node
,
depth
=
i
)
for
i
in
[
1
,
2
]
...
...
@@ -424,60 +429,30 @@ def local_dot_to_mul(fgraph, node):
return
[
new_out
]
def
is_inverse_pair
(
node_op
,
prev_op
,
inv_pair
):
"""
Given two consecutive operations, check if they are the
provided pair of inverse functions.
"""
node_is_op0
=
isinstance
(
node_op
,
inv_pair
[
0
])
node_is_op1
=
isinstance
(
node_op
,
inv_pair
[
1
])
prev_is_op0
=
isinstance
(
prev_op
,
inv_pair
[
0
])
prev_is_op1
=
isinstance
(
prev_op
,
inv_pair
[
1
])
return
(
node_is_op0
and
prev_is_op1
)
or
(
node_is_op1
and
prev_is_op0
)
@register_canonicalize
@register_specialize
@node_rewriter
([
Elemwise
])
def
local_func_inv
(
fgraph
,
node
):
"""
Check for two consecutive operations that are functional inverses
and remove them from the function graph.
"""
inv_pairs
=
(
(
ps
.
Deg2Rad
,
ps
.
Rad2Deg
),
(
ps
.
Cosh
,
ps
.
ArcCosh
),
(
ps
.
Tanh
,
ps
.
ArcTanh
),
(
ps
.
Sinh
,
ps
.
ArcSinh
),
(
ps
.
Conj
,
ps
.
Conj
),
(
ps
.
Neg
,
ps
.
Neg
),
(
ps
.
Reciprocal
,
ps
.
Reciprocal
),
)
x
=
node
.
inputs
[
0
]
if
not
isinstance
(
node
.
op
,
Elemwise
):
return
if
not
(
x
.
owner
and
isinstance
(
x
.
owner
.
op
,
Elemwise
)):
return
prev_op
=
x
.
owner
.
op
.
scalar_op
node_op
=
node
.
op
.
scalar_op
for
inv_pair
in
inv_pairs
:
if
is_inverse_pair
(
node_op
,
prev_op
,
inv_pair
):
# We don't need to copy stack trace, because the rewrite
# is trivial and maintains the earlier stack trace
ottype
=
node
.
out
.
dtype
inp
=
x
.
owner
.
inputs
[
0
]
# Functions may have casted integer input to float
if
inp
.
dtype
!=
ottype
:
inp
=
cast
(
inp
,
ottype
)
return
[
inp
]
for
pair
in
(
(
deg2rad
,
rad2deg
),
(
cosh
,
arccosh
),
(
tanh
,
arctanh
),
(
sinh
,
arcsinh
),
(
_conj
,
_conj
),
(
neg
,
neg
),
(
reciprocal
,
reciprocal
),
):
# Create a simple PatternNodeRewriter for each pair of opposite ops
# instead of a general Op that is called to often for very few hits
for
op
,
inv_op
in
(
pair
,
reversed
(
pair
)):
rewrite
=
PatternNodeRewriter
(
(
op
,
(
inv_op
,
"x"
)),
"x"
,
allow_multiple_clients
=
True
,
allow_cast
=
True
,
name
=
f
"useless_{op}_of_{inv_op}"
,
)
register_canonicalize
(
rewrite
)
register_specialize
(
rewrite
)
return
if
op
is
inv_op
:
break
# Same Op, no need to define two rewrites
@register_canonicalize
...
...
pytensor/tensor/rewriting/subtensor.py
浏览文件 @
27d79707
...
...
@@ -35,7 +35,7 @@ from pytensor.tensor.basic import (
switch
,
)
from
pytensor.tensor.basic
import
constant
as
tensor_constant
from
pytensor.tensor.blockwise
import
Blockwise
,
_squeeze_left
from
pytensor.tensor.blockwise
import
_squeeze_left
from
pytensor.tensor.elemwise
import
Elemwise
from
pytensor.tensor.exceptions
import
NotScalarConstantError
from
pytensor.tensor.extra_ops
import
broadcast_to
...
...
@@ -58,6 +58,7 @@ from pytensor.tensor.rewriting.basic import (
register_specialize
,
register_stabilize
,
)
from
pytensor.tensor.rewriting.blockwise
import
blockwise_of
from
pytensor.tensor.shape
import
(
shape_padleft
,
shape_padright
,
...
...
@@ -974,33 +975,30 @@ def local_IncSubtensor_serialize(fgraph, node):
and
not
i
.
owner
.
op
.
set_instead_of_inc
)
if
node
.
op
==
add
:
o_type
=
node
.
outputs
[
0
]
.
type
o_type
=
node
.
outputs
[
0
]
.
type
movable_inputs
=
[
i
for
i
in
node
.
inputs
if
movable
(
i
)]
movable_inputs
=
[
i
for
i
in
node
.
inputs
if
movable
(
i
)]
if
movable_inputs
:
new_inputs
=
[
i
for
i
in
node
.
inputs
if
not
movable
(
i
)]
+
[
mi
.
owner
.
inputs
[
0
]
for
mi
in
movable_inputs
]
new_add
=
variadic_add
(
*
new_inputs
)
# Copy over stacktrace from original output, as an error
# (e.g. an index error) in this add operation should
# correspond to an error in the original add operation.
copy_stack_trace
(
node
.
outputs
[
0
],
new_add
)
# stack up the new incsubtensors
tip
=
new_add
for
mi
in
movable_inputs
:
assert
o_type
.
is_super
(
tip
.
type
)
tip
=
mi
.
owner
.
op
(
tip
,
*
mi
.
owner
.
inputs
[
1
:])
# Copy over stacktrace from outputs of the original
# "movable" operation to the new operation.
copy_stack_trace
(
node
.
outputs
+
mi
.
owner
.
outputs
,
tip
)
if
movable_inputs
:
new_inputs
=
[
i
for
i
in
node
.
inputs
if
not
movable
(
i
)]
+
[
mi
.
owner
.
inputs
[
0
]
for
mi
in
movable_inputs
]
new_add
=
variadic_add
(
*
new_inputs
)
# Copy over stacktrace from original output, as an error
# (e.g. an index error) in this add operation should
# correspond to an error in the original add operation.
copy_stack_trace
(
node
.
outputs
[
0
],
new_add
)
return
[
tip
]
# stack up the new incsubtensors
tip
=
new_add
for
mi
in
movable_inputs
:
assert
o_type
.
is_super
(
tip
.
type
)
tip
=
mi
.
owner
.
op
(
tip
,
*
mi
.
owner
.
inputs
[
1
:])
# Copy over stacktrace from outputs of the original
# "movable" operation to the new operation.
copy_stack_trace
(
node
.
outputs
+
mi
.
owner
.
outputs
,
tip
)
# print incsub_inputs, [id(i.owner.inputs[0]) for i in incsub_inputs
]
return
[
tip
]
# We register it in a WalkingGraphRewriter inside the canonizer EQ optimizer.
...
...
@@ -1576,7 +1574,7 @@ compile.optdb.register(
@register_stabilize
@register_specialize
@node_rewriter
([
Blockwise
])
@node_rewriter
([
blockwise_of
(
Subtensor
)
])
def
local_blockwise_of_subtensor
(
fgraph
,
node
):
"""Rewrite Blockwise of Subtensor, where the only batch input is the indexed tensor.
...
...
@@ -1585,9 +1583,6 @@ def local_blockwise_of_subtensor(fgraph, node):
TODO: Handle batched indices like we do with blockwise of inc_subtensor
TODO: Extend to AdvanceSubtensor
"""
if
not
isinstance
(
node
.
op
.
core_op
,
Subtensor
):
return
x
,
*
idxs
=
node
.
inputs
if
not
all
(
all
(
idx
.
type
.
broadcastable
)
for
idx
in
idxs
):
return
...
...
@@ -1603,7 +1598,7 @@ def local_blockwise_of_subtensor(fgraph, node):
@register_canonicalize
(
"shape_unsafe"
)
@register_stabilize
(
"shape_unsafe"
)
@register_specialize
(
"shape_unsafe"
)
@node_rewriter
([
Blockwise
])
@node_rewriter
([
blockwise_of
(
IncSubtensor
|
AdvancedIncSubtensor
)
])
def
local_blockwise_inc_subtensor
(
fgraph
,
node
):
"""Rewrite blockwised inc_subtensors.
...
...
@@ -1614,12 +1609,9 @@ def local_blockwise_inc_subtensor(fgraph, node):
and can be safely rewritten without Blockwise.
"""
core_op
=
node
.
op
.
core_op
if
not
isinstance
(
core_op
,
AdvancedIncSubtensor
|
IncSubtensor
):
return
None
x
,
y
,
*
idxs
=
node
.
inputs
[
out
]
=
node
.
outputs
if
isinstance
(
node
.
op
.
core_op
,
AdvancedIncSubtensor
):
if
isinstance
(
core_op
,
AdvancedIncSubtensor
):
if
any
(
(
# Blockwise requires all inputs to be tensors so it is not possible
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论