提交 47548c0a authored 作者: Pascal Lamblin's avatar Pascal Lamblin 提交者: GitHub

Merge pull request #5255 from khaotik/new_opfromgraph

Upgraded OpFromGraph with inline support and gradient override
...@@ -68,7 +68,8 @@ from theano.compile import ( ...@@ -68,7 +68,8 @@ from theano.compile import (
SymbolicOutput, Out, SymbolicOutput, Out,
Mode, Mode,
predefined_modes, predefined_linkers, predefined_optimizers, predefined_modes, predefined_linkers, predefined_optimizers,
FunctionMaker, function, function_dump, OpFromGraph, FunctionMaker, function, function_dump,
OpFromGraph,
ProfileStats, ProfileStats,
Param, shared, as_op) Param, shared, as_op)
......
from __future__ import absolute_import, print_function, division """Define new Ops from existing Ops"""
from __future__ import absolute_import, division, print_function
from functools import reduce, partial
from collections import OrderedDict
import theano import theano
from theano import gof from theano import gof
from theano.compat import izip from theano.compat import izip
from theano.compile.function_module import orig_function from theano.compile.function_module import orig_function
from theano.compile import SharedVariable, rebuild_collect_shared from theano.compile import SharedVariable, rebuild_collect_shared, optdb
from theano.gof import ops_with_inner_function from theano.gof import Variable, ops_with_inner_function
from theano.gof.graph import io_connection_pattern from theano.gof.graph import io_connection_pattern
from theano.gof.null_type import NullType
from functools import reduce from theano.gradient import DisconnectedType
class OpFromGraph(gof.Op): class OpFromGraph(gof.Op):
""" """
This creates an `Op` from inputs and outputs lists of variables. This creates an ``Op`` from inputs and outputs lists of variables.
The signature is similar to :func:`theano.function <theano.function>`
The signature is similar to theano.function() and the resulting and the resulting ``Op``'s perform will do the same operation as::
`Op`'s perform will do the same operation as::
orig_function(inputs, outputs, **kwargs) orig_function(inputs, outputs, **kwargs)
TODO: Currently does not support ``updates`` or ``givens`` argument.
Parameters
----------
inputs: list of :class:`Variable <theano.gof.Variable>`
outputs: list of :class:`Variable <theano.gof.Variable>`
inline: bool, optional
Defaults to ``False``
``True`` : Cause the Op's original graph being used during
compilation, the Op will not be visible in the compiled
graph but rather its internal graph.
``False`` : will use a pre-compiled function inside.
grad_overrides : single or list of {'default', OpFromGraph, callable, Variable with special type}, optional
Defaults to ``'default'``.
``'default'`` : Do not override, use default grad() result
OpFromGraph instance : Override with another OpFromGraph, should
accept inputs as the same order and types of "inputs" and "output_grads"
arguments as one would specify in grad() method.
callable : similar to OpFromGraph instance, must return list of
:class:`Variable <theano.gof.Variable>`.
Variable :
``NullType() instance`` : Treat as non-differentiable
``DisconnectedType() instance`` : Treat as disconnected gradient, numerically gives zero
list: Each OpFromGraph/callable must return a single
:class:`Variable <theano.gof.Variable>`. Each list element corresponds to gradient of
a specific input, length of list must be equal to number of inputs.
rop_overrides : single or list of {'default', OpFromGraph, callable, Variable with special type}, optional
Defaults to ``default``.
``'default'`` : Do not override, use default R_op() result
OpFromGraph instance : Override with another OpFromGraph, should
accept inputs as the same order and types of "inputs" and "output_grads"
arguments as one would specify in grad() method.
callable : similar to OpFromGraph instance, must return list of
:class:`Variable <theano.gof.Variable>`.
Variable :
``NullType() instance`` : Treat as non-differentiable
``DisconnectedType() instance`` : Treat as zero since DisconnectedType is not yet supported in R_op
list: Each OpFromGraph/callable must return a single
:class:`Variable <theano.gof.Variable>`. Each list element corresponds
to a specific output of R_op, length of list must be equal to number of outputs.
name : string, optional
A name for debugging purposes
\*\*kwargs : optional
Check
:func:`orig_function <theano.compile.function_module.orig_function>`
for more arguments, only works when not inline.
.. TODO:
- examples for a multi-layer mlp. where? - examples for a multi-layer mlp. where?
- __hash__, __eq__ otherwise won't merge, try - __hash__, __eq__ otherwise won't merge, try
gof.opt.is_same_graph_with_merge(op1.new_outputs, op2, gof.opt.is_same_graph_with_merge(op1.local_outputs, op2,
new_outputs) local_outputs)
- c_code() to remove the double overhead? - c_code() to remove the double overhead?
- opt to unfold it, work inplace on inputs
- grad() make it support DisconnectedType and the new interface - grad() make it support DisconnectedType and the new interface
- extend grad() to L_op
- add support for NullType and DisconnectedType when R_op supports them
- check how it works with updates. - check how it works with updates.
- add test with constant as input or inside the inner graph. - add test with constant as input or inside the inner graph.
- Add support for the GPU? Probably just need an opt to remove transfer - Add support for the GPU? Probably just need an opt to remove transfer
- Add support to pickle this Op. - Add support to pickle this Op.
- Add support/test with random generator - Add support/test with random generator
- Add optimization to removing unused inputs/outputs
- Add optimization to work inplace on inputs when not inline
Notes Notes
----- -----
- We support shared variables in the inner graph. This is automatic and - We support shared variables in the inner graph. This is automatic
invisible to the user. They can be as input to the node or in the and invisible to the user. They can be as input to the node or in
inner graph. the inner graph.
- We support unused inputs. This is needed for the grad. - We support unused inputs. This is needed for the grad.
- We support nested OpFromGraph.
- ``inline=True`` will cause better runtime optimization at the cost
of compilation time. Currently only works with ``fast_compile`` or
``fast_run`` mode.
- It's recommanded to provide pure functions (no side effects like
setting global variable) as callable(s). The callable(s) supplied
for overriding gradient/rop will be called only once at the first
call to grad/R_op, and will be converted to OpFromGraph instances.
Examples Examples
-------- --------
...@@ -70,42 +151,112 @@ class OpFromGraph(gof.Op): ...@@ -70,42 +151,112 @@ class OpFromGraph(gof.Op):
e2 = op(x, y, z) + op(z, y, x) e2 = op(x, y, z) + op(z, y, x)
fn = function([x, y, z], [e2]) fn = function([x, y, z], [e2])
Example 3 override gradient
.. code-block:: python
from theano import function, OpFromGraph, tensor, grad
x, y, z = tensor.scalars('xyz')
e = x + y * z
def rescale_dy(inps, grads):
x, y, z = inps
g, = grads
return z*2
op = OpFromGraph(
[x, y, z], [e], grad_overrides=['default', rescale_dy, 'default']
e2 = op(x, y, z)
dx, dy, dz = grad(e2, [x, y, z])
fn = function([x, y, z], [dx, dy, dz])
# the gradient wrt y is now doubled
fn(2., 3., 4.) # [1., 8., 3.]
""" """
def __init__(self, inputs, outputs, **kwargs): @staticmethod
def _filter_grad_var(grad, inp):
# Returns (filtered_var, overrider_var)
# Args:
# grad: gradient Variable
# inp: the corresponding input of gradient Variable
#
# a grad() call could return instance of NullType() or DisconnectedType()
# which cannot be directly used in OfG
#
# Since we always use an OfG instance as self._grad_op, the current
# workaround is to "remember" the special cases of the gradient and
# replace them after self._grad_op is called.
#
# This helper function changes invalid types into a filtered_var,
# and provides a overrider_var to be replaced at grad() call
#
# For now, this converts NullType or DisconnectedType into zeros_like.
# other types are unmodified: overrider_var -> None
if isinstance(grad.type, (NullType, DisconnectedType)):
if hasattr(inp, 'zeros_like'):
return inp.zeros_like(), grad
else:
return theano.tensor.constant(0.), grad
else:
return grad, None
@staticmethod
def _filter_rop_var(inpJ, out):
# mostly similar to _filter_grad_var
if isinstance(inpJ.type, NullType):
return out.zeros_like(), inpJ
if isinstance(inpJ.type, DisconnectedType):
# since R_op does not have DisconnectedType yet, we will just
# make them zeros.
return out.zeros_like(), None
else:
return inpJ, None
def __init__(
self, inputs, outputs,
inline=False,
grad_overrides='default', rop_overrides='default',
name=None, **kwargs
):
if not isinstance(outputs, list): if not isinstance(outputs, list):
raise TypeError('outputs must be list', outputs) raise TypeError('outputs must be list, got %s' % type(outputs))
for i in inputs + outputs: for i in inputs + outputs:
if not isinstance(i, gof.Variable): if not isinstance(i, gof.Variable):
raise TypeError( raise TypeError(
'inputs and outputs must be Variable instances', i) 'inputs and outputs must be Variable instances', i)
if 'updates' in kwargs or 'givens' in kwargs: if 'updates' in kwargs or 'givens' in kwargs:
raise TypeError('updates and givens are not allowed in kwargs') raise TypeError('updates and givens are not allowed here')
self.is_inline = inline
# To support correctly shared variables the inner fct should # To correctly support shared variables the inner fct should
# not see them. Otherwise their is problem with the gradient. # not see them. Otherwise there is a problem with the gradient.
self.shared_inputs = [var for var in gof.graph.inputs(outputs) self.shared_inputs = [var for var in gof.graph.inputs(outputs)
if isinstance(var, SharedVariable)] if isinstance(var, SharedVariable)]
shared_vars = [var.type() for var in self.shared_inputs] shared_vars = [var.type() for var in self.shared_inputs]
new = rebuild_collect_shared(outputs, inputs=inputs + shared_vars, new = rebuild_collect_shared(outputs, inputs=inputs + shared_vars,
replace=dict(izip(self.shared_inputs, replace=dict(izip(
shared_vars)), self.shared_inputs, shared_vars)),
copy_inputs_over=False) copy_inputs_over=False)
(new_inputs, new_outputs, (local_inputs, local_outputs,
[clone_d, update_d, update_expr, shared_inputs]) = new [clone_d, update_d, update_expr, shared_inputs]) = new
assert len(new_inputs) == len(inputs) + len(self.shared_inputs) assert len(local_inputs) == len(inputs) + len(self.shared_inputs)
assert len(new_outputs) == len(outputs) assert len(local_outputs) == len(outputs)
assert not update_d assert not update_d
assert not update_expr assert not update_expr
assert not shared_inputs assert not shared_inputs
self.new_inputs = new_inputs self.local_inputs = local_inputs
self.new_outputs = new_outputs self.local_outputs = local_outputs
self.inputs = inputs self.inputs = inputs
self.outputs = outputs self.outputs = outputs
self.kwargs = kwargs self.kwargs = kwargs
self.input_types = [input.type for input in inputs] self.input_types = [inp.type for inp in inputs]
self.output_types = [output.type for output in outputs] self.output_types = [out.type for out in outputs]
self.set_grad_overrides(grad_overrides)
self.set_rop_overrides(rop_overrides)
if name is not None:
assert isinstance(name, str), 'name must be None or string object'
self.name = name
def __eq__(self, other): def __eq__(self, other):
# TODO: recognize a copy # TODO: recognize a copy
...@@ -115,40 +266,323 @@ class OpFromGraph(gof.Op): ...@@ -115,40 +266,323 @@ class OpFromGraph(gof.Op):
# TODO: use internal variables in hash # TODO: use internal variables in hash
return hash(type(self)) return hash(type(self))
def make_node(self, *inputs): def __str__(self):
for input, type in zip(inputs, self.input_types): name = self.__class__.__name__ if self.name is None else self.name
if not type == input.type: is_inline = self.is_inline
raise TypeError("Wrong type, expected %s but got %s" % return '%(name)s{inline=%(is_inline)s}' % locals()
(type, input.type))
return gof.Apply(self, def _recompute_grad_op(self):
list(inputs) + self.shared_inputs, '''
[type() for type in self.output_types]) converts self._grad_op from user supplied form to type(self) instance
'''
local_inputs = self.local_inputs
local_outputs = self.local_outputs
inp_len = len(local_inputs)
grad_op = self._grad_op
if isinstance(grad_op, OpFromGraph):
if not self._grad_op_is_cached:
self._grad_op_is_cached = True
self._grad_op_stypes_l = [None] * inp_len
return
output_grads = [out_t() for out_t in self.output_types]
fn_grad = partial(
theano.gradient.grad,
cost=None,
disconnected_inputs='ignore',
return_disconnected='Disconnected',
null_gradients='return',
known_grads=OrderedDict(izip(local_outputs, output_grads)))
TYPE_ERR_MSG = ("Gradient override should be (single or list of)"
"'default' | OpFromGraph | callable | Variable "
"with NullType or DisconnectedType, got %s")
STYPE_ERR_MSG = ('Overriding Variable instance can only have type'
' of DisconnectedType or NullType, got %s')
# we need to convert _grad_op into an OfG instance
if grad_op == 'default':
gdefaults_l = fn_grad(wrt=local_inputs)
all_grads_l, all_grads_ov_l = izip(
*[OpFromGraph._filter_grad_var(grad, inp) for grad, inp in izip(gdefaults_l, local_inputs)])
all_grads_l = list(all_grads_l)
all_grads_ov_l = list(all_grads_ov_l)
elif isinstance(grad_op, Variable):
if isinstance(grad_op.type, (DisconnectedType, NullType)):
all_grads_l = [inp.zeros_like() for inp in local_inputs]
all_grads_ov_l = [grad_op.type() for _ in range(inp_len)]
else:
raise ValueError(STYPE_ERR_MSG % grad_op.type)
elif isinstance(grad_op, list):
goverrides_l = grad_op
if len(goverrides_l) != inp_len:
raise ValueError(
'Need to override %d gradients, got %d' % (
inp_len, len(goverrides_l)), goverrides_l)
# compute non-overriding downsteam grads from upstreams grads
# it's normal some input may be disconnected, thus the 'ignore'
wrt_l = [lin for lin, gov in izip(
local_inputs, goverrides_l) if gov == 'default']
gdefaults = iter(fn_grad(wrt=wrt_l) if wrt_l else [])
# combine overriding gradients
all_grads_l = []
all_grads_ov_l = []
for inp, fn_gov in izip(local_inputs, goverrides_l):
if fn_gov == 'default':
gnext, gnext_ov = OpFromGraph._filter_grad_var(
next(gdefaults), inp)
all_grads_l.append(gnext)
all_grads_ov_l.append(gnext_ov)
elif isinstance(fn_gov, Variable):
if isinstance(fn_gov.type, (DisconnectedType, NullType)):
all_grads_l.append(inp.zeros_like())
all_grads_ov_l.append(fn_gov.type())
else:
raise ValueError(STYPE_ERR_MSG % fn_gov.type)
else:
if not hasattr(fn_gov, '__call__'):
raise TypeError(TYPE_ERR_MSG % fn_gov)
gov, gov_ov = OpFromGraph._filter_grad_var(
fn_gov(local_inputs, output_grads), inp)
all_grads_l.append(gov)
all_grads_ov_l.append(gov_ov)
else:
# callable case
if not hasattr(grad_op, '__call__'):
raise TypeError(TYPE_ERR_MSG % grad_op)
goverrides_l = grad_op(local_inputs, output_grads)
if not isinstance(goverrides_l, list):
raise TypeError(
'Gradient overriding function should return a list, '
'got "%s"' % type(goverrides_l))
all_grads_l, all_grads_ov_l = izip(
*[OpFromGraph._filter_grad_var(grad, inp)
for grad, inp in izip(goverrides_l, local_inputs)])
if len(all_grads_l) != len(local_inputs):
raise ValueError(
'Gradient overriding function should return list of '
'%d outputs, got %d' % (inp_len, len(all_grads_l)))
all_grads_l = list(all_grads_l)
all_grads_ov_l = list(all_grads_ov_l)
self._grad_op = type(self)(
inputs=local_inputs + output_grads,
outputs=all_grads_l,
inline=self.is_inline,
name=(None if self.name is None else self.name + '_grad'),
on_unused_input='ignore')
self._grad_op_stypes_l = all_grads_ov_l
self._grad_op_is_cached = True
def _recompute_rop_op(self):
'''
converts self._rop_op from user supplied form to type(self) instance
'''
local_inputs = self.local_inputs
local_outputs = self.local_outputs
out_len = len(local_outputs)
rop_op = self._rop_op
if isinstance(rop_op, OpFromGraph):
if not self._rop_op_is_cached:
self._rop_op_is_cached = True
self._rop_op_stypes_l = [None] * out_len
return
eval_points = [inp_t() for inp_t in self.input_types]
fn_rop = partial(
theano.gradient.Rop,
wrt=local_inputs,
eval_points=eval_points)
TYPE_ERR_MSG = ("R_op overrides should be (single or list of)"
"OpFromGraph | 'default' | None | 0 | callable, got %s")
STYPE_ERR_MSG = ('Overriding Variable instance can only have type'
' of DisconnectedType or NullType, got %s')
if rop_op == 'default':
rdefaults_l = fn_rop(f=local_outputs)
all_rops_l, all_rops_ov_l = izip(
*[OpFromGraph._filter_rop_var(rop, out) for rop,
out in izip(rdefaults_l, local_outputs)])
all_rops_l = list(all_rops_l)
all_rops_ov_l = list(all_rops_ov_l)
elif isinstance(rop_op, Variable):
if isinstance(rop_op.type, NullType):
all_rops_l = [inp.zeros_like() for inp in local_inputs]
all_rops_ov_l = [rop_op.type() for _ in range(out_len)]
elif isinstance(rop_op.type, DisconnectedType):
all_rops_l = [inp.zeros_like() for inp in local_inputs]
all_rops_ov_l = [None] * out_len
else:
raise ValueError(STYPE_ERR_MSG % rop_op.type)
elif isinstance(rop_op, list):
roverrides_l = rop_op
if len(roverrides_l) != out_len:
raise ValueError(
'Need to override %d Rop, got %d' % (
out_len, len(roverrides_l)), roverrides_l)
# get outputs that does not have Rop override
odefaults_l = [
lo for lo, rov in izip(local_outputs, roverrides_l)
if rov == 'default']
rdefaults_l = fn_rop(f=odefaults_l)
rdefaults = iter(rdefaults_l if odefaults_l else [])
# combine overriding Rops
all_rops_l = []
all_rops_ov_l = []
for out, fn_rov in izip(local_outputs, roverrides_l):
if fn_rov == 'default':
rnext, rnext_ov = OpFromGraph._filter_rop_var(
next(rdefaults), out)
all_rops_l.append(rnext)
all_rops_ov_l.append(rnext_ov)
elif isinstance(fn_rov, Variable):
if isinstance(fn_rov.type, NullType):
all_rops_l.append(out.zeros_like())
all_rops_ov_l.append(fn_rov.type())
if isinstance(fn_rov.type, DisconnectedType):
all_rops_l.append(out.zeros_like())
all_rops_ov_l.append(None)
else:
raise ValueError(STYPE_ERR_MSG % fn_rov.type)
else:
if not hasattr(fn_rov, '__call__'):
raise TypeError(TYPE_ERR_MSG % fn_rov)
rov, rov_ov = OpFromGraph._filter_rop_var(
fn_rov(local_inputs, eval_points), out)
all_rops_l.append(rov)
all_rops_ov_l.append(rov_ov)
else:
if not hasattr(rop_op, '__call__'):
raise TypeError(TYPE_ERR_MSG % rop_op)
roverrides_l = rop_op(local_inputs, eval_points)
if not isinstance(roverrides_l, list):
raise TypeError(
'Rop overriding function should return a list, '
'got "%s"' % type(roverrides_l))
all_rops_l, all_rops_ov_l = izip(
*[OpFromGraph._filter_rop_var(
rop, out) for rop, out in izip(roverrides_l, local_outputs)])
if len(all_rops_l) != out_len:
raise ValueError(
'Rop overriding function %s should return list of '
'%d outputs, got %d' % (
self._rop_op, out_len,
len(all_rops_l)), rop_op)
all_rops_l = list(all_rops_l)
all_rops_ov_l = list(all_rops_ov_l)
self._rop_op = type(self)(
inputs=local_inputs + eval_points,
outputs=all_rops_l,
inline=self.is_inline,
name=(None if self.name is None else self.name + '_rop'),
on_unused_input='ignore')
self._rop_op_stypes_l = all_rops_ov_l
self._rop_op_is_cached = True
def get_grad_op(self):
"""
getter method for self._grad_op
"""
if not self._grad_op_is_cached:
self._recompute_grad_op()
return self._grad_op
def prepare_node(self, node, storage_map, compute_map, impl): def get_rop_op(self):
if not hasattr(self, "fn") and impl == 'py': """
self.fn = orig_function(self.new_inputs, getter method for self._rop_op
self.new_outputs, """
**self.kwargs) if not self._rop_op_is_cached:
self._recompute_rop_op()
return self._rop_op
def perform(self, node, inputs, outputs): def set_grad_overrides(self, grad_overrides):
variables = self.fn(*inputs) """
assert len(variables) == len(outputs) Set gradient overrides, see help(theano.OpFromGraph) for syntax
for output, variable in zip(outputs, variables): This will completely remove any previously set gradient overrides
# TODO: when function's output-borrowing semantics are correct,
# we wont need this copy anymore """
output[0] = variable.copy() self._grad_op = grad_overrides
self._grad_op_is_cached = False
def set_rop_overrides(self, rop_overrides):
"""
Set R_op overrides, see help(theano.OpFromGraph) for syntax
This will completely remove any previously set R_op overrides
"""
self._rop_op = rop_overrides
self._rop_op_is_cached = False
def grad(self, inputs, output_grads):
if not self._grad_op_is_cached:
self._recompute_grad_op()
ret_ofg_l = self._grad_op(
*(list(inputs) + list(output_grads)), return_list=True)
ret_l = [
ret_ofg if ov is None else ov for ret_ofg, ov in izip(
ret_ofg_l, self._grad_op_stypes_l)]
return ret_l
def R_op(self, inputs, eval_points):
if not self._rop_op_is_cached:
self._recompute_rop_op()
ret_ofg_l = self._rop_op(
*(list(inputs) + list(eval_points)), return_list=True)
ret_l = [
ret_ofg if ov is None else ov for ret_ofg, ov in izip(
ret_ofg_l, self._rop_op_stypes_l)]
return ret_l
def make_node(self, *inputs):
num_expected_inps = len(self.local_inputs) - len(self.shared_inputs)
if len(inputs) != num_expected_inps:
raise ValueError(
"Expected %d inputs, got %d" % (num_expected_inps, len(inputs)))
inputs = [inp_t.filter_variable(inp) for inp, inp_t in izip(inputs, self.input_types)]
apply_node = gof.Apply(
self, list(inputs) + self.shared_inputs,
[type() for type in self.output_types])
apply_node.local_inputs = self.local_inputs
apply_node.local_outputs = self.local_outputs
return apply_node
def connection_pattern(self, node): def connection_pattern(self, node):
""" """
Return connection pattern of subfgraph defined by inputs and outputs. Return connection pattern of subfgraph defined by inputs and outputs.
""" """
return io_connection_pattern(self.new_inputs, self.new_outputs) inp_len = len(self.local_inputs)
out_len = len(self.local_outputs)
cpmat_self = io_connection_pattern(
self.local_inputs, self.local_outputs)
grad_op = self.get_grad_op()
cpmat_grad = io_connection_pattern(
grad_op.local_inputs[inp_len:],
grad_op.local_outputs)
# cpmat_self |= cpmat_grad.T
# cpmat_self &= out_is_disconnected
for i, t in enumerate(self._grad_op_stypes_l):
if t is not None:
if isinstance(t.type, DisconnectedType):
for o in range(out_len):
cpmat_self[i][o] = False
for o in range(out_len):
cpmat_self[i][o] |= cpmat_grad[o][i]
# TODO in case DisconnectedType is implemented for R_op,
# self._rop_op_stypes_l self._rop_op should considered for
# connection_pattern
return list(map(list, cpmat_self))
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
out_shp = theano.scan_module.scan_utils.infer_shape(self.new_outputs, out_shp = theano.scan_module.scan_utils.infer_shape(
self.new_inputs, self.local_outputs,
shapes) self.local_inputs,
shapes)
# Clone the output shape so that shape are computed from outer inputs. # Clone the output shape so that shape are computed from outer inputs.
# Note: # Note:
...@@ -157,7 +591,7 @@ class OpFromGraph(gof.Op): ...@@ -157,7 +591,7 @@ class OpFromGraph(gof.Op):
# But doing it multiple time could duplicate common subgraph between # But doing it multiple time could duplicate common subgraph between
# each shape call. Theano optimizer will clean this up later, but this # each shape call. Theano optimizer will clean this up later, but this
# will ask extra work to the optimizer. # will ask extra work to the optimizer.
repl = dict(zip(self.new_inputs, node.inputs)) repl = dict(zip(self.local_inputs, node.inputs))
cloned = theano.clone(reduce(tuple.__add__, out_shp), replace=repl) cloned = theano.clone(reduce(tuple.__add__, out_shp), replace=repl)
ret = [] ret = []
used = 0 used = 0
...@@ -168,30 +602,46 @@ class OpFromGraph(gof.Op): ...@@ -168,30 +602,46 @@ class OpFromGraph(gof.Op):
return ret return ret
def grad(self, inputs, output_grads): def prepare_node(self, node, storage_map, compute_map, impl):
if hasattr(self, "grad_ops"): if not hasattr(self, "fn") and impl == 'py':
grad_ops = self.grad_ops self.fn = orig_function(self.local_inputs,
else: self.local_outputs,
gs = theano.gradient.grad(cost=None, **self.kwargs)
known_grads=dict(izip(self.new_outputs, self.fn.trust_input = True
output_grads)),
wrt=self.new_inputs, def perform(self, node, inputs, outputs):
disconnected_inputs='ignore') variables = self.fn(*inputs)
assert len(variables) == len(outputs)
grad_ops = [] for output, variable in izip(outputs, variables):
for g in gs: # TODO: when function's output-borrowing semantics are correct,
if g is None: # we wont need this copy anymore
grad_ops.append(lambda *args: None) output[0] = variable.copy()
else:
# It is normal if some inputs are not needed in order
# to compute the gradient, so we ignore them.
grad_ops.append(OpFromGraph(self.new_inputs + output_grads,
[g],
on_unused_input='ignore'))
self.grad_ops = grad_ops
return [go(*(inputs + output_grads)) for go in grad_ops]
# Since OpFromGraph contains a Theano compiled function, we should let @gof.local_optimizer([OpFromGraph])
# DebugMode know about it def inline_ofg_expansion(node):
"""
This optimization expands internal graph of OpFromGraph.
Only performed if node.op.is_inline == True
Doing so can improve optimization at the cost of compilation speed.
"""
op = node.op
if not isinstance(op, OpFromGraph):
return False
if not op.is_inline:
return False
return theano.clone(
op.local_outputs, {
u: v for u, v in izip(
node.op.local_inputs, node.inputs)})
# We want to run this before the first merge optimizer
# and before the first scan optimizer.
optdb.register(
'inline_ofg_expansion',
gof.opt.in2out(inline_ofg_expansion),
-0.01, 'fast_compile', 'fast_run')
# Since OpFromGraph contains a Theano compiled function,
# we should let DebugMode know about it
ops_with_inner_function[OpFromGraph] = 'fn' ops_with_inner_function[OpFromGraph] = 'fn'
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
from functools import partial
import numpy as np import numpy as np
from theano import config, shared from theano import config, shared
from theano.gradient import DisconnectedType
from theano.gof.null_type import NullType
from theano.compile import function from theano.compile import function
from theano import tensor as T from theano import tensor as T
...@@ -12,13 +15,17 @@ from theano.compile.builders import OpFromGraph ...@@ -12,13 +15,17 @@ from theano.compile.builders import OpFromGraph
from theano.tests import unittest_tools from theano.tests import unittest_tools
test_params = unittest_tools.parameterized.expand(
[(OpFromGraph,), (partial(OpFromGraph, inline=True),)])
class T_OpFromGraph(unittest_tools.InferShapeTester): class T_OpFromGraph(unittest_tools.InferShapeTester):
def test_straightforward(self): @test_params
def test_straightforward(self, cls_ofg):
x, y, z = T.matrices('xyz') x, y, z = T.matrices('xyz')
e = x + y * z e = x + y * z
op = OpFromGraph([x, y, z], [e]) op = cls_ofg([x, y, z], [e])
# (1+3*5=array of 16) - (3+1*5=array of 8) # (1+3*5=array of 16) - (3+1*5=array of 8)
f = op(x, y, z) - op(y, z, x) f = op(x, y, z) - op(y, z, x)
...@@ -32,10 +39,11 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -32,10 +39,11 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
assert np.all(8.0 == fn(xv, yv, zv)) assert np.all(8.0 == fn(xv, yv, zv))
assert np.all(8.0 == fn(xv, yv, zv)) assert np.all(8.0 == fn(xv, yv, zv))
def test_size_changes(self): @test_params
def test_size_changes(self, cls_ofg):
x, y, z = T.matrices('xyz') x, y, z = T.matrices('xyz')
e = T.dot(x, y) e = T.dot(x, y)
op = OpFromGraph([x, y], [e]) op = cls_ofg([x, y], [e])
f = op(x, op(y, z)) f = op(x, op(y, z))
fn = function([x, y, z], f) fn = function([x, y, z], f)
xv = np.ones((2, 3), dtype=config.floatX) xv = np.ones((2, 3), dtype=config.floatX)
...@@ -48,10 +56,11 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -48,10 +56,11 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
assert res.shape == (2, 5) assert res.shape == (2, 5)
assert np.all(180.0 == res) assert np.all(180.0 == res)
def test_grad(self): @test_params
def test_grad(self, cls_ofg):
x, y, z = T.matrices('xyz') x, y, z = T.matrices('xyz')
e = x + y * z e = x + y * z
op = OpFromGraph([x, y, z], [e]) op = cls_ofg([x, y, z], [e])
f = op(x, y, z) f = op(x, y, z)
f = f - T.grad(T.sum(f), y) f = f - T.grad(T.sum(f), y)
fn = function([x, y, z], f) fn = function([x, y, z], f)
...@@ -60,10 +69,11 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -60,10 +69,11 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
zv = np.ones((2, 2), dtype=config.floatX) * 5 zv = np.ones((2, 2), dtype=config.floatX) * 5
assert np.all(11.0 == fn(xv, yv, zv)) assert np.all(11.0 == fn(xv, yv, zv))
def test_grad_grad(self): @test_params
def test_grad_grad(self, cls_ofg):
x, y, z = T.matrices('xyz') x, y, z = T.matrices('xyz')
e = x + y * z e = x + y * z
op = OpFromGraph([x, y, z], [e]) op = cls_ofg([x, y, z], [e])
f = op(x, y, z) f = op(x, y, z)
f = f - T.grad(T.sum(f), y) f = f - T.grad(T.sum(f), y)
f = f - T.grad(T.sum(f), y) f = f - T.grad(T.sum(f), y)
...@@ -73,11 +83,12 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -73,11 +83,12 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
zv = np.ones((2, 2), dtype=config.floatX) * 5 zv = np.ones((2, 2), dtype=config.floatX) * 5
assert np.allclose(6.0, fn(xv, yv, zv)) assert np.allclose(6.0, fn(xv, yv, zv))
def test_shared(self): @test_params
def test_shared(self, cls_ofg):
x, y, z = T.matrices('xyz') x, y, z = T.matrices('xyz')
s = shared(np.random.rand(2, 2).astype(config.floatX)) s = shared(np.random.rand(2, 2).astype(config.floatX))
e = x + y * z + s e = x + y * z + s
op = OpFromGraph([x, y, z], [e]) op = cls_ofg([x, y, z], [e])
# (1+3*5=array of 16) - (3+1*5=array of 8) # (1+3*5=array of 16) - (3+1*5=array of 8)
f = op(x, y, z) - op(y, z, x) f = op(x, y, z) - op(y, z, x)
...@@ -90,11 +101,12 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -90,11 +101,12 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
assert np.allclose(8.0, fn(xv, yv, zv)) assert np.allclose(8.0, fn(xv, yv, zv))
assert np.allclose(8.0, fn(xv, yv, zv)) assert np.allclose(8.0, fn(xv, yv, zv))
def test_shared_grad(self): @test_params
def test_shared_grad(self, cls_ofg):
x, y, z = T.matrices('xyz') x, y, z = T.matrices('xyz')
s = shared(np.random.rand(2, 2).astype(config.floatX)) s = shared(np.random.rand(2, 2).astype(config.floatX))
e = x + y * z + s e = x + y * z + s
op = OpFromGraph([x, y, z], [e]) op = cls_ofg([x, y, z], [e])
f = op(x, y, z) f = op(x, y, z)
f = f - T.grad(T.sum(f), y) f = f - T.grad(T.sum(f), y)
fn = function([x, y, z], f) fn = function([x, y, z], f)
...@@ -110,13 +122,146 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -110,13 +122,146 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
assert np.allclose(15.0 + s.get_value(), assert np.allclose(15.0 + s.get_value(),
fn(xv, yv, zv)) fn(xv, yv, zv))
def test_connection_pattern(self): @test_params
def test_grad_override(self, cls_ofg):
x, y = T.vectors('xy')
def go(inps, gs):
x, y = inps
g, = gs
return [g * y * 2, g * x * 1.5]
dedz = T.vector('dedz')
op_mul_grad = cls_ofg([x, y, dedz], go([x, y], [dedz]))
op_mul = cls_ofg([x, y], [x * y], grad_overrides=go)
op_mul2 = cls_ofg([x, y], [x * y], grad_overrides=op_mul_grad)
# single override case (function or OfG instance)
xx, yy = T.vector('xx'), T.vector('yy')
for op in [op_mul, op_mul2]:
zz = T.sum(op(xx, yy))
dx, dy = T.grad(zz, [xx, yy])
fn = function([xx, yy], [dx, dy])
xv = np.random.rand(16).astype(config.floatX)
yv = np.random.rand(16).astype(config.floatX)
dxv, dyv = fn(xv, yv)
assert np.allclose(yv * 2, dxv)
assert np.allclose(xv * 1.5, dyv)
# list override case
def go1(inps, gs):
x, w, b = inps
g = gs[0]
return g * w * 2
def go2(inps, gs):
x, w, b = inps
g = gs[0]
return g * x * 1.5
w, b = T.vectors('wb')
# we make the 3rd gradient default (no override)
op_linear = cls_ofg([x, w, b], [x * w + b], grad_overrides=[go1, go2, 'default'])
xx, ww, bb = T.vector('xx'), T.vector('yy'), T.vector('bb')
zz = T.sum(op_linear(xx, ww, bb))
dx, dw, db = T.grad(zz, [xx, ww, bb])
fn = function([xx, ww, bb], [dx, dw, db])
xv = np.random.rand(16).astype(config.floatX)
wv = np.random.rand(16).astype(config.floatX)
bv = np.random.rand(16).astype(config.floatX)
dxv, dwv, dbv = fn(xv, wv, bv)
assert np.allclose(wv * 2, dxv)
assert np.allclose(xv * 1.5, dwv)
assert np.allclose(np.ones(16, dtype=config.floatX), dbv)
# NullType and DisconnectedType
op_linear2 = cls_ofg(
[x, w, b], [x * w + b],
grad_overrides=[go1, NullType()(), DisconnectedType()()])
zz2 = T.sum(op_linear2(xx, ww, bb))
dx2, dw2, db2 = T.grad(
zz2, [xx, ww, bb],
return_disconnected='Disconnected',
disconnected_inputs='ignore',
null_gradients='return')
assert isinstance(dx2.type, T.TensorType)
assert dx2.ndim == 1
assert isinstance(dw2.type, NullType)
assert isinstance(db2.type, DisconnectedType)
@test_params
def test_rop(self, cls_ofg):
a = T.vector()
M = T.matrix()
b = T.dot(a, M)
op_matmul = cls_ofg([a, M], [b])
x = T.vector()
W = T.matrix()
y = op_matmul(x, W)
du = T.vector()
dv = T.Rop(y, x, du)
fn = function([x, W, du], dv)
xval = np.random.rand(16).astype(config.floatX)
Wval = np.random.rand(16, 16).astype(config.floatX)
duval = np.random.rand(16).astype(config.floatX)
dvval = np.dot(duval, Wval)
dvval2 = fn(xval, Wval, duval)
assert np.allclose(dvval2, dvval)
@test_params
def test_rop_override(self, cls_ofg):
x, y = T.vectors('xy')
def ro(inps, epts):
x, y = inps
u, v = epts
return [u * y * 2. + x * v * 1.5]
u, v = T.vectors('uv')
op_mul_rop = cls_ofg([x, y, u, v], ro([x, y], [u, v]))
op_mul = cls_ofg([x, y], [x * y], rop_overrides=ro)
op_mul2 = cls_ofg([x, y], [x * y], rop_overrides=op_mul_rop)
# single override case
xx, yy = T.vector('xx'), T.vector('yy')
du, dv = T.vector('du'), T.vector('dv')
for op in [op_mul, op_mul2]:
zz = op_mul(xx, yy)
dw = T.Rop(zz, [xx, yy], [du, dv])
fn = function([xx, yy, du, dv], dw)
vals = np.random.rand(4, 32).astype(config.floatX)
dwval = fn(*vals)
assert np.allclose(
dwval, vals[0] * vals[3] * 1.5 + vals[1] * vals[2] * 2.)
# TODO list override case
@test_params
def test_nested(self, cls_ofg):
x, y = T.vectors('xy')
u, v = x + y, x - y
op_ft = cls_ofg([x, y], [u, v])
op_ift = cls_ofg([x, y], [u / 2, v / 2])
xx, yy = T.vector('xx'), T.vector('yy')
xx2, yy2 = op_ift(*op_ft(xx, yy))
fn = function([xx, yy], [xx2, yy2])
xv = np.random.rand(16).astype(config.floatX)
yv = np.random.rand(16).astype(config.floatX)
xv2, yv2 = fn(xv, yv)
assert np.allclose(xv, xv2)
assert np.allclose(yv, yv2)
@test_params
def test_connection_pattern(self, cls_ofg):
# Basic case # Basic case
x, y, z = T.matrices('xyz') x, y, z = T.matrices('xyz')
out1 = x * y out1 = x * y
out2 = y * z out2 = y * z
op1 = OpFromGraph([x, y, z], [out1, out2]) op1 = cls_ofg([x, y, z], [out1, out2])
results = op1.connection_pattern(None) results = op1.connection_pattern(None)
expect_result = [[True, False], expect_result = [[True, False],
[True, True], [True, True],
...@@ -128,7 +273,7 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -128,7 +273,7 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
m, n, p, q = T.matrices('mnpq') m, n, p, q = T.matrices('mnpq')
o1, o2 = op1(m, n, p) o1, o2 = op1(m, n, p)
out1, out2 = op1(o1, q, o2) out1, out2 = op1(o1, q, o2)
op2 = OpFromGraph([m, n, p, q], [out1, out2]) op2 = cls_ofg([m, n, p, q], [out1, out2])
results = op2.connection_pattern(None) results = op2.connection_pattern(None)
expect_result = [[True, False], expect_result = [[True, False],
...@@ -144,7 +289,7 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -144,7 +289,7 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
out1 = x + rv_u out1 = x + rv_u
out2 = y + 3 out2 = y + 3
out3 = 3 + rv_u out3 = 3 + rv_u
op3 = OpFromGraph([x, y], [out1, out2, out3]) op3 = cls_ofg([x, y], [out1, out2, out3])
results = op3.connection_pattern(None) results = op3.connection_pattern(None)
expect_result = [[True, False, False], expect_result = [[True, False, False],
...@@ -153,6 +298,8 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -153,6 +298,8 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
assert results == expect_result assert results == expect_result
def test_infer_shape(self): def test_infer_shape(self):
# test infer shape does not need to against inline case
# since the Op is remove during optimization phase
x = T.matrix('x') x = T.matrix('x')
y = T.matrix('y') y = T.matrix('y')
o1 = x + y o1 = x + y
......
...@@ -244,15 +244,14 @@ class PyDotFormatter(object): ...@@ -244,15 +244,14 @@ class PyDotFormatter(object):
# Inputs mapping # Inputs mapping
ext_inputs = [self.__node_id(x) for x in node.inputs] ext_inputs = [self.__node_id(x) for x in node.inputs]
int_inputs = [gf.__node_id(x) int_inputs = [gf.__node_id(x)
for x in node.op.fn.maker.fgraph.inputs] for x in node.op.local_inputs]
assert len(ext_inputs) == len(int_inputs) assert len(ext_inputs) == len(int_inputs)
h = format_map(zip(ext_inputs, int_inputs)) h = format_map(zip(ext_inputs, int_inputs))
pd_node.get_attributes()['subg_map_inputs'] = h pd_node.get_attributes()['subg_map_inputs'] = h
# Outputs mapping # Outputs mapping
ext_outputs = [self.__node_id(x) for x in node.outputs] ext_outputs = [self.__node_id(x) for x in node.outputs]
int_outputs = node.op.fn.maker.fgraph.outputs int_outputs = [gf.__node_id(x) for x in node.op.local_outputs]
int_outputs = [gf.__node_id(x) for x in int_outputs]
assert len(ext_outputs) == len(int_outputs) assert len(ext_outputs) == len(int_outputs)
h = format_map(zip(int_outputs, ext_outputs)) h = format_map(zip(int_outputs, ext_outputs))
pd_node.get_attributes()['subg_map_outputs'] = h pd_node.get_attributes()['subg_map_outputs'] = h
......
...@@ -1099,7 +1099,10 @@ def io_connection_pattern(inputs, outputs): ...@@ -1099,7 +1099,10 @@ def io_connection_pattern(inputs, outputs):
# connnection patterns of the individual outputs # connnection patterns of the individual outputs
global_connection_pattern = [[] for o in range(len(inputs))] global_connection_pattern = [[] for o in range(len(inputs))]
for out in outputs: for out in outputs:
out_connection_pattern = connect_pattern_by_var[out] out_connection_pattern = connect_pattern_by_var.get(out)
if out_connection_pattern is None:
# the output is completely isolated from inputs
out_connection_pattern = [False] * len(inputs)
for i in range(len(inputs)): for i in range(len(inputs)):
global_connection_pattern[i].append(out_connection_pattern[i]) global_connection_pattern[i].append(out_connection_pattern[i])
......
...@@ -340,14 +340,22 @@ class TestAutoName: ...@@ -340,14 +340,22 @@ class TestAutoName:
assert r2.auto_name == "auto_" + str(autoname_id + 1) assert r2.auto_name == "auto_" + str(autoname_id + 1)
def test_constant(self): def test_constant(self):
# Make sure the value we will use for the test aren't yet in the cache.
r1 = tensor.constant(1.5)
del tensor.constant_cache[r1.signature()]
r1 = tensor.constant(1.6)
del tensor.constant_cache[r1.signature()]
# Get counter value # Get counter value
autoname_id = next(Variable.__count__) autoname_id = next(Variable.__count__)
Variable.__count__ = count(autoname_id) Variable.__count__ = count(autoname_id)
r1 = tensor.constant(1.5) r1 = tensor.constant(1.5)
r2 = tensor.constant(1.5) r2 = tensor.constant(1.5)
assert r1.auto_name == "auto_" + str(autoname_id) assert r1.auto_name == "auto_" + str(autoname_id), (
r1.auto_name, "auto_" + str(autoname_id))
# We reuse the same variable # We reuse the same variable
assert r2.auto_name == "auto_" + str(autoname_id) assert r2.auto_name == "auto_" + str(autoname_id), (
r2.auto_name, "auto_" + str(autoname_id))
assert r1 is r2 assert r1 is r2
r3 = tensor.constant(1.6) r3 = tensor.constant(1.6)
......
...@@ -1201,58 +1201,56 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -1201,58 +1201,56 @@ def _populate_grad_dict(var_to_app_to_idx,
is_zero = _is_zero(term) is_zero = _is_zero(term)
assert is_zero in ['yes', 'no', 'maybe'] assert is_zero in ['yes', 'no', 'maybe']
if is_zero == 'maybe': if is_zero == 'maybe':
msg = "%s.grad returned %s of type %s for input" msg = ("%s.grad returned %s of type %s for input"
msg += " %d. This input's only connections to " " %d. This input's only connections to "
msg += "the cost through this op are via " "the cost through this op are via "
msg += "integer-valued outputs so it should be " "integer-valued outputs so it should be "
msg += "NullType, DisconnectedType, or some form " "NullType, DisconnectedType, or some form "
msg += "of zeros. It is not NullType or " "of zeros. It is not NullType or "
msg += "DisconnectedType and theano can't " "DisconnectedType and theano can't "
msg += "simplify it to a constant, so it's not " "simplify it to a constant, so it's not "
msg += "verifiably zeros." "verifiably zeros.")
msg = msg % (str(node.op), str(term), msg %= (node.op, term, type(term), i)
str(type(term)), i)
elif is_zero == 'no':
if is_zero == 'no': msg = ("%s.grad returned %s of type %s for input"
msg = "%s.grad returned %s of type %s for input" " %d. Since this input is only connected "
msg += " %d. Since this input is only connected " "to integer-valued outputs, it should "
msg += "to integer-valued outputs, it should " "evaluate to zeros, but it evaluates to"
msg += "evaluate to zeros, but it evaluates to" "%s.")
msg += "%s."
msg %= (node.op, term, type(term), i,
msg % (node.op, term, type(term), i, theano.get_scalar_constant_value(term))
theano.get_scalar_constant_value(term))
raise ValueError(msg) raise ValueError(msg)
# Check that op.connection_pattern matches the connectivity # Check that op.connection_pattern matches the connectivity
# logic driving the op.grad method # logic driving the op.grad method
for i, packed in enumerate(zip(inputs, input_grads, for i, (ipt, ig, connected) in enumerate(
inputs_connected)): zip(inputs, input_grads, inputs_connected)
ipt, ig, connected = packed ):
actually_connected = \ actually_connected = \
not isinstance(ig.type, DisconnectedType) not isinstance(ig.type, DisconnectedType)
if actually_connected and not connected: if actually_connected and not connected:
msg = "%s.grad returned %s of type %s for input %d." msg = ("%s.grad returned %s of type %s for input %d."
msg += " Expected DisconnectedType instance based on " " Expected DisconnectedType instance based on "
msg += " the output of the op's connection_pattern " " the output of the op's connection_pattern "
msg += "method." "method.")
msg = msg % (str(node.op), str(ig), str(ig.type), i) msg %= (str(node.op), str(ig), str(ig.type), i)
raise TypeError(msg) raise TypeError(msg)
if connected and not actually_connected: elif connected and not actually_connected:
msg = "%s.grad returned DisconnectedType for input" msg = "%s.grad returned DisconnectedType for input %d."
msg += " %d." msg %= (str(node.op), i)
msg = msg % (str(node.op), i)
if hasattr(node.op, 'connection_pattern'): if hasattr(node.op, 'connection_pattern'):
msg += ' Its connection_pattern method does not' msg += (' Its connection_pattern method does not'
msg += ' allow this.' ' allow this.')
raise TypeError(msg) raise TypeError(msg)
else: else:
msg += ' You may want to implement a ' msg += (' You may want to implement a '
msg += 'connection_pattern method for it.' 'connection_pattern method for it.')
warnings.warn(msg) warnings.warn(msg)
# cache the result # cache the result
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论