提交 3bd8ed9f authored 作者: Tomas Capretto's avatar Tomas Capretto 提交者: Ricardo Vieira

Implement specific make_node for sparse HStack/VStack with early static shape validation

上级 f0603aac
...@@ -1636,23 +1636,35 @@ class Stack(Op): ...@@ -1636,23 +1636,35 @@ class Stack(Op):
raise ValueError("The output dtype must be specified.") raise ValueError("The output dtype must be specified.")
self.dtype = dtype self.dtype = dtype
def make_node(self, *mat): def __str__(self):
if not mat: return f"{self.__class__.__name__}({self.format},{self.dtype})"
class HStack(Stack):
def make_node(self, *blocks):
if not blocks:
raise ValueError("Cannot join an empty list of sparses.") raise ValueError("Cannot join an empty list of sparses.")
var = [as_sparse_variable(x) for x in mat] input_blocks = [as_sparse_variable(block) for block in blocks]
for x in var: for x in input_blocks:
assert x.format in ("csr", "csc") assert x.format in ("csr", "csc")
return Apply( # Known rows numbers must be the same for all matrices.
self, var, [SparseTensorType(dtype=self.dtype, format=self.format)()] static_n_rows = {
x.type.shape[0] for x in input_blocks if x.type.shape[0] is not None
}
if len(static_n_rows) > 1:
raise ValueError(
"All matrices must have the same number of rows; "
f"got row counts: {static_n_rows}."
) )
def __str__(self): return Apply(
return f"{self.__class__.__name__}({self.format},{self.dtype})" self,
input_blocks,
[SparseTensorType(dtype=self.dtype, format=self.format)()],
)
class HStack(Stack):
def perform(self, node, block, outputs): def perform(self, node, block, outputs):
(out,) = outputs (out,) = outputs
for b in block: for b in block:
...@@ -1726,6 +1738,30 @@ def hstack(blocks, format=None, dtype=None): ...@@ -1726,6 +1738,30 @@ def hstack(blocks, format=None, dtype=None):
class VStack(Stack): class VStack(Stack):
def make_node(self, *blocks):
if not blocks:
raise ValueError("Cannot join an empty list of sparses.")
input_blocks = [as_sparse_variable(block) for block in blocks]
for x in input_blocks:
assert x.format in ("csr", "csc")
# Known column numbers must be the same for all matrices.
static_n_cols = {
x.type.shape[1] for x in input_blocks if x.type.shape[1] is not None
}
if len(static_n_cols) > 1:
raise ValueError(
"All matrices must have the same number of columns; "
f"got column counts: {static_n_cols}."
)
return Apply(
self,
input_blocks,
[SparseTensorType(dtype=self.dtype, format=self.format)()],
)
def perform(self, node, block, outputs): def perform(self, node, block, outputs):
(out,) = outputs (out,) = outputs
for b in block: for b in block:
......
...@@ -409,8 +409,8 @@ def test_sparse_vstack(output_format, input_formats): ...@@ -409,8 +409,8 @@ def test_sparse_vstack(output_format, input_formats):
def test_sparse_hstack_mismatched_rows_raises(): def test_sparse_hstack_mismatched_rows_raises():
x = ps.matrix(name="x", shape=(3, 5), format="csr", dtype=config.floatX) x = ps.matrix(name="x", shape=(None, 5), format="csr", dtype=config.floatX)
y = ps.matrix(name="y", shape=(4, 7), format="csr", dtype=config.floatX) y = ps.matrix(name="y", shape=(None, 7), format="csr", dtype=config.floatX)
z = ps.hstack([x, y], format="csr", dtype=config.floatX) z = ps.hstack([x, y], format="csr", dtype=config.floatX)
fn = function([x, y], z, mode="NUMBA") fn = function([x, y], z, mode="NUMBA")
...@@ -422,8 +422,8 @@ def test_sparse_hstack_mismatched_rows_raises(): ...@@ -422,8 +422,8 @@ def test_sparse_hstack_mismatched_rows_raises():
def test_sparse_vstack_mismatched_cols_raises(): def test_sparse_vstack_mismatched_cols_raises():
x = ps.matrix(name="x", shape=(10, 3), format="csr", dtype=config.floatX) x = ps.matrix(name="x", shape=(10, None), format="csr", dtype=config.floatX)
y = ps.matrix(name="y", shape=(13, 4), format="csr", dtype=config.floatX) y = ps.matrix(name="y", shape=(13, None), format="csr", dtype=config.floatX)
z = ps.vstack([x, y], format="csr", dtype=config.floatX) z = ps.vstack([x, y], format="csr", dtype=config.floatX)
fn = function([x, y], z, mode="NUMBA") fn = function([x, y], z, mode="NUMBA")
......
...@@ -1583,3 +1583,23 @@ def test_hstack_vstack(): ...@@ -1583,3 +1583,23 @@ def test_hstack_vstack():
stacked_blocks = stack_function(blocks, dtype=to_dtype) stacked_blocks = stack_function(blocks, dtype=to_dtype)
expected_dtype = get_expected_dtype(blocks, to_dtype) expected_dtype = get_expected_dtype(blocks, to_dtype)
assert stacked_blocks.dtype == expected_dtype assert stacked_blocks.dtype == expected_dtype
def test_sparse_hstack_mismatched_static_rows_raises():
x = sparse.matrix(name="x", shape=(3, 5), format="csr", dtype=config.floatX)
y = sparse.matrix(name="y", shape=(4, 7), format="csr", dtype=config.floatX)
with pytest.raises(
ValueError, match="All matrices must have the same number of rows"
):
sparse.hstack([x, y], format="csr", dtype=config.floatX)
def test_sparse_vstack_mismatched_static_cols_raises():
x = sparse.matrix(name="x", shape=(10, 3), format="csr", dtype=config.floatX)
y = sparse.matrix(name="y", shape=(13, 4), format="csr", dtype=config.floatX)
with pytest.raises(
ValueError, match="All matrices must have the same number of columns"
):
sparse.vstack([x, y], format="csr", dtype=config.floatX)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论