提交 3518b38f authored 作者: khaotik's avatar khaotik

mixed changes for OpFromGraph

- correct docs about overriding - add test for lop_override - fix lop_override - cleanup debugging code
上级 55bb0f69
...@@ -46,12 +46,13 @@ class OpFromGraph(gof.Op): ...@@ -46,12 +46,13 @@ class OpFromGraph(gof.Op):
``'default'`` : Do not override, use default grad() result ``'default'`` : Do not override, use default grad() result
OpFromGraph instance : Override with another OpFromGraph, should OpFromGraph instance : Override with another OpFromGraph, should
accept inputs as the same order and types of "inputs" and "output_grads" accept inputs as the same order and types of ``inputs`` and ``output_grads``
arguments as one would specify in grad() method. This argument is mutually arguments as one would specify in grad() method. This argument is mutually
exclusive to lop_overrides. exclusive to lop_overrides.
callable : similar to OpFromGraph instance, must return list of callable : Should take two args: ``inputs`` and ``output_grads``.
:class:`Variable <theano.gof.Variable>`. Each argument is expected to be a list of :class:`Variable <theano.gof.Variable>`.
Must return list of :class:`Variable <theano.gof.Variable>`.
Variable : Variable :
``NullType() instance`` : Treat as non-differentiable ``NullType() instance`` : Treat as non-differentiable
...@@ -64,9 +65,24 @@ class OpFromGraph(gof.Op): ...@@ -64,9 +65,24 @@ class OpFromGraph(gof.Op):
lop_overrides : single or list of {'default', OpFromGraph, callable, Variable with special type}, optional lop_overrides : single or list of {'default', OpFromGraph, callable, Variable with special type}, optional
Defaults to ``'default'``. Defaults to ``'default'``.
Similar to ``grad_overrides``, except callables should accept inputs as the ``'default'`` : Do not override, use default L_op() result
same order and types of "inputs", "outputs", and "output_grads" in L_op. This
argument is mutually exclusive with grad_overrides. OpFromGraph instance : Override with another OpFromGraph, should
accept inputs as the same order and types of ``inputs``, ``outputs`` and ``output_grads``
arguments as one would specify in grad() method. This argument is mutually exclusive with
``grad_overrides``.
callable : Should take three args: ``inputs``, ``outputs`` and ``output_grads``.
Each argument is expected to be a list of :class:`Variable <theano.gof.Variable>`.
Must return list of :class:`Variable <theano.gof.Variable>`.
Variable :
``NullType() instance`` : Treat as non-differentiable
``DisconnectedType() instance`` : Treat as disconnected gradient, numerically gives zero
list: Each OpFromGraph/callable must return a single
:class:`Variable <theano.gof.Variable>`. Each list element corresponds to gradient of
a specific input, length of list must be equal to number of inputs.
rop_overrides : single or list of {'default', OpFromGraph, callable, Variable with special type}, optional rop_overrides : single or list of {'default', OpFromGraph, callable, Variable with special type}, optional
Defaults to ``default``. Defaults to ``default``.
...@@ -74,11 +90,12 @@ class OpFromGraph(gof.Op): ...@@ -74,11 +90,12 @@ class OpFromGraph(gof.Op):
``'default'`` : Do not override, use default R_op() result ``'default'`` : Do not override, use default R_op() result
OpFromGraph instance : Override with another OpFromGraph, should OpFromGraph instance : Override with another OpFromGraph, should
accept inputs as the same order and types of "inputs" and "output_grads" accept inputs as the same order and types of ``inputs`` and ``eval_points``
arguments as one would specify in grad() method. arguments as one would specify in R_op() method.
callable : similar to OpFromGraph instance, must return list of callable : Should take two args: ``inputs`` and ``eval_points``.
:class:`Variable <theano.gof.Variable>`. Each argument is expected to be a list of :class:`Variable <theano.gof.Variable>`.
Must return list of :class:`Variable <theano.gof.Variable>`.
Variable : Variable :
``NullType() instance`` : Treat as non-differentiable ``NullType() instance`` : Treat as non-differentiable
...@@ -123,7 +140,7 @@ class OpFromGraph(gof.Op): ...@@ -123,7 +140,7 @@ class OpFromGraph(gof.Op):
- ``inline=True`` will cause better runtime optimization at the cost - ``inline=True`` will cause better runtime optimization at the cost
of compilation time. Currently only works with ``fast_compile`` or of compilation time. Currently only works with ``fast_compile`` or
``fast_run`` mode. ``fast_run`` mode.
- For overriding, it's recommanded to provide pure functions (no side - For overriding, it's recommended to provide pure functions (no side
effects like setting global variable) as callable(s). The callable(s) effects like setting global variable) as callable(s). The callable(s)
supplied for overriding gradient/rop will be called only once at the supplied for overriding gradient/rop will be called only once at the
first call to grad/R_op, and will be converted to OpFromGraph instances. first call to grad/R_op, and will be converted to OpFromGraph instances.
...@@ -269,7 +286,7 @@ class OpFromGraph(gof.Op): ...@@ -269,7 +286,7 @@ class OpFromGraph(gof.Op):
self.output_types = [out.type for out in outputs] self.output_types = [out.type for out in outputs]
cond = int(lop_overrides == 'default') * 2 + int(grad_overrides == 'default') cond = int(lop_overrides == 'default') * 2 + int(grad_overrides == 'default')
if cond == 0: if cond == 0:
raise ValueError('lop_overrides and rop_overrides are mutually exclusive') raise ValueError('lop_overrides and grad_overrides are mutually exclusive')
elif cond == 1: elif cond == 1:
self.set_lop_overrides(lop_overrides) self.set_lop_overrides(lop_overrides)
self._lop_type = 'lop' self._lop_type = 'lop'
...@@ -310,7 +327,8 @@ class OpFromGraph(gof.Op): ...@@ -310,7 +327,8 @@ class OpFromGraph(gof.Op):
lop_op = self._lop_op lop_op = self._lop_op
if isinstance(lop_op, OpFromGraph): if isinstance(lop_op, OpFromGraph):
# OfG can be directly used if in L_op format if self._lop_op_is_cached:
return
if self._lop_type == 'grad': if self._lop_type == 'grad':
needed_ninps = inp_len + len(local_outputs) needed_ninps = inp_len + len(local_outputs)
ninps = len(lop_op.local_inputs) ninps = len(lop_op.local_inputs)
...@@ -318,16 +336,17 @@ class OpFromGraph(gof.Op): ...@@ -318,16 +336,17 @@ class OpFromGraph(gof.Op):
raise ValueError( raise ValueError(
self.OV_INP_LEN_ERR_MSG % (needed_ninps, ninps)) self.OV_INP_LEN_ERR_MSG % (needed_ninps, ninps))
# make a wrapper callable # make a wrapper callable
lop_op = lambda inps, grads: self._lop_op(*(inps + grads)) # noqa: 731 lop_op = lambda inps, grads: self._lop_op(*(inps + grads)) # noqa: 731
elif self._lop_type == 'lop': elif self._lop_type == 'lop':
# OfG can be directly used in L_op format
needed_ninps = inp_len + 2 * len(local_outputs) needed_ninps = inp_len + 2 * len(local_outputs)
ninps = len(lop_op.local_inputs) ninps = len(lop_op.local_inputs)
if needed_ninps != ninps: if needed_ninps != ninps:
raise ValueError( raise ValueError(
self.OV_INP_LEN_ERR_MSG % (needed_ninps, ninps)) self.OV_INP_LEN_ERR_MSG % (needed_ninps, ninps))
if not self._lop_op_is_cached: self._lop_op_is_cached = True
self._lop_op_is_cached = True self._lop_op_stypes_l = [None] * inp_len
self._lop_op_stypes_l = [None] * inp_len self._lop_op.kwargs['on_unused_input'] = 'ignore'
return return
else: else:
raise ValueError(self.LOP_TYPE_ERR_MSG % self._lop_type) raise ValueError(self.LOP_TYPE_ERR_MSG % self._lop_type)
...@@ -416,7 +435,7 @@ class OpFromGraph(gof.Op): ...@@ -416,7 +435,7 @@ class OpFromGraph(gof.Op):
inputs=local_inputs + local_outputs + output_grads, inputs=local_inputs + local_outputs + output_grads,
outputs=all_grads_l, outputs=all_grads_l,
inline=self.is_inline, inline=self.is_inline,
name=(None if self.name is None else self.name + '_lop'), name=(None if self.name is None else self.name + '_' + self._lop_type),
on_unused_input='ignore') on_unused_input='ignore')
self._lop_op_stypes_l = all_grads_ov_l self._lop_op_stypes_l = all_grads_ov_l
self._lop_op_is_cached = True self._lop_op_is_cached = True
...@@ -580,7 +599,6 @@ class OpFromGraph(gof.Op): ...@@ -580,7 +599,6 @@ class OpFromGraph(gof.Op):
def L_op(self, inputs, outputs, output_grads): def L_op(self, inputs, outputs, output_grads):
if not self._lop_op_is_cached: if not self._lop_op_is_cached:
self._recompute_lop_op() self._recompute_lop_op()
print('L_op', inputs, outputs, output_grads)
inps = list(inputs) + list(outputs) + list(output_grads) inps = list(inputs) + list(outputs) + list(output_grads)
ret_ofg_l = self._lop_op( ret_ofg_l = self._lop_op(
*inps, return_list=True) *inps, return_list=True)
...@@ -600,7 +618,6 @@ class OpFromGraph(gof.Op): ...@@ -600,7 +618,6 @@ class OpFromGraph(gof.Op):
return ret_l return ret_l
def make_node(self, *inputs): def make_node(self, *inputs):
print('make_node', inputs)
num_expected_inps = len(self.local_inputs) - len(self.shared_inputs) num_expected_inps = len(self.local_inputs) - len(self.shared_inputs)
if len(inputs) != num_expected_inps: if len(inputs) != num_expected_inps:
raise ValueError( raise ValueError(
...@@ -683,17 +700,6 @@ class OpFromGraph(gof.Op): ...@@ -683,17 +700,6 @@ class OpFromGraph(gof.Op):
# we wont need this copy anymore # we wont need this copy anymore
output[0] = variable.copy() output[0] = variable.copy()
def copy(self):
'''
Make a shallow copy, gradient/R_op overrides will remain same objects
'''
cpy = type(self)(self.local_inputs, self.local_outputs, inline=self.is_inline)
if not self._lop_is_default:
cpy.set_lop_overrides(self.get_lop_op())
if not self._rop_is_default:
cpy.set_rop_overrides(self.get_rop_op())
return cpy
@gof.local_optimizer([OpFromGraph]) @gof.local_optimizer([OpFromGraph])
def inline_ofg_expansion(node): def inline_ofg_expansion(node):
......
...@@ -192,6 +192,33 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -192,6 +192,33 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
assert isinstance(dw2.type, NullType) assert isinstance(dw2.type, NullType)
assert isinstance(db2.type, DisconnectedType) assert isinstance(db2.type, DisconnectedType)
@test_params
def test_lop_override(self, cls_ofg):
x = T.vector()
y = 1. / (1. + T.exp(-x))
def lop_ov(inps, outs, grads):
y_, = outs
dedy_, = grads
return [2. * y_ * (1. - y_) * dedy_]
y_, dedy = T.vector(), T.vector()
op_lop_ov = cls_ofg([x, y_, dedy], [2. * y_ * (1. - y_) * dedy])
xx = T.vector()
yy1 = T.sum(T.nnet.sigmoid(xx))
gyy1 = 2. * T.grad(yy1, xx)
for ov in [lop_ov, op_lop_ov]:
op = cls_ofg([x], [y], lop_overrides=ov)
yy2 = T.sum(op(xx))
gyy2 = T.grad(yy2, xx)
fn = function([xx], [gyy1, gyy2])
xval = np.random.rand(32).astype(config.floatX)
y1val, y2val = fn(xval)
assert np.allclose(y1val, y2val)
@test_params @test_params
def test_rop(self, cls_ofg): def test_rop(self, cls_ofg):
a = T.vector() a = T.vector()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论