提交 55bb0f69 authored 作者: khaotik's avatar khaotik

grad -> L_op for OpFromGraph

上级 404cea07
...@@ -47,7 +47,8 @@ class OpFromGraph(gof.Op): ...@@ -47,7 +47,8 @@ class OpFromGraph(gof.Op):
OpFromGraph instance : Override with another OpFromGraph, should OpFromGraph instance : Override with another OpFromGraph, should
accept inputs as the same order and types of "inputs" and "output_grads" accept inputs as the same order and types of "inputs" and "output_grads"
arguments as one would specify in grad() method. arguments as one would specify in grad() method. This argument is mutually
exclusive to lop_overrides.
callable : similar to OpFromGraph instance, must return list of callable : similar to OpFromGraph instance, must return list of
:class:`Variable <theano.gof.Variable>`. :class:`Variable <theano.gof.Variable>`.
...@@ -60,6 +61,13 @@ class OpFromGraph(gof.Op): ...@@ -60,6 +61,13 @@ class OpFromGraph(gof.Op):
:class:`Variable <theano.gof.Variable>`. Each list element corresponds to gradient of :class:`Variable <theano.gof.Variable>`. Each list element corresponds to gradient of
a specific input, length of list must be equal to number of inputs. a specific input, length of list must be equal to number of inputs.
lop_overrides : single or list of {'default', OpFromGraph, callable, Variable with special type}, optional
Defaults to ``'default'``.
Similar to ``grad_overrides``, except callables should accept inputs as the
same order and types of "inputs", "outputs", and "output_grads" in L_op. This
argument is mutually exclusive with grad_overrides.
rop_overrides : single or list of {'default', OpFromGraph, callable, Variable with special type}, optional rop_overrides : single or list of {'default', OpFromGraph, callable, Variable with special type}, optional
Defaults to ``default``. Defaults to ``default``.
...@@ -96,7 +104,6 @@ class OpFromGraph(gof.Op): ...@@ -96,7 +104,6 @@ class OpFromGraph(gof.Op):
local_outputs) local_outputs)
- c_code() to remove the double overhead? - c_code() to remove the double overhead?
- grad() make it support DisconnectedType and the new interface - grad() make it support DisconnectedType and the new interface
- extend grad() to L_op
- add support for NullType and DisconnectedType when R_op supports them - add support for NullType and DisconnectedType when R_op supports them
- check how it works with updates. - check how it works with updates.
- add test with constant as input or inside the inner graph. - add test with constant as input or inside the inner graph.
...@@ -116,10 +123,10 @@ class OpFromGraph(gof.Op): ...@@ -116,10 +123,10 @@ class OpFromGraph(gof.Op):
- ``inline=True`` will cause better runtime optimization at the cost - ``inline=True`` will cause better runtime optimization at the cost
of compilation time. Currently only works with ``fast_compile`` or of compilation time. Currently only works with ``fast_compile`` or
``fast_run`` mode. ``fast_run`` mode.
- It's recommanded to provide pure functions (no side effects like - For overriding, it's recommanded to provide pure functions (no side
setting global variable) as callable(s). The callable(s) supplied effects like setting global variable) as callable(s). The callable(s)
for overriding gradient/rop will be called only once at the first supplied for overriding gradient/rop will be called only once at the
call to grad/R_op, and will be converted to OpFromGraph instances. first call to grad/R_op, and will be converted to OpFromGraph instances.
Examples Examples
-------- --------
...@@ -171,6 +178,13 @@ class OpFromGraph(gof.Op): ...@@ -171,6 +178,13 @@ class OpFromGraph(gof.Op):
fn(2., 3., 4.) # [1., 8., 3.] fn(2., 3., 4.) # [1., 8., 3.]
""" """
TYPE_ERR_MSG = ("L_op/gradient override should be (single or list of)"
"'default' | OpFromGraph | callable | Variable "
"with NullType or DisconnectedType, got %s")
STYPE_ERR_MSG = ('Overriding Variable instance can only have type'
' of DisconnectedType or NullType, got %s')
LOP_TYPE_ERR_MSG = 'L_op type can only be "grad" or "lop", got %s.'
OV_INP_LEN_ERR_MSG = 'expect overrider with %d inputs, got %d'
@staticmethod @staticmethod
def _filter_grad_var(grad, inp): def _filter_grad_var(grad, inp):
...@@ -182,9 +196,9 @@ class OpFromGraph(gof.Op): ...@@ -182,9 +196,9 @@ class OpFromGraph(gof.Op):
# a grad() call could return instance of NullType() or DisconnectedType() # a grad() call could return instance of NullType() or DisconnectedType()
# which cannot be directly used in OfG # which cannot be directly used in OfG
# #
# Since we always use an OfG instance as self._grad_op, the current # Since we always use an OfG instance as self._lop_op, the current
# workaround is to "remember" the special cases of the gradient and # workaround is to "remember" the special cases of the gradient and
# replace them after self._grad_op is called. # replace them after self._lop_op is called.
# #
# This helper function changes invalid types into a filtered_var, # This helper function changes invalid types into a filtered_var,
# and provides a overrider_var to be replaced at grad() call # and provides a overrider_var to be replaced at grad() call
...@@ -212,10 +226,12 @@ class OpFromGraph(gof.Op): ...@@ -212,10 +226,12 @@ class OpFromGraph(gof.Op):
return inpJ, None return inpJ, None
def __init__( def __init__(
self, inputs, outputs, self, inputs, outputs,
inline=False, inline=False,
grad_overrides='default', rop_overrides='default', lop_overrides='default',
name=None, **kwargs grad_overrides='default',
rop_overrides='default',
name=None, **kwargs
): ):
if not isinstance(outputs, list): if not isinstance(outputs, list):
raise TypeError('outputs must be list, got %s' % type(outputs)) raise TypeError('outputs must be list, got %s' % type(outputs))
...@@ -251,7 +267,18 @@ class OpFromGraph(gof.Op): ...@@ -251,7 +267,18 @@ class OpFromGraph(gof.Op):
self.kwargs = kwargs self.kwargs = kwargs
self.input_types = [inp.type for inp in inputs] self.input_types = [inp.type for inp in inputs]
self.output_types = [out.type for out in outputs] self.output_types = [out.type for out in outputs]
self.set_grad_overrides(grad_overrides) cond = int(lop_overrides == 'default') * 2 + int(grad_overrides == 'default')
if cond == 0:
raise ValueError('lop_overrides and rop_overrides are mutually exclusive')
elif cond == 1:
self.set_lop_overrides(lop_overrides)
self._lop_type = 'lop'
elif cond == 2:
self.set_lop_overrides(grad_overrides)
self._lop_type = 'grad'
else:
self.set_lop_overrides('default')
self._lop_type = 'lop'
self.set_rop_overrides(rop_overrides) self.set_rop_overrides(rop_overrides)
if name is not None: if name is not None:
...@@ -272,21 +299,38 @@ class OpFromGraph(gof.Op): ...@@ -272,21 +299,38 @@ class OpFromGraph(gof.Op):
return '%(name)s{inline=%(is_inline)s}' % locals() return '%(name)s{inline=%(is_inline)s}' % locals()
@theano.change_flags(compute_test_value='off') @theano.change_flags(compute_test_value='off')
def _recompute_grad_op(self): def _recompute_lop_op(self):
''' '''
converts self._grad_op from user supplied form to type(self) instance converts self._lop_op from user supplied form to type(self) instance
''' '''
local_inputs = self.local_inputs local_inputs = self.local_inputs
local_outputs = self.local_outputs local_outputs = self.local_outputs
inp_len = len(local_inputs) inp_len = len(local_inputs)
grad_op = self._grad_op lop_op = self._lop_op
if isinstance(grad_op, OpFromGraph): if isinstance(lop_op, OpFromGraph):
if not self._grad_op_is_cached: # OfG can be directly used if in L_op format
self._grad_op_is_cached = True if self._lop_type == 'grad':
self._grad_op_stypes_l = [None] * inp_len needed_ninps = inp_len + len(local_outputs)
return ninps = len(lop_op.local_inputs)
if needed_ninps != ninps:
raise ValueError(
self.OV_INP_LEN_ERR_MSG % (needed_ninps, ninps))
# make a wrapper callable
lop_op = lambda inps, grads: self._lop_op(*(inps + grads)) # noqa: 731
elif self._lop_type == 'lop':
needed_ninps = inp_len + 2 * len(local_outputs)
ninps = len(lop_op.local_inputs)
if needed_ninps != ninps:
raise ValueError(
self.OV_INP_LEN_ERR_MSG % (needed_ninps, ninps))
if not self._lop_op_is_cached:
self._lop_op_is_cached = True
self._lop_op_stypes_l = [None] * inp_len
return
else:
raise ValueError(self.LOP_TYPE_ERR_MSG % self._lop_type)
output_grads = [out_t() for out_t in self.output_types] output_grads = [out_t() for out_t in self.output_types]
fn_grad = partial( fn_grad = partial(
...@@ -297,26 +341,28 @@ class OpFromGraph(gof.Op): ...@@ -297,26 +341,28 @@ class OpFromGraph(gof.Op):
null_gradients='return', null_gradients='return',
known_grads=OrderedDict(izip(local_outputs, output_grads))) known_grads=OrderedDict(izip(local_outputs, output_grads)))
TYPE_ERR_MSG = ("Gradient override should be (single or list of)" if self._lop_type == 'lop':
"'default' | OpFromGraph | callable | Variable " callable_args = (local_inputs, local_outputs, output_grads)
"with NullType or DisconnectedType, got %s") elif self._lop_type == 'grad':
STYPE_ERR_MSG = ('Overriding Variable instance can only have type' callable_args = (local_inputs, output_grads)
' of DisconnectedType or NullType, got %s') else:
# we need to convert _grad_op into an OfG instance raise ValueError(self.LOP_TYPE_ERR_MSG % self._lop_type)
if grad_op == 'default':
# we need to convert _lop_op into an OfG instance
if lop_op == 'default':
gdefaults_l = fn_grad(wrt=local_inputs) gdefaults_l = fn_grad(wrt=local_inputs)
all_grads_l, all_grads_ov_l = izip( all_grads_l, all_grads_ov_l = izip(
*[OpFromGraph._filter_grad_var(grad, inp) for grad, inp in izip(gdefaults_l, local_inputs)]) *[OpFromGraph._filter_grad_var(grad, inp) for grad, inp in izip(gdefaults_l, local_inputs)])
all_grads_l = list(all_grads_l) all_grads_l = list(all_grads_l)
all_grads_ov_l = list(all_grads_ov_l) all_grads_ov_l = list(all_grads_ov_l)
elif isinstance(grad_op, Variable): elif isinstance(lop_op, Variable):
if isinstance(grad_op.type, (DisconnectedType, NullType)): if isinstance(lop_op.type, (DisconnectedType, NullType)):
all_grads_l = [inp.zeros_like() for inp in local_inputs] all_grads_l = [inp.zeros_like() for inp in local_inputs]
all_grads_ov_l = [grad_op.type() for _ in range(inp_len)] all_grads_ov_l = [lop_op.type() for _ in range(inp_len)]
else: else:
raise ValueError(STYPE_ERR_MSG % grad_op.type) raise ValueError(self.STYPE_ERR_MSG % lop_op.type)
elif isinstance(grad_op, list): elif isinstance(lop_op, list):
goverrides_l = grad_op goverrides_l = lop_op
if len(goverrides_l) != inp_len: if len(goverrides_l) != inp_len:
raise ValueError( raise ValueError(
'Need to override %d gradients, got %d' % ( 'Need to override %d gradients, got %d' % (
...@@ -340,40 +386,41 @@ class OpFromGraph(gof.Op): ...@@ -340,40 +386,41 @@ class OpFromGraph(gof.Op):
all_grads_l.append(inp.zeros_like()) all_grads_l.append(inp.zeros_like())
all_grads_ov_l.append(fn_gov.type()) all_grads_ov_l.append(fn_gov.type())
else: else:
raise ValueError(STYPE_ERR_MSG % fn_gov.type) raise ValueError(self.STYPE_ERR_MSG % fn_gov.type)
else: else:
if not hasattr(fn_gov, '__call__'): if not callable(fn_gov):
raise TypeError(TYPE_ERR_MSG % fn_gov) raise TypeError(self.TYPE_ERR_MSG % fn_gov)
gov, gov_ov = OpFromGraph._filter_grad_var( gov, gov_ov = OpFromGraph._filter_grad_var(
fn_gov(local_inputs, output_grads), inp) fn_gov(*callable_args), inp)
all_grads_l.append(gov) all_grads_l.append(gov)
all_grads_ov_l.append(gov_ov) all_grads_ov_l.append(gov_ov)
else: else:
# callable case # callable case
if not hasattr(grad_op, '__call__'): if not callable(lop_op):
raise TypeError(TYPE_ERR_MSG % grad_op) raise TypeError(self.TYPE_ERR_MSG % lop_op)
goverrides_l = grad_op(local_inputs, output_grads) goverrides_l = lop_op(*callable_args)
if not isinstance(goverrides_l, list): if not isinstance(goverrides_l, list):
raise TypeError( raise TypeError(
'Gradient overriding function should return a list, ' 'Gradient/L_op overriding function should return a list, '
'got "%s"' % type(goverrides_l)) 'got "%s"' % type(goverrides_l))
all_grads_l, all_grads_ov_l = izip( all_grads_l, all_grads_ov_l = izip(
*[OpFromGraph._filter_grad_var(grad, inp) *[OpFromGraph._filter_grad_var(grad, inp)
for grad, inp in izip(goverrides_l, local_inputs)]) for grad, inp in izip(goverrides_l, local_inputs)])
if len(all_grads_l) != len(local_inputs): if len(all_grads_l) != len(local_inputs):
raise ValueError( raise ValueError(
'Gradient overriding function should return list of ' 'Gradient/L_op overriding function should return list of '
'%d outputs, got %d' % (inp_len, len(all_grads_l))) '%d outputs, got %d' % (inp_len, len(all_grads_l)))
all_grads_l = list(all_grads_l) all_grads_l = list(all_grads_l)
all_grads_ov_l = list(all_grads_ov_l) all_grads_ov_l = list(all_grads_ov_l)
self._grad_op = type(self)( self._lop_op = type(self)(
inputs=local_inputs + output_grads, inputs=local_inputs + local_outputs + output_grads,
outputs=all_grads_l, outputs=all_grads_l,
inline=self.is_inline, inline=self.is_inline,
name=(None if self.name is None else self.name + '_grad'), name=(None if self.name is None else self.name + '_lop'),
on_unused_input='ignore') on_unused_input='ignore')
self._grad_op_stypes_l = all_grads_ov_l self._lop_op_stypes_l = all_grads_ov_l
self._grad_op_is_cached = True self._lop_op_is_cached = True
self._lop_type = 'lop'
@theano.change_flags(compute_test_value='off') @theano.change_flags(compute_test_value='off')
def _recompute_rop_op(self): def _recompute_rop_op(self):
...@@ -448,14 +495,14 @@ class OpFromGraph(gof.Op): ...@@ -448,14 +495,14 @@ class OpFromGraph(gof.Op):
else: else:
raise ValueError(STYPE_ERR_MSG % fn_rov.type) raise ValueError(STYPE_ERR_MSG % fn_rov.type)
else: else:
if not hasattr(fn_rov, '__call__'): if not callable(fn_rov):
raise TypeError(TYPE_ERR_MSG % fn_rov) raise TypeError(TYPE_ERR_MSG % fn_rov)
rov, rov_ov = OpFromGraph._filter_rop_var( rov, rov_ov = OpFromGraph._filter_rop_var(
fn_rov(local_inputs, eval_points), out) fn_rov(local_inputs, eval_points), out)
all_rops_l.append(rov) all_rops_l.append(rov)
all_rops_ov_l.append(rov_ov) all_rops_ov_l.append(rov_ov)
else: else:
if not hasattr(rop_op, '__call__'): if not callable(rop_op):
raise TypeError(TYPE_ERR_MSG % rop_op) raise TypeError(TYPE_ERR_MSG % rop_op)
roverrides_l = rop_op(local_inputs, eval_points) roverrides_l = rop_op(local_inputs, eval_points)
if not isinstance(roverrides_l, list): if not isinstance(roverrides_l, list):
...@@ -482,13 +529,13 @@ class OpFromGraph(gof.Op): ...@@ -482,13 +529,13 @@ class OpFromGraph(gof.Op):
self._rop_op_stypes_l = all_rops_ov_l self._rop_op_stypes_l = all_rops_ov_l
self._rop_op_is_cached = True self._rop_op_is_cached = True
def get_grad_op(self): def get_lop_op(self):
""" """
getter method for self._grad_op getter method for self._lop_op
""" """
if not self._grad_op_is_cached: if not self._lop_op_is_cached:
self._recompute_grad_op() self._recompute_lop_op()
return self._grad_op return self._lop_op
def get_rop_op(self): def get_rop_op(self):
""" """
...@@ -501,11 +548,24 @@ class OpFromGraph(gof.Op): ...@@ -501,11 +548,24 @@ class OpFromGraph(gof.Op):
def set_grad_overrides(self, grad_overrides): def set_grad_overrides(self, grad_overrides):
""" """
Set gradient overrides, see help(theano.OpFromGraph) for syntax Set gradient overrides, see help(theano.OpFromGraph) for syntax
This will completely remove any previously set gradient overrides This will completely remove any previously set L_op/gradient overrides
"""
self._lop_op = grad_overrides
self._lop_op_is_cached = False
self._lop_type = 'grad'
self._lop_is_default = (grad_overrides == 'default')
def set_lop_overrides(self, lop_overrides):
"""
Set L_op overrides, see help(theano.OpFromGraph) for syntax
This will completely remove any previously set L_op/gradient overrides
""" """
self._grad_op = grad_overrides self._lop_op = lop_overrides
self._grad_op_is_cached = False self._lop_op_is_cached = False
self._lop_type = 'lop'
self._lop_is_default = (lop_overrides == 'default')
def set_rop_overrides(self, rop_overrides): def set_rop_overrides(self, rop_overrides):
""" """
...@@ -515,15 +575,18 @@ class OpFromGraph(gof.Op): ...@@ -515,15 +575,18 @@ class OpFromGraph(gof.Op):
""" """
self._rop_op = rop_overrides self._rop_op = rop_overrides
self._rop_op_is_cached = False self._rop_op_is_cached = False
self._rop_is_default = (rop_overrides == 'default')
def grad(self, inputs, output_grads):
if not self._grad_op_is_cached: def L_op(self, inputs, outputs, output_grads):
self._recompute_grad_op() if not self._lop_op_is_cached:
ret_ofg_l = self._grad_op( self._recompute_lop_op()
*(list(inputs) + list(output_grads)), return_list=True) print('L_op', inputs, outputs, output_grads)
inps = list(inputs) + list(outputs) + list(output_grads)
ret_ofg_l = self._lop_op(
*inps, return_list=True)
ret_l = [ ret_l = [
ret_ofg if ov is None else ov for ret_ofg, ov in izip( ret_ofg if ov is None else ov for ret_ofg, ov in izip(
ret_ofg_l, self._grad_op_stypes_l)] ret_ofg_l, self._lop_op_stypes_l)]
return ret_l return ret_l
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
...@@ -537,6 +600,7 @@ class OpFromGraph(gof.Op): ...@@ -537,6 +600,7 @@ class OpFromGraph(gof.Op):
return ret_l return ret_l
def make_node(self, *inputs): def make_node(self, *inputs):
print('make_node', inputs)
num_expected_inps = len(self.local_inputs) - len(self.shared_inputs) num_expected_inps = len(self.local_inputs) - len(self.shared_inputs)
if len(inputs) != num_expected_inps: if len(inputs) != num_expected_inps:
raise ValueError( raise ValueError(
...@@ -559,14 +623,14 @@ class OpFromGraph(gof.Op): ...@@ -559,14 +623,14 @@ class OpFromGraph(gof.Op):
cpmat_self = io_connection_pattern( cpmat_self = io_connection_pattern(
self.local_inputs, self.local_outputs) self.local_inputs, self.local_outputs)
grad_op = self.get_grad_op() grad_op = self.get_lop_op()
cpmat_grad = io_connection_pattern( cpmat_grad = io_connection_pattern(
grad_op.local_inputs[inp_len:], grad_op.local_inputs[inp_len:],
grad_op.local_outputs) grad_op.local_outputs)
# cpmat_self |= cpmat_grad.T # cpmat_self |= cpmat_grad.T
# cpmat_self &= out_is_disconnected # cpmat_self &= out_is_disconnected
for i, t in enumerate(self._grad_op_stypes_l): for i, t in enumerate(self._lop_op_stypes_l):
if t is not None: if t is not None:
if isinstance(t.type, DisconnectedType): if isinstance(t.type, DisconnectedType):
for o in range(out_len): for o in range(out_len):
...@@ -619,6 +683,17 @@ class OpFromGraph(gof.Op): ...@@ -619,6 +683,17 @@ class OpFromGraph(gof.Op):
# we wont need this copy anymore # we wont need this copy anymore
output[0] = variable.copy() output[0] = variable.copy()
def copy(self):
'''
Make a shallow copy, gradient/R_op overrides will remain same objects
'''
cpy = type(self)(self.local_inputs, self.local_outputs, inline=self.is_inline)
if not self._lop_is_default:
cpy.set_lop_overrides(self.get_lop_op())
if not self._rop_is_default:
cpy.set_rop_overrides(self.get_rop_op())
return cpy
@gof.local_optimizer([OpFromGraph]) @gof.local_optimizer([OpFromGraph])
def inline_ofg_expansion(node): def inline_ofg_expansion(node):
......
...@@ -163,7 +163,8 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -163,7 +163,8 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
w, b = T.vectors('wb') w, b = T.vectors('wb')
# we make the 3rd gradient default (no override) # we make the 3rd gradient default (no override)
op_linear = cls_ofg([x, w, b], [x * w + b], grad_overrides=[go1, go2, 'default']) op_linear = cls_ofg(
[x, w, b], [x * w + b], grad_overrides=[go1, go2, 'default'])
xx, ww, bb = T.vector('xx'), T.vector('yy'), T.vector('bb') xx, ww, bb = T.vector('xx'), T.vector('yy'), T.vector('bb')
zz = T.sum(op_linear(xx, ww, bb)) zz = T.sum(op_linear(xx, ww, bb))
dx, dw, db = T.grad(zz, [xx, ww, bb]) dx, dw, db = T.grad(zz, [xx, ww, bb])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论