Unverified 提交 14651fb5 authored 作者: Dhruvanshu-Joshi's avatar Dhruvanshu-Joshi 提交者: GitHub

Add support for negative axis in `specify_broadcastable` (#710)

上级 bcd81c78
......@@ -4,6 +4,7 @@ from textwrap import dedent
from typing import cast
import numpy as np
from numpy.core.numeric import normalize_axis_tuple # type: ignore
import pytensor
from pytensor.gradient import DisconnectedType
......@@ -994,9 +995,7 @@ def specify_broadcastable(x, *axes):
if not axes:
return x
if max(axes) >= x.type.ndim:
raise ValueError("Trying to specify broadcastable of non-existent dimension")
axes = normalize_axis_tuple(axes, x.type.ndim)
shape_info = [1 if i in axes else s for i, s in enumerate(x.type.shape)]
return specify_shape(x, shape_info)
......
......@@ -562,16 +562,22 @@ class TestSpecifyBroadcastable:
x = matrix()
assert specify_broadcastable(x, 0).type.shape == (1, None)
assert specify_broadcastable(x, 1).type.shape == (None, 1)
assert specify_broadcastable(x, -1).type.shape == (None, 1)
assert specify_broadcastable(x, 0, 1).type.shape == (1, 1)
x = row()
assert specify_broadcastable(x, 0) is x
assert specify_broadcastable(x, 1) is not x
assert specify_broadcastable(x, -2) is x
def test_validation(self):
x = matrix()
with pytest.raises(ValueError, match="^Trying to specify broadcastable of*"):
specify_broadcastable(x, 2)
axis = 2
with pytest.raises(
ValueError,
match=f"axis {axis} is out of bounds for array of dimension {axis}",
):
specify_broadcastable(x, axis)
class TestRopLop(RopLopChecker):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论