提交 c1a53d7d authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

SparseMultiply: Cleanup Ops

* Handle static shape * Rename to more readable Op classes * Simplify perform
上级 2e4e3095
...@@ -12,6 +12,7 @@ from pytensor import config ...@@ -12,6 +12,7 @@ from pytensor import config
from pytensor.gradient import grad_not_implemented from pytensor.gradient import grad_not_implemented
from pytensor.graph import Apply, Op from pytensor.graph import Apply, Op
from pytensor.link.c.op import COp from pytensor.link.c.op import COp
from pytensor.sparse.type import SparseTensorType
from pytensor.tensor.shape import specify_broadcastable from pytensor.tensor.shape import specify_broadcastable
from pytensor.tensor.type import TensorType, Variable, complex_dtypes, tensor from pytensor.tensor.type import TensorType, Variable, complex_dtypes, tensor
...@@ -379,7 +380,7 @@ class AddSS(Op): ...@@ -379,7 +380,7 @@ class AddSS(Op):
return Apply( return Apply(
self, self,
[x, y], [x, y],
[psb.SparseTensorType(dtype=out_dtype, format=x.type.format)()], [SparseTensorType(dtype=out_dtype, format=x.type.format)()],
) )
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
...@@ -439,7 +440,7 @@ class AddSSData(Op): ...@@ -439,7 +440,7 @@ class AddSSData(Op):
return Apply( return Apply(
self, self,
[x, y], [x, y],
[psb.SparseTensorType(dtype=x.type.dtype, format=x.type.format)()], [SparseTensorType(dtype=x.type.dtype, format=x.type.format)()],
) )
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
...@@ -542,7 +543,7 @@ class StructuredAddSV(Op): ...@@ -542,7 +543,7 @@ class StructuredAddSV(Op):
return Apply( return Apply(
self, self,
[x, y], [x, y],
[psb.SparseTensorType(dtype=x.type.dtype, format=x.type.format)()], [SparseTensorType(dtype=x.type.dtype, format=x.type.format)()],
) )
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
...@@ -658,7 +659,7 @@ def sub(x, y): ...@@ -658,7 +659,7 @@ def sub(x, y):
sub.__doc__ = subtract.__doc__ sub.__doc__ = subtract.__doc__
class MulSS(Op): class SparseSparseMultiply(Op):
# mul(sparse, sparse) # mul(sparse, sparse)
# See the doc of mul() for more detail # See the doc of mul() for more detail
__props__ = () __props__ = ()
...@@ -671,7 +672,7 @@ class MulSS(Op): ...@@ -671,7 +672,7 @@ class MulSS(Op):
return Apply( return Apply(
self, self,
[x, y], [x, y],
[psb.SparseTensorType(dtype=out_dtype, format=x.type.format)()], [SparseTensorType(dtype=out_dtype, format=x.type.format)()],
) )
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
...@@ -693,10 +694,10 @@ class MulSS(Op): ...@@ -693,10 +694,10 @@ class MulSS(Op):
return [shapes[0]] return [shapes[0]]
mul_s_s = MulSS() mul_s_s = SparseSparseMultiply()
class MulSD(Op): class SparseDenseMultiply(Op):
# mul(sparse, dense) # mul(sparse, dense)
# See the doc of mul() for more detail # See the doc of mul() for more detail
__props__ = () __props__ = ()
...@@ -713,65 +714,63 @@ class MulSD(Op): ...@@ -713,65 +714,63 @@ class MulSD(Op):
# objects must be matrices (have dimension 2) # objects must be matrices (have dimension 2)
# Broadcasting of the sparse matrix is not supported. # Broadcasting of the sparse matrix is not supported.
# We support nd == 0 used by grad of SpSum() # We support nd == 0 used by grad of SpSum()
assert y.type.ndim in (0, 2) if y.type.ndim not in (0, 2):
out = psb.SparseTensorType(dtype=dtype, format=x.type.format)() raise ValueError(f"y {y} must have 0 or 2 dimensions. Got {y.type.ndim}")
if y.type.ndim == 0:
out_shape = x.type.shape
if y.type.ndim == 2:
# Combine with static shape information from y
out_shape = []
for x_st_dim_length, y_st_dim_length in zip(x.type.shape, y.type.shape):
if x_st_dim_length is None:
out_shape.append(y_st_dim_length)
else:
out_shape.append(x_st_dim_length)
# If both are known, they must match
if (
y_st_dim_length is not None
and y_st_dim_length != x_st_dim_length
):
raise ValueError(
f"Incompatible static shapes {x}: {x.type.shape}, {y}: {y.type.shape}"
)
out_shape = tuple(out_shape)
out = SparseTensorType(dtype=dtype, format=x.type.format, shape=out_shape)()
return Apply(self, [x, y], [out]) return Apply(self, [x, y], [out])
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
(x, y) = inputs (x, y) = inputs
(out,) = outputs (out,) = outputs
out_dtype = node.outputs[0].dtype
assert psb._is_sparse(x) and psb._is_dense(y) assert psb._is_sparse(x) and psb._is_dense(y)
if len(y.shape) == 0:
out_dtype = node.outputs[0].dtype if x.dtype == out_dtype:
if x.dtype == out_dtype: z = x.copy()
z = x.copy() else:
else: z = x.astype(out_dtype)
z = x.astype(out_dtype) out[0] = z
out[0] = z z_data = z.data
out[0].data *= y
elif len(y.shape) == 1: if y.ndim == 0:
raise NotImplementedError() # RowScale / ColScale z_data *= y
elif len(y.shape) == 2: else: # y_ndim == 2
# if we have enough memory to fit y, maybe we can fit x.asarray() # if we have enough memory to fit y, maybe we can fit x.asarray()
# too? # too?
# TODO: change runtime from O(M*N) to O(nonzeros) # TODO: change runtime from O(M*N) to O(nonzeros)
M, N = x.shape M, N = x.shape
assert x.shape == y.shape assert x.shape == y.shape
out_dtype = node.outputs[0].dtype indices = x.indices
indptr = x.indptr
if x.format == "csc": if x.format == "csc":
indices = x.indices
indptr = x.indptr
if x.dtype == out_dtype:
z = x.copy()
else:
z = x.astype(out_dtype)
z_data = z.data
for j in range(0, N): for j in range(0, N):
for i_idx in range(indptr[j], indptr[j + 1]): for i_idx in range(indptr[j], indptr[j + 1]):
i = indices[i_idx] i = indices[i_idx]
z_data[i_idx] *= y[i, j] z_data[i_idx] *= y[i, j]
out[0] = z
elif x.format == "csr": elif x.format == "csr":
indices = x.indices
indptr = x.indptr
if x.dtype == out_dtype:
z = x.copy()
else:
z = x.astype(out_dtype)
z_data = z.data
for i in range(0, M): for i in range(0, M):
for j_idx in range(indptr[i], indptr[i + 1]): for j_idx in range(indptr[i], indptr[i + 1]):
j = indices[j_idx] j = indices[j_idx]
z_data[j_idx] *= y[i, j] z_data[j_idx] *= y[i, j]
out[0] = z
else:
warn(
"This implementation of MulSD is deficient: {x.format}",
)
out[0] = type(x)(x.toarray() * y)
def grad(self, inputs, gout): def grad(self, inputs, gout):
(x, y) = inputs (x, y) = inputs
...@@ -784,10 +783,10 @@ class MulSD(Op): ...@@ -784,10 +783,10 @@ class MulSD(Op):
return [shapes[0]] return [shapes[0]]
mul_s_d = MulSD() mul_s_d = SparseDenseMultiply()
class MulSV(Op): class SparseDenseVectorMultiply(Op):
"""Element-wise multiplication of sparse matrix by a broadcasted dense vector element wise. """Element-wise multiplication of sparse matrix by a broadcasted dense vector element wise.
Notes Notes
...@@ -796,6 +795,8 @@ class MulSV(Op): ...@@ -796,6 +795,8 @@ class MulSV(Op):
""" """
# TODO: Merge with the SparseDenseMultiply Op
__props__ = () __props__ = ()
def make_node(self, x, y): def make_node(self, x, y):
...@@ -812,17 +813,30 @@ class MulSV(Op): ...@@ -812,17 +813,30 @@ class MulSV(Op):
assert x.format in ("csr", "csc") assert x.format in ("csr", "csc")
y = ptb.as_tensor_variable(y) y = ptb.as_tensor_variable(y)
assert y.type.ndim == 1 if y.type.ndim != 1:
raise ValueError(f"y {y} must have 1 dimension. Got {y.type.ndim}")
if x.type.dtype != y.type.dtype: if x.type.dtype != y.type.dtype:
raise NotImplementedError( raise NotImplementedError(
"MulSV not implemented for differing dtypes." f"Differing dtypes not supported. Got {x.type.dtype} and {y.type.dtype}."
f"Got {x.type.dtype} and {y.type.dtype}."
) )
out_shape = [x.type.shape[0]]
if x.type.shape[-1] is None:
out_shape.append(y.type.shape[0])
else:
out_shape.append(x.type.shape[-1])
if y.type.shape[-1] is not None and x.type.shape[-1] != y.type.shape[-1]:
raise ValueError(
f"Incompatible static shapes for multiplication {x}: {x.type.shape}, {y}: {y.type.shape}"
)
return Apply( return Apply(
self, self,
[x, y], [x, y],
[psb.SparseTensorType(dtype=x.type.dtype, format=x.type.format)()], [
SparseTensorType(
dtype=x.type.dtype, format=x.type.format, shape=tuple(out_shape)
)()
],
) )
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
...@@ -852,7 +866,7 @@ class MulSV(Op): ...@@ -852,7 +866,7 @@ class MulSV(Op):
return [ins_shapes[0]] return [ins_shapes[0]]
mul_s_v = MulSV() mul_s_v = SparseDenseVectorMultiply()
def multiply(x, y): def multiply(x, y):
...@@ -891,16 +905,17 @@ def multiply(x, y): ...@@ -891,16 +905,17 @@ def multiply(x, y):
# mul_s_s is not implemented if the types differ # mul_s_s is not implemented if the types differ
if y.dtype == "float64" and x.dtype == "float32": if y.dtype == "float64" and x.dtype == "float32":
x = x.astype("float64") x = x.astype("float64")
return mul_s_s(x, y) return mul_s_s(x, y)
elif x_is_sparse_variable and not y_is_sparse_variable: elif x_is_sparse_variable or y_is_sparse_variable:
if y_is_sparse_variable:
x, y = y, x
# mul is unimplemented if the dtypes differ # mul is unimplemented if the dtypes differ
if y.dtype == "float64" and x.dtype == "float32": if y.dtype == "float64" and x.dtype == "float32":
x = x.astype("float64") x = x.astype("float64")
if y.ndim == 1:
return mul_s_d(x, y) return mul_s_v(x, y)
elif y_is_sparse_variable and not x_is_sparse_variable: else:
return mul_s_d(y, x) return mul_s_d(x, y)
else: else:
raise NotImplementedError() raise NotImplementedError()
...@@ -950,7 +965,7 @@ class __ComparisonOpSS(Op): ...@@ -950,7 +965,7 @@ class __ComparisonOpSS(Op):
if x.type.format != y.type.format: if x.type.format != y.type.format:
raise NotImplementedError() raise NotImplementedError()
return Apply( return Apply(
self, [x, y], [psb.SparseTensorType(dtype="uint8", format=x.type.format)()] self, [x, y], [SparseTensorType(dtype="uint8", format=x.type.format)()]
) )
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
...@@ -1203,7 +1218,7 @@ class TrueDot(Op): ...@@ -1203,7 +1218,7 @@ class TrueDot(Op):
raise NotImplementedError() raise NotImplementedError()
inputs = [x, y] # Need to convert? e.g. assparse inputs = [x, y] # Need to convert? e.g. assparse
outputs = [psb.SparseTensorType(dtype=x.type.dtype, format=myformat)()] outputs = [SparseTensorType(dtype=x.type.dtype, format=myformat)()]
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
def perform(self, node, inp, out_): def perform(self, node, inp, out_):
...@@ -1324,9 +1339,7 @@ class StructuredDot(Op): ...@@ -1324,9 +1339,7 @@ class StructuredDot(Op):
raise NotImplementedError("non-matrix b") raise NotImplementedError("non-matrix b")
if psb._is_sparse_variable(b): if psb._is_sparse_variable(b):
return Apply( return Apply(self, [a, b], [SparseTensorType(a.type.format, dtype_out)()])
self, [a, b], [psb.SparseTensorType(a.type.format, dtype_out)()]
)
else: else:
return Apply( return Apply(
self, self,
...@@ -1348,7 +1361,7 @@ class StructuredDot(Op): ...@@ -1348,7 +1361,7 @@ class StructuredDot(Op):
) )
variable = a * b variable = a * b
if isinstance(node.outputs[0].type, psb.SparseTensorType): if isinstance(node.outputs[0].type, SparseTensorType):
assert psb._is_sparse(variable) assert psb._is_sparse(variable)
out[0] = variable out[0] = variable
return return
......
...@@ -75,7 +75,8 @@ def test_local_mul_s_d(): ...@@ -75,7 +75,8 @@ def test_local_mul_s_d():
f = pytensor.function(inputs, smath.mul_s_d(*inputs), mode="CVM") f = pytensor.function(inputs, smath.mul_s_d(*inputs), mode="CVM")
assert not any( assert not any(
isinstance(node.op, smath.MulSD) for node in f.maker.fgraph.toposort() isinstance(node.op, smath.SparseDenseMultiply)
for node in f.maker.fgraph.toposort()
) )
...@@ -92,7 +93,8 @@ def test_local_mul_s_v(): ...@@ -92,7 +93,8 @@ def test_local_mul_s_v():
f = pytensor.function(inputs, smath.mul_s_v(*inputs), mode="CVM") f = pytensor.function(inputs, smath.mul_s_v(*inputs), mode="CVM")
assert not any( assert not any(
isinstance(node.op, smath.MulSV) for node in f.maker.fgraph.toposort() isinstance(node.op, smath.SparseDenseVectorMultiply)
for node in f.maker.fgraph.toposort()
) )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论