提交 c0fda9c0 authored 作者: khaotik's avatar khaotik 提交者: khaotik

grad_overrides now use syntax as in docstring

上级 7608602e
......@@ -69,7 +69,7 @@ from theano.compile import (
Mode,
predefined_modes, predefined_linkers, predefined_optimizers,
FunctionMaker, function, function_dump,
OpFromGraph, OpFromGrpahInline, OpFromGraphPrecompiled, op_from_graph
OpFromGraph, OpFromGraphInline, OpFromGraphPrecompiled, op_from_graph,
ProfileStats,
Param, shared, as_op)
......
......@@ -69,7 +69,7 @@ class OpFromGraphBase(gof.Op):
def grad(self, inputs, output_grads):
if self.cached_grad_ops:
return self.grad_ops(inputs+output_grads)
return self.grad_ops(inputs, output_grads)
grad_inps = self.internal_inputs + output_grads
upstream_grads = dict(izip(self.internal_outputs, output_grads))
......@@ -84,17 +84,18 @@ class OpFromGraphBase(gof.Op):
# to compute the gradient, so we ignore them.
gs = [go if go else type(self)(
grad_inps,
theano.gradient.grad(
cost=None,
known_grads=upstream_grads,
wrt=[inp],
disconnected_inputs='ignore'),
on_unused_input='ignore'
(lambda g: g if g else (lambda *a:None))(
theano.gradient.grad(
cost=None,
known_grads=upstream_grads,
wrt=[inp],
disconnected_inputs='ignore')
), on_unused_input='ignore'
) for go, inp in izip(grad_ops_l, self.internal_inputs)]
# since OpFromGraphBase only accepts and outputs list,
# since OpFromGraphBase only accepts input sequence,
# additional filtering is needed
grad_ops = lambda inps:[
(go(inps) if ov else go(*inps))
grad_ops = lambda inps,grds:[
(go(inps, grds) if ov else go(*(inps+grds)))
for go, ov in izip(gs, grad_ops_l)]
else:
grad_ops = grad_ops_l
......@@ -113,10 +114,10 @@ class OpFromGraphBase(gof.Op):
grad_ops_l.append(type(self)(grad_inps,
[g],
on_unused_input='ignore'))
grad_ops = lambda inps:[go(*inps) for go in grad_ops_l]
grad_ops = lambda inps, grds:[go(*(inps+grds)) for go in grad_ops_l]
self.grad_ops = grad_ops
self.cached_grad_ops = True
return grad_ops(inputs+output_grads)
return grad_ops(inputs, output_grads)
def make_node(self, *inputs):
for input, type in zip(inputs, self.input_types):
......@@ -191,6 +192,8 @@ class OpFromGraphInline(OpFromGraphBase):
@gof.local_optimizer([OpFromGraphInline])
def inline_ofg_expansion(node):
""" This optimization expands internal graph of OpFromGraphInline
"""
op = node.op
if not isinstance(op, OpFromGraphInline):
return False
......@@ -205,6 +208,8 @@ optdb.register(
gof.opt.in2out(inline_ofg_expansion),
0.5, 'fast_compile', 'fast_run')
# Since OpFromGraphPrecompiled contains a Theano compiled function,
# we should let DebugMode know about it
ops_with_inner_function[OpFromGraphPrecompiled] = 'fn'
# for backward compatibility
......@@ -227,16 +232,18 @@ def op_from_graph(
inputs: list of variables
outputs: list of variables
inline: bool
inline: bool, optional
if True, will cause the Op's original graph being used during
compilation, otherwise will use a pre-compiled function inside.
grad_overrides: None | function | list of (None|function)
grad_overrides: None | function | list of (None|function), optional
Used to override default gradient routine.
Overriding function must take two list as inputs: original inputs
and upstream gradients
If is None, will use default gradient routine.
If is function, must return list of Variable.
If is list, each function must return a single Variable. The order
Overriding function(s) must take two list of variable as inputs,
the original inputs and upstream gradients
For different `grad_overrides`:
- `None` : will use default gradient routine.
- function : must return list of Variable.
- list : each function must return a single Variable. The order
of the list must corresponds to inputs
Notes
......@@ -263,7 +270,7 @@ def op_from_graph(
invisible to the user. They can be as input to the node or in the
inner graph.
- We support unused inputs. This is needed for the grad.
- inline=True will cause better optimization at the cost of longer
- `inline=True` will cause better runtime optimization at the cost of longer
compilation, only works with optimizer "fast_run" or "fast_compile"
Examples
......@@ -307,12 +314,12 @@ def op_from_graph(
x, y, z = inps
g = grads
return z*2
op = op_from_graph(
[x, y, z], [e], grad_overrides=[None, rescale_dy, None])
e2 = op(x, y, z)
dx, dy, dz = grad(e2, [x, y, z])
fn = function([x, y, z], [dx, dy, dz])
# the graident wrt y is now doubled
fn(2., 3., 4.) # [1., 8., 3.]
"""
......
......@@ -124,12 +124,13 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
def test_grad_override(self, cls_ofg):
x,y = T.vectors('xy')
def go(args):
x, y, g = args
def go(inps, gs):
x, y = inps
g = gs[0]
return [g*y*2, g*x*1.5]
# no override is coverd in "grad" test
# no override case is coverd in "grad" test
# single override
# single override case
op_mul = cls_ofg([x, y], [x*y], grad_overrides=go)
xx,yy = T.vector('xx'), T.vector('yy')
zz = T.sum(op_mul(xx,yy))
......@@ -141,13 +142,15 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
assert numpy.allclose(yv*2, dxv)
assert numpy.allclose(xv*1.5, dyv)
# list override
def go1(args):
x, w, b, g = args
# list override case
def go1(inps, gs):
x, w, b = inps
g = gs[0]
return g*w*2
def go2(args):
x, w, b, g = args
def go2(inps, gs):
x, w, b = inps
g = gs[0]
return g*x*1.5
w, b = T.vectors('wb')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论