提交 7b13a955 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Deprecate rarely used Function functionality

上级 82f6a14f
...@@ -387,6 +387,9 @@ class Function: ...@@ -387,6 +387,9 @@ class Function:
self.nodes_with_inner_function = [] self.nodes_with_inner_function = []
self.output_keys = output_keys self.output_keys = output_keys
if self.output_keys is not None:
warnings.warn("output_keys is deprecated.", FutureWarning)
assert len(self.input_storage) == len(self.maker.fgraph.inputs) assert len(self.input_storage) == len(self.maker.fgraph.inputs)
assert len(self.output_storage) == len(self.maker.fgraph.outputs) assert len(self.output_storage) == len(self.maker.fgraph.outputs)
...@@ -836,7 +839,9 @@ class Function: ...@@ -836,7 +839,9 @@ class Function:
t0 = time.perf_counter() t0 = time.perf_counter()
output_subset = kwargs.pop("output_subset", None) output_subset = kwargs.pop("output_subset", None)
if output_subset is not None and self.output_keys is not None: if output_subset is not None:
warnings.warn("output_subset is deprecated.", FutureWarning)
if self.output_keys is not None:
output_subset = [self.output_keys.index(key) for key in output_subset] output_subset = [self.output_keys.index(key) for key in output_subset]
# Reinitialize each container's 'provided' counter # Reinitialize each container's 'provided' counter
...@@ -1560,6 +1565,8 @@ class FunctionMaker: ...@@ -1560,6 +1565,8 @@ class FunctionMaker:
) )
for i in self.inputs for i in self.inputs
] ]
if any(self.refeed):
warnings.warn("Inputs with default values are deprecated.", FutureWarning)
def create(self, input_storage=None, storage_map=None): def create(self, input_storage=None, storage_map=None):
""" """
......
...@@ -35,6 +35,9 @@ from pytensor.tensor.type import ( ...@@ -35,6 +35,9 @@ from pytensor.tensor.type import (
) )
pytestmark = pytest.mark.filterwarnings("error")
def PatternOptimizer(p1, p2, ign=True): def PatternOptimizer(p1, p2, ign=True):
return OpKeyGraphRewriter(PatternNodeRewriter(p1, p2), ignore_newtrees=ign) return OpKeyGraphRewriter(PatternNodeRewriter(p1, p2), ignore_newtrees=ign)
...@@ -195,6 +198,9 @@ class TestFunction: ...@@ -195,6 +198,9 @@ class TestFunction:
x, s = scalars("xs") x, s = scalars("xs")
# x's name is not ignored (as in test_naming_rule2) because a has a default value. # x's name is not ignored (as in test_naming_rule2) because a has a default value.
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
f = function([x, In(a, value=1.0), s], a / s + x) f = function([x, In(a, value=1.0), s], a / s + x)
assert f(9, 2, 4) == 9.5 # can specify all args in order assert f(9, 2, 4) == 9.5 # can specify all args in order
assert f(9, 2, s=4) == 9.5 # can give s as kwarg assert f(9, 2, s=4) == 9.5 # can give s as kwarg
...@@ -214,6 +220,9 @@ class TestFunction: ...@@ -214,6 +220,9 @@ class TestFunction:
a = scalar() # the a is for 'anonymous' (un-named). a = scalar() # the a is for 'anonymous' (un-named).
x, s = scalars("xs") x, s = scalars("xs")
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
f = function([x, In(a, value=1.0, name="a"), s], a / s + x) f = function([x, In(a, value=1.0, name="a"), s], a / s + x)
assert f(9, 2, 4) == 9.5 # can specify all args in order assert f(9, 2, 4) == 9.5 # can specify all args in order
...@@ -248,6 +257,9 @@ class TestFunction: ...@@ -248,6 +257,9 @@ class TestFunction:
a = scalar() a = scalar()
x, s = scalars("xs") x, s = scalars("xs")
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
f = function( f = function(
[x, In(a, value=1.0, name="a"), In(s, value=0.0, update=s + a * x)], [x, In(a, value=1.0, name="a"), In(s, value=0.0, update=s + a * x)],
s + a * x, s + a * x,
...@@ -303,6 +315,9 @@ class TestFunction: ...@@ -303,6 +315,9 @@ class TestFunction:
a = scalar() a = scalar()
x, s = scalars("xs") x, s = scalars("xs")
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
f = function( f = function(
[ [
x, x,
...@@ -504,6 +519,9 @@ class TestFunction: ...@@ -504,6 +519,9 @@ class TestFunction:
a = scalar() # the a is for 'anonymous' (un-named). a = scalar() # the a is for 'anonymous' (un-named).
x, s = scalars("xs") x, s = scalars("xs")
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
f = function( f = function(
[ [
x, x,
...@@ -532,6 +550,9 @@ class TestFunction: ...@@ -532,6 +550,9 @@ class TestFunction:
a = scalar() # the a is for 'anonymous' (un-named). a = scalar() # the a is for 'anonymous' (un-named).
x, s = scalars("xs") x, s = scalars("xs")
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
f = function( f = function(
[ [
x, x,
...@@ -556,6 +577,9 @@ class TestFunction: ...@@ -556,6 +577,9 @@ class TestFunction:
a = scalar() # the a is for 'anonymous' (un-named). a = scalar() # the a is for 'anonymous' (un-named).
x, s = scalars("xs") x, s = scalars("xs")
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
f = function( f = function(
[ [
x, x,
...@@ -718,6 +742,9 @@ class TestFunction: ...@@ -718,6 +742,9 @@ class TestFunction:
a, b = dscalars("a", "b") a, b = dscalars("a", "b")
c = a + b c = a + b
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
funct = function([In(a, name="first"), In(b, value=1, name="second")], c) funct = function([In(a, name="first"), In(b, value=1, name="second")], c)
x = funct(first=1) x = funct(first=1)
try: try:
...@@ -775,6 +802,7 @@ class TestFunction: ...@@ -775,6 +802,7 @@ class TestFunction:
# Tests that function works when outputs is a dictionary # Tests that function works when outputs is a dictionary
x = scalar() x = scalar()
with pytest.warns(FutureWarning, match="output_keys is deprecated."):
f = function([x], outputs={"a": x, "c": x * 2, "b": x * 3, "1": x * 4}) f = function([x], outputs={"a": x, "c": x * 2, "b": x * 3, "1": x * 4})
outputs = f(10.0) outputs = f(10.0)
...@@ -790,6 +818,7 @@ class TestFunction: ...@@ -790,6 +818,7 @@ class TestFunction:
x = scalar("x") x = scalar("x")
y = scalar("y") y = scalar("y")
with pytest.warns(FutureWarning, match="output_keys is deprecated."):
f = function([x, y], outputs={"a": x + y, "b": x * y}) f = function([x, y], outputs={"a": x + y, "b": x * y})
assert f(2, 4) == {"a": 6, "b": 8} assert f(2, 4) == {"a": 6, "b": 8}
...@@ -805,6 +834,7 @@ class TestFunction: ...@@ -805,6 +834,7 @@ class TestFunction:
e1 = scalar("1") e1 = scalar("1")
e2 = scalar("2") e2 = scalar("2")
with pytest.warns(FutureWarning, match="output_keys is deprecated."):
f = function( f = function(
[x, y, z, e1, e2], outputs={"x": x, "y": y, "z": z, "1": e1, "2": e2} [x, y, z, e1, e2], outputs={"x": x, "y": y, "z": z, "1": e1, "2": e2}
) )
...@@ -825,6 +855,7 @@ class TestFunction: ...@@ -825,6 +855,7 @@ class TestFunction:
a = x + y a = x + y
b = x * y b = x * y
with pytest.warns(FutureWarning, match="output_keys is deprecated."):
f = function([x, y], outputs={"a": a, "b": b}) f = function([x, y], outputs={"a": a, "b": b})
a = scalar("a") a = scalar("a")
...@@ -880,6 +911,9 @@ class TestPicklefunction: ...@@ -880,6 +911,9 @@ class TestPicklefunction:
a = scalar() # the a is for 'anonymous' (un-named). a = scalar() # the a is for 'anonymous' (un-named).
x, s = scalars("xs") x, s = scalars("xs")
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
f = function( f = function(
[ [
x, x,
...@@ -941,6 +975,9 @@ class TestPicklefunction: ...@@ -941,6 +975,9 @@ class TestPicklefunction:
a = dscalar() # the a is for 'anonymous' (un-named). a = dscalar() # the a is for 'anonymous' (un-named).
x, s = dscalars("xs") x, s = dscalars("xs")
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
f = function( f = function(
[ [
x, x,
...@@ -967,10 +1004,12 @@ class TestPicklefunction: ...@@ -967,10 +1004,12 @@ class TestPicklefunction:
def test_output_keys(self): def test_output_keys(self):
x = vector() x = vector()
with pytest.warns(FutureWarning, match="output_keys is deprecated."):
f = function([x], {"vec": x**2}) f = function([x], {"vec": x**2})
o = f([2, 3, 4]) o = f([2, 3, 4])
assert isinstance(o, dict) assert isinstance(o, dict)
assert np.allclose(o["vec"], [4, 9, 16]) assert np.allclose(o["vec"], [4, 9, 16])
with pytest.warns(FutureWarning, match="output_keys is deprecated."):
g = copy.deepcopy(f) g = copy.deepcopy(f)
o = g([2, 3, 4]) o = g([2, 3, 4])
assert isinstance(o, dict) assert isinstance(o, dict)
...@@ -980,6 +1019,9 @@ class TestPicklefunction: ...@@ -980,6 +1019,9 @@ class TestPicklefunction:
# Ensure that shared containers remain shared after a deep copy. # Ensure that shared containers remain shared after a deep copy.
a, x = scalars("ax") a, x = scalars("ax")
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
h = function([In(a, value=0.0)], a) h = function([In(a, value=0.0)], a)
f = function([x, In(a, value=h.container[a], implicit=True)], x + a) f = function([x, In(a, value=h.container[a], implicit=True)], x + a)
...@@ -1004,6 +1046,9 @@ class TestPicklefunction: ...@@ -1004,6 +1046,9 @@ class TestPicklefunction:
a = scalar() # the a is for 'anonymous' (un-named). a = scalar() # the a is for 'anonymous' (un-named).
x, s = scalars("xs") x, s = scalars("xs")
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
f = function( f = function(
[ [
x, x,
...@@ -1105,6 +1150,9 @@ class TestPicklefunction: ...@@ -1105,6 +1150,9 @@ class TestPicklefunction:
# some derived thing, whose inputs aren't all in the list # some derived thing, whose inputs aren't all in the list
list_of_things.append(a * x + s) list_of_things.append(a * x + s)
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
f1 = function( f1 = function(
[ [
x, x,
...@@ -1116,6 +1164,9 @@ class TestPicklefunction: ...@@ -1116,6 +1164,9 @@ class TestPicklefunction:
list_of_things.append(f1) list_of_things.append(f1)
# now put in a function sharing container with the previous one # now put in a function sharing container with the previous one
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
f2 = function( f2 = function(
[ [
x, x,
...@@ -1131,6 +1182,9 @@ class TestPicklefunction: ...@@ -1131,6 +1182,9 @@ class TestPicklefunction:
# now put in a function with non-scalar # now put in a function with non-scalar
v_value = np.asarray([2, 3, 4.0], dtype=config.floatX) v_value = np.asarray([2, 3, 4.0], dtype=config.floatX)
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
f3 = function([x, In(v, value=v_value)], x + v) f3 = function([x, In(v, value=v_value)], x + v)
list_of_things.append(f3) list_of_things.append(f3)
...@@ -1263,6 +1317,9 @@ class SomethingToPickle: ...@@ -1263,6 +1317,9 @@ class SomethingToPickle:
self.e = a * x + s self.e = a * x + s
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
self.f1 = function( self.f1 = function(
[ [
x, x,
...@@ -1272,6 +1329,9 @@ class SomethingToPickle: ...@@ -1272,6 +1329,9 @@ class SomethingToPickle:
s + a * x, s + a * x,
) )
with pytest.warns(
FutureWarning, match="Inputs with default values are deprecated."
):
self.f2 = function( self.f2 = function(
[ [
x, x,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论