提交 8e8758a3 authored 作者: khaotik's avatar khaotik

better handling for NullType and DisconnectedType

major changes: - The self._grad_op now only returns zeros_like() for special types like NullType() or DisconnectedType() - call to grad() will furthur replace returned zero tensors with special types - proposed gradient override interface : (single or list of below) Ellipsis -> <no_override> (-) since python 2 does not support `[...]` syntax, this may result in uglier code in python 2 None -> NullType() int(0) -> DisconnectedType() OpFromGraph instance or callable -> <override> minor changes: - various typo/bug fix notes: - This commit breaks OpFromGraph.R_op, which is expected to be fixed in upcoming commits.
上级 a6e5cd74
"""Define new Ops from existing Ops"""
from __future__ import absolute_import, print_function, division
from functools import reduce
from collections import OrderedDict
import theano
from theano import gof
......@@ -9,7 +10,8 @@ from theano.compile.function_module import orig_function
from theano.compile import SharedVariable, rebuild_collect_shared, optdb
from theano.gof import ops_with_inner_function
from theano.gof.graph import io_connection_pattern
from theano.gof.utils import undef
from theano.gof.null_type import NullType
from theano.gradient import DisconnectedType
class OpFromGraph(gof.Op):
......@@ -38,11 +40,12 @@ class OpFromGraph(gof.Op):
``False`` : will use a pre-compiled function inside.
grad_overrides : single or list of {None, undef, OpFromGraph, callable}, optional
grad_overrides : single or list of {0, None, Ellipsis, OpFromGraph, callable}, optional
Defaults to ``None``.
``None`` : Do not override gradient
theano.utils.undef : No gradient will be used (zero)
``None`` : No value, gives NullType()
``0`` : zero value, gives DisconnectedType()
``...`` : Do not override, use default grad() result
OpFromGraph instance : Override with another OpFromGraph, should
accept inputs as the same order and types of "inputs" and "output_grads"
......@@ -53,14 +56,15 @@ class OpFromGraph(gof.Op):
list: Each OpFromGraph/callable must return a single
:class:`Variable <theano.gof.Variable>`. Each list element corresponds to gradient of
a specific input.
a specific input, length of list must be equal to number of inputs.
rop_overrides : single or list of {None, undef, OpFromGraph, callable}, optional
rop_overrides : single or list of {0, None, Ellipsis, OpFromGraph, callable}, optional
Defaults to ``None``.
``None`` : Do not override gradient
``None`` : No value, gives NullType()
``0`` : zero value, gives DisconnectedType()
``...`` : Do not override, use default R_op() result
theano.utils.undef : No gradient will be used (zero)
OpFromGraph instance : Override with another OpFromGraph, should
accept inputs as the same order and types of "inputs" and "output_grads"
arguments as one would specify in grad() method.
......@@ -70,7 +74,7 @@ class OpFromGraph(gof.Op):
list: Each OpFromGraph/callable must return a single
:class:`Variable <theano.gof.Variable>`. Each list element corresponds
to a specific output of R_op.
to a specific output of R_op, length of list must be equal to number of outputs.
name : string, optional
A name for debugging purposes
......@@ -88,12 +92,14 @@ class OpFromGraph(gof.Op):
local_outputs)
- c_code() to remove the double overhead?
- grad() make it support DisconnectedType and the new interface
- extend to lop_overrides?
- check how it works with updates.
- add test with constant as input or inside the inner graph.
- Add support for the GPU? Probably just need an opt to remove transfer
- Add support to pickle this Op.
- Add support/test with random generator
- Add optimization to removing unused inputs/outputs
- Add optimization to work inplace when not inline
Notes
-----
......@@ -117,7 +123,7 @@ class OpFromGraph(gof.Op):
.. code-block:: python
from theano import function, tensor
from theano import function, OpFromGraph, tensor
x, y, z = tensor.scalars('xyz')
e = x + y * z
op = OpFromGraph([x, y, z], [e])
......@@ -144,25 +150,59 @@ class OpFromGraph(gof.Op):
.. code-block:: python
from thenao import funciton, OpFromGraph, tensor, grad
from theano import function, OpFromGraph, tensor, grad
x, y, z = tensor.scalars('xyz')
e = x + y * z
def rescale_dy(inps, grads):
x, y, z = inps
g = grads
g, = grads
return z*2
op = OpFromGraph(
[x, y, z], [e], grad_overrides=[None, rescale_dy, None])
[x, y, z], [e], grad_overrides=[Ellipsis, rescale_dy, Ellipsis]
e2 = op(x, y, z)
dx, dy, dz = grad(e2, [x, y, z])
fn = function([x, y, z], [dx, dy, dz])
# the graident wrt y is now doubled
# the gradient wrt y is now doubled
fn(2., 3., 4.) # [1., 8., 3.]
"""
def __init__(self, inputs, outputs, inline=False, grad_overrides=None, rop_overrides=None, name=None, **kwargs):
ofg_null_t = NullType(why_null='ofg_overridden')
ofg_discon_t = DisconnectedType()
@staticmethod
def _filter_grad_var(grad, inp):
# Returns (filtered_var, overrider_var)
# Args:
# grad: gradient Variable
# inp: the corresponding input of gradient Variable
#
# Some Variable types cannot be used directly as OfG output such as
# NullType, or DisconnectedType.
#
# However a grad() call could return these types
#
# Since we always use an OfG instance as self._grad_op, the current
# workaround is to "remember" the special cases of the gradient and
# replace them after self._grad_op is called.
#
# This helper function changes invalid types into a filtered_type,
# and provides a overrider_type to be replaced at grad() call
#
# For now, this converts NullType or DisconnectedType into zeros_like.
# other types are unmodified with overrider_type -> None
if isinstance(grad.type, (NullType, DisconnectedType)):
return inp.zeros_like(), grad.type
else:
return grad, None
def __init__(
self, inputs, outputs,
inline=False,
grad_overrides=Ellipsis, rop_overrides=Ellipsis,
name=None, **kwargs
):
if not isinstance(outputs, list):
raise TypeError('outputs must be list', outputs)
raise TypeError('outputs must be list, got %s' % outputs, outputs)
for i in inputs + outputs:
if not isinstance(i, gof.Variable):
raise TypeError(
......@@ -216,64 +256,94 @@ class OpFromGraph(gof.Op):
return '%(name)s{inline=%(is_inline)s}' % locals()
def _recompute_grad_op(self):
if isinstance(self._grad_op, OpFromGraph):
local_inputs = self.local_inputs
local_outputs = self.local_outputs
inp_len = len(local_inputs)
grad_op = self._grad_op
if isinstance(grad_op, OpFromGraph):
self._grad_op_is_cached = True
self._grad_op_overrides_l = [None] * len(self.local_inputs)
return
output_grads = [out_t() for out_t in self.output_types]
if self._grad_op is None:
self._grad_op = []
# we need to convert a list/function into an OfG instance
if isinstance(self._grad_op, list):
output_grads = [out_t() for out_t in self.output_types]
TYPE_ERR_MSG = 'Gradient override should be (single or list of)' \
'OpFromGraph | Ellipsis | None | 0 | callable, got %s'
# we need to convert _grad_op into an OfG instance
if grad_op is Ellipsis:
self._grad_op_tflags = bytes(inp_len)
all_grads_l = theano.gradient.grad(
cost=None,
known_grads=OrderedDict(izip(local_outputs, output_grads)),
wrt=local_inputs,
disconnected_inputs='ignore')
all_grads_ov_l = [None] * inp_len
elif grad_op is None:
all_grads_l = [inp.zeros_like() for inp in local_inputs]
all_grads_ov_l = [self.ofg_null_t()] * inp_len
elif grad_op is 0:
all_grads_l = [inp.zeros_like() for inp in local_inputs]
all_grads_ov_l = [self.ofg_discon_t()] * inp_len
elif isinstance(grad_op, list):
goverrides_l = self._grad_op
if len(goverrides_l) > len(self.local_inputs):
if len(goverrides_l) != inp_len:
raise ValueError(
'Can override %d gradients at most, got %d' % (
len(self.local_inputs), len(goverrides_l)),
self.goverrides_l)
if len(goverrides_l) < len(self.local_inputs):
goverrides_l += [None] * (
len(self.local_inputs) - len(goverrides_l))
wrt_l = [lin for lin, gov in
izip(self.local_inputs, goverrides_l) if not gov]
'Need to override %d gradients, got %d' % (
inp_len, len(goverrides_l)), goverrides_l)
# compute non-overriding downsteam grads from upstreams grads
# it's normal some input may be disconnected, thus the 'ignore'
wrt_l = [lin for lin, gov in izip(
self.local_inputs, goverrides_l) if gov is Ellipsis]
gdefaults = iter(theano.gradient.grad(
cost=None,
known_grads=dict(izip(self.local_outputs, output_grads)),
known_grads=OrderedDict(izip(self.local_outputs, output_grads)),
wrt=wrt_l,
disconnected_inputs='ignore') if wrt_l else [])
# combine overriding gradients
all_grads_l = []
for inp, gov in izip(self.local_inputs, goverrides_l):
if gov is None:
all_grads_l.append(next(gdefaults))
elif gov is undef:
all_grads_l.append(
inp.zeros_like().astype(theano.config.floatX))
all_grads_ov_l = []
for i, (inp, fn_gov) in enumerate(izip(local_inputs, goverrides_l)):
if fn_gov is Ellipsis:
gnext, gnext_ov = OpFromGraph._filter_grad_var(next(gdefaults), inp)
all_grads_l.append(gnext)
all_grads_ov_l.append(gnext_ov)
elif fn_gov is 0:
all_grads_l.append(inp.zeros_like())
all_grads_ov_l.append(self.ofg_discon_t())
elif fn_gov is None:
all_grads_l.append(inp.zeros_like())
all_grads_ov_l.append(self.ofg_null_t())
else:
all_grads_l.append(gov(self.local_inputs, output_grads))
elif self._grad_op is undef:
all_grads_l = [
inp.zeros_like().astype(theano.config.floatX)
for inp in self.local_inputs]
if not hasattr(fn_gov, '__call__'):
raise TypeError(TYPE_ERR_MSG % fn_gov)
gov, gov_ov = OpFromGraph._filter_grad_var(
fn_gov(local_inputs, output_grads), inp)
all_grads_l.append(gov)
all_grads_ov_l.append(gov_ov)
else:
all_grads_l = self._grad_op(self.local_inputs, output_grads)
if not isinstance(all_grads_l, (tuple, list)):
all_grads_l = [all_grads_l]
if len(all_grads_l) != len(self.local_inputs):
# callable case
if not hasattr(grad_op, '__call__'):
raise TypeError(TYPE_ERR_MSG % grad_op)
goverrides_l = grad_op(local_inputs, output_grads)
if not isinstance(goverrides_l, list):
raise TypeError(
'Gradient overriding function should return a list, '
'got "%s"' % type(goverrides_l))
all_grads_l, all_grads_ov_l = izip(
*[OpFromGraph._filter_grad_var(grad, inp) for grad, inp in izip(goverrides_l, local_inputs)])
if len(all_grads_l) != len(local_inputs):
raise ValueError(
'Gradient overriding function %s should return list of '
'%d outputs, got %d' % (
self._grad_op, len(self.local_inputs), len(all_grads_l)),
self._grad_op
)
'Gradient overriding function should return list of '
'%d outputs, got %d' % (inp_len, len(all_grads_l)))
all_grads_l = list(all_grads_l)
all_grads_ov_l = list(all_grads_ov_l)
self._grad_op = type(self)(
inputs=self.local_inputs + output_grads,
inputs=local_inputs + output_grads,
outputs=all_grads_l,
inline=self.is_inline,
name=(None if self.name is None else self.name + '_grad'),
on_unused_input='ignore')
self._grad_op_overrides_l = all_grads_ov_l
self._grad_op_is_cached = True
def _recompute_rop_op(self):
......@@ -375,12 +445,23 @@ class OpFromGraph(gof.Op):
def R_op(self, inputs, eval_points):
if not self._rop_op_is_cached:
self._recompute_rop_op()
return self._rop_op(*(list(inputs) + list(eval_points)), return_list=True)
ret_ofg_l = self._rop_op(
*(list(inputs) + list(eval_points)), return_list=True)
ret_l = [{
self.TFLAG_NULL_T: self.ofg_null_t(),
self.TFLAG_DISCON_T: self.ofg_discon_t()
}[flag] if flag else ret_ofg for ret_ofg, flag in izip(ret_ofg_l, self._grad_tflags)]
return ret_l
def grad(self, inputs, output_grads):
if not self._grad_op_is_cached:
self._recompute_grad_op()
return self._grad_op(*(list(inputs) + list(output_grads)), return_list=True)
ret_ofg_l = self._grad_op(
*(list(inputs) + list(output_grads)), return_list=True)
ret_l = [
ret_ofg if ov is None else ov for ret_ofg, ov in izip(ret_ofg_l, self._grad_op_overrides_l)]
return ret_l
def make_node(self, *inputs):
num_expected_inps = len(self.local_inputs) - len(self.shared_inputs)
......
......@@ -160,7 +160,7 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
w, b = T.vectors('wb')
# we make the 3rd gradient default (no override)
op_linear = cls_ofg([x, w, b], [x * w + b], grad_overrides=[go1, go2])
op_linear = cls_ofg([x, w, b], [x * w + b], grad_overrides=[go1, go2, Ellipsis])
xx, ww, bb = T.vector('xx'), T.vector('yy'), T.vector('bb')
zz = T.sum(op_linear(xx, ww, bb))
dx, dw, db = T.grad(zz, [xx, ww, bb])
......@@ -281,21 +281,19 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
[True, False, True]]
assert results == expect_result
@test_params
def test_infer_shape(self, cls_ofg):
def test_infer_shape(self):
# test infer shape does not need to against inline case
# since the Op is remove during optimization phase
x = T.matrix('x')
y = T.matrix('y')
o1 = x + y
o2 = x * y
op_graph = cls_ofg([x, y], [o1, o2])
op_graph = OpFromGraph([x, y], [o1, o2])
q = T.matrix('q')
p = T.matrix('p')
# we don't want check_topo for inline ops
# since the inline op is replaced during optimization
self._compile_and_check([q, p],
op_graph(q, p),
[np.ones([3, 4], dtype=config.floatX),
np.ones([3, 4], dtype=config.floatX)],
cls_ofg,
check_topo=not op_graph.is_inline)
OpFromGraph)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论