Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
e0a2a865
提交
e0a2a865
authored
11月 27, 2025
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
11月 28, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Remove predefined inplace Elemwise Ops and redundant tests
上级
42e8490c
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
10 个修改的文件
包含
171 行增加
和
510 行删除
+171
-510
test.yml
.github/workflows/test.yml
+6
-6
elemwise.py
pytensor/tensor/elemwise.py
+6
-11
inplace.py
pytensor/tensor/inplace.py
+0
-427
test_elemwise.py
tests/link/numba/test_elemwise.py
+5
-3
test_math.py
tests/tensor/rewriting/test_math.py
+39
-40
test_blas.py
tests/tensor/test_blas.py
+7
-3
test_elemwise.py
tests/tensor/test_elemwise.py
+90
-3
test_inplace.py
tests/tensor/test_inplace.py
+0
-0
test_math_scipy.py
tests/tensor/test_math_scipy.py
+0
-0
utils.py
tests/tensor/utils.py
+18
-17
没有找到文件。
.github/workflows/test.yml
浏览文件 @
e0a2a865
...
...
@@ -84,13 +84,13 @@ jobs:
install-mlx
:
[
0
]
install-xarray
:
[
0
]
part
:
-
"
tests
--ignore=tests/
tensor
--ignore=tests/scan
--ignore=tests/xtensor"
-
"
tests
--ignore=tests/
scan
--ignore=tests/tensor
--ignore=tests/xtensor"
-
"
tests/scan"
-
"
tests/tensor
--ignore=tests/tensor/
rewriting
--ignore=tests/tensor/test_math.py
--ignore=tests/tensor/test_basic.py
--ignore=tests/tensor/test_inplace.py
--ignore=tests/tensor/conv
--ignore=tests/tensor/test_blas.py
--ignore=tests/tensor/test_elemwise.py
--ignore=tests/tensor/test_math_scipy.py
"
-
"
tests/tensor/
rewriting
"
-
"
tests/tensor
--ignore=tests/tensor/
test_basic.py
--ignore=tests/tensor/test_elemwise.py
--ignore=tests/tensor/test_math.py
--ignore=tests/tensor/test_math_scipy.py
--ignore=tests/tensor/test_blas.py
--ignore=tests/tensor/conv
--ignore=tests/tensor/rewriting
"
-
"
tests/tensor/
test_basic.py
tests/tensor/test_elemwise.py
"
-
"
tests/tensor/test_math.py"
-
"
tests/tensor/test_
basic.py
tests/tensor/test_inplace
.py
tests/tensor/conv"
-
"
tests/tensor/
test_blas.py
tests/tensor/test_elemwise.py
tests/tensor/test_math_scipy.py
"
-
"
tests/tensor/test_
math_scipy.py
tests/tensor/test_blas
.py
tests/tensor/conv"
-
"
tests/tensor/
rewriting
"
exclude
:
-
python-version
:
"
3.11"
fast-compile
:
1
...
...
@@ -167,7 +167,7 @@ jobs:
install-numba
:
0
install-jax
:
0
install-torch
:
0
part
:
"
tests/tensor/test_
blas.py
tests/tensor/test_elemwise.py
tests/tensor/test_math_scipy
.py"
part
:
"
tests/tensor/test_
elemwise.py
tests/tensor/test_math_scipy.py
tests/tensor/test_blas
.py"
steps
:
-
uses
:
actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8
# v5.0.0
...
...
pytensor/tensor/elemwise.py
浏览文件 @
e0a2a865
...
...
@@ -20,7 +20,7 @@ from pytensor.misc.frozendict import frozendict
from
pytensor.printing
import
Printer
,
pprint
from
pytensor.scalar
import
get_scalar_type
from
pytensor.scalar.basic
import
identity
as
scalar_identity
from
pytensor.scalar.basic
import
int64
,
transfer_type
,
upcast
from
pytensor.scalar.basic
import
int64
,
upcast
from
pytensor.tensor
import
elemwise_cgen
as
cgen
from
pytensor.tensor
import
get_vector_length
from
pytensor.tensor.basic
import
_get_vector_length
,
as_tensor_variable
...
...
@@ -1634,17 +1634,12 @@ def scalar_elemwise(*symbol, nfunc=None, nin=None, nout=None, symbolname=None):
symbolname
=
symbolname
or
symbol
.
__name__
if
symbolname
.
endswith
(
"_inplace"
):
base_symbol_name
=
symbolname
[:
-
len
(
"_inplace"
)]
scalar_op
=
getattr
(
scalar
,
base_symbol_name
)
inplace_scalar_op
=
scalar_op
.
__class__
(
transfer_type
(
0
))
rval
=
Elemwise
(
inplace_scalar_op
,
{
0
:
0
},
nfunc_spec
=
(
nfunc
and
(
nfunc
,
nin
,
nout
)),
raise
ValueError
(
"Creation of automatic inplace elemwise operations deprecated"
)
else
:
scalar_op
=
getattr
(
scalar
,
symbolname
)
rval
=
Elemwise
(
scalar_op
,
nfunc_spec
=
(
nfunc
and
(
nfunc
,
nin
,
nout
)))
scalar_op
=
getattr
(
scalar
,
symbolname
)
rval
=
Elemwise
(
scalar_op
,
nfunc_spec
=
(
nfunc
and
(
nfunc
,
nin
,
nout
)))
if
getattr
(
symbol
,
"__doc__"
):
rval
.
__doc__
=
symbol
.
__doc__
...
...
pytensor/tensor/inplace.py
deleted
100644 → 0
浏览文件 @
42e8490c
from
pytensor
import
printing
from
pytensor.printing
import
pprint
from
pytensor.tensor.elemwise
import
scalar_elemwise
@scalar_elemwise
def
lt_inplace
(
a
,
b
):
"""a < b (inplace on a)"""
@scalar_elemwise
def
gt_inplace
(
a
,
b
):
"""a > b (inplace on a)"""
@scalar_elemwise
def
le_inplace
(
a
,
b
):
"""a <= b (inplace on a)"""
@scalar_elemwise
def
ge_inplace
(
a
,
b
):
"""a >= b (inplace on a)"""
@scalar_elemwise
def
eq_inplace
(
a
,
b
):
"""a == b (inplace on a)"""
@scalar_elemwise
def
neq_inplace
(
a
,
b
):
"""a != b (inplace on a)"""
@scalar_elemwise
def
and__inplace
(
a
,
b
):
"""bitwise a & b (inplace on a)"""
@scalar_elemwise
def
or__inplace
(
a
,
b
):
"""bitwise a | b (inplace on a)"""
@scalar_elemwise
def
xor_inplace
(
a
,
b
):
"""bitwise a ^ b (inplace on a)"""
@scalar_elemwise
def
invert_inplace
(
a
):
"""bitwise ~a (inplace on a)"""
@scalar_elemwise
def
abs_inplace
(
a
):
"""|`a`| (inplace on `a`)"""
@scalar_elemwise
def
exp_inplace
(
a
):
"""e^`a` (inplace on `a`)"""
@scalar_elemwise
def
exp2_inplace
(
a
):
"""2^`a` (inplace on `a`)"""
@scalar_elemwise
def
expm1_inplace
(
a
):
"""e^`a` - 1 (inplace on `a`)"""
@scalar_elemwise
def
neg_inplace
(
a
):
"""-a (inplace on a)"""
@scalar_elemwise
def
reciprocal_inplace
(
a
):
"""1.0/a (inplace on a)"""
@scalar_elemwise
def
log_inplace
(
a
):
"""base e logarithm of a (inplace on a)"""
@scalar_elemwise
def
log1p_inplace
(
a
):
"""log(1+a)"""
@scalar_elemwise
def
log2_inplace
(
a
):
"""base 2 logarithm of a (inplace on a)"""
@scalar_elemwise
def
log10_inplace
(
a
):
"""base 10 logarithm of a (inplace on a)"""
@scalar_elemwise
def
sign_inplace
(
a
):
"""sign of `a` (inplace on `a`)"""
@scalar_elemwise
def
ceil_inplace
(
a
):
"""ceil of `a` (inplace on `a`)"""
@scalar_elemwise
def
floor_inplace
(
a
):
"""floor of `a` (inplace on `a`)"""
@scalar_elemwise
def
trunc_inplace
(
a
):
"""trunc of `a` (inplace on `a`)"""
@scalar_elemwise
def
round_half_to_even_inplace
(
a
):
"""round_half_to_even_inplace(a) (inplace on `a`)"""
@scalar_elemwise
def
round_half_away_from_zero_inplace
(
a
):
"""round_half_away_from_zero_inplace(a) (inplace on `a`)"""
@scalar_elemwise
def
sqr_inplace
(
a
):
"""square of `a` (inplace on `a`)"""
@scalar_elemwise
def
sqrt_inplace
(
a
):
"""square root of `a` (inplace on `a`)"""
@scalar_elemwise
def
deg2rad_inplace
(
a
):
"""convert degree `a` to radian(inplace on `a`)"""
@scalar_elemwise
def
rad2deg_inplace
(
a
):
"""convert radian `a` to degree(inplace on `a`)"""
@scalar_elemwise
def
cos_inplace
(
a
):
"""cosine of `a` (inplace on `a`)"""
@scalar_elemwise
def
arccos_inplace
(
a
):
"""arccosine of `a` (inplace on `a`)"""
@scalar_elemwise
def
sin_inplace
(
a
):
"""sine of `a` (inplace on `a`)"""
@scalar_elemwise
def
arcsin_inplace
(
a
):
"""arcsine of `a` (inplace on `a`)"""
@scalar_elemwise
def
tan_inplace
(
a
):
"""tangent of `a` (inplace on `a`)"""
@scalar_elemwise
def
arctan_inplace
(
a
):
"""arctangent of `a` (inplace on `a`)"""
@scalar_elemwise
def
arctan2_inplace
(
a
,
b
):
"""arctangent of `a` / `b` (inplace on `a`)"""
@scalar_elemwise
def
cosh_inplace
(
a
):
"""hyperbolic cosine of `a` (inplace on `a`)"""
@scalar_elemwise
def
arccosh_inplace
(
a
):
"""hyperbolic arc cosine of `a` (inplace on `a`)"""
@scalar_elemwise
def
sinh_inplace
(
a
):
"""hyperbolic sine of `a` (inplace on `a`)"""
@scalar_elemwise
def
arcsinh_inplace
(
a
):
"""hyperbolic arc sine of `a` (inplace on `a`)"""
@scalar_elemwise
def
tanh_inplace
(
a
):
"""hyperbolic tangent of `a` (inplace on `a`)"""
@scalar_elemwise
def
arctanh_inplace
(
a
):
"""hyperbolic arc tangent of `a` (inplace on `a`)"""
@scalar_elemwise
def
erf_inplace
(
a
):
"""error function"""
@scalar_elemwise
def
erfc_inplace
(
a
):
"""complementary error function"""
@scalar_elemwise
def
erfcx_inplace
(
a
):
"""scaled complementary error function"""
@scalar_elemwise
def
owens_t_inplace
(
h
,
a
):
"""owens t function"""
@scalar_elemwise
def
gamma_inplace
(
a
):
"""gamma function"""
@scalar_elemwise
def
gammaln_inplace
(
a
):
"""log gamma function"""
@scalar_elemwise
def
psi_inplace
(
a
):
"""derivative of log gamma function"""
@scalar_elemwise
def
tri_gamma_inplace
(
a
):
"""second derivative of the log gamma function"""
@scalar_elemwise
def
gammainc_inplace
(
k
,
x
):
"""regularized lower gamma function (P)"""
@scalar_elemwise
def
gammaincc_inplace
(
k
,
x
):
"""regularized upper gamma function (Q)"""
@scalar_elemwise
def
gammau_inplace
(
k
,
x
):
"""upper incomplete gamma function"""
@scalar_elemwise
def
gammal_inplace
(
k
,
x
):
"""lower incomplete gamma function"""
@scalar_elemwise
def
gammaincinv_inplace
(
k
,
x
):
"""Inverse to the regularized lower incomplete gamma function"""
@scalar_elemwise
def
gammainccinv_inplace
(
k
,
x
):
"""Inverse of the regularized upper incomplete gamma function"""
@scalar_elemwise
def
j0_inplace
(
x
):
"""Bessel function of the first kind of order 0."""
@scalar_elemwise
def
j1_inplace
(
x
):
"""Bessel function of the first kind of order 1."""
@scalar_elemwise
def
jv_inplace
(
v
,
x
):
"""Bessel function of the first kind of order v (real)."""
@scalar_elemwise
def
i0_inplace
(
x
):
"""Modified Bessel function of the first kind of order 0."""
@scalar_elemwise
def
i1_inplace
(
x
):
"""Modified Bessel function of the first kind of order 1."""
@scalar_elemwise
def
iv_inplace
(
v
,
x
):
"""Modified Bessel function of the first kind of order v (real)."""
@scalar_elemwise
def
ive_inplace
(
v
,
x
):
"""Exponentially scaled modified Bessel function of the first kind of order v (real)."""
@scalar_elemwise
def
sigmoid_inplace
(
x
):
"""Logistic sigmoid function (1 / (1 + exp(-x)), also known as expit or inverse logit"""
@scalar_elemwise
def
softplus_inplace
(
x
):
"""Compute log(1 + exp(x)), also known as softplus or log1pexp"""
@scalar_elemwise
def
log1mexp_inplace
(
x
):
"""Compute log(1 - exp(x)), also known as log1mexp"""
@scalar_elemwise
def
betainc_inplace
(
a
,
b
,
x
):
"""Regularized incomplete beta function"""
@scalar_elemwise
def
betaincinv_inplace
(
a
,
b
,
x
):
"""Inverse of the regularized incomplete beta function"""
@scalar_elemwise
def
second_inplace
(
a
):
"""Fill `a` with `b`"""
fill_inplace
=
second_inplace
pprint
.
assign
(
fill_inplace
,
printing
.
FunctionPrinter
([
"fill="
]))
@scalar_elemwise
def
maximum_inplace
(
a
,
b
):
"""elementwise addition (inplace on `a`)"""
@scalar_elemwise
def
minimum_inplace
(
a
,
b
):
"""elementwise addition (inplace on `a`)"""
@scalar_elemwise
def
add_inplace
(
a
,
b
):
"""elementwise addition (inplace on `a`)"""
@scalar_elemwise
def
sub_inplace
(
a
,
b
):
"""elementwise subtraction (inplace on `a`)"""
@scalar_elemwise
def
mul_inplace
(
a
,
b
):
"""elementwise multiplication (inplace on `a`)"""
@scalar_elemwise
def
true_div_inplace
(
a
,
b
):
"""elementwise division (inplace on `a`)"""
@scalar_elemwise
def
int_div_inplace
(
a
,
b
):
"""elementwise division (inplace on `a`)"""
@scalar_elemwise
def
mod_inplace
(
a
,
b
):
"""elementwise modulo (inplace on `a`)"""
@scalar_elemwise
def
pow_inplace
(
a
,
b
):
"""elementwise power (inplace on `a`)"""
@scalar_elemwise
def
conj_inplace
(
a
):
"""elementwise conjugate (inplace on `a`)"""
@scalar_elemwise
def
hyp2f1_inplace
(
a
,
b
,
c
,
z
):
"""gaussian hypergeometric function"""
pprint
.
assign
(
add_inplace
,
printing
.
OperatorPrinter
(
"+="
,
-
2
,
"either"
))
pprint
.
assign
(
mul_inplace
,
printing
.
OperatorPrinter
(
"*="
,
-
1
,
"either"
))
pprint
.
assign
(
sub_inplace
,
printing
.
OperatorPrinter
(
"-="
,
-
2
,
"left"
))
pprint
.
assign
(
neg_inplace
,
printing
.
OperatorPrinter
(
"-="
,
0
,
"either"
))
pprint
.
assign
(
true_div_inplace
,
printing
.
OperatorPrinter
(
"/="
,
-
1
,
"left"
))
pprint
.
assign
(
int_div_inplace
,
printing
.
OperatorPrinter
(
"//="
,
-
1
,
"left"
))
pprint
.
assign
(
pow_inplace
,
printing
.
OperatorPrinter
(
"**="
,
1
,
"right"
))
def
transpose_inplace
(
x
,
**
kwargs
):
"Perform a transpose on a tensor without copying the underlying storage"
dims
=
list
(
range
(
x
.
ndim
-
1
,
-
1
,
-
1
))
return
x
.
dimshuffle
(
dims
)
tests/link/numba/test_elemwise.py
浏览文件 @
e0a2a865
...
...
@@ -6,13 +6,13 @@ import scipy.special
import
pytensor
import
pytensor.tensor
as
pt
import
pytensor.tensor.inplace
as
pti
import
pytensor.tensor.math
as
ptm
from
pytensor
import
config
,
function
from
pytensor.compile
import
get_mode
from
pytensor.compile.ops
import
deep_copy_op
from
pytensor.gradient
import
grad
from
pytensor.scalar
import
Composite
,
float64
from
pytensor.scalar
import
add
as
scalar_add
from
pytensor.tensor
import
blas
,
tensor
from
pytensor.tensor.elemwise
import
CAReduce
,
DimShuffle
,
Elemwise
from
pytensor.tensor.math
import
All
,
Any
,
Max
,
Min
,
Prod
,
ProdWithoutZeros
,
Sum
...
...
@@ -30,6 +30,8 @@ from tests.tensor.test_elemwise import (
rng
=
np
.
random
.
default_rng
(
42849
)
add_inplace
=
Elemwise
(
scalar_add
,
{
0
:
0
})
@pytest.mark.parametrize
(
"inputs, input_vals, output_fn"
,
...
...
@@ -80,7 +82,7 @@ rng = np.random.default_rng(42849)
np
.
array
(
1.0
,
dtype
=
config
.
floatX
),
np
.
array
(
1.0
,
dtype
=
config
.
floatX
),
],
lambda
x
,
y
:
pti
.
add_inplace
(
deep_copy_op
(
x
),
deep_copy_op
(
y
)),
lambda
x
,
y
:
add_inplace
(
deep_copy_op
(
x
),
deep_copy_op
(
y
)),
),
(
[
pt
.
vector
(),
pt
.
vector
()],
...
...
@@ -88,7 +90,7 @@ rng = np.random.default_rng(42849)
rng
.
standard_normal
(
100
)
.
astype
(
config
.
floatX
),
rng
.
standard_normal
(
100
)
.
astype
(
config
.
floatX
),
],
lambda
x
,
y
:
pti
.
add_inplace
(
deep_copy_op
(
x
),
deep_copy_op
(
y
)),
lambda
x
,
y
:
add_inplace
(
deep_copy_op
(
x
),
deep_copy_op
(
y
)),
),
(
[
pt
.
vector
(),
pt
.
vector
()],
...
...
tests/tensor/rewriting/test_math.py
浏览文件 @
e0a2a865
...
...
@@ -31,7 +31,6 @@ from pytensor.graph.rewriting.utils import is_same_graph, rewrite_graph
from
pytensor.graph.traversal
import
ancestors
from
pytensor.printing
import
debugprint
from
pytensor.scalar
import
PolyGamma
,
Psi
,
TriGamma
from
pytensor.tensor
import
inplace
from
pytensor.tensor.basic
import
Alloc
,
constant
,
join
,
second
,
switch
from
pytensor.tensor.blas
import
Dot22
,
Gemv
from
pytensor.tensor.blas_c
import
CGemv
...
...
@@ -1134,15 +1133,15 @@ def test_log1p():
f
=
function
([
x
],
log
(
1
+
(
x
)),
mode
=
m
)
assert
[
node
.
op
for
node
in
f
.
maker
.
fgraph
.
toposort
()]
==
[
log1p
]
f
=
function
([
x
],
log
(
1
+
(
-
x
)),
mode
=
m
)
assert
[
node
.
op
for
node
in
f
.
maker
.
fgraph
.
toposort
()]
==
[
neg
,
inplace
.
log1p_inplace
,
assert
[
node
.
op
.
scalar_op
for
node
in
f
.
maker
.
fgraph
.
toposort
()]
==
[
ps
.
neg
,
ps
.
log1p
,
]
f
=
function
([
x
],
-
log
(
1
+
(
-
x
)),
mode
=
m
)
assert
[
node
.
op
for
node
in
f
.
maker
.
fgraph
.
toposort
()]
==
[
neg
,
inplace
.
log1p_inplace
,
inplace
.
neg_inplace
,
assert
[
node
.
op
.
scalar_op
for
node
in
f
.
maker
.
fgraph
.
toposort
()]
==
[
ps
.
neg
,
ps
.
log1p
,
ps
.
neg
,
]
# check trickier cases (and use different dtype)
...
...
@@ -4035,27 +4034,27 @@ class TestSigmoidRewrites:
# todo: solve issue #4589 first
# assert check_stack_trace(
# f, ops_to_check=[sigmoid, neg_inplace])
assert
[
node
.
op
for
node
in
f
.
maker
.
fgraph
.
toposort
()]
==
[
sigmoid
,
inplace
.
neg_inplace
,
assert
[
node
.
op
.
scalar_op
for
node
in
f
.
maker
.
fgraph
.
toposort
()]
==
[
ps
.
sigmoid
,
ps
.
neg
,
]
f
(
data
)
f
=
pytensor
.
function
([
x
],
pt
.
fill
(
x
,
-
1.0
)
/
(
1
-
exp
(
-
x
)),
mode
=
m
)
assert
[
node
.
op
for
node
in
f
.
maker
.
fgraph
.
toposort
()]
!=
[
sigmoid
,
inplace
.
neg_inplace
,
assert
[
node
.
op
.
scalar_op
for
node
in
f
.
maker
.
fgraph
.
toposort
()]
!=
[
ps
.
sigmoid
,
ps
.
neg
,
]
f
(
data
)
f
=
pytensor
.
function
([
x
],
pt
.
fill
(
x
,
-
1.0
)
/
(
2
+
exp
(
-
x
)),
mode
=
m
)
assert
[
node
.
op
for
node
in
f
.
maker
.
fgraph
.
toposort
()]
!=
[
sigmoid
,
inplace
.
neg_inplace
,
assert
[
node
.
op
.
scalar_op
for
node
in
f
.
maker
.
fgraph
.
toposort
()]
!=
[
ps
.
sigmoid
,
ps
.
neg
,
]
f
(
data
)
f
=
pytensor
.
function
([
x
],
pt
.
fill
(
x
,
-
1.1
)
/
(
1
+
exp
(
-
x
)),
mode
=
m
)
assert
[
node
.
op
for
node
in
f
.
maker
.
fgraph
.
toposort
()]
!=
[
sigmoid
,
inplace
.
neg_inplace
,
assert
[
node
.
op
.
scalar_op
for
node
in
f
.
maker
.
fgraph
.
toposort
()]
!=
[
ps
.
sigmoid
,
ps
.
neg
,
]
f
(
data
)
...
...
@@ -4077,10 +4076,10 @@ class TestSigmoidRewrites:
(
pt
.
fill
(
x
,
-
1.1
)
*
exp
(
x
))
/
((
1
+
exp
(
x
))
*
(
1
+
exp
(
-
x
))),
mode
=
m
,
)
assert
[
node
.
op
for
node
in
f
.
maker
.
fgraph
.
toposort
()]
!=
[
sigmoid
,
mul
,
inplace
.
neg_inplace
,
assert
[
node
.
op
.
scalar_op
for
node
in
f
.
maker
.
fgraph
.
toposort
()]
!=
[
ps
.
sigmoid
,
ps
.
mul
,
ps
.
neg
,
]
f
(
data
)
f
=
pytensor
.
function
(
...
...
@@ -4088,10 +4087,10 @@ class TestSigmoidRewrites:
(
pt
.
fill
(
x
,
-
1.0
)
*
exp
(
x
))
/
((
2
+
exp
(
x
))
*
(
1
+
exp
(
-
x
))),
mode
=
m
,
)
assert
[
node
.
op
for
node
in
f
.
maker
.
fgraph
.
toposort
()]
!=
[
sigmoid
,
mul
,
inplace
.
neg_inplace
,
assert
[
node
.
op
.
scalar_op
for
node
in
f
.
maker
.
fgraph
.
toposort
()]
!=
[
ps
.
sigmoid
,
ps
.
mul
,
ps
.
neg
,
]
f
(
data
)
f
=
pytensor
.
function
(
...
...
@@ -4099,10 +4098,10 @@ class TestSigmoidRewrites:
(
pt
.
fill
(
x
,
-
1.0
)
*
exp
(
x
))
/
((
1
+
exp
(
x
))
*
(
2
+
exp
(
-
x
))),
mode
=
m
,
)
assert
[
node
.
op
for
node
in
f
.
maker
.
fgraph
.
toposort
()]
!=
[
sigmoid
,
mul
,
inplace
.
neg_inplace
,
assert
[
node
.
op
.
scalar_op
for
node
in
f
.
maker
.
fgraph
.
toposort
()]
!=
[
ps
.
sigmoid
,
ps
.
mul
,
ps
.
neg
,
]
f
(
data
)
f
=
pytensor
.
function
(
...
...
@@ -4110,10 +4109,10 @@ class TestSigmoidRewrites:
(
pt
.
fill
(
x
,
-
1.0
)
*
exp
(
x
))
/
((
1
+
exp
(
x
))
*
(
1
+
exp
(
x
))),
mode
=
m
,
)
assert
[
node
.
op
for
node
in
f
.
maker
.
fgraph
.
toposort
()]
!=
[
sigmoid
,
mul
,
inplace
.
neg_inplace
,
assert
[
node
.
op
.
scalar_op
for
node
in
f
.
maker
.
fgraph
.
toposort
()]
!=
[
ps
.
sigmoid
,
ps
.
mul
,
ps
.
neg
,
]
f
(
data
)
f
=
pytensor
.
function
(
...
...
@@ -4121,10 +4120,10 @@ class TestSigmoidRewrites:
(
pt
.
fill
(
x
,
-
1.0
)
*
exp
(
x
))
/
((
1
+
exp
(
x
))
*
(
2
+
exp
(
-
x
))),
mode
=
m
,
)
assert
[
node
.
op
for
node
in
f
.
maker
.
fgraph
.
toposort
()]
!=
[
sigmoid
,
mul
,
inplace
.
neg_inplace
,
assert
[
node
.
op
.
scalar_op
for
node
in
f
.
maker
.
fgraph
.
toposort
()]
!=
[
ps
.
sigmoid
,
ps
.
mul
,
ps
.
neg
,
]
f
(
data
)
...
...
tests/tensor/test_blas.py
浏览文件 @
e0a2a865
...
...
@@ -17,7 +17,6 @@ from pytensor.configdefaults import config
from
pytensor.gradient
import
grad
from
pytensor.graph.rewriting.basic
import
in2out
from
pytensor.graph.utils
import
InconsistencyError
from
pytensor.tensor
import
inplace
from
pytensor.tensor.basic
import
as_tensor_variable
from
pytensor.tensor.blas
import
(
BatchedDot
,
...
...
@@ -40,6 +39,7 @@ from pytensor.tensor.blas import (
ger
,
ger_destructive
,
)
from
pytensor.tensor.elemwise
import
DimShuffle
from
pytensor.tensor.math
import
Dot
,
dot
,
mean
,
mul
,
outer
,
sigmoid
from
pytensor.tensor.rewriting.blas
import
local_dot22_to_dot22scalar
,
local_gemm_to_ger
from
pytensor.tensor.type
import
(
...
...
@@ -258,16 +258,20 @@ class TestGemm:
rng
=
np
.
random
.
default_rng
(
seed
=
utt
.
fetch_seed
())
Z
=
as_tensor_variable
(
rng
.
random
((
2
,
2
)))
A
=
as_tensor_variable
(
rng
.
random
((
2
,
2
)))
Zt
=
Z
.
transpose
()
assert
isinstance
(
Zt
.
owner
.
op
,
DimShuffle
)
and
Zt
.
owner
.
op
.
view_map
==
{
0
:
[
0
]}
with
pytest
.
raises
(
InconsistencyError
,
match
=
Gemm
.
E_z_uniq
):
gemm_inplace
(
Z
,
1.0
,
A
,
inplace
.
transpose_inplace
(
Z
)
,
1.0
)
gemm_inplace
(
Z
,
1.0
,
A
,
Zt
,
1.0
)
def
test_destroy_map2
(
self
):
# test that only first input can be overwritten.
rng
=
np
.
random
.
default_rng
(
seed
=
utt
.
fetch_seed
())
Z
=
as_tensor_variable
(
rng
.
random
((
2
,
2
)))
A
=
as_tensor_variable
(
rng
.
random
((
2
,
2
)))
Zt
=
Z
.
transpose
()
assert
isinstance
(
Zt
.
owner
.
op
,
DimShuffle
)
and
Zt
.
owner
.
op
.
view_map
==
{
0
:
[
0
]}
with
pytest
.
raises
(
InconsistencyError
,
match
=
Gemm
.
E_z_uniq
):
gemm_inplace
(
Z
,
1.0
,
inplace
.
transpose_inplace
(
Z
)
,
A
,
1.0
)
gemm_inplace
(
Z
,
1.0
,
Zt
,
A
,
1.0
)
def
test_destroy_map3
(
self
):
# test that only first input can be overwritten
...
...
tests/tensor/test_elemwise.py
浏览文件 @
e0a2a865
...
...
@@ -20,6 +20,9 @@ from pytensor.graph.replace import vectorize_node
from
pytensor.link.basic
import
PerformLinker
from
pytensor.link.c.basic
import
CLinker
,
OpWiseCLinker
from
pytensor.scalar
import
ScalarOp
,
float32
,
float64
,
int32
,
int64
from
pytensor.scalar
import
add
as
scalar_add
from
pytensor.scalar
import
exp
as
scalar_exp
from
pytensor.scalar
import
xor
as
scalar_xor
from
pytensor.tensor
import
as_tensor_variable
from
pytensor.tensor.basic
import
get_scalar_constant_value
,
second
from
pytensor.tensor.elemwise
import
CAReduce
,
DimShuffle
,
Elemwise
...
...
@@ -43,6 +46,16 @@ from pytensor.tensor.type import (
)
from
tests
import
unittest_tools
from
tests.link.test_link
import
make_function
from
tests.tensor.utils
import
(
_bad_runtime_broadcast_binary_normal
,
inplace_func
,
integers
,
integers_uint16
,
integers_uint32
,
makeBroadcastTester
,
random
,
random_complex
,
)
def
reduce_bitwise_and
(
x
,
axis
=-
1
,
dtype
=
"int8"
):
...
...
@@ -334,7 +347,7 @@ class TestBroadcast:
x
=
x_type
(
"x"
)
y
=
y_type
(
"y"
)
e
=
op
(
ps
.
Add
(
ps
.
transfer_type
(
0
))
,
{
0
:
0
})(
x
,
y
)
e
=
op
(
ps
.
add
,
{
0
:
0
})(
x
,
y
)
f
=
make_function
(
copy
(
linker
)
.
accept
(
FunctionGraph
([
x
,
y
],
[
e
])))
xv
=
rand_val
(
xsh
)
yv
=
rand_val
(
ysh
)
...
...
@@ -348,7 +361,7 @@ class TestBroadcast:
if
isinstance
(
linker
,
PerformLinker
):
x
=
x_type
(
"x"
)
y
=
y_type
(
"y"
)
e
=
op
(
ps
.
Add
(
ps
.
transfer_type
(
0
))
,
{
0
:
0
})(
x
,
y
)
e
=
op
(
ps
.
add
,
{
0
:
0
})(
x
,
y
)
f
=
make_function
(
copy
(
linker
)
.
accept
(
FunctionGraph
([
x
,
y
],
[
e
.
shape
])))
xv
=
rand_val
(
xsh
)
yv
=
rand_val
(
ysh
)
...
...
@@ -390,7 +403,10 @@ class TestBroadcast:
):
x
=
t
(
pytensor
.
config
.
floatX
,
shape
=
(
None
,
None
))(
"x"
)
y
=
t
(
pytensor
.
config
.
floatX
,
shape
=
(
1
,
1
))(
"y"
)
e
=
op
(
ps
.
Second
(
ps
.
transfer_type
(
0
)),
{
0
:
0
})(
x
,
y
)
op1
=
op
(
ps
.
second
,
{
0
:
0
})
op2
=
op
(
ps
.
second
,
{
0
:
0
})
assert
op1
==
op2
e
=
op
(
ps
.
Second
(),
{
0
:
0
})(
x
,
y
)
f
=
make_function
(
linker
()
.
accept
(
FunctionGraph
([
x
,
y
],
[
e
])))
xv
=
rval
((
5
,
5
))
yv
=
rval
((
1
,
1
))
...
...
@@ -1113,3 +1129,74 @@ def test_numpy_warning_suppressed():
y
=
pt
.
log
(
x
)
fn
=
pytensor
.
function
([
x
],
y
,
mode
=
Mode
(
linker
=
"py"
))
assert
fn
(
0
)
==
-
np
.
inf
rng
=
np
.
random
.
default_rng
(
18
)
_good_add_inplace
=
dict
(
same_shapes
=
(
random
(
2
,
3
,
rng
=
rng
),
random
(
2
,
3
,
rng
=
rng
)),
not_same_dimensions
=
(
random
(
2
,
2
,
rng
=
rng
),
random
(
2
,
rng
=
rng
)),
scalar
=
(
random
(
2
,
3
,
rng
=
rng
),
random
(
1
,
1
,
rng
=
rng
)),
row
=
(
random
(
2
,
3
,
rng
=
rng
),
random
(
1
,
3
,
rng
=
rng
)),
column
=
(
random
(
2
,
3
,
rng
=
rng
),
random
(
2
,
1
,
rng
=
rng
)),
integers
=
(
integers
(
2
,
3
,
rng
=
rng
),
integers
(
2
,
3
,
rng
=
rng
)),
uint32
=
(
integers_uint32
(
2
,
3
,
rng
=
rng
),
integers_uint32
(
2
,
3
,
rng
=
rng
)),
uint16
=
(
integers_uint16
(
2
,
3
,
rng
=
rng
),
integers_uint16
(
2
,
3
,
rng
=
rng
)),
# (float32, >int16) upcasts to float64 by default
dtype_valid_mixup
=
(
random
(
2
,
3
,
rng
=
rng
),
integers
(
2
,
3
,
rng
=
rng
)
.
astype
(
"int16"
if
config
.
floatX
==
"float32"
else
"int64"
),
),
complex1
=
(
random_complex
(
2
,
3
,
rng
=
rng
),
random_complex
(
2
,
3
,
rng
=
rng
)),
complex2
=
(
random_complex
(
2
,
3
,
rng
=
rng
),
random
(
2
,
3
,
rng
=
rng
)),
empty
=
(
np
.
asarray
([],
dtype
=
config
.
floatX
),
np
.
asarray
([
1
],
dtype
=
config
.
floatX
)),
)
TestAddInplaceBroadcast
=
makeBroadcastTester
(
op
=
Elemwise
(
scalar_add
,
{
0
:
0
}),
expected
=
lambda
x
,
y
:
x
+
y
,
good
=
_good_add_inplace
,
# Cannot inplace on first input if it doesn't match output dtype (upcast of inputs)
bad_build
=
dict
(
dtype_invalid_mixup
=
_good_add_inplace
[
"dtype_valid_mixup"
][::
-
1
]),
bad_runtime
=
_bad_runtime_broadcast_binary_normal
,
inplace
=
True
,
)
@pytest.mark.xfail
(
config
.
cycle_detection
==
"fast"
and
config
.
mode
!=
"FAST_COMPILE"
,
reason
=
"Cycle detection is fast and mode is FAST_COMPILE"
,
)
def
test_exp_inplace_grad_1
():
utt
.
verify_grad
(
Elemwise
(
scalar_exp
,
{
0
:
0
}),
[
np
.
asarray
(
[
[
1.5089518
,
1.48439076
,
-
4.7820262
],
[
2.04832468
,
0.50791564
,
-
1.58892269
],
]
)
],
)
def
test_XOR_inplace
():
dtype
=
[
"int8"
,
"int16"
,
"int32"
,
"int64"
,
]
xor_inplace
=
Elemwise
(
scalar_xor
,
{
0
:
0
})
for
dtype
in
dtype
:
x
,
y
=
vector
(
dtype
=
dtype
),
vector
(
dtype
=
dtype
)
l
=
np
.
asarray
([
0
,
0
,
1
,
1
],
dtype
=
dtype
)
r
=
np
.
asarray
([
0
,
1
,
0
,
1
],
dtype
=
dtype
)
ix
=
x
ix
=
xor_inplace
(
ix
,
y
)
gn
=
inplace_func
([
x
,
y
],
ix
)
_
=
gn
(
l
,
r
)
# test the in-place stuff
assert
np
.
all
(
l
==
np
.
asarray
([
0
,
1
,
1
,
0
])),
l
tests/tensor/test_inplace.py
deleted
100644 → 0
浏览文件 @
42e8490c
差异被折叠。
点击展开。
tests/tensor/test_math_scipy.py
浏览文件 @
e0a2a865
差异被折叠。
点击展开。
tests/tensor/utils.py
浏览文件 @
e0a2a865
...
...
@@ -672,7 +672,9 @@ def makeTester(
return
Checker
def
makeBroadcastTester
(
op
,
expected
,
checks
=
None
,
name
=
None
,
**
kwargs
):
def
makeBroadcastTester
(
op
,
expected
,
checks
=
None
,
name
=
None
,
*
,
inplace
=
False
,
**
kwargs
):
if
checks
is
None
:
checks
=
{}
if
name
is
None
:
...
...
@@ -695,22 +697,20 @@ def makeBroadcastTester(op, expected, checks=None, name=None, **kwargs):
# cases we need to add it manually.
if
not
name
.
endswith
(
"Tester"
):
name
+=
"Tester"
if
"inplace"
in
kwargs
:
if
kwargs
[
"inplace"
]:
_expected
=
expected
if
not
isinstance
(
_expected
,
dict
):
def
expected
(
*
inputs
):
return
np
.
array
(
_expected
(
*
inputs
),
dtype
=
inputs
[
0
]
.
dtype
)
def
inplace_check
(
inputs
,
outputs
):
# this used to be inputs[0] is output[0]
# I changed it so that it was easier to satisfy by the
# DebugMode
return
np
.
all
(
inputs
[
0
]
==
outputs
[
0
])
checks
=
dict
(
checks
,
inplace_check
=
inplace_check
)
del
kwargs
[
"inplace"
]
if
inplace
:
_expected
=
expected
if
not
isinstance
(
_expected
,
dict
):
def
expected
(
*
inputs
):
return
np
.
array
(
_expected
(
*
inputs
),
dtype
=
inputs
[
0
]
.
dtype
)
def
inplace_check
(
inputs
,
outputs
):
# this used to be inputs[0] is output[0]
# I changed it so that it was easier to satisfy by the
# DebugMode
return
np
.
all
(
inputs
[
0
]
==
outputs
[
0
])
checks
=
dict
(
checks
,
inplace_check
=
inplace_check
)
return
makeTester
(
name
,
op
,
expected
,
checks
,
**
kwargs
)
...
...
@@ -815,6 +815,7 @@ _good_broadcast_unary_normal_no_complex = dict(
big_scalar
=
[
np
.
arange
(
17.0
,
29.0
,
0.5
,
dtype
=
config
.
floatX
)],
)
# FIXME: Why is this empty?
_bad_build_broadcast_binary_normal
=
dict
()
_bad_runtime_broadcast_binary_normal
=
dict
(
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论