Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
56327779
提交
56327779
authored
7月 15, 2024
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
10月 08, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Simplify logic with `variadic_add` and `variadic_mul` helpers
上级
cdae9037
显示空白字符变更
内嵌
并排
正在显示
6 个修改的文件
包含
47 行增加
和
53 行删除
+47
-53
blas.py
pytensor/tensor/blas.py
+2
-6
math.py
pytensor/tensor/math.py
+24
-12
basic.py
pytensor/tensor/rewriting/basic.py
+2
-7
blas.py
pytensor/tensor/rewriting/blas.py
+10
-5
math.py
pytensor/tensor/rewriting/math.py
+7
-18
subtensor.py
pytensor/tensor/rewriting/subtensor.py
+2
-5
没有找到文件。
pytensor/tensor/blas.py
浏览文件 @
56327779
...
...
@@ -102,7 +102,7 @@ from pytensor.tensor import basic as ptb
from
pytensor.tensor.basic
import
expand_dims
from
pytensor.tensor.blas_headers
import
blas_header_text
,
blas_header_version
from
pytensor.tensor.elemwise
import
DimShuffle
from
pytensor.tensor.math
import
add
,
mul
,
neg
,
sub
from
pytensor.tensor.math
import
add
,
mul
,
neg
,
sub
,
variadic_add
from
pytensor.tensor.shape
import
shape_padright
,
specify_broadcastable
from
pytensor.tensor.type
import
DenseTensorType
,
TensorType
,
integer_dtypes
,
tensor
...
...
@@ -1399,11 +1399,7 @@ def _gemm_from_factored_list(fgraph, lst):
item_to_var
(
input
)
for
k
,
input
in
enumerate
(
lst
)
if
k
not
in
(
i
,
j
)
]
add_inputs
.
extend
(
gemm_of_sM_list
)
if
len
(
add_inputs
)
>
1
:
rval
=
[
add
(
*
add_inputs
)]
else
:
rval
=
add_inputs
# print "RETURNING GEMM THING", rval
rval
=
[
variadic_add
(
*
add_inputs
)]
return
rval
,
old_dot22
...
...
pytensor/tensor/math.py
浏览文件 @
56327779
...
...
@@ -1429,18 +1429,12 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False, acc_dtype=None)
else
:
shp
=
cast
(
shp
,
"float64"
)
if
axis
is
None
:
axis
=
list
(
range
(
input
.
ndim
))
elif
isinstance
(
axis
,
int
|
np
.
integer
):
axis
=
[
axis
]
elif
isinstance
(
axis
,
np
.
ndarray
)
and
axis
.
ndim
==
0
:
axis
=
[
int
(
axis
)]
else
:
axis
=
[
int
(
a
)
for
a
in
axis
]
# This sequential division will possibly be optimized by PyTensor:
for
i
in
axis
:
s
=
true_div
(
s
,
shp
[
i
])
reduced_dims
=
(
shp
if
axis
is
None
else
[
shp
[
i
]
for
i
in
normalize_axis_tuple
(
axis
,
input
.
type
.
ndim
)]
)
s
/=
variadic_mul
(
*
reduced_dims
)
.
astype
(
shp
.
dtype
)
# This can happen when axis is an empty list/tuple
if
s
.
dtype
!=
shp
.
dtype
and
s
.
dtype
in
discrete_dtypes
:
...
...
@@ -1596,6 +1590,15 @@ def add(a, *other_terms):
# see decorator for function body
def
variadic_add
(
*
args
):
"""Add that accepts arbitrary number of inputs, including zero or one."""
if
not
args
:
return
constant
(
0
)
if
len
(
args
)
==
1
:
return
args
[
0
]
return
add
(
*
args
)
@scalar_elemwise
def
sub
(
a
,
b
):
"""elementwise subtraction"""
...
...
@@ -1608,6 +1611,15 @@ def mul(a, *other_terms):
# see decorator for function body
def
variadic_mul
(
*
args
):
"""Mul that accepts arbitrary number of inputs, including zero or one."""
if
not
args
:
return
constant
(
1
)
if
len
(
args
)
==
1
:
return
args
[
0
]
return
mul
(
*
args
)
@scalar_elemwise
def
true_div
(
a
,
b
):
"""elementwise [true] division (inverse of multiplication)"""
...
...
pytensor/tensor/rewriting/basic.py
浏览文件 @
56327779
...
...
@@ -68,7 +68,7 @@ from pytensor.tensor.basic import (
from
pytensor.tensor.elemwise
import
DimShuffle
,
Elemwise
from
pytensor.tensor.exceptions
import
NotScalarConstantError
from
pytensor.tensor.extra_ops
import
broadcast_arrays
from
pytensor.tensor.math
import
Sum
,
add
,
eq
from
pytensor.tensor.math
import
Sum
,
add
,
eq
,
variadic_add
from
pytensor.tensor.shape
import
Shape_i
,
shape_padleft
from
pytensor.tensor.type
import
DenseTensorType
,
TensorType
from
pytensor.tensor.variable
import
TensorConstant
,
TensorVariable
...
...
@@ -939,13 +939,8 @@ def local_sum_make_vector(fgraph, node):
if
acc_dtype
==
"float64"
and
out_dtype
!=
"float64"
and
config
.
floatX
!=
"float64"
:
return
if
len
(
elements
)
==
0
:
element_sum
=
zeros
(
dtype
=
out_dtype
,
shape
=
())
elif
len
(
elements
)
==
1
:
element_sum
=
cast
(
elements
[
0
],
out_dtype
)
else
:
element_sum
=
cast
(
add
(
*
[
cast
(
value
,
acc_dtype
)
for
value
in
elements
]),
out_dtype
variadic_
add
(
*
[
cast
(
value
,
acc_dtype
)
for
value
in
elements
]),
out_dtype
)
return
[
element_sum
]
...
...
pytensor/tensor/rewriting/blas.py
浏览文件 @
56327779
...
...
@@ -96,7 +96,15 @@ from pytensor.tensor.blas import (
)
from
pytensor.tensor.elemwise
import
DimShuffle
,
Elemwise
from
pytensor.tensor.exceptions
import
NotScalarConstantError
from
pytensor.tensor.math
import
Dot
,
_matrix_matrix_matmul
,
add
,
mul
,
neg
,
sub
from
pytensor.tensor.math
import
(
Dot
,
_matrix_matrix_matmul
,
add
,
mul
,
neg
,
sub
,
variadic_add
,
)
from
pytensor.tensor.rewriting.elemwise
import
local_dimshuffle_lift
from
pytensor.tensor.type
import
(
DenseTensorType
,
...
...
@@ -386,10 +394,7 @@ def _gemm_from_factored_list(fgraph, lst):
item_to_var
(
input
)
for
k
,
input
in
enumerate
(
lst
)
if
k
not
in
(
i
,
j
)
]
add_inputs
.
extend
(
gemm_of_sM_list
)
if
len
(
add_inputs
)
>
1
:
rval
=
[
add
(
*
add_inputs
)]
else
:
rval
=
add_inputs
rval
=
[
variadic_add
(
*
add_inputs
)]
# print "RETURNING GEMM THING", rval
return
rval
,
old_dot22
...
...
pytensor/tensor/rewriting/math.py
浏览文件 @
56327779
...
...
@@ -76,6 +76,8 @@ from pytensor.tensor.math import (
sub
,
tri_gamma
,
true_div
,
variadic_add
,
variadic_mul
,
)
from
pytensor.tensor.math
import
abs
as
pt_abs
from
pytensor.tensor.math
import
max
as
pt_max
...
...
@@ -1270,17 +1272,13 @@ def local_sum_prod_of_mul_or_div(fgraph, node):
if
not
outer_terms
:
return
None
elif
len
(
outer_terms
)
==
1
:
[
outer_term
]
=
outer_terms
else
:
outer_term
=
mul
(
*
outer_terms
)
outer_term
=
variadic_
mul
(
*
outer_terms
)
if
not
inner_terms
:
inner_term
=
None
elif
len
(
inner_terms
)
==
1
:
[
inner_term
]
=
inner_terms
else
:
inner_term
=
mul
(
*
inner_terms
)
inner_term
=
variadic_
mul
(
*
inner_terms
)
else
:
# true_div
# We only care about removing the denominator out of the reduction
...
...
@@ -2143,10 +2141,7 @@ def local_add_remove_zeros(fgraph, node):
assert
cst
.
type
.
broadcastable
==
(
True
,)
*
ndim
return
[
alloc_like
(
cst
,
node_output
,
fgraph
)]
if
len
(
new_inputs
)
==
1
:
ret
=
[
alloc_like
(
new_inputs
[
0
],
node_output
,
fgraph
)]
else
:
ret
=
[
alloc_like
(
add
(
*
new_inputs
),
node_output
,
fgraph
)]
ret
=
[
alloc_like
(
variadic_add
(
*
new_inputs
),
node_output
,
fgraph
)]
# The dtype should not be changed. It can happen if the input
# that was forcing upcasting was equal to 0.
...
...
@@ -2257,10 +2252,7 @@ def local_log1p(fgraph, node):
# scalar_inputs are potentially dimshuffled and fill'd scalars
if
scalars
and
np
.
allclose
(
np
.
sum
(
scalars
),
1
):
if
nonconsts
:
if
len
(
nonconsts
)
>
1
:
ninp
=
add
(
*
nonconsts
)
else
:
ninp
=
nonconsts
[
0
]
ninp
=
variadic_add
(
*
nonconsts
)
if
ninp
.
dtype
!=
log_arg
.
type
.
dtype
:
ninp
=
ninp
.
astype
(
node
.
outputs
[
0
]
.
dtype
)
return
[
alloc_like
(
log1p
(
ninp
),
node
.
outputs
[
0
],
fgraph
)]
...
...
@@ -3084,10 +3076,7 @@ def local_exp_over_1_plus_exp(fgraph, node):
return
# put the new numerator together
new_num
=
sigmoids
+
[
exp
(
t
)
for
t
in
num_exp_x
]
+
num_rest
if
len
(
new_num
)
==
1
:
new_num
=
new_num
[
0
]
else
:
new_num
=
mul
(
*
new_num
)
new_num
=
variadic_mul
(
*
new_num
)
if
num_neg
^
denom_neg
:
new_num
=
-
new_num
...
...
pytensor/tensor/rewriting/subtensor.py
浏览文件 @
56327779
...
...
@@ -48,6 +48,7 @@ from pytensor.tensor.math import (
maximum
,
minimum
,
or_
,
variadic_add
,
)
from
pytensor.tensor.math
import
all
as
pt_all
from
pytensor.tensor.rewriting.basic
import
(
...
...
@@ -1241,11 +1242,7 @@ def local_IncSubtensor_serialize(fgraph, node):
new_inputs
=
[
i
for
i
in
node
.
inputs
if
not
movable
(
i
)]
+
[
mi
.
owner
.
inputs
[
0
]
for
mi
in
movable_inputs
]
if
len
(
new_inputs
)
==
0
:
new_add
=
new_inputs
[
0
]
else
:
new_add
=
add
(
*
new_inputs
)
new_add
=
variadic_add
(
*
new_inputs
)
# Copy over stacktrace from original output, as an error
# (e.g. an index error) in this add operation should
# correspond to an error in the original add operation.
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论