Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
bad5f528
提交
bad5f528
authored
1月 29, 2026
作者:
Tomas Capretto
提交者:
Ricardo Vieira
2月 12, 2026
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Implement StructuredDotGradCSR and StructuredDotGradCSC in numba backend
上级
00a11b60
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
239 行增加
和
3 行删除
+239
-3
math.py
pytensor/link/numba/dispatch/sparse/math.py
+233
-0
math.py
pytensor/sparse/math.py
+6
-3
没有找到文件。
pytensor/link/numba/dispatch/sparse/math.py
浏览文件 @
bad5f528
...
@@ -4,6 +4,7 @@ import numpy as np
...
@@ -4,6 +4,7 @@ import numpy as np
import
scipy.sparse
as
sp
import
scipy.sparse
as
sp
import
pytensor.sparse.basic
as
psb
import
pytensor.sparse.basic
as
psb
from
pytensor
import
config
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch.basic
import
(
from
pytensor.link.numba.dispatch.basic
import
(
register_funcify_and_cache_key
,
register_funcify_and_cache_key
,
...
@@ -14,6 +15,8 @@ from pytensor.sparse import (
...
@@ -14,6 +15,8 @@ from pytensor.sparse import (
SparseDenseMultiply
,
SparseDenseMultiply
,
SparseDenseVectorMultiply
,
SparseDenseVectorMultiply
,
StructuredDot
,
StructuredDot
,
StructuredDotGradCSC
,
StructuredDotGradCSR
,
)
)
...
@@ -402,3 +405,233 @@ def numba_funcify_SparseDot(op, node, **kwargs):
...
@@ -402,3 +405,233 @@ def numba_funcify_SparseDot(op, node, **kwargs):
return
spmdm_csr
(
y
.
T
,
x
.
T
)
.
T
return
spmdm_csr
(
y
.
T
,
x
.
T
)
.
T
return
dmspm
,
cache_key
return
dmspm
,
cache_key
@register_funcify_and_cache_key
(
StructuredDotGradCSR
)
@register_funcify_and_cache_key
(
StructuredDotGradCSC
)
def
numba_funcify_StructuredDotGrad
(
op
,
node
,
**
kwargs
):
"""Overload StructuredDotGrad in Numba.
Let:
Z = structured_dot(X, Y)
L = L(Z), a scalar loss depending on Z.
This function computes the gradient of the loss with respect to X:
dL/dX = dot(dL/dZ, Y^T)
where G = dL/dZ is the accumulated (upstream) gradient.
The returned gradient is structured, preserving the sparsity pattern of X,
and only the `.data` component of the sparse matrix is computed.
If Y is sparse, the sparsity pattern of the result is not recomputed.
The output may contain explicit zeros at positions that would be structural zeros
if the sparsity structure were updated.
The core of the algorithm is:
dot(g_xy[i], y[j])
where g_xy[i] (row of G) and y[j] (column of Y^T) are vectors of length 'k'
Reminder:
x.shape (n, p)
y.shape (p, k)
g_xy.shape (n, k)
"""
_
,
_
,
y
,
g_xy
=
node
.
inputs
y_dtype
=
y
.
type
.
dtype
y_is_sparse
=
psb
.
_is_sparse_variable
(
y
)
y_format
=
y
.
type
.
format
if
y_is_sparse
else
None
g_xy_dtype
=
g_xy
.
type
.
dtype
g_xy_is_sparse
=
psb
.
_is_sparse_variable
(
g_xy
)
g_xy_format
=
g_xy
.
type
.
format
if
g_xy_is_sparse
else
None
x_format
=
"csc"
if
isinstance
(
op
,
StructuredDotGradCSC
)
else
"csr"
out_dtype
=
g_xy_dtype
cache_key
=
sha256
(
str
(
(
type
(
op
),
x_format
,
y_format
,
y_dtype
,
g_xy_format
,
out_dtype
,
y
.
type
.
shape
,
)
)
.
encode
()
)
.
hexdigest
()
if
not
g_xy_is_sparse
:
# X is sparse, Y and G_xy are dense.
if
x_format
==
"csr"
:
if
y
.
type
.
shape
[
1
]
==
1
:
# If Y is actually 1D, use more performant specialized algorithm
# Inputs with ndims > 2 will never appear in the StructuredDot Op
@numba_basic.numba_njit
def
_grad_spmdv_csr
(
x_indices
,
x_ptr
,
y
,
g_xy
):
output
=
np
.
empty
(
len
(
x_indices
),
dtype
=
out_dtype
)
size
=
len
(
x_ptr
)
-
1
x_indices
=
x_indices
.
view
(
np
.
uint32
)
x_ptr
=
x_ptr
.
view
(
np
.
uint32
)
for
row_idx
in
range
(
size
):
for
value_idx
in
range
(
x_ptr
[
row_idx
],
x_ptr
[
row_idx
+
1
]):
output
[
value_idx
]
=
g_xy
[
row_idx
]
*
y
[
x_indices
[
value_idx
]]
return
output
@numba_basic.numba_njit
def
grad_spmdv_csr
(
x_indices
,
x_ptr
,
y
,
g_xy
):
return
_grad_spmdv_csr
(
x_indices
,
x_ptr
,
y
[:,
0
],
g_xy
[:,
0
])
return
grad_spmdv_csr
,
cache_key
else
:
# Y is a matrix
if
config
.
compiler_verbose
and
y_dtype
!=
out_dtype
:
print
(
# noqa: T201
"Numba StructuredDotGrad requires a type casting of inputs: "
f
"{y_dtype=}, {g_xy_dtype=}."
)
@numba_basic.numba_njit
def
grad_spmdm_csr
(
x_indices
,
x_ptr
,
y
,
g_xy
):
size
=
len
(
x_ptr
)
-
1
x_indices
=
x_indices
.
view
(
np
.
uint32
)
x_ptr
=
x_ptr
.
view
(
np
.
uint32
)
if
y_dtype
!=
out_dtype
:
new_out_dtype
=
np
.
result_type
(
y
,
g_xy
)
output
=
np
.
zeros
(
len
(
x_indices
),
dtype
=
new_out_dtype
)
y
=
y
.
astype
(
out_dtype
)
g_xy
=
g_xy
.
astype
(
out_dtype
)
else
:
output
=
np
.
zeros
(
len
(
x_indices
),
dtype
=
out_dtype
)
for
row_idx
in
range
(
size
):
for
value_idx
in
range
(
x_ptr
[
row_idx
],
x_ptr
[
row_idx
+
1
]):
output
[
value_idx
]
=
np
.
dot
(
g_xy
[
row_idx
],
y
[
x_indices
[
value_idx
]]
)
return
output
return
grad_spmdm_csr
,
cache_key
else
:
# X is CSC
@numba_basic.numba_njit
def
grad_spmdm_csc
(
x_indices
,
x_ptr
,
y
,
g_xy
):
# len(x_indices) gives the number of non-zero elements in X.
output
=
np
.
zeros
(
len
(
x_indices
),
dtype
=
out_dtype
)
size
=
len
(
x_ptr
)
-
1
x_indices
=
x_indices
.
view
(
np
.
uint32
)
x_ptr
=
x_ptr
.
view
(
np
.
uint32
)
for
col_idx
in
range
(
size
):
for
value_idx
in
range
(
x_ptr
[
col_idx
],
x_ptr
[
col_idx
+
1
]):
output
[
value_idx
]
=
np
.
dot
(
g_xy
[
x_indices
[
value_idx
]],
y
[
col_idx
]
)
return
output
return
grad_spmdm_csc
,
cache_key
# Y is sparse. In either case we need 'dot_csr_rows'
@numba_basic.numba_njit
def
dot_csr_rows
(
x_ptr
,
x_indices
,
x_data
,
x_row
,
y_ptr
,
y_indices
,
y_data
,
y_row
):
x_p
=
x_ptr
[
x_row
]
x_end
=
x_ptr
[
x_row
+
1
]
y_p
=
y_ptr
[
y_row
]
y_end
=
y_ptr
[
y_row
+
1
]
acc
=
0.0
while
x_p
<
x_end
and
y_p
<
y_end
:
x_col
=
x_indices
[
x_p
]
y_col
=
y_indices
[
y_p
]
if
x_col
==
y_col
:
acc
+=
x_data
[
x_p
]
*
y_data
[
y_p
]
x_p
+=
1
y_p
+=
1
elif
x_col
<
y_col
:
x_p
+=
1
else
:
y_p
+=
1
return
acc
if
x_format
==
"csr"
:
assert
g_xy_format
==
"csr"
assert
psb
.
_is_sparse_variable
(
y
)
@numba_basic.numba_njit
def
grad_spmspm_csr
(
x_indices
,
x_ptr
,
y
,
g_xy
):
if
y_format
==
"csc"
:
y
=
y
.
tocsr
()
g_xy_data
=
g_xy
.
data
g_xy_indices
=
g_xy
.
indices
.
view
(
np
.
uint32
)
g_xy_ptr
=
g_xy
.
indptr
.
view
(
np
.
uint32
)
y_data
=
y
.
data
y_indices
=
y
.
indices
.
view
(
np
.
uint32
)
y_ptr
=
y
.
indptr
.
view
(
np
.
uint32
)
n_row
=
len
(
x_ptr
)
-
1
output
=
np
.
zeros
(
len
(
x_indices
),
dtype
=
out_dtype
)
for
x_row
in
range
(
n_row
):
for
data_idx
in
range
(
x_ptr
[
x_row
],
x_ptr
[
x_row
+
1
]):
x_col
=
x_indices
[
data_idx
]
output
[
data_idx
]
=
dot_csr_rows
(
g_xy_ptr
,
g_xy_indices
,
g_xy_data
,
x_row
,
y_ptr
,
y_indices
,
y_data
,
x_col
,
)
return
output
return
grad_spmspm_csr
,
cache_key
else
:
assert
g_xy_format
==
"csc"
assert
psb
.
_is_sparse_variable
(
y
)
@numba_basic.numba_njit
def
grad_spmspm_csc
(
x_indices
,
x_ptr
,
y
,
g_xy
):
if
y_format
==
"csc"
:
y
=
y
.
tocsr
()
# Looping a CSC matrix rowwise is too painful, slow, and cryptic.
g_xy
=
g_xy
.
tocsr
()
g_xy_data
=
g_xy
.
data
g_xy_indices
=
g_xy
.
indices
.
view
(
np
.
uint32
)
g_xy_ptr
=
g_xy
.
indptr
.
view
(
np
.
uint32
)
y_data
=
y
.
data
y_indices
=
y
.
indices
.
view
(
np
.
uint32
)
y_ptr
=
y
.
indptr
.
view
(
np
.
uint32
)
n_cols
=
len
(
x_ptr
)
-
1
output
=
np
.
empty
(
len
(
x_indices
),
dtype
=
out_dtype
)
for
x_col
in
range
(
n_cols
):
for
data_idx
in
range
(
x_ptr
[
x_col
],
x_ptr
[
x_col
+
1
]):
x_row
=
x_indices
[
data_idx
]
output
[
data_idx
]
=
dot_csr_rows
(
g_xy_ptr
,
g_xy_indices
,
g_xy_data
,
x_row
,
y_ptr
,
y_indices
,
y_data
,
x_col
,
)
return
output
return
grad_spmspm_csc
,
cache_key
pytensor/sparse/math.py
浏览文件 @
bad5f528
...
@@ -1394,6 +1394,8 @@ class StructuredDot(Op):
...
@@ -1394,6 +1394,8 @@ class StructuredDot(Op):
out
[
0
]
=
np
.
asarray
(
variable
,
str
(
variable
.
dtype
))
out
[
0
]
=
np
.
asarray
(
variable
,
str
(
variable
.
dtype
))
def
grad
(
self
,
inputs
,
gout
):
def
grad
(
self
,
inputs
,
gout
):
# FIXME: It's not always true that b and g_out are dense.
# Python implementation (and numba) support sparse 'b' (and thus, 'g_out') as well.
# a is sparse, b is dense, g_out is dense
# a is sparse, b is dense, g_out is dense
# ga = g_out x b.T
# ga = g_out x b.T
# gb = a.T x g_out
# gb = a.T x g_out
...
@@ -1474,16 +1476,17 @@ class StructuredDotGradCSC(COp):
...
@@ -1474,16 +1476,17 @@ class StructuredDotGradCSC(COp):
__props__
=
()
__props__
=
()
def
make_node
(
self
,
a_indices
,
a_indptr
,
b
,
g_ab
):
def
make_node
(
self
,
a_indices
,
a_indptr
,
b
,
g_ab
):
out_dtype
=
ps
.
upcast
(
b
.
dtype
,
g_ab
.
dtype
)
return
Apply
(
return
Apply
(
self
,
self
,
[
a_indices
,
a_indptr
,
b
,
g_ab
],
[
a_indices
,
a_indptr
,
b
,
g_ab
],
[
tensor
(
dtype
=
g_ab
.
dtype
,
shape
=
(
None
,))],
[
tensor
(
dtype
=
out_
dtype
,
shape
=
(
None
,))],
)
)
def
perform
(
self
,
node
,
inputs
,
outputs
):
def
perform
(
self
,
node
,
inputs
,
outputs
):
(
a_indices
,
a_indptr
,
b
,
g_ab
)
=
inputs
(
a_indices
,
a_indptr
,
b
,
g_ab
)
=
inputs
(
out
,)
=
outputs
(
out
,)
=
outputs
g_a_data
=
np
.
zeros
(
a_indices
.
shape
,
dtype
=
g_ab
.
dtype
)
g_a_data
=
np
.
zeros
(
a_indices
.
shape
,
dtype
=
node
.
outputs
[
0
]
.
dtype
)
for
j
in
range
(
len
(
a_indptr
)
-
1
):
for
j
in
range
(
len
(
a_indptr
)
-
1
):
ind0
=
a_indptr
[
j
]
ind0
=
a_indptr
[
j
]
ind1
=
a_indptr
[
j
+
1
]
ind1
=
a_indptr
[
j
+
1
]
...
@@ -1615,7 +1618,7 @@ class StructuredDotGradCSR(COp):
...
@@ -1615,7 +1618,7 @@ class StructuredDotGradCSR(COp):
def
perform
(
self
,
node
,
inputs
,
outputs
):
def
perform
(
self
,
node
,
inputs
,
outputs
):
(
a_indices
,
a_indptr
,
b
,
g_ab
)
=
inputs
(
a_indices
,
a_indptr
,
b
,
g_ab
)
=
inputs
(
out
,)
=
outputs
(
out
,)
=
outputs
g_a_data
=
np
.
zeros
(
a_indices
.
shape
,
dtype
=
g_ab
.
dtype
)
g_a_data
=
np
.
zeros
(
a_indices
.
shape
,
dtype
=
node
.
outputs
[
0
]
.
dtype
)
for
i
in
range
(
len
(
a_indptr
)
-
1
):
# loop over rows
for
i
in
range
(
len
(
a_indptr
)
-
1
):
# loop over rows
ind0
=
a_indptr
[
i
]
ind0
=
a_indptr
[
i
]
ind1
=
a_indptr
[
i
+
1
]
ind1
=
a_indptr
[
i
+
1
]
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论