提交 da7f2835 authored 作者: khaotik's avatar khaotik

allow directly assign OfG inst overriding grad/rop

上级 fca59b06
...@@ -29,7 +29,8 @@ class OpFromGraph(gof.Op): ...@@ -29,7 +29,8 @@ class OpFromGraph(gof.Op):
inline: bool, optional inline: bool, optional
if True, will cause the Op's original graph being used during if True, will cause the Op's original graph being used during
compilation, otherwise will use a pre-compiled function inside. compilation, otherwise will use a pre-compiled function inside.
grad_overrides: None | undef | function | list of (None|undef|function), optional grad_overrides: None | undef | OpFromGraph instance | function | \
list of (None|undef|function), optional
Used to override default gradient routine. Used to override default gradient routine.
Overriding function(s) must take two list of variable(s) as inputs, Overriding function(s) must take two list of variable(s) as inputs,
the original inputs and ups gradients the original inputs and ups gradients
...@@ -37,11 +38,15 @@ class OpFromGraph(gof.Op): ...@@ -37,11 +38,15 @@ class OpFromGraph(gof.Op):
- `None` : will use default gradient routine. - `None` : will use default gradient routine.
- theano.utils.undef : No gradient will be used (zero) - theano.utils.undef : No gradient will be used (zero)
- OpFromGraph instance: the OfG instance should accept inputs with same
order and types as specified in "inputs" and "outputs" arguments
- function : must return list of Variable. - function : must return list of Variable.
- list : each function must return a single Variable. The order - list : each function must return a single Variable. The order
of the list must corresponds to inputs of the list must corresponds to inputs
rop_overrides: None | undef | function | list of (None|undef|function), optional rop_overrides: None | undef | OpFromGraph instance | function | \
similar to grad_overrides, list order should match output instead list of (None|undef|function), optional
similar to grad_overrides, list order should match two list of "inputs"
concatenated.
TODO: TODO:
- examples for a multi-layer mlp. where? - examples for a multi-layer mlp. where?
...@@ -169,6 +174,9 @@ class OpFromGraph(gof.Op): ...@@ -169,6 +174,9 @@ class OpFromGraph(gof.Op):
return hash(type(self)) return hash(type(self))
def _recompute_grad_op(self): def _recompute_grad_op(self):
if isinstance(self._grad_op, OpFromGraph):
self._grad_op_is_cached = True
return
output_grads = [out_t() for out_t in self.output_types] output_grads = [out_t() for out_t in self.output_types]
if self._grad_op is None: if self._grad_op is None:
self._grad_op = [] self._grad_op = []
...@@ -226,6 +234,9 @@ class OpFromGraph(gof.Op): ...@@ -226,6 +234,9 @@ class OpFromGraph(gof.Op):
self._grad_op_is_cached = True self._grad_op_is_cached = True
def _recompute_rop_op(self): def _recompute_rop_op(self):
if isinstance(self._rop_op, OpFromGraph):
self._rop_op_is_cached = True
return
eval_points = [inp_t() for inp_t in self.input_types] eval_points = [inp_t() for inp_t in self.input_types]
if self._rop_op is None: if self._rop_op is None:
self._rop_op = [] self._rop_op = []
......
...@@ -126,14 +126,19 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -126,14 +126,19 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
def go(inps, gs): def go(inps, gs):
x, y = inps x, y = inps
g = gs[0] g, = gs
return [g * y * 2, g * x * 1.5] return [g * y * 2, g * x * 1.5]
# no override case is coverd in "grad" test
# single override case dedz = T.vector('dedz')
op_mul_grad = cls_ofg([x, y, dedz], go([x, y], [dedz]))
op_mul = cls_ofg([x, y], [x * y], grad_overrides=go) op_mul = cls_ofg([x, y], [x * y], grad_overrides=go)
op_mul2 = cls_ofg([x, y], [x * y], grad_overrides=op_mul_grad)
# single override case (function or OfG instance)
xx, yy = T.vector('xx'), T.vector('yy') xx, yy = T.vector('xx'), T.vector('yy')
zz = T.sum(op_mul(xx, yy)) for op in [op_mul, op_mul2]:
zz = T.sum(op(xx, yy))
dx, dy = T.grad(zz, [xx, yy]) dx, dy = T.grad(zz, [xx, yy])
fn = function([xx, yy], [dx, dy]) fn = function([xx, yy], [dx, dy])
xv = np.random.rand(16).astype(config.floatX) xv = np.random.rand(16).astype(config.floatX)
...@@ -180,13 +185,13 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -180,13 +185,13 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
du = T.vector() du = T.vector()
dv = T.Rop(y, x, du) dv = T.Rop(y, x, du)
fn = function([x, W, du], dv) fn = function([x, W, du], dv)
xval = numpy.random.rand(16).astype(config.floatX) xval = np.random.rand(16).astype(config.floatX)
Wval = numpy.random.rand(16, 16).astype(config.floatX) Wval = np.random.rand(16, 16).astype(config.floatX)
duval = numpy.random.rand(16).astype(config.floatX) duval = np.random.rand(16).astype(config.floatX)
dvval = numpy.dot(duval, Wval) dvval = np.dot(duval, Wval)
dvval2 = fn(xval, Wval, duval) dvval2 = fn(xval, Wval, duval)
print(dvval, dvval2) print(dvval, dvval2)
assert numpy.allclose(dvval2, dvval) assert np.allclose(dvval2, dvval)
@test_params @test_params
def test_rop_override(self, cls_ofg): def test_rop_override(self, cls_ofg):
...@@ -197,18 +202,25 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -197,18 +202,25 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
u, v = epts u, v = epts
return [u * y * 2. + x * v * 1.5] return [u * y * 2. + x * v * 1.5]
u, v = T.vectors('uv')
op_mul_rop = cls_ofg([x, y, u, v], ro([x, y], [u, v]))
op_mul = cls_ofg([x, y], [x * y], rop_overrides=ro) op_mul = cls_ofg([x, y], [x * y], rop_overrides=ro)
xx, yy = T.vector('xx'), T.vector('yy') op_mul2 = cls_ofg([x, y], [x * y], rop_overrides=op_mul_rop)
zz = op_mul(xx, yy)
# single override case
xx, yy = T.vector('xx'), T.vector('yy')
du, dv = T.vector('du'), T.vector('dv') du, dv = T.vector('du'), T.vector('dv')
for op in [op_mul, op_mul2]:
zz = op_mul(xx, yy)
dw = T.Rop(zz, [xx, yy], [du, dv]) dw = T.Rop(zz, [xx, yy], [du, dv])
fn = function([xx, yy, du, dv], dw) fn = function([xx, yy, du, dv], dw)
vals = numpy.random.rand(4, 32).astype(config.floatX) vals = np.random.rand(4, 32).astype(config.floatX)
dwval = fn(*vals) dwval = fn(*vals)
assert numpy.allclose( assert np.allclose(
dwval, vals[0] * vals[3] * 1.5 + vals[1] * vals[2] * 2.) dwval, vals[0] * vals[3] * 1.5 + vals[1] * vals[2] * 2.)
# TODO list override case
@test_params @test_params
def test_nested(self, cls_ofg): def test_nested(self, cls_ofg):
x, y = T.vectors('xy') x, y = T.vectors('xy')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论