提交 f4536c30 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Make Split a view_op

This allows the outputs to be views of the inputs. The Python and Numba implementation do that, but the C still performs a copy
上级 91966e85
......@@ -1903,6 +1903,7 @@ class Split(COp):
b == [3, 4]
c == [5]
TODO: Don't make a copy in C impl
"""
len_splits = None
......@@ -1913,6 +1914,7 @@ class Split(COp):
def __init__(self, len_splits):
self.len_splits = int(len_splits)
self.view_map = {i: [0] for i in range(self.len_splits)}
def __str__(self):
return f"{self.__class__.__name__ }{{{self.len_splits}}}"
......@@ -1949,7 +1951,7 @@ class Split(COp):
split_outs = np.split(x, np.cumsum(splits[:-1]), axis=axis)
for i, out in enumerate(split_outs):
outputs[i][0] = out.copy()
outputs[i][0] = out
def infer_shape(self, fgraph, node, in_shapes):
axis = node.inputs[1]
......
......@@ -4,10 +4,11 @@ import pytest
import pytensor.scalar as aes
import pytensor.tensor as at
import pytensor.tensor.basic as atb
from pytensor import config
from pytensor import config, function
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.scalar import Add
from pytensor.tensor.shape import Unbroadcast
from tests.link.numba.test_basic import (
compare_numba_and_py,
......@@ -332,6 +333,30 @@ def test_Split(n_splits, axis, values, sizes):
)
def test_Split_view():
# https://github.com/pymc-devs/pytensor/issues/343
x1 = at.matrix("x1")
x2 = at.matrix("x2", shape=(None, 1))
v = at.vector("v", shape=(2,), dtype=int)
out = at.split(x1, v, n_splits=2, axis=1)[0] + x2
fn = function([x1, x2, v], out, mode="NUMBA")
# Check that the addition of split[0] and x2 is not in place
add_op = fn.maker.fgraph.outputs[0].owner.op
assert isinstance(add_op.scalar_op, Add)
assert not add_op.inplace_pattern
rng = np.random.default_rng(123)
test_x1 = rng.normal(size=(2, 2))
test_x2 = rng.normal(size=(2, 1))
test_v = np.array([1, 1])
np.testing.assert_allclose(
fn(test_x1, test_x2, test_v).copy(),
fn(test_x1, test_x2, test_v).copy(),
)
@pytest.mark.parametrize(
"val, offset",
[
......
......@@ -1372,15 +1372,28 @@ def test_local_useless_split():
f_rewritten(np.random.random((4, 4)).astype(config.floatX), [4])
f_not_rewritten(np.random.random((4, 4)).astype(config.floatX), [1, 2, 1])
graph_rewritten = f_rewritten.maker.fgraph.toposort()
graph_not_rewritten = f_not_rewritten.maker.fgraph.toposort()
graph_rewritten = f_rewritten.maker.fgraph
graph_not_rewritten = f_not_rewritten.maker.fgraph
assert isinstance(graph_rewritten[-1].op, DeepCopyOp)
assert len(graph_not_rewritten) == 1
assert isinstance(graph_not_rewritten[0].op, Split)
assert all(
isinstance(out.owner.op, DeepCopyOp) for out in graph_not_rewritten.outputs
)
assert all(isinstance(out.owner.op, DeepCopyOp) for out in graph_rewritten.outputs)
assert sum(isinstance(node.op, Split) for node in graph_rewritten.apply_nodes) == 0
assert (
sum(isinstance(node.op, Split) for node in graph_not_rewritten.apply_nodes) == 1
)
assert sum(isinstance(node.op, Assert) for node in graph_rewritten.apply_nodes) == 2
assert (
sum(isinstance(node.op, Assert) for node in graph_not_rewritten.apply_nodes)
== 0
)
# The DeepCopy Ops don't have traces, so we can't check "all"
assert check_stack_trace(f_rewritten, ops_to_check=[Assert])
assert check_stack_trace(f_not_rewritten, ops_to_check="all")
assert check_stack_trace(f_not_rewritten, ops_to_check=[Split])
@pytest.mark.parametrize("i", list(range(1, 4)))
......
......@@ -11,7 +11,7 @@ import pytensor.tensor.basic as at
import pytensor.tensor.math as tm
from pytensor import compile, config, function, shared
from pytensor.compile.io import In, Out
from pytensor.compile.mode import get_default_mode
from pytensor.compile.mode import Mode, get_default_mode
from pytensor.compile.ops import DeepCopyOp
from pytensor.gradient import grad, hessian
from pytensor.graph.basic import Apply
......@@ -2002,8 +2002,7 @@ class TestJoinAndSplit:
y = Split(2)(x, 0, [s, 5 - s])[0]
assert y.type.shape == (None,)
def test_join_inplace():
def test_join_inplace(self):
# Test join to work inplace.
#
# This function tests the case when several elements are passed to the
......@@ -2025,8 +2024,7 @@ def test_join_inplace():
assert f(data, 0) is data
assert np.allclose(f(data, 0), [3, 4, 5])
def test_join_oneInput():
def test_join_oneInput(self):
# Test join when only 1 input is given.
#
# This functions tests the case when concatenate is called
......@@ -2042,6 +2040,28 @@ def test_join_oneInput():
assert join_0 is x_0
assert join_1 is not x_0
@pytest.mark.parametrize("linker", ("py", "c"))
def test_split_view(self, linker):
x = vector("x")
axis = 0
op = Split(len_splits=3)
assert op.view_map == {0: [0], 1: [0], 2: [0]}
splits = op(x, axis, [0, 3, 2])
mode = Mode(linker)
f = pytensor.function(
[In(x, borrow=True)], [Out(s, borrow=True) for s in splits], mode=mode
)
x_test = np.arange(5, dtype=config.floatX)
res = f(x_test)
for r, expected in zip(res, ([], [0, 1, 2], [3, 4])):
assert np.allclose(r, expected)
if linker == "py":
assert r.base is x_test
else:
# C impl always makes a copy
assert r.base is not x_test
def test_TensorFromScalar():
s = aes.constant(56)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论