Unverified 提交 14651fb5 authored 作者: Dhruvanshu-Joshi's avatar Dhruvanshu-Joshi 提交者: GitHub

Add support for negative axis in `specify_broadcastable` (#710)

上级 bcd81c78
...@@ -4,6 +4,7 @@ from textwrap import dedent ...@@ -4,6 +4,7 @@ from textwrap import dedent
from typing import cast from typing import cast
import numpy as np import numpy as np
from numpy.core.numeric import normalize_axis_tuple # type: ignore
import pytensor import pytensor
from pytensor.gradient import DisconnectedType from pytensor.gradient import DisconnectedType
...@@ -994,9 +995,7 @@ def specify_broadcastable(x, *axes): ...@@ -994,9 +995,7 @@ def specify_broadcastable(x, *axes):
if not axes: if not axes:
return x return x
if max(axes) >= x.type.ndim: axes = normalize_axis_tuple(axes, x.type.ndim)
raise ValueError("Trying to specify broadcastable of non-existent dimension")
shape_info = [1 if i in axes else s for i, s in enumerate(x.type.shape)] shape_info = [1 if i in axes else s for i, s in enumerate(x.type.shape)]
return specify_shape(x, shape_info) return specify_shape(x, shape_info)
......
...@@ -562,16 +562,22 @@ class TestSpecifyBroadcastable: ...@@ -562,16 +562,22 @@ class TestSpecifyBroadcastable:
x = matrix() x = matrix()
assert specify_broadcastable(x, 0).type.shape == (1, None) assert specify_broadcastable(x, 0).type.shape == (1, None)
assert specify_broadcastable(x, 1).type.shape == (None, 1) assert specify_broadcastable(x, 1).type.shape == (None, 1)
assert specify_broadcastable(x, -1).type.shape == (None, 1)
assert specify_broadcastable(x, 0, 1).type.shape == (1, 1) assert specify_broadcastable(x, 0, 1).type.shape == (1, 1)
x = row() x = row()
assert specify_broadcastable(x, 0) is x assert specify_broadcastable(x, 0) is x
assert specify_broadcastable(x, 1) is not x assert specify_broadcastable(x, 1) is not x
assert specify_broadcastable(x, -2) is x
def test_validation(self): def test_validation(self):
x = matrix() x = matrix()
with pytest.raises(ValueError, match="^Trying to specify broadcastable of*"): axis = 2
specify_broadcastable(x, 2) with pytest.raises(
ValueError,
match=f"axis {axis} is out of bounds for array of dimension {axis}",
):
specify_broadcastable(x, axis)
class TestRopLop(RopLopChecker): class TestRopLop(RopLopChecker):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论