提交 788e8bac authored 作者: khaotik's avatar khaotik 提交者: khaotik

refactor into single OpFromGraph

上级 9b75b7ae
...@@ -69,7 +69,7 @@ from theano.compile import ( ...@@ -69,7 +69,7 @@ from theano.compile import (
Mode, Mode,
predefined_modes, predefined_linkers, predefined_optimizers, predefined_modes, predefined_linkers, predefined_optimizers,
FunctionMaker, function, function_dump, FunctionMaker, function, function_dump,
OpFromGraph, OpFromGraphInline, OpFromGraphPrecompiled, op_from_graph, OpFromGraph, op_from_graph,
ProfileStats, ProfileStats,
Param, shared, as_op) Param, shared, as_op)
......
...@@ -11,13 +11,13 @@ from theano.gof import ops_with_inner_function ...@@ -11,13 +11,13 @@ from theano.gof import ops_with_inner_function
from theano.gof.graph import io_connection_pattern from theano.gof.graph import io_connection_pattern
class OpFromGraphBase(gof.Op): class OpFromGraph(gof.Op):
""" """
base class for Ops with custom inner graph class for Ops with user-defined inner graph
""" """
# NOTE: if you make a subclass of this, make sure add it under: # NOTE: if you make a subclass of this, make sure add test for it under:
# theano/compile/tests/test_builders.py # theano/compile/tests/test_builders.py
def __init__(self, inputs, outputs, grad_overrides=None, **kwargs): def __init__(self, inputs, outputs, inline=False, grad_overrides=None, **kwargs):
if not isinstance(outputs, list): if not isinstance(outputs, list):
raise TypeError('outputs must be list', outputs) raise TypeError('outputs must be list', outputs)
for i in inputs + outputs: for i in inputs + outputs:
...@@ -26,6 +26,7 @@ class OpFromGraphBase(gof.Op): ...@@ -26,6 +26,7 @@ class OpFromGraphBase(gof.Op):
'inputs and outputs must be Variable instances', i) 'inputs and outputs must be Variable instances', i)
if 'updates' in kwargs or 'givens' in kwargs: if 'updates' in kwargs or 'givens' in kwargs:
raise TypeError('updates and givens are not allowed here') raise TypeError('updates and givens are not allowed here')
self.is_inline = inline
# To correctly support shared variables the inner fct should # To correctly support shared variables the inner fct should
# not see them. Otherwise there is a problem with the gradient. # not see them. Otherwise there is a problem with the gradient.
self.shared_inputs = [var for var in gof.graph.inputs(outputs) self.shared_inputs = [var for var in gof.graph.inputs(outputs)
...@@ -68,19 +69,22 @@ class OpFromGraphBase(gof.Op): ...@@ -68,19 +69,22 @@ class OpFromGraphBase(gof.Op):
if self.cached_grad_ops: if self.cached_grad_ops:
return self.grad_ops(inputs, output_grads) return self.grad_ops(inputs, output_grads)
grad_inps = self.internal_inputs + output_grads
upstream_grads = dict(izip(self.internal_outputs, output_grads)) upstream_grads = dict(izip(self.internal_outputs, output_grads))
if self.grad_ops is not None: if self.grad_ops is None:
self.grad_ops = []
grad_ops_l = self.grad_ops grad_ops_l = self.grad_ops
if isinstance(grad_ops_l, list): if isinstance(grad_ops_l, list):
assert len(grad_ops_l) <= len(self.internal_inputs) if len(grad_ops_l) > len(self.internal_inputs):
raise ValueError(
'Can override %d gradients at most, got %d' % (
len(self.internal_inputs), len(grad_ops_l)))
if len(grad_ops_l) < len(self.internal_inputs): if len(grad_ops_l) < len(self.internal_inputs):
grad_ops_l += [None] * ( grad_ops_l += [None] * (
len(self.internal_inputs) - len(grad_ops_l)) len(self.internal_inputs) - len(grad_ops_l))
# It is normal if some inputs are not needed in order # It is normal if some inputs are not needed in order
# to compute the gradient, so we ignore them. # to compute the gradient, so we ignore them.
gs = [go if go else type(self)( gs = [go if go else type(self)(
grad_inps, self.internal_inputs + output_grads,
(lambda g: g if g else (lambda *a:None))( (lambda g: g if g else (lambda *a:None))(
theano.gradient.grad( theano.gradient.grad(
cost=None, cost=None,
...@@ -96,27 +100,9 @@ class OpFromGraphBase(gof.Op): ...@@ -96,27 +100,9 @@ class OpFromGraphBase(gof.Op):
# nonlocal gs, grad_ops_l # nonlocal gs, grad_ops_l
return [(go(inps, grds) if ov else go(*(inps + grds))) return [(go(inps, grds) if ov else go(*(inps + grds)))
for go, ov in izip(gs, grad_ops_l)] for go, ov in izip(gs, grad_ops_l)]
else:
grad_ops = grad_ops_l
self.grad_ops = grad_ops self.grad_ops = grad_ops
else: else:
gs = theano.gradient.grad( grad_ops = grad_ops_l
cost=None,
known_grads=upstream_grads,
wrt=self.internal_inputs,
disconnected_inputs='ignore')
grad_ops_l = []
for g in gs:
if g is None:
grad_ops_l.append(lambda *args: None)
else:
grad_ops_l.append(type(self)(
grad_inps, [g], on_unused_input='ignore'))
def grad_ops(inps, grds):
# nonlocal grad_ops_l
return [go(*(inps + grds)) for go in grad_ops_l]
self.grad_ops = grad_ops
self.cached_grad_ops = True self.cached_grad_ops = True
return grad_ops(inputs, output_grads) return grad_ops(inputs, output_grads)
...@@ -165,14 +151,6 @@ class OpFromGraphBase(gof.Op): ...@@ -165,14 +151,6 @@ class OpFromGraphBase(gof.Op):
return ret return ret
def perform(self, node, inputs, outputs):
raise NotImplementedError()
class OpFromGraphPrecompiled(OpFromGraphBase):
"""
The Op's inner graph is compiled into a theano function.
"""
def prepare_node(self, node, storage_map, compute_map, impl): def prepare_node(self, node, storage_map, compute_map, impl):
if not hasattr(self, "fn") and impl == 'py': if not hasattr(self, "fn") and impl == 'py':
self.fn = orig_function(self.internal_inputs, self.fn = orig_function(self.internal_inputs,
...@@ -187,43 +165,33 @@ class OpFromGraphPrecompiled(OpFromGraphBase): ...@@ -187,43 +165,33 @@ class OpFromGraphPrecompiled(OpFromGraphBase):
# we wont need this copy anymore # we wont need this copy anymore
output[0] = variable.copy() output[0] = variable.copy()
@gof.local_optimizer([OpFromGraph])
class OpFromGraphInline(OpFromGraphBase): def inline_ofg_expansion(node):
"""
The Op's inner graph is expanded into the outer graph at compile time
""" """
def perform(self, node, inputs, outputs): This optimization expands internal graph of OpFromGraph.
raise RuntimeError(
type(self).__name__ + ' is not supposed to be executed at runtime')
@gof.local_optimizer([OpFromGraphInline]) Doing so can improve optimization at the cost of compilation speed.
def inline_ofg_expansion(node):
""" This optimization expands internal graph of OpFromGraphInline
""" """
op = node.op op = node.op
if not isinstance(op, OpFromGraphInline): if not isinstance(op, OpFromGraph):
return False return False
outputs = theano.clone( if not op.is_inline:
return False
return theano.clone(
op.internal_outputs, { op.internal_outputs, {
u: v for u, v in izip( u: v for u, v in izip(
node.op.internal_inputs, node.inputs)}) node.op.internal_inputs, node.inputs)})
return outputs
optdb.register( optdb.register(
'inline_ofg_expansion', 'inline_ofg_expansion',
gof.opt.in2out(inline_ofg_expansion), gof.opt.in2out(inline_ofg_expansion),
0.5, 'fast_compile', 'fast_run') 0.5, 'fast_compile', 'fast_run')
# Since OpFromGraphPrecompiled contains a Theano compiled function, # Since OpFromGraph contains a Theano compiled function,
# we should let DebugMode know about it # we should let DebugMode know about it
ops_with_inner_function[OpFromGraphPrecompiled] = 'fn' ops_with_inner_function[OpFromGraph] = 'fn'
# for backward compatibility
OpFromGraph = OpFromGraphPrecompiled
# API for OpFromGraph
# API for OpFromGraph*
def op_from_graph( def op_from_graph(
inputs, outputs, inline=False, grad_overrides=None, **kwargs inputs, outputs, inline=False, grad_overrides=None, **kwargs
): ):
...@@ -254,9 +222,6 @@ def op_from_graph( ...@@ -254,9 +222,6 @@ def op_from_graph(
- list : each function must return a single Variable. The order - list : each function must return a single Variable. The order
of the list must corresponds to inputs of the list must corresponds to inputs
Notes
-----
TODO: TODO:
- examples for a multi-layer mlp. where? - examples for a multi-layer mlp. where?
- __hash__, __eq__ otherwise won't merge, try - __hash__, __eq__ otherwise won't merge, try
...@@ -274,12 +239,14 @@ def op_from_graph( ...@@ -274,12 +239,14 @@ def op_from_graph(
Notes Notes
----- -----
- We support shared variables in the inner graph. This is automatic and - We support shared variables in the inner graph. This is automatic
invisible to the user. They can be as input to the node or in the and invisible to the user. They can be as input to the node or in
inner graph. the inner graph.
- We support unused inputs. This is needed for the grad. - We support unused inputs. This is needed for the grad.
- `inline=True` will cause better runtime optimization at the cost of - `inline=True` will cause better runtime optimization at the cost
longer compilation, only works with optimizer "fast_run" or "fast_compile" of compilation time. Like "inline" keyword in C, this is merely a
suggestion to compiler which is not guaranteed. Currently only
works with "fast_compile" or "fast_run" mode.
Examples Examples
-------- --------
...@@ -331,9 +298,5 @@ def op_from_graph( ...@@ -331,9 +298,5 @@ def op_from_graph(
fn(2., 3., 4.) # [1., 8., 3.] fn(2., 3., 4.) # [1., 8., 3.]
""" """
if inline and theano.config.optimizer in ['fast_run', 'fast_compile']: return OpFromGraph(
cls_opfromgraph = OpFromGraphInline inputs, outputs, inline=inline, grad_overrides=grad_overrides, **kwargs)
else:
cls_opfromgraph = OpFromGraphPrecompiled
return cls_opfromgraph(
inputs, outputs, grad_overrides=grad_overrides, **kwargs)
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
from functools import partial
import numpy as np import numpy as np
from theano import config, shared from theano import config, shared
...@@ -8,12 +9,12 @@ from theano.compile import function ...@@ -8,12 +9,12 @@ from theano.compile import function
from theano import tensor as T from theano import tensor as T
from theano.tensor.shared_randomstreams import RandomStreams from theano.tensor.shared_randomstreams import RandomStreams
from theano.compile.builders import OpFromGraphInline, OpFromGraphPrecompiled from theano.compile.builders import OpFromGraph
from theano.tests import unittest_tools from theano.tests import unittest_tools
test_params = unittest_tools.parameterized.expand( test_params = unittest_tools.parameterized.expand(
[(OpFromGraphInline,), (OpFromGraphPrecompiled,)]) [(OpFromGraph,), (partial(OpFromGraph, inline=True),)])
class T_OpFromGraph(unittest_tools.InferShapeTester): class T_OpFromGraph(unittest_tools.InferShapeTester):
...@@ -135,11 +136,11 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -135,11 +136,11 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
zz = T.sum(op_mul(xx, yy)) zz = T.sum(op_mul(xx, yy))
dx, dy = T.grad(zz, [xx, yy]) dx, dy = T.grad(zz, [xx, yy])
fn = function([xx, yy], [dx, dy]) fn = function([xx, yy], [dx, dy])
xv = numpy.random.rand(16).astype(config.floatX) xv = np.random.rand(16).astype(config.floatX)
yv = numpy.random.rand(16).astype(config.floatX) yv = np.random.rand(16).astype(config.floatX)
dxv, dyv = fn(xv, yv) dxv, dyv = fn(xv, yv)
assert numpy.allclose(yv * 2, dxv) assert np.allclose(yv * 2, dxv)
assert numpy.allclose(xv * 1.5, dyv) assert np.allclose(xv * 1.5, dyv)
# list override case # list override case
def go1(inps, gs): def go1(inps, gs):
...@@ -159,13 +160,13 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -159,13 +160,13 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
zz = T.sum(op_linear(xx, ww, bb)) zz = T.sum(op_linear(xx, ww, bb))
dx, dw, db = T.grad(zz, [xx, ww, bb]) dx, dw, db = T.grad(zz, [xx, ww, bb])
fn = function([xx, ww, bb], [dx, dw, db]) fn = function([xx, ww, bb], [dx, dw, db])
xv = numpy.random.rand(16).astype(config.floatX) xv = np.random.rand(16).astype(config.floatX)
wv = numpy.random.rand(16).astype(config.floatX) wv = np.random.rand(16).astype(config.floatX)
bv = numpy.random.rand(16).astype(config.floatX) bv = np.random.rand(16).astype(config.floatX)
dxv, dwv, dbv = fn(xv, wv, bv) dxv, dwv, dbv = fn(xv, wv, bv)
assert numpy.allclose(wv * 2, dxv) assert np.allclose(wv * 2, dxv)
assert numpy.allclose(xv * 1.5, dwv) assert np.allclose(xv * 1.5, dwv)
assert numpy.allclose(numpy.ones(16, dtype=config.floatX), dbv) assert np.allclose(np.ones(16, dtype=config.floatX), dbv)
@test_params @test_params
def test_nested(self, cls_ofg): def test_nested(self, cls_ofg):
...@@ -178,11 +179,11 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -178,11 +179,11 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
xx2, yy2 = op_ift(*op_ft(xx, yy)) xx2, yy2 = op_ift(*op_ft(xx, yy))
fn = function([xx, yy], [xx2, yy2]) fn = function([xx, yy], [xx2, yy2])
xv = numpy.random.rand(16).astype(config.floatX) xv = np.random.rand(16).astype(config.floatX)
yv = numpy.random.rand(16).astype(config.floatX) yv = np.random.rand(16).astype(config.floatX)
xv2, yv2 = fn(xv, yv) xv2, yv2 = fn(xv, yv)
assert numpy.allclose(xv, xv2) assert np.allclose(xv, xv2)
assert numpy.allclose(yv, yv2) assert np.allclose(yv, yv2)
@test_params @test_params
def test_connection_pattern(self, cls_ofg): def test_connection_pattern(self, cls_ofg):
...@@ -239,10 +240,9 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -239,10 +240,9 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
p = T.matrix('p') p = T.matrix('p')
# we don't want check_topo for inline ops # we don't want check_topo for inline ops
# since the inline op is replaced during optimization # since the inline op is replaced during optimization
is_compile = not issubclass(cls_ofg, OpFromGraphInline)
self._compile_and_check([q, p], self._compile_and_check([q, p],
op_graph(q, p), op_graph(q, p),
[np.ones([3, 4], dtype=config.floatX), [np.ones([3, 4], dtype=config.floatX),
np.ones([3, 4], dtype=config.floatX)], np.ones([3, 4], dtype=config.floatX)],
cls_ofg, cls_ofg,
check_topo=is_compile) check_topo=not op_graph.is_inline)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论