提交 bad5f528 authored 作者: Tomas Capretto's avatar Tomas Capretto 提交者: Ricardo Vieira

Implement StructuredDotGradCSR and StructuredDotGradCSC in numba backend

上级 00a11b60
...@@ -4,6 +4,7 @@ import numpy as np ...@@ -4,6 +4,7 @@ import numpy as np
import scipy.sparse as sp import scipy.sparse as sp
import pytensor.sparse.basic as psb import pytensor.sparse.basic as psb
from pytensor import config
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import ( from pytensor.link.numba.dispatch.basic import (
register_funcify_and_cache_key, register_funcify_and_cache_key,
...@@ -14,6 +15,8 @@ from pytensor.sparse import ( ...@@ -14,6 +15,8 @@ from pytensor.sparse import (
SparseDenseMultiply, SparseDenseMultiply,
SparseDenseVectorMultiply, SparseDenseVectorMultiply,
StructuredDot, StructuredDot,
StructuredDotGradCSC,
StructuredDotGradCSR,
) )
...@@ -402,3 +405,233 @@ def numba_funcify_SparseDot(op, node, **kwargs): ...@@ -402,3 +405,233 @@ def numba_funcify_SparseDot(op, node, **kwargs):
return spmdm_csr(y.T, x.T).T return spmdm_csr(y.T, x.T).T
return dmspm, cache_key return dmspm, cache_key
@register_funcify_and_cache_key(StructuredDotGradCSR)
@register_funcify_and_cache_key(StructuredDotGradCSC)
def numba_funcify_StructuredDotGrad(op, node, **kwargs):
"""Overload StructuredDotGrad in Numba.
Let:
Z = structured_dot(X, Y)
L = L(Z), a scalar loss depending on Z.
This function computes the gradient of the loss with respect to X:
dL/dX = dot(dL/dZ, Y^T)
where G = dL/dZ is the accumulated (upstream) gradient.
The returned gradient is structured, preserving the sparsity pattern of X,
and only the `.data` component of the sparse matrix is computed.
If Y is sparse, the sparsity pattern of the result is not recomputed.
The output may contain explicit zeros at positions that would be structural zeros
if the sparsity structure were updated.
The core of the algorithm is:
dot(g_xy[i], y[j])
where g_xy[i] (row of G) and y[j] (column of Y^T) are vectors of length 'k'
Reminder:
x.shape (n, p)
y.shape (p, k)
g_xy.shape (n, k)
"""
_, _, y, g_xy = node.inputs
y_dtype = y.type.dtype
y_is_sparse = psb._is_sparse_variable(y)
y_format = y.type.format if y_is_sparse else None
g_xy_dtype = g_xy.type.dtype
g_xy_is_sparse = psb._is_sparse_variable(g_xy)
g_xy_format = g_xy.type.format if g_xy_is_sparse else None
x_format = "csc" if isinstance(op, StructuredDotGradCSC) else "csr"
out_dtype = g_xy_dtype
cache_key = sha256(
str(
(
type(op),
x_format,
y_format,
y_dtype,
g_xy_format,
out_dtype,
y.type.shape,
)
).encode()
).hexdigest()
if not g_xy_is_sparse:
# X is sparse, Y and G_xy are dense.
if x_format == "csr":
if y.type.shape[1] == 1:
# If Y is actually 1D, use more performant specialized algorithm
# Inputs with ndims > 2 will never appear in the StructuredDot Op
@numba_basic.numba_njit
def _grad_spmdv_csr(x_indices, x_ptr, y, g_xy):
output = np.empty(len(x_indices), dtype=out_dtype)
size = len(x_ptr) - 1
x_indices = x_indices.view(np.uint32)
x_ptr = x_ptr.view(np.uint32)
for row_idx in range(size):
for value_idx in range(x_ptr[row_idx], x_ptr[row_idx + 1]):
output[value_idx] = g_xy[row_idx] * y[x_indices[value_idx]]
return output
@numba_basic.numba_njit
def grad_spmdv_csr(x_indices, x_ptr, y, g_xy):
return _grad_spmdv_csr(x_indices, x_ptr, y[:, 0], g_xy[:, 0])
return grad_spmdv_csr, cache_key
else:
# Y is a matrix
if config.compiler_verbose and y_dtype != out_dtype:
print( # noqa: T201
"Numba StructuredDotGrad requires a type casting of inputs: "
f"{y_dtype=}, {g_xy_dtype=}."
)
@numba_basic.numba_njit
def grad_spmdm_csr(x_indices, x_ptr, y, g_xy):
size = len(x_ptr) - 1
x_indices = x_indices.view(np.uint32)
x_ptr = x_ptr.view(np.uint32)
if y_dtype != out_dtype:
new_out_dtype = np.result_type(y, g_xy)
output = np.zeros(len(x_indices), dtype=new_out_dtype)
y = y.astype(out_dtype)
g_xy = g_xy.astype(out_dtype)
else:
output = np.zeros(len(x_indices), dtype=out_dtype)
for row_idx in range(size):
for value_idx in range(x_ptr[row_idx], x_ptr[row_idx + 1]):
output[value_idx] = np.dot(
g_xy[row_idx], y[x_indices[value_idx]]
)
return output
return grad_spmdm_csr, cache_key
else:
# X is CSC
@numba_basic.numba_njit
def grad_spmdm_csc(x_indices, x_ptr, y, g_xy):
# len(x_indices) gives the number of non-zero elements in X.
output = np.zeros(len(x_indices), dtype=out_dtype)
size = len(x_ptr) - 1
x_indices = x_indices.view(np.uint32)
x_ptr = x_ptr.view(np.uint32)
for col_idx in range(size):
for value_idx in range(x_ptr[col_idx], x_ptr[col_idx + 1]):
output[value_idx] = np.dot(
g_xy[x_indices[value_idx]], y[col_idx]
)
return output
return grad_spmdm_csc, cache_key
# Y is sparse. In either case we need 'dot_csr_rows'
@numba_basic.numba_njit
def dot_csr_rows(x_ptr, x_indices, x_data, x_row, y_ptr, y_indices, y_data, y_row):
x_p = x_ptr[x_row]
x_end = x_ptr[x_row + 1]
y_p = y_ptr[y_row]
y_end = y_ptr[y_row + 1]
acc = 0.0
while x_p < x_end and y_p < y_end:
x_col = x_indices[x_p]
y_col = y_indices[y_p]
if x_col == y_col:
acc += x_data[x_p] * y_data[y_p]
x_p += 1
y_p += 1
elif x_col < y_col:
x_p += 1
else:
y_p += 1
return acc
if x_format == "csr":
assert g_xy_format == "csr"
assert psb._is_sparse_variable(y)
@numba_basic.numba_njit
def grad_spmspm_csr(x_indices, x_ptr, y, g_xy):
if y_format == "csc":
y = y.tocsr()
g_xy_data = g_xy.data
g_xy_indices = g_xy.indices.view(np.uint32)
g_xy_ptr = g_xy.indptr.view(np.uint32)
y_data = y.data
y_indices = y.indices.view(np.uint32)
y_ptr = y.indptr.view(np.uint32)
n_row = len(x_ptr) - 1
output = np.zeros(len(x_indices), dtype=out_dtype)
for x_row in range(n_row):
for data_idx in range(x_ptr[x_row], x_ptr[x_row + 1]):
x_col = x_indices[data_idx]
output[data_idx] = dot_csr_rows(
g_xy_ptr,
g_xy_indices,
g_xy_data,
x_row,
y_ptr,
y_indices,
y_data,
x_col,
)
return output
return grad_spmspm_csr, cache_key
else:
assert g_xy_format == "csc"
assert psb._is_sparse_variable(y)
@numba_basic.numba_njit
def grad_spmspm_csc(x_indices, x_ptr, y, g_xy):
if y_format == "csc":
y = y.tocsr()
# Looping a CSC matrix rowwise is too painful, slow, and cryptic.
g_xy = g_xy.tocsr()
g_xy_data = g_xy.data
g_xy_indices = g_xy.indices.view(np.uint32)
g_xy_ptr = g_xy.indptr.view(np.uint32)
y_data = y.data
y_indices = y.indices.view(np.uint32)
y_ptr = y.indptr.view(np.uint32)
n_cols = len(x_ptr) - 1
output = np.empty(len(x_indices), dtype=out_dtype)
for x_col in range(n_cols):
for data_idx in range(x_ptr[x_col], x_ptr[x_col + 1]):
x_row = x_indices[data_idx]
output[data_idx] = dot_csr_rows(
g_xy_ptr,
g_xy_indices,
g_xy_data,
x_row,
y_ptr,
y_indices,
y_data,
x_col,
)
return output
return grad_spmspm_csc, cache_key
...@@ -1394,6 +1394,8 @@ class StructuredDot(Op): ...@@ -1394,6 +1394,8 @@ class StructuredDot(Op):
out[0] = np.asarray(variable, str(variable.dtype)) out[0] = np.asarray(variable, str(variable.dtype))
def grad(self, inputs, gout): def grad(self, inputs, gout):
# FIXME: It's not always true that b and g_out are dense.
# Python implementation (and numba) support sparse 'b' (and thus, 'g_out') as well.
# a is sparse, b is dense, g_out is dense # a is sparse, b is dense, g_out is dense
# ga = g_out x b.T # ga = g_out x b.T
# gb = a.T x g_out # gb = a.T x g_out
...@@ -1474,16 +1476,17 @@ class StructuredDotGradCSC(COp): ...@@ -1474,16 +1476,17 @@ class StructuredDotGradCSC(COp):
__props__ = () __props__ = ()
def make_node(self, a_indices, a_indptr, b, g_ab): def make_node(self, a_indices, a_indptr, b, g_ab):
out_dtype = ps.upcast(b.dtype, g_ab.dtype)
return Apply( return Apply(
self, self,
[a_indices, a_indptr, b, g_ab], [a_indices, a_indptr, b, g_ab],
[tensor(dtype=g_ab.dtype, shape=(None,))], [tensor(dtype=out_dtype, shape=(None,))],
) )
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
(a_indices, a_indptr, b, g_ab) = inputs (a_indices, a_indptr, b, g_ab) = inputs
(out,) = outputs (out,) = outputs
g_a_data = np.zeros(a_indices.shape, dtype=g_ab.dtype) g_a_data = np.zeros(a_indices.shape, dtype=node.outputs[0].dtype)
for j in range(len(a_indptr) - 1): for j in range(len(a_indptr) - 1):
ind0 = a_indptr[j] ind0 = a_indptr[j]
ind1 = a_indptr[j + 1] ind1 = a_indptr[j + 1]
...@@ -1615,7 +1618,7 @@ class StructuredDotGradCSR(COp): ...@@ -1615,7 +1618,7 @@ class StructuredDotGradCSR(COp):
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
(a_indices, a_indptr, b, g_ab) = inputs (a_indices, a_indptr, b, g_ab) = inputs
(out,) = outputs (out,) = outputs
g_a_data = np.zeros(a_indices.shape, dtype=g_ab.dtype) g_a_data = np.zeros(a_indices.shape, dtype=node.outputs[0].dtype)
for i in range(len(a_indptr) - 1): # loop over rows for i in range(len(a_indptr) - 1): # loop over rows
ind0 = a_indptr[i] ind0 = a_indptr[i]
ind1 = a_indptr[i + 1] ind1 = a_indptr[i + 1]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论