提交 1dfb0455 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Simpler Python implementation of Join Op using Numpy Split

上级 35f3cbf5
...@@ -1925,30 +1925,18 @@ class Split(COp): ...@@ -1925,30 +1925,18 @@ class Split(COp):
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
x, axis, splits = inputs x, axis, splits = inputs
len_along_axis = x.shape[axis]
if len(splits) != self.len_splits: if len(splits) != self.len_splits:
raise ValueError("Length of `splits` is not equal to `len_splits`") raise ValueError("Length of splits is not equal to n_splits")
if np.sum(splits) != len_along_axis: if np.sum(splits) != x.shape[axis]:
raise ValueError(
f"The splits sum to {np.sum(splits)}; expected {len_along_axis}"
)
if builtins.any(nb < 0 for nb in splits):
raise ValueError( raise ValueError(
"Attempted to make an array with a negative number of elements" f"Split sizes sum to {np.sum(splits)}; expected {x.shape[axis]}"
) )
if np.any(splits < 0):
raise ValueError("Split sizes cannot be negative")
# Checking is done, let's roll the splitting algorithm! split_outs = np.split(x, np.cumsum(splits[:-1]), axis=axis)
# Basically we step along the given axis of x, extracting for i, out in enumerate(split_outs):
# subtensors of size splits[i] as we go along. outputs[i][0] = out.copy()
general_key = [slice(None, None, None) for s in x.shape]
lower_idx = 0
for i in range(self.len_splits):
upper_idx = lower_idx + splits[i]
general_key[axis] = slice(lower_idx, upper_idx, None)
outputs[i][0] = x.__getitem__(tuple(general_key)).copy()
lower_idx = upper_idx
def infer_shape(self, fgraph, node, in_shapes): def infer_shape(self, fgraph, node, in_shapes):
axis = node.inputs[1] axis = node.inputs[1]
......
...@@ -1491,13 +1491,21 @@ class TestJoinAndSplit: ...@@ -1491,13 +1491,21 @@ class TestJoinAndSplit:
out = self.eval_outputs_and_check_join([s]) out = self.eval_outputs_and_check_join([s])
assert (out == want).all() assert (out == want).all()
def test_join_matrix1(self): @pytest.mark.parametrize("py_impl", (False, True))
def test_join_matrix1(self, py_impl):
if py_impl:
impl_ctxt = pytensor.config.change_flags(cxx="")
else:
impl_ctxt = pytensor.config.change_flags()
av = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype="float32") av = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype="float32")
bv = np.array([[0.7], [0.8]], dtype="float32") bv = np.array([[0.7], [0.8]], dtype="float32")
a = self.shared(av) a = self.shared(av)
b = as_tensor_variable(bv) b = as_tensor_variable(bv)
s = join(1, a, b) s = join(1, a, b)
want = np.array([[0.1, 0.2, 0.3, 0.7], [0.4, 0.5, 0.6, 0.8]], dtype="float32") want = np.array([[0.1, 0.2, 0.3, 0.7], [0.4, 0.5, 0.6, 0.8]], dtype="float32")
with impl_ctxt:
out = self.eval_outputs_and_check_join([s]) out = self.eval_outputs_and_check_join([s])
assert (out == want).all() assert (out == want).all()
...@@ -1624,13 +1632,21 @@ class TestJoinAndSplit: ...@@ -1624,13 +1632,21 @@ class TestJoinAndSplit:
with pytest.raises(IndexError): with pytest.raises(IndexError):
f(-3) f(-3)
def test_join_matrixC_negative_axis(self): @pytest.mark.parametrize("py_impl", (False, True))
def test_join_matrixC_negative_axis(self, py_impl):
if py_impl:
impl_ctxt = pytensor.config.change_flags(cxx="")
else:
impl_ctxt = pytensor.config.change_flags()
# constant join negative axis # constant join negative axis
v = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=self.floatX) v = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=self.floatX)
a = self.shared(v) a = self.shared(v)
b = as_tensor_variable(v) b = as_tensor_variable(v)
s = join(-1, a, b) s = join(-1, a, b)
with impl_ctxt:
f = pytensor.function([], [s], mode=self.mode) f = pytensor.function([], [s], mode=self.mode)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert [True for node in topo if isinstance(node.op, type(self.join_op))] assert [True for node in topo if isinstance(node.op, type(self.join_op))]
...@@ -1643,6 +1659,8 @@ class TestJoinAndSplit: ...@@ -1643,6 +1659,8 @@ class TestJoinAndSplit:
assert np.allclose(got, want) assert np.allclose(got, want)
s = join(-2, a, b) s = join(-2, a, b)
with impl_ctxt:
f = pytensor.function([], [s], mode=self.mode) f = pytensor.function([], [s], mode=self.mode)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert [True for node in topo if isinstance(node.op, type(self.join_op))] assert [True for node in topo if isinstance(node.op, type(self.join_op))]
...@@ -1657,6 +1675,7 @@ class TestJoinAndSplit: ...@@ -1657,6 +1675,7 @@ class TestJoinAndSplit:
with pytest.raises(IndexError): with pytest.raises(IndexError):
join(-3, a, b) join(-3, a, b)
with impl_ctxt:
utt.verify_grad(lambda a, b: join(-1, a, b), [v, 2 * v], mode=self.mode) utt.verify_grad(lambda a, b: join(-1, a, b), [v, 2 * v], mode=self.mode)
def test_broadcastable_flag_assignment_mixed_otheraxes(self): def test_broadcastable_flag_assignment_mixed_otheraxes(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论