提交 6dfc811f authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Deprecate grad_overrides in OpFromGraph

上级 ca8d60a3
"""Define new Ops from existing Ops"""
import warnings
from collections import OrderedDict
from collections.abc import Sequence
from copy import copy
......@@ -189,7 +190,7 @@ class OpFromGraph(Op, HasInnerGraph):
- For overriding, it's recommended to provide pure functions (no side
effects like setting global variable) as callable(s). The callable(s)
supplied for overriding gradient/rop will be called only once at the
first call to grad/R_op, and will be converted to OpFromGraph instances.
first call to L_op/R_op, and will be converted to OpFromGraph instances.
Examples
--------
......@@ -224,7 +225,7 @@ class OpFromGraph(Op, HasInnerGraph):
e2 = op(x, y, z) + op(z, y, x)
fn = function([x, y, z], [e2])
Example 3 override gradient
Example 3 override L_op
.. code-block:: python
......@@ -233,12 +234,15 @@ class OpFromGraph(Op, HasInnerGraph):
x, y, z = pt.scalars('xyz')
e = x + y * z
def rescale_dy(inps, grads):
def rescale_dy(inps, outputs, out_grads):
x, y, z = inps
g, = grads
g, = out_grads
return z*2
op = OpFromGraph(
[x, y, z], [e], grad_overrides=['default', rescale_dy, 'default']
[x, y, z],
[e],
lop_overrides=['default', rescale_dy, 'default'],
)
e2 = op(x, y, z)
dx, dy, dz = grad(e2, [x, y, z])
fn = function([x, y, z], [dx, dy, dz])
......@@ -455,6 +459,10 @@ class OpFromGraph(Op, HasInnerGraph):
self.set_lop_overrides(lop_overrides)
self._lop_type = "lop"
elif grad_overrides != "default":
warnings.warn(
"grad_overrides is deprecated in favor of lop_overrides. Using it will lead to an error in the future.",
FutureWarning,
)
self.set_lop_overrides(grad_overrides)
self._lop_type = "grad"
else:
......
......@@ -181,8 +181,9 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
dedz = vector("dedz")
op_mul_grad = cls_ofg([x, y, dedz], go([x, y], [dedz]))
op_mul = cls_ofg([x, y], [x * y], grad_overrides=go)
op_mul2 = cls_ofg([x, y], [x * y], grad_overrides=op_mul_grad)
with pytest.warns(FutureWarning, match="grad_overrides is deprecated"):
op_mul = cls_ofg([x, y], [x * y], grad_overrides=go)
op_mul2 = cls_ofg([x, y], [x * y], grad_overrides=op_mul_grad)
# single override case (function or OfG instance)
xx, yy = vector("xx"), vector("yy")
......@@ -209,9 +210,10 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
w, b = vectors("wb")
# we make the 3rd gradient default (no override)
op_linear = cls_ofg(
[x, w, b], [x * w + b], grad_overrides=[go1, go2, "default"]
)
with pytest.warns(FutureWarning, match="grad_overrides is deprecated"):
op_linear = cls_ofg(
[x, w, b], [x * w + b], grad_overrides=[go1, go2, "default"]
)
xx, ww, bb = vector("xx"), vector("yy"), vector("bb")
zz = pt_sum(op_linear(xx, ww, bb))
dx, dw, db = grad(zz, [xx, ww, bb])
......@@ -225,11 +227,12 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
np.testing.assert_array_almost_equal(np.ones(16, dtype=config.floatX), dbv, 4)
# NullType and DisconnectedType
op_linear2 = cls_ofg(
[x, w, b],
[x * w + b],
grad_overrides=[go1, NullType()(), DisconnectedType()()],
)
with pytest.warns(FutureWarning, match="grad_overrides is deprecated"):
op_linear2 = cls_ofg(
[x, w, b],
[x * w + b],
grad_overrides=[go1, NullType()(), DisconnectedType()()],
)
zz2 = pt_sum(op_linear2(xx, ww, bb))
dx2, dw2, db2 = grad(
zz2,
......@@ -339,13 +342,14 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
def f1_back(inputs, output_gradients):
return [output_gradients[0], disconnected_type()]
op = cls_ofg(
inputs=[x, y],
outputs=[f1(x, y)],
grad_overrides=f1_back,
connection_pattern=[[True], [False]], # This is new
on_unused_input="ignore",
) # This is new
with pytest.warns(FutureWarning, match="grad_overrides is deprecated"):
op = cls_ofg(
inputs=[x, y],
outputs=[f1(x, y)],
grad_overrides=f1_back,
connection_pattern=[[True], [False]],
on_unused_input="ignore",
)
c = op(x, y)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论