提交 3578c80c authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Flake8 for compile/debugmode.py

上级 90e512d6
...@@ -7,7 +7,10 @@ from __future__ import print_function ...@@ -7,7 +7,10 @@ from __future__ import print_function
__docformat__ = "restructuredtext en" __docformat__ = "restructuredtext en"
import copy, sys, copy_reg, gc import copy
import sys
import copy_reg
import gc
from itertools import izip from itertools import izip
import numpy import numpy
...@@ -16,10 +19,9 @@ import theano ...@@ -16,10 +19,9 @@ import theano
from theano import gof from theano import gof
from theano.compat import get_unbound_function, product as itertools_product from theano.compat import get_unbound_function, product as itertools_product
from theano.compat.six import StringIO from theano.compat.six import StringIO
from theano.gof import (FunctionGraph, graph, utils, link, from theano.gof import (graph, utils, link,
ops_with_inner_function) ops_with_inner_function)
from theano.gof.link import raise_with_op from theano.gof.link import raise_with_op
from theano.gof.cc import CLinker
from theano.configparser import (config, AddConfigVar, BoolParam, IntParam, from theano.configparser import (config, AddConfigVar, BoolParam, IntParam,
StrParam) StrParam)
from theano.compile.function_module import ( from theano.compile.function_module import (
...@@ -57,8 +59,8 @@ AddConfigVar('DebugMode.check_strides', ...@@ -57,8 +59,8 @@ AddConfigVar('DebugMode.check_strides',
AddConfigVar('DebugMode.warn_input_not_reused', AddConfigVar('DebugMode.warn_input_not_reused',
("Generate a warning when destroy_map or view_map says that an " ("Generate a warning when destroy_map or view_map says that an "
"op works inplace, but the op did not reuse the input for its output." "op works inplace, but the op did not reuse the input for its "
), "output."),
BoolParam(True), BoolParam(True),
in_c_key=False) in_c_key=False)
...@@ -148,7 +150,7 @@ class BadThunkOutput(DebugModeError): ...@@ -148,7 +150,7 @@ class BadThunkOutput(DebugModeError):
def __init__(self, r, thunk1, val1, thunk2, val2, inputs_val=()): def __init__(self, r, thunk1, val1, thunk2, val2, inputs_val=()):
"""Initialize members""" """Initialize members"""
DebugModeError.__init__(self) # to be compatible with python2.4 super(BadThunkOutput, self).__init__()
self.r = r self.r = r
self.thunk1 = thunk1 self.thunk1 = thunk1
self.val1 = val1 self.val1 = val1
...@@ -173,8 +175,10 @@ class BadThunkOutput(DebugModeError): ...@@ -173,8 +175,10 @@ class BadThunkOutput(DebugModeError):
print(" op :", self.offending_op(), file=sio) print(" op :", self.offending_op(), file=sio)
print(" Outputs Type:", self.r.type, file=sio) print(" Outputs Type:", self.r.type, file=sio)
print(" Outputs Shape:", getattr(self.val1, 'shape', None), file=sio) print(" Outputs Shape:", getattr(self.val1, 'shape', None), file=sio)
print(" Outputs Strides:", getattr(self.val1, 'strides', None), file=sio) print(" Outputs Strides:", getattr(self.val1, 'strides', None),
print(" Inputs Type :", [i.type for i in self.r.owner.inputs], file=sio) file=sio)
print(" Inputs Type :", [i.type for i in self.r.owner.inputs],
file=sio)
print(" Inputs Shape:", [getattr(val, 'shape', None) print(" Inputs Shape:", [getattr(val, 'shape', None)
for val in self.inputs_val], file=sio) for val in self.inputs_val], file=sio)
print(" Inputs Strides:", [getattr(val, 'strides', None) print(" Inputs Strides:", [getattr(val, 'strides', None)
...@@ -226,7 +230,7 @@ class BadOptimization(DebugModeError): ...@@ -226,7 +230,7 @@ class BadOptimization(DebugModeError):
def __init__(self, old_r, new_r, old_r_val, new_r_val, reason, def __init__(self, old_r, new_r, old_r_val, new_r_val, reason,
old_graph, new_graph): old_graph, new_graph):
"""Initialize members""" """Initialize members"""
DebugModeError.__init__(self) # to be compatible with python2.4 super(BadOptimization, self).__init__()
self.old_r = old_r self.old_r = old_r
self.new_r = new_r self.new_r = new_r
self.old_r_val = old_r_val self.old_r_val = old_r_val
...@@ -287,24 +291,20 @@ class BadOptimization(DebugModeError): ...@@ -287,24 +291,20 @@ class BadOptimization(DebugModeError):
ov = numpy.asarray(self.old_r_val) ov = numpy.asarray(self.old_r_val)
nv = numpy.asarray(self.new_r_val) nv = numpy.asarray(self.new_r_val)
ssio = StringIO() ssio = StringIO()
print(" Max Abs Diff: ", numpy.max(numpy.absolute(nv - abs_diff = numpy.absolute(nv - ov)
ov)), file=ssio) print(" Max Abs Diff: ", numpy.max(abs_diff), file=ssio)
print(" Mean Abs Diff: ", numpy.mean(numpy.absolute(nv - print(" Mean Abs Diff: ", numpy.mean(abs_diff), file=ssio)
ov)), file=ssio) print(" Median Abs Diff: ", numpy.median(abs_diff), file=ssio)
print(" Median Abs Diff: ", numpy.median(numpy.absolute( print(" Std Abs Diff: ", numpy.std(abs_diff), file=ssio)
nv - ov)), file=ssio) arg_max_val = numpy.argmax(abs_diff)
print(" Std Abs Diff: ", numpy.std(numpy.absolute(
nv - ov)), file=ssio)
arg_max_val = numpy.argmax(numpy.absolute(nv - ov))
values_at_max = (nv.flatten()[arg_max_val], values_at_max = (nv.flatten()[arg_max_val],
ov.flatten()[arg_max_val]) ov.flatten()[arg_max_val])
print(" Value at Max Diff: ", values_at_max, file=ssio) print(" Value at Max Diff: ", values_at_max, file=ssio)
# N.B. the maximum(..., 1e-8) protects against div by 0 when # N.B. the maximum(..., 1e-8) protects against div by 0 when
# nv == ov == 0 # nv == ov == 0
reldiff = (numpy.absolute(nv - ov) reldiff = (abs_diff
/ numpy.maximum( / numpy.maximum(numpy.absolute(nv) + numpy.absolute(ov),
numpy.absolute(nv) + numpy.absolute(ov),
1e-8)) 1e-8))
print(" Max Rel Diff: ", numpy.max(reldiff), file=ssio) print(" Max Rel Diff: ", numpy.max(reldiff), file=ssio)
print(" Mean Rel Diff: ", numpy.mean(reldiff), file=ssio) print(" Mean Rel Diff: ", numpy.mean(reldiff), file=ssio)
...@@ -325,8 +325,10 @@ class BadOptimization(DebugModeError): ...@@ -325,8 +325,10 @@ class BadOptimization(DebugModeError):
print(" New Graph:", file=sio) print(" New Graph:", file=sio)
print(self.new_graph, file=sio) print(self.new_graph, file=sio)
print("", file=sio) print("", file=sio)
print("Hint: relax the tolerance by setting tensor.cmp_sloppy=1", file=sio) print("Hint: relax the tolerance by setting tensor.cmp_sloppy=1",
print(" or even tensor.cmp_sloppy=2 for less-strict comparison", file=sio) file=sio)
print(" or even tensor.cmp_sloppy=2 for less-strict comparison",
file=sio)
return sio.getvalue() return sio.getvalue()
...@@ -334,8 +336,7 @@ class BadDestroyMap(DebugModeError): ...@@ -334,8 +336,7 @@ class BadDestroyMap(DebugModeError):
"""Exception: Some perform() or c_code() modified an input that """Exception: Some perform() or c_code() modified an input that
wasn't in the destroy_map""" wasn't in the destroy_map"""
def __init__(self, node, idx, old_val, new_val, perform): def __init__(self, node, idx, old_val, new_val, perform):
#super(BadDestroyMap, self).__init__() super(BadDestroyMap, self).__init__()
DebugModeError.__init__(self) # to be compatible with python2.4
self.node = node self.node = node
self.idx = idx self.idx = idx
self.old_val = old_val self.old_val = old_val
...@@ -351,26 +352,38 @@ class BadDestroyMap(DebugModeError): ...@@ -351,26 +352,38 @@ class BadDestroyMap(DebugModeError):
print(" destroy_map:", getattr(self.node.op, print(" destroy_map:", getattr(self.node.op,
'destroy_map', {}), file=sio) 'destroy_map', {}), file=sio)
print(" changed input idx:", self.idx, file=sio) print(" changed input idx:", self.idx, file=sio)
print(" changed input type:", self.node.inputs[self.idx].type, file=sio) print(" changed input type:", self.node.inputs[self.idx].type,
file=sio)
print(" repr (old val):", repr(self.old_val), file=sio) print(" repr (old val):", repr(self.old_val), file=sio)
print(" repr (new val):", repr(self.new_val), file=sio) print(" repr (new val):", repr(self.new_val), file=sio)
try: try:
npy_old_val = numpy.asarray(self.old_val) npy_old_val = numpy.asarray(self.old_val)
npy_new_val = numpy.asarray(self.new_val) npy_new_val = numpy.asarray(self.new_val)
print(" value dtype (new <space> old):", npy_new_val.dtype, npy_old_val.dtype, file=sio) print(" value dtype (new <space> old):", npy_new_val.dtype,
print(" value shape (new <space> old):", npy_new_val.shape, npy_old_val.shape, file=sio) npy_old_val.dtype, file=sio)
print(" value min (new <space> old):", npy_new_val.min(), npy_old_val.min(), file=sio) print(" value shape (new <space> old):", npy_new_val.shape,
print(" value max (new <space> old):", npy_new_val.max(), npy_old_val.max(), file=sio) npy_old_val.shape, file=sio)
print(" value min (new <space> old):", npy_new_val.min(),
npy_old_val.min(), file=sio)
print(" value max (new <space> old):", npy_new_val.max(),
npy_old_val.max(), file=sio)
delta = npy_new_val - npy_old_val delta = npy_new_val - npy_old_val
print(" value min (new-old):", delta.min(), file=sio) print(" value min (new-old):", delta.min(), file=sio)
print(" value max (new-old):", delta.max(), file=sio) print(" value max (new-old):", delta.max(), file=sio)
print(" value argmin (new-old):", numpy.unravel_index(delta.argmin(), npy_new_val.shape), file=sio) print(" value argmin (new-old):",
print(" value argmax (new-old):", numpy.unravel_index(delta.argmax(), npy_new_val.shape), file=sio) numpy.unravel_index(delta.argmin(), npy_new_val.shape),
print(" location of first 10 mismatches:", numpy.transpose(numpy.nonzero(delta))[:10], file=sio) file=sio)
print(" value argmax (new-old):",
numpy.unravel_index(delta.argmax(), npy_new_val.shape),
file=sio)
print(" location of first 10 mismatches:",
numpy.transpose(numpy.nonzero(delta))[:10], file=sio)
print("", file=sio) print("", file=sio)
except Exception as e: except Exception as e:
print("(Numpy-hints failed with: %s)" % str(e), file=sio) print("(Numpy-hints failed with: %s)" % str(e), file=sio)
print(" Hint: this can also be caused by a deficient values_eq_approx() or __eq__() implementation [which compared input values]", file=sio) print(" Hint: this can also be caused by a deficient "
"values_eq_approx() or __eq__() implementation "
"[which compared input values]", file=sio)
return sio.getvalue() return sio.getvalue()
...@@ -379,8 +392,7 @@ class BadViewMap(DebugModeError): ...@@ -379,8 +392,7 @@ class BadViewMap(DebugModeError):
that wasn't in the view_map""" that wasn't in the view_map"""
def __init__(self, node, output_idx, out_storage, def __init__(self, node, output_idx, out_storage,
in_alias_idx=None, out_alias_idx=None): in_alias_idx=None, out_alias_idx=None):
#super(BadViewMap, self).__init__() super(BadViewMap, self).__init__()
DebugModeError.__init__(self) # to be compatible with python2.4
self.node = node self.node = node
self.output_idx = output_idx self.output_idx = output_idx
self.out_storage = out_storage self.out_storage = out_storage
...@@ -425,8 +437,7 @@ class InvalidValueError(DebugModeError): ...@@ -425,8 +437,7 @@ class InvalidValueError(DebugModeError):
the Type of that output""" the Type of that output"""
def __init__(self, r, v, client_node=None, hint='none', def __init__(self, r, v, client_node=None, hint='none',
specific_hint='none'): specific_hint='none'):
#super(InvalidValueError, self).__init__() super(InvalidValueError, self).__init__()
DebugModeError.__init__(self) # to be compatible with python2.4
self.r = r self.r = r
self.v = v self.v = v
self.client_node = client_node self.client_node = client_node
...@@ -454,7 +465,8 @@ class InvalidValueError(DebugModeError): ...@@ -454,7 +465,8 @@ class InvalidValueError(DebugModeError):
client_node = self.client_node client_node = self.client_node
hint = self.hint hint = self.hint
specific_hint = self.specific_hint specific_hint = self.specific_hint
context = debugprint(r, prefix=' ', depth=12, file=StringIO()).getvalue() context = debugprint(r, prefix=' ', depth=12,
file=StringIO()).getvalue()
return """InvalidValueError return """InvalidValueError
type(variable) = %(type_r)s type(variable) = %(type_r)s
variable = %(r)s variable = %(r)s
...@@ -512,7 +524,8 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False, ...@@ -512,7 +524,8 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
and their associated printed ids and their associated printed ids
:param print_type: whether to print the Variable type after the other infos :param print_type: whether to print the Variable type after the other infos
:param file: file-like object to which to print :param file: file-like object to which to print
:param print_destroy_map: whether to print the op destroy_map after other info :param print_destroy_map: whether to print the op destroy_map after
other info
:param print_view_map: whether to print the op view_map after other info :param print_view_map: whether to print the op view_map after other info
:param order: If not empty will print the index in the toposort. :param order: If not empty will print the index in the toposort.
:param ids: How do we print the identifier of the variable :param ids: How do we print the identifier of the variable
...@@ -592,7 +605,7 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False, ...@@ -592,7 +605,7 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
already_printed = a in done # get_id_str put it in the dict already_printed = a in done # get_id_str put it in the dict
id_str = get_id_str(a) id_str = get_id_str(a)
if profile == None or a not in profile.apply_time: if profile is None or a not in profile.apply_time:
if len(a.outputs) == 1: if len(a.outputs) == 1:
print('%s%s %s%s \'%s\' %s %s %s' % (prefix, a.op, print('%s%s %s%s \'%s\' %s %s %s' % (prefix, a.op,
id_str, id_str,
...@@ -617,7 +630,8 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False, ...@@ -617,7 +630,8 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
tot_time_percent = (tot_time_dict[a] / profile.fct_call_time) * 100 tot_time_percent = (tot_time_dict[a] / profile.fct_call_time) * 100
if len(a.outputs) == 1: if len(a.outputs) == 1:
print('%s%s %s%s \'%s\' %s %s %s --> %8.2es %4.1f%% %8.2es %4.1f%%'\ print("%s%s %s%s '%s' %s %s %s --> "
"%8.2es %4.1f%% %8.2es %4.1f%%"
% (prefix, a.op, % (prefix, a.op,
id_str, id_str,
type_str, type_str,
...@@ -629,7 +643,8 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False, ...@@ -629,7 +643,8 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
tot_time, tot_time,
tot_time_percent), file=file) tot_time_percent), file=file)
else: else:
print('%s%s.%i %s%s \'%s\' %s %s %s --> %8.2es %4.1f%% %8.2es %4.1f%%'\ print("%s%s.%i %s%s '%s' %s %s %s --> "
"%8.2es %4.1f%% %8.2es %4.1f%%"
% (prefix, a.op, % (prefix, a.op,
a.outputs.index(r), a.outputs.index(r),
id_str, type_str, id_str, type_str,
...@@ -652,14 +667,15 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False, ...@@ -652,14 +667,15 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
new_prefix_child = prefix_child + ' ' new_prefix_child = prefix_child + ' '
if hasattr(i, 'owner') and hasattr(i.owner, 'op'): if hasattr(i, 'owner') and hasattr(i.owner, 'op'):
if isinstance(i.owner.op, theano.scan_module.scan_op.Scan): if isinstance(i.owner.op,
theano.scan_module.scan_op.Scan):
scan_ops.append(i) scan_ops.append(i)
debugprint(i, new_prefix, depth=depth - 1, done=done, debugprint(i, new_prefix, depth=depth - 1, done=done,
print_type=print_type, file=file, order=order, print_type=print_type, file=file, order=order,
ids=ids, stop_on_name=stop_on_name, ids=ids, stop_on_name=stop_on_name,
prefix_child=new_prefix_child, scan_ops=scan_ops, prefix_child=new_prefix_child,
profile=profile) scan_ops=scan_ops, profile=profile)
else: else:
# this is an input variable # this is an input variable
...@@ -679,7 +695,8 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False): ...@@ -679,7 +695,8 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False):
:param accept_inplace: are inplace ops permitted in the original graph? :param accept_inplace: are inplace ops permitted in the original graph?
:type accept_inplace: Bool :type accept_inplace: Bool
:rtype: `FunctionGraph` :rtype: `FunctionGraph`
:returns: a new FunctionGraph with a cloned graph, with debugging `Feature` instances already installed. :returns: a new FunctionGraph with a cloned graph, with debugging
`Feature` instances already installed.
""" """
orig_inputs = [spec.variable for spec in input_specs] orig_inputs = [spec.variable for spec in input_specs]
updates = [spec.update for spec in input_specs if spec.update] updates = [spec.update for spec in input_specs if spec.update]
...@@ -687,15 +704,15 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False): ...@@ -687,15 +704,15 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False):
equivalence_tracker = _VariableEquivalenceTracker() equivalence_tracker = _VariableEquivalenceTracker()
fgraph = gof.fg.FunctionGraph(orig_inputs, orig_outputs, fgraph = gof.fg.FunctionGraph(orig_inputs, orig_outputs,
features=[equivalence_tracker])
# DestroyHandler may not be needed yet, as there is usually no # DestroyHandler may not be needed yet, as there is usually no
# inplace operation in the graph at this stage. DestroyHandler # inplace operation in the graph at this stage. DestroyHandler
# will be installed by an optimization after canonicalization, # will be installed by an optimization after canonicalization,
# before the inplace operations are applied. # before the inplace operations are applied. This results in a big
# This results in a big speed gain. # speed gain.
#
# If inplace operations are accepted and present, however, # If inplace operations are accepted and present, however,
# DestroyHandler will be inserted in the loop below. # DestroyHandler will be inserted in the loop below.
# features=[equivalence_tracker, gof.DestroyHandler(do_imports_on_attach=False)])
features=[equivalence_tracker])
if not accept_inplace: if not accept_inplace:
for node in fgraph.apply_nodes: for node in fgraph.apply_nodes:
...@@ -711,7 +728,8 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False): ...@@ -711,7 +728,8 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False):
break break
# We need to protect all immutable inputs from inplace operations. # We need to protect all immutable inputs from inplace operations.
fgraph.attach_feature(Supervisor(input for spec, input in zip(input_specs, fgraph.inputs) fgraph.attach_feature(Supervisor(
input for spec, input in zip(input_specs, fgraph.inputs)
if not (spec.mutable or (hasattr(fgraph, 'destroyers') if not (spec.mutable or (hasattr(fgraph, 'destroyers')
and fgraph.destroyers(input))))) and fgraph.destroyers(input)))))
...@@ -797,8 +815,8 @@ def _check_inputs(node, storage_map, r_vals, dr_vals, active_nodes, ...@@ -797,8 +815,8 @@ def _check_inputs(node, storage_map, r_vals, dr_vals, active_nodes,
continue continue
if not may_share: if not may_share:
_logger.warning("Optimization Warning: input idx %d marked " _logger.warning("Optimization Warning: input idx %d marked "
"as viewed but new memory allocated by node '%s'", "as viewed but new memory allocated by node "
ii[0], str(node)) "'%s'", ii[0], str(node))
for r_idx, r in enumerate(node.inputs): for r_idx, r in enumerate(node.inputs):
if not r.type.values_eq(r_vals[r], storage_map[r][0]): if not r.type.values_eq(r_vals[r], storage_map[r][0]):
...@@ -808,10 +826,12 @@ def _check_inputs(node, storage_map, r_vals, dr_vals, active_nodes, ...@@ -808,10 +826,12 @@ def _check_inputs(node, storage_map, r_vals, dr_vals, active_nodes,
# ok, we expected r to be destroyed # ok, we expected r to be destroyed
if node in active_nodes: if node in active_nodes:
if dr_vals.get(r, (0, node))[1] is not node: if dr_vals.get(r, (0, node))[1] is not node:
# bad: there should only be one active node that destroys any variable # bad: there should only be one active node
# that destroys any variable
raise Exception('failure in topological ordering') raise Exception('failure in topological ordering')
if clobber_dr_vals: if clobber_dr_vals:
dr_vals[r] = (storage_map[r][0], node) # no copy, this is the last use of this variable # no copy, this is the last use of this variable
dr_vals[r] = (storage_map[r][0], node)
# make sure that dr_vals[r] doens't get used again # make sure that dr_vals[r] doens't get used again
storage_map[r][0] = data_destroyed storage_map[r][0] = data_destroyed
else: else:
...@@ -823,8 +843,10 @@ def _check_inputs(node, storage_map, r_vals, dr_vals, active_nodes, ...@@ -823,8 +843,10 @@ def _check_inputs(node, storage_map, r_vals, dr_vals, active_nodes,
def _check_viewmap(node, storage_map): def _check_viewmap(node, storage_map):
""" """
This functions raises a BadViewMap exception when it detects the following: This functions raises a BadViewMap exception when it detects the
- output node storages aliased to input storage, with no declaration in view_map following:
- output node storages aliased to input storage, with no declaration
in view_map
- if not aliased to an input, check if two outputs are aliased together - if not aliased to an input, check if two outputs are aliased together
and used subsequently in the graph and used subsequently in the graph
""" """
...@@ -1062,16 +1084,13 @@ def _find_bad_optimizations2(order, reasons, r_vals): ...@@ -1062,16 +1084,13 @@ def _find_bad_optimizations2(order, reasons, r_vals):
return return
checked_variables.add(r) checked_variables.add(r)
# (recursively) first check all the variables that could make r look bad: # (recursively) first check all the variables that could make
# r look bad:
list_of_vars = [old_r for (reason, old_r, olds, news) in reasons[r]] list_of_vars = [old_r for (reason, old_r, olds, news) in reasons[r]]
if (None is not r.owner): if (None is not r.owner):
list_of_vars += r.owner.inputs list_of_vars += r.owner.inputs
for var_that_could_make_r_look_bad in \ for var_that_could_make_r_look_bad in list_of_vars:
list_of_vars:
# backport
#[old_r for (reason, old_r, olds, news) in reasons[r]] \
#+ ([] if (None is r.owner) else r.owner.inputs):
check_variable(var_that_could_make_r_look_bad) check_variable(var_that_could_make_r_look_bad)
check_variable_norec(r) check_variable_norec(r)
...@@ -1087,8 +1106,8 @@ _find_bad_optimizations = _find_bad_optimizations0 ...@@ -1087,8 +1106,8 @@ _find_bad_optimizations = _find_bad_optimizations0
def _get_preallocated_maps(node, thunk, prealloc_modes, def_val, def _get_preallocated_maps(node, thunk, prealloc_modes, def_val,
storage_map, r_vals, dr_vals, perform, active_order_set, storage_map, r_vals, dr_vals, perform,
inplace_outs, init_outputs): active_order_set, inplace_outs, init_outputs):
'''Preallocate outputs in different memory layouts''' '''Preallocate outputs in different memory layouts'''
# To avoid circular imports # To avoid circular imports
...@@ -1310,8 +1329,8 @@ def _get_preallocated_maps(node, thunk, prealloc_modes, def_val, ...@@ -1310,8 +1329,8 @@ def _get_preallocated_maps(node, thunk, prealloc_modes, def_val,
def _check_preallocated_output(node, thunk, prealloc_modes, def_val, def _check_preallocated_output(node, thunk, prealloc_modes, def_val,
storage_map, r_vals, dr_vals, perform, active_order_set, storage_map, r_vals, dr_vals, perform,
inplace_outs, init_outputs): active_order_set, inplace_outs, init_outputs):
'''Try to apply thunk() on different output storages''' '''Try to apply thunk() on different output storages'''
# If node has an inner compiled Theano function with mode DebugMode, # If node has an inner compiled Theano function with mode DebugMode,
...@@ -1378,7 +1397,8 @@ def _check_preallocated_output(node, thunk, prealloc_modes, def_val, ...@@ -1378,7 +1397,8 @@ def _check_preallocated_output(node, thunk, prealloc_modes, def_val,
# Check outputs # Check outputs
for r in node.outputs: for r in node.outputs:
if not r.type.is_valid_value(storage_map[r][0]): if not r.type.is_valid_value(storage_map[r][0]):
raise InvalidValueError(r, storage_map[r][0], raise InvalidValueError(
r, storage_map[r][0],
hint=thunk_name, hint=thunk_name,
specific_hint=r.type.value_validity_msg( specific_hint=r.type.value_validity_msg(
storage_map[r][0])) storage_map[r][0]))
...@@ -1393,10 +1413,13 @@ def _check_preallocated_output(node, thunk, prealloc_modes, def_val, ...@@ -1393,10 +1413,13 @@ def _check_preallocated_output(node, thunk, prealloc_modes, def_val,
for r in node.outputs: for r in node.outputs:
if not check_eq(r, r_vals[r], storage_map[r][0]): if not check_eq(r, r_vals[r], storage_map[r][0]):
# TODO: indicate it is not a C/Py problem # TODO: indicate it is not a C/Py problem
inputs_val = [storage_map[inp][0] for inp in r.owner.inputs] inputs_val = [storage_map[inp][0] for inp in
r.owner.inputs]
raise BadThunkOutput(r, raise BadThunkOutput(r,
thunk1='Reference value', val1=r_vals[r], thunk1='Reference value',
thunk2=thunk_name, val2=storage_map[r][0], val1=r_vals[r],
thunk2=thunk_name,
val2=storage_map[r][0],
inputs_val=inputs_val) inputs_val=inputs_val)
# Clear storage_map # Clear storage_map
...@@ -1455,8 +1478,6 @@ class _FunctionGraphEvent(object): ...@@ -1455,8 +1478,6 @@ class _FunctionGraphEvent(object):
str(self.op), str(self.op),
str(self.idx), str(self.idx),
msg]) msg])
# backport
# str(len(self.node.inputs)) if (self.op != 'output') else ''])
else: else:
return str(self.__dict__) return str(self.__dict__)
...@@ -1475,7 +1496,8 @@ class _FunctionGraphEvent(object): ...@@ -1475,7 +1496,8 @@ class _FunctionGraphEvent(object):
class _VariableEquivalenceTracker(object): class _VariableEquivalenceTracker(object):
"""A FunctionGraph Feature that keeps tabs on an FunctionGraph and tries to detect problems.""" """A FunctionGraph Feature that keeps tabs on an FunctionGraph and
tries to detect problems."""
fgraph = None fgraph = None
"""WRITEME""" """WRITEME"""
...@@ -1524,7 +1546,6 @@ class _VariableEquivalenceTracker(object): ...@@ -1524,7 +1546,6 @@ class _VariableEquivalenceTracker(object):
def on_prune(self, fgraph, node, reason): def on_prune(self, fgraph, node, reason):
self.event_list.append(_FunctionGraphEvent('prune', node, self.event_list.append(_FunctionGraphEvent('prune', node,
reason=reason)) reason=reason))
# print 'PRUNING NODE', node, id(node)
assert node in self.active_nodes assert node in self.active_nodes
assert node not in self.inactive_nodes assert node not in self.inactive_nodes
self.active_nodes.remove(node) self.active_nodes.remove(node)
...@@ -1534,7 +1555,6 @@ class _VariableEquivalenceTracker(object): ...@@ -1534,7 +1555,6 @@ class _VariableEquivalenceTracker(object):
self.event_list.append(_FunctionGraphEvent('import', node, self.event_list.append(_FunctionGraphEvent('import', node,
reason=reason)) reason=reason))
# print 'NEW NODE', node, id(node)
assert node not in self.active_nodes assert node not in self.active_nodes
self.active_nodes.add(node) self.active_nodes.add(node)
...@@ -1554,7 +1574,6 @@ class _VariableEquivalenceTracker(object): ...@@ -1554,7 +1574,6 @@ class _VariableEquivalenceTracker(object):
self.replaced_by.setdefault(r, []) self.replaced_by.setdefault(r, [])
def on_change_input(self, fgraph, node, i, r, new_r, reason=None): def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
# print 'CHANGE by', reason, 'to use', new_r, type(new_r)
self.event_list.append(_FunctionGraphEvent('change', node, self.event_list.append(_FunctionGraphEvent('change', node,
reason=str(reason), idx=i)) reason=str(reason), idx=i))
...@@ -1570,7 +1589,8 @@ class _VariableEquivalenceTracker(object): ...@@ -1570,7 +1589,8 @@ class _VariableEquivalenceTracker(object):
# N.B. compute the debugprint now, because future # N.B. compute the debugprint now, because future
# optimizations will change the graph # optimizations will change the graph
done = dict() done = dict()
self.reasons[new_r].append((reason, self.reasons[new_r].append(
(reason,
r, r,
debugprint(r, prefix=' ', depth=6, debugprint(r, prefix=' ', depth=6,
file=StringIO(), done=done).getvalue(), file=StringIO(), done=done).getvalue(),
...@@ -1647,26 +1667,28 @@ class _Linker(gof.link.LocalLinker): ...@@ -1647,26 +1667,28 @@ class _Linker(gof.link.LocalLinker):
self.no_recycling = no_recycling self.no_recycling = no_recycling
return self return self
def make_all(self, profiler=None, input_storage=None def make_all(self, profiler=None, input_storage=None,
, output_storage=None): output_storage=None):
if 1:
# can't import at toplevel because of circular import TODO: # can't import at toplevel because of circular import TODO:
# don't do this ugly hacky way of setting the # don't do this ugly hacky way of setting the
# filter_checks_isfinite # filter_checks_isfinite
from theano.tensor import TensorType # to set filter_check_isfinite from theano.tensor import TensorType # to set filter_check_isfinite
fgraph = self.fgraph fgraph = self.fgraph
input_storage_ = input_storage input_storage_ = input_storage
output_storage_ = output_storage output_storage_ = output_storage
#order = self.schedule(fgraph)
# Compute a topological ordering that IGNORES the destroy_map of destructive Ops. # Compute a topological ordering that IGNORES the destroy_map
# This will be OK, because every thunk is evaluated on a copy of its input. # of destructive Ops. This will be OK, because every thunk is
order_outputs = copy.copy(fgraph.equivalence_tracker.all_variables_ever) # evaluated on a copy of its input.
fgraph_equiv = fgraph.equivalence_tracker
order_outputs = copy.copy(fgraph_equiv.all_variables_ever)
del fgraph_equiv
order_outputs.reverse() order_outputs.reverse()
order = graph.io_toposort(fgraph.inputs, order_outputs) order = graph.io_toposort(fgraph.inputs, order_outputs)
active_order = self.schedule(fgraph) # an ordering of just the active nodes # an ordering of just the active nodes
active_order = self.schedule(fgraph)
active_order_set = set(active_order) active_order_set = set(active_order)
# Disable no_recycling, in order to be able to use # Disable no_recycling, in order to be able to use
...@@ -1682,9 +1704,6 @@ class _Linker(gof.link.LocalLinker): ...@@ -1682,9 +1704,6 @@ class _Linker(gof.link.LocalLinker):
thunks_c = [] # c thunks thunks_c = [] # c thunks
for node in order: for node in order:
node_input_storage = [storage_map[r] for r in node.inputs]
node_output_storage = [storage_map[r] for r in node.outputs]
compute_map = {} compute_map = {}
for k in node.inputs: for k in node.inputs:
compute_map[k] = [True] compute_map[k] = [True]
...@@ -1696,7 +1715,8 @@ class _Linker(gof.link.LocalLinker): ...@@ -1696,7 +1715,8 @@ class _Linker(gof.link.LocalLinker):
# the compilation of some dependency is triggered there. # the compilation of some dependency is triggered there.
thunk_other = None thunk_other = None
if get_unbound_function(node.op.make_thunk) not in default_make_thunk: if (get_unbound_function(node.op.make_thunk) not in
default_make_thunk):
thunk = node.op.make_thunk(node, thunk = node.op.make_thunk(node,
storage_map, storage_map,
compute_map, compute_map,
...@@ -1725,7 +1745,8 @@ class _Linker(gof.link.LocalLinker): ...@@ -1725,7 +1745,8 @@ class _Linker(gof.link.LocalLinker):
# raises an not implemented exception), so in those cases we # raises an not implemented exception), so in those cases we
# consider that we don't have a python implementation # consider that we don't have a python implementation
if ((self.maker.mode.check_py_code or thunks_c[-1] is None) and if ((self.maker.mode.check_py_code or thunks_c[-1] is None) and
node.op.perform.func_code != gof.op.PureOp.perform.func_code): (node.op.perform.func_code !=
gof.op.PureOp.perform.func_code)):
thunk = node.op.make_py_thunk(node, storage_map, compute_map, thunk = node.op.make_py_thunk(node, storage_map, compute_map,
no_recycling) no_recycling)
thunks_py.append(thunk) thunks_py.append(thunk)
...@@ -1739,7 +1760,9 @@ class _Linker(gof.link.LocalLinker): ...@@ -1739,7 +1760,9 @@ class _Linker(gof.link.LocalLinker):
elif thunks_c[-1] is None: elif thunks_c[-1] is None:
thunks_c[-1] = thunk_other thunks_c[-1] = thunk_other
else: else:
_logger.warn("We won't check the perform function of node '%s' but we will check its make_thunk function" % node) _logger.warn("We won't check the perform function "
"of node '%s' but we will check its "
"make_thunk function" % node)
thunks_py[-1] = thunk_other thunks_py[-1] = thunk_other
# Use self.no_recycling (that was passed in accept()) to always # Use self.no_recycling (that was passed in accept()) to always
...@@ -1747,7 +1770,8 @@ class _Linker(gof.link.LocalLinker): ...@@ -1747,7 +1770,8 @@ class _Linker(gof.link.LocalLinker):
# function's outputs. no_recycling_map will be used in f() below. # function's outputs. no_recycling_map will be used in f() below.
if self.no_recycling is True: if self.no_recycling is True:
no_recycling_map = storage_map.values() no_recycling_map = storage_map.values()
no_recycling_map = utils.difference(no_recycling_map, input_storage) no_recycling_map = utils.difference(no_recycling_map,
input_storage)
else: else:
no_recycling_map = [storage_map[r] for r in self.no_recycling no_recycling_map = [storage_map[r] for r in self.no_recycling
if r not in fgraph.inputs] if r not in fgraph.inputs]
...@@ -1784,14 +1808,17 @@ class _Linker(gof.link.LocalLinker): ...@@ -1784,14 +1808,17 @@ class _Linker(gof.link.LocalLinker):
# the evaluation of this function, even when the graph # the evaluation of this function, even when the graph
# has destructive ops in it # has destructive ops in it
# #
# This dictionary is used to populate the storage_map as necessary # This dictionary is used to populate the storage_map
# as necessary
r_vals = {} r_vals = {}
# dr_vals are the values taken by variables after being destroyed # dr_vals are the values taken by variables after
# being destroyed
dr_vals = {} dr_vals = {}
assert len(thunks_py) == len(order) assert len(thunks_py) == len(order)
# transfer the initial values from the storage_map to the r_vals # transfer the initial values from the storage_map to
# the r_vals
_logger.debug("DEBUGMODE: transfer initial values") _logger.debug("DEBUGMODE: transfer initial values")
# r_vals_initialized keeps track of the values that have # r_vals_initialized keeps track of the values that have
# actually been transferred from storage_map to r_vals # actually been transferred from storage_map to r_vals
...@@ -1803,17 +1830,20 @@ class _Linker(gof.link.LocalLinker): ...@@ -1803,17 +1830,20 @@ class _Linker(gof.link.LocalLinker):
# for a Generic object). We only want to raise # for a Generic object). We only want to raise
# an error if it is not valid. # an error if it is not valid.
if (storage_map[r][0] is None): if (storage_map[r][0] is None):
raise InvalidValueError(r, storage_map[r][0], raise InvalidValueError(
hint="Graph Input '%s' is missing" % str(r)) r, storage_map[r][0],
raise InvalidValueError(r, storage_map[r][0], hint=("Graph Input '%s' is missing" %
str(r)))
raise InvalidValueError(
r, storage_map[r][0],
hint=("Graph Input '%s' has invalid value " hint=("Graph Input '%s' has invalid value "
"%s" % (r, storage_map[r][0]))) "%s" % (r, storage_map[r][0])))
r_vals[r] = storage_map[r][0] r_vals[r] = storage_map[r][0]
storage_map[r][0] = None storage_map[r][0] = None
r_vals_initialized.append(r) r_vals_initialized.append(r)
# store preallocated outputs in another map, and test the thunks on # store preallocated outputs in another map, and test
# them as output storages. # the thunks on them as output storages.
init_outputs = {} init_outputs = {}
for r in storage_map: for r in storage_map:
if r in fgraph.outputs: if r in fgraph.outputs:
...@@ -1835,8 +1865,6 @@ class _Linker(gof.link.LocalLinker): ...@@ -1835,8 +1865,6 @@ class _Linker(gof.link.LocalLinker):
for i, (thunk_py, thunk_c, node) in enumerate(zip(thunks_py, for i, (thunk_py, thunk_c, node) in enumerate(zip(thunks_py,
thunks_c, thunks_c,
order)): order)):
this_node_destroyed_variables = set()
_logger.debug("%i - starting node %i %s", i, i, node) _logger.debug("%i - starting node %i %s", i, i, node)
# put a copy of each input into the storage_map # put a copy of each input into the storage_map
...@@ -1844,7 +1872,6 @@ class _Linker(gof.link.LocalLinker): ...@@ -1844,7 +1872,6 @@ class _Linker(gof.link.LocalLinker):
for r in node.inputs: for r in node.inputs:
assert isinstance(r, gof.Variable) assert isinstance(r, gof.Variable)
assert r in r_vals assert r in r_vals
# print >> sys.stderr,i, "DEBUGMODE: deepcopy input ", r
storage_map[r][0] = _lessbroken_deepcopy(r_vals[r]) storage_map[r][0] = _lessbroken_deepcopy(r_vals[r])
if not r.type.is_valid_value(storage_map[r][0]): if not r.type.is_valid_value(storage_map[r][0]):
raise InvalidValueError(r, storage_map[r][0], raise InvalidValueError(r, storage_map[r][0],
...@@ -1872,10 +1899,12 @@ class _Linker(gof.link.LocalLinker): ...@@ -1872,10 +1899,12 @@ class _Linker(gof.link.LocalLinker):
raise raise
opt = str(reason[0][0]) opt = str(reason[0][0])
msg = ( msg = (
"An optimization (probably %s ) inserted an apply node that raise an error." % opt + "An optimization (probably %s) inserted an "
"\nThe information we have about this optimizations is:" + str(reason[0][1]) + "apply node that raise an error." % opt +
"\n" + reason[0][2] + "\nThe information we have about this "
"\n\nThe original exception: \n" + str(e)) "optimizations is:" + str(reason[0][1]) +
"\n" + reason[0][2] +
"\n\nThe original exception: \n" + str(e))
new_e = e.__class__(msg) new_e = e.__class__(msg)
exc_type, exc_value, exc_trace = sys.exc_info() exc_type, exc_value, exc_trace = sys.exc_info()
exc_value = new_e exc_value = new_e
...@@ -1891,19 +1920,19 @@ class _Linker(gof.link.LocalLinker): ...@@ -1891,19 +1920,19 @@ class _Linker(gof.link.LocalLinker):
raise InvalidValueError(r, storage_map[r][0], raise InvalidValueError(r, storage_map[r][0],
hint='perform output', hint='perform output',
specific_hint=hint2) specific_hint=hint2)
warn_inp = config.DebugMode.warn_input_not_reused
py_inplace_outs = _check_inputs( py_inplace_outs = _check_inputs(
node, storage_map, r_vals, dr_vals, node, storage_map, r_vals, dr_vals,
active_order_set, active_order_set,
clobber_dr_vals=True, perform='py', clobber_dr_vals=True, perform='py',
warn_input_not_reused=config.DebugMode.warn_input_not_reused) warn_input_not_reused=warn_inp)
_check_viewmap(node, storage_map) _check_viewmap(node, storage_map)
# Retrieve each output from the storage_map # Retrieve each output from the storage_map.
# The return values of this first run will be the reference ones # The return values of this first run will be
# the reference ones
for r in node.outputs: for r in node.outputs:
assert r not in r_vals assert r not in r_vals
# print >> sys.stderr, i, "DEBUGMODE storing reference output %x" % id(storage_map[r][0])
r_vals[r] = storage_map[r][0] r_vals[r] = storage_map[r][0]
# clear the storage_map of outputs for the thunk_c # clear the storage_map of outputs for the thunk_c
storage_map[r][0] = None storage_map[r][0] = None
...@@ -1927,9 +1956,6 @@ class _Linker(gof.link.LocalLinker): ...@@ -1927,9 +1956,6 @@ class _Linker(gof.link.LocalLinker):
inplace_outs=py_inplace_outs, inplace_outs=py_inplace_outs,
init_outputs=init_outputs) init_outputs=init_outputs)
# print >> sys.stderr, i, "DEBUGMODE thunk_py %100s %50s %30s" % (node,
#[(id(o), numpy.asarray(storage_map[o][0])[0,0]) for o in node.inputs],
#[(id(o), numpy.asarray(storage_map[o][0])[0,0]) for o in node.outputs])
sys.stdout.flush() sys.stdout.flush()
if thunk_c: if thunk_c:
...@@ -1939,18 +1965,22 @@ class _Linker(gof.link.LocalLinker): ...@@ -1939,18 +1965,22 @@ class _Linker(gof.link.LocalLinker):
dmap = getattr(node.op, 'destroy_map', {}) dmap = getattr(node.op, 'destroy_map', {})
vmap = getattr(node.op, 'view_map', {}) vmap = getattr(node.op, 'view_map', {})
for i, r in enumerate(node.inputs): for i, r in enumerate(node.inputs):
# if thunk_py ran, and we still got this far, # if thunk_py ran, and we still got
# it means that the destroy_map of the Op (and view_map) are # this far, it means that the
# accurate # destroy_map of the Op (and view_map)
# so we can assume that inputs not marked as destroyed have in # are accurate so we can assume that
# fact not been destroyed. # inputs not marked as destroyed have
# Therefore... we only need to overwrite inputs that *have* # in fact not been destroyed.
# been marked as destroyed. # Therefore... we only need to
# Inputs marked as viewd are unsafe too, # overwrite inputs that *have* been
# because the corresponding output can # marked as destroyed. Inputs marked
# be destroyed. # as viewd are unsafe too, because the
if any(i in v for v in (dmap.values() + vmap.values())): # corresponding output can be
storage_map[r][0] = _lessbroken_deepcopy(r_vals[r]) # destroyed.
if any(i in v for v in (dmap.values() +
vmap.values())):
storage_map[r][0] = _lessbroken_deepcopy(
r_vals[r])
clobber = False clobber = False
...@@ -1969,10 +1999,12 @@ class _Linker(gof.link.LocalLinker): ...@@ -1969,10 +1999,12 @@ class _Linker(gof.link.LocalLinker):
raise raise
opt = str(reason[0][0]) opt = str(reason[0][0])
msg = ( msg = (
"An optimization (probably %s ) inserted an apply node that raise an error." % opt + "An optimization (probably %s) inserted "
"\nThe information we have about this optimizations is:" + str(reason[0][1]) + "an apply node that raise an error." % opt +
"\n" + reason[0][2] + "\nThe information we have about this "
"\n\nThe original exception: \n" + str(e)) "optimizations is:" + str(reason[0][1]) +
"\n" + reason[0][2] +
"\n\nThe original exception: \n" + str(e))
new_e = e.__class__(msg) new_e = e.__class__(msg)
exc_type, exc_value, exc_trace = sys.exc_info() exc_type, exc_value, exc_trace = sys.exc_info()
exc_value = new_e exc_value = new_e
...@@ -1982,21 +2014,26 @@ class _Linker(gof.link.LocalLinker): ...@@ -1982,21 +2014,26 @@ class _Linker(gof.link.LocalLinker):
for r in node.outputs: for r in node.outputs:
# check output values for type-correctness # check output values for type-correctness
if not r.type.is_valid_value(storage_map[r][0]): if not r.type.is_valid_value(storage_map[r][0]):
raise InvalidValueError(r, storage_map[r][0], hint='c output') raise InvalidValueError(r, storage_map[r][0],
hint='c output')
if thunk_py: if thunk_py:
assert r in r_vals # because we put it in during the thunk_py branch # because we put it in during the
# check for stride correctness (may raise exception) # thunk_py branch
_check_strides_match(r_vals[r], assert r in r_vals
storage_map[r][0], # check for stride correctness (may
# raise exception)
_check_strides_match(
r_vals[r], storage_map[r][0],
self.maker.mode.require_matching_strides, self.maker.mode.require_matching_strides,
node.op) node.op)
warn_inp = config.DebugMode.warn_input_not_reused
c_inplace_outs = _check_inputs( c_inplace_outs = _check_inputs(
node, storage_map, r_vals, node, storage_map, r_vals,
dr_vals, active_order_set, dr_vals, active_order_set,
clobber_dr_vals=clobber, perform='c', clobber_dr_vals=clobber, perform='c',
warn_input_not_reused=config.DebugMode.warn_input_not_reused) warn_input_not_reused=warn_inp)
_check_viewmap(node, storage_map) _check_viewmap(node, storage_map)
...@@ -2006,11 +2043,14 @@ class _Linker(gof.link.LocalLinker): ...@@ -2006,11 +2043,14 @@ class _Linker(gof.link.LocalLinker):
# compares the version from thunk_py # compares the version from thunk_py
# (in r_vals) to the version produced # (in r_vals) to the version produced
# by thunk_c (in storage_map) # by thunk_c (in storage_map)
if not check_eq(r, r_vals[r], storage_map[r][0]): if not check_eq(r, r_vals[r],
inputs_val = [storage_map[inp][0] for inp in r.owner.inputs] storage_map[r][0]):
inputs_val = [storage_map[inp][0]
for inp in r.owner.inputs]
raise BadThunkOutput( raise BadThunkOutput(
r, thunk1='perform', val1=r_vals[r], r, thunk1='perform', val1=r_vals[r],
thunk2='c_code', val2=storage_map[r][0], thunk2='c_code',
val2=storage_map[r][0],
inputs_val=inputs_val) inputs_val=inputs_val)
else: else:
# retrieve each output from the storage_map # retrieve each output from the storage_map
...@@ -2021,6 +2061,7 @@ class _Linker(gof.link.LocalLinker): ...@@ -2021,6 +2061,7 @@ class _Linker(gof.link.LocalLinker):
if self.maker.mode.check_preallocated_output: if self.maker.mode.check_preallocated_output:
prealloc_modes = \ prealloc_modes = \
self.maker.mode.check_preallocated_output self.maker.mode.check_preallocated_output
def thunk(): def thunk():
try: try:
thunk_c() thunk_c()
...@@ -2042,15 +2083,11 @@ class _Linker(gof.link.LocalLinker): ...@@ -2042,15 +2083,11 @@ class _Linker(gof.link.LocalLinker):
inplace_outs=c_inplace_outs, inplace_outs=c_inplace_outs,
init_outputs=init_outputs) init_outputs=init_outputs)
# print >> sys.stderr, i, "DEBUGMODE thunk_c %100s %50s %30s" % (node,
#[(id(o), numpy.asarray(storage_map[o][0])[0,0]) for o in node.inputs],
#[(id(o), numpy.asarray(storage_map[o][0])[0,0]) for o in node.outputs])
sys.stdout.flush() sys.stdout.flush()
# we're done with this thunk # we're done with this thunk
# clear everything out of the storage_map # clear everything out of the storage_map
for r in node.inputs: for r in node.inputs:
#print >> sys.stderr, i, "DEBUGMODE clearing input", r
storage_map[r][0] = None storage_map[r][0] = None
_logger.debug("%i - done with node", i) _logger.debug("%i - done with node", i)
...@@ -2059,7 +2096,8 @@ class _Linker(gof.link.LocalLinker): ...@@ -2059,7 +2096,8 @@ class _Linker(gof.link.LocalLinker):
# But it is very slow and it is not sure it will help. # But it is very slow and it is not sure it will help.
gc.collect() gc.collect()
_find_bad_optimizations(order, fgraph.equivalence_tracker.reasons, _find_bad_optimizations(order,
fgraph.equivalence_tracker.reasons,
r_vals) r_vals)
##### #####
...@@ -2081,18 +2119,24 @@ class _Linker(gof.link.LocalLinker): ...@@ -2081,18 +2119,24 @@ class _Linker(gof.link.LocalLinker):
for r in r_vals: for r in r_vals:
if r.owner is None: if r.owner is None:
if r in fgraph.inputs: if r in fgraph.inputs:
assert storage_map[r] is input_storage[fgraph.inputs.index(r)] assert (storage_map[r] is
input_storage[fgraph.inputs.index(r)])
storage_map[r][0] = r_vals[r] storage_map[r][0] = r_vals[r]
# if an input was destroyed, the destroyed value should be returned # if an input was destroyed, the destroyed value
# should be returned
for r in dr_vals: for r in dr_vals:
assert dr_vals[r][0] is not None assert dr_vals[r][0] is not None
if r.owner is None: if r.owner is None:
assert r in fgraph.inputs assert r in fgraph.inputs
# HACK TO LOOK LIKE A REAL DESTRUCTIVE ACTION TOOK PLACE # HACK TO LOOK LIKE A REAL DESTRUCTIVE ACTION
if type(dr_vals[r][0]) in (numpy.ndarray, numpy.memmap) \ # TOOK PLACE
and dr_vals[r][0].dtype == storage_map[r][0].dtype \ if ((type(dr_vals[r][0]) in
and dr_vals[r][0].shape == storage_map[r][0].shape: (numpy.ndarray, numpy.memmap)) and
(dr_vals[r][0].dtype ==
storage_map[r][0].dtype) and
(dr_vals[r][0].shape ==
storage_map[r][0].shape)):
if len(dr_vals[r][0].shape): if len(dr_vals[r][0].shape):
storage_map[r][0][:] = dr_vals[r][0] storage_map[r][0][:] = dr_vals[r][0]
else: else:
...@@ -2111,10 +2155,6 @@ class _Linker(gof.link.LocalLinker): ...@@ -2111,10 +2155,6 @@ class _Linker(gof.link.LocalLinker):
storage_map[r][0] = None storage_map[r][0] = None
raise raise
# print ""
# print output_storage
# print dr_vals
# print storage_map
for r in storage_map: for r in storage_map:
if (r.owner is None): if (r.owner is None):
if not r.type.is_valid_value(None): if not r.type.is_valid_value(None):
...@@ -2130,12 +2170,14 @@ class _Linker(gof.link.LocalLinker): ...@@ -2130,12 +2170,14 @@ class _Linker(gof.link.LocalLinker):
# so it will screw up if we are trying to use # so it will screw up if we are trying to use
# multiple modes at once. # multiple modes at once.
old_filter_checks_isfinite = TensorType.filter_checks_isfinite old_filter_checks_isfinite = TensorType.filter_checks_isfinite
TensorType.filter_checks_isfinite = self.maker.mode.check_isfinite TensorType.filter_checks_isfinite = \
self.maker.mode.check_isfinite
try: try:
return f() return f()
finally: finally:
# put back the filter_checks_isfinite # put back the filter_checks_isfinite
TensorType.filter_checks_isfinite = old_filter_checks_isfinite TensorType.filter_checks_isfinite = \
old_filter_checks_isfinite
return deco return deco
f = run_with_tensortype_filter_check(f) f = run_with_tensortype_filter_check(f)
...@@ -2143,12 +2185,12 @@ class _Linker(gof.link.LocalLinker): ...@@ -2143,12 +2185,12 @@ class _Linker(gof.link.LocalLinker):
f.allow_gc = True f.allow_gc = True
assert len(fgraph.inputs) == len(input_storage) assert len(fgraph.inputs) == len(input_storage)
assert len(fgraph.outputs) == len(output_storage) assert len(fgraph.outputs) == len(output_storage)
# print 'make_all returning output', [id(z) for z in output_storage] return (f,
return f, [link.Container(input, storage, readonly=False) [link.Container(input, storage, readonly=False)
for input, storage in zip(fgraph.inputs, input_storage)], \ for input, storage in zip(fgraph.inputs, input_storage)],
[link.Container(output, storage, readonly=True) [link.Container(output, storage, readonly=True)
for output, storage in zip(fgraph.outputs, output_storage)], \ for output, storage in zip(fgraph.outputs, output_storage)],
thunks_py, order thunks_py, order)
_NODEFAULT = ['NODEFAULT'] _NODEFAULT = ['NODEFAULT']
...@@ -2170,25 +2212,29 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions ...@@ -2170,25 +2212,29 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
""" """
:type inputs: a list of SymbolicInput instances :type inputs: a list of SymbolicInput instances
:type outputs: a list of SymbolicOutput instances :type outputs: a list of SymbolicOutput instances outputs may
outputs may also be a single Variable (not a list), in which also be a single Variable (not a list), in
case the functions produced by FunctionMaker will return which case the functions produced by
their output value directly FunctionMaker will return their output value
directly
:param accept_inplace: True iff it is acceptable to have :param accept_inplace: True iff it is acceptable to have
inplace operations in the graph from the inputs to inplace operations in the graph from the inputs to
the outputs the outputs
:param on_unused_input: What to do if a variable in the 'inputs' list is :param on_unused_input: What to do if a variable in the
not used in the graph. Possible values are 'raise', 'warn', and 'ignore'. 'inputs' list is not used in the
graph. Possible values are 'raise',
'warn', and 'ignore'.
:param output_keys: If the outputs argument for theano.function was a :param output_keys: If the outputs argument for
list, then output_keys is None. If the outputs argument was a dict, theano.function was a list, then
then output_keys is a sorted list of the keys from that dict. output_keys is None. If the outputs
argument was a dict, then output_keys is a
sorted list of the keys from that dict.
:note: this function sets TensorType.filter_checks_isfinite :note: this function sets TensorType.filter_checks_isfinite
when `mode.check_isfinite` is True when `mode.check_isfinite` is True
""" """
self.profile = profile self.profile = profile
# Handle the case where inputs and/or outputs is a single # Handle the case where inputs and/or outputs is a single
...@@ -2205,7 +2251,8 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions ...@@ -2205,7 +2251,8 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
inputs = [inputs] inputs = [inputs]
# Wrap them in In or Out instances if needed. # Wrap them in In or Out instances if needed.
inputs, outputs = map(self.wrap_in, inputs), map(self.wrap_out, outputs) inputs = map(self.wrap_in, inputs),
outputs = map(self.wrap_out, outputs)
_inputs = gof.graph.inputs([o.variable for o in outputs] + _inputs = gof.graph.inputs([o.variable for o in outputs] +
[i.update for i in inputs [i.update for i in inputs
if getattr(i, 'update', False)]) if getattr(i, 'update', False)])
...@@ -2213,9 +2260,11 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions ...@@ -2213,9 +2260,11 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
# Check if some input variables are unused # Check if some input variables are unused
self._check_unused_inputs(inputs, outputs, on_unused_input) self._check_unused_inputs(inputs, outputs, on_unused_input)
# Make a list of (SymbolicInput|SymblicInputKits, indices, [SymbolicInput,...]), one # Make a list of (SymbolicInput|SymblicInputKits, indices,
# tuple for each input. (See Function.indices for more details) # [SymbolicInput,...]), one tuple for each input. (See
indices = [[input] + self.expand_in(input, _inputs) for input in inputs] # Function.indices for more details)
indices = [[input] + self.expand_in(input, _inputs)
for input in inputs]
# make the fgraph # make the fgraph
for i in xrange(mode.stability_patience): for i in xrange(mode.stability_patience):
...@@ -2226,58 +2275,53 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions ...@@ -2226,58 +2275,53 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
# optimize the fgraph # optimize the fgraph
compute_test_value_orig = theano.config.compute_test_value compute_test_value_orig = theano.config.compute_test_value
try: try:
theano.config.compute_test_value = theano.config.compute_test_value_opt theano.config.compute_test_value = \
theano.config.compute_test_value_opt
optimizer(fgraph) optimizer(fgraph)
theano.compile.function_module.insert_deepcopy(fgraph, inputs, theano.compile.function_module.insert_deepcopy(
outputs + additional_outputs) fgraph, inputs, outputs + additional_outputs)
finally: finally:
theano.config.compute_test_value = compute_test_value_orig theano.config.compute_test_value = compute_test_value_orig
if i: if i == 0:
fgraph0 = fgraph
else:
li = fgraph.equivalence_tracker.event_list li = fgraph.equivalence_tracker.event_list
l0 = fgraph0.equivalence_tracker.event_list l0 = fgraph0.equivalence_tracker.event_list
if li != l0 : if li != l0:
infolog = StringIO() infolog = StringIO()
print("WARNING: Optimization process is unstable...", file=infolog) print("WARNING: Optimization process is unstable...",
print(" (HINT: Ops that the nodes point to must compare equal)", file=infolog) file=infolog)
print("(event index) (one event trace) (other event trace)", file=infolog) print(" (HINT: Ops that the nodes point to must compare "
print("-----------------------------------------------------", file=infolog) "equal)", file=infolog)
print("(event index) (one event trace) (other event "
"trace)", file=infolog)
print("-------------------------------------------------"
"----", file=infolog)
for j in xrange(max(len(li), len(l0))): for j in xrange(max(len(li), len(l0))):
if j >= len(li): if j >= len(li):
print('trailing event in optimization 0 :', j, file=infolog) print('trailing event in optimization 0 :', j,
file=infolog)
print(' ', str(l0[j]), file=infolog) print(' ', str(l0[j]), file=infolog)
elif j >= len(l0): elif j >= len(l0):
print('trailing event in optimization', i, ':', j, file=infolog) print('trailing event in optimization', i, ':',
j, file=infolog)
print(' ', str(li[j]), file=infolog) print(' ', str(li[j]), file=infolog)
elif li[j] != l0[j]: elif li[j] != l0[j]:
print('non-equal optimization events', i, ':', j, file=infolog) print('non-equal optimization events', i, ':',
j, file=infolog)
print(' ', str(l0[j]), file=infolog) print(' ', str(l0[j]), file=infolog)
print(' ', str(li[j]), file=infolog) print(' ', str(li[j]), file=infolog)
#print >> infolog, "* ", j,
# if j < len(li):
# msg = str(li[j])
# else:
# msg = '-'
#print >> infolog, " ", msg
# if j < len(l0):
# msg = str(l0[j])
# else:
# msg = '-'
#print >> infolog, " ", msg
else: else:
pass pass
raise StochasticOrder(infolog.getvalue()) raise StochasticOrder(infolog.getvalue())
else: else:
if self.verbose: if self.verbose:
print("OPTCHECK: optimization", i, \ print("OPTCHECK: optimization", i,
"of", len(li), "events was stable.", file=sys.stderr) "of", len(li), "events was stable.",
else: file=sys.stderr)
fgraph0 = fgraph
del fgraph0
self.fgraph = fgraph self.fgraph = fgraph
# equivalence_tracker.printstuff()
linker = _Linker(self) linker = _Linker(self)
...@@ -2532,8 +2576,7 @@ class DebugMode(Mode): ...@@ -2532,8 +2576,7 @@ class DebugMode(Mode):
raise Exception("DebugMode can only use its own linker! You " raise Exception("DebugMode can only use its own linker! You "
"should not provide one.", linker) "should not provide one.", linker)
super(DebugMode, self).__init__( super(DebugMode, self).__init__(optimizer=optimizer,
optimizer=optimizer,
linker=linker) linker=linker)
if stability_patience is not None: if stability_patience is not None:
......
...@@ -38,7 +38,6 @@ whitelist_flake8 = [ ...@@ -38,7 +38,6 @@ whitelist_flake8 = [
"tests/test_tutorial.py", "tests/test_tutorial.py",
"tests/disturb_mem.py", "tests/disturb_mem.py",
"tests/unittest_tools.py", "tests/unittest_tools.py",
"compile/debugmode.py",
"compile/function.py", "compile/function.py",
"compile/pfunc.py", "compile/pfunc.py",
"compile/mode.py", "compile/mode.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论