提交 f4536c30 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Make Split a view_op

This allows the outputs to be views of the inputs. The Python and Numba implementation do that, but the C still performs a copy
上级 91966e85
...@@ -1903,6 +1903,7 @@ class Split(COp): ...@@ -1903,6 +1903,7 @@ class Split(COp):
b == [3, 4] b == [3, 4]
c == [5] c == [5]
TODO: Don't make a copy in C impl
""" """
len_splits = None len_splits = None
...@@ -1913,6 +1914,7 @@ class Split(COp): ...@@ -1913,6 +1914,7 @@ class Split(COp):
def __init__(self, len_splits): def __init__(self, len_splits):
self.len_splits = int(len_splits) self.len_splits = int(len_splits)
self.view_map = {i: [0] for i in range(self.len_splits)}
def __str__(self): def __str__(self):
return f"{self.__class__.__name__ }{{{self.len_splits}}}" return f"{self.__class__.__name__ }{{{self.len_splits}}}"
...@@ -1949,7 +1951,7 @@ class Split(COp): ...@@ -1949,7 +1951,7 @@ class Split(COp):
split_outs = np.split(x, np.cumsum(splits[:-1]), axis=axis) split_outs = np.split(x, np.cumsum(splits[:-1]), axis=axis)
for i, out in enumerate(split_outs): for i, out in enumerate(split_outs):
outputs[i][0] = out.copy() outputs[i][0] = out
def infer_shape(self, fgraph, node, in_shapes): def infer_shape(self, fgraph, node, in_shapes):
axis = node.inputs[1] axis = node.inputs[1]
......
...@@ -4,10 +4,11 @@ import pytest ...@@ -4,10 +4,11 @@ import pytest
import pytensor.scalar as aes import pytensor.scalar as aes
import pytensor.tensor as at import pytensor.tensor as at
import pytensor.tensor.basic as atb import pytensor.tensor.basic as atb
from pytensor import config from pytensor import config, function
from pytensor.compile.sharedvalue import SharedVariable from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Constant from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.scalar import Add
from pytensor.tensor.shape import Unbroadcast from pytensor.tensor.shape import Unbroadcast
from tests.link.numba.test_basic import ( from tests.link.numba.test_basic import (
compare_numba_and_py, compare_numba_and_py,
...@@ -332,6 +333,30 @@ def test_Split(n_splits, axis, values, sizes): ...@@ -332,6 +333,30 @@ def test_Split(n_splits, axis, values, sizes):
) )
def test_Split_view():
# https://github.com/pymc-devs/pytensor/issues/343
x1 = at.matrix("x1")
x2 = at.matrix("x2", shape=(None, 1))
v = at.vector("v", shape=(2,), dtype=int)
out = at.split(x1, v, n_splits=2, axis=1)[0] + x2
fn = function([x1, x2, v], out, mode="NUMBA")
# Check that the addition of split[0] and x2 is not in place
add_op = fn.maker.fgraph.outputs[0].owner.op
assert isinstance(add_op.scalar_op, Add)
assert not add_op.inplace_pattern
rng = np.random.default_rng(123)
test_x1 = rng.normal(size=(2, 2))
test_x2 = rng.normal(size=(2, 1))
test_v = np.array([1, 1])
np.testing.assert_allclose(
fn(test_x1, test_x2, test_v).copy(),
fn(test_x1, test_x2, test_v).copy(),
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"val, offset", "val, offset",
[ [
......
...@@ -1372,15 +1372,28 @@ def test_local_useless_split(): ...@@ -1372,15 +1372,28 @@ def test_local_useless_split():
f_rewritten(np.random.random((4, 4)).astype(config.floatX), [4]) f_rewritten(np.random.random((4, 4)).astype(config.floatX), [4])
f_not_rewritten(np.random.random((4, 4)).astype(config.floatX), [1, 2, 1]) f_not_rewritten(np.random.random((4, 4)).astype(config.floatX), [1, 2, 1])
graph_rewritten = f_rewritten.maker.fgraph.toposort() graph_rewritten = f_rewritten.maker.fgraph
graph_not_rewritten = f_not_rewritten.maker.fgraph.toposort() graph_not_rewritten = f_not_rewritten.maker.fgraph
assert isinstance(graph_rewritten[-1].op, DeepCopyOp) assert all(
assert len(graph_not_rewritten) == 1 isinstance(out.owner.op, DeepCopyOp) for out in graph_not_rewritten.outputs
assert isinstance(graph_not_rewritten[0].op, Split) )
assert all(isinstance(out.owner.op, DeepCopyOp) for out in graph_rewritten.outputs)
assert sum(isinstance(node.op, Split) for node in graph_rewritten.apply_nodes) == 0
assert (
sum(isinstance(node.op, Split) for node in graph_not_rewritten.apply_nodes) == 1
)
assert sum(isinstance(node.op, Assert) for node in graph_rewritten.apply_nodes) == 2
assert (
sum(isinstance(node.op, Assert) for node in graph_not_rewritten.apply_nodes)
== 0
)
# The DeepCopy Ops don't have traces, so we can't check "all"
assert check_stack_trace(f_rewritten, ops_to_check=[Assert]) assert check_stack_trace(f_rewritten, ops_to_check=[Assert])
assert check_stack_trace(f_not_rewritten, ops_to_check="all") assert check_stack_trace(f_not_rewritten, ops_to_check=[Split])
@pytest.mark.parametrize("i", list(range(1, 4))) @pytest.mark.parametrize("i", list(range(1, 4)))
......
...@@ -11,7 +11,7 @@ import pytensor.tensor.basic as at ...@@ -11,7 +11,7 @@ import pytensor.tensor.basic as at
import pytensor.tensor.math as tm import pytensor.tensor.math as tm
from pytensor import compile, config, function, shared from pytensor import compile, config, function, shared
from pytensor.compile.io import In, Out from pytensor.compile.io import In, Out
from pytensor.compile.mode import get_default_mode from pytensor.compile.mode import Mode, get_default_mode
from pytensor.compile.ops import DeepCopyOp from pytensor.compile.ops import DeepCopyOp
from pytensor.gradient import grad, hessian from pytensor.gradient import grad, hessian
from pytensor.graph.basic import Apply from pytensor.graph.basic import Apply
...@@ -2002,8 +2002,7 @@ class TestJoinAndSplit: ...@@ -2002,8 +2002,7 @@ class TestJoinAndSplit:
y = Split(2)(x, 0, [s, 5 - s])[0] y = Split(2)(x, 0, [s, 5 - s])[0]
assert y.type.shape == (None,) assert y.type.shape == (None,)
def test_join_inplace(self):
def test_join_inplace():
# Test join to work inplace. # Test join to work inplace.
# #
# This function tests the case when several elements are passed to the # This function tests the case when several elements are passed to the
...@@ -2025,8 +2024,7 @@ def test_join_inplace(): ...@@ -2025,8 +2024,7 @@ def test_join_inplace():
assert f(data, 0) is data assert f(data, 0) is data
assert np.allclose(f(data, 0), [3, 4, 5]) assert np.allclose(f(data, 0), [3, 4, 5])
def test_join_oneInput(self):
def test_join_oneInput():
# Test join when only 1 input is given. # Test join when only 1 input is given.
# #
# This functions tests the case when concatenate is called # This functions tests the case when concatenate is called
...@@ -2042,6 +2040,28 @@ def test_join_oneInput(): ...@@ -2042,6 +2040,28 @@ def test_join_oneInput():
assert join_0 is x_0 assert join_0 is x_0
assert join_1 is not x_0 assert join_1 is not x_0
@pytest.mark.parametrize("linker", ("py", "c"))
def test_split_view(self, linker):
x = vector("x")
axis = 0
op = Split(len_splits=3)
assert op.view_map == {0: [0], 1: [0], 2: [0]}
splits = op(x, axis, [0, 3, 2])
mode = Mode(linker)
f = pytensor.function(
[In(x, borrow=True)], [Out(s, borrow=True) for s in splits], mode=mode
)
x_test = np.arange(5, dtype=config.floatX)
res = f(x_test)
for r, expected in zip(res, ([], [0, 1, 2], [3, 4])):
assert np.allclose(r, expected)
if linker == "py":
assert r.base is x_test
else:
# C impl always makes a copy
assert r.base is not x_test
def test_TensorFromScalar(): def test_TensorFromScalar():
s = aes.constant(56) s = aes.constant(56)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论