提交 8e8758a3 authored 作者: khaotik's avatar khaotik

better handling for NullType and DisconnectedType

major changes: - The self._grad_op now only returns zeros_like() for special types like NullType() or DisconnectedType() - call to grad() will furthur replace returned zero tensors with special types - proposed gradient override interface : (single or list of below) Ellipsis -> <no_override> (-) since python 2 does not support `[...]` syntax, this may result in uglier code in python 2 None -> NullType() int(0) -> DisconnectedType() OpFromGraph instance or callable -> <override> minor changes: - various typo/bug fix notes: - This commit breaks OpFromGraph.R_op, which is expected to be fixed in upcoming commits.
上级 a6e5cd74
"""Define new Ops from existing Ops""" """Define new Ops from existing Ops"""
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
from functools import reduce from functools import reduce
from collections import OrderedDict
import theano import theano
from theano import gof from theano import gof
...@@ -9,7 +10,8 @@ from theano.compile.function_module import orig_function ...@@ -9,7 +10,8 @@ from theano.compile.function_module import orig_function
from theano.compile import SharedVariable, rebuild_collect_shared, optdb from theano.compile import SharedVariable, rebuild_collect_shared, optdb
from theano.gof import ops_with_inner_function from theano.gof import ops_with_inner_function
from theano.gof.graph import io_connection_pattern from theano.gof.graph import io_connection_pattern
from theano.gof.utils import undef from theano.gof.null_type import NullType
from theano.gradient import DisconnectedType
class OpFromGraph(gof.Op): class OpFromGraph(gof.Op):
...@@ -38,11 +40,12 @@ class OpFromGraph(gof.Op): ...@@ -38,11 +40,12 @@ class OpFromGraph(gof.Op):
``False`` : will use a pre-compiled function inside. ``False`` : will use a pre-compiled function inside.
grad_overrides : single or list of {None, undef, OpFromGraph, callable}, optional grad_overrides : single or list of {0, None, Ellipsis, OpFromGraph, callable}, optional
Defaults to ``None``. Defaults to ``None``.
``None`` : Do not override gradient ``None`` : No value, gives NullType()
theano.utils.undef : No gradient will be used (zero) ``0`` : zero value, gives DisconnectedType()
``...`` : Do not override, use default grad() result
OpFromGraph instance : Override with another OpFromGraph, should OpFromGraph instance : Override with another OpFromGraph, should
accept inputs as the same order and types of "inputs" and "output_grads" accept inputs as the same order and types of "inputs" and "output_grads"
...@@ -53,14 +56,15 @@ class OpFromGraph(gof.Op): ...@@ -53,14 +56,15 @@ class OpFromGraph(gof.Op):
list: Each OpFromGraph/callable must return a single list: Each OpFromGraph/callable must return a single
:class:`Variable <theano.gof.Variable>`. Each list element corresponds to gradient of :class:`Variable <theano.gof.Variable>`. Each list element corresponds to gradient of
a specific input. a specific input, length of list must be equal to number of inputs.
rop_overrides : single or list of {None, undef, OpFromGraph, callable}, optional rop_overrides : single or list of {0, None, Ellipsis, OpFromGraph, callable}, optional
Defaults to ``None``. Defaults to ``None``.
``None`` : Do not override gradient ``None`` : No value, gives NullType()
``0`` : zero value, gives DisconnectedType()
``...`` : Do not override, use default R_op() result
theano.utils.undef : No gradient will be used (zero)
OpFromGraph instance : Override with another OpFromGraph, should OpFromGraph instance : Override with another OpFromGraph, should
accept inputs as the same order and types of "inputs" and "output_grads" accept inputs as the same order and types of "inputs" and "output_grads"
arguments as one would specify in grad() method. arguments as one would specify in grad() method.
...@@ -70,7 +74,7 @@ class OpFromGraph(gof.Op): ...@@ -70,7 +74,7 @@ class OpFromGraph(gof.Op):
list: Each OpFromGraph/callable must return a single list: Each OpFromGraph/callable must return a single
:class:`Variable <theano.gof.Variable>`. Each list element corresponds :class:`Variable <theano.gof.Variable>`. Each list element corresponds
to a specific output of R_op. to a specific output of R_op, length of list must be equal to number of outputs.
name : string, optional name : string, optional
A name for debugging purposes A name for debugging purposes
...@@ -88,12 +92,14 @@ class OpFromGraph(gof.Op): ...@@ -88,12 +92,14 @@ class OpFromGraph(gof.Op):
local_outputs) local_outputs)
- c_code() to remove the double overhead? - c_code() to remove the double overhead?
- grad() make it support DisconnectedType and the new interface - grad() make it support DisconnectedType and the new interface
- extend to lop_overrides?
- check how it works with updates. - check how it works with updates.
- add test with constant as input or inside the inner graph. - add test with constant as input or inside the inner graph.
- Add support for the GPU? Probably just need an opt to remove transfer - Add support for the GPU? Probably just need an opt to remove transfer
- Add support to pickle this Op. - Add support to pickle this Op.
- Add support/test with random generator - Add support/test with random generator
- Add optimization to removing unused inputs/outputs - Add optimization to removing unused inputs/outputs
- Add optimization to work inplace when not inline
Notes Notes
----- -----
...@@ -117,7 +123,7 @@ class OpFromGraph(gof.Op): ...@@ -117,7 +123,7 @@ class OpFromGraph(gof.Op):
.. code-block:: python .. code-block:: python
from theano import function, tensor from theano import function, OpFromGraph, tensor
x, y, z = tensor.scalars('xyz') x, y, z = tensor.scalars('xyz')
e = x + y * z e = x + y * z
op = OpFromGraph([x, y, z], [e]) op = OpFromGraph([x, y, z], [e])
...@@ -144,25 +150,59 @@ class OpFromGraph(gof.Op): ...@@ -144,25 +150,59 @@ class OpFromGraph(gof.Op):
.. code-block:: python .. code-block:: python
from thenao import funciton, OpFromGraph, tensor, grad from theano import function, OpFromGraph, tensor, grad
x, y, z = tensor.scalars('xyz') x, y, z = tensor.scalars('xyz')
e = x + y * z e = x + y * z
def rescale_dy(inps, grads): def rescale_dy(inps, grads):
x, y, z = inps x, y, z = inps
g = grads g, = grads
return z*2 return z*2
op = OpFromGraph( op = OpFromGraph(
[x, y, z], [e], grad_overrides=[None, rescale_dy, None]) [x, y, z], [e], grad_overrides=[Ellipsis, rescale_dy, Ellipsis]
e2 = op(x, y, z) e2 = op(x, y, z)
dx, dy, dz = grad(e2, [x, y, z]) dx, dy, dz = grad(e2, [x, y, z])
fn = function([x, y, z], [dx, dy, dz]) fn = function([x, y, z], [dx, dy, dz])
# the graident wrt y is now doubled # the gradient wrt y is now doubled
fn(2., 3., 4.) # [1., 8., 3.] fn(2., 3., 4.) # [1., 8., 3.]
""" """
def __init__(self, inputs, outputs, inline=False, grad_overrides=None, rop_overrides=None, name=None, **kwargs): ofg_null_t = NullType(why_null='ofg_overridden')
ofg_discon_t = DisconnectedType()
@staticmethod
def _filter_grad_var(grad, inp):
# Returns (filtered_var, overrider_var)
# Args:
# grad: gradient Variable
# inp: the corresponding input of gradient Variable
#
# Some Variable types cannot be used directly as OfG output such as
# NullType, or DisconnectedType.
#
# However a grad() call could return these types
#
# Since we always use an OfG instance as self._grad_op, the current
# workaround is to "remember" the special cases of the gradient and
# replace them after self._grad_op is called.
#
# This helper function changes invalid types into a filtered_type,
# and provides a overrider_type to be replaced at grad() call
#
# For now, this converts NullType or DisconnectedType into zeros_like.
# other types are unmodified with overrider_type -> None
if isinstance(grad.type, (NullType, DisconnectedType)):
return inp.zeros_like(), grad.type
else:
return grad, None
def __init__(
self, inputs, outputs,
inline=False,
grad_overrides=Ellipsis, rop_overrides=Ellipsis,
name=None, **kwargs
):
if not isinstance(outputs, list): if not isinstance(outputs, list):
raise TypeError('outputs must be list', outputs) raise TypeError('outputs must be list, got %s' % outputs, outputs)
for i in inputs + outputs: for i in inputs + outputs:
if not isinstance(i, gof.Variable): if not isinstance(i, gof.Variable):
raise TypeError( raise TypeError(
...@@ -216,64 +256,94 @@ class OpFromGraph(gof.Op): ...@@ -216,64 +256,94 @@ class OpFromGraph(gof.Op):
return '%(name)s{inline=%(is_inline)s}' % locals() return '%(name)s{inline=%(is_inline)s}' % locals()
def _recompute_grad_op(self): def _recompute_grad_op(self):
if isinstance(self._grad_op, OpFromGraph): local_inputs = self.local_inputs
local_outputs = self.local_outputs
inp_len = len(local_inputs)
grad_op = self._grad_op
if isinstance(grad_op, OpFromGraph):
self._grad_op_is_cached = True self._grad_op_is_cached = True
self._grad_op_overrides_l = [None] * len(self.local_inputs)
return return
output_grads = [out_t() for out_t in self.output_types]
if self._grad_op is None:
self._grad_op = []
# we need to convert a list/function into an OfG instance output_grads = [out_t() for out_t in self.output_types]
if isinstance(self._grad_op, list): TYPE_ERR_MSG = 'Gradient override should be (single or list of)' \
'OpFromGraph | Ellipsis | None | 0 | callable, got %s'
# we need to convert _grad_op into an OfG instance
if grad_op is Ellipsis:
self._grad_op_tflags = bytes(inp_len)
all_grads_l = theano.gradient.grad(
cost=None,
known_grads=OrderedDict(izip(local_outputs, output_grads)),
wrt=local_inputs,
disconnected_inputs='ignore')
all_grads_ov_l = [None] * inp_len
elif grad_op is None:
all_grads_l = [inp.zeros_like() for inp in local_inputs]
all_grads_ov_l = [self.ofg_null_t()] * inp_len
elif grad_op is 0:
all_grads_l = [inp.zeros_like() for inp in local_inputs]
all_grads_ov_l = [self.ofg_discon_t()] * inp_len
elif isinstance(grad_op, list):
goverrides_l = self._grad_op goverrides_l = self._grad_op
if len(goverrides_l) > len(self.local_inputs): if len(goverrides_l) != inp_len:
raise ValueError( raise ValueError(
'Can override %d gradients at most, got %d' % ( 'Need to override %d gradients, got %d' % (
len(self.local_inputs), len(goverrides_l)), inp_len, len(goverrides_l)), goverrides_l)
self.goverrides_l)
if len(goverrides_l) < len(self.local_inputs):
goverrides_l += [None] * (
len(self.local_inputs) - len(goverrides_l))
wrt_l = [lin for lin, gov in
izip(self.local_inputs, goverrides_l) if not gov]
# compute non-overriding downsteam grads from upstreams grads # compute non-overriding downsteam grads from upstreams grads
# it's normal some input may be disconnected, thus the 'ignore' # it's normal some input may be disconnected, thus the 'ignore'
wrt_l = [lin for lin, gov in izip(
self.local_inputs, goverrides_l) if gov is Ellipsis]
gdefaults = iter(theano.gradient.grad( gdefaults = iter(theano.gradient.grad(
cost=None, cost=None,
known_grads=dict(izip(self.local_outputs, output_grads)), known_grads=OrderedDict(izip(self.local_outputs, output_grads)),
wrt=wrt_l, wrt=wrt_l,
disconnected_inputs='ignore') if wrt_l else []) disconnected_inputs='ignore') if wrt_l else [])
# combine overriding gradients # combine overriding gradients
all_grads_l = [] all_grads_l = []
for inp, gov in izip(self.local_inputs, goverrides_l): all_grads_ov_l = []
if gov is None: for i, (inp, fn_gov) in enumerate(izip(local_inputs, goverrides_l)):
all_grads_l.append(next(gdefaults)) if fn_gov is Ellipsis:
elif gov is undef: gnext, gnext_ov = OpFromGraph._filter_grad_var(next(gdefaults), inp)
all_grads_l.append( all_grads_l.append(gnext)
inp.zeros_like().astype(theano.config.floatX)) all_grads_ov_l.append(gnext_ov)
elif fn_gov is 0:
all_grads_l.append(inp.zeros_like())
all_grads_ov_l.append(self.ofg_discon_t())
elif fn_gov is None:
all_grads_l.append(inp.zeros_like())
all_grads_ov_l.append(self.ofg_null_t())
else: else:
all_grads_l.append(gov(self.local_inputs, output_grads)) if not hasattr(fn_gov, '__call__'):
elif self._grad_op is undef: raise TypeError(TYPE_ERR_MSG % fn_gov)
all_grads_l = [ gov, gov_ov = OpFromGraph._filter_grad_var(
inp.zeros_like().astype(theano.config.floatX) fn_gov(local_inputs, output_grads), inp)
for inp in self.local_inputs] all_grads_l.append(gov)
all_grads_ov_l.append(gov_ov)
else: else:
all_grads_l = self._grad_op(self.local_inputs, output_grads) # callable case
if not isinstance(all_grads_l, (tuple, list)): if not hasattr(grad_op, '__call__'):
all_grads_l = [all_grads_l] raise TypeError(TYPE_ERR_MSG % grad_op)
if len(all_grads_l) != len(self.local_inputs): goverrides_l = grad_op(local_inputs, output_grads)
if not isinstance(goverrides_l, list):
raise TypeError(
'Gradient overriding function should return a list, '
'got "%s"' % type(goverrides_l))
all_grads_l, all_grads_ov_l = izip(
*[OpFromGraph._filter_grad_var(grad, inp) for grad, inp in izip(goverrides_l, local_inputs)])
if len(all_grads_l) != len(local_inputs):
raise ValueError( raise ValueError(
'Gradient overriding function %s should return list of ' 'Gradient overriding function should return list of '
'%d outputs, got %d' % ( '%d outputs, got %d' % (inp_len, len(all_grads_l)))
self._grad_op, len(self.local_inputs), len(all_grads_l)), all_grads_l = list(all_grads_l)
self._grad_op all_grads_ov_l = list(all_grads_ov_l)
)
self._grad_op = type(self)( self._grad_op = type(self)(
inputs=self.local_inputs + output_grads, inputs=local_inputs + output_grads,
outputs=all_grads_l, outputs=all_grads_l,
inline=self.is_inline, inline=self.is_inline,
name=(None if self.name is None else self.name + '_grad'), name=(None if self.name is None else self.name + '_grad'),
on_unused_input='ignore') on_unused_input='ignore')
self._grad_op_overrides_l = all_grads_ov_l
self._grad_op_is_cached = True self._grad_op_is_cached = True
def _recompute_rop_op(self): def _recompute_rop_op(self):
...@@ -375,12 +445,23 @@ class OpFromGraph(gof.Op): ...@@ -375,12 +445,23 @@ class OpFromGraph(gof.Op):
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
if not self._rop_op_is_cached: if not self._rop_op_is_cached:
self._recompute_rop_op() self._recompute_rop_op()
return self._rop_op(*(list(inputs) + list(eval_points)), return_list=True) ret_ofg_l = self._rop_op(
*(list(inputs) + list(eval_points)), return_list=True)
ret_l = [{
self.TFLAG_NULL_T: self.ofg_null_t(),
self.TFLAG_DISCON_T: self.ofg_discon_t()
}[flag] if flag else ret_ofg for ret_ofg, flag in izip(ret_ofg_l, self._grad_tflags)]
return ret_l
def grad(self, inputs, output_grads): def grad(self, inputs, output_grads):
if not self._grad_op_is_cached: if not self._grad_op_is_cached:
self._recompute_grad_op() self._recompute_grad_op()
return self._grad_op(*(list(inputs) + list(output_grads)), return_list=True) ret_ofg_l = self._grad_op(
*(list(inputs) + list(output_grads)), return_list=True)
ret_l = [
ret_ofg if ov is None else ov for ret_ofg, ov in izip(ret_ofg_l, self._grad_op_overrides_l)]
return ret_l
def make_node(self, *inputs): def make_node(self, *inputs):
num_expected_inps = len(self.local_inputs) - len(self.shared_inputs) num_expected_inps = len(self.local_inputs) - len(self.shared_inputs)
......
...@@ -160,7 +160,7 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -160,7 +160,7 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
w, b = T.vectors('wb') w, b = T.vectors('wb')
# we make the 3rd gradient default (no override) # we make the 3rd gradient default (no override)
op_linear = cls_ofg([x, w, b], [x * w + b], grad_overrides=[go1, go2]) op_linear = cls_ofg([x, w, b], [x * w + b], grad_overrides=[go1, go2, Ellipsis])
xx, ww, bb = T.vector('xx'), T.vector('yy'), T.vector('bb') xx, ww, bb = T.vector('xx'), T.vector('yy'), T.vector('bb')
zz = T.sum(op_linear(xx, ww, bb)) zz = T.sum(op_linear(xx, ww, bb))
dx, dw, db = T.grad(zz, [xx, ww, bb]) dx, dw, db = T.grad(zz, [xx, ww, bb])
...@@ -281,21 +281,19 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -281,21 +281,19 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
[True, False, True]] [True, False, True]]
assert results == expect_result assert results == expect_result
@test_params def test_infer_shape(self):
def test_infer_shape(self, cls_ofg): # test infer shape does not need to against inline case
# since the Op is remove during optimization phase
x = T.matrix('x') x = T.matrix('x')
y = T.matrix('y') y = T.matrix('y')
o1 = x + y o1 = x + y
o2 = x * y o2 = x * y
op_graph = cls_ofg([x, y], [o1, o2]) op_graph = OpFromGraph([x, y], [o1, o2])
q = T.matrix('q') q = T.matrix('q')
p = T.matrix('p') p = T.matrix('p')
# we don't want check_topo for inline ops
# since the inline op is replaced during optimization
self._compile_and_check([q, p], self._compile_and_check([q, p],
op_graph(q, p), op_graph(q, p),
[np.ones([3, 4], dtype=config.floatX), [np.ones([3, 4], dtype=config.floatX),
np.ones([3, 4], dtype=config.floatX)], np.ones([3, 4], dtype=config.floatX)],
cls_ofg, OpFromGraph)
check_topo=not op_graph.is_inline)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论