提交 ab304cb9 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Deprecate use of "default" and Variable as OpFromGrah overrides

上级 6dfc811f
......@@ -11,7 +11,7 @@ from pytensor.configdefaults import config
from pytensor.gradient import DisconnectedType, Rop, disconnected_type, grad
from pytensor.graph.basic import equal_computations
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.null_type import NullType
from pytensor.graph.null_type import NullType, null_type
from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.graph.utils import MissingInputError
from pytensor.printing import debugprint
......@@ -93,6 +93,20 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
assert res.shape == (2, 5)
assert np.all(180.0 == res)
def test_overrides_deprecated_api(self):
inp = scalar("x")
out = inp + 1
for kwarg in ("lop_overrides", "grad_overrides", "rop_overrides"):
with pytest.raises(
ValueError, match="'default' is no longer a valid value for overrides"
):
OpFromGraph([inp], [out], **{kwarg: "default"})
with pytest.raises(
TypeError, match="Variables are no longer valid types for overrides"
):
OpFromGraph([inp], [out], **{kwarg: null_type()})
@pytest.mark.parametrize(
"cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)]
)
......@@ -211,9 +225,7 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
w, b = vectors("wb")
# we make the 3rd gradient default (no override)
with pytest.warns(FutureWarning, match="grad_overrides is deprecated"):
op_linear = cls_ofg(
[x, w, b], [x * w + b], grad_overrides=[go1, go2, "default"]
)
op_linear = cls_ofg([x, w, b], [x * w + b], grad_overrides=[go1, go2, None])
xx, ww, bb = vector("xx"), vector("yy"), vector("bb")
zz = pt_sum(op_linear(xx, ww, bb))
dx, dw, db = grad(zz, [xx, ww, bb])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论