提交 55bb0f69 authored 作者: khaotik's avatar khaotik

grad -> L_op for OpFromGraph

上级 404cea07
......@@ -47,7 +47,8 @@ class OpFromGraph(gof.Op):
OpFromGraph instance : Override with another OpFromGraph, should
accept inputs as the same order and types of "inputs" and "output_grads"
arguments as one would specify in grad() method.
arguments as one would specify in grad() method. This argument is mutually
exclusive to lop_overrides.
callable : similar to OpFromGraph instance, must return list of
:class:`Variable <theano.gof.Variable>`.
......@@ -60,6 +61,13 @@ class OpFromGraph(gof.Op):
:class:`Variable <theano.gof.Variable>`. Each list element corresponds to gradient of
a specific input, length of list must be equal to number of inputs.
lop_overrides : single or list of {'default', OpFromGraph, callable, Variable with special type}, optional
Defaults to ``'default'``.
Similar to ``grad_overrides``, except callables should accept inputs as the
same order and types of "inputs", "outputs", and "output_grads" in L_op. This
argument is mutually exclusive with grad_overrides.
rop_overrides : single or list of {'default', OpFromGraph, callable, Variable with special type}, optional
Defaults to ``default``.
......@@ -96,7 +104,6 @@ class OpFromGraph(gof.Op):
local_outputs)
- c_code() to remove the double overhead?
- grad() make it support DisconnectedType and the new interface
- extend grad() to L_op
- add support for NullType and DisconnectedType when R_op supports them
- check how it works with updates.
- add test with constant as input or inside the inner graph.
......@@ -116,10 +123,10 @@ class OpFromGraph(gof.Op):
- ``inline=True`` will cause better runtime optimization at the cost
of compilation time. Currently only works with ``fast_compile`` or
``fast_run`` mode.
- It's recommanded to provide pure functions (no side effects like
setting global variable) as callable(s). The callable(s) supplied
for overriding gradient/rop will be called only once at the first
call to grad/R_op, and will be converted to OpFromGraph instances.
- For overriding, it's recommanded to provide pure functions (no side
effects like setting global variable) as callable(s). The callable(s)
supplied for overriding gradient/rop will be called only once at the
first call to grad/R_op, and will be converted to OpFromGraph instances.
Examples
--------
......@@ -171,6 +178,13 @@ class OpFromGraph(gof.Op):
fn(2., 3., 4.) # [1., 8., 3.]
"""
TYPE_ERR_MSG = ("L_op/gradient override should be (single or list of)"
"'default' | OpFromGraph | callable | Variable "
"with NullType or DisconnectedType, got %s")
STYPE_ERR_MSG = ('Overriding Variable instance can only have type'
' of DisconnectedType or NullType, got %s')
LOP_TYPE_ERR_MSG = 'L_op type can only be "grad" or "lop", got %s.'
OV_INP_LEN_ERR_MSG = 'expect overrider with %d inputs, got %d'
@staticmethod
def _filter_grad_var(grad, inp):
......@@ -182,9 +196,9 @@ class OpFromGraph(gof.Op):
# a grad() call could return instance of NullType() or DisconnectedType()
# which cannot be directly used in OfG
#
# Since we always use an OfG instance as self._grad_op, the current
# Since we always use an OfG instance as self._lop_op, the current
# workaround is to "remember" the special cases of the gradient and
# replace them after self._grad_op is called.
# replace them after self._lop_op is called.
#
# This helper function changes invalid types into a filtered_var,
# and provides a overrider_var to be replaced at grad() call
......@@ -212,10 +226,12 @@ class OpFromGraph(gof.Op):
return inpJ, None
def __init__(
self, inputs, outputs,
inline=False,
grad_overrides='default', rop_overrides='default',
name=None, **kwargs
self, inputs, outputs,
inline=False,
lop_overrides='default',
grad_overrides='default',
rop_overrides='default',
name=None, **kwargs
):
if not isinstance(outputs, list):
raise TypeError('outputs must be list, got %s' % type(outputs))
......@@ -251,7 +267,18 @@ class OpFromGraph(gof.Op):
self.kwargs = kwargs
self.input_types = [inp.type for inp in inputs]
self.output_types = [out.type for out in outputs]
self.set_grad_overrides(grad_overrides)
cond = int(lop_overrides == 'default') * 2 + int(grad_overrides == 'default')
if cond == 0:
raise ValueError('lop_overrides and rop_overrides are mutually exclusive')
elif cond == 1:
self.set_lop_overrides(lop_overrides)
self._lop_type = 'lop'
elif cond == 2:
self.set_lop_overrides(grad_overrides)
self._lop_type = 'grad'
else:
self.set_lop_overrides('default')
self._lop_type = 'lop'
self.set_rop_overrides(rop_overrides)
if name is not None:
......@@ -272,21 +299,38 @@ class OpFromGraph(gof.Op):
return '%(name)s{inline=%(is_inline)s}' % locals()
@theano.change_flags(compute_test_value='off')
def _recompute_grad_op(self):
def _recompute_lop_op(self):
'''
converts self._grad_op from user supplied form to type(self) instance
converts self._lop_op from user supplied form to type(self) instance
'''
local_inputs = self.local_inputs
local_outputs = self.local_outputs
inp_len = len(local_inputs)
grad_op = self._grad_op
if isinstance(grad_op, OpFromGraph):
if not self._grad_op_is_cached:
self._grad_op_is_cached = True
self._grad_op_stypes_l = [None] * inp_len
return
lop_op = self._lop_op
if isinstance(lop_op, OpFromGraph):
# OfG can be directly used if in L_op format
if self._lop_type == 'grad':
needed_ninps = inp_len + len(local_outputs)
ninps = len(lop_op.local_inputs)
if needed_ninps != ninps:
raise ValueError(
self.OV_INP_LEN_ERR_MSG % (needed_ninps, ninps))
# make a wrapper callable
lop_op = lambda inps, grads: self._lop_op(*(inps + grads)) # noqa: 731
elif self._lop_type == 'lop':
needed_ninps = inp_len + 2 * len(local_outputs)
ninps = len(lop_op.local_inputs)
if needed_ninps != ninps:
raise ValueError(
self.OV_INP_LEN_ERR_MSG % (needed_ninps, ninps))
if not self._lop_op_is_cached:
self._lop_op_is_cached = True
self._lop_op_stypes_l = [None] * inp_len
return
else:
raise ValueError(self.LOP_TYPE_ERR_MSG % self._lop_type)
output_grads = [out_t() for out_t in self.output_types]
fn_grad = partial(
......@@ -297,26 +341,28 @@ class OpFromGraph(gof.Op):
null_gradients='return',
known_grads=OrderedDict(izip(local_outputs, output_grads)))
TYPE_ERR_MSG = ("Gradient override should be (single or list of)"
"'default' | OpFromGraph | callable | Variable "
"with NullType or DisconnectedType, got %s")
STYPE_ERR_MSG = ('Overriding Variable instance can only have type'
' of DisconnectedType or NullType, got %s')
# we need to convert _grad_op into an OfG instance
if grad_op == 'default':
if self._lop_type == 'lop':
callable_args = (local_inputs, local_outputs, output_grads)
elif self._lop_type == 'grad':
callable_args = (local_inputs, output_grads)
else:
raise ValueError(self.LOP_TYPE_ERR_MSG % self._lop_type)
# we need to convert _lop_op into an OfG instance
if lop_op == 'default':
gdefaults_l = fn_grad(wrt=local_inputs)
all_grads_l, all_grads_ov_l = izip(
*[OpFromGraph._filter_grad_var(grad, inp) for grad, inp in izip(gdefaults_l, local_inputs)])
all_grads_l = list(all_grads_l)
all_grads_ov_l = list(all_grads_ov_l)
elif isinstance(grad_op, Variable):
if isinstance(grad_op.type, (DisconnectedType, NullType)):
elif isinstance(lop_op, Variable):
if isinstance(lop_op.type, (DisconnectedType, NullType)):
all_grads_l = [inp.zeros_like() for inp in local_inputs]
all_grads_ov_l = [grad_op.type() for _ in range(inp_len)]
all_grads_ov_l = [lop_op.type() for _ in range(inp_len)]
else:
raise ValueError(STYPE_ERR_MSG % grad_op.type)
elif isinstance(grad_op, list):
goverrides_l = grad_op
raise ValueError(self.STYPE_ERR_MSG % lop_op.type)
elif isinstance(lop_op, list):
goverrides_l = lop_op
if len(goverrides_l) != inp_len:
raise ValueError(
'Need to override %d gradients, got %d' % (
......@@ -340,40 +386,41 @@ class OpFromGraph(gof.Op):
all_grads_l.append(inp.zeros_like())
all_grads_ov_l.append(fn_gov.type())
else:
raise ValueError(STYPE_ERR_MSG % fn_gov.type)
raise ValueError(self.STYPE_ERR_MSG % fn_gov.type)
else:
if not hasattr(fn_gov, '__call__'):
raise TypeError(TYPE_ERR_MSG % fn_gov)
if not callable(fn_gov):
raise TypeError(self.TYPE_ERR_MSG % fn_gov)
gov, gov_ov = OpFromGraph._filter_grad_var(
fn_gov(local_inputs, output_grads), inp)
fn_gov(*callable_args), inp)
all_grads_l.append(gov)
all_grads_ov_l.append(gov_ov)
else:
# callable case
if not hasattr(grad_op, '__call__'):
raise TypeError(TYPE_ERR_MSG % grad_op)
goverrides_l = grad_op(local_inputs, output_grads)
if not callable(lop_op):
raise TypeError(self.TYPE_ERR_MSG % lop_op)
goverrides_l = lop_op(*callable_args)
if not isinstance(goverrides_l, list):
raise TypeError(
'Gradient overriding function should return a list, '
'Gradient/L_op overriding function should return a list, '
'got "%s"' % type(goverrides_l))
all_grads_l, all_grads_ov_l = izip(
*[OpFromGraph._filter_grad_var(grad, inp)
for grad, inp in izip(goverrides_l, local_inputs)])
if len(all_grads_l) != len(local_inputs):
raise ValueError(
'Gradient overriding function should return list of '
'Gradient/L_op overriding function should return list of '
'%d outputs, got %d' % (inp_len, len(all_grads_l)))
all_grads_l = list(all_grads_l)
all_grads_ov_l = list(all_grads_ov_l)
self._grad_op = type(self)(
inputs=local_inputs + output_grads,
self._lop_op = type(self)(
inputs=local_inputs + local_outputs + output_grads,
outputs=all_grads_l,
inline=self.is_inline,
name=(None if self.name is None else self.name + '_grad'),
name=(None if self.name is None else self.name + '_lop'),
on_unused_input='ignore')
self._grad_op_stypes_l = all_grads_ov_l
self._grad_op_is_cached = True
self._lop_op_stypes_l = all_grads_ov_l
self._lop_op_is_cached = True
self._lop_type = 'lop'
@theano.change_flags(compute_test_value='off')
def _recompute_rop_op(self):
......@@ -448,14 +495,14 @@ class OpFromGraph(gof.Op):
else:
raise ValueError(STYPE_ERR_MSG % fn_rov.type)
else:
if not hasattr(fn_rov, '__call__'):
if not callable(fn_rov):
raise TypeError(TYPE_ERR_MSG % fn_rov)
rov, rov_ov = OpFromGraph._filter_rop_var(
fn_rov(local_inputs, eval_points), out)
all_rops_l.append(rov)
all_rops_ov_l.append(rov_ov)
else:
if not hasattr(rop_op, '__call__'):
if not callable(rop_op):
raise TypeError(TYPE_ERR_MSG % rop_op)
roverrides_l = rop_op(local_inputs, eval_points)
if not isinstance(roverrides_l, list):
......@@ -482,13 +529,13 @@ class OpFromGraph(gof.Op):
self._rop_op_stypes_l = all_rops_ov_l
self._rop_op_is_cached = True
def get_grad_op(self):
def get_lop_op(self):
"""
getter method for self._grad_op
getter method for self._lop_op
"""
if not self._grad_op_is_cached:
self._recompute_grad_op()
return self._grad_op
if not self._lop_op_is_cached:
self._recompute_lop_op()
return self._lop_op
def get_rop_op(self):
"""
......@@ -501,11 +548,24 @@ class OpFromGraph(gof.Op):
def set_grad_overrides(self, grad_overrides):
"""
Set gradient overrides, see help(theano.OpFromGraph) for syntax
This will completely remove any previously set gradient overrides
This will completely remove any previously set L_op/gradient overrides
"""
self._lop_op = grad_overrides
self._lop_op_is_cached = False
self._lop_type = 'grad'
self._lop_is_default = (grad_overrides == 'default')
def set_lop_overrides(self, lop_overrides):
"""
Set L_op overrides, see help(theano.OpFromGraph) for syntax
This will completely remove any previously set L_op/gradient overrides
"""
self._grad_op = grad_overrides
self._grad_op_is_cached = False
self._lop_op = lop_overrides
self._lop_op_is_cached = False
self._lop_type = 'lop'
self._lop_is_default = (lop_overrides == 'default')
def set_rop_overrides(self, rop_overrides):
"""
......@@ -515,15 +575,18 @@ class OpFromGraph(gof.Op):
"""
self._rop_op = rop_overrides
self._rop_op_is_cached = False
def grad(self, inputs, output_grads):
if not self._grad_op_is_cached:
self._recompute_grad_op()
ret_ofg_l = self._grad_op(
*(list(inputs) + list(output_grads)), return_list=True)
self._rop_is_default = (rop_overrides == 'default')
def L_op(self, inputs, outputs, output_grads):
if not self._lop_op_is_cached:
self._recompute_lop_op()
print('L_op', inputs, outputs, output_grads)
inps = list(inputs) + list(outputs) + list(output_grads)
ret_ofg_l = self._lop_op(
*inps, return_list=True)
ret_l = [
ret_ofg if ov is None else ov for ret_ofg, ov in izip(
ret_ofg_l, self._grad_op_stypes_l)]
ret_ofg_l, self._lop_op_stypes_l)]
return ret_l
def R_op(self, inputs, eval_points):
......@@ -537,6 +600,7 @@ class OpFromGraph(gof.Op):
return ret_l
def make_node(self, *inputs):
print('make_node', inputs)
num_expected_inps = len(self.local_inputs) - len(self.shared_inputs)
if len(inputs) != num_expected_inps:
raise ValueError(
......@@ -559,14 +623,14 @@ class OpFromGraph(gof.Op):
cpmat_self = io_connection_pattern(
self.local_inputs, self.local_outputs)
grad_op = self.get_grad_op()
grad_op = self.get_lop_op()
cpmat_grad = io_connection_pattern(
grad_op.local_inputs[inp_len:],
grad_op.local_outputs)
# cpmat_self |= cpmat_grad.T
# cpmat_self &= out_is_disconnected
for i, t in enumerate(self._grad_op_stypes_l):
for i, t in enumerate(self._lop_op_stypes_l):
if t is not None:
if isinstance(t.type, DisconnectedType):
for o in range(out_len):
......@@ -619,6 +683,17 @@ class OpFromGraph(gof.Op):
# we wont need this copy anymore
output[0] = variable.copy()
def copy(self):
'''
Make a shallow copy, gradient/R_op overrides will remain same objects
'''
cpy = type(self)(self.local_inputs, self.local_outputs, inline=self.is_inline)
if not self._lop_is_default:
cpy.set_lop_overrides(self.get_lop_op())
if not self._rop_is_default:
cpy.set_rop_overrides(self.get_rop_op())
return cpy
@gof.local_optimizer([OpFromGraph])
def inline_ofg_expansion(node):
......
......@@ -163,7 +163,8 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
w, b = T.vectors('wb')
# we make the 3rd gradient default (no override)
op_linear = cls_ofg([x, w, b], [x * w + b], grad_overrides=[go1, go2, 'default'])
op_linear = cls_ofg(
[x, w, b], [x * w + b], grad_overrides=[go1, go2, 'default'])
xx, ww, bb = T.vector('xx'), T.vector('yy'), T.vector('bb')
zz = T.sum(op_linear(xx, ww, bb))
dx, dw, db = T.grad(zz, [xx, ww, bb])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论