Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
75485c13
提交
75485c13
authored
12月 05, 2023
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
12月 10, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Simplify BatchedDot implementation
The Op now always expects rank 3 inputs, and any dimshuffles are added explicitly by the helper function
上级
18f245fa
显示空白字符变更
内嵌
并排
正在显示
5 个修改的文件
包含
85 行增加
和
165 行删除
+85
-165
nlinalg.py
pytensor/link/jax/dispatch/nlinalg.py
+1
-3
basic.py
pytensor/link/numba/dispatch/basic.py
+2
-0
blas.py
pytensor/tensor/blas.py
+74
-145
test_nlinalg.py
tests/link/jax/test_nlinalg.py
+0
-9
test_basic.py
tests/link/numba/test_basic.py
+8
-8
没有找到文件。
pytensor/link/jax/dispatch/nlinalg.py
浏览文件 @
75485c13
...
@@ -99,9 +99,7 @@ def jax_funcify_BatchedDot(op, **kwargs):
...
@@ -99,9 +99,7 @@ def jax_funcify_BatchedDot(op, **kwargs):
def
batched_dot
(
a
,
b
):
def
batched_dot
(
a
,
b
):
if
a
.
shape
[
0
]
!=
b
.
shape
[
0
]:
if
a
.
shape
[
0
]
!=
b
.
shape
[
0
]:
raise
TypeError
(
"Shapes must match in the 0-th dimension"
)
raise
TypeError
(
"Shapes must match in the 0-th dimension"
)
if
a
.
ndim
==
2
or
b
.
ndim
==
2
:
return
jnp
.
matmul
(
a
,
b
)
return
jnp
.
einsum
(
"n...j,nj...->n..."
,
a
,
b
)
return
jnp
.
einsum
(
"nij,njk->nik"
,
a
,
b
)
return
batched_dot
return
batched_dot
...
...
pytensor/link/numba/dispatch/basic.py
浏览文件 @
75485c13
...
@@ -895,6 +895,8 @@ def numba_funcify_BatchedDot(op, node, **kwargs):
...
@@ -895,6 +895,8 @@ def numba_funcify_BatchedDot(op, node, **kwargs):
@numba_njit
@numba_njit
def
batched_dot
(
x
,
y
):
def
batched_dot
(
x
,
y
):
# Numba does not support 3D matmul
# https://github.com/numba/numba/issues/3804
shape
=
x
.
shape
[:
-
1
]
+
y
.
shape
[
2
:]
shape
=
x
.
shape
[:
-
1
]
+
y
.
shape
[
2
:]
z0
=
np
.
empty
(
shape
,
dtype
=
dtype
)
z0
=
np
.
empty
(
shape
,
dtype
=
dtype
)
for
i
in
range
(
z0
.
shape
[
0
]):
for
i
in
range
(
z0
.
shape
[
0
]):
...
...
pytensor/tensor/blas.py
浏览文件 @
75485c13
...
@@ -98,10 +98,11 @@ from pytensor.link.c.params_type import ParamsType
...
@@ -98,10 +98,11 @@ from pytensor.link.c.params_type import ParamsType
from
pytensor.printing
import
FunctionPrinter
,
pprint
from
pytensor.printing
import
FunctionPrinter
,
pprint
from
pytensor.scalar
import
bool
as
bool_t
from
pytensor.scalar
import
bool
as
bool_t
from
pytensor.tensor
import
basic
as
at
from
pytensor.tensor
import
basic
as
at
from
pytensor.tensor.basic
import
expand_dims
from
pytensor.tensor.blas_headers
import
blas_header_text
,
blas_header_version
from
pytensor.tensor.blas_headers
import
blas_header_text
,
blas_header_version
from
pytensor.tensor.elemwise
import
DimShuffle
from
pytensor.tensor.elemwise
import
DimShuffle
from
pytensor.tensor.math
import
add
,
mul
,
neg
,
sub
from
pytensor.tensor.math
import
add
,
mul
,
neg
,
sub
from
pytensor.tensor.shape
import
specify_broadcastable
from
pytensor.tensor.shape
import
s
hape_padright
,
s
pecify_broadcastable
from
pytensor.tensor.type
import
DenseTensorType
,
TensorType
,
integer_dtypes
,
tensor
from
pytensor.tensor.type
import
DenseTensorType
,
TensorType
,
integer_dtypes
,
tensor
from
pytensor.utils
import
memoize
from
pytensor.utils
import
memoize
...
@@ -1637,48 +1638,53 @@ _dot22scalar = Dot22Scalar()
...
@@ -1637,48 +1638,53 @@ _dot22scalar = Dot22Scalar()
class
BatchedDot
(
COp
):
class
BatchedDot
(
COp
):
"""
"""
Computes
the batched dot product of two variables:
Computes
a batch matrix-matrix dot with tensor3 variables
batched_dot(a, b)[i] = dot(a[i], b[i])
batched_dot(a, b)[i] = dot(a[i], b[i])
"""
"""
__props__
=
()
__props__
=
()
gufunc_signature
=
"(b,m,k),(b,k,n)->(b,m,n)"
def
make_node
(
self
,
*
inputs
):
def
make_node
(
self
,
x
,
y
):
inputs
=
list
(
map
(
at
.
as_tensor_variable
,
inputs
))
x
=
at
.
as_tensor_variable
(
x
)
y
=
at
.
as_tensor_variable
(
y
)
if
any
(
not
isinstance
(
i
.
type
,
DenseTensorType
)
for
i
in
inputs
):
if
not
(
isinstance
(
x
.
type
,
DenseTensorType
)
and
isinstance
(
y
.
type
,
DenseTensorType
)
):
raise
NotImplementedError
(
"Only dense tensor types are supported"
)
raise
NotImplementedError
(
"Only dense tensor types are supported"
)
if
len
(
inputs
)
!=
2
:
if
not
(
x
.
type
.
ndim
==
3
and
y
.
type
.
ndim
==
3
):
raise
TypeError
(
f
"Two arguments required, but {len(inputs)} given."
)
if
inputs
[
0
]
.
ndim
not
in
(
2
,
3
):
raise
TypeError
(
"Input 0 (0-indexed)"
f
" must have ndim of 2 or 3, {int(inputs[0].ndim)} given. Consider"
" calling batched_dot instead."
)
if
inputs
[
1
]
.
ndim
not
in
(
2
,
3
):
raise
TypeError
(
raise
TypeError
(
"Input 1 (0-indexed)"
f
"Inputs must have 3 ndim, but got {x.type.ndim} and {y.type.ndim}. "
f
" must have ndim of 2 or 3, {int(inputs[1].ndim)} given. Consider"
"Consider calling batched_dot instead."
" calling batched_dot instead."
)
)
dtype
=
pytensor
.
scalar
.
upcast
(
*
[
input
.
type
.
dtype
for
input
in
inputs
])
def
extract_static_dim
(
dim_x
,
dim_y
):
# upcast inputs to common dtype if needed
dims
=
{
dim_x
,
dim_y
}
-
{
None
}
upcasted_inputs
=
[
at
.
cast
(
input
,
dtype
)
for
input
in
inputs
]
if
len
(
dims
)
>
1
:
out_shape
=
(
# BatchedDot doesn't allow broadcasting
(
raise
ValueError
(
1
f
"Static dimensions of BatchedDot don't match, got {x.type.shape} and {y.type.shape}"
if
inputs
[
0
]
.
type
.
shape
[
0
]
==
1
or
inputs
[
1
]
.
type
.
shape
[
0
]
==
1
else
None
,
)
+
inputs
[
0
]
.
type
.
shape
[
1
:
-
1
]
+
inputs
[
1
]
.
type
.
shape
[
2
:]
)
)
out_shape
=
tuple
(
1
if
s
==
1
else
None
for
s
in
out_shape
)
elif
not
dims
:
return
Apply
(
self
,
upcasted_inputs
,
[
tensor
(
dtype
=
dtype
,
shape
=
out_shape
)])
return
None
else
:
return
dims
.
pop
()
x_batch_dim
,
x_row_dim
,
x_sum_dim
=
x
.
type
.
shape
y_batch_dim
,
y_sum_dim
,
y_col_dim
=
y
.
type
.
shape
batch_dim
=
extract_static_dim
(
x_batch_dim
,
y_batch_dim
)
# Raise if static sum dimensions do not match
_
=
extract_static_dim
(
x_sum_dim
,
y_sum_dim
)
out_shape
=
(
batch_dim
,
x_row_dim
,
y_col_dim
)
# Change dtype if needed
dtype
=
pytensor
.
scalar
.
upcast
(
x
.
type
.
dtype
,
y
.
type
.
dtype
)
x
,
y
=
at
.
cast
(
x
,
dtype
),
at
.
cast
(
y
,
dtype
)
out
=
tensor
(
dtype
=
dtype
,
shape
=
out_shape
)
return
Apply
(
self
,
[
x
,
y
],
[
out
])
def
perform
(
self
,
node
,
inp
,
out
):
def
perform
(
self
,
node
,
inp
,
out
):
x
,
y
=
inp
x
,
y
=
inp
...
@@ -1690,11 +1696,7 @@ class BatchedDot(COp):
...
@@ -1690,11 +1696,7 @@ class BatchedDot(COp):
f
" same size in axis 0, but have sizes [{', '.join([str(i.shape[0]) for i in inp])}]."
f
" same size in axis 0, but have sizes [{', '.join([str(i.shape[0]) for i in inp])}]."
)
)
shape
=
self
.
infer_shape
(
None
,
node
,
[
i
.
shape
for
i
in
inp
])[
0
]
z
[
0
]
=
np
.
matmul
(
x
,
y
)
dtype
=
node
.
outputs
[
0
]
.
dtype
z0
=
z
[
0
]
=
np
.
empty
(
shape
,
dtype
=
dtype
)
for
i
in
range
(
z0
.
shape
[
0
]):
z0
[
i
]
=
np
.
dot
(
x
[
i
],
y
[
i
])
def
c_support_code
(
self
,
**
kwargs
):
def
c_support_code
(
self
,
**
kwargs
):
batch_gemm_defn
=
"""
batch_gemm_defn
=
"""
...
@@ -1792,14 +1794,6 @@ class BatchedDot(COp):
...
@@ -1792,14 +1794,6 @@ class BatchedDot(COp):
def
c_header_dirs
(
self
,
**
kwargs
):
def
c_header_dirs
(
self
,
**
kwargs
):
return
ldflags
(
libs
=
False
,
include_dir
=
True
)
return
ldflags
(
libs
=
False
,
include_dir
=
True
)
def
c_code_cleanup
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
return
"""
// clean up views
Py_XDECREF(xs); xs = 0;
Py_XDECREF(ys); ys = 0;
Py_XDECREF(zs); zs = 0;
"""
def
c_code
(
self
,
node
,
name
,
inp
,
out
,
sub
):
def
c_code
(
self
,
node
,
name
,
inp
,
out
,
sub
):
_x
,
_y
=
inp
_x
,
_y
=
inp
(
_z
,)
=
out
(
_z
,)
=
out
...
@@ -1832,12 +1826,11 @@ class BatchedDot(COp):
...
@@ -1832,12 +1826,11 @@ class BatchedDot(COp):
)
)
# generate code to allocate output based on runtime input shapes
# generate code to allocate output based on runtime input shapes
z_dims
=
[
f
"PyArray_DIMS({_x})[0]"
]
z_dims
=
[
if
x_ndim
==
3
:
f
"PyArray_DIMS({_x})[0]"
,
z_dims
.
append
(
f
"PyArray_DIMS({_x})[1]"
)
f
"PyArray_DIMS({_x})[1]"
,
if
y_ndim
==
3
:
f
"PyArray_DIMS({_y})[2]"
,
z_dims
.
append
(
f
"PyArray_DIMS({_y})[2]"
)
]
assert
len
(
z_dims
)
==
z_ndim
z_shape_correct
=
" && "
.
join
(
z_shape_correct
=
" && "
.
join
(
"PyArray_DIMS(
%
s)[
%
i] ==
%
s"
%
(
_z
,
i
,
dim
)
for
i
,
dim
in
enumerate
(
z_dims
)
"PyArray_DIMS(
%
s)[
%
i] ==
%
s"
%
(
_z
,
i
,
dim
)
for
i
,
dim
in
enumerate
(
z_dims
)
...
@@ -1880,76 +1873,26 @@ class BatchedDot(COp):
...
@@ -1880,76 +1873,26 @@ class BatchedDot(COp):
)
)
contiguate
=
"
\n
"
.
join
(
contiguate
)
contiguate
=
"
\n
"
.
join
(
contiguate
)
def
c_dimshuffle
(
newname
,
oldname
,
shape
):
_fail
=
fail
_shape
=
", "
.
join
(
"1"
if
axis
is
None
else
"PyArray_DIMS(
%
s)[
%
i]"
%
(
oldname
,
axis
)
for
axis
in
shape
)
return
(
"""{
npy_intp dims[3] = {
%(_shape)
s};
PyArray_Dims newshape = {dims, 3};
%(newname)
s = (PyArrayObject*)PyArray_Newshape(
%(oldname)
s, &newshape, NPY_ANYORDER);
if (!
%(newname)
s)
%(_fail)
s
// make sure we didn't accidentally copy
assert(PyArray_DATA(
%(oldname)
s) == PyArray_DATA(
%(newname)
s));
}"""
%
locals
()
)
# create tensor3 views for any of x, y, z that are not tensor3, so that
# we only need to implement the tensor3-tensor3 batched dot product.
# xs, ys and zs will point to these views, or to the original array if
# it was already tensor3.
# in the latter case, we artificially increase the reference count of
# the original array so that the c_code_cleanup method can decref them
# all indiscriminately.
upcast
=
[]
if
x_ndim
==
3
:
upcast
.
append
(
"xs =
%(_x)
s; Py_XINCREF(xs);"
)
elif
x_ndim
==
2
:
upcast
.
append
(
c_dimshuffle
(
"xs"
,
_x
,
(
0
,
None
,
1
)))
if
y_ndim
==
3
:
upcast
.
append
(
"ys =
%(_y)
s; Py_XINCREF(ys);"
)
elif
y_ndim
==
2
:
upcast
.
append
(
c_dimshuffle
(
"ys"
,
_y
,
(
0
,
1
,
None
)))
if
z_ndim
==
3
:
upcast
.
append
(
"zs =
%(_z)
s; Py_XINCREF(zs);"
)
else
:
upcast
.
append
(
c_dimshuffle
(
"zs"
,
_z
,
(
0
,
None
if
x_ndim
==
2
else
1
,
None
if
y_ndim
==
2
else
1
),
)
)
upcast
=
"
\n
"
.
join
(
upcast
)
%
locals
()
return
(
return
(
"""
"""
int type_num = PyArray_DESCR(
%(_x)
s)->type_num;
int type_num = PyArray_DESCR(
%(_x)
s)->type_num;
int type_size = PyArray_DESCR(
%(_x)
s)->elsize; // in bytes
int type_size = PyArray_DESCR(
%(_x)
s)->elsize; // in bytes
// xs, ys, zs will point to views onto
%(_x)
s,
%(_y)
s,
%(_z)
s
if (PyArray_NDIM(
%(_x)
s) != 3) {
PyArrayObject *xs = 0, *ys = 0, *zs = 0;
if (PyArray_NDIM(
%(_x)
s) !=
%(x_ndim)
s) {
PyErr_Format(PyExc_NotImplementedError,
PyErr_Format(PyExc_NotImplementedError,
"rank(x) !=
%(x_ndim)
s
. rank(x) is
%%
d.",
"rank(x) !=
3
. rank(x) is
%%
d.",
PyArray_NDIM(
%(_x)
s));
PyArray_NDIM(
%(_x)
s));
%(fail)
s;
%(fail)
s;
}
}
if (PyArray_NDIM(
%(_y)
s) !=
%(y_ndim)
s
) {
if (PyArray_NDIM(
%(_y)
s) !=
3
) {
PyErr_Format(PyExc_NotImplementedError,
PyErr_Format(PyExc_NotImplementedError,
"rank(y) !=
%(y_ndim)
s
. rank(y) is
%%
d.",
"rank(y) !=
3
. rank(y) is
%%
d.",
PyArray_NDIM(
%(_y)
s));
PyArray_NDIM(
%(_y)
s));
%(fail)
s;
%(fail)
s;
}
}
if (
%(_z)
s && PyArray_NDIM(
%(_z)
s) !=
%(z_ndim)
s
) {
if (
%(_z)
s && PyArray_NDIM(
%(_z)
s) !=
3
) {
PyErr_Format(PyExc_NotImplementedError,
PyErr_Format(PyExc_NotImplementedError,
"rank(z) !=
%(z_ndim)
s
. rank(z) is
%%
d.",
"rank(z) !=
3
. rank(z) is
%%
d.",
PyArray_NDIM(
%(_z)
s));
PyArray_NDIM(
%(_z)
s));
%(fail)
s;
%(fail)
s;
}
}
...
@@ -1958,36 +1901,32 @@ class BatchedDot(COp):
...
@@ -1958,36 +1901,32 @@ class BatchedDot(COp):
%(allocate)
s
%(allocate)
s
// reallocate any noncontiguous arrays or arrays with invalid strides
// reallocate any noncontiguous arrays or arrays with invalid strides
%(contiguate)
s
%(contiguate)
s
// add dims to make sure everything is tensor3
%(upcast)
s
// from here on, use xs, ys and zs as they are tensor3 and share memory
// with the original
%(_x)
s,
%(_y)
s and
%(_z)
s arrays.
if ((PyArray_DESCR(
x
s)->type_num != NPY_DOUBLE)
if ((PyArray_DESCR(
%(_x)
s)->type_num != NPY_DOUBLE)
&& (PyArray_DESCR(
x
s)->type_num != NPY_FLOAT))
&& (PyArray_DESCR(
%(_x)
s)->type_num != NPY_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float");
%(fail)
s;}
{PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float");
%(fail)
s;}
if ((PyArray_DESCR(
y
s)->type_num != NPY_DOUBLE)
if ((PyArray_DESCR(
%(_y)
s)->type_num != NPY_DOUBLE)
&& (PyArray_DESCR(
y
s)->type_num != NPY_FLOAT))
&& (PyArray_DESCR(
%(_y)
s)->type_num != NPY_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float");
%(fail)
s;}
{PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float");
%(fail)
s;}
if ((PyArray_DESCR(
z
s)->type_num != NPY_DOUBLE)
if ((PyArray_DESCR(
%(_z)
s)->type_num != NPY_DOUBLE)
&& (PyArray_DESCR(
z
s)->type_num != NPY_FLOAT))
&& (PyArray_DESCR(
%(_z)
s)->type_num != NPY_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float");
%(fail)
s;}
{PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float");
%(fail)
s;}
if ((PyArray_DESCR(
xs)->type_num != PyArray_DESCR(y
s)->type_num)
if ((PyArray_DESCR(
%(_x)
s)->type_num != PyArray_DESCR(
%(_y)
s)->type_num)
||(PyArray_DESCR(
xs)->type_num != PyArray_DESCR(z
s)->type_num))
||(PyArray_DESCR(
%(_x)
s)->type_num != PyArray_DESCR(
%(_z)
s)->type_num))
{ PyErr_SetString(PyExc_NotImplementedError, "type(x), type(y), type(z) are not all the same");
%(fail)
s; }
{ PyErr_SetString(PyExc_NotImplementedError, "type(x), type(y), type(z) are not all the same");
%(fail)
s; }
switch (type_num)
switch (type_num)
{
{
case NPY_FLOAT:
case NPY_FLOAT:
if (batch_gemm<float>(sgemm_, type_size,
xs, ys, z
s)) {
if (batch_gemm<float>(sgemm_, type_size,
%(_x)
s,
%(_y)
s,
%(_z)
s)) {
%(fail)
s;
%(fail)
s;
}
}
break;
break;
case NPY_DOUBLE:
case NPY_DOUBLE:
if (batch_gemm<double>(dgemm_, type_size,
xs, ys, z
s)) {
if (batch_gemm<double>(dgemm_, type_size,
%(_x)
s,
%(_y)
s,
%(_z)
s)) {
%(fail)
s;
%(fail)
s;
}
}
break;
break;
...
@@ -1999,30 +1938,12 @@ class BatchedDot(COp):
...
@@ -1999,30 +1938,12 @@ class BatchedDot(COp):
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
from
pytensor.tensor.blas_headers
import
blas_header_version
from
pytensor.tensor.blas_headers
import
blas_header_version
return
(
4
,
blas_header_version
())
return
(
5
,
blas_header_version
())
def
grad
(
self
,
inp
,
grads
):
def
grad
(
self
,
inp
,
grads
):
x
,
y
=
inp
x
,
y
=
inp
(
gz
,)
=
grads
(
gz
,)
=
grads
xdim
,
ydim
,
gdim
=
x
.
type
.
ndim
,
y
.
type
.
ndim
,
gz
.
type
.
ndim
# grad is a vector, so x is a matrix and y is a matrix
if
gdim
==
1
:
xgrad
=
gz
.
dimshuffle
(
0
,
"x"
)
*
y
ygrad
=
gz
.
dimshuffle
(
0
,
"x"
)
*
x
# x is a matrix, y is a tensor3, grad is a matrix
elif
xdim
==
2
and
ydim
==
3
:
xgrad
=
batched_dot
(
gz
,
y
.
dimshuffle
(
0
,
2
,
1
))
ygrad
=
x
.
dimshuffle
(
0
,
1
,
"x"
)
*
gz
.
dimshuffle
(
0
,
"x"
,
1
)
# x is a tensor3, y is a matrix, grad is a matrix
elif
xdim
==
3
and
ydim
==
2
:
xgrad
=
gz
.
dimshuffle
(
0
,
1
,
"x"
)
*
y
.
dimshuffle
(
0
,
"x"
,
1
)
ygrad
=
batched_dot
(
x
.
dimshuffle
(
0
,
2
,
1
),
gz
)
# x is a tensor3, y is a tensor3, grad is a tensor3
elif
xdim
==
ydim
==
3
:
xgrad
=
batched_dot
(
gz
,
y
.
dimshuffle
(
0
,
2
,
1
))
xgrad
=
batched_dot
(
gz
,
y
.
dimshuffle
(
0
,
2
,
1
))
ygrad
=
batched_dot
(
x
.
dimshuffle
(
0
,
2
,
1
),
gz
)
ygrad
=
batched_dot
(
x
.
dimshuffle
(
0
,
2
,
1
),
gz
)
...
@@ -2105,6 +2026,7 @@ class BatchedDot(COp):
...
@@ -2105,6 +2026,7 @@ class BatchedDot(COp):
+
" to BatchedDot.R_op should have the same shape, but "
+
" to BatchedDot.R_op should have the same shape, but "
f
"their shapes are {input_values[i].shape} and {eval_point_values[i].shape}, respectively"
f
"their shapes are {input_values[i].shape} and {eval_point_values[i].shape}, respectively"
)
)
if
eval_points
[
0
]:
if
eval_points
[
0
]:
t1
=
self
(
eval_points
[
0
],
inputs
[
1
])
t1
=
self
(
eval_points
[
0
],
inputs
[
1
])
if
eval_points
[
1
]:
if
eval_points
[
1
]:
...
@@ -2118,9 +2040,6 @@ class BatchedDot(COp):
...
@@ -2118,9 +2040,6 @@ class BatchedDot(COp):
return
[
t2
]
return
[
t2
]
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
for
shape_
in
shapes
:
if
len
(
shape_
)
not
in
(
2
,
3
):
raise
NotImplementedError
()
xshp
,
yshp
=
shapes
xshp
,
yshp
=
shapes
return
[
xshp
[:
-
1
]
+
yshp
[
2
:]]
return
[
xshp
[:
-
1
]
+
yshp
[
2
:]]
...
@@ -2157,14 +2076,24 @@ def batched_dot(a, b):
...
@@ -2157,14 +2076,24 @@ def batched_dot(a, b):
elif
b
.
ndim
==
0
:
elif
b
.
ndim
==
0
:
raise
TypeError
(
"b must have at least one (batch) axis"
)
raise
TypeError
(
"b must have at least one (batch) axis"
)
elif
a
.
ndim
==
1
:
elif
a
.
ndim
==
1
:
return
a
.
dimshuffle
(
*
([
0
]
+
[
"x"
]
*
(
b
.
ndim
-
1
)
))
*
b
return
shape_padright
(
a
,
(
b
.
ndim
-
1
))
*
b
elif
b
.
ndim
==
1
:
elif
b
.
ndim
==
1
:
return
a
*
b
.
dimshuffle
(
*
([
0
]
+
[
"x"
]
*
(
a
.
ndim
-
1
)
))
return
a
*
shape_padright
(
b
,
(
a
.
ndim
-
1
))
elif
a
.
ndim
>
3
or
b
.
ndim
>
3
:
elif
a
.
ndim
>
3
or
b
.
ndim
>
3
:
return
batched_tensordot
(
a
,
b
,
[[
a
.
ndim
-
1
],
[
np
.
maximum
(
1
,
b
.
ndim
-
2
)]])
return
batched_tensordot
(
a
,
b
,
[[
a
.
ndim
-
1
],
[
np
.
maximum
(
1
,
b
.
ndim
-
2
)]])
else
:
else
:
# avoid circular import
# If either a or b is a batched vector, expand dims and later squeeze them
return
_batched_dot
(
a
,
b
)
expanded_axis
=
[]
if
a
.
ndim
==
2
:
a
=
expand_dims
(
a
,
axis
=
1
)
expanded_axis
.
append
(
1
)
if
b
.
ndim
==
2
:
b
=
expand_dims
(
b
,
axis
=
2
)
expanded_axis
.
append
(
2
)
out
=
_batched_dot
(
a
,
b
)
if
expanded_axis
:
out
=
out
.
squeeze
(
axis
=
expanded_axis
)
return
out
def
batched_tensordot
(
x
,
y
,
axes
=
2
):
def
batched_tensordot
(
x
,
y
,
axes
=
2
):
...
...
tests/link/jax/test_nlinalg.py
浏览文件 @
75485c13
...
@@ -43,15 +43,6 @@ def test_jax_BatchedDot():
...
@@ -43,15 +43,6 @@ def test_jax_BatchedDot():
with
pytest
.
raises
(
TypeError
):
with
pytest
.
raises
(
TypeError
):
pytensor_jax_fn
(
*
inputs
)
pytensor_jax_fn
(
*
inputs
)
# matrix . matrix
a
=
matrix
(
"a"
)
a
.
tag
.
test_value
=
np
.
linspace
(
-
1
,
1
,
5
*
3
)
.
astype
(
config
.
floatX
)
.
reshape
((
5
,
3
))
b
=
matrix
(
"b"
)
b
.
tag
.
test_value
=
np
.
linspace
(
1
,
-
1
,
5
*
3
)
.
astype
(
config
.
floatX
)
.
reshape
((
5
,
3
))
out
=
at_blas
.
BatchedDot
()(
a
,
b
)
fgraph
=
FunctionGraph
([
a
,
b
],
[
out
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
def
test_jax_basic_multiout
():
def
test_jax_basic_multiout
():
rng
=
np
.
random
.
default_rng
(
213234
)
rng
=
np
.
random
.
default_rng
(
213234
)
...
...
tests/link/numba/test_basic.py
浏览文件 @
75485c13
...
@@ -843,23 +843,23 @@ def test_Softplus(x, exc):
...
@@ -843,23 +843,23 @@ def test_Softplus(x, exc):
[
[
(
(
set_test_value
(
set_test_value
(
at
.
d
matrix
(),
at
.
d
tensor3
(),
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
"float64"
),
rng
.
random
(
size
=
(
2
,
3
,
3
))
.
astype
(
"float64"
),
),
),
set_test_value
(
set_test_value
(
at
.
d
matrix
(),
at
.
d
tensor3
(),
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
"float64"
),
rng
.
random
(
size
=
(
2
,
3
,
3
))
.
astype
(
"float64"
),
),
),
None
,
None
,
),
),
(
(
set_test_value
(
set_test_value
(
at
.
d
matrix
(),
at
.
d
tensor3
(),
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
"float64"
),
rng
.
random
(
size
=
(
2
,
3
,
3
))
.
astype
(
"float64"
),
),
),
set_test_value
(
set_test_value
(
at
.
l
matrix
(),
at
.
l
tensor3
(),
rng
.
poisson
(
size
=
(
3
,
3
))
.
astype
(
"int64"
),
rng
.
poisson
(
size
=
(
2
,
3
,
3
))
.
astype
(
"int64"
),
),
),
None
,
None
,
),
),
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论