提交 ea66f4a7 authored 作者: khaotik's avatar khaotik

OpFromGraph improvements

- now use explicit NullType() or DisconnectedType() instance for OpFromGraph special types. 'default' for no gradient/Rop override - connection pattern now considers gradient override - revert hackish changes done to theano/gradient.py - bug fix for theano.gof.graph.io_connection_pattern with isolated output - bug fix when _recompute_grad/rop_op is called twice for OfG instance
上级 b5685f95
"""Define new Ops from existing Ops"""
from __future__ import absolute_import, print_function, division
from __future__ import absolute_import, division
from functools import reduce, partial
from collections import OrderedDict
......@@ -8,7 +8,7 @@ from theano import gof
from theano.compat import izip
from theano.compile.function_module import orig_function
from theano.compile import SharedVariable, rebuild_collect_shared, optdb
from theano.gof import ops_with_inner_function
from theano.gof import Variable, ops_with_inner_function
from theano.gof.graph import io_connection_pattern
from theano.gof.null_type import NullType
from theano.gradient import DisconnectedType
......@@ -40,12 +40,10 @@ class OpFromGraph(gof.Op):
``False`` : will use a pre-compiled function inside.
grad_overrides : single or list of {0, None, Ellipsis, OpFromGraph, callable}, optional
Defaults to ``None``.
grad_overrides : single or list of {'default', OpFromGraph, callable, Variable with special type}, optional
Defaults to ``'default'``.
``None`` : No value, gives NullType()
``0`` : zero value, gives DisconnectedType()
``...`` : Do not override, use default grad() result
``'default' : Do not override, use default grad() result
OpFromGraph instance : Override with another OpFromGraph, should
accept inputs as the same order and types of "inputs" and "output_grads"
......@@ -54,16 +52,18 @@ class OpFromGraph(gof.Op):
callable : similar to OpFromGraph instance, must return list of
:class:`Variable <theano.gof.Variable>`.
Variable :
``NullType() instance`` : Treat as non-differentiable
``DisconnectedType() instance`` : Treat as disconnected gradient, numerically gives zero
list: Each OpFromGraph/callable must return a single
:class:`Variable <theano.gof.Variable>`. Each list element corresponds to gradient of
a specific input, length of list must be equal to number of inputs.
rop_overrides : single or list of {0, None, Ellipsis, OpFromGraph, callable}, optional
Defaults to ``None``.
rop_overrides : single or list of {'default', OpFromGraph, callable, Variable with special type}, optional
Defaults to ``default``.
``None`` : No value, gives NullType()
``0`` : zero value, gives zeros_like(...)
``...`` : Do not override, use default R_op() result
``default : Do not override, use default R_op() result
OpFromGraph instance : Override with another OpFromGraph, should
accept inputs as the same order and types of "inputs" and "output_grads"
......@@ -72,6 +72,10 @@ class OpFromGraph(gof.Op):
callable : similar to OpFromGraph instance, must return list of
:class:`Variable <theano.gof.Variable>`.
Variable :
``NullType() instance`` : Treat as non-differentiable
``DisconnectedType() instance`` : Treat as zero since DisconnectedType is not yet supported in R_op
list: Each OpFromGraph/callable must return a single
:class:`Variable <theano.gof.Variable>`. Each list element corresponds
to a specific output of R_op, length of list must be equal to number of outputs.
......@@ -159,7 +163,7 @@ class OpFromGraph(gof.Op):
g, = grads
return z*2
op = OpFromGraph(
[x, y, z], [e], grad_overrides=[Ellipsis, rescale_dy, Ellipsis]
[x, y, z], [e], grad_overrides=['default', rescale_dy, 'default']
e2 = op(x, y, z)
dx, dy, dz = grad(e2, [x, y, z])
fn = function([x, y, z], [dx, dy, dz])
......@@ -167,8 +171,6 @@ class OpFromGraph(gof.Op):
fn(2., 3., 4.) # [1., 8., 3.]
"""
ofg_null_t = NullType(why_null='ofg_overridden')
ofg_discon_t = DisconnectedType()
@staticmethod
def _filter_grad_var(grad, inp):
......@@ -190,7 +192,10 @@ class OpFromGraph(gof.Op):
# For now, this converts NullType or DisconnectedType into zeros_like.
# other types are unmodified: overrider_var -> None
if isinstance(grad.type, (NullType, DisconnectedType)):
return inp.zeros_like(), grad
if hasattr(inp, 'zeros_like'):
return inp.zeros_like(), grad
else:
return theano.tensor.constant(0.), grad
else:
return grad, None
......@@ -209,7 +214,7 @@ class OpFromGraph(gof.Op):
def __init__(
self, inputs, outputs,
inline=False,
grad_overrides=Ellipsis, rop_overrides=Ellipsis,
grad_overrides='default', rop_overrides='default',
name=None, **kwargs
):
if not isinstance(outputs, list):
......@@ -267,14 +272,19 @@ class OpFromGraph(gof.Op):
return '%(name)s{inline=%(is_inline)s}' % locals()
def _recompute_grad_op(self):
'''
converts self._grad_op from user supplied form to type(self) instance
'''
local_inputs = self.local_inputs
local_outputs = self.local_outputs
inp_len = len(local_inputs)
grad_op = self._grad_op
if isinstance(grad_op, OpFromGraph):
self._grad_op_is_cached = True
self._grad_op_overrides_l = [None] * inp_len
if not self._grad_op_is_cached:
self._grad_op_is_cached = True
self._grad_op_stypes_l = [None] * inp_len
return
output_grads = [out_t() for out_t in self.output_types]
......@@ -286,20 +296,24 @@ class OpFromGraph(gof.Op):
null_gradients='return',
known_grads=OrderedDict(izip(local_outputs, output_grads)))
TYPE_ERR_MSG = 'Gradient override should be (single or list of)' \
'OpFromGraph | Ellipsis | None | 0 | callable, got %s'
TYPE_ERR_MSG = ("Gradient override should be (single or list of)"
"'default' | OpFromGraph | callable | Variable "
"with NullType or DisconnectedType, got %s")
STYPE_ERR_MSG = ('Overriding Variable instance can only have type'
' of DisconnectedType or NullType, got %s')
# we need to convert _grad_op into an OfG instance
if grad_op is Ellipsis:
if grad_op == 'default':
gdefaults_l = fn_grad(wrt=local_inputs)
all_grads_ov_l = [None] * inp_len
all_grads_l, all_grads_ov_l = izip(
*[OpFromGraph._filter_grad_var(grad, inp) for grad, inp in izip(gdefaults_l, local_inputs)])
elif grad_op is None:
all_grads_l = [inp.zeros_like() for inp in local_inputs]
all_grads_ov_l = [self.ofg_null_t()] * inp_len
elif type(grad_op) is int and grad_op == 0:
all_grads_l = [inp.zeros_like() for inp in local_inputs]
all_grads_ov_l = [self.ofg_discon_t()] * inp_len
all_grads_l = list(all_grads_l)
all_grads_ov_l = list(all_grads_ov_l)
elif isinstance(grad_op, Variable):
if isinstance(grad_op.type, (DisconnectedType, NullType)):
all_grads_l = [inp.zeros_like() for inp in local_inputs]
all_grads_ov_l = [grad_op.type() for _ in range(inp_len)]
else:
raise ValueError(STYPE_ERR_MSG % grad_op.type)
elif isinstance(grad_op, list):
goverrides_l = grad_op
if len(goverrides_l) != inp_len:
......@@ -309,23 +323,23 @@ class OpFromGraph(gof.Op):
# compute non-overriding downsteam grads from upstreams grads
# it's normal some input may be disconnected, thus the 'ignore'
wrt_l = [lin for lin, gov in izip(
local_inputs, goverrides_l) if gov is Ellipsis]
local_inputs, goverrides_l) if gov == 'default']
gdefaults = iter(fn_grad(wrt=wrt_l) if wrt_l else [])
# combine overriding gradients
all_grads_l = []
all_grads_ov_l = []
for inp, fn_gov in izip(local_inputs, goverrides_l):
if fn_gov is Ellipsis:
if fn_gov == 'default':
gnext, gnext_ov = OpFromGraph._filter_grad_var(
next(gdefaults), inp)
all_grads_l.append(gnext)
all_grads_ov_l.append(gnext_ov)
elif fn_gov is 0:
all_grads_l.append(inp.zeros_like())
all_grads_ov_l.append(self.ofg_discon_t())
elif fn_gov is None:
all_grads_l.append(inp.zeros_like())
all_grads_ov_l.append(self.ofg_null_t())
elif isinstance(fn_gov, Variable):
if isinstance(fn_gov.type, (DisconnectedType, NullType)):
all_grads_l.append(inp.zeros_like())
all_grads_ov_l.append(fn_gov.type())
else:
raise ValueError(STYPE_ERR_MSG % fn_gov.type)
else:
if not hasattr(fn_gov, '__call__'):
raise TypeError(TYPE_ERR_MSG % fn_gov)
......@@ -343,8 +357,8 @@ class OpFromGraph(gof.Op):
'Gradient overriding function should return a list, '
'got "%s"' % type(goverrides_l))
all_grads_l, all_grads_ov_l = izip(
*[OpFromGraph._filter_grad_var(
grad, inp) for grad, inp in izip(goverrides_l, local_inputs)])
*[OpFromGraph._filter_grad_var(grad, inp)
for grad, inp in izip(goverrides_l, local_inputs)])
if len(all_grads_l) != len(local_inputs):
raise ValueError(
'Gradient overriding function should return list of '
......@@ -357,18 +371,23 @@ class OpFromGraph(gof.Op):
inline=self.is_inline,
name=(None if self.name is None else self.name + '_grad'),
on_unused_input='ignore')
self._grad_op_overrides_l = all_grads_ov_l
self._grad_op_stypes_l = all_grads_ov_l
self._grad_op_is_cached = True
def _recompute_rop_op(self):
'''
converts self._rop_op from user supplied form to type(self) instance
'''
local_inputs = self.local_inputs
local_outputs = self.local_outputs
out_len = len(local_outputs)
rop_op = self._rop_op
if isinstance(self._rop_op, OpFromGraph):
self._rop_op_is_cached = True
self._rop_op_overrides_l = [None] * out_len
if isinstance(rop_op, OpFromGraph):
if not self._rop_op_is_cached:
self._rop_op_is_cached = True
self._rop_op_stypes_l = [None] * out_len
return
eval_points = [inp_t() for inp_t in self.input_types]
......@@ -376,17 +395,26 @@ class OpFromGraph(gof.Op):
theano.gradient.Rop,
wrt=local_inputs,
eval_points=eval_points)
TYPE_ERR_MSG = 'R_op overrides should be (single or list of)' \
'OpFromGraph | Ellipsis | None | 0 | callable, got %s'
if rop_op is Ellipsis:
all_rops_l = fn_rop(f=local_outputs)
all_rops_ov_l = [None] * out_len
elif rop_op is None:
all_rops_l = [out.zeros_like() for out in local_outputs]
all_rops_ov_l = [self.ofg_null_t()] * out_len
elif rop_op is 0:
all_rops_l = [out.zeros_like() for out in local_outputs]
all_rops_ov_l = [None] * out_len
TYPE_ERR_MSG = ("R_op overrides should be (single or list of)"
"OpFromGraph | 'default' | None | 0 | callable, got %s")
STYPE_ERR_MSG = ('Overriding Variable instance can only have type'
' of DisconnectedType or NullType, got %s')
if rop_op == 'default':
rdefaults_l = fn_rop(f=local_outputs)
all_rops_l, all_rops_ov_l = izip(
*[OpFromGraph._filter_rop_var(rop, out) for rop,
out in izip(rdefaults_l, local_outputs)])
all_rops_l = list(all_rops_l)
all_rops_ov_l = list(all_rops_ov_l)
elif isinstance(rop_op, Variable):
if isinstance(rop_op.type, NullType):
all_rops_l = [inp.zeros_like() for inp in local_inputs]
all_rops_ov_l = [rop_op.type() for _ in range(out_len)]
elif isinstance(rop_op.type, DisconnectedType):
all_rops_l = [inp.zeros_like() for inp in local_inputs]
all_rops_ov_l = [None] * out_len
else:
raise ValueError(STYPE_ERR_MSG % rop_op.type)
elif isinstance(rop_op, list):
roverrides_l = rop_op
if len(roverrides_l) != out_len:
......@@ -396,24 +424,27 @@ class OpFromGraph(gof.Op):
# get outputs that does not have Rop override
odefaults_l = [
lo for lo, rov in izip(local_outputs, roverrides_l)
if rov is Ellipsis]
if rov == 'default']
rdefaults_l = fn_rop(f=odefaults_l)
rdefaults = iter(rdefaults_l if odefaults_l else [])
# combine overriding Rops
all_rops_l = []
all_rops_ov_l = []
for out, fn_rov in izip(local_outputs, roverrides_l):
if fn_rov is Ellipsis:
if fn_rov == 'default':
rnext, rnext_ov = OpFromGraph._filter_rop_var(
next(rdefaults), out)
all_rops_l.append(rnext)
all_rops_ov_l.append(rnext_ov)
elif fn_rov is 0:
all_rops_l.append(out.zeros_like())
all_rops_ov_l.append(None)
elif fn_rov is None:
all_rops_l.append(out.zeros_like())
all_rops_ov_l.append(self.ofg_null_t())
elif isinstance(fn_rov, Variable):
if isinstance(fn_rov.type, NullType):
all_rops_l.append(out.zeros_like())
all_rops_ov_l.append(fn_rov.type())
if isinstance(fn_rov.type, DisconnectedType):
all_rops_l.append(out.zeros_like())
all_rops_ov_l.append(None)
else:
raise ValueError(STYPE_ERR_MSG % fn_rov.type)
else:
if not hasattr(fn_rov, '__call__'):
raise TypeError(TYPE_ERR_MSG % fn_rov)
......@@ -446,7 +477,7 @@ class OpFromGraph(gof.Op):
inline=self.is_inline,
name=(None if self.name is None else self.name + '_rop'),
on_unused_input='ignore')
self._rop_op_overrides_l = all_rops_ov_l
self._rop_op_stypes_l = all_rops_ov_l
self._rop_op_is_cached = True
def get_grad_op(self):
......@@ -465,15 +496,6 @@ class OpFromGraph(gof.Op):
self._recompute_rop_op()
return self._rop_op
def set_rop_overrides(self, rop_overrides):
"""
Set R_op overrides, see help(theano.OpFromGraph) for syntax
This will completely remove any previously set R_op overrides
"""
self._rop_op = rop_overrides
self._rop_op_is_cached = False
def set_grad_overrides(self, grad_overrides):
"""
Set gradient overrides, see help(theano.OpFromGraph) for syntax
......@@ -483,15 +505,14 @@ class OpFromGraph(gof.Op):
self._grad_op = grad_overrides
self._grad_op_is_cached = False
def R_op(self, inputs, eval_points):
if not self._rop_op_is_cached:
self._recompute_rop_op()
ret_ofg_l = self._rop_op(
*(list(inputs) + list(eval_points)), return_list=True)
ret_l = [
ret_ofg if ov is None else ov for ret_ofg, ov in izip(
ret_ofg_l, self._rop_op_overrides_l)]
return ret_l
def set_rop_overrides(self, rop_overrides):
"""
Set R_op overrides, see help(theano.OpFromGraph) for syntax
This will completely remove any previously set R_op overrides
"""
self._rop_op = rop_overrides
self._rop_op_is_cached = False
def grad(self, inputs, output_grads):
if not self._grad_op_is_cached:
......@@ -500,13 +521,24 @@ class OpFromGraph(gof.Op):
*(list(inputs) + list(output_grads)), return_list=True)
ret_l = [
ret_ofg if ov is None else ov for ret_ofg, ov in izip(
ret_ofg_l, self._grad_op_overrides_l)]
ret_ofg_l, self._grad_op_stypes_l)]
return ret_l
def R_op(self, inputs, eval_points):
if not self._rop_op_is_cached:
self._recompute_rop_op()
ret_ofg_l = self._rop_op(
*(list(inputs) + list(eval_points)), return_list=True)
ret_l = [
ret_ofg if ov is None else ov for ret_ofg, ov in izip(
ret_ofg_l, self._rop_op_stypes_l)]
return ret_l
def make_node(self, *inputs):
num_expected_inps = len(self.local_inputs) - len(self.shared_inputs)
if len(inputs) != num_expected_inps:
raise ValueError("Expected %d inputs, got %d" % (num_expected_inps, len(inputs)))
raise ValueError(
"Expected %d inputs, got %d" % (num_expected_inps, len(inputs)))
inputs = [inp_t.filter_variable(inp) for inp, inp_t in izip(inputs, self.input_types)]
apply_node = gof.Apply(
self, list(inputs) + self.shared_inputs,
......@@ -520,9 +552,32 @@ class OpFromGraph(gof.Op):
Return connection pattern of subfgraph defined by inputs and outputs.
"""
return io_connection_pattern(
inp_len = len(self.local_inputs)
out_len = len(self.local_outputs)
cpmat_self = io_connection_pattern(
self.local_inputs, self.local_outputs)
grad_op = self.get_grad_op()
cpmat_grad = io_connection_pattern(
grad_op.local_inputs[inp_len:],
grad_op.local_outputs)
# cpmat_self |= cpmat_grad.T
# cpmat_self &= out_is_disconnected
for i, t in enumerate(self._grad_op_stypes_l):
if t is not None:
if isinstance(t.type, DisconnectedType):
for o in range(out_len):
cpmat_self[i][o] = False
for o in range(out_len):
cpmat_self[i][o] |= cpmat_grad[o][i]
# TODO in case DisconnectedType is implemented for R_op,
# self._rop_op_stypes_l self._rop_op should considered for
# connection_pattern
return list(map(list, cpmat_self))
def infer_shape(self, node, shapes):
out_shp = theano.scan_module.scan_utils.infer_shape(
self.local_outputs,
......
......@@ -162,7 +162,7 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
w, b = T.vectors('wb')
# we make the 3rd gradient default (no override)
op_linear = cls_ofg([x, w, b], [x * w + b], grad_overrides=[go1, go2, Ellipsis])
op_linear = cls_ofg([x, w, b], [x * w + b], grad_overrides=[go1, go2, 'default'])
xx, ww, bb = T.vector('xx'), T.vector('yy'), T.vector('bb')
zz = T.sum(op_linear(xx, ww, bb))
dx, dw, db = T.grad(zz, [xx, ww, bb])
......@@ -176,7 +176,9 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
assert np.allclose(np.ones(16, dtype=config.floatX), dbv)
# NullType and DisconnectedType
op_linear2 = cls_ofg([x, w, b], [x * w + b], grad_overrides=[go1, None, 0])
op_linear2 = cls_ofg(
[x, w, b], [x * w + b],
grad_overrides=[go1, NullType()(), DisconnectedType()()])
zz2 = T.sum(op_linear2(xx, ww, bb))
dx2, dw2, db2 = T.grad(
zz2, [xx, ww, bb],
......@@ -205,7 +207,6 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
duval = np.random.rand(16).astype(config.floatX)
dvval = np.dot(duval, Wval)
dvval2 = fn(xval, Wval, duval)
print(dvval, dvval2)
assert np.allclose(dvval2, dvval)
@test_params
......
......@@ -1099,7 +1099,10 @@ def io_connection_pattern(inputs, outputs):
# connnection patterns of the individual outputs
global_connection_pattern = [[] for o in range(len(inputs))]
for out in outputs:
out_connection_pattern = connect_pattern_by_var[out]
out_connection_pattern = connect_pattern_by_var.get(out)
if out_connection_pattern is None:
# the output is completely isolated from inputs
out_connection_pattern = [False] * len(inputs)
for i in range(len(inputs)):
global_connection_pattern[i].append(out_connection_pattern[i])
......
......@@ -1233,12 +1233,6 @@ def _populate_grad_dict(var_to_app_to_idx,
actually_connected = \
not isinstance(ig.type, DisconnectedType)
if isinstance(node.op, theano.OpFromGraph):
ov = node.op._grad_op_overrides_l[i]
if ov is not None:
connected &= not isinstance(
ov.type, DisconnectedType)
if actually_connected and not connected:
msg = ("%s.grad returned %s of type %s for input %d."
" Expected DisconnectedType instance based on "
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论