提交 fca59b06 authored 作者: khaotik's avatar khaotik 提交者: khaotik

implement Rop for OfG with tests

上级 03c6f9bb
...@@ -40,6 +40,8 @@ class OpFromGraph(gof.Op): ...@@ -40,6 +40,8 @@ class OpFromGraph(gof.Op):
- function : must return list of Variable. - function : must return list of Variable.
- list : each function must return a single Variable. The order - list : each function must return a single Variable. The order
of the list must corresponds to inputs of the list must corresponds to inputs
rop_overrides: None | undef | function | list of (None|undef|function), optional
similar to grad_overrides, list order should match output instead
TODO: TODO:
- examples for a multi-layer mlp. where? - examples for a multi-layer mlp. where?
...@@ -156,10 +158,7 @@ class OpFromGraph(gof.Op): ...@@ -156,10 +158,7 @@ class OpFromGraph(gof.Op):
self.input_types = [inp.type for inp in inputs] self.input_types = [inp.type for inp in inputs]
self.output_types = [out.type for out in outputs] self.output_types = [out.type for out in outputs]
self.set_grad_overrides(grad_overrides) self.set_grad_overrides(grad_overrides)
self.set_rop_overrides(rop_overrides)
# TODO
if rop_overrides is not None:
raise NotImplementedError('Overriding Rop is not implemented yet.')
def __eq__(self, other): def __eq__(self, other):
# TODO: recognize a copy # TODO: recognize a copy
...@@ -169,10 +168,6 @@ class OpFromGraph(gof.Op): ...@@ -169,10 +168,6 @@ class OpFromGraph(gof.Op):
# TODO: use internal variables in hash # TODO: use internal variables in hash
return hash(type(self)) return hash(type(self))
# TODO impl me
# def R_op(self, inputs, eval_points):
# pass
def _recompute_grad_op(self): def _recompute_grad_op(self):
output_grads = [out_t() for out_t in self.output_types] output_grads = [out_t() for out_t in self.output_types]
if self._grad_op is None: if self._grad_op is None:
...@@ -184,7 +179,8 @@ class OpFromGraph(gof.Op): ...@@ -184,7 +179,8 @@ class OpFromGraph(gof.Op):
if len(goverrides_l) > len(self.local_inputs): if len(goverrides_l) > len(self.local_inputs):
raise ValueError( raise ValueError(
'Can override %d gradients at most, got %d' % ( 'Can override %d gradients at most, got %d' % (
len(self.local_inputs), len(goverrides_l))) len(self.local_inputs), len(goverrides_l)),
self.goverrides_l)
if len(goverrides_l) < len(self.local_inputs): if len(goverrides_l) < len(self.local_inputs):
goverrides_l += [None] * ( goverrides_l += [None] * (
len(self.local_inputs) - len(goverrides_l)) len(self.local_inputs) - len(goverrides_l))
...@@ -213,12 +209,81 @@ class OpFromGraph(gof.Op): ...@@ -213,12 +209,81 @@ class OpFromGraph(gof.Op):
for inp in self.local_inputs] for inp in self.local_inputs]
else: else:
all_grads_l = self._grad_op(self.local_inputs, output_grads) all_grads_l = self._grad_op(self.local_inputs, output_grads)
if not isinstance(all_grads_l, (tuple, list)):
all_grads_l = [all_grads_l]
if len(all_grads_l) != len(self.local_inputs):
raise ValueError(
'Gradient overriding function %s should return list of '
'%d outputs, got %d' % (
self._grad_op, len(self.local_inputs), len(all_grads_l)),
self._grad_op
)
self._grad_op = type(self)( self._grad_op = type(self)(
inputs=self.local_inputs + output_grads, inputs=self.local_inputs + output_grads,
outputs=all_grads_l, outputs=all_grads_l,
inline=self.is_inline, on_unused_input='ignore') inline=self.is_inline, on_unused_input='ignore',
)
self._grad_op_is_cached = True self._grad_op_is_cached = True
def _recompute_rop_op(self):
eval_points = [inp_t() for inp_t in self.input_types]
if self._rop_op is None:
self._rop_op = []
if isinstance(self._rop_op, list):
roverrides_l = self._rop_op
if len(roverrides_l) > len(self.local_outputs):
raise ValueError(
'Can override %d gradients at most, got %d' % (
len(self.local_onputs), len(roverrides_l)),
roverrides_l)
if len(roverrides_l) < len(self.local_outputs):
roverrides_l += [None] * (
len(self.local_outputs) - len(roverrides_l))
# get outputs that does not have Rop override
odefaults_l = [
lo for lo, rov in izip(self.local_outputs, roverrides_l)
if not rov]
# compute non-overriding downsteam grads from upstreams grads
# it's normal some input may be disconnected, thus the 'ignore'
rdefaults_li = theano.gradient.Rop(
f=odefaults_l,
wrt=self.local_inputs,
eval_points=eval_points
)
rdefaults = iter(rdefaults_li if odefaults_l else [])
# combine overriding gradients
all_rops_l = []
for out, rov in izip(self.local_outputs, roverrides_l):
if rov is None:
all_rops_l.append(next(rdefaults))
elif rov is undef:
all_rops_l.append(
out.zeros_like().astype(theano.config.floatX))
else:
all_rops_l.append(rov(self.local_inputs, eval_points))
elif self._rop_op is undef:
all_rops_l = [
out.zeros_like().astype(theano.config.floatX)
for out in self.local_outputs]
else:
all_rops_l = self._rop_op(self.local_inputs, eval_points)
if not isinstance(all_rops_l, (tuple, list)):
all_rops_l = [all_rops_l]
if len(all_rops_l) != len(self.local_outputs):
raise ValueError(
'Rop overriding function %s should return list of '
'%d outputs, got %d' % (
self._rop_op,
len(self.local_outputs),
len(all_rops_l)),
self._rop_op)
self._rop_op = type(self)(
inputs=self.local_inputs + eval_points,
outputs=all_rops_l,
inline=self.is_inline, on_unused_input='ignore')
self._rop_op_is_cached = True
def get_grad_op(self): def get_grad_op(self):
""" """
getter method for self._grad_op getter method for self._grad_op
...@@ -227,21 +292,41 @@ class OpFromGraph(gof.Op): ...@@ -227,21 +292,41 @@ class OpFromGraph(gof.Op):
self._recompute_grad_op() self._recompute_grad_op()
return self._grad_op return self._grad_op
def get_rop_op(self):
"""
getter method for self._rop_op
"""
if not self._rop_op_is_cached:
self._recompute_rop_op()
return self._rop_op
def set_rop_overrides(self, rop_overrides):
"""
Set R_op overrides, see help(theano.OpFromGraph) for syntax
This will completely remove any previously set R_op overrides
"""
self._rop_op = rop_overrides
self._rop_op_is_cached = False
def set_grad_overrides(self, grad_overrides): def set_grad_overrides(self, grad_overrides):
""" """
Set gradient overrides, see help(theano.OpFromGraph) for syntax Set gradient overrides, see help(theano.OpFromGraph) for syntax
This will completed remove any previously set gradient overrides This will completely remove any previously set gradient overrides
""" """
self._grad_op = grad_overrides self._grad_op = grad_overrides
self._grad_op_is_cached = False self._grad_op_is_cached = False
def R_op(self, inputs, eval_points):
if not self._rop_op_is_cached:
self._recompute_rop_op()
return self._rop_op(*(list(inputs) + list(eval_points)), return_list=True)
def grad(self, inputs, output_grads): def grad(self, inputs, output_grads):
if not self._grad_op_is_cached: if not self._grad_op_is_cached:
self._recompute_grad_op() self._recompute_grad_op()
if self._grad_op is undef: return self._grad_op(*(list(inputs) + list(output_grads)), return_list=True)
return [None for _ in self.input_types]
return self._grad_op(*(list(inputs) + list(output_grads)))
def make_node(self, *inputs): def make_node(self, *inputs):
for input, type in zip(inputs, self.input_types): for input, type in zip(inputs, self.input_types):
...@@ -298,7 +383,7 @@ class OpFromGraph(gof.Op): ...@@ -298,7 +383,7 @@ class OpFromGraph(gof.Op):
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
variables = self.fn(*inputs) variables = self.fn(*inputs)
assert len(variables) == len(outputs) assert len(variables) == len(outputs)
for output, variable in zip(outputs, variables): for output, variable in izip(outputs, variables):
# TODO: when function's output-borrowing semantics are correct, # TODO: when function's output-borrowing semantics are correct,
# we wont need this copy anymore # we wont need this copy anymore
output[0] = variable.copy() output[0] = variable.copy()
......
...@@ -168,6 +168,47 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -168,6 +168,47 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
assert np.allclose(xv * 1.5, dwv) assert np.allclose(xv * 1.5, dwv)
assert np.allclose(np.ones(16, dtype=config.floatX), dbv) assert np.allclose(np.ones(16, dtype=config.floatX), dbv)
@test_params
def test_rop(self, cls_ofg):
a = T.vector()
M = T.matrix()
b = T.dot(a, M)
op_matmul = cls_ofg([a, M], [b])
x = T.vector()
W = T.matrix()
y = op_matmul(x, W)
du = T.vector()
dv = T.Rop(y, x, du)
fn = function([x, W, du], dv)
xval = numpy.random.rand(16).astype(config.floatX)
Wval = numpy.random.rand(16, 16).astype(config.floatX)
duval = numpy.random.rand(16).astype(config.floatX)
dvval = numpy.dot(duval, Wval)
dvval2 = fn(xval, Wval, duval)
print(dvval, dvval2)
assert numpy.allclose(dvval2, dvval)
@test_params
def test_rop_override(self, cls_ofg):
x, y = T.vectors('xy')
def ro(inps, epts):
x, y = inps
u, v = epts
return [u * y * 2. + x * v * 1.5]
op_mul = cls_ofg([x, y], [x * y], rop_overrides=ro)
xx, yy = T.vector('xx'), T.vector('yy')
zz = op_mul(xx, yy)
du, dv = T.vector('du'), T.vector('dv')
dw = T.Rop(zz, [xx, yy], [du, dv])
fn = function([xx, yy, du, dv], dw)
vals = numpy.random.rand(4, 32).astype(config.floatX)
dwval = fn(*vals)
assert numpy.allclose(
dwval, vals[0] * vals[3] * 1.5 + vals[1] * vals[2] * 2.)
@test_params @test_params
def test_nested(self, cls_ofg): def test_nested(self, cls_ofg):
x, y = T.vectors('xy') x, y = T.vectors('xy')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论