提交 8d9fa9e5 authored 作者: khaotik's avatar khaotik 提交者: khaotik

cleaner grad() method for OpFromGraph

上级 788e8bac
...@@ -37,25 +37,33 @@ class OpFromGraph(gof.Op): ...@@ -37,25 +37,33 @@ class OpFromGraph(gof.Op):
replace=dict(izip( replace=dict(izip(
self.shared_inputs, shared_vars)), self.shared_inputs, shared_vars)),
copy_inputs_over=False) copy_inputs_over=False)
(internal_inputs, internal_outputs, (local_inputs, local_outputs,
[clone_d, update_d, update_expr, shared_inputs]) = new [clone_d, update_d, update_expr, shared_inputs]) = new
assert len(internal_inputs) == len(inputs) + len(self.shared_inputs) assert len(local_inputs) == len(inputs) + len(self.shared_inputs)
assert len(internal_outputs) == len(outputs) assert len(local_outputs) == len(outputs)
assert not update_d assert not update_d
assert not update_expr assert not update_expr
assert not shared_inputs assert not shared_inputs
self.internal_inputs = internal_inputs self.local_inputs = local_inputs
self.internal_outputs = internal_outputs self.local_outputs = local_outputs
self.inputs = inputs self.inputs = inputs
self.outputs = outputs self.outputs = outputs
self.kwargs = kwargs self.kwargs = kwargs
self.input_types = [inp.type for inp in inputs] self.input_types = [inp.type for inp in inputs]
self.output_types = [out.type for out in outputs] self.output_types = [out.type for out in outputs]
# used to cache gradient for subgraph # grad_op: a functor takes form:
self.grad_ops = grad_overrides #
# should be True after 1st call to grad() # def grad_op(inputs:list, ups_grads:list):
self.cached_grad_ops = False # return dns_grads:list
#
# This is used to cache gradient for subgraph
# for __init__, just set as grad_overrides
#
# grad_op should be build on the 1st call to grad()
# after which grad_op_is_cached should be True
self.grad_op = grad_overrides
self.grad_op_is_cached = False
def __eq__(self, other): def __eq__(self, other):
# TODO: recognize a copy # TODO: recognize a copy
...@@ -66,45 +74,45 @@ class OpFromGraph(gof.Op): ...@@ -66,45 +74,45 @@ class OpFromGraph(gof.Op):
return hash(type(self)) return hash(type(self))
def grad(self, inputs, output_grads): def grad(self, inputs, output_grads):
if self.cached_grad_ops: if self.grad_op_is_cached:
return self.grad_ops(inputs, output_grads) return self.grad_op(inputs, output_grads)
upstream_grads = dict(izip(self.internal_outputs, output_grads)) if self.grad_op is None:
if self.grad_ops is None: self.grad_op = []
self.grad_ops = []
grad_ops_l = self.grad_ops # we need to convert a list into a single funtor
if isinstance(grad_ops_l, list): if isinstance(self.grad_op, list):
if len(grad_ops_l) > len(self.internal_inputs): grad_op_l = self.grad_op
if len(grad_op_l) > len(self.local_inputs):
raise ValueError( raise ValueError(
'Can override %d gradients at most, got %d' % ( 'Can override %d gradients at most, got %d' % (
len(self.internal_inputs), len(grad_ops_l))) len(self.local_inputs), len(grad_op_l)))
if len(grad_ops_l) < len(self.internal_inputs): if len(grad_op_l) < len(self.local_inputs):
grad_ops_l += [None] * ( grad_op_l += [None] * (
len(self.internal_inputs) - len(grad_ops_l)) len(self.local_inputs) - len(grad_op_l))
# It is normal if some inputs are not needed in order wrt = [self.local_inputs[i] for i, go in
# to compute the gradient, so we ignore them. enumerate(grad_op_l) if not go]
gs = [go if go else type(self)( # compute non-overriding downsteam gradients from upstreams grads
self.internal_inputs + output_grads, # it's normal some input may be disconnected, thus the 'ignore'
(lambda g: g if g else (lambda *a:None))( ups_grads_d = dict(izip(self.local_outputs, output_grads))
theano.gradient.grad( nat_dns_grads = iter(theano.gradient.grad(
cost=None, cost=None,
known_grads=upstream_grads, known_grads=ups_grads_d,
wrt=[inp], wrt=wrt,
disconnected_inputs='ignore') disconnected_inputs='ignore'))
), on_unused_input='ignore' # combine overriding gradients
) for go, inp in izip(grad_ops_l, self.internal_inputs)] dns_grads_l = [
go(self.local_inputs, output_grads) if go else next(nat_dns_grads) for go in grad_op_l]
# since OpFromGraphBase only accepts input sequence, grad_ofg = type(self)(
# additional filtering is needed inputs=self.local_inputs + output_grads,
def grad_ops(inps, grds): outputs=dns_grads_l,
# nonlocal gs, grad_ops_l inline=self.is_inline, on_unused_input='ignore')
return [(go(inps, grds) if ov else go(*(inps + grds)))
for go, ov in izip(gs, grad_ops_l)] def grad_op(inps, grds):
self.grad_ops = grad_ops return grad_ofg(*(list(inps) + list(grds)))
else: self.grad_op = grad_op
grad_ops = grad_ops_l self.grad_op_is_cached = True
self.cached_grad_ops = True return self.grad_op(inputs, output_grads)
return grad_ops(inputs, output_grads)
def make_node(self, *inputs): def make_node(self, *inputs):
for input, type in zip(inputs, self.input_types): for input, type in zip(inputs, self.input_types):
...@@ -115,8 +123,8 @@ class OpFromGraph(gof.Op): ...@@ -115,8 +123,8 @@ class OpFromGraph(gof.Op):
apply_node = gof.Apply( apply_node = gof.Apply(
self, list(inputs) + self.shared_inputs, self, list(inputs) + self.shared_inputs,
[type() for type in self.output_types]) [type() for type in self.output_types])
apply_node.internal_inputs = self.internal_inputs apply_node.local_inputs = self.local_inputs
apply_node.internal_outputs = self.internal_outputs apply_node.local_outputs = self.local_outputs
return apply_node return apply_node
def connection_pattern(self, node): def connection_pattern(self, node):
...@@ -125,12 +133,12 @@ class OpFromGraph(gof.Op): ...@@ -125,12 +133,12 @@ class OpFromGraph(gof.Op):
""" """
return io_connection_pattern( return io_connection_pattern(
self.internal_inputs, self.internal_outputs) self.local_inputs, self.local_outputs)
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
out_shp = theano.scan_module.scan_utils.infer_shape( out_shp = theano.scan_module.scan_utils.infer_shape(
self.internal_outputs, self.local_outputs,
self.internal_inputs, self.local_inputs,
shapes) shapes)
# Clone the output shape so that shape are computed from outer inputs. # Clone the output shape so that shape are computed from outer inputs.
...@@ -140,7 +148,7 @@ class OpFromGraph(gof.Op): ...@@ -140,7 +148,7 @@ class OpFromGraph(gof.Op):
# But doing it multiple time could duplicate common subgraph between # But doing it multiple time could duplicate common subgraph between
# each shape call. Theano optimizer will clean this up later, but this # each shape call. Theano optimizer will clean this up later, but this
# will ask extra work to the optimizer. # will ask extra work to the optimizer.
repl = dict(zip(self.internal_inputs, node.inputs)) repl = dict(zip(self.local_inputs, node.inputs))
cloned = theano.clone(reduce(tuple.__add__, out_shp), replace=repl) cloned = theano.clone(reduce(tuple.__add__, out_shp), replace=repl)
ret = [] ret = []
used = 0 used = 0
...@@ -153,8 +161,8 @@ class OpFromGraph(gof.Op): ...@@ -153,8 +161,8 @@ class OpFromGraph(gof.Op):
def prepare_node(self, node, storage_map, compute_map, impl): def prepare_node(self, node, storage_map, compute_map, impl):
if not hasattr(self, "fn") and impl == 'py': if not hasattr(self, "fn") and impl == 'py':
self.fn = orig_function(self.internal_inputs, self.fn = orig_function(self.local_inputs,
self.internal_outputs, self.local_outputs,
**self.kwargs) **self.kwargs)
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
...@@ -165,6 +173,7 @@ class OpFromGraph(gof.Op): ...@@ -165,6 +173,7 @@ class OpFromGraph(gof.Op):
# we wont need this copy anymore # we wont need this copy anymore
output[0] = variable.copy() output[0] = variable.copy()
@gof.local_optimizer([OpFromGraph]) @gof.local_optimizer([OpFromGraph])
def inline_ofg_expansion(node): def inline_ofg_expansion(node):
""" """
...@@ -178,9 +187,9 @@ def inline_ofg_expansion(node): ...@@ -178,9 +187,9 @@ def inline_ofg_expansion(node):
if not op.is_inline: if not op.is_inline:
return False return False
return theano.clone( return theano.clone(
op.internal_outputs, { op.local_outputs, {
u: v for u, v in izip( u: v for u, v in izip(
node.op.internal_inputs, node.inputs)}) node.op.local_inputs, node.inputs)})
optdb.register( optdb.register(
'inline_ofg_expansion', 'inline_ofg_expansion',
...@@ -191,6 +200,7 @@ optdb.register( ...@@ -191,6 +200,7 @@ optdb.register(
# we should let DebugMode know about it # we should let DebugMode know about it
ops_with_inner_function[OpFromGraph] = 'fn' ops_with_inner_function[OpFromGraph] = 'fn'
# API for OpFromGraph # API for OpFromGraph
def op_from_graph( def op_from_graph(
inputs, outputs, inline=False, grad_overrides=None, **kwargs inputs, outputs, inline=False, grad_overrides=None, **kwargs
...@@ -214,7 +224,7 @@ def op_from_graph( ...@@ -214,7 +224,7 @@ def op_from_graph(
grad_overrides: None | function | list of (None|function), optional grad_overrides: None | function | list of (None|function), optional
Used to override default gradient routine. Used to override default gradient routine.
Overriding function(s) must take two list of variable as inputs, Overriding function(s) must take two list of variable as inputs,
the original inputs and upstream gradients the original inputs and ups gradients
For different `grad_overrides`: For different `grad_overrides`:
- `None` : will use default gradient routine. - `None` : will use default gradient routine.
...@@ -225,8 +235,8 @@ def op_from_graph( ...@@ -225,8 +235,8 @@ def op_from_graph(
TODO: TODO:
- examples for a multi-layer mlp. where? - examples for a multi-layer mlp. where?
- __hash__, __eq__ otherwise won't merge, try - __hash__, __eq__ otherwise won't merge, try
gof.opt.is_same_graph_with_merge(op1.internal_outputs, op2, gof.opt.is_same_graph_with_merge(op1.local_outputs, op2,
internal_outputs) local_outputs)
- c_code() to remove the double overhead? - c_code() to remove the double overhead?
- grad() make it support DisconnectedType and the new interface - grad() make it support DisconnectedType and the new interface
- check how it works with updates. - check how it works with updates.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论