提交 6dfc811f authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Deprecate grad_overrides in OpFromGraph

上级 ca8d60a3
"""Define new Ops from existing Ops""" """Define new Ops from existing Ops"""
import warnings
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Sequence from collections.abc import Sequence
from copy import copy from copy import copy
...@@ -189,7 +190,7 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -189,7 +190,7 @@ class OpFromGraph(Op, HasInnerGraph):
- For overriding, it's recommended to provide pure functions (no side - For overriding, it's recommended to provide pure functions (no side
effects like setting global variable) as callable(s). The callable(s) effects like setting global variable) as callable(s). The callable(s)
supplied for overriding gradient/rop will be called only once at the supplied for overriding gradient/rop will be called only once at the
first call to grad/R_op, and will be converted to OpFromGraph instances. first call to L_op/R_op, and will be converted to OpFromGraph instances.
Examples Examples
-------- --------
...@@ -224,7 +225,7 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -224,7 +225,7 @@ class OpFromGraph(Op, HasInnerGraph):
e2 = op(x, y, z) + op(z, y, x) e2 = op(x, y, z) + op(z, y, x)
fn = function([x, y, z], [e2]) fn = function([x, y, z], [e2])
Example 3 override gradient Example 3 override L_op
.. code-block:: python .. code-block:: python
...@@ -233,12 +234,15 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -233,12 +234,15 @@ class OpFromGraph(Op, HasInnerGraph):
x, y, z = pt.scalars('xyz') x, y, z = pt.scalars('xyz')
e = x + y * z e = x + y * z
def rescale_dy(inps, grads): def rescale_dy(inps, outputs, out_grads):
x, y, z = inps x, y, z = inps
g, = grads g, = out_grads
return z*2 return z*2
op = OpFromGraph( op = OpFromGraph(
[x, y, z], [e], grad_overrides=['default', rescale_dy, 'default'] [x, y, z],
[e],
lop_overrides=['default', rescale_dy, 'default'],
)
e2 = op(x, y, z) e2 = op(x, y, z)
dx, dy, dz = grad(e2, [x, y, z]) dx, dy, dz = grad(e2, [x, y, z])
fn = function([x, y, z], [dx, dy, dz]) fn = function([x, y, z], [dx, dy, dz])
...@@ -455,6 +459,10 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -455,6 +459,10 @@ class OpFromGraph(Op, HasInnerGraph):
self.set_lop_overrides(lop_overrides) self.set_lop_overrides(lop_overrides)
self._lop_type = "lop" self._lop_type = "lop"
elif grad_overrides != "default": elif grad_overrides != "default":
warnings.warn(
"grad_overrides is deprecated in favor of lop_overrides. Using it will lead to an error in the future.",
FutureWarning,
)
self.set_lop_overrides(grad_overrides) self.set_lop_overrides(grad_overrides)
self._lop_type = "grad" self._lop_type = "grad"
else: else:
......
...@@ -181,8 +181,9 @@ class TestOpFromGraph(unittest_tools.InferShapeTester): ...@@ -181,8 +181,9 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
dedz = vector("dedz") dedz = vector("dedz")
op_mul_grad = cls_ofg([x, y, dedz], go([x, y], [dedz])) op_mul_grad = cls_ofg([x, y, dedz], go([x, y], [dedz]))
op_mul = cls_ofg([x, y], [x * y], grad_overrides=go) with pytest.warns(FutureWarning, match="grad_overrides is deprecated"):
op_mul2 = cls_ofg([x, y], [x * y], grad_overrides=op_mul_grad) op_mul = cls_ofg([x, y], [x * y], grad_overrides=go)
op_mul2 = cls_ofg([x, y], [x * y], grad_overrides=op_mul_grad)
# single override case (function or OfG instance) # single override case (function or OfG instance)
xx, yy = vector("xx"), vector("yy") xx, yy = vector("xx"), vector("yy")
...@@ -209,9 +210,10 @@ class TestOpFromGraph(unittest_tools.InferShapeTester): ...@@ -209,9 +210,10 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
w, b = vectors("wb") w, b = vectors("wb")
# we make the 3rd gradient default (no override) # we make the 3rd gradient default (no override)
op_linear = cls_ofg( with pytest.warns(FutureWarning, match="grad_overrides is deprecated"):
[x, w, b], [x * w + b], grad_overrides=[go1, go2, "default"] op_linear = cls_ofg(
) [x, w, b], [x * w + b], grad_overrides=[go1, go2, "default"]
)
xx, ww, bb = vector("xx"), vector("yy"), vector("bb") xx, ww, bb = vector("xx"), vector("yy"), vector("bb")
zz = pt_sum(op_linear(xx, ww, bb)) zz = pt_sum(op_linear(xx, ww, bb))
dx, dw, db = grad(zz, [xx, ww, bb]) dx, dw, db = grad(zz, [xx, ww, bb])
...@@ -225,11 +227,12 @@ class TestOpFromGraph(unittest_tools.InferShapeTester): ...@@ -225,11 +227,12 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
np.testing.assert_array_almost_equal(np.ones(16, dtype=config.floatX), dbv, 4) np.testing.assert_array_almost_equal(np.ones(16, dtype=config.floatX), dbv, 4)
# NullType and DisconnectedType # NullType and DisconnectedType
op_linear2 = cls_ofg( with pytest.warns(FutureWarning, match="grad_overrides is deprecated"):
[x, w, b], op_linear2 = cls_ofg(
[x * w + b], [x, w, b],
grad_overrides=[go1, NullType()(), DisconnectedType()()], [x * w + b],
) grad_overrides=[go1, NullType()(), DisconnectedType()()],
)
zz2 = pt_sum(op_linear2(xx, ww, bb)) zz2 = pt_sum(op_linear2(xx, ww, bb))
dx2, dw2, db2 = grad( dx2, dw2, db2 = grad(
zz2, zz2,
...@@ -339,13 +342,14 @@ class TestOpFromGraph(unittest_tools.InferShapeTester): ...@@ -339,13 +342,14 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
def f1_back(inputs, output_gradients): def f1_back(inputs, output_gradients):
return [output_gradients[0], disconnected_type()] return [output_gradients[0], disconnected_type()]
op = cls_ofg( with pytest.warns(FutureWarning, match="grad_overrides is deprecated"):
inputs=[x, y], op = cls_ofg(
outputs=[f1(x, y)], inputs=[x, y],
grad_overrides=f1_back, outputs=[f1(x, y)],
connection_pattern=[[True], [False]], # This is new grad_overrides=f1_back,
on_unused_input="ignore", connection_pattern=[[True], [False]],
) # This is new on_unused_input="ignore",
)
c = op(x, y) c = op(x, y)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论