提交 b9854ee7 authored 作者: abergeron's avatar abergeron 提交者: GitHub

Merge pull request #5641 from khaotik/ofg_lop

L_op for OpFromGraph
......@@ -42,15 +42,39 @@ class OpFromGraph(gof.Op):
grad_overrides : single or list of {'default', OpFromGraph, callable, Variable with special type}, optional
Defaults to ``'default'``.
This argument is mutually exclusive with lop_overrides.
``'default'`` : Do not override, use default grad() result
OpFromGraph instance : Override with another OpFromGraph, should
accept inputs as the same order and types of "inputs" and "output_grads"
accept inputs as the same order and types of ``inputs`` and ``output_grads``
arguments as one would specify in grad() method.
callable : similar to OpFromGraph instance, must return list of
:class:`Variable <theano.gof.Variable>`.
callable : Should take two args: ``inputs`` and ``output_grads``.
Each argument is expected to be a list of :class:`Variable <theano.gof.Variable>`.
Must return list of :class:`Variable <theano.gof.Variable>`.
Variable :
``NullType() instance`` : Treat as non-differentiable
``DisconnectedType() instance`` : Treat as disconnected gradient, numerically gives zero
list: Each OpFromGraph/callable must return a single
:class:`Variable <theano.gof.Variable>`. Each list element corresponds to gradient of
a specific input, length of list must be equal to number of inputs.
lop_overrides : single or list of {'default', OpFromGraph, callable, Variable with special type}, optional
Defaults to ``'default'``.
This argument is mutually exclusive with ``grad_overrides``.
``'default'`` : Do not override, use default L_op() result
OpFromGraph instance : Override with another OpFromGraph, should
accept inputs as the same order and types of ``inputs``, ``outputs`` and ``output_grads``
arguments as one would specify in grad() method.
callable : Should take three args: ``inputs``, ``outputs`` and ``output_grads``.
Each argument is expected to be a list of :class:`Variable <theano.gof.Variable>`.
Must return list of :class:`Variable <theano.gof.Variable>`.
Variable :
``NullType() instance`` : Treat as non-differentiable
......@@ -66,11 +90,12 @@ class OpFromGraph(gof.Op):
``'default'`` : Do not override, use default R_op() result
OpFromGraph instance : Override with another OpFromGraph, should
accept inputs as the same order and types of "inputs" and "output_grads"
arguments as one would specify in grad() method.
accept inputs as the same order and types of ``inputs`` and ``eval_points``
arguments as one would specify in R_op() method.
callable : similar to OpFromGraph instance, must return list of
:class:`Variable <theano.gof.Variable>`.
callable : Should take two args: ``inputs`` and ``eval_points``.
Each argument is expected to be a list of :class:`Variable <theano.gof.Variable>`.
Must return list of :class:`Variable <theano.gof.Variable>`.
Variable :
``NullType() instance`` : Treat as non-differentiable
......@@ -96,7 +121,6 @@ class OpFromGraph(gof.Op):
local_outputs)
- c_code() to remove the double overhead?
- grad() make it support DisconnectedType and the new interface
- extend grad() to L_op
- add support for NullType and DisconnectedType when R_op supports them
- check how it works with updates.
- add test with constant as input or inside the inner graph.
......@@ -116,10 +140,10 @@ class OpFromGraph(gof.Op):
- ``inline=True`` will cause better runtime optimization at the cost
of compilation time. Currently only works with ``fast_compile`` or
``fast_run`` mode.
- It's recommanded to provide pure functions (no side effects like
setting global variable) as callable(s). The callable(s) supplied
for overriding gradient/rop will be called only once at the first
call to grad/R_op, and will be converted to OpFromGraph instances.
- For overriding, it's recommended to provide pure functions (no side
effects like setting global variable) as callable(s). The callable(s)
supplied for overriding gradient/rop will be called only once at the
first call to grad/R_op, and will be converted to OpFromGraph instances.
Examples
--------
......@@ -171,6 +195,13 @@ class OpFromGraph(gof.Op):
fn(2., 3., 4.) # [1., 8., 3.]
"""
TYPE_ERR_MSG = ("L_op/gradient override should be (single or list of)"
"'default' | OpFromGraph | callable | Variable "
"with NullType or DisconnectedType, got %s")
STYPE_ERR_MSG = ('Overriding Variable instance can only have type'
' of DisconnectedType or NullType, got %s')
LOP_TYPE_ERR_MSG = 'L_op type can only be "grad" or "lop", got %s.'
OV_INP_LEN_ERR_MSG = 'expect overrider with %d inputs, got %d'
@staticmethod
def _filter_grad_var(grad, inp):
......@@ -182,9 +213,9 @@ class OpFromGraph(gof.Op):
# a grad() call could return instance of NullType() or DisconnectedType()
# which cannot be directly used in OfG
#
# Since we always use an OfG instance as self._grad_op, the current
# Since we always use an OfG instance as self._lop_op, the current
# workaround is to "remember" the special cases of the gradient and
# replace them after self._grad_op is called.
# replace them after self._lop_op is called.
#
# This helper function changes invalid types into a filtered_var,
# and provides a overrider_var to be replaced at grad() call
......@@ -212,10 +243,12 @@ class OpFromGraph(gof.Op):
return inpJ, None
def __init__(
self, inputs, outputs,
inline=False,
grad_overrides='default', rop_overrides='default',
name=None, **kwargs
self, inputs, outputs,
inline=False,
lop_overrides='default',
grad_overrides='default',
rop_overrides='default',
name=None, **kwargs
):
if not isinstance(outputs, list):
raise TypeError('outputs must be list, got %s' % type(outputs))
......@@ -251,7 +284,18 @@ class OpFromGraph(gof.Op):
self.kwargs = kwargs
self.input_types = [inp.type for inp in inputs]
self.output_types = [out.type for out in outputs]
self.set_grad_overrides(grad_overrides)
if lop_overrides != 'default':
if grad_overrides != 'default':
raise ValueError('lop_overrides and grad_overrides are mutually exclusive')
else:
self.set_lop_overrides(lop_overrides)
self._lop_type = 'lop'
elif grad_overrides != 'default':
self.set_lop_overrides(grad_overrides)
self._lop_type = 'grad'
else:
self.set_lop_overrides('default')
self._lop_type = 'lop'
self.set_rop_overrides(rop_overrides)
if name is not None:
......@@ -272,21 +316,42 @@ class OpFromGraph(gof.Op):
return '%(name)s{inline=%(is_inline)s}' % locals()
@theano.change_flags(compute_test_value='off')
def _recompute_grad_op(self):
def _recompute_lop_op(self):
'''
converts self._grad_op from user supplied form to type(self) instance
converts self._lop_op from user supplied form to type(self) instance
'''
local_inputs = self.local_inputs
local_outputs = self.local_outputs
inp_len = len(local_inputs)
grad_op = self._grad_op
if isinstance(grad_op, OpFromGraph):
if not self._grad_op_is_cached:
self._grad_op_is_cached = True
self._grad_op_stypes_l = [None] * inp_len
return
lop_op = self._lop_op
if isinstance(lop_op, OpFromGraph):
if self._lop_op_is_cached:
return
assert self._lop_type in ['lop', 'grad'],\
self.LOP_TYPE_ERR_MSG % self._lop_type
if self._lop_type == 'grad':
needed_ninps = inp_len + len(local_outputs)
ninps = len(lop_op.local_inputs)
if needed_ninps != ninps:
raise ValueError(
self.OV_INP_LEN_ERR_MSG % (needed_ninps, ninps))
# make a wrapper callable
def lop_op(inps, grads):
return self._lop_op(*(inps + grads))
elif self._lop_type == 'lop':
# OfG can be directly used in L_op format
needed_ninps = inp_len + 2 * len(local_outputs)
ninps = len(lop_op.local_inputs)
if needed_ninps != ninps:
raise ValueError(
self.OV_INP_LEN_ERR_MSG % (needed_ninps, ninps))
self._lop_op_is_cached = True
self._lop_op_stypes_l = [None] * inp_len
self._lop_op.kwargs['on_unused_input'] = 'ignore'
return
output_grads = [out_t() for out_t in self.output_types]
fn_grad = partial(
......@@ -297,26 +362,28 @@ class OpFromGraph(gof.Op):
null_gradients='return',
known_grads=OrderedDict(izip(local_outputs, output_grads)))
TYPE_ERR_MSG = ("Gradient override should be (single or list of)"
"'default' | OpFromGraph | callable | Variable "
"with NullType or DisconnectedType, got %s")
STYPE_ERR_MSG = ('Overriding Variable instance can only have type'
' of DisconnectedType or NullType, got %s')
# we need to convert _grad_op into an OfG instance
if grad_op == 'default':
assert self._lop_type in ['lop', 'grad'],\
self.LOP_TYPE_ERR_MSG % self._lop_type
if self._lop_type == 'lop':
callable_args = (local_inputs, local_outputs, output_grads)
elif self._lop_type == 'grad':
callable_args = (local_inputs, output_grads)
# we need to convert _lop_op into an OfG instance
if lop_op == 'default':
gdefaults_l = fn_grad(wrt=local_inputs)
all_grads_l, all_grads_ov_l = izip(
*[OpFromGraph._filter_grad_var(grad, inp) for grad, inp in izip(gdefaults_l, local_inputs)])
all_grads_l = list(all_grads_l)
all_grads_ov_l = list(all_grads_ov_l)
elif isinstance(grad_op, Variable):
if isinstance(grad_op.type, (DisconnectedType, NullType)):
elif isinstance(lop_op, Variable):
if isinstance(lop_op.type, (DisconnectedType, NullType)):
all_grads_l = [inp.zeros_like() for inp in local_inputs]
all_grads_ov_l = [grad_op.type() for _ in range(inp_len)]
all_grads_ov_l = [lop_op.type() for _ in range(inp_len)]
else:
raise ValueError(STYPE_ERR_MSG % grad_op.type)
elif isinstance(grad_op, list):
goverrides_l = grad_op
raise ValueError(self.STYPE_ERR_MSG % lop_op.type)
elif isinstance(lop_op, list):
goverrides_l = lop_op
if len(goverrides_l) != inp_len:
raise ValueError(
'Need to override %d gradients, got %d' % (
......@@ -340,40 +407,41 @@ class OpFromGraph(gof.Op):
all_grads_l.append(inp.zeros_like())
all_grads_ov_l.append(fn_gov.type())
else:
raise ValueError(STYPE_ERR_MSG % fn_gov.type)
raise ValueError(self.STYPE_ERR_MSG % fn_gov.type)
else:
if not hasattr(fn_gov, '__call__'):
raise TypeError(TYPE_ERR_MSG % fn_gov)
if not callable(fn_gov):
raise TypeError(self.TYPE_ERR_MSG % fn_gov)
gov, gov_ov = OpFromGraph._filter_grad_var(
fn_gov(local_inputs, output_grads), inp)
fn_gov(*callable_args), inp)
all_grads_l.append(gov)
all_grads_ov_l.append(gov_ov)
else:
# callable case
if not hasattr(grad_op, '__call__'):
raise TypeError(TYPE_ERR_MSG % grad_op)
goverrides_l = grad_op(local_inputs, output_grads)
if not callable(lop_op):
raise TypeError(self.TYPE_ERR_MSG % lop_op)
goverrides_l = lop_op(*callable_args)
if not isinstance(goverrides_l, list):
raise TypeError(
'Gradient overriding function should return a list, '
'Gradient/L_op overriding function should return a list, '
'got "%s"' % type(goverrides_l))
all_grads_l, all_grads_ov_l = izip(
*[OpFromGraph._filter_grad_var(grad, inp)
for grad, inp in izip(goverrides_l, local_inputs)])
if len(all_grads_l) != len(local_inputs):
raise ValueError(
'Gradient overriding function should return list of '
'Gradient/L_op overriding function should return list of '
'%d outputs, got %d' % (inp_len, len(all_grads_l)))
all_grads_l = list(all_grads_l)
all_grads_ov_l = list(all_grads_ov_l)
self._grad_op = type(self)(
inputs=local_inputs + output_grads,
self._lop_op = type(self)(
inputs=local_inputs + local_outputs + output_grads,
outputs=all_grads_l,
inline=self.is_inline,
name=(None if self.name is None else self.name + '_grad'),
name=(None if self.name is None else self.name + '_' + self._lop_type),
on_unused_input='ignore')
self._grad_op_stypes_l = all_grads_ov_l
self._grad_op_is_cached = True
self._lop_op_stypes_l = all_grads_ov_l
self._lop_op_is_cached = True
self._lop_type = 'lop'
@theano.change_flags(compute_test_value='off')
def _recompute_rop_op(self):
......@@ -448,14 +516,14 @@ class OpFromGraph(gof.Op):
else:
raise ValueError(STYPE_ERR_MSG % fn_rov.type)
else:
if not hasattr(fn_rov, '__call__'):
if not callable(fn_rov):
raise TypeError(TYPE_ERR_MSG % fn_rov)
rov, rov_ov = OpFromGraph._filter_rop_var(
fn_rov(local_inputs, eval_points), out)
all_rops_l.append(rov)
all_rops_ov_l.append(rov_ov)
else:
if not hasattr(rop_op, '__call__'):
if not callable(rop_op):
raise TypeError(TYPE_ERR_MSG % rop_op)
roverrides_l = rop_op(local_inputs, eval_points)
if not isinstance(roverrides_l, list):
......@@ -482,13 +550,13 @@ class OpFromGraph(gof.Op):
self._rop_op_stypes_l = all_rops_ov_l
self._rop_op_is_cached = True
def get_grad_op(self):
def get_lop_op(self):
"""
getter method for self._grad_op
getter method for self._lop_op
"""
if not self._grad_op_is_cached:
self._recompute_grad_op()
return self._grad_op
if not self._lop_op_is_cached:
self._recompute_lop_op()
return self._lop_op
def get_rop_op(self):
"""
......@@ -501,11 +569,24 @@ class OpFromGraph(gof.Op):
def set_grad_overrides(self, grad_overrides):
"""
Set gradient overrides, see help(theano.OpFromGraph) for syntax
This will completely remove any previously set gradient overrides
This will completely remove any previously set L_op/gradient overrides
"""
self._lop_op = grad_overrides
self._lop_op_is_cached = False
self._lop_type = 'grad'
self._lop_is_default = (grad_overrides == 'default')
def set_lop_overrides(self, lop_overrides):
"""
self._grad_op = grad_overrides
self._grad_op_is_cached = False
Set L_op overrides, see help(theano.OpFromGraph) for syntax
This will completely remove any previously set L_op/gradient overrides
"""
self._lop_op = lop_overrides
self._lop_op_is_cached = False
self._lop_type = 'lop'
self._lop_is_default = (lop_overrides == 'default')
def set_rop_overrides(self, rop_overrides):
"""
......@@ -515,15 +596,17 @@ class OpFromGraph(gof.Op):
"""
self._rop_op = rop_overrides
self._rop_op_is_cached = False
def grad(self, inputs, output_grads):
if not self._grad_op_is_cached:
self._recompute_grad_op()
ret_ofg_l = self._grad_op(
*(list(inputs) + list(output_grads)), return_list=True)
self._rop_is_default = (rop_overrides == 'default')
def L_op(self, inputs, outputs, output_grads):
if not self._lop_op_is_cached:
self._recompute_lop_op()
inps = list(inputs) + list(outputs) + list(output_grads)
ret_ofg_l = self._lop_op(
*inps, return_list=True)
ret_l = [
ret_ofg if ov is None else ov for ret_ofg, ov in izip(
ret_ofg_l, self._grad_op_stypes_l)]
ret_ofg_l, self._lop_op_stypes_l)]
return ret_l
def R_op(self, inputs, eval_points):
......@@ -559,14 +642,14 @@ class OpFromGraph(gof.Op):
cpmat_self = io_connection_pattern(
self.local_inputs, self.local_outputs)
grad_op = self.get_grad_op()
lop_op = self.get_lop_op()
cpmat_grad = io_connection_pattern(
grad_op.local_inputs[inp_len:],
grad_op.local_outputs)
lop_op.local_inputs[inp_len:],
lop_op.local_outputs)
# cpmat_self |= cpmat_grad.T
# cpmat_self &= out_is_disconnected
for i, t in enumerate(self._grad_op_stypes_l):
for i, t in enumerate(self._lop_op_stypes_l):
if t is not None:
if isinstance(t.type, DisconnectedType):
for o in range(out_len):
......
......@@ -163,7 +163,8 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
w, b = T.vectors('wb')
# we make the 3rd gradient default (no override)
op_linear = cls_ofg([x, w, b], [x * w + b], grad_overrides=[go1, go2, 'default'])
op_linear = cls_ofg(
[x, w, b], [x * w + b], grad_overrides=[go1, go2, 'default'])
xx, ww, bb = T.vector('xx'), T.vector('yy'), T.vector('bb')
zz = T.sum(op_linear(xx, ww, bb))
dx, dw, db = T.grad(zz, [xx, ww, bb])
......@@ -191,6 +192,33 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
assert isinstance(dw2.type, NullType)
assert isinstance(db2.type, DisconnectedType)
@test_params
def test_lop_override(self, cls_ofg):
x = T.vector()
y = 1. / (1. + T.exp(-x))
def lop_ov(inps, outs, grads):
y_, = outs
dedy_, = grads
return [2. * y_ * (1. - y_) * dedy_]
y_, dedy = T.vector(), T.vector()
op_lop_ov = cls_ofg([x, y_, dedy], [2. * y_ * (1. - y_) * dedy])
xx = T.vector()
yy1 = T.sum(T.nnet.sigmoid(xx))
gyy1 = 2. * T.grad(yy1, xx)
for ov in [lop_ov, op_lop_ov]:
op = cls_ofg([x], [y], lop_overrides=ov)
yy2 = T.sum(op(xx))
gyy2 = T.grad(yy2, xx)
fn = function([xx], [gyy1, gyy2])
xval = np.random.rand(32).astype(config.floatX)
y1val, y2val = fn(xval)
assert np.allclose(y1val, y2val)
@test_params
def test_rop(self, cls_ofg):
a = T.vector()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论