Unverified 提交 68d8dc72 authored 作者: Pablo de Roque's avatar Pablo de Roque 提交者: GitHub

Converts negative constant axis to positive if present in `Join(COp)` (#1527)

上级 4ce092fe
...@@ -2470,6 +2470,18 @@ class Join(COp): ...@@ -2470,6 +2470,18 @@ class Join(COp):
if axis.type.ndim > 0: if axis.type.ndim > 0:
raise TypeError(f"Axis {axis} must be 0-d.") raise TypeError(f"Axis {axis} must be 0-d.")
# Convert negative constant axis to positive during canonicalization
if isinstance(axis, Constant) and tensors:
# Get the axis value directly from the constant's data
axis_val = axis.data.item()
# Check if it's negative and needs normalization
if axis_val < 0:
ndim = tensors[0].ndim
# Convert negative axis to positive
axis_val = normalize_axis_index(axis_val, ndim)
# Replace the original axis with the normalized one
axis = constant(axis_val, dtype=axis.type.dtype)
tensors = [as_tensor_variable(x) for x in tensors] tensors = [as_tensor_variable(x) for x in tensors]
if not builtins.all(targs.type.ndim > 0 for targs in tensors): if not builtins.all(targs.type.ndim > 0 for targs in tensors):
......
...@@ -2179,6 +2179,15 @@ class TestJoinAndSplit: ...@@ -2179,6 +2179,15 @@ class TestJoinAndSplit:
assert fn(*test_values).shape == (n * 6, n)[:ndim] if axis == 0 else (n, n * 6) assert fn(*test_values).shape == (n * 6, n)[:ndim] if axis == 0 else (n, n * 6)
benchmark(fn, *test_values) benchmark(fn, *test_values)
def test_join_negative_axis_rewrite(self):
"""Test that constant negative axis is rewritten to positive axis in make_node."""
v = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=self.floatX)
a = self.shared(v)
b = as_tensor_variable(v)
assert equal_computations([join(-1, a, b)], [join(1, a, b)])
assert equal_computations([join(-2, a, b)], [join(0, a, b)])
def test_TensorFromScalar(): def test_TensorFromScalar():
s = ps.constant(56) s = ps.constant(56)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论