提交 da7f2835 authored 作者: khaotik's avatar khaotik

allow directly assign OfG inst overriding grad/rop

上级 fca59b06
......@@ -29,7 +29,8 @@ class OpFromGraph(gof.Op):
inline: bool, optional
if True, will cause the Op's original graph being used during
compilation, otherwise will use a pre-compiled function inside.
grad_overrides: None | undef | function | list of (None|undef|function), optional
grad_overrides: None | undef | OpFromGraph instance | function | \
list of (None|undef|function), optional
Used to override default gradient routine.
Overriding function(s) must take two list of variable(s) as inputs,
the original inputs and ups gradients
......@@ -37,11 +38,15 @@ class OpFromGraph(gof.Op):
- `None` : will use default gradient routine.
- theano.utils.undef : No gradient will be used (zero)
- OpFromGraph instance: the OfG instance should accept inputs with same
order and types as specified in "inputs" and "outputs" arguments
- function : must return list of Variable.
- list : each function must return a single Variable. The order
of the list must corresponds to inputs
rop_overrides: None | undef | function | list of (None|undef|function), optional
similar to grad_overrides, list order should match output instead
rop_overrides: None | undef | OpFromGraph instance | function | \
list of (None|undef|function), optional
similar to grad_overrides, list order should match two list of "inputs"
concatenated.
TODO:
- examples for a multi-layer mlp. where?
......@@ -169,6 +174,9 @@ class OpFromGraph(gof.Op):
return hash(type(self))
def _recompute_grad_op(self):
if isinstance(self._grad_op, OpFromGraph):
self._grad_op_is_cached = True
return
output_grads = [out_t() for out_t in self.output_types]
if self._grad_op is None:
self._grad_op = []
......@@ -226,6 +234,9 @@ class OpFromGraph(gof.Op):
self._grad_op_is_cached = True
def _recompute_rop_op(self):
if isinstance(self._rop_op, OpFromGraph):
self._rop_op_is_cached = True
return
eval_points = [inp_t() for inp_t in self.input_types]
if self._rop_op is None:
self._rop_op = []
......
......@@ -126,21 +126,26 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
def go(inps, gs):
x, y = inps
g = gs[0]
g, = gs
return [g * y * 2, g * x * 1.5]
# no override case is coverd in "grad" test
# single override case
dedz = T.vector('dedz')
op_mul_grad = cls_ofg([x, y, dedz], go([x, y], [dedz]))
op_mul = cls_ofg([x, y], [x * y], grad_overrides=go)
op_mul2 = cls_ofg([x, y], [x * y], grad_overrides=op_mul_grad)
# single override case (function or OfG instance)
xx, yy = T.vector('xx'), T.vector('yy')
zz = T.sum(op_mul(xx, yy))
dx, dy = T.grad(zz, [xx, yy])
fn = function([xx, yy], [dx, dy])
xv = np.random.rand(16).astype(config.floatX)
yv = np.random.rand(16).astype(config.floatX)
dxv, dyv = fn(xv, yv)
assert np.allclose(yv * 2, dxv)
assert np.allclose(xv * 1.5, dyv)
for op in [op_mul, op_mul2]:
zz = T.sum(op(xx, yy))
dx, dy = T.grad(zz, [xx, yy])
fn = function([xx, yy], [dx, dy])
xv = np.random.rand(16).astype(config.floatX)
yv = np.random.rand(16).astype(config.floatX)
dxv, dyv = fn(xv, yv)
assert np.allclose(yv * 2, dxv)
assert np.allclose(xv * 1.5, dyv)
# list override case
def go1(inps, gs):
......@@ -180,13 +185,13 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
du = T.vector()
dv = T.Rop(y, x, du)
fn = function([x, W, du], dv)
xval = numpy.random.rand(16).astype(config.floatX)
Wval = numpy.random.rand(16, 16).astype(config.floatX)
duval = numpy.random.rand(16).astype(config.floatX)
dvval = numpy.dot(duval, Wval)
xval = np.random.rand(16).astype(config.floatX)
Wval = np.random.rand(16, 16).astype(config.floatX)
duval = np.random.rand(16).astype(config.floatX)
dvval = np.dot(duval, Wval)
dvval2 = fn(xval, Wval, duval)
print(dvval, dvval2)
assert numpy.allclose(dvval2, dvval)
assert np.allclose(dvval2, dvval)
@test_params
def test_rop_override(self, cls_ofg):
......@@ -197,17 +202,24 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
u, v = epts
return [u * y * 2. + x * v * 1.5]
u, v = T.vectors('uv')
op_mul_rop = cls_ofg([x, y, u, v], ro([x, y], [u, v]))
op_mul = cls_ofg([x, y], [x * y], rop_overrides=ro)
xx, yy = T.vector('xx'), T.vector('yy')
zz = op_mul(xx, yy)
op_mul2 = cls_ofg([x, y], [x * y], rop_overrides=op_mul_rop)
# single override case
xx, yy = T.vector('xx'), T.vector('yy')
du, dv = T.vector('du'), T.vector('dv')
dw = T.Rop(zz, [xx, yy], [du, dv])
fn = function([xx, yy, du, dv], dw)
vals = numpy.random.rand(4, 32).astype(config.floatX)
dwval = fn(*vals)
assert numpy.allclose(
dwval, vals[0] * vals[3] * 1.5 + vals[1] * vals[2] * 2.)
for op in [op_mul, op_mul2]:
zz = op_mul(xx, yy)
dw = T.Rop(zz, [xx, yy], [du, dv])
fn = function([xx, yy, du, dv], dw)
vals = np.random.rand(4, 32).astype(config.floatX)
dwval = fn(*vals)
assert np.allclose(
dwval, vals[0] * vals[3] * 1.5 + vals[1] * vals[2] * 2.)
# TODO list override case
@test_params
def test_nested(self, cls_ofg):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论