Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
0c138495
提交
0c138495
authored
7月 24, 2025
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
7月 25, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Make Dot only accept matrix inputs
上级
d1be796e
隐藏空白字符变更
内嵌
并排
正在显示
20 个修改的文件
包含
161 行增加
和
480 行删除
+161
-480
__init__.py
pytensor/tensor/__init__.py
+0
-1
basic.py
pytensor/tensor/basic.py
+1
-2
blas.py
pytensor/tensor/blas.py
+14
-22
blas_scipy.py
pytensor/tensor/blas_scipy.py
+0
-34
math.py
pytensor/tensor/math.py
+52
-112
__init__.py
pytensor/tensor/rewriting/__init__.py
+0
-1
blas.py
pytensor/tensor/rewriting/blas.py
+1
-17
blas_scipy.py
pytensor/tensor/rewriting/blas_scipy.py
+0
-37
linalg.py
pytensor/tensor/rewriting/linalg.py
+1
-9
math.py
pytensor/tensor/rewriting/math.py
+43
-83
subtensor_lift.py
pytensor/tensor/rewriting/subtensor_lift.py
+3
-18
test_kanren.py
tests/graph/rewriting/test_kanren.py
+20
-20
test_basic.py
tests/link/numba/test_basic.py
+1
-1
test_math.py
tests/tensor/rewriting/test_math.py
+3
-2
test_blas.py
tests/tensor/test_blas.py
+0
-1
test_blas_c.py
tests/tensor/test_blas_c.py
+0
-3
test_blas_scipy.py
tests/tensor/test_blas_scipy.py
+0
-75
test_math.py
tests/tensor/test_math.py
+15
-37
test_printing.py
tests/test_printing.py
+4
-4
test_math.py
tests/xtensor/test_math.py
+3
-1
没有找到文件。
pytensor/tensor/__init__.py
浏览文件 @
0c138495
...
@@ -107,7 +107,6 @@ from pytensor.gradient import grad, hessian, jacobian
...
@@ -107,7 +107,6 @@ from pytensor.gradient import grad, hessian, jacobian
from
pytensor.tensor
import
(
from
pytensor.tensor
import
(
blas
,
blas
,
blas_c
,
blas_c
,
blas_scipy
,
sharedvar
,
sharedvar
,
xlogx
,
xlogx
,
)
)
...
...
pytensor/tensor/basic.py
浏览文件 @
0c138495
...
@@ -1801,8 +1801,7 @@ class Alloc(COp):
...
@@ -1801,8 +1801,7 @@ class Alloc(COp):
|
pytensor
.
tensor
.
blas
.
Gemv
|
pytensor
.
tensor
.
blas
.
Gemv
|
pytensor
.
tensor
.
blas_c
.
CGemv
|
pytensor
.
tensor
.
blas_c
.
CGemv
|
pytensor
.
tensor
.
blas
.
Ger
|
pytensor
.
tensor
.
blas
.
Ger
|
pytensor
.
tensor
.
blas_c
.
CGer
|
pytensor
.
tensor
.
blas_c
.
CGer
,
|
pytensor
.
tensor
.
blas_scipy
.
ScipyGer
,
)
)
):
):
# Ops that will work inplace on the Alloc. So if they
# Ops that will work inplace on the Alloc. So if they
...
...
pytensor/tensor/blas.py
浏览文件 @
0c138495
...
@@ -83,6 +83,7 @@ import warnings
...
@@ -83,6 +83,7 @@ import warnings
from
pathlib
import
Path
from
pathlib
import
Path
import
numpy
as
np
import
numpy
as
np
from
scipy.linalg
import
get_blas_funcs
from
pytensor.graph
import
vectorize_graph
from
pytensor.graph
import
vectorize_graph
from
pytensor.npy_2_compat
import
normalize_axis_tuple
from
pytensor.npy_2_compat
import
normalize_axis_tuple
...
@@ -288,18 +289,17 @@ class Ger(Op):
...
@@ -288,18 +289,17 @@ class Ger(Op):
return
Apply
(
self
,
inputs
,
[
A
.
type
()])
return
Apply
(
self
,
inputs
,
[
A
.
type
()])
def
perform
(
self
,
node
,
inp
,
out
):
def
perform
(
self
,
node
,
inputs
,
output_storage
):
cA
,
calpha
,
cx
,
cy
=
inp
A
,
alpha
,
x
,
y
=
inputs
(
cZ
,)
=
out
if
A
.
size
:
if
self
.
destructive
:
# GER doesn't handle zero-sized inputs
A
=
cA
ger_func
=
get_blas_funcs
(
"ger"
,
dtype
=
A
.
dtype
)
else
:
if
A
.
flags
[
"C_CONTIGUOUS"
]:
A
=
cA
.
copy
()
# Work on transposed system to avoid copying
if
calpha
!=
1
:
A
=
ger_func
(
alpha
,
y
,
x
,
a
=
A
.
T
,
overwrite_a
=
self
.
destructive
)
.
T
A
+=
calpha
*
np
.
outer
(
cx
,
cy
)
else
:
else
:
A
=
ger_func
(
alpha
,
x
,
y
,
a
=
A
,
overwrite_a
=
self
.
destructive
)
A
+=
np
.
outer
(
cx
,
cy
)
output_storage
[
0
][
0
]
=
A
cZ
[
0
]
=
A
def
infer_shape
(
self
,
fgraph
,
node
,
input_shapes
):
def
infer_shape
(
self
,
fgraph
,
node
,
input_shapes
):
return
[
input_shapes
[
0
]]
return
[
input_shapes
[
0
]]
...
@@ -1128,16 +1128,8 @@ class Dot22(GemmRelated):
...
@@ -1128,16 +1128,8 @@ class Dot22(GemmRelated):
outputs
=
[
tensor
(
dtype
=
x
.
type
.
dtype
,
shape
=
(
x
.
type
.
shape
[
0
],
y
.
type
.
shape
[
1
]))]
outputs
=
[
tensor
(
dtype
=
x
.
type
.
dtype
,
shape
=
(
x
.
type
.
shape
[
0
],
y
.
type
.
shape
[
1
]))]
return
Apply
(
self
,
[
x
,
y
],
outputs
)
return
Apply
(
self
,
[
x
,
y
],
outputs
)
def
perform
(
self
,
node
,
inp
,
out
):
def
perform
(
self
,
node
,
inputs
,
output_storage
):
x
,
y
=
inp
output_storage
[
0
][
0
]
=
np
.
dot
(
*
inputs
)
(
z
,)
=
out
try
:
z
[
0
]
=
np
.
asarray
(
np
.
dot
(
x
,
y
))
except
ValueError
as
e
:
# The error raised by numpy has no shape information, we mean to
# add that
e
.
args
=
(
*
e
.
args
,
x
.
shape
,
y
.
shape
)
raise
def
infer_shape
(
self
,
fgraph
,
node
,
input_shapes
):
def
infer_shape
(
self
,
fgraph
,
node
,
input_shapes
):
return
[[
input_shapes
[
0
][
0
],
input_shapes
[
1
][
1
]]]
return
[[
input_shapes
[
0
][
0
],
input_shapes
[
1
][
1
]]]
...
...
pytensor/tensor/blas_scipy.py
deleted
100644 → 0
浏览文件 @
d1be796e
"""
Implementations of BLAS Ops based on scipy's BLAS bindings.
"""
from
pytensor.tensor.blas
import
Ger
class
ScipyGer
(
Ger
):
def
perform
(
self
,
node
,
inputs
,
output_storage
):
from
scipy.linalg.blas
import
get_blas_funcs
cA
,
calpha
,
cx
,
cy
=
inputs
(
cZ
,)
=
output_storage
# N.B. some versions of scipy (e.g. mine) don't actually work
# in-place on a, even when I tell it to.
A
=
cA
local_ger
=
get_blas_funcs
(
"ger"
,
dtype
=
cA
.
dtype
)
if
A
.
size
==
0
:
# We don't have to compute anything, A is empty.
# We need this special case because Numpy considers it
# C-contiguous, which is confusing.
if
not
self
.
destructive
:
# Sometimes numpy thinks empty matrices can share memory,
# so here to stop DebugMode from complaining.
A
=
A
.
copy
()
elif
A
.
flags
[
"C_CONTIGUOUS"
]:
A
=
local_ger
(
calpha
,
cy
,
cx
,
a
=
A
.
T
,
overwrite_a
=
int
(
self
.
destructive
))
.
T
else
:
A
=
local_ger
(
calpha
,
cx
,
cy
,
a
=
A
,
overwrite_a
=
int
(
self
.
destructive
))
cZ
[
0
]
=
A
scipy_ger_no_inplace
=
ScipyGer
(
False
)
scipy_ger_inplace
=
ScipyGer
(
True
)
pytensor/tensor/math.py
浏览文件 @
0c138495
...
@@ -40,12 +40,13 @@ from pytensor.tensor.elemwise import (
...
@@ -40,12 +40,13 @@ from pytensor.tensor.elemwise import (
get_normalized_batch_axes
,
get_normalized_batch_axes
,
scalar_elemwise
,
scalar_elemwise
,
)
)
from
pytensor.tensor.shape
import
shape
,
specify_
broadcastabl
e
from
pytensor.tensor.shape
import
shape
,
specify_
shap
e
from
pytensor.tensor.type
import
(
from
pytensor.tensor.type
import
(
DenseTensorType
,
DenseTensorType
,
complex_dtypes
,
complex_dtypes
,
continuous_dtypes
,
continuous_dtypes
,
discrete_dtypes
,
discrete_dtypes
,
float_dtypes
,
int_dtypes
,
int_dtypes
,
tensor
,
tensor
,
uint_dtypes
,
uint_dtypes
,
...
@@ -2986,9 +2987,7 @@ pprint.assign(pow, printing.OperatorPrinter("**", 1, "right"))
...
@@ -2986,9 +2987,7 @@ pprint.assign(pow, printing.OperatorPrinter("**", 1, "right"))
class
Dot
(
Op
):
class
Dot
(
Op
):
"""
"""
Computes the dot product of two variables. For two matrices, this is
Computes the dot product of two matrices variables
equivalent to matrix multiplication. For two vectors, this is the inner
product.
Notes
Notes
-----
-----
...
@@ -3001,97 +3000,58 @@ class Dot(Op):
...
@@ -3001,97 +3000,58 @@ class Dot(Op):
"""
"""
gufunc_signature
=
"(m,n),(n,p)->(m,p)"
gufunc_spec
=
(
"matmul"
,
2
,
1
)
__props__
=
()
__props__
=
()
# the rationale for Dot22 is related to getting GEMM Ops into the
def
make_node
(
self
,
x
,
y
):
# graph. See Dot22 in tensor.blas for details.
x
=
as_tensor_variable
(
x
)
y
=
as_tensor_variable
(
y
)
def
make_node
(
self
,
*
inputs
):
inputs
=
list
(
map
(
as_tensor_variable
,
inputs
))
if
len
(
inputs
)
!=
2
:
if
x
.
type
.
ndim
!=
2
:
raise
TypeError
(
f
"Two arguments required, {len(inputs)} given "
)
if
inputs
[
0
]
.
ndim
not
in
(
1
,
2
):
raise
TypeError
(
raise
TypeError
(
"Input 0 (0-indexed) must have ndim of "
f
"Dot Op expects a 2D tensor as input 0, got {x} with {x.type.ndim} dimensions"
f
"1 or 2, {int(inputs[0].ndim)} given. Consider calling "
"pytensor.tensor.dot instead."
)
)
if
inputs
[
1
]
.
ndim
not
in
(
1
,
2
)
:
if
y
.
type
.
ndim
!=
2
:
raise
TypeError
(
raise
TypeError
(
"Input 1 (0-indexed) must have ndim of "
f
"Dot Op expects a 2D tensor as input 1, got {y} with {y.type.ndim} dimensions"
f
"1 or 2, {int(inputs[1].ndim)} given. Consider calling "
"pytensor.tensor.dot instead."
)
)
sx
,
sy
=
(
input
.
type
.
shape
for
input
in
inputs
)
sx
,
sy
=
x
.
type
.
shape
,
y
.
type
.
shape
if
sx
[
-
1
]
is
not
None
and
sy
[
0
]
is
not
None
and
sx
[
-
1
]
!=
sy
[
0
]:
if
sx
[
-
1
]
is
not
None
and
sy
[
0
]
is
not
None
and
sx
[
-
1
]
!=
sy
[
0
]:
raise
ValueError
(
raise
ValueError
(
f
"Incompatible shared dimension for dot product: {sx}, {sy}"
f
"Incompatible shared dimension for dot product: {sx}, {sy}"
)
)
out_shape
=
(
sx
[
0
],
sy
[
1
])
out_dtype
=
ps
.
upcast
(
x
.
type
.
dtype
,
y
.
type
.
dtype
)
outputs
=
[
tensor
(
dtype
=
out_dtype
,
shape
=
out_shape
)]
return
Apply
(
self
,
[
x
,
y
],
outputs
)
if
len
(
sy
)
==
2
:
def
perform
(
self
,
node
,
inputs
,
output_storage
):
sz
=
sx
[:
-
1
]
+
sy
[
-
1
:]
output_storage
[
0
][
0
]
=
np
.
matmul
(
*
inputs
)
elif
len
(
sy
)
==
1
:
sz
=
sx
[:
-
1
]
i_dtypes
=
[
input
.
type
.
dtype
for
input
in
inputs
]
outputs
=
[
tensor
(
dtype
=
ps
.
upcast
(
*
i_dtypes
),
shape
=
sz
)]
return
Apply
(
self
,
inputs
,
outputs
)
def
perform
(
self
,
node
,
inp
,
out
):
x
,
y
=
inp
(
z
,)
=
out
# the asarray is here because dot between two vectors
# gives a numpy float object but we need to return a 0d
# ndarray
z
[
0
]
=
np
.
asarray
(
np
.
dot
(
x
,
y
))
def
grad
(
self
,
inp
,
grads
):
def
grad
(
self
,
inp
,
grads
):
x
,
y
=
inp
x
,
y
=
inp
(
gz
,)
=
grads
(
gz
,)
=
grads
xdim
,
ydim
,
gdim
=
x
.
type
.
ndim
,
y
.
type
.
ndim
,
gz
.
type
.
ndim
# grad is scalar, so x is vector and y is vector
if
gdim
==
0
:
xgrad
=
gz
*
y
ygrad
=
gz
*
x
# x is vector, y is matrix, grad is vector
elif
xdim
==
1
and
ydim
==
2
:
xgrad
=
dot
(
gz
,
y
.
T
)
ygrad
=
outer
(
x
.
T
,
gz
)
# x is matrix, y is vector, grad is vector
elif
xdim
==
2
and
ydim
==
1
:
xgrad
=
outer
(
gz
,
y
.
T
)
ygrad
=
dot
(
x
.
T
,
gz
)
# x is matrix, y is matrix, grad is matrix
xgrad
=
self
(
gz
,
y
.
T
)
elif
xdim
==
ydim
==
2
:
ygrad
=
self
(
x
.
T
,
gz
)
xgrad
=
dot
(
gz
,
y
.
T
)
ygrad
=
dot
(
x
.
T
,
gz
)
# If x or y contain broadcastable dimensions but only one of
# If x or y contain broadcastable dimensions but only one of
# them know that a matching dimensions is broadcastable, the
# them know that a matching dimensions is broadcastable, the
# above code don't always return the right broadcast pattern.
# above code don't always return the right broadcast pattern.
# This cause problem down the road. See gh-1461.
# This cause problem down the road. See gh-1461.
if
xgrad
.
broadcastable
!=
x
.
broadcastable
:
if
xgrad
.
type
.
shape
!=
x
.
type
.
shape
:
xgrad
=
specify_broadcastable
(
xgrad
=
specify_shape
(
xgrad
,
x
.
type
.
shape
)
xgrad
,
*
(
ax
for
(
ax
,
b
)
in
enumerate
(
x
.
type
.
broadcastable
)
if
b
)
if
ygrad
.
type
.
shape
!=
y
.
type
.
shape
:
)
ygrad
=
specify_shape
(
ygrad
,
y
.
type
.
shape
)
if
ygrad
.
broadcastable
!=
y
.
broadcastable
:
ygrad
=
specify_broadcastable
(
ygrad
,
*
(
ax
for
(
ax
,
b
)
in
enumerate
(
y
.
type
.
broadcastable
)
if
b
)
)
rval
=
xgrad
,
ygrad
for
elem
in
rval
:
if
xgrad
.
type
.
dtype
not
in
float_dtypes
:
assert
elem
.
dtype
.
find
(
"float"
)
!=
-
1
raise
TypeError
(
"Dot grad x output must be a float type"
)
if
ygrad
.
type
.
dtype
not
in
float_dtypes
:
raise
TypeError
(
"Dot grad y output must be a float type"
)
return
rval
return
xgrad
,
ygrad
def
R_op
(
self
,
inputs
,
eval_points
):
def
R_op
(
self
,
inputs
,
eval_points
):
# R_op for a \dot b evaluated at c for a and d for b is
# R_op for a \dot b evaluated at c for a and d for b is
...
@@ -3116,24 +3076,7 @@ class Dot(Op):
...
@@ -3116,24 +3076,7 @@ class Dot(Op):
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
xshp
,
yshp
=
shapes
xshp
,
yshp
=
shapes
x
,
y
=
node
.
inputs
return
[[
xshp
[
0
],
yshp
[
1
]]]
# vector / vector
if
x
.
ndim
==
1
and
y
.
ndim
==
1
:
return
[()]
# matrix / vector
if
x
.
ndim
==
2
and
y
.
ndim
==
1
:
return
[
xshp
[:
-
1
]]
# vector / matrix
if
x
.
ndim
==
1
and
y
.
ndim
==
2
:
return
[
yshp
[
-
1
:]]
# matrix / matrix
if
x
.
ndim
==
2
and
y
.
ndim
==
2
:
return
[
xshp
[:
-
1
]
+
yshp
[
-
1
:]]
raise
NotImplementedError
()
def
__str__
(
self
):
return
"dot"
_dot
=
Dot
()
_dot
=
Dot
()
...
@@ -3215,7 +3158,24 @@ def dense_dot(a, b):
...
@@ -3215,7 +3158,24 @@ def dense_dot(a, b):
elif
a
.
ndim
>
2
or
b
.
ndim
>
2
:
elif
a
.
ndim
>
2
or
b
.
ndim
>
2
:
return
tensordot
(
a
,
b
,
[[
a
.
ndim
-
1
],
[
np
.
maximum
(
0
,
b
.
ndim
-
2
)]])
return
tensordot
(
a
,
b
,
[[
a
.
ndim
-
1
],
[
np
.
maximum
(
0
,
b
.
ndim
-
2
)]])
else
:
else
:
return
_dot
(
a
,
b
)
row_vector
=
a
.
ndim
==
1
if
row_vector
:
# Promote to row matrix
a
=
a
[
None
]
col_vector
=
b
.
ndim
==
1
if
col_vector
:
# Promote to column matrix
b
=
b
[:,
None
]
out
=
_dot
(
a
,
b
)
if
row_vector
:
# If we promoted a to a row matrix, we need to squeeze the first dimension
out
=
out
.
squeeze
(
0
)
if
col_vector
:
# If we promoted b to a column matrix, we need to squeeze the last dimension
out
=
out
.
squeeze
(
-
1
)
return
out
def
tensordot
(
def
tensordot
(
...
@@ -3921,11 +3881,7 @@ def logsumexp(x, axis=None, keepdims=False):
...
@@ -3921,11 +3881,7 @@ def logsumexp(x, axis=None, keepdims=False):
return
log
(
sum
(
exp
(
x
),
axis
=
axis
,
keepdims
=
keepdims
))
return
log
(
sum
(
exp
(
x
),
axis
=
axis
,
keepdims
=
keepdims
))
_matmul
=
Blockwise
(
_matmul
=
Blockwise
(
_dot
,
name
=
"Matmul"
)
_dot
,
signature
=
"(m,k),(k,n)->(m,n)"
,
gufunc_spec
=
(
"numpy.matmul"
,
2
,
1
),
)
def
matmul
(
x1
:
"ArrayLike"
,
x2
:
"ArrayLike"
,
dtype
:
Optional
[
"DTypeLike"
]
=
None
):
def
matmul
(
x1
:
"ArrayLike"
,
x2
:
"ArrayLike"
,
dtype
:
Optional
[
"DTypeLike"
]
=
None
):
...
@@ -3975,7 +3931,7 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
...
@@ -3975,7 +3931,7 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
if
x1
.
type
.
ndim
==
0
or
x2
.
type
.
ndim
==
0
:
if
x1
.
type
.
ndim
==
0
or
x2
.
type
.
ndim
==
0
:
raise
ValueError
(
"matmul operand cannot be scalar"
)
raise
ValueError
(
"matmul operand cannot be scalar"
)
if
x1
.
type
.
ndim
==
1
and
x2
.
type
.
ndim
==
1
:
if
x1
.
type
.
ndim
==
1
and
x2
.
type
.
ndim
==
1
:
out
=
_
dot
(
x1
,
x2
)
out
=
vec
dot
(
x1
,
x2
)
elif
x1
.
type
.
ndim
==
1
:
elif
x1
.
type
.
ndim
==
1
:
out
=
vecmat
(
x1
,
x2
)
out
=
vecmat
(
x1
,
x2
)
elif
x2
.
type
.
ndim
==
1
:
elif
x2
.
type
.
ndim
==
1
:
...
@@ -4139,23 +4095,7 @@ def vecmat(
...
@@ -4139,23 +4095,7 @@ def vecmat(
@_vectorize_node.register
(
Dot
)
@_vectorize_node.register
(
Dot
)
def
vectorize_node_dot
(
op
,
node
,
batched_x
,
batched_y
):
def
vectorize_node_dot
(
op
,
node
,
batched_x
,
batched_y
):
old_x
,
old_y
=
node
.
inputs
return
matmul
(
batched_x
,
batched_y
)
.
owner
old_x_ndim
=
old_x
.
type
.
ndim
old_y_ndim
=
old_y
.
type
.
ndim
match
(
old_x_ndim
,
old_y_ndim
):
case
(
1
,
1
):
batch_fn
=
vecdot
case
(
2
,
1
):
batch_fn
=
matvec
case
(
1
,
2
):
batch_fn
=
vecmat
case
(
2
,
2
):
batch_fn
=
matmul
case
_
:
raise
ValueError
(
f
"Core dot Op should have 1D or 2D inputs, got {old_x_ndim}D and {old_y_ndim}D."
)
return
batch_fn
(
batched_x
,
batched_y
)
.
owner
def
nan_to_num
(
x
,
nan
=
0.0
,
posinf
=
None
,
neginf
=
None
):
def
nan_to_num
(
x
,
nan
=
0.0
,
posinf
=
None
,
neginf
=
None
):
...
...
pytensor/tensor/rewriting/__init__.py
浏览文件 @
0c138495
import
pytensor.tensor.rewriting.basic
import
pytensor.tensor.rewriting.basic
import
pytensor.tensor.rewriting.blas
import
pytensor.tensor.rewriting.blas
import
pytensor.tensor.rewriting.blas_c
import
pytensor.tensor.rewriting.blas_c
import
pytensor.tensor.rewriting.blas_scipy
import
pytensor.tensor.rewriting.blockwise
import
pytensor.tensor.rewriting.blockwise
import
pytensor.tensor.rewriting.einsum
import
pytensor.tensor.rewriting.einsum
import
pytensor.tensor.rewriting.elemwise
import
pytensor.tensor.rewriting.elemwise
...
...
pytensor/tensor/rewriting/blas.py
浏览文件 @
0c138495
...
@@ -107,7 +107,6 @@ from pytensor.tensor.math import (
...
@@ -107,7 +107,6 @@ from pytensor.tensor.math import (
)
)
from
pytensor.tensor.rewriting.elemwise
import
local_dimshuffle_lift
from
pytensor.tensor.rewriting.elemwise
import
local_dimshuffle_lift
from
pytensor.tensor.type
import
(
from
pytensor.tensor.type
import
(
DenseTensorType
,
TensorType
,
TensorType
,
integer_dtypes
,
integer_dtypes
,
values_eq_approx_remove_inf_nan
,
values_eq_approx_remove_inf_nan
,
...
@@ -580,12 +579,6 @@ class GemmOptimizer(GraphRewriter):
...
@@ -580,12 +579,6 @@ class GemmOptimizer(GraphRewriter):
def
local_dot_to_dot22
(
fgraph
,
node
):
def
local_dot_to_dot22
(
fgraph
,
node
):
# This works for tensor.outer too because basic.outer is a macro that
# This works for tensor.outer too because basic.outer is a macro that
# produces a dot(dimshuffle,dimshuffle) of form 4 below
# produces a dot(dimshuffle,dimshuffle) of form 4 below
if
not
isinstance
(
node
.
op
,
Dot
):
return
if
any
(
not
isinstance
(
i
.
type
,
DenseTensorType
)
for
i
in
node
.
inputs
):
return
False
x
,
y
=
node
.
inputs
x
,
y
=
node
.
inputs
if
y
.
type
.
dtype
!=
x
.
type
.
dtype
:
if
y
.
type
.
dtype
!=
x
.
type
.
dtype
:
# TODO: upcast one so the types match
# TODO: upcast one so the types match
...
@@ -593,16 +586,7 @@ def local_dot_to_dot22(fgraph, node):
...
@@ -593,16 +586,7 @@ def local_dot_to_dot22(fgraph, node):
return
return
if
y
.
type
.
dtype
in
(
"float16"
,
"float32"
,
"float64"
,
"complex64"
,
"complex128"
):
if
y
.
type
.
dtype
in
(
"float16"
,
"float32"
,
"float64"
,
"complex64"
,
"complex128"
):
if
x
.
ndim
==
2
and
y
.
ndim
==
2
:
new_out
=
[
_dot22
(
*
node
.
inputs
)]
new_out
=
[
_dot22
(
*
node
.
inputs
)]
elif
x
.
ndim
==
2
and
y
.
ndim
==
1
:
new_out
=
[
_dot22
(
x
,
y
.
dimshuffle
(
0
,
"x"
))
.
dimshuffle
(
0
)]
elif
x
.
ndim
==
1
and
y
.
ndim
==
2
:
new_out
=
[
_dot22
(
x
.
dimshuffle
(
"x"
,
0
),
y
)
.
dimshuffle
(
1
)]
elif
x
.
ndim
==
1
and
y
.
ndim
==
1
:
new_out
=
[
_dot22
(
x
.
dimshuffle
(
"x"
,
0
),
y
.
dimshuffle
(
0
,
"x"
))
.
dimshuffle
()]
else
:
return
copy_stack_trace
(
node
.
outputs
,
new_out
)
copy_stack_trace
(
node
.
outputs
,
new_out
)
return
new_out
return
new_out
...
...
pytensor/tensor/rewriting/blas_scipy.py
deleted
100644 → 0
浏览文件 @
d1be796e
from
pytensor.graph.rewriting.basic
import
in2out
from
pytensor.tensor.blas
import
ger
,
ger_destructive
from
pytensor.tensor.blas_scipy
import
scipy_ger_inplace
,
scipy_ger_no_inplace
from
pytensor.tensor.rewriting.blas
import
blas_optdb
,
node_rewriter
,
optdb
@node_rewriter
([
ger
,
ger_destructive
])
def
use_scipy_ger
(
fgraph
,
node
):
if
node
.
op
==
ger
:
return
[
scipy_ger_no_inplace
(
*
node
.
inputs
)]
@node_rewriter
([
scipy_ger_no_inplace
])
def
make_ger_destructive
(
fgraph
,
node
):
if
node
.
op
==
scipy_ger_no_inplace
:
return
[
scipy_ger_inplace
(
*
node
.
inputs
)]
use_scipy_blas
=
in2out
(
use_scipy_ger
)
make_scipy_blas_destructive
=
in2out
(
make_ger_destructive
)
# scipy_blas is scheduled in the blas_optdb very late, because scipy sortof
# sucks [citation needed], but it is almost always present.
# C implementations should be scheduled earlier than this, so that they take
# precedence. Once the original Ger is replaced, then these optimizations
# have no effect.
blas_optdb
.
register
(
"scipy_blas"
,
use_scipy_blas
,
"fast_run"
,
position
=
100
)
# this matches the InplaceBlasOpt defined in blas.py
optdb
.
register
(
"make_scipy_blas_destructive"
,
make_scipy_blas_destructive
,
"fast_run"
,
"inplace"
,
position
=
50.2
,
)
pytensor/tensor/rewriting/linalg.py
浏览文件 @
0c138495
...
@@ -276,15 +276,7 @@ def cholesky_ldotlt(fgraph, node):
...
@@ -276,15 +276,7 @@ def cholesky_ldotlt(fgraph, node):
A
=
node
.
inputs
[
0
]
A
=
node
.
inputs
[
0
]
if
not
(
if
not
(
A
.
owner
is
not
None
A
.
owner
is
not
None
and
(
isinstance
(
A
.
owner
.
op
,
Dot
)
or
(
A
.
owner
.
op
==
_matmul
))
and
(
(
isinstance
(
A
.
owner
.
op
,
Dot
)
# This rewrite only applies to matrix Dot
and
A
.
owner
.
inputs
[
0
]
.
type
.
ndim
==
2
)
or
(
A
.
owner
.
op
==
_matmul
)
)
):
):
return
return
...
...
pytensor/tensor/rewriting/math.py
浏览文件 @
0c138495
...
@@ -19,7 +19,6 @@ from pytensor.graph.rewriting.basic import (
...
@@ -19,7 +19,6 @@ from pytensor.graph.rewriting.basic import (
node_rewriter
,
node_rewriter
,
)
)
from
pytensor.graph.rewriting.utils
import
get_clients_at_depth
from
pytensor.graph.rewriting.utils
import
get_clients_at_depth
from
pytensor.raise_op
import
assert_op
from
pytensor.tensor.basic
import
(
from
pytensor.tensor.basic
import
(
Alloc
,
Alloc
,
Join
,
Join
,
...
@@ -34,6 +33,7 @@ from pytensor.tensor.basic import (
...
@@ -34,6 +33,7 @@ from pytensor.tensor.basic import (
ones_like
,
ones_like
,
register_infer_shape
,
register_infer_shape
,
switch
,
switch
,
zeros
,
zeros_like
,
zeros_like
,
)
)
from
pytensor.tensor.elemwise
import
CAReduce
,
DimShuffle
,
Elemwise
from
pytensor.tensor.elemwise
import
CAReduce
,
DimShuffle
,
Elemwise
...
@@ -44,12 +44,10 @@ from pytensor.tensor.math import (
...
@@ -44,12 +44,10 @@ from pytensor.tensor.math import (
Prod
,
Prod
,
Sum
,
Sum
,
_conj
,
_conj
,
_dot
,
_matmul
,
_matmul
,
add
,
add
,
digamma
,
digamma
,
dot
,
dot
,
eq
,
erf
,
erf
,
erfc
,
erfc
,
exp
,
exp
,
...
@@ -130,16 +128,12 @@ def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False):
...
@@ -130,16 +128,12 @@ def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False):
return
consts
,
origconsts
,
nonconsts
return
consts
,
origconsts
,
nonconsts
@register_canonicalize
@register_canonicalize
(
"shape_unsafe"
)
@register_stabilize
@register_stabilize
(
"shape_unsafe"
)
@node_rewriter
([
Dot
])
@node_rewriter
([
Dot
])
def
local_0_dot_x
(
fgraph
,
node
):
def
local_0_dot_x
(
fgraph
,
node
):
if
not
isinstance
(
node
.
op
,
Dot
):
x
,
y
=
node
.
inputs
return
False
if
(
x
=
node
.
inputs
[
0
]
y
=
node
.
inputs
[
1
]
replace
=
(
get_underlying_scalar_constant_value
(
get_underlying_scalar_constant_value
(
x
,
only_process_constants
=
True
,
raise_not_constant
=
False
x
,
only_process_constants
=
True
,
raise_not_constant
=
False
)
)
...
@@ -148,26 +142,12 @@ def local_0_dot_x(fgraph, node):
...
@@ -148,26 +142,12 @@ def local_0_dot_x(fgraph, node):
y
,
only_process_constants
=
True
,
raise_not_constant
=
False
y
,
only_process_constants
=
True
,
raise_not_constant
=
False
)
)
==
0
==
0
)
):
return
[
zeros
((
x
.
shape
[
0
],
y
.
shape
[
1
]),
dtype
=
node
.
outputs
[
0
]
.
type
.
dtype
)]
if
replace
:
constant_zero
=
constant
(
0
,
dtype
=
node
.
outputs
[
0
]
.
type
.
dtype
)
if
x
.
ndim
==
2
and
y
.
ndim
==
2
:
constant_zero
=
assert_op
(
constant_zero
,
eq
(
x
.
shape
[
1
],
y
.
shape
[
0
]))
return
[
alloc
(
constant_zero
,
x
.
shape
[
0
],
y
.
shape
[
1
])]
elif
x
.
ndim
==
1
and
y
.
ndim
==
2
:
constant_zero
=
assert_op
(
constant_zero
,
eq
(
x
.
shape
[
0
],
y
.
shape
[
0
]))
return
[
alloc
(
constant_zero
,
y
.
shape
[
1
])]
elif
x
.
ndim
==
2
and
y
.
ndim
==
1
:
constant_zero
=
assert_op
(
constant_zero
,
eq
(
x
.
shape
[
1
],
y
.
shape
[
0
]))
return
[
alloc
(
constant_zero
,
x
.
shape
[
0
])]
elif
x
.
ndim
==
1
and
y
.
ndim
==
1
:
constant_zero
=
assert_op
(
constant_zero
,
eq
(
x
.
shape
[
0
],
y
.
shape
[
0
]))
return
[
constant_zero
]
@register_canonicalize
@register_canonicalize
@node_rewriter
([
D
imShuffle
])
@node_rewriter
([
D
ot
,
_matmul
])
def
local_lift_transpose_through_dot
(
fgraph
,
node
):
def
local_lift_transpose_through_dot
(
fgraph
,
node
):
r"""Perform the rewrite ``dot(x,y).T -> dot(y.T, x.T)``.
r"""Perform the rewrite ``dot(x,y).T -> dot(y.T, x.T)``.
...
@@ -176,22 +156,25 @@ def local_lift_transpose_through_dot(fgraph, node):
...
@@ -176,22 +156,25 @@ def local_lift_transpose_through_dot(fgraph, node):
and to later merge consecutive `DimShuffle`\s.
and to later merge consecutive `DimShuffle`\s.
"""
"""
if
not
(
clients
=
fgraph
.
clients
[
node
.
out
]
is_matrix_transpose
(
node
.
outputs
[
0
])
and
node
.
inputs
[
0
]
.
owner
and
((
dot_op
:
=
node
.
inputs
[
0
]
.
owner
.
op
)
in
(
_dot
,
_matmul
))
):
return
False
x
,
y
=
node
.
inputs
[
0
]
.
owner
.
inputs
if
len
(
clients
)
!=
1
:
# If the dot is used in more than one place, we don't want to duplicate it
return
None
if
x
.
ndim
>=
y
.
ndim
>=
2
:
[(
client
,
_
)]
=
clients
# Output is dot product of transposed inputs in reverse order
ret
=
[
dot_op
(
y
.
mT
,
x
.
mT
)]
# Copy over stack trace to output from result of dot-product
if
not
(
isinstance
(
client
.
op
,
DimShuffle
)
and
is_matrix_transpose
(
client
.
out
)):
copy_stack_trace
(
node
.
inputs
[
0
],
ret
)
return
None
return
ret
x
,
y
=
node
.
inputs
# Output is dot product of transposed inputs in reverse order
ret
=
node
.
op
(
y
.
mT
,
x
.
mT
)
# Copy over stack trace to output from result of dot-product
copy_stack_trace
(
node
.
out
,
ret
)
return
{
client
.
out
:
ret
}
def
_batched_matmul_to_core_matmul
(
fgraph
,
node
,
allow_reshape
:
bool
):
def
_batched_matmul_to_core_matmul
(
fgraph
,
node
,
allow_reshape
:
bool
):
...
@@ -344,57 +327,34 @@ def local_batched_matmul_to_core_matmul_with_reshape(fgraph, node):
...
@@ -344,57 +327,34 @@ def local_batched_matmul_to_core_matmul_with_reshape(fgraph, node):
@register_canonicalize
@register_canonicalize
@register_specialize
@register_specialize
@node_rewriter
([
_matmul
])
@node_rewriter
([
_matmul
,
Dot
])
def
local_
blockwise_
dot_to_mul
(
fgraph
,
node
):
def
local_dot_to_mul
(
fgraph
,
node
):
"""Rewrite
blockwise
dots that correspond to multiplication without summation.
"""Rewrite dots that correspond to multiplication without summation.
We don't touch the regular dot, to not interfere with the BLAS optimizations.
We don't touch outer product without batch-dimensions, to allow rewriting into GER,
which seems more performant in that case.
# TODO: Once we blockwise Blas operations we shouldn't do it for outer product with batch-dimensions either
# TODO: We may still want to canonicalize outer dot as mul, and detect that for GER.
"""
"""
a
,
b
=
node
.
inputs
a
,
b
=
node
.
inputs
a_static_shape
=
a
.
type
.
shape
a_static_shape
=
a
.
type
.
shape
b_static_shape
=
b
.
type
.
shape
b_static_shape
=
b
.
type
.
shape
core_a_ndim
=
len
(
node
.
op
.
inputs_sig
[
0
])
core_b_ndim
=
len
(
node
.
op
.
inputs_sig
[
1
])
if
core_a_ndim
>
2
or
core_b_ndim
>
2
:
# Check if we have matrix-matrix product: (..., m, 1) * (..., 1, n) -> (..., m, n)
# Shouldn't happen, but here just in case
if
not
(
a_static_shape
[
-
1
]
==
1
or
b_static_shape
[
-
2
]
==
1
):
return
None
return
None
if
core_b_ndim
==
1
:
# If it's a core Dot we only rewrite if there's no outer product
if
a_static_shape
[
-
1
]
==
1
or
b_static_shape
[
-
1
]
==
1
:
# (1, 1) * (1, n) or (m, 1) * (1, 1)
if
core_a_ndim
==
1
:
# Otherwise we leave as is, so GER can be used instead
# inner product: (..., 1) * (..., 1) -> (...)
if
isinstance
(
node
.
op
,
Dot
)
and
not
(
# just squeeze the last dimensions of a and b
a_static_shape
[
-
2
]
==
1
or
b_static_shape
[
-
1
]
==
1
new_a
=
a
.
squeeze
(
-
1
)
):
new_b
=
b
.
squeeze
(
-
1
)
return
None
else
:
# matrix vector product: (..., m, 1) * (..., 1) -> (..., m)
# the last dimension of b is already aligned for the elemwise multiplication
# after we squeeze the last dimension of a
new_a
=
a
.
squeeze
(
-
1
)
new_b
=
b
else
:
return
None
else
:
if
a_static_shape
[
-
1
]
==
1
or
b_static_shape
[
-
2
]
==
1
:
if
core_a_ndim
==
1
:
# vector_matrix product: (..., 1) * (..., 1, n) -> (..., n)
# the last dimension of a is already aligned for the elemwise multiplication
# after we squeeze the one to last dimension of b
new_a
=
a
new_b
=
b
.
squeeze
(
-
2
)
else
:
# matrix matrix product: (..., m, 1) * (..., 1, n) -> (..., m, n)
# the dimensions of a and b are already aligned for the elemwise multiplication
new_a
=
a
new_b
=
b
else
:
return
None
new_a
=
copy_stack_trace
(
a
,
new_a
)
new_out
=
mul
(
a
,
b
)
new_b
=
copy_stack_trace
(
b
,
new_b
)
copy_stack_trace
(
node
.
out
,
new_out
)
new_out
=
copy_stack_trace
(
node
.
out
,
mul
(
new_a
,
new_b
))
return
[
new_out
]
return
[
new_out
]
...
...
pytensor/tensor/rewriting/subtensor_lift.py
浏览文件 @
0c138495
...
@@ -158,26 +158,11 @@ def local_subtensor_of_dot(fgraph, node):
...
@@ -158,26 +158,11 @@ def local_subtensor_of_dot(fgraph, node):
a
=
a
.
type
.
clone
(
shape
=
a
.
type
.
shape
[
batch_ndim
:])()
a
=
a
.
type
.
clone
(
shape
=
a
.
type
.
shape
[
batch_ndim
:])()
b
=
b
.
type
.
clone
(
shape
=
b
.
type
.
shape
[
batch_ndim
:])()
b
=
b
.
type
.
clone
(
shape
=
b
.
type
.
shape
[
batch_ndim
:])()
a_ndim
=
a
.
ndim
a_indices
=
idx_list
[:
1
]
b_ndim
=
b
.
ndim
b_indices
=
(
slice
(
None
),
*
idx_list
[
1
:])
num_a_indices
=
min
(
a_ndim
-
1
,
len
(
idx_list
))
a_indices
=
idx_list
[:
num_a_indices
]
b_indices
=
idx_list
[
num_a_indices
:]
# This is necessary because np.dot sums the last index of a with the second to last of b
# so we want to skip the second-to-last index into b.
# This wasn't necessary for a, because we just omitted the last index.
# We skip this if b.ndim = 1, since then we just want b_sub = b, not b_sub = b[:]
# (dot also handles b.ndim < 2 as a special case)
if
b_ndim
>
1
and
len
(
b_indices
)
>=
b_ndim
-
1
:
b_indices
=
(
b_indices
[:
b_ndim
-
2
]
+
(
slice
(
None
,
None
,
None
),)
+
b_indices
[
b_ndim
-
2
:]
)
a_sub
=
a
[
tuple
(
a_indices
)]
a_sub
=
a
[
tuple
(
a_indices
)]
b_sub
=
b
[
tuple
(
b_indices
)]
if
b_indices
else
b
b_sub
=
b
[
tuple
(
b_indices
)]
r
=
dot
(
a_sub
,
b_sub
)
r
=
dot
(
a_sub
,
b_sub
)
if
batch_ndim
:
if
batch_ndim
:
...
...
tests/graph/rewriting/test_kanren.py
浏览文件 @
0c138495
...
@@ -37,51 +37,51 @@ def clear_assoccomm():
...
@@ -37,51 +37,51 @@ def clear_assoccomm():
def
test_kanren_basic
():
def
test_kanren_basic
():
A_pt
=
pt
.
matrix
(
"A"
)
A_pt
=
pt
.
matrix
(
"A"
)
x_pt
=
pt
.
vector
(
"x
"
)
B_pt
=
pt
.
matrix
(
"B
"
)
y_pt
=
pt
.
dot
(
A_pt
,
x
_pt
)
y_pt
=
pt
.
dot
(
A_pt
,
B
_pt
)
q
=
var
()
q
=
var
()
res
=
list
(
run
(
None
,
q
,
eq
(
y_pt
,
etuple
(
_dot
,
q
,
x
_pt
))))
res
=
list
(
run
(
None
,
q
,
eq
(
y_pt
,
etuple
(
_dot
,
q
,
B
_pt
))))
assert
res
==
[
A_pt
]
assert
res
==
[
A_pt
]
def
test_KanrenRelationSub_filters
():
def
test_KanrenRelationSub_filters
():
x_pt
=
pt
.
vector
(
"x"
)
y_pt
=
pt
.
vector
(
"y"
)
z_pt
=
pt
.
vector
(
"z"
)
A_pt
=
pt
.
matrix
(
"A"
)
A_pt
=
pt
.
matrix
(
"A"
)
B_pt
=
pt
.
matrix
(
"B"
)
C_pt
=
pt
.
matrix
(
"C"
)
D_pt
=
pt
.
matrix
(
"D"
)
fact
(
commutative
,
_dot
)
fact
(
commutative
,
_dot
)
fact
(
commutative
,
pt
.
add
)
fact
(
commutative
,
pt
.
add
)
fact
(
associative
,
pt
.
add
)
fact
(
associative
,
pt
.
add
)
Z_pt
=
A_pt
.
dot
((
x_pt
+
y_pt
)
+
z
_pt
)
Z_pt
=
A_pt
.
dot
((
B_pt
+
C_pt
)
+
D
_pt
)
fgraph
=
FunctionGraph
(
outputs
=
[
Z_pt
],
clone
=
False
)
fgraph
=
FunctionGraph
(
outputs
=
[
Z_pt
],
clone
=
False
)
def
distributes
(
in_lv
,
out_lv
):
def
distributes
(
in_lv
,
out_lv
):
A_lv
,
x_lv
,
y_lv
,
z
_lv
=
vars
(
4
)
A_lv
,
B_lv
,
C_lv
,
D
_lv
=
vars
(
4
)
return
lall
(
return
lall
(
# lhs == A * (x + y + z)
# lhs == A * (x + y + z)
eq_assoccomm
(
eq_assoccomm
(
etuple
(
_dot
,
A_lv
,
etuple
(
pt
.
add
,
x_lv
,
etuple
(
pt
.
add
,
y_lv
,
z
_lv
))),
etuple
(
_dot
,
A_lv
,
etuple
(
pt
.
add
,
B_lv
,
etuple
(
pt
.
add
,
C_lv
,
D
_lv
))),
in_lv
,
in_lv
,
),
),
# This relation does nothing but provide us with a means of
# This relation does nothing but provide us with a means of
# generating associative-commutative matches in the `kanren`
# generating associative-commutative matches in the `kanren`
# output.
# output.
eq
((
A_lv
,
x_lv
,
y_lv
,
z
_lv
),
out_lv
),
eq
((
A_lv
,
B_lv
,
C_lv
,
D
_lv
),
out_lv
),
)
)
def
results_filter
(
results
):
def
results_filter
(
results
):
_results
=
[
eval_if_etuple
(
v
)
for
v
in
results
]
_results
=
[
eval_if_etuple
(
v
)
for
v
in
results
]
# Make sure that at least a couple permutations are present
# Make sure that at least a couple permutations are present
assert
(
A_pt
,
x_pt
,
y_pt
,
z
_pt
)
in
_results
assert
(
A_pt
,
B_pt
,
C_pt
,
D
_pt
)
in
_results
assert
(
A_pt
,
y_pt
,
x_pt
,
z
_pt
)
in
_results
assert
(
A_pt
,
C_pt
,
B_pt
,
D
_pt
)
in
_results
assert
(
A_pt
,
z_pt
,
x_pt
,
y
_pt
)
in
_results
assert
(
A_pt
,
D_pt
,
B_pt
,
C
_pt
)
in
_results
return
None
return
None
...
@@ -121,13 +121,13 @@ def test_KanrenRelationSub_multiout():
...
@@ -121,13 +121,13 @@ def test_KanrenRelationSub_multiout():
def
test_KanrenRelationSub_dot
():
def
test_KanrenRelationSub_dot
():
"""Make sure we can run miniKanren "optimizations" over a graph until a fixed-point/normal-form is reached."""
"""Make sure we can run miniKanren "optimizations" over a graph until a fixed-point/normal-form is reached."""
x_pt
=
pt
.
vector
(
"x"
)
c_pt
=
pt
.
vector
(
"c"
)
d_pt
=
pt
.
vector
(
"d"
)
A_pt
=
pt
.
matrix
(
"A"
)
A_pt
=
pt
.
matrix
(
"A"
)
B_pt
=
pt
.
matrix
(
"B"
)
B_pt
=
pt
.
matrix
(
"B"
)
C_pt
=
pt
.
matrix
(
"C"
)
D_pt
=
pt
.
matrix
(
"D"
)
E_pt
=
pt
.
matrix
(
"E"
)
Z_pt
=
A_pt
.
dot
(
x_pt
+
B_pt
.
dot
(
c_pt
+
d
_pt
))
Z_pt
=
A_pt
.
dot
(
E_pt
+
B_pt
.
dot
(
C_pt
+
D
_pt
))
fgraph
=
FunctionGraph
(
outputs
=
[
Z_pt
],
clone
=
False
)
fgraph
=
FunctionGraph
(
outputs
=
[
Z_pt
],
clone
=
False
)
...
@@ -137,15 +137,15 @@ def test_KanrenRelationSub_dot():
...
@@ -137,15 +137,15 @@ def test_KanrenRelationSub_dot():
return
lall
(
return
lall
(
# lhs == A * (x + b)
# lhs == A * (x + b)
eq
(
eq
(
etuple
(
_dot
,
var
(
"A"
),
etuple
(
pt
.
add
,
var
(
"
x"
),
var
(
"b
"
))),
etuple
(
_dot
,
var
(
"A"
),
etuple
(
pt
.
add
,
var
(
"
E"
),
var
(
"B
"
))),
in_lv
,
in_lv
,
),
),
# rhs == A * x + A * b
# rhs == A * x + A * b
eq
(
eq
(
etuple
(
etuple
(
pt
.
add
,
pt
.
add
,
etuple
(
_dot
,
var
(
"A"
),
var
(
"
x
"
)),
etuple
(
_dot
,
var
(
"A"
),
var
(
"
E
"
)),
etuple
(
_dot
,
var
(
"A"
),
var
(
"
b
"
)),
etuple
(
_dot
,
var
(
"A"
),
var
(
"
B
"
)),
),
),
out_lv
,
out_lv
,
),
),
...
...
tests/link/numba/test_basic.py
浏览文件 @
0c138495
...
@@ -631,7 +631,7 @@ def test_Dot(x, y):
...
@@ -631,7 +631,7 @@ def test_Dot(x, y):
x
,
x_test_value
=
x
x
,
x_test_value
=
x
y
,
y_test_value
=
y
y
,
y_test_value
=
y
g
=
ptm
.
Dot
()
(
x
,
y
)
g
=
ptm
.
dot
(
x
,
y
)
compare_numba_and_py
(
compare_numba_and_py
(
[
x
,
y
],
[
x
,
y
],
...
...
tests/tensor/rewriting/test_math.py
浏览文件 @
0c138495
...
@@ -4714,14 +4714,15 @@ def test_local_dot_to_mul(batched, a_shape, b_shape):
...
@@ -4714,14 +4714,15 @@ def test_local_dot_to_mul(batched, a_shape, b_shape):
==
1
==
1
)
)
# For now rewrite only applies to Batched Dots
# For now we do not rewrite only the case of unbatched outer
core_outer
=
(
not
batched
)
and
(
a_shape
==
(
3
,
1
))
and
(
b_shape
==
(
1
,
3
))
rewritten_out
=
rewrite_graph
(
out
)
rewritten_out
=
rewrite_graph
(
out
)
assert
rewritten_out
.
type
.
shape
==
out
.
type
.
shape
assert
rewritten_out
.
type
.
shape
==
out
.
type
.
shape
assert
sum
(
assert
sum
(
isinstance
(
var
.
owner
.
op
,
(
Blockwise
|
Dot
))
isinstance
(
var
.
owner
.
op
,
(
Blockwise
|
Dot
))
for
var
in
ancestors
([
rewritten_out
])
for
var
in
ancestors
([
rewritten_out
])
if
var
.
owner
if
var
.
owner
)
==
(
0
if
batched
else
1
)
)
==
(
1
if
core_outer
else
0
)
a_test
=
np
.
random
.
normal
(
size
=
a
.
type
.
shape
)
.
astype
(
a
.
type
.
dtype
)
a_test
=
np
.
random
.
normal
(
size
=
a
.
type
.
shape
)
.
astype
(
a
.
type
.
dtype
)
b_test
=
np
.
random
.
normal
(
size
=
b
.
type
.
shape
)
.
astype
(
b
.
type
.
dtype
)
b_test
=
np
.
random
.
normal
(
size
=
b
.
type
.
shape
)
.
astype
(
b
.
type
.
dtype
)
...
...
tests/tensor/test_blas.py
浏览文件 @
0c138495
...
@@ -9,7 +9,6 @@ from numpy.testing import assert_array_almost_equal
...
@@ -9,7 +9,6 @@ from numpy.testing import assert_array_almost_equal
import
pytensor
import
pytensor
import
pytensor.scalar
as
ps
import
pytensor.scalar
as
ps
import
pytensor.tensor
as
pt
import
pytensor.tensor
as
pt
import
pytensor.tensor.blas_scipy
from
pytensor.compile.function
import
function
from
pytensor.compile.function
import
function
from
pytensor.compile.io
import
In
from
pytensor.compile.io
import
In
from
pytensor.compile.mode
import
Mode
from
pytensor.compile.mode
import
Mode
...
...
tests/tensor/test_blas_c.py
浏览文件 @
0c138495
...
@@ -8,7 +8,6 @@ import pytensor.tensor as pt
...
@@ -8,7 +8,6 @@ import pytensor.tensor as pt
from
pytensor.tensor.basic
import
AllocEmpty
from
pytensor.tensor.basic
import
AllocEmpty
from
pytensor.tensor.blas
import
Ger
from
pytensor.tensor.blas
import
Ger
from
pytensor.tensor.blas_c
import
CGemv
,
CGer
,
must_initialize_y_gemv
from
pytensor.tensor.blas_c
import
CGemv
,
CGer
,
must_initialize_y_gemv
from
pytensor.tensor.blas_scipy
import
ScipyGer
from
pytensor.tensor.type
import
dmatrix
,
dvector
,
matrix
,
scalar
,
tensor
,
vector
from
pytensor.tensor.type
import
dmatrix
,
dvector
,
matrix
,
scalar
,
tensor
,
vector
from
tests
import
unittest_tools
from
tests
import
unittest_tools
from
tests.tensor.test_blas
import
BaseGemv
,
TestBlasStrides
from
tests.tensor.test_blas
import
BaseGemv
,
TestBlasStrides
...
@@ -68,8 +67,6 @@ class TestCGer(OptimizationTestMixin):
...
@@ -68,8 +67,6 @@ class TestCGer(OptimizationTestMixin):
assert
CGer
(
False
)
==
CGer
(
False
)
assert
CGer
(
False
)
==
CGer
(
False
)
assert
CGer
(
False
)
!=
CGer
(
True
)
assert
CGer
(
False
)
!=
CGer
(
True
)
assert
CGer
(
True
)
!=
ScipyGer
(
True
)
assert
CGer
(
False
)
!=
ScipyGer
(
False
)
assert
CGer
(
True
)
!=
Ger
(
True
)
assert
CGer
(
True
)
!=
Ger
(
True
)
assert
CGer
(
False
)
!=
Ger
(
False
)
assert
CGer
(
False
)
!=
Ger
(
False
)
...
...
tests/tensor/test_blas_scipy.py
deleted
100644 → 0
浏览文件 @
d1be796e
import
pickle
import
numpy
as
np
import
pytensor
from
pytensor
import
tensor
as
pt
from
pytensor.tensor.blas_scipy
import
ScipyGer
from
pytensor.tensor.math
import
outer
from
pytensor.tensor.type
import
tensor
from
tests.tensor.test_blas
import
TestBlasStrides
,
gemm_no_inplace
from
tests.unittest_tools
import
OptimizationTestMixin
class
TestScipyGer
(
OptimizationTestMixin
):
def
setup_method
(
self
):
self
.
mode
=
pytensor
.
compile
.
get_default_mode
()
self
.
mode
=
self
.
mode
.
including
(
"fast_run"
)
self
.
mode
=
self
.
mode
.
excluding
(
"c_blas"
)
# c_blas trumps scipy Ops
dtype
=
self
.
dtype
=
"float64"
# optimization isn't dtype-dependent
self
.
A
=
tensor
(
dtype
=
dtype
,
shape
=
(
None
,
None
))
self
.
a
=
tensor
(
dtype
=
dtype
,
shape
=
())
self
.
x
=
tensor
(
dtype
=
dtype
,
shape
=
(
None
,))
self
.
y
=
tensor
(
dtype
=
dtype
,
shape
=
(
None
,))
self
.
Aval
=
np
.
ones
((
2
,
3
),
dtype
=
dtype
)
self
.
xval
=
np
.
asarray
([
1
,
2
],
dtype
=
dtype
)
self
.
yval
=
np
.
asarray
([
1.5
,
2.7
,
3.9
],
dtype
=
dtype
)
def
function
(
self
,
inputs
,
outputs
):
return
pytensor
.
function
(
inputs
,
outputs
,
self
.
mode
)
def
run_f
(
self
,
f
):
f
(
self
.
Aval
,
self
.
xval
,
self
.
yval
)
f
(
self
.
Aval
[::
-
1
,
::
-
1
],
self
.
xval
[::
-
1
],
self
.
yval
[::
-
1
])
def
b
(
self
,
bval
):
return
pt
.
as_tensor_variable
(
np
.
asarray
(
bval
,
dtype
=
self
.
dtype
))
def
test_outer
(
self
):
f
=
self
.
function
([
self
.
x
,
self
.
y
],
outer
(
self
.
x
,
self
.
y
))
self
.
assertFunctionContains
(
f
,
ScipyGer
(
destructive
=
True
))
def
test_A_plus_outer
(
self
):
f
=
self
.
function
([
self
.
A
,
self
.
x
,
self
.
y
],
self
.
A
+
outer
(
self
.
x
,
self
.
y
))
self
.
assertFunctionContains
(
f
,
ScipyGer
(
destructive
=
False
))
self
.
run_f
(
f
)
# DebugMode tests correctness
def
test_A_plus_scaled_outer
(
self
):
f
=
self
.
function
(
[
self
.
A
,
self
.
x
,
self
.
y
],
self
.
A
+
0.1
*
outer
(
self
.
x
,
self
.
y
)
)
self
.
assertFunctionContains
(
f
,
ScipyGer
(
destructive
=
False
))
self
.
run_f
(
f
)
# DebugMode tests correctness
def
test_scaled_A_plus_scaled_outer
(
self
):
f
=
self
.
function
(
[
self
.
A
,
self
.
x
,
self
.
y
],
0.2
*
self
.
A
+
0.1
*
outer
(
self
.
x
,
self
.
y
)
)
self
.
assertFunctionContains
(
f
,
gemm_no_inplace
)
self
.
run_f
(
f
)
# DebugMode tests correctness
def
test_pickle
(
self
):
out
=
ScipyGer
(
destructive
=
False
)(
self
.
A
,
self
.
a
,
self
.
x
,
self
.
y
)
f
=
pytensor
.
function
([
self
.
A
,
self
.
a
,
self
.
x
,
self
.
y
],
out
)
new_f
=
pickle
.
loads
(
pickle
.
dumps
(
f
))
assert
isinstance
(
new_f
.
maker
.
fgraph
.
toposort
()[
-
1
]
.
op
,
ScipyGer
)
assert
np
.
allclose
(
f
(
self
.
Aval
,
1.0
,
self
.
xval
,
self
.
yval
),
new_f
(
self
.
Aval
,
1.0
,
self
.
xval
,
self
.
yval
),
)
class
TestBlasStridesScipy
(
TestBlasStrides
):
mode
=
pytensor
.
compile
.
get_default_mode
()
mode
=
mode
.
including
(
"fast_run"
)
.
excluding
(
"gpu"
,
"c_blas"
)
tests/tensor/test_math.py
浏览文件 @
0c138495
...
@@ -1998,50 +1998,20 @@ class TestMean:
...
@@ -1998,50 +1998,20 @@ class TestMean:
assert
mean
(
ll
)
.
eval
()
==
1
assert
mean
(
ll
)
.
eval
()
==
1
def
test_dot_numpy_inputs
():
"""Test the `PyTensor.tensor.dot` interface function with NumPy inputs."""
a
=
np
.
ones
(
2
)
b
=
np
.
ones
(
2
)
res
=
dot
(
a
,
b
)
assert
isinstance
(
res
,
Variable
)
assert
isinstance
(
res
.
owner
.
op
,
Dot
)
class
TestDot
:
class
TestDot
:
def
test_
Op_dims
(
self
):
def
test_
valid_ndim
(
self
):
d0
=
scalar
()
d0
=
scalar
()
d1
=
vector
()
d1
=
vector
()
d2
=
matrix
()
d2
=
matrix
()
d3
=
tensor3
()
d3
=
tensor3
()
with
pytest
.
raises
(
TypeError
):
_dot
(
d0
,
d0
)
with
pytest
.
raises
(
TypeError
):
_dot
(
d0
,
d1
)
with
pytest
.
raises
(
TypeError
):
with
pytest
.
raises
(
TypeError
):
_dot
(
d0
,
d2
)
_dot
(
d0
,
d2
)
with
pytest
.
raises
(
TypeError
):
with
pytest
.
raises
(
TypeError
):
_dot
(
d0
,
d3
)
_dot
(
d1
,
d2
)
with
pytest
.
raises
(
TypeError
):
_dot
(
d1
,
d0
)
_dot
(
d1
,
d1
)
_dot
(
d1
,
d2
)
with
pytest
.
raises
(
TypeError
):
_dot
(
d1
,
d3
)
with
pytest
.
raises
(
TypeError
):
_dot
(
d2
,
d0
)
_dot
(
d2
,
d1
)
_dot
(
d2
,
d2
)
with
pytest
.
raises
(
TypeError
):
_dot
(
d2
,
d3
)
with
pytest
.
raises
(
TypeError
):
_dot
(
d3
,
d0
)
with
pytest
.
raises
(
TypeError
):
_dot
(
d3
,
d1
)
with
pytest
.
raises
(
TypeError
):
with
pytest
.
raises
(
TypeError
):
_dot
(
d3
,
d2
)
_dot
(
d3
,
d2
)
with
pytest
.
raises
(
TypeError
):
_dot
(
d2
,
d2
)
# Fine
_dot
(
d3
,
d3
)
def
test_grad
(
self
):
def
test_grad
(
self
):
rng
=
np
.
random
.
default_rng
(
seed
=
utt
.
fetch_seed
())
rng
=
np
.
random
.
default_rng
(
seed
=
utt
.
fetch_seed
())
...
@@ -2089,6 +2059,14 @@ class TestDot:
...
@@ -2089,6 +2059,14 @@ class TestDot:
g
=
grad
(
z
.
sum
(),
y
)
g
=
grad
(
z
.
sum
(),
y
)
assert
is_super_shape
(
y
,
g
)
assert
is_super_shape
(
y
,
g
)
def
test_dot_numpy_inputs
(
self
):
"""Test the `PyTensor.tensor.dot` interface function with NumPy inputs."""
a
=
np
.
ones
((
2
,
2
))
b
=
np
.
ones
((
2
,
2
))
res
=
dot
(
a
,
b
)
assert
isinstance
(
res
,
Variable
)
assert
isinstance
(
res
.
owner
.
op
,
Dot
)
def
test_matrix_vector_ops
():
def
test_matrix_vector_ops
():
"""Test vecdot, matvec, and vecmat helper functions."""
"""Test vecdot, matvec, and vecmat helper functions."""
...
@@ -2796,7 +2774,7 @@ class TestInferShape(utt.InferShapeTester):
...
@@ -2796,7 +2774,7 @@ class TestInferShape(utt.InferShapeTester):
bdvec_val
=
random
(
4
,
rng
=
rng
)
bdvec_val
=
random
(
4
,
rng
=
rng
)
self
.
_compile_and_check
(
self
.
_compile_and_check
(
[
advec
,
bdvec
],
[
advec
,
bdvec
],
[
Dot
()
(
advec
,
bdvec
)],
[
dot
(
advec
,
bdvec
)],
[
advec_val
,
bdvec_val
],
[
advec_val
,
bdvec_val
],
(
Dot
,
blas
.
Dot22
,
blas
.
Gemv
,
blas_c
.
CGemv
),
(
Dot
,
blas
.
Dot22
,
blas
.
Gemv
,
blas_c
.
CGemv
),
)
)
...
@@ -2808,7 +2786,7 @@ class TestInferShape(utt.InferShapeTester):
...
@@ -2808,7 +2786,7 @@ class TestInferShape(utt.InferShapeTester):
bdmat_val
=
random
(
5
,
3
,
rng
=
rng
)
bdmat_val
=
random
(
5
,
3
,
rng
=
rng
)
self
.
_compile_and_check
(
self
.
_compile_and_check
(
[
admat
,
bdmat
],
[
admat
,
bdmat
],
[
Dot
()
(
admat
,
bdmat
)],
[
dot
(
admat
,
bdmat
)],
[
admat_val
,
bdmat_val
],
[
admat_val
,
bdmat_val
],
(
Dot
,
blas
.
Dot22
),
(
Dot
,
blas
.
Dot22
),
)
)
...
@@ -2817,7 +2795,7 @@ class TestInferShape(utt.InferShapeTester):
...
@@ -2817,7 +2795,7 @@ class TestInferShape(utt.InferShapeTester):
bdmat_val
=
random
(
4
,
5
,
rng
=
rng
)
bdmat_val
=
random
(
4
,
5
,
rng
=
rng
)
self
.
_compile_and_check
(
self
.
_compile_and_check
(
[
advec
,
bdmat
],
[
advec
,
bdmat
],
[
Dot
()
(
advec
,
bdmat
)],
[
dot
(
advec
,
bdmat
)],
[
advec_val
,
bdmat_val
],
[
advec_val
,
bdmat_val
],
(
Dot
,
blas
.
Dot22
,
blas
.
Gemv
,
blas_c
.
CGemv
),
(
Dot
,
blas
.
Dot22
,
blas
.
Gemv
,
blas_c
.
CGemv
),
)
)
...
@@ -2826,7 +2804,7 @@ class TestInferShape(utt.InferShapeTester):
...
@@ -2826,7 +2804,7 @@ class TestInferShape(utt.InferShapeTester):
admat_val
=
random
(
5
,
4
,
rng
=
rng
)
admat_val
=
random
(
5
,
4
,
rng
=
rng
)
self
.
_compile_and_check
(
self
.
_compile_and_check
(
[
admat
,
bdvec
],
[
admat
,
bdvec
],
[
Dot
()
(
admat
,
bdvec
)],
[
dot
(
admat
,
bdvec
)],
[
admat_val
,
bdvec_val
],
[
admat_val
,
bdvec_val
],
(
Dot
,
blas
.
Dot22
,
blas
.
Gemv
,
blas_c
.
CGemv
),
(
Dot
,
blas
.
Dot22
,
blas
.
Gemv
,
blas_c
.
CGemv
),
)
)
...
...
tests/test_printing.py
浏览文件 @
0c138495
...
@@ -333,7 +333,7 @@ def test_debugprint():
...
@@ -333,7 +333,7 @@ def test_debugprint():
def
test_debugprint_id_type
():
def
test_debugprint_id_type
():
a_at
=
d
vector
()
a_at
=
d
matrix
()
b_at
=
dmatrix
()
b_at
=
dmatrix
()
d_at
=
b_at
.
dot
(
a_at
)
d_at
=
b_at
.
dot
(
a_at
)
...
@@ -344,10 +344,10 @@ def test_debugprint_id_type():
...
@@ -344,10 +344,10 @@ def test_debugprint_id_type():
s
=
s
.
getvalue
()
s
=
s
.
getvalue
()
exp_res
=
f
"""Add [id {e_at.auto_name}]
exp_res
=
f
"""Add [id {e_at.auto_name}]
├─
d
ot [id {d_at.auto_name}]
├─
D
ot [id {d_at.auto_name}]
│ ├─ <Matrix(float64, shape=(?, ?))> [id {b_at.auto_name}]
│ ├─ <Matrix(float64, shape=(?, ?))> [id {b_at.auto_name}]
│ └─ <
Vector(float64, shape=(?,
))> [id {a_at.auto_name}]
│ └─ <
Matrix(float64, shape=(?, ?
))> [id {a_at.auto_name}]
└─ <
Vector(float64, shape=(?,
))> [id {a_at.auto_name}]
└─ <
Matrix(float64, shape=(?, ?
))> [id {a_at.auto_name}]
"""
"""
assert
[
l
.
strip
()
for
l
in
s
.
split
(
"
\n
"
)]
==
[
assert
[
l
.
strip
()
for
l
in
s
.
split
(
"
\n
"
)]
==
[
...
...
tests/xtensor/test_math.py
浏览文件 @
0c138495
...
@@ -312,5 +312,7 @@ def test_dot_errors():
...
@@ -312,5 +312,7 @@ def test_dot_errors():
x_test
=
DataArray
(
np
.
ones
((
2
,
3
)),
dims
=
(
"a"
,
"b"
))
x_test
=
DataArray
(
np
.
ones
((
2
,
3
)),
dims
=
(
"a"
,
"b"
))
y_test
=
DataArray
(
np
.
ones
((
4
,
5
)),
dims
=
(
"b"
,
"c"
))
y_test
=
DataArray
(
np
.
ones
((
4
,
5
)),
dims
=
(
"b"
,
"c"
))
# Doesn't fail until the rewrite
# Doesn't fail until the rewrite
with
pytest
.
raises
(
ValueError
,
match
=
"not aligned"
):
with
pytest
.
raises
(
ValueError
,
match
=
"Input operand 1 has a mismatch in its core dimension 0"
):
fn
(
x_test
,
y_test
)
fn
(
x_test
,
y_test
)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论