提交 c0f7334c authored 作者: khaotik's avatar khaotik

change R_op overriding to new format

上级 8e8758a3
"""Define new Ops from existing Ops"""
from __future__ import absolute_import, print_function, division
from functools import reduce
from functools import reduce, partial
from collections import OrderedDict
import theano
......@@ -62,7 +62,7 @@ class OpFromGraph(gof.Op):
Defaults to ``None``.
``None`` : No value, gives NullType()
``0`` : zero value, gives DisconnectedType()
``0`` : zero value, gives zeros_like(...)
``...`` : Do not override, use default R_op() result
OpFromGraph instance : Override with another OpFromGraph, should
......@@ -92,14 +92,15 @@ class OpFromGraph(gof.Op):
local_outputs)
- c_code() to remove the double overhead?
- grad() make it support DisconnectedType and the new interface
- extend to lop_overrides?
- extend grad() to L_op
- add support for NullType and DisconnectedType when R_op supports them
- check how it works with updates.
- add test with constant as input or inside the inner graph.
- Add support for the GPU? Probably just need an opt to remove transfer
- Add support to pickle this Op.
- Add support/test with random generator
- Add optimization to removing unused inputs/outputs
- Add optimization to work inplace when not inline
- Add optimization to work inplace on inputs when not inline
Notes
-----
......@@ -113,7 +114,7 @@ class OpFromGraph(gof.Op):
``fast_run`` mode.
- It's recommanded to provide pure functions (no side effects like
setting global variable) as callable(s). The callable(s) supplied
for overrding gradient/rop will be called only once at the first
for overriding gradient/rop will be called only once at the first
call to grad/R_op, and will be converted to OpFromGraph instances.
Examples
......@@ -176,25 +177,35 @@ class OpFromGraph(gof.Op):
# grad: gradient Variable
# inp: the corresponding input of gradient Variable
#
# Some Variable types cannot be used directly as OfG output such as
# NullType, or DisconnectedType.
#
# However a grad() call could return these types
# a grad() call could return instance of NullType() or DisconnectedType()
# which cannot be directly used in OfG
#
# Since we always use an OfG instance as self._grad_op, the current
# workaround is to "remember" the special cases of the gradient and
# replace them after self._grad_op is called.
#
# This helper function changes invalid types into a filtered_type,
# and provides a overrider_type to be replaced at grad() call
# This helper function changes invalid types into a filtered_var,
# and provides a overrider_var to be replaced at grad() call
#
# For now, this converts NullType or DisconnectedType into zeros_like.
# other types are unmodified with overrider_type -> None
# other types are unmodified: overrider_var -> None
if isinstance(grad.type, (NullType, DisconnectedType)):
return inp.zeros_like(), grad.type
return inp.zeros_like(), grad
else:
return grad, None
@staticmethod
def _filter_rop_var(inpJ, out):
# mostly similar to _filter_grad_var
if isinstance(inpJ.type, NullType):
return out.zeros_like(), inpJ
if isinstance(inpJ.type, DisconnectedType):
# since R_op does not have DisconnectedType yet, we will just
# make them zeros.
return out.zeros_like(), None
else:
return inpJ, None
def __init__(
self, inputs, outputs,
inline=False,
......@@ -202,7 +213,7 @@ class OpFromGraph(gof.Op):
name=None, **kwargs
):
if not isinstance(outputs, list):
raise TypeError('outputs must be list, got %s' % outputs, outputs)
raise TypeError('outputs must be list, got %s' % type(outputs), outputs)
for i in inputs + outputs:
if not isinstance(i, gof.Variable):
raise TypeError(
......@@ -263,21 +274,26 @@ class OpFromGraph(gof.Op):
if isinstance(grad_op, OpFromGraph):
self._grad_op_is_cached = True
self._grad_op_overrides_l = [None] * len(self.local_inputs)
self._grad_op_overrides_l = [None] * inp_len
return
output_grads = [out_t() for out_t in self.output_types]
fn_grad = partial(
theano.gradient.grad,
cost=None,
disconnected_inputs='ignore',
return_disconnected='Disconnected',
null_gradients='return',
known_grads=OrderedDict(izip(local_outputs, output_grads)))
TYPE_ERR_MSG = 'Gradient override should be (single or list of)' \
'OpFromGraph | Ellipsis | None | 0 | callable, got %s'
# we need to convert _grad_op into an OfG instance
if grad_op is Ellipsis:
self._grad_op_tflags = bytes(inp_len)
all_grads_l = theano.gradient.grad(
cost=None,
known_grads=OrderedDict(izip(local_outputs, output_grads)),
wrt=local_inputs,
disconnected_inputs='ignore')
gdefaults_l = fn_grad(wrt=local_inputs)
all_grads_ov_l = [None] * inp_len
all_grads_l, all_grads_ov_l = izip(
*[OpFromGraph._filter_grad_var(grad, inp) for grad, inp in izip(gdefaults_l, local_inputs)])
elif grad_op is None:
all_grads_l = [inp.zeros_like() for inp in local_inputs]
all_grads_ov_l = [self.ofg_null_t()] * inp_len
......@@ -285,7 +301,7 @@ class OpFromGraph(gof.Op):
all_grads_l = [inp.zeros_like() for inp in local_inputs]
all_grads_ov_l = [self.ofg_discon_t()] * inp_len
elif isinstance(grad_op, list):
goverrides_l = self._grad_op
goverrides_l = grad_op
if len(goverrides_l) != inp_len:
raise ValueError(
'Need to override %d gradients, got %d' % (
......@@ -293,18 +309,15 @@ class OpFromGraph(gof.Op):
# compute non-overriding downsteam grads from upstreams grads
# it's normal some input may be disconnected, thus the 'ignore'
wrt_l = [lin for lin, gov in izip(
self.local_inputs, goverrides_l) if gov is Ellipsis]
gdefaults = iter(theano.gradient.grad(
cost=None,
known_grads=OrderedDict(izip(self.local_outputs, output_grads)),
wrt=wrt_l,
disconnected_inputs='ignore') if wrt_l else [])
local_inputs, goverrides_l) if gov is Ellipsis]
gdefaults = iter(fn_grad(wrt=wrt_l) if wrt_l else [])
# combine overriding gradients
all_grads_l = []
all_grads_ov_l = []
for i, (inp, fn_gov) in enumerate(izip(local_inputs, goverrides_l)):
for inp, fn_gov in izip(local_inputs, goverrides_l):
if fn_gov is Ellipsis:
gnext, gnext_ov = OpFromGraph._filter_grad_var(next(gdefaults), inp)
gnext, gnext_ov = OpFromGraph._filter_grad_var(
next(gdefaults), inp)
all_grads_l.append(gnext)
all_grads_ov_l.append(gnext_ov)
elif fn_gov is 0:
......@@ -330,13 +343,14 @@ class OpFromGraph(gof.Op):
'Gradient overriding function should return a list, '
'got "%s"' % type(goverrides_l))
all_grads_l, all_grads_ov_l = izip(
*[OpFromGraph._filter_grad_var(grad, inp) for grad, inp in izip(goverrides_l, local_inputs)])
*[OpFromGraph._filter_grad_var(
grad, inp) for grad, inp in izip(goverrides_l, local_inputs)])
if len(all_grads_l) != len(local_inputs):
raise ValueError(
'Gradient overriding function should return list of '
'%d outputs, got %d' % (inp_len, len(all_grads_l)))
all_grads_l = list(all_grads_l)
all_grads_ov_l = list(all_grads_ov_l)
all_grads_l = list(all_grads_l)
all_grads_ov_l = list(all_grads_ov_l)
self._grad_op = type(self)(
inputs=local_inputs + output_grads,
outputs=all_grads_l,
......@@ -347,65 +361,92 @@ class OpFromGraph(gof.Op):
self._grad_op_is_cached = True
def _recompute_rop_op(self):
local_inputs = self.local_inputs
local_outputs = self.local_outputs
out_len = len(local_outputs)
rop_op = self._rop_op
if isinstance(self._rop_op, OpFromGraph):
self._rop_op_is_cached = True
self._rop_op_overrides_l = [None] * out_len
return
eval_points = [inp_t() for inp_t in self.input_types]
if self._rop_op is None:
self._rop_op = []
if isinstance(self._rop_op, list):
roverrides_l = self._rop_op
if len(roverrides_l) > len(self.local_outputs):
eval_points = [inp_t() for inp_t in self.input_types]
fn_rop = partial(
theano.gradient.Rop,
wrt=local_inputs,
eval_points=eval_points)
TYPE_ERR_MSG = 'R_op overrides should be (single or list of)' \
'OpFromGraph | Ellipsis | None | 0 | callable, got %s'
if rop_op is Ellipsis:
all_rops_l = fn_rop(f=local_outputs)
all_rops_ov_l = [None] * out_len
elif rop_op is None:
all_rops_l = [out.zeros_like() for out in local_outputs]
all_rops_ov_l = [self.ofg_null_t()] * out_len
elif rop_op is 0:
all_rops_l = [out.zeros_like() for out in local_outputs]
all_rops_ov_l = [None] * out_len
elif isinstance(rop_op, list):
roverrides_l = rop_op
if len(roverrides_l) != out_len:
raise ValueError(
'Can override %d gradients at most, got %d' % (
len(self.local_onputs), len(roverrides_l)),
roverrides_l)
if len(roverrides_l) < len(self.local_outputs):
roverrides_l += [None] * (
len(self.local_outputs) - len(roverrides_l))
'Need to override %d Rop, got %d' % (
out_len, len(roverrides_l)), roverrides_l)
# get outputs that does not have Rop override
odefaults_l = [
lo for lo, rov in izip(self.local_outputs, roverrides_l)
if not rov]
rdefaults_li = theano.gradient.Rop(
f=odefaults_l,
wrt=self.local_inputs,
eval_points=eval_points
)
rdefaults = iter(rdefaults_li if odefaults_l else [])
lo for lo, rov in izip(local_outputs, roverrides_l)
if rov is Ellipsis]
rdefaults_l = fn_rop(f=odefaults_l)
rdefaults = iter(rdefaults_l if odefaults_l else [])
# combine overriding Rops
all_rops_l = []
for out, rov in izip(self.local_outputs, roverrides_l):
if rov is None:
all_rops_l.append(next(rdefaults))
elif rov is undef:
all_rops_l.append(
out.zeros_like().astype(theano.config.floatX))
all_rops_ov_l = []
for out, fn_rov in izip(local_outputs, roverrides_l):
if fn_rov is Ellipsis:
rnext, rnext_ov = OpFromGraph._filter_rop_var(
next(rdefaults), out)
all_rops_l.append(rnext)
all_rops_ov_l.append(rnext_ov)
elif fn_rov is 0:
all_rops_l.append(out.zeros_like())
all_rops_ov_l.append(None)
elif fn_rov is None:
all_rops_l.append(out.zeros_like())
all_rops_ov_l.append(self.ofg_null_t())
else:
all_rops_l.append(rov(self.local_inputs, eval_points))
elif self._rop_op is undef:
all_rops_l = [
out.zeros_like().astype(theano.config.floatX)
for out in self.local_outputs]
if not hasattr(fn_rov, '__call__'):
raise TypeError(TYPE_ERR_MSG % fn_rov)
rov, rov_ov = OpFromGraph._filter_rop_var(
fn_rov(local_inputs, eval_points), out)
all_rops_l.append(rov)
all_rops_ov_l.append(rov_ov)
else:
all_rops_l = self._rop_op(self.local_inputs, eval_points)
if not isinstance(all_rops_l, (tuple, list)):
all_rops_l = [all_rops_l]
if len(all_rops_l) != len(self.local_outputs):
if not hasattr(rop_op, '__call__'):
raise TypeError(TYPE_ERR_MSG % rop_op)
roverrides_l = rop_op(local_inputs, eval_points)
if not isinstance(roverrides_l, list):
raise TypeError(
'Rop overriding function should return a list, '
'got "%s"' % type(roverrides_l))
all_rops_l, all_rops_ov_l = izip(
*[OpFromGraph._filter_rop_var(
rop, out) for rop, out in izip(roverrides_l, local_outputs)])
if len(all_rops_l) != out_len:
raise ValueError(
'Rop overriding function %s should return list of '
'%d outputs, got %d' % (
self._rop_op,
len(self.local_outputs),
len(all_rops_l)),
self._rop_op)
self._rop_op, out_len,
len(all_rops_l)), rop_op)
all_rops_l = list(all_rops_l)
all_rops_ov_l = list(all_rops_ov_l)
self._rop_op = type(self)(
inputs=self.local_inputs + eval_points,
inputs=local_inputs + eval_points,
outputs=all_rops_l,
inline=self.is_inline,
name=(None if self.name is None else self.name + '_rop'),
on_unused_input='ignore')
self._rop_op_overrides_l = all_rops_ov_l
self._rop_op_is_cached = True
def get_grad_op(self):
......@@ -447,10 +488,9 @@ class OpFromGraph(gof.Op):
self._recompute_rop_op()
ret_ofg_l = self._rop_op(
*(list(inputs) + list(eval_points)), return_list=True)
ret_l = [{
self.TFLAG_NULL_T: self.ofg_null_t(),
self.TFLAG_DISCON_T: self.ofg_discon_t()
}[flag] if flag else ret_ofg for ret_ofg, flag in izip(ret_ofg_l, self._grad_tflags)]
ret_l = [
ret_ofg if ov is None else ov for ret_ofg, ov in izip(
ret_ofg_l, self._rop_op_overrides_l)]
return ret_l
def grad(self, inputs, output_grads):
......@@ -459,10 +499,10 @@ class OpFromGraph(gof.Op):
ret_ofg_l = self._grad_op(
*(list(inputs) + list(output_grads)), return_list=True)
ret_l = [
ret_ofg if ov is None else ov for ret_ofg, ov in izip(ret_ofg_l, self._grad_op_overrides_l)]
ret_ofg if ov is None else ov for ret_ofg, ov in izip(
ret_ofg_l, self._grad_op_overrides_l)]
return ret_l
def make_node(self, *inputs):
num_expected_inps = len(self.local_inputs) - len(self.shared_inputs)
if len(inputs) != num_expected_inps:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论