提交 788e8bac authored 作者: khaotik's avatar khaotik 提交者: khaotik

refactor into single OpFromGraph

上级 9b75b7ae
......@@ -69,7 +69,7 @@ from theano.compile import (
Mode,
predefined_modes, predefined_linkers, predefined_optimizers,
FunctionMaker, function, function_dump,
OpFromGraph, OpFromGraphInline, OpFromGraphPrecompiled, op_from_graph,
OpFromGraph, op_from_graph,
ProfileStats,
Param, shared, as_op)
......
......@@ -11,13 +11,13 @@ from theano.gof import ops_with_inner_function
from theano.gof.graph import io_connection_pattern
class OpFromGraphBase(gof.Op):
class OpFromGraph(gof.Op):
"""
base class for Ops with custom inner graph
class for Ops with user-defined inner graph
"""
# NOTE: if you make a subclass of this, make sure add it under:
# theano/compile/tests/test_builders.py
def __init__(self, inputs, outputs, grad_overrides=None, **kwargs):
# NOTE: if you make a subclass of this, make sure add test for it under:
# theano/compile/tests/test_builders.py
def __init__(self, inputs, outputs, inline=False, grad_overrides=None, **kwargs):
if not isinstance(outputs, list):
raise TypeError('outputs must be list', outputs)
for i in inputs + outputs:
......@@ -26,6 +26,7 @@ class OpFromGraphBase(gof.Op):
'inputs and outputs must be Variable instances', i)
if 'updates' in kwargs or 'givens' in kwargs:
raise TypeError('updates and givens are not allowed here')
self.is_inline = inline
# To correctly support shared variables the inner fct should
# not see them. Otherwise there is a problem with the gradient.
self.shared_inputs = [var for var in gof.graph.inputs(outputs)
......@@ -68,55 +69,40 @@ class OpFromGraphBase(gof.Op):
if self.cached_grad_ops:
return self.grad_ops(inputs, output_grads)
grad_inps = self.internal_inputs + output_grads
upstream_grads = dict(izip(self.internal_outputs, output_grads))
if self.grad_ops is not None:
grad_ops_l = self.grad_ops
if isinstance(grad_ops_l, list):
assert len(grad_ops_l) <= len(self.internal_inputs)
if len(grad_ops_l) < len(self.internal_inputs):
grad_ops_l += [None] * (
len(self.internal_inputs) - len(grad_ops_l))
# It is normal if some inputs are not needed in order
# to compute the gradient, so we ignore them.
gs = [go if go else type(self)(
grad_inps,
(lambda g: g if g else (lambda *a:None))(
theano.gradient.grad(
cost=None,
known_grads=upstream_grads,
wrt=[inp],
disconnected_inputs='ignore')
), on_unused_input='ignore'
) for go, inp in izip(grad_ops_l, self.internal_inputs)]
# since OpFromGraphBase only accepts input sequence,
# additional filtering is needed
def grad_ops(inps, grds):
# nonlocal gs, grad_ops_l
return [(go(inps, grds) if ov else go(*(inps + grds)))
for go, ov in izip(gs, grad_ops_l)]
else:
grad_ops = grad_ops_l
self.grad_ops = grad_ops
else:
gs = theano.gradient.grad(
cost=None,
known_grads=upstream_grads,
wrt=self.internal_inputs,
disconnected_inputs='ignore')
grad_ops_l = []
for g in gs:
if g is None:
grad_ops_l.append(lambda *args: None)
else:
grad_ops_l.append(type(self)(
grad_inps, [g], on_unused_input='ignore'))
if self.grad_ops is None:
self.grad_ops = []
grad_ops_l = self.grad_ops
if isinstance(grad_ops_l, list):
if len(grad_ops_l) > len(self.internal_inputs):
raise ValueError(
'Can override %d gradients at most, got %d' % (
len(self.internal_inputs), len(grad_ops_l)))
if len(grad_ops_l) < len(self.internal_inputs):
grad_ops_l += [None] * (
len(self.internal_inputs) - len(grad_ops_l))
# It is normal if some inputs are not needed in order
# to compute the gradient, so we ignore them.
gs = [go if go else type(self)(
self.internal_inputs + output_grads,
(lambda g: g if g else (lambda *a:None))(
theano.gradient.grad(
cost=None,
known_grads=upstream_grads,
wrt=[inp],
disconnected_inputs='ignore')
), on_unused_input='ignore'
) for go, inp in izip(grad_ops_l, self.internal_inputs)]
# since OpFromGraphBase only accepts input sequence,
# additional filtering is needed
def grad_ops(inps, grds):
# nonlocal grad_ops_l
return [go(*(inps + grds)) for go in grad_ops_l]
# nonlocal gs, grad_ops_l
return [(go(inps, grds) if ov else go(*(inps + grds)))
for go, ov in izip(gs, grad_ops_l)]
self.grad_ops = grad_ops
else:
grad_ops = grad_ops_l
self.cached_grad_ops = True
return grad_ops(inputs, output_grads)
......@@ -165,14 +151,6 @@ class OpFromGraphBase(gof.Op):
return ret
def perform(self, node, inputs, outputs):
raise NotImplementedError()
class OpFromGraphPrecompiled(OpFromGraphBase):
"""
The Op's inner graph is compiled into a theano function.
"""
def prepare_node(self, node, storage_map, compute_map, impl):
if not hasattr(self, "fn") and impl == 'py':
self.fn = orig_function(self.internal_inputs,
......@@ -187,43 +165,33 @@ class OpFromGraphPrecompiled(OpFromGraphBase):
# we wont need this copy anymore
output[0] = variable.copy()
class OpFromGraphInline(OpFromGraphBase):
"""
The Op's inner graph is expanded into the outer graph at compile time
@gof.local_optimizer([OpFromGraph])
def inline_ofg_expansion(node):
"""
def perform(self, node, inputs, outputs):
raise RuntimeError(
type(self).__name__ + ' is not supposed to be executed at runtime')
This optimization expands internal graph of OpFromGraph.
@gof.local_optimizer([OpFromGraphInline])
def inline_ofg_expansion(node):
""" This optimization expands internal graph of OpFromGraphInline
Doing so can improve optimization at the cost of compilation speed.
"""
op = node.op
if not isinstance(op, OpFromGraphInline):
if not isinstance(op, OpFromGraph):
return False
if not op.is_inline:
return False
outputs = theano.clone(
return theano.clone(
op.internal_outputs, {
u: v for u, v in izip(
node.op.internal_inputs, node.inputs)})
return outputs
optdb.register(
'inline_ofg_expansion',
gof.opt.in2out(inline_ofg_expansion),
0.5, 'fast_compile', 'fast_run')
# Since OpFromGraphPrecompiled contains a Theano compiled function,
# Since OpFromGraph contains a Theano compiled function,
# we should let DebugMode know about it
ops_with_inner_function[OpFromGraphPrecompiled] = 'fn'
ops_with_inner_function[OpFromGraph] = 'fn'
# for backward compatibility
OpFromGraph = OpFromGraphPrecompiled
# API for OpFromGraph*
# API for OpFromGraph
def op_from_graph(
inputs, outputs, inline=False, grad_overrides=None, **kwargs
):
......@@ -254,9 +222,6 @@ def op_from_graph(
- list : each function must return a single Variable. The order
of the list must corresponds to inputs
Notes
-----
TODO:
- examples for a multi-layer mlp. where?
- __hash__, __eq__ otherwise won't merge, try
......@@ -274,12 +239,14 @@ def op_from_graph(
Notes
-----
- We support shared variables in the inner graph. This is automatic and
invisible to the user. They can be as input to the node or in the
inner graph.
- We support shared variables in the inner graph. This is automatic
and invisible to the user. They can be as input to the node or in
the inner graph.
- We support unused inputs. This is needed for the grad.
- `inline=True` will cause better runtime optimization at the cost of
longer compilation, only works with optimizer "fast_run" or "fast_compile"
- `inline=True` will cause better runtime optimization at the cost
of compilation time. Like "inline" keyword in C, this is merely a
suggestion to compiler which is not guaranteed. Currently only
works with "fast_compile" or "fast_run" mode.
Examples
--------
......@@ -331,9 +298,5 @@ def op_from_graph(
fn(2., 3., 4.) # [1., 8., 3.]
"""
if inline and theano.config.optimizer in ['fast_run', 'fast_compile']:
cls_opfromgraph = OpFromGraphInline
else:
cls_opfromgraph = OpFromGraphPrecompiled
return cls_opfromgraph(
inputs, outputs, grad_overrides=grad_overrides, **kwargs)
return OpFromGraph(
inputs, outputs, inline=inline, grad_overrides=grad_overrides, **kwargs)
from __future__ import absolute_import, print_function, division
from functools import partial
import numpy as np
from theano import config, shared
......@@ -8,12 +9,12 @@ from theano.compile import function
from theano import tensor as T
from theano.tensor.shared_randomstreams import RandomStreams
from theano.compile.builders import OpFromGraphInline, OpFromGraphPrecompiled
from theano.compile.builders import OpFromGraph
from theano.tests import unittest_tools
test_params = unittest_tools.parameterized.expand(
[(OpFromGraphInline,), (OpFromGraphPrecompiled,)])
[(OpFromGraph,), (partial(OpFromGraph, inline=True),)])
class T_OpFromGraph(unittest_tools.InferShapeTester):
......@@ -135,11 +136,11 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
zz = T.sum(op_mul(xx, yy))
dx, dy = T.grad(zz, [xx, yy])
fn = function([xx, yy], [dx, dy])
xv = numpy.random.rand(16).astype(config.floatX)
yv = numpy.random.rand(16).astype(config.floatX)
xv = np.random.rand(16).astype(config.floatX)
yv = np.random.rand(16).astype(config.floatX)
dxv, dyv = fn(xv, yv)
assert numpy.allclose(yv * 2, dxv)
assert numpy.allclose(xv * 1.5, dyv)
assert np.allclose(yv * 2, dxv)
assert np.allclose(xv * 1.5, dyv)
# list override case
def go1(inps, gs):
......@@ -159,13 +160,13 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
zz = T.sum(op_linear(xx, ww, bb))
dx, dw, db = T.grad(zz, [xx, ww, bb])
fn = function([xx, ww, bb], [dx, dw, db])
xv = numpy.random.rand(16).astype(config.floatX)
wv = numpy.random.rand(16).astype(config.floatX)
bv = numpy.random.rand(16).astype(config.floatX)
xv = np.random.rand(16).astype(config.floatX)
wv = np.random.rand(16).astype(config.floatX)
bv = np.random.rand(16).astype(config.floatX)
dxv, dwv, dbv = fn(xv, wv, bv)
assert numpy.allclose(wv * 2, dxv)
assert numpy.allclose(xv * 1.5, dwv)
assert numpy.allclose(numpy.ones(16, dtype=config.floatX), dbv)
assert np.allclose(wv * 2, dxv)
assert np.allclose(xv * 1.5, dwv)
assert np.allclose(np.ones(16, dtype=config.floatX), dbv)
@test_params
def test_nested(self, cls_ofg):
......@@ -178,11 +179,11 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
xx2, yy2 = op_ift(*op_ft(xx, yy))
fn = function([xx, yy], [xx2, yy2])
xv = numpy.random.rand(16).astype(config.floatX)
yv = numpy.random.rand(16).astype(config.floatX)
xv = np.random.rand(16).astype(config.floatX)
yv = np.random.rand(16).astype(config.floatX)
xv2, yv2 = fn(xv, yv)
assert numpy.allclose(xv, xv2)
assert numpy.allclose(yv, yv2)
assert np.allclose(xv, xv2)
assert np.allclose(yv, yv2)
@test_params
def test_connection_pattern(self, cls_ofg):
......@@ -239,10 +240,9 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
p = T.matrix('p')
# we don't want check_topo for inline ops
# since the inline op is replaced during optimization
is_compile = not issubclass(cls_ofg, OpFromGraphInline)
self._compile_and_check([q, p],
op_graph(q, p),
[np.ones([3, 4], dtype=config.floatX),
np.ones([3, 4], dtype=config.floatX)],
cls_ofg,
check_topo=is_compile)
check_topo=not op_graph.is_inline)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论