提交 7b13a955 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Deprecate rarely used Function functionality

上级 82f6a14f
......@@ -387,6 +387,9 @@ class Function:
self.nodes_with_inner_function = []
self.output_keys = output_keys
if self.output_keys is not None:
warnings.warn("output_keys is deprecated.", FutureWarning)
assert len(self.input_storage) == len(self.maker.fgraph.inputs)
assert len(self.output_storage) == len(self.maker.fgraph.outputs)
......@@ -836,7 +839,9 @@ class Function:
t0 = time.perf_counter()
output_subset = kwargs.pop("output_subset", None)
if output_subset is not None and self.output_keys is not None:
if output_subset is not None:
warnings.warn("output_subset is deprecated.", FutureWarning)
if self.output_keys is not None:
output_subset = [self.output_keys.index(key) for key in output_subset]
# Reinitialize each container's 'provided' counter
......@@ -1560,6 +1565,8 @@ class FunctionMaker:
)
for i in self.inputs
]
if any(self.refeed):
warnings.warn("Inputs with default values are deprecated.", FutureWarning)
def create(self, input_storage=None, storage_map=None):
"""
......
......@@ -35,6 +35,9 @@ from pytensor.tensor.type import (
)
pytestmark = pytest.mark.filterwarnings("error")
def PatternOptimizer(p1, p2, ign=True):
return OpKeyGraphRewriter(PatternNodeRewriter(p1, p2), ignore_newtrees=ign)
......@@ -195,6 +198,9 @@ class TestFunction:
x, s = scalars("xs")
# x's name is not ignored (as in test_naming_rule2) because a has a default value.
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
f = function([x, In(a, value=1.0), s], a / s + x)
assert f(9, 2, 4) == 9.5 # can specify all args in order
assert f(9, 2, s=4) == 9.5 # can give s as kwarg
......@@ -214,6 +220,9 @@ class TestFunction:
a = scalar() # the a is for 'anonymous' (un-named).
x, s = scalars("xs")
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
f = function([x, In(a, value=1.0, name="a"), s], a / s + x)
assert f(9, 2, 4) == 9.5 # can specify all args in order
......@@ -248,6 +257,9 @@ class TestFunction:
a = scalar()
x, s = scalars("xs")
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
f = function(
[x, In(a, value=1.0, name="a"), In(s, value=0.0, update=s + a * x)],
s + a * x,
......@@ -303,6 +315,9 @@ class TestFunction:
a = scalar()
x, s = scalars("xs")
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
f = function(
[
x,
......@@ -504,6 +519,9 @@ class TestFunction:
a = scalar() # the a is for 'anonymous' (un-named).
x, s = scalars("xs")
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
f = function(
[
x,
......@@ -532,6 +550,9 @@ class TestFunction:
a = scalar() # the a is for 'anonymous' (un-named).
x, s = scalars("xs")
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
f = function(
[
x,
......@@ -556,6 +577,9 @@ class TestFunction:
a = scalar() # the a is for 'anonymous' (un-named).
x, s = scalars("xs")
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
f = function(
[
x,
......@@ -718,6 +742,9 @@ class TestFunction:
a, b = dscalars("a", "b")
c = a + b
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
funct = function([In(a, name="first"), In(b, value=1, name="second")], c)
x = funct(first=1)
try:
......@@ -775,6 +802,7 @@ class TestFunction:
# Tests that function works when outputs is a dictionary
x = scalar()
with pytest.warns(FutureWarning, match="output_keys is deprecated."):
f = function([x], outputs={"a": x, "c": x * 2, "b": x * 3, "1": x * 4})
outputs = f(10.0)
......@@ -790,6 +818,7 @@ class TestFunction:
x = scalar("x")
y = scalar("y")
with pytest.warns(FutureWarning, match="output_keys is deprecated."):
f = function([x, y], outputs={"a": x + y, "b": x * y})
assert f(2, 4) == {"a": 6, "b": 8}
......@@ -805,6 +834,7 @@ class TestFunction:
e1 = scalar("1")
e2 = scalar("2")
with pytest.warns(FutureWarning, match="output_keys is deprecated."):
f = function(
[x, y, z, e1, e2], outputs={"x": x, "y": y, "z": z, "1": e1, "2": e2}
)
......@@ -825,6 +855,7 @@ class TestFunction:
a = x + y
b = x * y
with pytest.warns(FutureWarning, match="output_keys is deprecated."):
f = function([x, y], outputs={"a": a, "b": b})
a = scalar("a")
......@@ -880,6 +911,9 @@ class TestPicklefunction:
a = scalar() # the a is for 'anonymous' (un-named).
x, s = scalars("xs")
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
f = function(
[
x,
......@@ -941,6 +975,9 @@ class TestPicklefunction:
a = dscalar() # the a is for 'anonymous' (un-named).
x, s = dscalars("xs")
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
f = function(
[
x,
......@@ -967,10 +1004,12 @@ class TestPicklefunction:
def test_output_keys(self):
x = vector()
with pytest.warns(FutureWarning, match="output_keys is deprecated."):
f = function([x], {"vec": x**2})
o = f([2, 3, 4])
assert isinstance(o, dict)
assert np.allclose(o["vec"], [4, 9, 16])
with pytest.warns(FutureWarning, match="output_keys is deprecated."):
g = copy.deepcopy(f)
o = g([2, 3, 4])
assert isinstance(o, dict)
......@@ -980,6 +1019,9 @@ class TestPicklefunction:
# Ensure that shared containers remain shared after a deep copy.
a, x = scalars("ax")
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
h = function([In(a, value=0.0)], a)
f = function([x, In(a, value=h.container[a], implicit=True)], x + a)
......@@ -1004,6 +1046,9 @@ class TestPicklefunction:
a = scalar() # the a is for 'anonymous' (un-named).
x, s = scalars("xs")
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
f = function(
[
x,
......@@ -1105,6 +1150,9 @@ class TestPicklefunction:
# some derived thing, whose inputs aren't all in the list
list_of_things.append(a * x + s)
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
f1 = function(
[
x,
......@@ -1116,6 +1164,9 @@ class TestPicklefunction:
list_of_things.append(f1)
# now put in a function sharing container with the previous one
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
f2 = function(
[
x,
......@@ -1131,6 +1182,9 @@ class TestPicklefunction:
# now put in a function with non-scalar
v_value = np.asarray([2, 3, 4.0], dtype=config.floatX)
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
f3 = function([x, In(v, value=v_value)], x + v)
list_of_things.append(f3)
......@@ -1263,6 +1317,9 @@ class SomethingToPickle:
self.e = a * x + s
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
self.f1 = function(
[
x,
......@@ -1272,6 +1329,9 @@ class SomethingToPickle:
s + a * x,
)
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
self.f2 = function(
[
x,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论