提交 9a9cd0ee authored 作者: Frederic's avatar Frederic

Make OpFromGraph support shared variable.

上级 bb703bd8
...@@ -9,8 +9,6 @@ from theano.compile.mode import * ...@@ -9,8 +9,6 @@ from theano.compile.mode import *
from theano.compile.io import * from theano.compile.io import *
from theano.compile.builders import *
from theano.compile.module import * from theano.compile.module import *
from theano.compile.debugmode import DebugMode from theano.compile.debugmode import DebugMode
...@@ -25,4 +23,6 @@ from theano.compile.sharedvalue import (shared, shared_constructor, ...@@ -25,4 +23,6 @@ from theano.compile.sharedvalue import (shared, shared_constructor,
SharedVariable) SharedVariable)
from theano.compile.pfunc import pfunc, Param, rebuild_collect_shared from theano.compile.pfunc import pfunc, Param, rebuild_collect_shared
from theano.compile.builders import *
from theano.compile.function import function from theano.compile.function import function
from theano import gof from theano import gof
from theano import gradient as G from theano import gradient as G
from theano.compile.function_module import orig_function from theano.compile.function_module import orig_function
from theano.compile import SharedVariable, rebuild_collect_shared
from theano.gof import ops_with_inner_function from theano.gof import ops_with_inner_function
...@@ -10,7 +11,7 @@ class OpFromGraph(gof.Op): ...@@ -10,7 +11,7 @@ class OpFromGraph(gof.Op):
The signature is similar to theano.function() and the resulting The signature is similar to theano.function() and the resulting
`Op` perform will do the same operation as:: `Op` perform will do the same operation as::
function(inputs, outputs, **kwargs) orig_function(inputs, outputs, **kwargs)
Example: Example:
x, y, z = tensor.scalars('xyz') x, y, z = tensor.scalars('xyz')
...@@ -21,12 +22,19 @@ class OpFromGraph(gof.Op): ...@@ -21,12 +22,19 @@ class OpFromGraph(gof.Op):
fn = function([x, y, z], [e2]) fn = function([x, y, z], [e2])
TODO: -examples TODO: - examples
- support shared var
- __hash__, __eq__ otherwise won't merge - __hash__, __eq__ otherwise won't merge
- c_code() to remove the double overhead? - c_code() to remove the double overhead?
- opt to unfold it, work inplace on inputs - opt to unfold it, work inplace on inputs
- grad() make it support DisconnectedType and the new interface - grad() make it support DisconnectedType and the new interface
- check how it work with updates.
- add test with constant as input or inside the inner graph.
- Add support for the GPU? Probably just need an opt to remove transfer
- Add support to pickle this Op.
:note:
We support unused inputs. This is needed for the grad.
We support shared variable in the inner graph. This is automatic and
invisible to the user.
""" """
def __init__(self, inputs, outputs, **kwargs): def __init__(self, inputs, outputs, **kwargs):
...@@ -39,12 +47,27 @@ class OpFromGraph(gof.Op): ...@@ -39,12 +47,27 @@ class OpFromGraph(gof.Op):
if 'updates' in kwargs: if 'updates' in kwargs:
raise TypeError('updates are not allowed in kwargs') raise TypeError('updates are not allowed in kwargs')
shared_inputs = [var for var in gof.graph.inputs(outputs) # To support correctly shared variables the inner fct should
# not see them. Otherwise their is problem with the gradient.
self.shared_inputs = [var for var in gof.graph.inputs(outputs)
if isinstance(var, SharedVariable)] if isinstance(var, SharedVariable)]
if shared_inputs: used_inputs = [var for var in gof.graph.inputs(outputs)
raise NotImplementedError( if not isinstance(var, gof.Constant)]
"OpFromGraph do not support SharedVariable in the inner graph") shared_vars = [var.type() for var in self.shared_inputs]
new = rebuild_collect_shared(outputs, inputs=inputs + shared_vars,
replace=dict(zip(self.shared_inputs,
shared_vars)),
copy_inputs_over=False)
(new_inputs, new_outputs,
[clone_d, update_d, update_expr, shared_inputs]) = new
assert len(new_inputs) == len(inputs) + len(self.shared_inputs)
assert len(new_outputs) == len(outputs)
assert not update_d
assert not update_expr
assert not shared_inputs
self.new_inputs = new_inputs
self.new_outputs = new_outputs
self.inputs = inputs self.inputs = inputs
self.outputs = outputs self.outputs = outputs
self.kwargs = kwargs self.kwargs = kwargs
...@@ -66,14 +89,16 @@ class OpFromGraph(gof.Op): ...@@ -66,14 +89,16 @@ class OpFromGraph(gof.Op):
raise TypeError("Wrong type, expected %s but got %s" raise TypeError("Wrong type, expected %s but got %s"
% (type, input.type)) % (type, input.type))
return gof.Apply(self, return gof.Apply(self,
inputs, list(inputs) + self.shared_inputs,
[type() for type in self.output_types]) [type() for type in self.output_types])
def make_thunk(self, node, storage_map, compute_map, no_recycling): def make_thunk(self, node, storage_map, compute_map, no_recycling):
ret = super(OpFromGraph, self).make_thunk(node, storage_map, ret = super(OpFromGraph, self).make_thunk(node, storage_map,
compute_map, no_recycling) compute_map, no_recycling)
if not hasattr(self, "fn"): if not hasattr(self, "fn"):
self.fn = orig_function(self.inputs, self.outputs, **self.kwargs) self.fn = orig_function(self.new_inputs,
self.new_outputs,
**self.kwargs)
return ret return ret
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
...@@ -94,8 +119,9 @@ class OpFromGraph(gof.Op): ...@@ -94,8 +119,9 @@ class OpFromGraph(gof.Op):
grad_ops = self.grad_ops grad_ops = self.grad_ops
else: else:
gs = G.grad(cost=None, gs = G.grad(cost=None,
known_grads=dict(zip(self.outputs, output_grads)), known_grads=dict(zip(self.new_outputs, output_grads)),
wrt=self.inputs, disconnected_inputs='ignore') wrt=self.new_inputs,
disconnected_inputs='ignore')
grad_ops = [] grad_ops = []
for g in gs: for g in gs:
...@@ -104,7 +130,7 @@ class OpFromGraph(gof.Op): ...@@ -104,7 +130,7 @@ class OpFromGraph(gof.Op):
else: else:
# It is normal if some inputs are not needed in order # It is normal if some inputs are not needed in order
# to compute the gradient, so we ignore them. # to compute the gradient, so we ignore them.
grad_ops.append(OpFromGraph(self.inputs + output_grads, grad_ops.append(OpFromGraph(self.new_inputs + output_grads,
[g], [g],
on_unused_input='ignore')) on_unused_input='ignore'))
self.grad_ops = grad_ops self.grad_ops = grad_ops
......
...@@ -73,8 +73,6 @@ class T_OpFromGraph(unittest.TestCase): ...@@ -73,8 +73,6 @@ class T_OpFromGraph(unittest.TestCase):
x, y, z = T.matrices('xyz') x, y, z = T.matrices('xyz')
s = shared(numpy.random.rand(2, 2).astype(config.floatX)) s = shared(numpy.random.rand(2, 2).astype(config.floatX))
e = x + y * z + s e = x + y * z + s
self.assertRaises(NotImplementedError, OpFromGraph, [x, y, z], [e], mode='FAST_RUN')
return
op = OpFromGraph([x, y, z], [e], mode='FAST_RUN') op = OpFromGraph([x, y, z], [e], mode='FAST_RUN')
f = op(x, y, z) - op(y, z, x) # (1+3*5=array of 16) - (3+1*5=array of 8) f = op(x, y, z) - op(y, z, x) # (1+3*5=array of 16) - (3+1*5=array of 8)
fn = function([x, y, z], f) fn = function([x, y, z], f)
...@@ -86,6 +84,26 @@ class T_OpFromGraph(unittest.TestCase): ...@@ -86,6 +84,26 @@ class T_OpFromGraph(unittest.TestCase):
assert numpy.allclose(8.0, fn(xv, yv, zv)) assert numpy.allclose(8.0, fn(xv, yv, zv))
assert numpy.allclose(8.0, fn(xv, yv, zv)) assert numpy.allclose(8.0, fn(xv, yv, zv))
def test_shared_grad(self):
x, y, z = T.matrices('xyz')
s = shared(numpy.random.rand(2, 2).astype(config.floatX))
e = x + y * z + s
op = OpFromGraph([x, y, z], [e], mode='FAST_RUN')
f = op(x, y, z)
f = f - T.grad(T.sum(f), y)
fn = function([x, y, z], f)
xv = numpy.ones((2, 2), dtype=config.floatX)
yv = numpy.ones((2, 2), dtype=config.floatX) * 3
zv = numpy.ones((2, 2), dtype=config.floatX) * 5
assert numpy.allclose(11.0 + s.get_value(), fn(xv, yv, zv))
# grad again the shared variable
f = op(x, y, z)
f = f - T.grad(T.sum(f), s)
fn = function([x, y, z], f)
assert numpy.allclose(15.0 + s.get_value(),
fn(xv, yv, zv))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论