Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
56327779
提交
56327779
authored
7月 15, 2024
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
10月 08, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Simplify logic with `variadic_add` and `variadic_mul` helpers
上级
cdae9037
隐藏空白字符变更
内嵌
并排
正在显示
6 个修改的文件
包含
53 行增加
和
59 行删除
+53
-59
blas.py
pytensor/tensor/blas.py
+2
-6
math.py
pytensor/tensor/math.py
+24
-12
basic.py
pytensor/tensor/rewriting/basic.py
+4
-9
blas.py
pytensor/tensor/rewriting/blas.py
+10
-5
math.py
pytensor/tensor/rewriting/math.py
+7
-18
subtensor.py
pytensor/tensor/rewriting/subtensor.py
+6
-9
没有找到文件。
pytensor/tensor/blas.py
浏览文件 @
56327779
...
@@ -102,7 +102,7 @@ from pytensor.tensor import basic as ptb
...
@@ -102,7 +102,7 @@ from pytensor.tensor import basic as ptb
from
pytensor.tensor.basic
import
expand_dims
from
pytensor.tensor.basic
import
expand_dims
from
pytensor.tensor.blas_headers
import
blas_header_text
,
blas_header_version
from
pytensor.tensor.blas_headers
import
blas_header_text
,
blas_header_version
from
pytensor.tensor.elemwise
import
DimShuffle
from
pytensor.tensor.elemwise
import
DimShuffle
from
pytensor.tensor.math
import
add
,
mul
,
neg
,
sub
from
pytensor.tensor.math
import
add
,
mul
,
neg
,
sub
,
variadic_add
from
pytensor.tensor.shape
import
shape_padright
,
specify_broadcastable
from
pytensor.tensor.shape
import
shape_padright
,
specify_broadcastable
from
pytensor.tensor.type
import
DenseTensorType
,
TensorType
,
integer_dtypes
,
tensor
from
pytensor.tensor.type
import
DenseTensorType
,
TensorType
,
integer_dtypes
,
tensor
...
@@ -1399,11 +1399,7 @@ def _gemm_from_factored_list(fgraph, lst):
...
@@ -1399,11 +1399,7 @@ def _gemm_from_factored_list(fgraph, lst):
item_to_var
(
input
)
for
k
,
input
in
enumerate
(
lst
)
if
k
not
in
(
i
,
j
)
item_to_var
(
input
)
for
k
,
input
in
enumerate
(
lst
)
if
k
not
in
(
i
,
j
)
]
]
add_inputs
.
extend
(
gemm_of_sM_list
)
add_inputs
.
extend
(
gemm_of_sM_list
)
if
len
(
add_inputs
)
>
1
:
rval
=
[
variadic_add
(
*
add_inputs
)]
rval
=
[
add
(
*
add_inputs
)]
else
:
rval
=
add_inputs
# print "RETURNING GEMM THING", rval
return
rval
,
old_dot22
return
rval
,
old_dot22
...
...
pytensor/tensor/math.py
浏览文件 @
56327779
...
@@ -1429,18 +1429,12 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False, acc_dtype=None)
...
@@ -1429,18 +1429,12 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False, acc_dtype=None)
else
:
else
:
shp
=
cast
(
shp
,
"float64"
)
shp
=
cast
(
shp
,
"float64"
)
if
axis
is
None
:
reduced_dims
=
(
axis
=
list
(
range
(
input
.
ndim
))
shp
elif
isinstance
(
axis
,
int
|
np
.
integer
):
if
axis
is
None
axis
=
[
axis
]
else
[
shp
[
i
]
for
i
in
normalize_axis_tuple
(
axis
,
input
.
type
.
ndim
)]
elif
isinstance
(
axis
,
np
.
ndarray
)
and
axis
.
ndim
==
0
:
)
axis
=
[
int
(
axis
)]
s
/=
variadic_mul
(
*
reduced_dims
)
.
astype
(
shp
.
dtype
)
else
:
axis
=
[
int
(
a
)
for
a
in
axis
]
# This sequential division will possibly be optimized by PyTensor:
for
i
in
axis
:
s
=
true_div
(
s
,
shp
[
i
])
# This can happen when axis is an empty list/tuple
# This can happen when axis is an empty list/tuple
if
s
.
dtype
!=
shp
.
dtype
and
s
.
dtype
in
discrete_dtypes
:
if
s
.
dtype
!=
shp
.
dtype
and
s
.
dtype
in
discrete_dtypes
:
...
@@ -1596,6 +1590,15 @@ def add(a, *other_terms):
...
@@ -1596,6 +1590,15 @@ def add(a, *other_terms):
# see decorator for function body
# see decorator for function body
def
variadic_add
(
*
args
):
"""Add that accepts arbitrary number of inputs, including zero or one."""
if
not
args
:
return
constant
(
0
)
if
len
(
args
)
==
1
:
return
args
[
0
]
return
add
(
*
args
)
@scalar_elemwise
@scalar_elemwise
def
sub
(
a
,
b
):
def
sub
(
a
,
b
):
"""elementwise subtraction"""
"""elementwise subtraction"""
...
@@ -1608,6 +1611,15 @@ def mul(a, *other_terms):
...
@@ -1608,6 +1611,15 @@ def mul(a, *other_terms):
# see decorator for function body
# see decorator for function body
def
variadic_mul
(
*
args
):
"""Mul that accepts arbitrary number of inputs, including zero or one."""
if
not
args
:
return
constant
(
1
)
if
len
(
args
)
==
1
:
return
args
[
0
]
return
mul
(
*
args
)
@scalar_elemwise
@scalar_elemwise
def
true_div
(
a
,
b
):
def
true_div
(
a
,
b
):
"""elementwise [true] division (inverse of multiplication)"""
"""elementwise [true] division (inverse of multiplication)"""
...
...
pytensor/tensor/rewriting/basic.py
浏览文件 @
56327779
...
@@ -68,7 +68,7 @@ from pytensor.tensor.basic import (
...
@@ -68,7 +68,7 @@ from pytensor.tensor.basic import (
from
pytensor.tensor.elemwise
import
DimShuffle
,
Elemwise
from
pytensor.tensor.elemwise
import
DimShuffle
,
Elemwise
from
pytensor.tensor.exceptions
import
NotScalarConstantError
from
pytensor.tensor.exceptions
import
NotScalarConstantError
from
pytensor.tensor.extra_ops
import
broadcast_arrays
from
pytensor.tensor.extra_ops
import
broadcast_arrays
from
pytensor.tensor.math
import
Sum
,
add
,
eq
from
pytensor.tensor.math
import
Sum
,
add
,
eq
,
variadic_add
from
pytensor.tensor.shape
import
Shape_i
,
shape_padleft
from
pytensor.tensor.shape
import
Shape_i
,
shape_padleft
from
pytensor.tensor.type
import
DenseTensorType
,
TensorType
from
pytensor.tensor.type
import
DenseTensorType
,
TensorType
from
pytensor.tensor.variable
import
TensorConstant
,
TensorVariable
from
pytensor.tensor.variable
import
TensorConstant
,
TensorVariable
...
@@ -939,14 +939,9 @@ def local_sum_make_vector(fgraph, node):
...
@@ -939,14 +939,9 @@ def local_sum_make_vector(fgraph, node):
if
acc_dtype
==
"float64"
and
out_dtype
!=
"float64"
and
config
.
floatX
!=
"float64"
:
if
acc_dtype
==
"float64"
and
out_dtype
!=
"float64"
and
config
.
floatX
!=
"float64"
:
return
return
if
len
(
elements
)
==
0
:
element_sum
=
cast
(
element_sum
=
zeros
(
dtype
=
out_dtype
,
shape
=
())
variadic_add
(
*
[
cast
(
value
,
acc_dtype
)
for
value
in
elements
]),
out_dtype
elif
len
(
elements
)
==
1
:
)
element_sum
=
cast
(
elements
[
0
],
out_dtype
)
else
:
element_sum
=
cast
(
add
(
*
[
cast
(
value
,
acc_dtype
)
for
value
in
elements
]),
out_dtype
)
return
[
element_sum
]
return
[
element_sum
]
...
...
pytensor/tensor/rewriting/blas.py
浏览文件 @
56327779
...
@@ -96,7 +96,15 @@ from pytensor.tensor.blas import (
...
@@ -96,7 +96,15 @@ from pytensor.tensor.blas import (
)
)
from
pytensor.tensor.elemwise
import
DimShuffle
,
Elemwise
from
pytensor.tensor.elemwise
import
DimShuffle
,
Elemwise
from
pytensor.tensor.exceptions
import
NotScalarConstantError
from
pytensor.tensor.exceptions
import
NotScalarConstantError
from
pytensor.tensor.math
import
Dot
,
_matrix_matrix_matmul
,
add
,
mul
,
neg
,
sub
from
pytensor.tensor.math
import
(
Dot
,
_matrix_matrix_matmul
,
add
,
mul
,
neg
,
sub
,
variadic_add
,
)
from
pytensor.tensor.rewriting.elemwise
import
local_dimshuffle_lift
from
pytensor.tensor.rewriting.elemwise
import
local_dimshuffle_lift
from
pytensor.tensor.type
import
(
from
pytensor.tensor.type
import
(
DenseTensorType
,
DenseTensorType
,
...
@@ -386,10 +394,7 @@ def _gemm_from_factored_list(fgraph, lst):
...
@@ -386,10 +394,7 @@ def _gemm_from_factored_list(fgraph, lst):
item_to_var
(
input
)
for
k
,
input
in
enumerate
(
lst
)
if
k
not
in
(
i
,
j
)
item_to_var
(
input
)
for
k
,
input
in
enumerate
(
lst
)
if
k
not
in
(
i
,
j
)
]
]
add_inputs
.
extend
(
gemm_of_sM_list
)
add_inputs
.
extend
(
gemm_of_sM_list
)
if
len
(
add_inputs
)
>
1
:
rval
=
[
variadic_add
(
*
add_inputs
)]
rval
=
[
add
(
*
add_inputs
)]
else
:
rval
=
add_inputs
# print "RETURNING GEMM THING", rval
# print "RETURNING GEMM THING", rval
return
rval
,
old_dot22
return
rval
,
old_dot22
...
...
pytensor/tensor/rewriting/math.py
浏览文件 @
56327779
...
@@ -76,6 +76,8 @@ from pytensor.tensor.math import (
...
@@ -76,6 +76,8 @@ from pytensor.tensor.math import (
sub
,
sub
,
tri_gamma
,
tri_gamma
,
true_div
,
true_div
,
variadic_add
,
variadic_mul
,
)
)
from
pytensor.tensor.math
import
abs
as
pt_abs
from
pytensor.tensor.math
import
abs
as
pt_abs
from
pytensor.tensor.math
import
max
as
pt_max
from
pytensor.tensor.math
import
max
as
pt_max
...
@@ -1270,17 +1272,13 @@ def local_sum_prod_of_mul_or_div(fgraph, node):
...
@@ -1270,17 +1272,13 @@ def local_sum_prod_of_mul_or_div(fgraph, node):
if
not
outer_terms
:
if
not
outer_terms
:
return
None
return
None
elif
len
(
outer_terms
)
==
1
:
[
outer_term
]
=
outer_terms
else
:
else
:
outer_term
=
mul
(
*
outer_terms
)
outer_term
=
variadic_
mul
(
*
outer_terms
)
if
not
inner_terms
:
if
not
inner_terms
:
inner_term
=
None
inner_term
=
None
elif
len
(
inner_terms
)
==
1
:
[
inner_term
]
=
inner_terms
else
:
else
:
inner_term
=
mul
(
*
inner_terms
)
inner_term
=
variadic_
mul
(
*
inner_terms
)
else
:
# true_div
else
:
# true_div
# We only care about removing the denominator out of the reduction
# We only care about removing the denominator out of the reduction
...
@@ -2143,10 +2141,7 @@ def local_add_remove_zeros(fgraph, node):
...
@@ -2143,10 +2141,7 @@ def local_add_remove_zeros(fgraph, node):
assert
cst
.
type
.
broadcastable
==
(
True
,)
*
ndim
assert
cst
.
type
.
broadcastable
==
(
True
,)
*
ndim
return
[
alloc_like
(
cst
,
node_output
,
fgraph
)]
return
[
alloc_like
(
cst
,
node_output
,
fgraph
)]
if
len
(
new_inputs
)
==
1
:
ret
=
[
alloc_like
(
variadic_add
(
*
new_inputs
),
node_output
,
fgraph
)]
ret
=
[
alloc_like
(
new_inputs
[
0
],
node_output
,
fgraph
)]
else
:
ret
=
[
alloc_like
(
add
(
*
new_inputs
),
node_output
,
fgraph
)]
# The dtype should not be changed. It can happen if the input
# The dtype should not be changed. It can happen if the input
# that was forcing upcasting was equal to 0.
# that was forcing upcasting was equal to 0.
...
@@ -2257,10 +2252,7 @@ def local_log1p(fgraph, node):
...
@@ -2257,10 +2252,7 @@ def local_log1p(fgraph, node):
# scalar_inputs are potentially dimshuffled and fill'd scalars
# scalar_inputs are potentially dimshuffled and fill'd scalars
if
scalars
and
np
.
allclose
(
np
.
sum
(
scalars
),
1
):
if
scalars
and
np
.
allclose
(
np
.
sum
(
scalars
),
1
):
if
nonconsts
:
if
nonconsts
:
if
len
(
nonconsts
)
>
1
:
ninp
=
variadic_add
(
*
nonconsts
)
ninp
=
add
(
*
nonconsts
)
else
:
ninp
=
nonconsts
[
0
]
if
ninp
.
dtype
!=
log_arg
.
type
.
dtype
:
if
ninp
.
dtype
!=
log_arg
.
type
.
dtype
:
ninp
=
ninp
.
astype
(
node
.
outputs
[
0
]
.
dtype
)
ninp
=
ninp
.
astype
(
node
.
outputs
[
0
]
.
dtype
)
return
[
alloc_like
(
log1p
(
ninp
),
node
.
outputs
[
0
],
fgraph
)]
return
[
alloc_like
(
log1p
(
ninp
),
node
.
outputs
[
0
],
fgraph
)]
...
@@ -3084,10 +3076,7 @@ def local_exp_over_1_plus_exp(fgraph, node):
...
@@ -3084,10 +3076,7 @@ def local_exp_over_1_plus_exp(fgraph, node):
return
return
# put the new numerator together
# put the new numerator together
new_num
=
sigmoids
+
[
exp
(
t
)
for
t
in
num_exp_x
]
+
num_rest
new_num
=
sigmoids
+
[
exp
(
t
)
for
t
in
num_exp_x
]
+
num_rest
if
len
(
new_num
)
==
1
:
new_num
=
variadic_mul
(
*
new_num
)
new_num
=
new_num
[
0
]
else
:
new_num
=
mul
(
*
new_num
)
if
num_neg
^
denom_neg
:
if
num_neg
^
denom_neg
:
new_num
=
-
new_num
new_num
=
-
new_num
...
...
pytensor/tensor/rewriting/subtensor.py
浏览文件 @
56327779
...
@@ -48,6 +48,7 @@ from pytensor.tensor.math import (
...
@@ -48,6 +48,7 @@ from pytensor.tensor.math import (
maximum
,
maximum
,
minimum
,
minimum
,
or_
,
or_
,
variadic_add
,
)
)
from
pytensor.tensor.math
import
all
as
pt_all
from
pytensor.tensor.math
import
all
as
pt_all
from
pytensor.tensor.rewriting.basic
import
(
from
pytensor.tensor.rewriting.basic
import
(
...
@@ -1241,15 +1242,11 @@ def local_IncSubtensor_serialize(fgraph, node):
...
@@ -1241,15 +1242,11 @@ def local_IncSubtensor_serialize(fgraph, node):
new_inputs
=
[
i
for
i
in
node
.
inputs
if
not
movable
(
i
)]
+
[
new_inputs
=
[
i
for
i
in
node
.
inputs
if
not
movable
(
i
)]
+
[
mi
.
owner
.
inputs
[
0
]
for
mi
in
movable_inputs
mi
.
owner
.
inputs
[
0
]
for
mi
in
movable_inputs
]
]
if
len
(
new_inputs
)
==
0
:
new_add
=
variadic_add
(
*
new_inputs
)
new_add
=
new_inputs
[
0
]
# Copy over stacktrace from original output, as an error
else
:
# (e.g. an index error) in this add operation should
new_add
=
add
(
*
new_inputs
)
# correspond to an error in the original add operation.
copy_stack_trace
(
node
.
outputs
[
0
],
new_add
)
# Copy over stacktrace from original output, as an error
# (e.g. an index error) in this add operation should
# correspond to an error in the original add operation.
copy_stack_trace
(
node
.
outputs
[
0
],
new_add
)
# stack up the new incsubtensors
# stack up the new incsubtensors
tip
=
new_add
tip
=
new_add
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论