提交 7b13a955 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Deprecate rarely used Function functionality

上级 82f6a14f
......@@ -387,6 +387,9 @@ class Function:
self.nodes_with_inner_function = []
self.output_keys = output_keys
if self.output_keys is not None:
warnings.warn("output_keys is deprecated.", FutureWarning)
assert len(self.input_storage) == len(self.maker.fgraph.inputs)
assert len(self.output_storage) == len(self.maker.fgraph.outputs)
......@@ -836,8 +839,10 @@ class Function:
t0 = time.perf_counter()
output_subset = kwargs.pop("output_subset", None)
if output_subset is not None and self.output_keys is not None:
output_subset = [self.output_keys.index(key) for key in output_subset]
if output_subset is not None:
warnings.warn("output_subset is deprecated.", FutureWarning)
if self.output_keys is not None:
output_subset = [self.output_keys.index(key) for key in output_subset]
# Reinitialize each container's 'provided' counter
if self.trust_input:
......@@ -1560,6 +1565,8 @@ class FunctionMaker:
)
for i in self.inputs
]
if any(self.refeed):
warnings.warn("Inputs with default values are deprecated.", FutureWarning)
def create(self, input_storage=None, storage_map=None):
"""
......
......@@ -35,6 +35,9 @@ from pytensor.tensor.type import (
)
pytestmark = pytest.mark.filterwarnings("error")
def PatternOptimizer(p1, p2, ign=True):
return OpKeyGraphRewriter(PatternNodeRewriter(p1, p2), ignore_newtrees=ign)
......@@ -195,7 +198,10 @@ class TestFunction:
x, s = scalars("xs")
# x's name is not ignored (as in test_naming_rule2) because a has a default value.
f = function([x, In(a, value=1.0), s], a / s + x)
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
f = function([x, In(a, value=1.0), s], a / s + x)
assert f(9, 2, 4) == 9.5 # can specify all args in order
assert f(9, 2, s=4) == 9.5 # can give s as kwarg
assert f(9, s=4) == 9.25 # can give s as kwarg, get default a
......@@ -214,7 +220,10 @@ class TestFunction:
a = scalar() # the a is for 'anonymous' (un-named).
x, s = scalars("xs")
f = function([x, In(a, value=1.0, name="a"), s], a / s + x)
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
f = function([x, In(a, value=1.0, name="a"), s], a / s + x)
assert f(9, 2, 4) == 9.5 # can specify all args in order
assert f(9, 2, s=4) == 9.5 # can give s as kwarg
......@@ -248,11 +257,14 @@ class TestFunction:
a = scalar()
x, s = scalars("xs")
f = function(
[x, In(a, value=1.0, name="a"), In(s, value=0.0, update=s + a * x)],
s + a * x,
mode=mode,
)
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
f = function(
[x, In(a, value=1.0, name="a"), In(s, value=0.0, update=s + a * x)],
s + a * x,
mode=mode,
)
assert f[a] == 1.0
assert f[s] == 0.0
......@@ -303,16 +315,19 @@ class TestFunction:
a = scalar()
x, s = scalars("xs")
f = function(
[
x,
In(a, value=1.0, name="a"),
In(s, value=0.0, update=s + a * x, mutable=True),
],
s + a * x,
)
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
f = function(
[
x,
In(a, value=1.0, name="a"),
In(s, value=0.0, update=s + a * x, mutable=True),
],
s + a * x,
)
g = copy.copy(f)
g = copy.copy(f)
assert f.unpack_single == g.unpack_single
assert f.trust_input == g.trust_input
......@@ -504,22 +519,25 @@ class TestFunction:
a = scalar() # the a is for 'anonymous' (un-named).
x, s = scalars("xs")
f = function(
[
x,
In(a, value=1.0, name="a"),
In(s, value=0.0, update=s + a * x, mutable=True),
],
s + a * x,
)
g = function(
[
x,
In(a, value=1.0, name="a"),
In(s, value=f.container[s], update=s - a * x, mutable=True),
],
s + a * x,
)
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
f = function(
[
x,
In(a, value=1.0, name="a"),
In(s, value=0.0, update=s + a * x, mutable=True),
],
s + a * x,
)
g = function(
[
x,
In(a, value=1.0, name="a"),
In(s, value=f.container[s], update=s - a * x, mutable=True),
],
s + a * x,
)
f(1, 2)
assert f[s] == 2
......@@ -532,17 +550,20 @@ class TestFunction:
a = scalar() # the a is for 'anonymous' (un-named).
x, s = scalars("xs")
f = function(
[
x,
In(a, value=1.0, name="a"),
In(s, value=0.0, update=s + a * x, mutable=True),
],
s + a * x,
)
g = function(
[x, In(a, value=1.0, name="a"), In(s, value=f.container[s])], s + a * x
)
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
f = function(
[
x,
In(a, value=1.0, name="a"),
In(s, value=0.0, update=s + a * x, mutable=True),
],
s + a * x,
)
g = function(
[x, In(a, value=1.0, name="a"), In(s, value=f.container[s])], s + a * x
)
f(1, 2)
assert f[s] == 2
......@@ -556,17 +577,20 @@ class TestFunction:
a = scalar() # the a is for 'anonymous' (un-named).
x, s = scalars("xs")
f = function(
[
x,
In(a, value=1.0, name="a"),
In(s, value=0.0, update=s + a * x, mutable=False),
],
s + a * x,
)
g = function(
[x, In(a, value=1.0, name="a"), In(s, value=f.container[s])], s + a * x
)
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
f = function(
[
x,
In(a, value=1.0, name="a"),
In(s, value=0.0, update=s + a * x, mutable=False),
],
s + a * x,
)
g = function(
[x, In(a, value=1.0, name="a"), In(s, value=f.container[s])], s + a * x
)
f(1, 2)
assert f[s] == 2
......@@ -718,7 +742,10 @@ class TestFunction:
a, b = dscalars("a", "b")
c = a + b
funct = function([In(a, name="first"), In(b, value=1, name="second")], c)
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
funct = function([In(a, name="first"), In(b, value=1, name="second")], c)
x = funct(first=1)
try:
funct(second=2)
......@@ -775,7 +802,8 @@ class TestFunction:
# Tests that function works when outputs is a dictionary
x = scalar()
f = function([x], outputs={"a": x, "c": x * 2, "b": x * 3, "1": x * 4})
with pytest.warns(FutureWarning, match="output_keys is deprecated."):
f = function([x], outputs={"a": x, "c": x * 2, "b": x * 3, "1": x * 4})
outputs = f(10.0)
......@@ -790,7 +818,8 @@ class TestFunction:
x = scalar("x")
y = scalar("y")
f = function([x, y], outputs={"a": x + y, "b": x * y})
with pytest.warns(FutureWarning, match="output_keys is deprecated."):
f = function([x, y], outputs={"a": x + y, "b": x * y})
assert f(2, 4) == {"a": 6, "b": 8}
assert f(2, y=4) == f(2, 4)
......@@ -805,9 +834,10 @@ class TestFunction:
e1 = scalar("1")
e2 = scalar("2")
f = function(
[x, y, z, e1, e2], outputs={"x": x, "y": y, "z": z, "1": e1, "2": e2}
)
with pytest.warns(FutureWarning, match="output_keys is deprecated."):
f = function(
[x, y, z, e1, e2], outputs={"x": x, "y": y, "z": z, "1": e1, "2": e2}
)
assert "1" in str(f.outputs[0])
assert "2" in str(f.outputs[1])
......@@ -825,7 +855,8 @@ class TestFunction:
a = x + y
b = x * y
f = function([x, y], outputs={"a": a, "b": b})
with pytest.warns(FutureWarning, match="output_keys is deprecated."):
f = function([x, y], outputs={"a": a, "b": b})
a = scalar("a")
b = scalar("b")
......@@ -880,14 +911,17 @@ class TestPicklefunction:
a = scalar() # the a is for 'anonymous' (un-named).
x, s = scalars("xs")
f = function(
[
x,
In(a, value=1.0, name="a", mutable=True),
In(s, value=0.0, update=s + a * x, mutable=True),
],
s + a * x,
)
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
f = function(
[
x,
In(a, value=1.0, name="a", mutable=True),
In(s, value=0.0, update=s + a * x, mutable=True),
],
s + a * x,
)
try:
g = copy.deepcopy(f)
except NotImplementedError as e:
......@@ -941,14 +975,17 @@ class TestPicklefunction:
a = dscalar() # the a is for 'anonymous' (un-named).
x, s = dscalars("xs")
f = function(
[
x,
In(a, value=1.0, name="a"),
In(s, value=0.0, update=s + a * x, mutable=True),
],
s + a * x,
)
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
f = function(
[
x,
In(a, value=1.0, name="a"),
In(s, value=0.0, update=s + a * x, mutable=True),
],
s + a * x,
)
f.trust_input = True
try:
g = copy.deepcopy(f)
......@@ -967,11 +1004,13 @@ class TestPicklefunction:
def test_output_keys(self):
x = vector()
f = function([x], {"vec": x**2})
with pytest.warns(FutureWarning, match="output_keys is deprecated."):
f = function([x], {"vec": x**2})
o = f([2, 3, 4])
assert isinstance(o, dict)
assert np.allclose(o["vec"], [4, 9, 16])
g = copy.deepcopy(f)
with pytest.warns(FutureWarning, match="output_keys is deprecated."):
g = copy.deepcopy(f)
o = g([2, 3, 4])
assert isinstance(o, dict)
assert np.allclose(o["vec"], [4, 9, 16])
......@@ -980,7 +1019,10 @@ class TestPicklefunction:
# Ensure that shared containers remain shared after a deep copy.
a, x = scalars("ax")
h = function([In(a, value=0.0)], a)
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
h = function([In(a, value=0.0)], a)
f = function([x, In(a, value=h.container[a], implicit=True)], x + a)
try:
......@@ -1004,14 +1046,17 @@ class TestPicklefunction:
a = scalar() # the a is for 'anonymous' (un-named).
x, s = scalars("xs")
f = function(
[
x,
In(a, value=1.0, name="a"),
In(s, value=0.0, update=s + a * x, mutable=True),
],
s + a * x,
)
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
f = function(
[
x,
In(a, value=1.0, name="a"),
In(s, value=0.0, update=s + a * x, mutable=True),
],
s + a * x,
)
try:
# Note that here we also test protocol 0 on purpose, since it
......@@ -1105,25 +1150,31 @@ class TestPicklefunction:
# some derived thing, whose inputs aren't all in the list
list_of_things.append(a * x + s)
f1 = function(
[
x,
In(a, value=1.0, name="a"),
In(s, value=0.0, update=s + a * x, mutable=True),
],
s + a * x,
)
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
f1 = function(
[
x,
In(a, value=1.0, name="a"),
In(s, value=0.0, update=s + a * x, mutable=True),
],
s + a * x,
)
list_of_things.append(f1)
# now put in a function sharing container with the previous one
f2 = function(
[
x,
In(a, value=1.0, name="a"),
In(s, value=f1.container[s], update=s + a * x, mutable=True),
],
s + a * x,
)
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
f2 = function(
[
x,
In(a, value=1.0, name="a"),
In(s, value=f1.container[s], update=s + a * x, mutable=True),
],
s + a * x,
)
list_of_things.append(f2)
assert isinstance(f2.container[s].storage, list)
......@@ -1131,7 +1182,10 @@ class TestPicklefunction:
# now put in a function with non-scalar
v_value = np.asarray([2, 3, 4.0], dtype=config.floatX)
f3 = function([x, In(v, value=v_value)], x + v)
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
f3 = function([x, In(v, value=v_value)], x + v)
list_of_things.append(f3)
# try to pickle the entire things
......@@ -1263,23 +1317,29 @@ class SomethingToPickle:
self.e = a * x + s
self.f1 = function(
[
x,
In(a, value=1.0, name="a"),
In(s, value=0.0, update=s + a * x, mutable=True),
],
s + a * x,
)
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
self.f1 = function(
[
x,
In(a, value=1.0, name="a"),
In(s, value=0.0, update=s + a * x, mutable=True),
],
s + a * x,
)
self.f2 = function(
[
x,
In(a, value=1.0, name="a"),
In(s, value=self.f1.container[s], update=s + a * x, mutable=True),
],
s + a * x,
)
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
self.f2 = function(
[
x,
In(a, value=1.0, name="a"),
In(s, value=self.f1.container[s], update=s + a * x, mutable=True),
],
s + a * x,
)
def test_empty_givens_updates():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论