提交 c0f7334c authored 作者: khaotik's avatar khaotik

change R_op overriding to new format

上级 8e8758a3
"""Define new Ops from existing Ops""" """Define new Ops from existing Ops"""
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
from functools import reduce from functools import reduce, partial
from collections import OrderedDict from collections import OrderedDict
import theano import theano
...@@ -62,7 +62,7 @@ class OpFromGraph(gof.Op): ...@@ -62,7 +62,7 @@ class OpFromGraph(gof.Op):
Defaults to ``None``. Defaults to ``None``.
``None`` : No value, gives NullType() ``None`` : No value, gives NullType()
``0`` : zero value, gives DisconnectedType() ``0`` : zero value, gives zeros_like(...)
``...`` : Do not override, use default R_op() result ``...`` : Do not override, use default R_op() result
OpFromGraph instance : Override with another OpFromGraph, should OpFromGraph instance : Override with another OpFromGraph, should
...@@ -92,14 +92,15 @@ class OpFromGraph(gof.Op): ...@@ -92,14 +92,15 @@ class OpFromGraph(gof.Op):
local_outputs) local_outputs)
- c_code() to remove the double overhead? - c_code() to remove the double overhead?
- grad() make it support DisconnectedType and the new interface - grad() make it support DisconnectedType and the new interface
- extend to lop_overrides? - extend grad() to L_op
- add support for NullType and DisconnectedType when R_op supports them
- check how it works with updates. - check how it works with updates.
- add test with constant as input or inside the inner graph. - add test with constant as input or inside the inner graph.
- Add support for the GPU? Probably just need an opt to remove transfer - Add support for the GPU? Probably just need an opt to remove transfer
- Add support to pickle this Op. - Add support to pickle this Op.
- Add support/test with random generator - Add support/test with random generator
- Add optimization to removing unused inputs/outputs - Add optimization to removing unused inputs/outputs
- Add optimization to work inplace when not inline - Add optimization to work inplace on inputs when not inline
Notes Notes
----- -----
...@@ -113,7 +114,7 @@ class OpFromGraph(gof.Op): ...@@ -113,7 +114,7 @@ class OpFromGraph(gof.Op):
``fast_run`` mode. ``fast_run`` mode.
- It's recommanded to provide pure functions (no side effects like - It's recommanded to provide pure functions (no side effects like
setting global variable) as callable(s). The callable(s) supplied setting global variable) as callable(s). The callable(s) supplied
for overrding gradient/rop will be called only once at the first for overriding gradient/rop will be called only once at the first
call to grad/R_op, and will be converted to OpFromGraph instances. call to grad/R_op, and will be converted to OpFromGraph instances.
Examples Examples
...@@ -176,25 +177,35 @@ class OpFromGraph(gof.Op): ...@@ -176,25 +177,35 @@ class OpFromGraph(gof.Op):
# grad: gradient Variable # grad: gradient Variable
# inp: the corresponding input of gradient Variable # inp: the corresponding input of gradient Variable
# #
# Some Variable types cannot be used directly as OfG output such as # a grad() call could return instance of NullType() or DisconnectedType()
# NullType, or DisconnectedType. # which cannot be directly used in OfG
#
# However a grad() call could return these types
# #
# Since we always use an OfG instance as self._grad_op, the current # Since we always use an OfG instance as self._grad_op, the current
# workaround is to "remember" the special cases of the gradient and # workaround is to "remember" the special cases of the gradient and
# replace them after self._grad_op is called. # replace them after self._grad_op is called.
# #
# This helper function changes invalid types into a filtered_type, # This helper function changes invalid types into a filtered_var,
# and provides a overrider_type to be replaced at grad() call # and provides a overrider_var to be replaced at grad() call
# #
# For now, this converts NullType or DisconnectedType into zeros_like. # For now, this converts NullType or DisconnectedType into zeros_like.
# other types are unmodified with overrider_type -> None # other types are unmodified: overrider_var -> None
if isinstance(grad.type, (NullType, DisconnectedType)): if isinstance(grad.type, (NullType, DisconnectedType)):
return inp.zeros_like(), grad.type return inp.zeros_like(), grad
else: else:
return grad, None return grad, None
@staticmethod
def _filter_rop_var(inpJ, out):
# mostly similar to _filter_grad_var
if isinstance(inpJ.type, NullType):
return out.zeros_like(), inpJ
if isinstance(inpJ.type, DisconnectedType):
# since R_op does not have DisconnectedType yet, we will just
# make them zeros.
return out.zeros_like(), None
else:
return inpJ, None
def __init__( def __init__(
self, inputs, outputs, self, inputs, outputs,
inline=False, inline=False,
...@@ -202,7 +213,7 @@ class OpFromGraph(gof.Op): ...@@ -202,7 +213,7 @@ class OpFromGraph(gof.Op):
name=None, **kwargs name=None, **kwargs
): ):
if not isinstance(outputs, list): if not isinstance(outputs, list):
raise TypeError('outputs must be list, got %s' % outputs, outputs) raise TypeError('outputs must be list, got %s' % type(outputs), outputs)
for i in inputs + outputs: for i in inputs + outputs:
if not isinstance(i, gof.Variable): if not isinstance(i, gof.Variable):
raise TypeError( raise TypeError(
...@@ -263,21 +274,26 @@ class OpFromGraph(gof.Op): ...@@ -263,21 +274,26 @@ class OpFromGraph(gof.Op):
if isinstance(grad_op, OpFromGraph): if isinstance(grad_op, OpFromGraph):
self._grad_op_is_cached = True self._grad_op_is_cached = True
self._grad_op_overrides_l = [None] * len(self.local_inputs) self._grad_op_overrides_l = [None] * inp_len
return return
output_grads = [out_t() for out_t in self.output_types] output_grads = [out_t() for out_t in self.output_types]
fn_grad = partial(
theano.gradient.grad,
cost=None,
disconnected_inputs='ignore',
return_disconnected='Disconnected',
null_gradients='return',
known_grads=OrderedDict(izip(local_outputs, output_grads)))
TYPE_ERR_MSG = 'Gradient override should be (single or list of)' \ TYPE_ERR_MSG = 'Gradient override should be (single or list of)' \
'OpFromGraph | Ellipsis | None | 0 | callable, got %s' 'OpFromGraph | Ellipsis | None | 0 | callable, got %s'
# we need to convert _grad_op into an OfG instance # we need to convert _grad_op into an OfG instance
if grad_op is Ellipsis: if grad_op is Ellipsis:
self._grad_op_tflags = bytes(inp_len) gdefaults_l = fn_grad(wrt=local_inputs)
all_grads_l = theano.gradient.grad(
cost=None,
known_grads=OrderedDict(izip(local_outputs, output_grads)),
wrt=local_inputs,
disconnected_inputs='ignore')
all_grads_ov_l = [None] * inp_len all_grads_ov_l = [None] * inp_len
all_grads_l, all_grads_ov_l = izip(
*[OpFromGraph._filter_grad_var(grad, inp) for grad, inp in izip(gdefaults_l, local_inputs)])
elif grad_op is None: elif grad_op is None:
all_grads_l = [inp.zeros_like() for inp in local_inputs] all_grads_l = [inp.zeros_like() for inp in local_inputs]
all_grads_ov_l = [self.ofg_null_t()] * inp_len all_grads_ov_l = [self.ofg_null_t()] * inp_len
...@@ -285,7 +301,7 @@ class OpFromGraph(gof.Op): ...@@ -285,7 +301,7 @@ class OpFromGraph(gof.Op):
all_grads_l = [inp.zeros_like() for inp in local_inputs] all_grads_l = [inp.zeros_like() for inp in local_inputs]
all_grads_ov_l = [self.ofg_discon_t()] * inp_len all_grads_ov_l = [self.ofg_discon_t()] * inp_len
elif isinstance(grad_op, list): elif isinstance(grad_op, list):
goverrides_l = self._grad_op goverrides_l = grad_op
if len(goverrides_l) != inp_len: if len(goverrides_l) != inp_len:
raise ValueError( raise ValueError(
'Need to override %d gradients, got %d' % ( 'Need to override %d gradients, got %d' % (
...@@ -293,18 +309,15 @@ class OpFromGraph(gof.Op): ...@@ -293,18 +309,15 @@ class OpFromGraph(gof.Op):
# compute non-overriding downsteam grads from upstreams grads # compute non-overriding downsteam grads from upstreams grads
# it's normal some input may be disconnected, thus the 'ignore' # it's normal some input may be disconnected, thus the 'ignore'
wrt_l = [lin for lin, gov in izip( wrt_l = [lin for lin, gov in izip(
self.local_inputs, goverrides_l) if gov is Ellipsis] local_inputs, goverrides_l) if gov is Ellipsis]
gdefaults = iter(theano.gradient.grad( gdefaults = iter(fn_grad(wrt=wrt_l) if wrt_l else [])
cost=None,
known_grads=OrderedDict(izip(self.local_outputs, output_grads)),
wrt=wrt_l,
disconnected_inputs='ignore') if wrt_l else [])
# combine overriding gradients # combine overriding gradients
all_grads_l = [] all_grads_l = []
all_grads_ov_l = [] all_grads_ov_l = []
for i, (inp, fn_gov) in enumerate(izip(local_inputs, goverrides_l)): for inp, fn_gov in izip(local_inputs, goverrides_l):
if fn_gov is Ellipsis: if fn_gov is Ellipsis:
gnext, gnext_ov = OpFromGraph._filter_grad_var(next(gdefaults), inp) gnext, gnext_ov = OpFromGraph._filter_grad_var(
next(gdefaults), inp)
all_grads_l.append(gnext) all_grads_l.append(gnext)
all_grads_ov_l.append(gnext_ov) all_grads_ov_l.append(gnext_ov)
elif fn_gov is 0: elif fn_gov is 0:
...@@ -330,13 +343,14 @@ class OpFromGraph(gof.Op): ...@@ -330,13 +343,14 @@ class OpFromGraph(gof.Op):
'Gradient overriding function should return a list, ' 'Gradient overriding function should return a list, '
'got "%s"' % type(goverrides_l)) 'got "%s"' % type(goverrides_l))
all_grads_l, all_grads_ov_l = izip( all_grads_l, all_grads_ov_l = izip(
*[OpFromGraph._filter_grad_var(grad, inp) for grad, inp in izip(goverrides_l, local_inputs)]) *[OpFromGraph._filter_grad_var(
grad, inp) for grad, inp in izip(goverrides_l, local_inputs)])
if len(all_grads_l) != len(local_inputs): if len(all_grads_l) != len(local_inputs):
raise ValueError( raise ValueError(
'Gradient overriding function should return list of ' 'Gradient overriding function should return list of '
'%d outputs, got %d' % (inp_len, len(all_grads_l))) '%d outputs, got %d' % (inp_len, len(all_grads_l)))
all_grads_l = list(all_grads_l) all_grads_l = list(all_grads_l)
all_grads_ov_l = list(all_grads_ov_l) all_grads_ov_l = list(all_grads_ov_l)
self._grad_op = type(self)( self._grad_op = type(self)(
inputs=local_inputs + output_grads, inputs=local_inputs + output_grads,
outputs=all_grads_l, outputs=all_grads_l,
...@@ -347,65 +361,92 @@ class OpFromGraph(gof.Op): ...@@ -347,65 +361,92 @@ class OpFromGraph(gof.Op):
self._grad_op_is_cached = True self._grad_op_is_cached = True
def _recompute_rop_op(self): def _recompute_rop_op(self):
local_inputs = self.local_inputs
local_outputs = self.local_outputs
out_len = len(local_outputs)
rop_op = self._rop_op
if isinstance(self._rop_op, OpFromGraph): if isinstance(self._rop_op, OpFromGraph):
self._rop_op_is_cached = True self._rop_op_is_cached = True
self._rop_op_overrides_l = [None] * out_len
return return
eval_points = [inp_t() for inp_t in self.input_types]
if self._rop_op is None:
self._rop_op = []
if isinstance(self._rop_op, list): eval_points = [inp_t() for inp_t in self.input_types]
roverrides_l = self._rop_op fn_rop = partial(
if len(roverrides_l) > len(self.local_outputs): theano.gradient.Rop,
wrt=local_inputs,
eval_points=eval_points)
TYPE_ERR_MSG = 'R_op overrides should be (single or list of)' \
'OpFromGraph | Ellipsis | None | 0 | callable, got %s'
if rop_op is Ellipsis:
all_rops_l = fn_rop(f=local_outputs)
all_rops_ov_l = [None] * out_len
elif rop_op is None:
all_rops_l = [out.zeros_like() for out in local_outputs]
all_rops_ov_l = [self.ofg_null_t()] * out_len
elif rop_op is 0:
all_rops_l = [out.zeros_like() for out in local_outputs]
all_rops_ov_l = [None] * out_len
elif isinstance(rop_op, list):
roverrides_l = rop_op
if len(roverrides_l) != out_len:
raise ValueError( raise ValueError(
'Can override %d gradients at most, got %d' % ( 'Need to override %d Rop, got %d' % (
len(self.local_onputs), len(roverrides_l)), out_len, len(roverrides_l)), roverrides_l)
roverrides_l)
if len(roverrides_l) < len(self.local_outputs):
roverrides_l += [None] * (
len(self.local_outputs) - len(roverrides_l))
# get outputs that does not have Rop override # get outputs that does not have Rop override
odefaults_l = [ odefaults_l = [
lo for lo, rov in izip(self.local_outputs, roverrides_l) lo for lo, rov in izip(local_outputs, roverrides_l)
if not rov] if rov is Ellipsis]
rdefaults_li = theano.gradient.Rop( rdefaults_l = fn_rop(f=odefaults_l)
f=odefaults_l, rdefaults = iter(rdefaults_l if odefaults_l else [])
wrt=self.local_inputs,
eval_points=eval_points
)
rdefaults = iter(rdefaults_li if odefaults_l else [])
# combine overriding Rops # combine overriding Rops
all_rops_l = [] all_rops_l = []
for out, rov in izip(self.local_outputs, roverrides_l): all_rops_ov_l = []
if rov is None: for out, fn_rov in izip(local_outputs, roverrides_l):
all_rops_l.append(next(rdefaults)) if fn_rov is Ellipsis:
elif rov is undef: rnext, rnext_ov = OpFromGraph._filter_rop_var(
all_rops_l.append( next(rdefaults), out)
out.zeros_like().astype(theano.config.floatX)) all_rops_l.append(rnext)
all_rops_ov_l.append(rnext_ov)
elif fn_rov is 0:
all_rops_l.append(out.zeros_like())
all_rops_ov_l.append(None)
elif fn_rov is None:
all_rops_l.append(out.zeros_like())
all_rops_ov_l.append(self.ofg_null_t())
else: else:
all_rops_l.append(rov(self.local_inputs, eval_points)) if not hasattr(fn_rov, '__call__'):
elif self._rop_op is undef: raise TypeError(TYPE_ERR_MSG % fn_rov)
all_rops_l = [ rov, rov_ov = OpFromGraph._filter_rop_var(
out.zeros_like().astype(theano.config.floatX) fn_rov(local_inputs, eval_points), out)
for out in self.local_outputs] all_rops_l.append(rov)
all_rops_ov_l.append(rov_ov)
else: else:
all_rops_l = self._rop_op(self.local_inputs, eval_points) if not hasattr(rop_op, '__call__'):
if not isinstance(all_rops_l, (tuple, list)): raise TypeError(TYPE_ERR_MSG % rop_op)
all_rops_l = [all_rops_l] roverrides_l = rop_op(local_inputs, eval_points)
if len(all_rops_l) != len(self.local_outputs): if not isinstance(roverrides_l, list):
raise TypeError(
'Rop overriding function should return a list, '
'got "%s"' % type(roverrides_l))
all_rops_l, all_rops_ov_l = izip(
*[OpFromGraph._filter_rop_var(
rop, out) for rop, out in izip(roverrides_l, local_outputs)])
if len(all_rops_l) != out_len:
raise ValueError( raise ValueError(
'Rop overriding function %s should return list of ' 'Rop overriding function %s should return list of '
'%d outputs, got %d' % ( '%d outputs, got %d' % (
self._rop_op, self._rop_op, out_len,
len(self.local_outputs), len(all_rops_l)), rop_op)
len(all_rops_l)), all_rops_l = list(all_rops_l)
self._rop_op) all_rops_ov_l = list(all_rops_ov_l)
self._rop_op = type(self)( self._rop_op = type(self)(
inputs=self.local_inputs + eval_points, inputs=local_inputs + eval_points,
outputs=all_rops_l, outputs=all_rops_l,
inline=self.is_inline, inline=self.is_inline,
name=(None if self.name is None else self.name + '_rop'), name=(None if self.name is None else self.name + '_rop'),
on_unused_input='ignore') on_unused_input='ignore')
self._rop_op_overrides_l = all_rops_ov_l
self._rop_op_is_cached = True self._rop_op_is_cached = True
def get_grad_op(self): def get_grad_op(self):
...@@ -447,10 +488,9 @@ class OpFromGraph(gof.Op): ...@@ -447,10 +488,9 @@ class OpFromGraph(gof.Op):
self._recompute_rop_op() self._recompute_rop_op()
ret_ofg_l = self._rop_op( ret_ofg_l = self._rop_op(
*(list(inputs) + list(eval_points)), return_list=True) *(list(inputs) + list(eval_points)), return_list=True)
ret_l = [{ ret_l = [
self.TFLAG_NULL_T: self.ofg_null_t(), ret_ofg if ov is None else ov for ret_ofg, ov in izip(
self.TFLAG_DISCON_T: self.ofg_discon_t() ret_ofg_l, self._rop_op_overrides_l)]
}[flag] if flag else ret_ofg for ret_ofg, flag in izip(ret_ofg_l, self._grad_tflags)]
return ret_l return ret_l
def grad(self, inputs, output_grads): def grad(self, inputs, output_grads):
...@@ -459,10 +499,10 @@ class OpFromGraph(gof.Op): ...@@ -459,10 +499,10 @@ class OpFromGraph(gof.Op):
ret_ofg_l = self._grad_op( ret_ofg_l = self._grad_op(
*(list(inputs) + list(output_grads)), return_list=True) *(list(inputs) + list(output_grads)), return_list=True)
ret_l = [ ret_l = [
ret_ofg if ov is None else ov for ret_ofg, ov in izip(ret_ofg_l, self._grad_op_overrides_l)] ret_ofg if ov is None else ov for ret_ofg, ov in izip(
ret_ofg_l, self._grad_op_overrides_l)]
return ret_l return ret_l
def make_node(self, *inputs): def make_node(self, *inputs):
num_expected_inps = len(self.local_inputs) - len(self.shared_inputs) num_expected_inps = len(self.local_inputs) - len(self.shared_inputs)
if len(inputs) != num_expected_inps: if len(inputs) != num_expected_inps:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论