Unverified 提交 0b731c27 authored 作者: Tat Chan's avatar Tat Chan 提交者: GitHub

Rewrite concatenate([x, x]) as tile (#1714)

上级 ee568260
...@@ -77,6 +77,7 @@ from pytensor.tensor.basic import ( ...@@ -77,6 +77,7 @@ from pytensor.tensor.basic import (
register_infer_shape, register_infer_shape,
switch, switch,
tensor_copy, tensor_copy,
tile,
zeros, zeros,
zeros_like, zeros_like,
) )
...@@ -910,6 +911,53 @@ def local_join_make_vector(fgraph, node): ...@@ -910,6 +911,53 @@ def local_join_make_vector(fgraph, node):
return [ret] return [ret]
@register_canonicalize
@node_rewriter([Join])
def local_join_to_repeat(fgraph, node):
"""Join(axis, x, x, x, ...) -> tile(x, reps)
When the same tensor is concatenated multiple times along an axis,
replace with a single tile operation which is more efficient.
Examples
--------
join(0, x, x, x) -> tile(x, (3, 1, 1, ...))
join(1, x, x) -> tile(x, (1, 2, 1, ...))
"""
# Extract axis and the tensors being joined
axis, *tensors = node.inputs
# Optimization only applies when axis is constant
if not isinstance(axis, Constant):
return None
# Extract the Python integer from the constant
axis_val = axis.data
# Need at least 2 tensors to consider optimization
if len(tensors) <= 1:
return
# Check if all tensors are identical
if not all(t == tensors[0] for t in tensors[1:]):
return
n_reps = len(tensors)
first_tensor = tensors[0]
ndim = first_tensor.ndim
# Build reps tuple to repeat only along the join axis
# For shape (a, b, c) joining at axis 1: reps = (1, n_reps, 1)
# This directly concatenates n_reps copies along axis_val
reps = tuple(n_reps if i == axis_val else 1 for i in range(ndim))
result = tile(first_tensor, reps)
# Preserve debugging information
copy_stack_trace(node.outputs[0], result)
return [result]
@register_specialize @register_specialize
@register_canonicalize @register_canonicalize
@register_useless @register_useless
......
...@@ -1237,33 +1237,98 @@ def test_local_join_1(): ...@@ -1237,33 +1237,98 @@ def test_local_join_1():
assert len([n for n in e if isinstance(n.op, Join)]) == 0 assert len([n for n in e if isinstance(n.op, Join)]) == 0
assert f.maker.fgraph.outputs[0].dtype == config.floatX assert f.maker.fgraph.outputs[0].dtype == config.floatX
# test we don't apply when their is 2 inputs # Test that join with 2 different inputs remains (not optimized away)
s = join(1, a, a) s = join(1, a, a[:, ::-1])
f = function([a], s, mode=rewrite_mode) f = function([a], s, mode=rewrite_mode)
val = f([[1]]) val = f([[1, 2]])
assert np.all(val == [[1]]) assert np.all(val == [[1, 2, 2, 1]]) # joined along axis 1
e = f.maker.fgraph.toposort() e = f.maker.fgraph.toposort()
assert len([n for n in e if isinstance(n.op, Join)]) == 1 assert len([n for n in e if isinstance(n.op, Join)]) == 1 # join remains
assert f.maker.fgraph.outputs[0].dtype == config.floatX assert f.maker.fgraph.outputs[0].dtype == config.floatX
def test_local_join_to_tile():
"""Join(axis, x, x, ...) is rewritten to tile(x, reps) with reps[axis] = k.
This optimization applies whenever we concatenate the *same* tensor multiple
times along a given axis. It replaces the Join/concatenate with a Tile op.
"""
# ---- Case 1: joining same vector along axis 0 ----
x = vector("x")
s = join(0, x, x, x) # (3n,)
f = function([x], s, mode=rewrite_mode)
test_val = np.array([1.0, 2.0], dtype=config.floatX)
result = f(test_val)
expected = np.array([1.0, 2.0, 1.0, 2.0, 1.0, 2.0], dtype=config.floatX)
assert np.allclose(result, expected)
# Join should be optimized away
ops = f.maker.fgraph.toposort()
assert not any(isinstance(n.op, Join) for n in ops)
# ---- Case 2: joining same matrix along axis 0 ----
a = matrix("a")
s = join(0, a, a) # (2m, n)
f = function([a], s, mode=rewrite_mode)
test_mat = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
result = f(test_mat)
expected = np.vstack([test_mat, test_mat])
assert np.allclose(result, expected)
ops = f.maker.fgraph.toposort()
assert not any(isinstance(n.op, Join) for n in ops)
# ---- Case 3: joining same matrix along axis 1 ----
s = join(1, a, a, a) # (m, 3n)
f = function([a], s, mode=rewrite_mode)
result = f(test_mat)
expected = np.hstack([test_mat, test_mat, test_mat])
assert np.allclose(result, expected)
ops = f.maker.fgraph.toposort()
assert not any(isinstance(n.op, Join) for n in ops)
# ---- Case 4: different tensors -> should NOT optimize ----
y = vector("y")
s = join(0, x, y) # inputs differ
f = function([x, y], s, mode=rewrite_mode)
test_vec1 = np.array([1.0, 2.0], dtype=config.floatX)
test_vec2 = np.array([3.0, 4.0], dtype=config.floatX)
result = f(test_vec1, test_vec2)
expected = np.array([1.0, 2.0, 3.0, 4.0], dtype=config.floatX)
assert np.allclose(result, expected)
# Join should still be present since inputs aren't identical
ops = f.maker.fgraph.toposort()
assert any(isinstance(n.op, Join) for n in ops)
def test_local_join_empty(): def test_local_join_empty():
# Vector case # Vector case - empty tensors should be removed
empty_vec = np.asarray([], dtype=config.floatX) empty_vec = np.asarray([], dtype=config.floatX)
vec = vector("vec") vec = vector("vec")
s = pt.join(0, vec, vec, empty_vec) s = pt.join(0, vec, vec[::-1], empty_vec)
new_s = rewrite_graph(s) new_s = rewrite_graph(s)
assert equal_computations([new_s], [join(0, vec, vec)])
assert new_s.dtype == s.dtype assert new_s.dtype == s.dtype
# Verify that empty tensors are removed from the join
expected = pt.join(0, vec, vec[::-1])
assert equal_computations([new_s], [expected])
# Matrix case # Matrix case - empty tensors should be removed
empty_mat = np.zeros((2, 0), dtype=config.floatX) empty_mat = np.zeros((2, 0), dtype=config.floatX)
empty_sym_mat = matrix("m", shape=(2, 0)) empty_sym_mat = matrix("m", shape=(2, 0))
mat = matrix("mat", shape=(2, 10)) mat = matrix("mat", shape=(2, 10))
s = join(1, empty_mat, mat, empty_sym_mat, mat, mat) s = join(1, empty_mat, mat, empty_sym_mat, mat[:, ::-1])
new_s = rewrite_graph(s) new_s = rewrite_graph(s)
assert equal_computations([new_s], [join(1, mat, mat, mat)])
assert new_s.dtype == s.dtype assert new_s.dtype == s.dtype
# Verify that empty tensors are removed from the join
expected = join(1, mat, mat[:, ::-1])
assert equal_computations([new_s], [expected])
# Join can be completely removed, but casting and specify_shape are propagated # Join can be completely removed, but casting and specify_shape are propagated
int_mat = matrix("int_mat", dtype=int) int_mat = matrix("int_mat", dtype=int)
......
...@@ -2020,25 +2020,6 @@ class TestJoinAndSplit: ...@@ -2020,25 +2020,6 @@ class TestJoinAndSplit:
# This line used to crash. # This line used to crash.
ptb.concatenate([x, -u], axis=2) ptb.concatenate([x, -u], axis=2)
def test_concatenate_same(self):
# Test that we can concatenate the same tensor multiple time.
# In the past it was broken on the GPU.
rng = np.random.default_rng(seed=utt.fetch_seed())
T_shared = self.shared(rng.random((3, 4)).astype(self.floatX))
Tout = ptb.concatenate([T_shared, T_shared])
f = function([], Tout, mode=self.mode)
out = f()
if config.mode != "FAST_COMPILE":
assert [
True
for node in f.maker.fgraph.toposort()
if isinstance(node.op, type(self.join_op))
]
assert np.allclose(
out, np.concatenate([T_shared.get_value(), T_shared.get_value()])
)
def test_mixed_ndim_error(self): def test_mixed_ndim_error(self):
rng = np.random.default_rng(seed=utt.fetch_seed()) rng = np.random.default_rng(seed=utt.fetch_seed())
v = self.shared(rng.random(4).astype(self.floatX)) v = self.shared(rng.random(4).astype(self.floatX))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论