提交 fca59b06 authored 作者: khaotik's avatar khaotik 提交者: khaotik

implement Rop for OfG with tests

上级 03c6f9bb
......@@ -40,6 +40,8 @@ class OpFromGraph(gof.Op):
- function : must return list of Variable.
- list : each function must return a single Variable. The order
of the list must corresponds to inputs
rop_overrides: None | undef | function | list of (None|undef|function), optional
similar to grad_overrides, list order should match output instead
TODO:
- examples for a multi-layer mlp. where?
......@@ -156,10 +158,7 @@ class OpFromGraph(gof.Op):
self.input_types = [inp.type for inp in inputs]
self.output_types = [out.type for out in outputs]
self.set_grad_overrides(grad_overrides)
# TODO
if rop_overrides is not None:
raise NotImplementedError('Overriding Rop is not implemented yet.')
self.set_rop_overrides(rop_overrides)
def __eq__(self, other):
# TODO: recognize a copy
......@@ -169,10 +168,6 @@ class OpFromGraph(gof.Op):
# TODO: use internal variables in hash
return hash(type(self))
# TODO impl me
# def R_op(self, inputs, eval_points):
# pass
def _recompute_grad_op(self):
output_grads = [out_t() for out_t in self.output_types]
if self._grad_op is None:
......@@ -184,7 +179,8 @@ class OpFromGraph(gof.Op):
if len(goverrides_l) > len(self.local_inputs):
raise ValueError(
'Can override %d gradients at most, got %d' % (
len(self.local_inputs), len(goverrides_l)))
len(self.local_inputs), len(goverrides_l)),
self.goverrides_l)
if len(goverrides_l) < len(self.local_inputs):
goverrides_l += [None] * (
len(self.local_inputs) - len(goverrides_l))
......@@ -213,12 +209,81 @@ class OpFromGraph(gof.Op):
for inp in self.local_inputs]
else:
all_grads_l = self._grad_op(self.local_inputs, output_grads)
if not isinstance(all_grads_l, (tuple, list)):
all_grads_l = [all_grads_l]
if len(all_grads_l) != len(self.local_inputs):
raise ValueError(
'Gradient overriding function %s should return list of '
'%d outputs, got %d' % (
self._grad_op, len(self.local_inputs), len(all_grads_l)),
self._grad_op
)
self._grad_op = type(self)(
inputs=self.local_inputs + output_grads,
outputs=all_grads_l,
inline=self.is_inline, on_unused_input='ignore')
inline=self.is_inline, on_unused_input='ignore',
)
self._grad_op_is_cached = True
def _recompute_rop_op(self):
eval_points = [inp_t() for inp_t in self.input_types]
if self._rop_op is None:
self._rop_op = []
if isinstance(self._rop_op, list):
roverrides_l = self._rop_op
if len(roverrides_l) > len(self.local_outputs):
raise ValueError(
'Can override %d gradients at most, got %d' % (
len(self.local_onputs), len(roverrides_l)),
roverrides_l)
if len(roverrides_l) < len(self.local_outputs):
roverrides_l += [None] * (
len(self.local_outputs) - len(roverrides_l))
# get outputs that does not have Rop override
odefaults_l = [
lo for lo, rov in izip(self.local_outputs, roverrides_l)
if not rov]
# compute non-overriding downsteam grads from upstreams grads
# it's normal some input may be disconnected, thus the 'ignore'
rdefaults_li = theano.gradient.Rop(
f=odefaults_l,
wrt=self.local_inputs,
eval_points=eval_points
)
rdefaults = iter(rdefaults_li if odefaults_l else [])
# combine overriding gradients
all_rops_l = []
for out, rov in izip(self.local_outputs, roverrides_l):
if rov is None:
all_rops_l.append(next(rdefaults))
elif rov is undef:
all_rops_l.append(
out.zeros_like().astype(theano.config.floatX))
else:
all_rops_l.append(rov(self.local_inputs, eval_points))
elif self._rop_op is undef:
all_rops_l = [
out.zeros_like().astype(theano.config.floatX)
for out in self.local_outputs]
else:
all_rops_l = self._rop_op(self.local_inputs, eval_points)
if not isinstance(all_rops_l, (tuple, list)):
all_rops_l = [all_rops_l]
if len(all_rops_l) != len(self.local_outputs):
raise ValueError(
'Rop overriding function %s should return list of '
'%d outputs, got %d' % (
self._rop_op,
len(self.local_outputs),
len(all_rops_l)),
self._rop_op)
self._rop_op = type(self)(
inputs=self.local_inputs + eval_points,
outputs=all_rops_l,
inline=self.is_inline, on_unused_input='ignore')
self._rop_op_is_cached = True
def get_grad_op(self):
"""
getter method for self._grad_op
......@@ -227,21 +292,41 @@ class OpFromGraph(gof.Op):
self._recompute_grad_op()
return self._grad_op
def get_rop_op(self):
"""
getter method for self._rop_op
"""
if not self._rop_op_is_cached:
self._recompute_rop_op()
return self._rop_op
def set_rop_overrides(self, rop_overrides):
"""
Set R_op overrides, see help(theano.OpFromGraph) for syntax
This will completely remove any previously set R_op overrides
"""
self._rop_op = rop_overrides
self._rop_op_is_cached = False
def set_grad_overrides(self, grad_overrides):
"""
Set gradient overrides, see help(theano.OpFromGraph) for syntax
This will completed remove any previously set gradient overrides
This will completely remove any previously set gradient overrides
"""
self._grad_op = grad_overrides
self._grad_op_is_cached = False
def R_op(self, inputs, eval_points):
if not self._rop_op_is_cached:
self._recompute_rop_op()
return self._rop_op(*(list(inputs) + list(eval_points)), return_list=True)
def grad(self, inputs, output_grads):
if not self._grad_op_is_cached:
self._recompute_grad_op()
if self._grad_op is undef:
return [None for _ in self.input_types]
return self._grad_op(*(list(inputs) + list(output_grads)))
return self._grad_op(*(list(inputs) + list(output_grads)), return_list=True)
def make_node(self, *inputs):
for input, type in zip(inputs, self.input_types):
......@@ -298,7 +383,7 @@ class OpFromGraph(gof.Op):
def perform(self, node, inputs, outputs):
variables = self.fn(*inputs)
assert len(variables) == len(outputs)
for output, variable in zip(outputs, variables):
for output, variable in izip(outputs, variables):
# TODO: when function's output-borrowing semantics are correct,
# we wont need this copy anymore
output[0] = variable.copy()
......
......@@ -168,6 +168,47 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
assert np.allclose(xv * 1.5, dwv)
assert np.allclose(np.ones(16, dtype=config.floatX), dbv)
@test_params
def test_rop(self, cls_ofg):
a = T.vector()
M = T.matrix()
b = T.dot(a, M)
op_matmul = cls_ofg([a, M], [b])
x = T.vector()
W = T.matrix()
y = op_matmul(x, W)
du = T.vector()
dv = T.Rop(y, x, du)
fn = function([x, W, du], dv)
xval = numpy.random.rand(16).astype(config.floatX)
Wval = numpy.random.rand(16, 16).astype(config.floatX)
duval = numpy.random.rand(16).astype(config.floatX)
dvval = numpy.dot(duval, Wval)
dvval2 = fn(xval, Wval, duval)
print(dvval, dvval2)
assert numpy.allclose(dvval2, dvval)
@test_params
def test_rop_override(self, cls_ofg):
x, y = T.vectors('xy')
def ro(inps, epts):
x, y = inps
u, v = epts
return [u * y * 2. + x * v * 1.5]
op_mul = cls_ofg([x, y], [x * y], rop_overrides=ro)
xx, yy = T.vector('xx'), T.vector('yy')
zz = op_mul(xx, yy)
du, dv = T.vector('du'), T.vector('dv')
dw = T.Rop(zz, [xx, yy], [du, dv])
fn = function([xx, yy, du, dv], dw)
vals = numpy.random.rand(4, 32).astype(config.floatX)
dwval = fn(*vals)
assert numpy.allclose(
dwval, vals[0] * vals[3] * 1.5 + vals[1] * vals[2] * 2.)
@test_params
def test_nested(self, cls_ofg):
x, y = T.vectors('xy')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论