提交 8d9fa9e5 authored 作者: khaotik's avatar khaotik 提交者: khaotik

cleaner grad() method for OpFromGraph

上级 788e8bac
......@@ -37,25 +37,33 @@ class OpFromGraph(gof.Op):
replace=dict(izip(
self.shared_inputs, shared_vars)),
copy_inputs_over=False)
(internal_inputs, internal_outputs,
(local_inputs, local_outputs,
[clone_d, update_d, update_expr, shared_inputs]) = new
assert len(internal_inputs) == len(inputs) + len(self.shared_inputs)
assert len(internal_outputs) == len(outputs)
assert len(local_inputs) == len(inputs) + len(self.shared_inputs)
assert len(local_outputs) == len(outputs)
assert not update_d
assert not update_expr
assert not shared_inputs
self.internal_inputs = internal_inputs
self.internal_outputs = internal_outputs
self.local_inputs = local_inputs
self.local_outputs = local_outputs
self.inputs = inputs
self.outputs = outputs
self.kwargs = kwargs
self.input_types = [inp.type for inp in inputs]
self.output_types = [out.type for out in outputs]
# used to cache gradient for subgraph
self.grad_ops = grad_overrides
# should be True after 1st call to grad()
self.cached_grad_ops = False
# grad_op: a functor takes form:
#
# def grad_op(inputs:list, ups_grads:list):
# return dns_grads:list
#
# This is used to cache gradient for subgraph
# for __init__, just set as grad_overrides
#
# grad_op should be build on the 1st call to grad()
# after which grad_op_is_cached should be True
self.grad_op = grad_overrides
self.grad_op_is_cached = False
def __eq__(self, other):
# TODO: recognize a copy
......@@ -66,45 +74,45 @@ class OpFromGraph(gof.Op):
return hash(type(self))
def grad(self, inputs, output_grads):
if self.cached_grad_ops:
return self.grad_ops(inputs, output_grads)
upstream_grads = dict(izip(self.internal_outputs, output_grads))
if self.grad_ops is None:
self.grad_ops = []
grad_ops_l = self.grad_ops
if isinstance(grad_ops_l, list):
if len(grad_ops_l) > len(self.internal_inputs):
if self.grad_op_is_cached:
return self.grad_op(inputs, output_grads)
if self.grad_op is None:
self.grad_op = []
# we need to convert a list into a single funtor
if isinstance(self.grad_op, list):
grad_op_l = self.grad_op
if len(grad_op_l) > len(self.local_inputs):
raise ValueError(
'Can override %d gradients at most, got %d' % (
len(self.internal_inputs), len(grad_ops_l)))
if len(grad_ops_l) < len(self.internal_inputs):
grad_ops_l += [None] * (
len(self.internal_inputs) - len(grad_ops_l))
# It is normal if some inputs are not needed in order
# to compute the gradient, so we ignore them.
gs = [go if go else type(self)(
self.internal_inputs + output_grads,
(lambda g: g if g else (lambda *a:None))(
theano.gradient.grad(
cost=None,
known_grads=upstream_grads,
wrt=[inp],
disconnected_inputs='ignore')
), on_unused_input='ignore'
) for go, inp in izip(grad_ops_l, self.internal_inputs)]
# since OpFromGraphBase only accepts input sequence,
# additional filtering is needed
def grad_ops(inps, grds):
# nonlocal gs, grad_ops_l
return [(go(inps, grds) if ov else go(*(inps + grds)))
for go, ov in izip(gs, grad_ops_l)]
self.grad_ops = grad_ops
else:
grad_ops = grad_ops_l
self.cached_grad_ops = True
return grad_ops(inputs, output_grads)
len(self.local_inputs), len(grad_op_l)))
if len(grad_op_l) < len(self.local_inputs):
grad_op_l += [None] * (
len(self.local_inputs) - len(grad_op_l))
wrt = [self.local_inputs[i] for i, go in
enumerate(grad_op_l) if not go]
# compute non-overriding downsteam gradients from upstreams grads
# it's normal some input may be disconnected, thus the 'ignore'
ups_grads_d = dict(izip(self.local_outputs, output_grads))
nat_dns_grads = iter(theano.gradient.grad(
cost=None,
known_grads=ups_grads_d,
wrt=wrt,
disconnected_inputs='ignore'))
# combine overriding gradients
dns_grads_l = [
go(self.local_inputs, output_grads) if go else next(nat_dns_grads) for go in grad_op_l]
grad_ofg = type(self)(
inputs=self.local_inputs + output_grads,
outputs=dns_grads_l,
inline=self.is_inline, on_unused_input='ignore')
def grad_op(inps, grds):
return grad_ofg(*(list(inps) + list(grds)))
self.grad_op = grad_op
self.grad_op_is_cached = True
return self.grad_op(inputs, output_grads)
def make_node(self, *inputs):
for input, type in zip(inputs, self.input_types):
......@@ -115,8 +123,8 @@ class OpFromGraph(gof.Op):
apply_node = gof.Apply(
self, list(inputs) + self.shared_inputs,
[type() for type in self.output_types])
apply_node.internal_inputs = self.internal_inputs
apply_node.internal_outputs = self.internal_outputs
apply_node.local_inputs = self.local_inputs
apply_node.local_outputs = self.local_outputs
return apply_node
def connection_pattern(self, node):
......@@ -125,12 +133,12 @@ class OpFromGraph(gof.Op):
"""
return io_connection_pattern(
self.internal_inputs, self.internal_outputs)
self.local_inputs, self.local_outputs)
def infer_shape(self, node, shapes):
out_shp = theano.scan_module.scan_utils.infer_shape(
self.internal_outputs,
self.internal_inputs,
self.local_outputs,
self.local_inputs,
shapes)
# Clone the output shape so that shape are computed from outer inputs.
......@@ -140,7 +148,7 @@ class OpFromGraph(gof.Op):
# But doing it multiple time could duplicate common subgraph between
# each shape call. Theano optimizer will clean this up later, but this
# will ask extra work to the optimizer.
repl = dict(zip(self.internal_inputs, node.inputs))
repl = dict(zip(self.local_inputs, node.inputs))
cloned = theano.clone(reduce(tuple.__add__, out_shp), replace=repl)
ret = []
used = 0
......@@ -153,8 +161,8 @@ class OpFromGraph(gof.Op):
def prepare_node(self, node, storage_map, compute_map, impl):
if not hasattr(self, "fn") and impl == 'py':
self.fn = orig_function(self.internal_inputs,
self.internal_outputs,
self.fn = orig_function(self.local_inputs,
self.local_outputs,
**self.kwargs)
def perform(self, node, inputs, outputs):
......@@ -165,6 +173,7 @@ class OpFromGraph(gof.Op):
# we wont need this copy anymore
output[0] = variable.copy()
@gof.local_optimizer([OpFromGraph])
def inline_ofg_expansion(node):
"""
......@@ -178,9 +187,9 @@ def inline_ofg_expansion(node):
if not op.is_inline:
return False
return theano.clone(
op.internal_outputs, {
op.local_outputs, {
u: v for u, v in izip(
node.op.internal_inputs, node.inputs)})
node.op.local_inputs, node.inputs)})
optdb.register(
'inline_ofg_expansion',
......@@ -191,6 +200,7 @@ optdb.register(
# we should let DebugMode know about it
ops_with_inner_function[OpFromGraph] = 'fn'
# API for OpFromGraph
def op_from_graph(
inputs, outputs, inline=False, grad_overrides=None, **kwargs
......@@ -214,7 +224,7 @@ def op_from_graph(
grad_overrides: None | function | list of (None|function), optional
Used to override default gradient routine.
Overriding function(s) must take two list of variable as inputs,
the original inputs and upstream gradients
the original inputs and ups gradients
For different `grad_overrides`:
- `None` : will use default gradient routine.
......@@ -225,8 +235,8 @@ def op_from_graph(
TODO:
- examples for a multi-layer mlp. where?
- __hash__, __eq__ otherwise won't merge, try
gof.opt.is_same_graph_with_merge(op1.internal_outputs, op2,
internal_outputs)
gof.opt.is_same_graph_with_merge(op1.local_outputs, op2,
local_outputs)
- c_code() to remove the double overhead?
- grad() make it support DisconnectedType and the new interface
- check how it works with updates.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论