提交 0845ddc3 authored 作者: lamblin's avatar lamblin

Merge pull request #1407 from nouiz/fix_test_p33

Fix test p33
...@@ -286,6 +286,41 @@ can be achieved as follows: ...@@ -286,6 +286,41 @@ can be achieved as follows:
# Inputs : [array(0.0)] # Inputs : [array(0.0)]
# Outputs: [array(nan)] # Outputs: [array(nan)]
To help understand what is happening in your graph, you can
disable the ``local_elemwise_fusion`` and all ``inplace``
optimizations. The first is a speed optimization that merge elemwise
operations together. This make it harder to know which particular
elemwise cause the problem. The second optimization make some ops
output overwrite its input. So, if an op create a bad output, you
won't be able see the input that was overwriten in the ``post_fun``
function. To disable those optimization (with a Theano version after
0.6rc3), define the MonitorMode like this:
.. code-block:: python
mode = theano.compile.MonitorMode(post_func=detect_nan).excluding(
'local_elemwise_fusion', 'inplace)
f = theano.function([x], [theano.tensor.log(x) * x],
mode=mode)
.. note::
The Theano flags ``optimizer_including``, ``optimizer_excluding``
and ``optimizer_requiring`` aren't used by the MonitorMode, they
are used only by the ``default`` mode. You can't use the ``default``
mode with MonitorMode, as you need to define what you monitor.
To be sure all inputs of the node are available during the call to
``post_func``, you also must disable the garbage collector. Otherwise,
the execution of the node can garbage collect its inputs that aren't
needed anymore by the Theano function. This can be done with the Theano
flag:
.. code-block:: cfg
allow_gc=False
.. TODO: documentation for link.WrapLinkerMany .. TODO: documentation for link.WrapLinkerMany
......
...@@ -20,7 +20,8 @@ class MonitorMode(Mode): ...@@ -20,7 +20,8 @@ class MonitorMode(Mode):
For an example of such a use case, see doc/tutorial/debug_faq.txt. For an example of such a use case, see doc/tutorial/debug_faq.txt.
""" """
def __init__(self, pre_func=None, post_func=None, optimizer='fast_run'): def __init__(self, pre_func=None, post_func=None,
optimizer='default', linker=None):
""" """
Constructor. Constructor.
...@@ -35,11 +36,21 @@ class MonitorMode(Mode): ...@@ -35,11 +36,21 @@ class MonitorMode(Mode):
:param optimizer: The optimizer to use. One may use for instance :param optimizer: The optimizer to use. One may use for instance
'fast_compile' to skip optimizations. 'fast_compile' to skip optimizations.
:param linker: DO NOT USE. This mode use its own linker.
The parameter is needed to allow selecting optimizers to use.
""" """
self.pre_func = pre_func self.pre_func = pre_func
self.post_func = post_func self.post_func = post_func
wrap_linker = theano.gof.WrapLinkerMany([theano.gof.OpWiseCLinker()], wrap_linker = theano.gof.WrapLinkerMany([theano.gof.OpWiseCLinker()],
[self.eval]) [self.eval])
if optimizer is 'default':
optimizer = theano.config.optimizer
if (linker is not None and
not isinstance(linker.mode, MonitorMode)):
raise Exception("MonitorMode can only use its own linker! You "
"should not provide one.", linker)
super(MonitorMode, self).__init__(wrap_linker, optimizer=optimizer) super(MonitorMode, self).__init__(wrap_linker, optimizer=optimizer)
def eval(self, i, node, fn): def eval(self, i, node, fn):
...@@ -51,3 +62,21 @@ class MonitorMode(Mode): ...@@ -51,3 +62,21 @@ class MonitorMode(Mode):
fn() fn()
if self.post_func is not None: if self.post_func is not None:
self.post_func(i, node, fn) self.post_func(i, node, fn)
def including(self, *tags):
ret = super(MonitorMode, self).including(*tags)
ret.pre_func = self.pre_func
ret.post_func = self.post_func
return ret
def excluding(self, *tags):
ret = super(MonitorMode, self).excluding(*tags)
ret.pre_func = self.pre_func
ret.post_func = self.post_func
return ret
def requiring(self, *tags):
ret = super(MonitorMode, self).requiring(*tags)
ret.pre_func = self.pre_func
ret.post_func = self.post_func
return ret
...@@ -439,6 +439,11 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None, ...@@ -439,6 +439,11 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
and not isinstance(no_default_updates, list): and not isinstance(no_default_updates, list):
raise TypeError("no_default_update should be either a boolean or a list") raise TypeError("no_default_update should be either a boolean or a list")
if len(updates) > 0 and any(isinstance(v, Variable)
for v in iter_over_pairs(updates)):
raise ValueError(
"The updates parameter must an OrderedDict/dict or a list of list/tuple with 2 elements")
# transform params into theano.compile.In objects. # transform params into theano.compile.In objects.
inputs = [_pfunc_param_to_in(p, allow_downcast=allow_input_downcast) inputs = [_pfunc_param_to_in(p, allow_downcast=allow_input_downcast)
for p in params] for p in params]
......
...@@ -25,3 +25,67 @@ def test_detect_nan(): ...@@ -25,3 +25,67 @@ def test_detect_nan():
post_func=detect_nan)) post_func=detect_nan))
f(0) # log(0) * 0 = -inf * 0 = NaN f(0) # log(0) * 0 = -inf * 0 = NaN
assert nan_detected[0] assert nan_detected[0]
def test_optimizer():
"""
Test that we can remove optimizer
"""
nan_detected = [False]
def detect_nan(i, node, fn):
for output in fn.outputs:
if numpy.isnan(output[0]).any():
print '*** NaN detected ***'
theano.printing.debugprint(node)
print 'Inputs : %s' % [input[0] for input in fn.inputs]
print 'Outputs: %s' % [output[0] for output in fn.outputs]
nan_detected[0] = True
break
x = theano.tensor.dscalar('x')
mode = theano.compile.MonitorMode(post_func=detect_nan)
mode = mode.excluding('fusion')
f = theano.function([x], [theano.tensor.log(x) * x],
mode=mode)
# Test that the fusion wasn't done
assert len(f.maker.fgraph.nodes) == 2
f(0) # log(0) * 0 = -inf * 0 = NaN
# Test that we still detect the nan
assert nan_detected[0]
def test_not_inplace():
"""
Test that we can remove optimizers including inplace optimizers
"""
nan_detected = [False]
def detect_nan(i, node, fn):
for output in fn.outputs:
if numpy.isnan(output[0]).any():
print '*** NaN detected ***'
theano.printing.debugprint(node)
print 'Inputs : %s' % [input[0] for input in fn.inputs]
print 'Outputs: %s' % [output[0] for output in fn.outputs]
nan_detected[0] = True
break
x = theano.tensor.vector('x')
mode = theano.compile.MonitorMode(post_func=detect_nan)
#mode = mode.excluding('fusion', 'inplace')
mode = mode.excluding('local_elemwise_fusion',
'inplace_elemwise_optimizer')
o = theano.tensor.outer(x, x)
out = theano.tensor.log(o) * o
f = theano.function([x], [out],
mode=mode)
# Test that the fusion wasn't done
assert len(f.maker.fgraph.nodes) == 5
assert not f.maker.fgraph.toposort()[-1].op.destroy_map
f([0, 0]) # log(0) * 0 = -inf * 0 = NaN
# Test that we still detect the nan
assert nan_detected[0]
...@@ -16,6 +16,12 @@ import itertools ...@@ -16,6 +16,12 @@ import itertools
import distutils.sysconfig import distutils.sysconfig
importlib = None
try:
import importlib
except ImportError:
pass
import numpy.distutils # TODO: TensorType should handle this import numpy.distutils # TODO: TensorType should handle this
import theano import theano
...@@ -278,6 +284,9 @@ def dlimport(fullpath, suffix=None): ...@@ -278,6 +284,9 @@ def dlimport(fullpath, suffix=None):
sys.path[0:0] = [workdir] # insert workdir at beginning (temporarily) sys.path[0:0] = [workdir] # insert workdir at beginning (temporarily)
try: try:
if importlib is not None:
if hasattr(importlib, "invalidate_caches"):
importlib.invalidate_caches()
rval = __import__(module_name, {}, {}, [module_name]) rval = __import__(module_name, {}, {}, [module_name])
if not rval: if not rval:
raise Exception('__import__ failed', fullpath) raise Exception('__import__ failed', fullpath)
......
"""WRITEME""" """WRITEME"""
from copy import copy
import sys
import traceback
import theano import theano
from theano.gof import utils from theano.gof import utils
from theano.gof import graph from theano.gof import graph
from theano.gof.type import Type from theano.gof.type import Type
import sys, traceback
from copy import copy
__excepthook = sys.excepthook __excepthook = sys.excepthook
def log_thunk_trace(value, f=sys.stderr): def log_thunk_trace(value, f=sys.stderr):
"""Log theano's diagnostic stack trace for an exception """Log theano's diagnostic stack trace for an exception
raised by raise_with_op. raised by raise_with_op.
...@@ -110,6 +112,7 @@ def raise_with_op(op, exc_info=None): ...@@ -110,6 +112,7 @@ def raise_with_op(op, exc_info=None):
raise_with_op.print_thunk_trace = False raise_with_op.print_thunk_trace = False
class Linker(object): class Linker(object):
"""WRITEME""" """WRITEME"""
...@@ -132,10 +135,11 @@ class Linker(object): ...@@ -132,10 +135,11 @@ class Linker(object):
print new_e.data # 3.0 print new_e.data # 3.0
print e.data # 3.0 iff inplace == True (else unknown) print e.data # 3.0 iff inplace == True (else unknown)
""" """
raise utils.MethodNotDefined("make_thunk", type(self), self.__class__.__name__) raise utils.MethodNotDefined("make_thunk", type(self),
self.__class__.__name__)
## DELETEME ## ## DELETEME ##
def make_function(self, unpack_single = True, **kwargs): def make_function(self, unpack_single=True, **kwargs):
""" """
Returns a function that takes values corresponding to the inputs of the Returns a function that takes values corresponding to the inputs of the
fgraph used by this L{Linker} and returns values corresponding the the outputs fgraph used by this L{Linker} and returns values corresponding the the outputs
...@@ -155,6 +159,7 @@ class Linker(object): ...@@ -155,6 +159,7 @@ class Linker(object):
length 1 will be returned. length 1 will be returned.
""" """
thunk, inputs, outputs = self.make_thunk(**kwargs) thunk, inputs, outputs = self.make_thunk(**kwargs)
def execute(*args): def execute(*args):
def e_arity(takes, got): def e_arity(takes, got):
return 'Function call takes exactly %i %s (%i given)' \ return 'Function call takes exactly %i %s (%i given)' \
...@@ -165,7 +170,8 @@ class Linker(object): ...@@ -165,7 +170,8 @@ class Linker(object):
variable.data = arg variable.data = arg
thunk() thunk()
if unpack_single: if unpack_single:
return utils.to_return_values([variable.data for variable in outputs]) return utils.to_return_values([variable.data
for variable in outputs])
else: else:
return [variable.data for variable in outputs] return [variable.data for variable in outputs]
execute.thunk = thunk execute.thunk = thunk
...@@ -177,6 +183,7 @@ class Linker(object): ...@@ -177,6 +183,7 @@ class Linker(object):
def schedule(self, fgraph): def schedule(self, fgraph):
return fgraph.toposort() return fgraph.toposort()
#TODO: Move this class to the compile module, where it is used (and for which it exists). #TODO: Move this class to the compile module, where it is used (and for which it exists).
class Container(object): class Container(object):
"""This class joins a variable with its computed value. """This class joins a variable with its computed value.
...@@ -228,8 +235,10 @@ class Container(object): ...@@ -228,8 +235,10 @@ class Container(object):
kwargs['strict'] = True kwargs['strict'] = True
if self.allow_downcast is not None: if self.allow_downcast is not None:
kwargs['allow_downcast'] = self.allow_downcast kwargs['allow_downcast'] = self.allow_downcast
if hasattr(self.type,'filter_inplace'): if hasattr(self.type, 'filter_inplace'):
self.storage[0] = self.type.filter_inplace(value, self.storage[0], **kwargs) self.storage[0] = self.type.filter_inplace(value,
self.storage[0],
**kwargs)
else: else:
self.storage[0] = self.type.filter(value, **kwargs) self.storage[0] = self.type.filter(value, **kwargs)
...@@ -238,8 +247,10 @@ class Container(object): ...@@ -238,8 +247,10 @@ class Container(object):
raise raise
data = property(__get__, __set__) data = property(__get__, __set__)
value = property(__get__, __set__) value = property(__get__, __set__)
def __str__(self): def __str__(self):
return "<" + str(self.storage[0]) + ">" return "<" + str(self.storage[0]) + ">"
def __repr__(self): def __repr__(self):
return "<" + repr(self.storage[0]) + ">" return "<" + repr(self.storage[0]) + ">"
...@@ -285,7 +296,6 @@ def map_storage(fgraph, order, input_storage, output_storage): ...@@ -285,7 +296,6 @@ def map_storage(fgraph, order, input_storage, output_storage):
for r, storage in zip(fgraph.outputs, output_storage): for r, storage in zip(fgraph.outputs, output_storage):
storage_map[r] = storage storage_map[r] = storage
thunks = []
for node in order: for node in order:
for r in node.inputs: for r in node.inputs:
if r not in storage_map: if r not in storage_map:
...@@ -302,6 +312,7 @@ def map_storage(fgraph, order, input_storage, output_storage): ...@@ -302,6 +312,7 @@ def map_storage(fgraph, order, input_storage, output_storage):
return input_storage, output_storage, storage_map return input_storage, output_storage, storage_map
def streamline(fgraph, thunks, order, post_thunk_old_storage=None, def streamline(fgraph, thunks, order, post_thunk_old_storage=None,
no_recycling=None, profiler=None, nice_errors=True): no_recycling=None, profiler=None, nice_errors=True):
"""WRITEME """WRITEME
...@@ -335,14 +346,16 @@ def streamline(fgraph, thunks, order, post_thunk_old_storage=None, ...@@ -335,14 +346,16 @@ def streamline(fgraph, thunks, order, post_thunk_old_storage=None,
if post_thunk_old_storage: if post_thunk_old_storage:
if len(thunks) != len(post_thunk_old_storage): if len(thunks) != len(post_thunk_old_storage):
raise ValueError('Length of thunks and post_thunk_old_storage must match', raise ValueError(
'Length of thunks and post_thunk_old_storage must match',
(len(thunks), len(post_thunk_old_storage))) (len(thunks), len(post_thunk_old_storage)))
def streamline_default_f(): def streamline_default_f():
for x in no_recycling: for x in no_recycling:
x[0] = None x[0] = None
try: try:
for thunk, node, old_storage in zip(thunks, order, post_thunk_old_storage): for thunk, node, old_storage in zip(thunks, order,
post_thunk_old_storage):
thunk() thunk()
for old_s in old_storage: for old_s in old_storage:
old_s[0] = None old_s[0] = None
...@@ -351,6 +364,7 @@ def streamline(fgraph, thunks, order, post_thunk_old_storage=None, ...@@ -351,6 +364,7 @@ def streamline(fgraph, thunks, order, post_thunk_old_storage=None,
f = streamline_default_f f = streamline_default_f
elif nice_errors: elif nice_errors:
thunk_node_list = zip(thunks, order) thunk_node_list = zip(thunks, order)
def streamline_nice_errors_f(): def streamline_nice_errors_f():
for x in no_recycling: for x in no_recycling:
x[0] = None x[0] = None
...@@ -371,16 +385,18 @@ def streamline(fgraph, thunks, order, post_thunk_old_storage=None, ...@@ -371,16 +385,18 @@ def streamline(fgraph, thunks, order, post_thunk_old_storage=None,
f = streamline_fast_f f = streamline_fast_f
return f return f
class LocalLinker(Linker): class LocalLinker(Linker):
"""WRITEME """WRITEME
Useful base class for L{Linker}s which keep all nodes in the graph, and run a Useful base class for L{Linker}s which keep all nodes in the graph, and run a
thunk associated with each node. thunk associated with each node.
""" """
def make_thunk(self, profiler = None, input_storage = None, output_storage = None): def make_thunk(self, profiler=None, input_storage=None,
return self.make_all(profiler = profiler, output_storage=None):
input_storage = input_storage, return self.make_all(profiler=profiler,
output_storage = output_storage)[:3] input_storage=input_storage,
output_storage=output_storage)[:3]
def make_all(self, profiler, input_storage, output_storage): def make_all(self, profiler, input_storage, output_storage):
# By convention, subclasses of LocalLinker should implement this function! # By convention, subclasses of LocalLinker should implement this function!
...@@ -391,7 +407,9 @@ class LocalLinker(Linker): ...@@ -391,7 +407,9 @@ class LocalLinker(Linker):
# 3. output storage # 3. output storage
# 4. thunks: list of nodes' functions in the order they will be run by the function in (1) # 4. thunks: list of nodes' functions in the order they will be run by the function in (1)
# 5. order: list of nodes, in the order they will be run by the function in (1) # 5. order: list of nodes, in the order they will be run by the function in (1)
raise utils.MethodNotDefined("make_all", type(self), self.__class__.__name__) raise utils.MethodNotDefined("make_all", type(self),
self.__class__.__name__)
def gc_helper(node_list): def gc_helper(node_list):
""" """
...@@ -413,6 +431,7 @@ def gc_helper(node_list): ...@@ -413,6 +431,7 @@ def gc_helper(node_list):
computed.add(output) computed.add(output)
return computed, last_user return computed, last_user
class PerformLinker(LocalLinker): class PerformLinker(LocalLinker):
"""WRITEME """WRITEME
...@@ -445,7 +464,7 @@ class PerformLinker(LocalLinker): ...@@ -445,7 +464,7 @@ class PerformLinker(LocalLinker):
self.no_recycling = no_recycling self.no_recycling = no_recycling
return self return self
def make_all(self, profiler = None, input_storage = None, output_storage = None): def make_all(self, profiler=None, input_storage=None, output_storage=None):
""" """
:param profiler: WRITEME :param profiler: WRITEME
:param input_storage: WRITEME :param input_storage: WRITEME
...@@ -460,7 +479,6 @@ class PerformLinker(LocalLinker): ...@@ -460,7 +479,6 @@ class PerformLinker(LocalLinker):
input_storage, output_storage, storage_map = map_storage(fgraph, order, input_storage, output_storage) input_storage, output_storage, storage_map = map_storage(fgraph, order, input_storage, output_storage)
compute_map = {} compute_map = {}
for k in storage_map: for k in storage_map:
compute_map[k] = [k.owner is None] compute_map[k] = [k.owner is None]
...@@ -511,6 +529,7 @@ class PerformLinker(LocalLinker): ...@@ -511,6 +529,7 @@ class PerformLinker(LocalLinker):
[Container(output, storage, True) for output, storage in zip(fgraph.outputs, output_storage)], \ [Container(output, storage, True) for output, storage in zip(fgraph.outputs, output_storage)], \
thunks, order thunks, order
def add_clear_storage(f, computed, storage_map): def add_clear_storage(f, computed, storage_map):
def clear_storage(): def clear_storage():
for c in computed: for c in computed:
...@@ -591,11 +610,13 @@ class WrapLinker(Linker): ...@@ -591,11 +610,13 @@ class WrapLinker(Linker):
if no_recycling is None: if no_recycling is None:
no_recycling = [] no_recycling = []
if self.fgraph is not None and self.fgraph is not fgraph: if self.fgraph is not None and self.fgraph is not fgraph:
return type(self)(self.linkers, self.wrapper).accept(fgraph, no_recycling) return type(self)(self.linkers, self.wrapper).accept(fgraph,
no_recycling)
self.fgraph = fgraph self.fgraph = fgraph
self.no_recycling = no_recycling self.no_recycling = no_recycling
self.linkers = [linker.accept(fgraph, no_recycling) for linker in self.linkers] self.linkers = [linker.accept(fgraph, no_recycling)
for linker in self.linkers]
return self return self
def pre(self, f, inputs, order, thunk_groups): def pre(self, f, inputs, order, thunk_groups):
...@@ -614,7 +635,8 @@ class WrapLinker(Linker): ...@@ -614,7 +635,8 @@ class WrapLinker(Linker):
order_list0 = order_lists[0] order_list0 = order_lists[0]
for order_list in order_lists[1:]: for order_list in order_lists[1:]:
if not order_list0 == order_list: if not order_list0 == order_list:
raise Exception("All linkers to WrapLinker should execute operations in the same order.") raise Exception(
"All linkers to WrapLinker should execute operations in the same order.")
inputs0 = input_lists[0] inputs0 = input_lists[0]
outputs0 = output_lists[0] outputs0 = output_lists[0]
...@@ -631,13 +653,15 @@ class WrapLinker(Linker): ...@@ -631,13 +653,15 @@ class WrapLinker(Linker):
wrapper = self.wrapper wrapper = self.wrapper
pre = self.pre pre = self.pre
def f(): def f():
for inputs in input_lists[1:]: for inputs in input_lists[1:]:
for input1, input2 in zip(inputs0, inputs): for input1, input2 in zip(inputs0, inputs):
input2.storage[0] = copy(input1.storage[0]) input2.storage[0] = copy(input1.storage[0])
for x in to_reset: for x in to_reset:
x[0] = None x[0] = None
pre(self, [input.data for input in input_lists[0]], order, thunk_groups) pre(self, [input.data for input in input_lists[0]],
order, thunk_groups)
for i, (thunks, node) in enumerate(zip(thunk_groups, order)): for i, (thunks, node) in enumerate(zip(thunk_groups, order)):
try: try:
wrapper(i, node, *thunks) wrapper(i, node, *thunks)
...@@ -647,6 +671,7 @@ class WrapLinker(Linker): ...@@ -647,6 +671,7 @@ class WrapLinker(Linker):
return f, inputs0, outputs0 return f, inputs0, outputs0
def WrapLinkerMany(linkers, wrappers): def WrapLinkerMany(linkers, wrappers):
""" """
Variant on WrapLinker that runs a series of wrapper functions instead of Variant on WrapLinker that runs a series of wrapper functions instead of
......
...@@ -22,6 +22,8 @@ from theano.compat import cmp ...@@ -22,6 +22,8 @@ from theano.compat import cmp
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE. # SOFTWARE.
def memodict(f): def memodict(f):
""" Memoization decorator for a function taking a single argument """ """ Memoization decorator for a function taking a single argument """
class memodict(defaultdict): class memodict(defaultdict):
...@@ -41,6 +43,7 @@ def make_depends(): ...@@ -41,6 +43,7 @@ def make_depends():
if ainp.owner)) if ainp.owner))
return depends return depends
def make_dependence_cmp(): def make_dependence_cmp():
""" Create a comparator to represent the dependence of nodes in a graph """ """ Create a comparator to represent the dependence of nodes in a graph """
...@@ -53,18 +56,27 @@ def make_dependence_cmp(): ...@@ -53,18 +56,27 @@ def make_dependence_cmp():
Returns negative number if b depends on a Returns negative number if b depends on a
Returns 0 otherwise Returns 0 otherwise
""" """
if depends((a, b)): return 1 if depends((a, b)):
if depends((b, a)): return -1 return 1
if depends((b, a)):
return -1
return 0 return 0
return dependence return dependence
def reverse_dict(d): def reverse_dict(d):
""" Reverses direction of dependence dict """Reverses direction of dependence dict
>>> d = {'a': (1, 2), 'b': (2, 3), 'c':()} >>> d = {'a': (1, 2), 'b': (2, 3), 'c':()}
>>> reverse_dict(d) >>> reverse_dict(d)
{1: ('a',), 2: ('a', 'b'), 3: ('b',)} {1: ('a',), 2: ('a', 'b'), 3: ('b',)}
:note: dict order are not deterministic. As we iterate on the
input dict, it make the output of this function depend on the
dict order. So this function output order should be considered
as undeterministic.
""" """
result = {} result = {}
for key in d: for key in d:
...@@ -72,6 +84,7 @@ def reverse_dict(d): ...@@ -72,6 +84,7 @@ def reverse_dict(d):
result[val] = result.get(val, tuple()) + (key, ) result[val] = result.get(val, tuple()) + (key, )
return result return result
def _toposort(edges): def _toposort(edges):
""" Topological sort algorithm by Kahn [1] - O(nodes + vertices) """ Topological sort algorithm by Kahn [1] - O(nodes + vertices)
...@@ -106,6 +119,7 @@ def _toposort(edges): ...@@ -106,6 +119,7 @@ def _toposort(edges):
raise ValueError("Input has cycles") raise ValueError("Input has cycles")
return L return L
def posort(l, *cmps): def posort(l, *cmps):
""" Partially ordered sort with multiple comparators """ Partially ordered sort with multiple comparators
...@@ -156,6 +170,7 @@ def posort(l, *cmps): ...@@ -156,6 +170,7 @@ def posort(l, *cmps):
return _toposort(comes_after) return _toposort(comes_after)
def sort_apply_nodes(inputs, outputs, cmps): def sort_apply_nodes(inputs, outputs, cmps):
""" Order a graph of apply nodes according to a list of comparators """ Order a graph of apply nodes according to a list of comparators
...@@ -178,6 +193,7 @@ def sort_apply_nodes(inputs, outputs, cmps): ...@@ -178,6 +193,7 @@ def sort_apply_nodes(inputs, outputs, cmps):
return posort(list_of_nodes(inputs, outputs), *cmps) return posort(list_of_nodes(inputs, outputs), *cmps)
def sort_schedule_fn(*cmps): def sort_schedule_fn(*cmps):
""" Make a schedule function from comparators """ Make a schedule function from comparators
...@@ -186,11 +202,13 @@ def sort_schedule_fn(*cmps): ...@@ -186,11 +202,13 @@ def sort_schedule_fn(*cmps):
""" """
dependence = make_dependence_cmp() dependence = make_dependence_cmp()
cmps = (dependence,) + cmps cmps = (dependence,) + cmps
def schedule(fgraph): def schedule(fgraph):
""" Order nodes in a FunctionGraph """ """ Order nodes in a FunctionGraph """
return sort_apply_nodes(fgraph.inputs, fgraph.outputs, cmps) return sort_apply_nodes(fgraph.inputs, fgraph.outputs, cmps)
return schedule return schedule
def key_to_cmp(key): def key_to_cmp(key):
def key_cmp(a, b): def key_cmp(a, b):
return cmp(key(a), key(b)) return cmp(key(a), key(b))
......
...@@ -7,6 +7,7 @@ from theano.gof.graph import io_toposort ...@@ -7,6 +7,7 @@ from theano.gof.graph import io_toposort
from theano.gof.python25 import any from theano.gof.python25 import any
from theano.compat import cmp from theano.compat import cmp
def test_dependence(): def test_dependence():
dependence = make_dependence_cmp() dependence = make_dependence_cmp()
...@@ -30,7 +31,10 @@ def test_sort_apply_nodes(): ...@@ -30,7 +31,10 @@ def test_sort_apply_nodes():
def test_reverse_dict(): def test_reverse_dict():
d = {'a': (1, 2), 'b': (2, 3), 'c': ()} d = {'a': (1, 2), 'b': (2, 3), 'c': ()}
assert reverse_dict(d) == {1: ('a',), 2: ('a', 'b'), 3: ('b',)} # Python 3.3 enable by default random hash for dict.
# This change the order of traversal, so this can give 2 outputs
assert (reverse_dict(d) == {1: ('a',), 2: ('a', 'b'), 3: ('b',)} or
reverse_dict(d) == {1: ('a',), 2: ('b', 'a'), 3: ('b',)})
def test__toposort(): def test__toposort():
...@@ -44,7 +48,7 @@ def test__toposort(): ...@@ -44,7 +48,7 @@ def test__toposort():
def test_posort_easy(): def test_posort_easy():
nodes = "asdfghjkl" nodes = "asdfghjkl"
def cmp(a, b): def mycmp(a, b):
if a < b: if a < b:
return -1 return -1
elif a > b: elif a > b:
...@@ -52,7 +56,7 @@ def test_posort_easy(): ...@@ -52,7 +56,7 @@ def test_posort_easy():
else: else:
return 0 return 0
assert posort(nodes, cmp) == list("adfghjkls") assert posort(nodes, mycmp) == list("adfghjkls")
def test_posort(): def test_posort():
......
...@@ -360,7 +360,7 @@ class T_softplus_opts(unittest.TestCase): ...@@ -360,7 +360,7 @@ class T_softplus_opts(unittest.TestCase):
f(numpy.random.rand(54).astype(config.floatX)) f(numpy.random.rand(54).astype(config.floatX))
def test_log1msigm_to_softplus(self): def test_log1msigm_to_softplus(self):
x = T.vector() x = T.matrix()
out = T.log(1 - sigmoid(x)) out = T.log(1 - sigmoid(x))
f = theano.function([x], out, mode=self.m) f = theano.function([x], out, mode=self.m)
...@@ -369,7 +369,29 @@ class T_softplus_opts(unittest.TestCase): ...@@ -369,7 +369,29 @@ class T_softplus_opts(unittest.TestCase):
assert isinstance(topo[0].op.scalar_op, assert isinstance(topo[0].op.scalar_op,
theano.tensor.nnet.sigm.ScalarSoftplus) theano.tensor.nnet.sigm.ScalarSoftplus)
assert isinstance(topo[1].op.scalar_op, theano.scalar.Neg) assert isinstance(topo[1].op.scalar_op, theano.scalar.Neg)
f(numpy.random.rand(54).astype(config.floatX)) f(numpy.random.rand(54, 11).astype(config.floatX))
# Same test with a flatten
out = T.log(1 - T.flatten(sigmoid(x)))
f = theano.function([x], out, mode=self.m)
topo = f.maker.fgraph.toposort()
assert len(topo) == 3
assert isinstance(topo[0].op, T.Flatten)
assert isinstance(topo[1].op.scalar_op,
theano.tensor.nnet.sigm.ScalarSoftplus)
assert isinstance(topo[2].op.scalar_op, theano.scalar.Neg)
f(numpy.random.rand(54, 11).astype(config.floatX))
# Same test with a reshape
out = T.log(1 - sigmoid(x).reshape([x.size]))
f = theano.function([x], out, mode=self.m)
topo = f.maker.fgraph.toposort()
#assert len(topo) == 3
assert any(isinstance(node.op, T.Reshape) for node in topo)
assert any(isinstance(getattr(node.op, 'scalar_op', None),
theano.tensor.nnet.sigm.ScalarSoftplus)
for node in topo)
f(numpy.random.rand(54, 11).astype(config.floatX))
def test_log1pexp_to_softplus(self): def test_log1pexp_to_softplus(self):
m = theano.config.mode m = theano.config.mode
......
...@@ -273,8 +273,8 @@ def inplace_elemwise_optimizer_op(OP): ...@@ -273,8 +273,8 @@ def inplace_elemwise_optimizer_op(OP):
return inplace_elemwise_optimizer return inplace_elemwise_optimizer
inplace_elemwise_optimizer = inplace_elemwise_optimizer_op(T.Elemwise) inplace_elemwise_optimizer = inplace_elemwise_optimizer_op(T.Elemwise)
compile.optdb.register('inplace_opt', inplace_elemwise_optimizer, 75, compile.optdb.register('inplace_opt', inplace_elemwise_optimizer, 75,
'inplace_elemwise_optimizer',
'fast_run', 'inplace') 'fast_run', 'inplace')
...@@ -2385,6 +2385,27 @@ def local_div_switch_sink(node): ...@@ -2385,6 +2385,27 @@ def local_div_switch_sink(node):
return False return False
################
# Flatten Opts #
################
@register_canonicalize
@register_stabilize
@gof.local_optimizer([])
def local_flatten_lift(node):
"""
Flatten(UnaryElemwise(x)) -> UnaryElemwise(Flatten(x))
This optimization is needed by optimization
nnet/sigm.py:log1msigm_to_softplus to get applied when there is a flatten.
"""
if (isinstance(node.op, T.Flatten) and
node.inputs[0].owner and
isinstance(node.inputs[0].owner.op, T.Elemwise) and
len(node.inputs[0].owner.inputs) == 1):
f = node.op(node.inputs[0].owner.inputs[0])
e = node.inputs[0].owner.op(f)
return [e]
################## ##################
# Reshape opts # # Reshape opts #
################## ##################
...@@ -2415,6 +2436,26 @@ def local_reshape_chain(node): ...@@ -2415,6 +2436,26 @@ def local_reshape_chain(node):
return False return False
register_canonicalize(local_reshape_chain) register_canonicalize(local_reshape_chain)
@register_canonicalize
@register_stabilize
@gof.local_optimizer([])
def local_reshape_lift(node):
"""
Reshape(UnaryElemwise(x)) -> UnaryElemwise(Reshape(x))
This optimization is needed by optimization
nnet/sigm.py:log1msigm_to_softplus to get applied when there is a reshape.
"""
if (isinstance(node.op, T.Reshape) and
node.inputs[0].owner and
isinstance(node.inputs[0].owner.op, T.Elemwise) and
len(node.inputs[0].owner.inputs) == 1):
r = node.op(node.inputs[0].owner.inputs[0], node.inputs[1])
e = node.inputs[0].owner.op(r)
return [e]
if 0: if 0:
# TODO: Test that this optimziation works. # TODO: Test that this optimziation works.
@register_canonicalize @register_canonicalize
......
...@@ -3988,6 +3988,35 @@ def test_local_div_to_inv(): ...@@ -3988,6 +3988,35 @@ def test_local_div_to_inv():
assert numpy.allclose(out_val, 0.5) assert numpy.allclose(out_val, 0.5)
def test_local_flatten_lift():
for i in range(1, 4):
op = tensor.Flatten(i)
x = tensor.tensor4()
out = op(T.exp(x))
assert out.ndim == i
mode = compile.mode.get_default_mode()
mode = mode.including('local_flatten_lift')
f = theano.function([x], out, mode=mode)
f(numpy.random.rand(5, 4, 3, 2).astype(config.floatX))
topo = f.maker.fgraph.toposort()
assert len(topo) == 2
assert isinstance(topo[0].op, tensor.Flatten)
assert isinstance(topo[1].op, tensor.Elemwise)
def test_local_reshape_lift():
x = tensor.tensor4()
out = T.exp(x).reshape([x.size])
assert out.ndim == 1
mode = compile.mode.get_default_mode()
mode = mode.including('local_reshape_lift')
f = theano.function([x], out, mode=mode)
f(numpy.random.rand(5, 4, 3, 2).astype(config.floatX))
topo = f.maker.fgraph.toposort()
assert isinstance(topo[-2].op, tensor.Reshape)
assert isinstance(topo[-1].op, tensor.Elemwise)
class Test_lift_transpose_through_dot(unittest.TestCase): class Test_lift_transpose_through_dot(unittest.TestCase):
def simple_optimize(self, g): def simple_optimize(self, g):
out2in(opt.local_useless_elemwise).optimize(g) out2in(opt.local_useless_elemwise).optimize(g)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论