提交 cba9c812 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3130 from harlouci/flake8_gof

Flake8 gof
......@@ -177,7 +177,7 @@ def get_config_md5():
"""
all_opts = sorted([c for c in _config_var_list if c.in_c_key],
key=lambda cv: cv.fullname)
return theano.gof.cc.hash_from_code('\n'.join(
return theano.gof.utils.hash_from_code('\n'.join(
['%s = %s' % (cv.fullname, cv.__get__()) for cv in all_opts]))
......
"""
Defines Linkers that deal with C implementations.
"""
from __future__ import print_function
# Python imports
from copy import copy
import os
import sys
from theano.compat import izip
import logging
import numpy
import theano
from theano import config
from theano.compat import PY3
from theano.compat import izip
from six import string_types, reraise
from six.moves import StringIO, xrange
from theano.gof.utils import MethodNotDefined
import theano
from theano import config
if PY3:
import hashlib
def hash_from_code(msg):
# hashlib.md5() requires an object that supports buffer interface,
# but Python 3 (unicode) strings don't.
if isinstance(msg, str):
msg = msg.encode()
# Python 3 does not like module names that start with
# a digit.
return 'm' + hashlib.md5(msg).hexdigest()
else:
import hashlib
def hash_from_code(msg):
try:
return hashlib.md5(msg).hexdigest()
except TypeError:
assert isinstance(msg, numpy.ndarray)
return hashlib.md5(numpy.getbuffer(msg)).hexdigest()
def hash_from_file(file_path):
"""Return the MD5 hash of a file."""
return hash_from_code(open(file_path, 'rb').read())
# Note that we need to do this before importing cutils, since when there is
# no theano cache dir initialized yet, importing cutils may require compilation
# of cutils_ext.
from theano.configparser import AddConfigVar, StrParam
AddConfigVar('gcc.cxxflags',
"Extra compiler flags for gcc",
StrParam(""))
# gof imports
from theano.gof import graph
from theano.gof import link
from theano.gof import utils
from theano.gof import cmodule
from theano.gof.compilelock import get_lock, release_lock
from theano.gof.callcache import CallCache
from theano.gof import cmodule
AddConfigVar('gcc.cxxflags',
"Extra compiler flags for gcc",
StrParam(""))
import logging
_logger = logging.getLogger("theano.gof.cc")
from theano.gof.callcache import CallCache
run_cthunk = None # Will be imported only when needed.
......@@ -314,9 +284,8 @@ def get_c_declare(r, name, sub):
"""Wrapper around c_declare that declares py_name"""
if any([c != "output" and getattr(c.op, 'check_input',
config.check_input) for (c, _) in r.clients]) or (r.owner
and getattr(r.owner.op, 'check_input', True)):
config.check_input) for (c, _) in r.clients]) or (
r.owner and getattr(r.owner.op, 'check_input', True)):
c_declare = r.type.c_declare(name, sub, True)
else:
c_declare = r.type.c_declare(name, sub, False)
......@@ -840,11 +809,11 @@ class CLinker(link.Linker):
# FillMissing must disable some of them. Putting -ffast-math would
# make it disable all other parameter at the same time.
ret += ["-fno-math-errno",
#"-funsafe-math-optimizations",
#"-fno-signaling-nans",
#"-fcx-limited-range",
#"-fno-rounding-math",
#"-ffinite-math-only",
# "-funsafe-math-optimizations",
# "-fno-signaling-nans",
# "-fcx-limited-range",
# "-fno-rounding-math",
# "-ffinite-math-only",
# the current code generate label event if they are not used.
# Could use gcc attribute for those label only
......@@ -1335,7 +1304,6 @@ class CLinker(link.Linker):
preargs.remove('-DREPLACE_WITH_AMDLIBM')
if 'amdlibm' in libs:
libs.remove('amdlibm')
src_code = mod.code()
get_lock()
try:
_logger.debug("LOCATION %s", str(location))
......@@ -1371,8 +1339,8 @@ class CLinker(link.Linker):
code = self.instantiate_code(1 + len(self.args))
instantiate = cmodule.ExtFunction('instantiate', code,
method=cmodule.METH_VARARGS)
#['error_storage'] + argnames,
#local_dict = d,
# ['error_storage'] + argnames,
# local_dict = d,
# global_dict = {})
# Static methods that can run and destroy the struct built by
......@@ -1498,7 +1466,7 @@ class _CThunk(object):
global run_cthunk
if run_cthunk is None:
# Lazy import to avoid compilation when importing theano.
from theano.gof.cutils import run_cthunk
from theano.gof.cutils import run_cthunk # noqa
self.cthunk = cthunk
self.init_tasks = init_tasks
self.tasks = tasks
......@@ -1534,7 +1502,8 @@ class _CThunk(object):
exc_value.__thunk_trace__ = trace
except Exception:
print(('ERROR retrieving error_storage.'
' Was the error set in the c code?'), end=' ', file=sys.stderr)
'Was the error set in the c code?'),
end=' ', file=sys.stderr)
print(self.error_storage, file=sys.stderr)
raise
reraise(exc_type, exc_value, exc_trace)
......@@ -1641,8 +1610,8 @@ class OpWiseCLinker(link.LocalLinker):
for node in order:
if self.allow_gc:
post_thunk_old_storage.append([storage_map[input]
for input in node.inputs
post_thunk_old_storage.append(
[storage_map[input] for input in node.inputs
if ((input in computed) and
(input not in fgraph.outputs) and
node == last_user[input])])
......@@ -1741,12 +1710,12 @@ class DualLinker(link.Linker):
no_recycling = self.no_recycling
_f, i1, o1, thunks1, order1 = (
link.PerformLinker(schedule=self.schedule).accept(fgraph,
no_recycling=no_recycling).make_all(**kwargs))
link.PerformLinker(schedule=self.schedule).accept(
fgraph, no_recycling=no_recycling).make_all(**kwargs))
kwargs.pop('input_storage', None)
_f, i2, o2, thunks2, order2 = (
OpWiseCLinker(schedule=self.schedule).accept(fgraph,
no_recycling=no_recycling).make_all(**kwargs))
OpWiseCLinker(schedule=self.schedule).accept(
fgraph, no_recycling=no_recycling).make_all(**kwargs))
def f():
for input1, input2 in izip(i1, i2):
......
"""Generate and compile C modules for Python,
"""
from __future__ import print_function
import atexit
import six.moves.cPickle as pickle
import logging
......@@ -15,12 +17,6 @@ import time
import platform
import distutils.sysconfig
importlib = None
try:
import importlib
except ImportError:
pass
import numpy.distutils # TODO: TensorType should handle this
import theano
......@@ -28,7 +24,7 @@ from theano.compat import PY3, decode, decode_iter
from six import b, BytesIO, StringIO, string_types, iteritems
from theano.gof.utils import flatten
from theano.configparser import config
from theano.gof.cc import hash_from_code
from theano.gof.utils import hash_from_code
from theano.misc.windows import (subprocess_Popen,
output_subprocess_Popen)
......@@ -38,7 +34,14 @@ from theano.gof.compiledir import gcc_version_str, local_bitwidth
from theano.configparser import AddConfigVar, BoolParam
AddConfigVar('cmodule.mac_framework_link',
importlib = None
try:
import importlib
except ImportError:
pass
AddConfigVar(
'cmodule.mac_framework_link',
"If set to True, breaks certain MacOS installations with the infamous "
"Bus Error",
BoolParam(False))
......@@ -136,7 +139,8 @@ class ExtFunction(object):
class DynamicModule(object):
def __init__(self, name=None):
assert name is None, ("The 'name' parameter of DynamicModule"
assert name is None, (
"The 'name' parameter of DynamicModule"
" cannot be specified anymore. Instead, 'code_hash'"
" will be automatically computed and can be used as"
" the module's name.")
......@@ -429,8 +433,8 @@ def get_safe_part(key):
# Find the md5 hash part.
c_link_key = key[1]
for key_element in c_link_key[1:]:
if (isinstance(key_element, string_types)
and key_element.startswith('md5:')):
if (isinstance(key_element, string_types) and
key_element.startswith('md5:')):
md5 = key_element[4:]
break
......@@ -1149,15 +1153,14 @@ class ModuleCache(object):
# This is to make debugging in pdb easier, by providing
# the offending keys in the local context.
# key_data_keys = list(key_data.keys)
## import pdb; pdb.set_trace()
# import pdb; pdb.set_trace()
pass
elif found > 1:
msg = 'Multiple equal keys found in unpickled KeyData file'
if msg:
raise AssertionError(
"%s. Verify the __eq__ and __hash__ functions of your "
"Ops. The file is: %s. The key is: %s" %
(msg, key_pkl, key))
"Ops. The file is: %s. The key is: %s" % (msg, key_pkl, key))
# Also verify that there exists no other loaded key that would be equal
# to this key. In order to speed things up, we only compare to keys
# with the same version part and config md5, since we can assume this
......@@ -1805,9 +1808,9 @@ class GCC_compiler(Compiler):
for line in default_lines:
if line.startswith(part[0]):
part2 = [p for p in join_options(line.split())
if (not 'march' in p and
not 'mtune' in p and
not 'target-cpu' in p)]
if ('march' not in p and
'mtune' not in p and
'target-cpu' not in p)]
new_flags = [p for p in part if p not in part2]
# Replace '-target-cpu value', which is an option
# of clang, with '-march=value', for g++
......@@ -2021,13 +2024,12 @@ class GCC_compiler(Compiler):
cmd.append(cppfilename)
cmd.extend(['-L%s' % ldir for ldir in lib_dirs])
cmd.extend(['-l%s' % l for l in libs])
#print >> sys.stderr, 'COMPILING W CMD', cmd
# print >> sys.stderr, 'COMPILING W CMD', cmd
_logger.debug('Running cmd: %s', ' '.join(cmd))
def print_command_line_error():
# Print command line when a problem occurred.
print((
"Problem occurred during compilation with the "
print(("Problem occurred during compilation with the "
"command line below:"), file=sys.stderr)
print(' '.join(cmd), file=sys.stderr)
......
......@@ -46,7 +46,6 @@ def _contains_cycle(fgraph, orderings):
"""
# These are lists of Variable instances
inputs = fgraph.inputs
outputs = fgraph.outputs
# this is hard-coded reimplementation of functions from graph.py
......@@ -65,8 +64,6 @@ def _contains_cycle(fgraph, orderings):
# (defaultdict runs faster than dict in the case where the key
# is not in the dictionary, at least in CPython)
iset = set(inputs)
# IG: I tried converting parent_counts to use an id for the key,
# so that the dict would do reference counting on its keys.
# This caused a slowdown.
......@@ -236,9 +233,9 @@ def fast_inplace_check(inputs):
protected_inputs.extend(fgraph.outputs)
inputs = [i for i in inputs if
not isinstance(i, graph.Constant)
and not fgraph.destroyers(i)
and i not in protected_inputs]
not isinstance(i, graph.Constant) and
not fgraph.destroyers(i) and
i not in protected_inputs]
return inputs
if 0:
......@@ -293,7 +290,7 @@ if 0:
TODO: WRITEME: what does this do besides the checks?
"""
####### Do the checking ###########
# Do the checking #
already_there = False
if self.fgraph not in [None, fgraph]:
raise Exception("A DestroyHandler instance can only serve"
......@@ -309,7 +306,7 @@ if 0:
"DestroyHandler feature is already present or in"
" conflict with another plugin.")
####### end of checking ############
# end of checking #
def get_destroyers_of(r):
droot, impact, root_destroyer = self.refresh_droot_impact()
......@@ -362,8 +359,8 @@ if 0:
"Multiple destroyers of %s" % input_root)
droot[input_root] = input_root
root_destroyer[input_root] = app
#input_impact = set([input_root])
#add_impact(input_root, self.view_o, input_impact)
# input_impact = set([input_root])
# add_impact(input_root, self.view_o, input_impact)
input_impact = get_impact(input_root, self.view_o)
for v in input_impact:
assert v not in droot
......@@ -390,7 +387,7 @@ if 0:
def on_import(self, fgraph, app, reason):
"""Add Apply instance to set which must be computed"""
#if app in self.debug_all_apps: raise ProtocolError("double import")
# if app in self.debug_all_apps: raise ProtocolError("double import")
# self.debug_all_apps.add(app)
# print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps)
......@@ -421,7 +418,7 @@ if 0:
def on_prune(self, fgraph, app, reason):
"""Remove Apply instance from set which must be computed"""
#if app not in self.debug_all_apps: raise ProtocolError("prune without import")
# if app not in self.debug_all_apps: raise ProtocolError("prune without import")
# self.debug_all_apps.remove(app)
# UPDATE self.clients
......@@ -458,7 +455,7 @@ if 0:
# considered 'outputs' of the graph.
pass
else:
#if app not in self.debug_all_apps: raise ProtocolError("change without import")
# if app not in self.debug_all_apps: raise ProtocolError("change without import")
# UPDATE self.clients
self.clients[old_r][app] -= 1
......@@ -529,7 +526,8 @@ if 0:
droot, impact, __ignore = self.refresh_droot_impact()
# check for destruction of constants
illegal_destroy = [r for r in droot if
illegal_destroy = [
r for r in droot if
getattr(r.tag, 'indestructible', False) or
isinstance(r, graph.Constant)]
if illegal_destroy:
......@@ -621,7 +619,7 @@ if 0:
return rval
class DestroyHandler(toolbox.Bookkeeper):
class DestroyHandler(toolbox.Bookkeeper): # noqa
"""
The DestroyHandler class detects when a graph is impossible to evaluate
because of aliasing and destructive operations.
......@@ -702,7 +700,7 @@ class DestroyHandler(toolbox.Bookkeeper):
TODO: WRITEME: what does this do besides the checks?
"""
####### Do the checking ###########
# Do the checking #
already_there = False
if self.fgraph is fgraph:
already_there = True
......@@ -720,7 +718,7 @@ class DestroyHandler(toolbox.Bookkeeper):
"DestroyHandler feature is already present"
" or in conflict with another plugin.")
####### Annotate the FunctionGraph ############
# Annotate the FunctionGraph #
self.unpickle(fgraph)
fgraph.destroy_handler = self
......@@ -945,11 +943,12 @@ class DestroyHandler(toolbox.Bookkeeper):
droot, impact, __ignore = self.refresh_droot_impact()
# check for destruction of constants
illegal_destroy = [r for r in droot if \
getattr(r.tag, 'indestructible', False) or \
illegal_destroy = [r for r in droot if
getattr(r.tag, 'indestructible', False) or
isinstance(r, graph.Constant)]
if illegal_destroy:
raise InconsistencyError("Attempting to destroy indestructible variables: %s" %
raise InconsistencyError(
"Attempting to destroy indestructible variables: %s" %
illegal_destroy)
# add destroyed variable clients as computational dependencies
......@@ -995,12 +994,14 @@ class DestroyHandler(toolbox.Bookkeeper):
# CHECK FOR INPUT ALIASING
# OPT: pre-compute this on import
tolerate_same = getattr(app.op, 'destroyhandler_tolerate_same', [])
tolerate_same = getattr(app.op,
'destroyhandler_tolerate_same', [])
assert isinstance(tolerate_same, list)
tolerated = OrderedSet(idx1 for idx0, idx1 in tolerate_same
if idx0 == destroyed_idx)
tolerated.add(destroyed_idx)
tolerate_aliased = getattr(app.op, 'destroyhandler_tolerate_aliased', [])
tolerate_aliased = getattr(
app.op, 'destroyhandler_tolerate_aliased', [])
assert isinstance(tolerate_aliased, list)
ignored = OrderedSet(idx1 for idx0, idx1 in tolerate_aliased
if idx0 == destroyed_idx)
......@@ -1010,7 +1011,8 @@ class DestroyHandler(toolbox.Bookkeeper):
if i in ignored:
continue
if input in root_impact \
and (i not in tolerated or input is not destroyed_variable):
and (i not in tolerated or
input is not destroyed_variable):
raise InconsistencyError("Input aliasing: %s (%i, %i)"
% (app, destroyed_idx, i))
......
......@@ -13,7 +13,6 @@ from theano.gof import graph
from theano.gof import utils
from theano.gof import toolbox
from theano import config
import warnings
from theano.compat import OrderedDict
from six import iteritems, itervalues
......@@ -22,6 +21,7 @@ from theano.misc.ordered_set import OrderedSet
NullType = None
class CachedConstantError(Exception):
"""An exception thrown when we put in a FunctionGraph a Constant
that is cached. This should not happen as the user can reuse this
......@@ -143,7 +143,7 @@ class FunctionGraph(utils.object2):
self.variable_locks = {}
self.profile = None
### Setup a Variable ###
# Setup a Variable #
def __setup_r__(self, r):
# sets up r so it belongs to this fgraph
if getattr(r, 'cached', False):
......@@ -157,7 +157,7 @@ class FunctionGraph(utils.object2):
raise Exception("%s is already owned by another fgraph" % r)
r.fgraph = self
r.clients = []
#self.execute_callbacks('on_setup_variable', r)
# self.execute_callbacks('on_setup_variable', r)
def __setup_node__(self, node):
# sets up node so it belongs to this fgraph
......@@ -177,7 +177,7 @@ class FunctionGraph(utils.object2):
str(node.op), str(node.op.destroy_map)))
node.fgraph = self
node.deps = {}
#self.execute_callbacks('on_setup_node', node)
# self.execute_callbacks('on_setup_node', node)
def disown(self):
""" WRITEME
......@@ -201,7 +201,7 @@ class FunctionGraph(utils.object2):
self.inputs = None
self.outputs = None
### clients ###
# clients #
def clients(self, r):
"""
Set of all the (node, i) pairs such that node.inputs[i] is r.
......@@ -245,7 +245,7 @@ class FunctionGraph(utils.object2):
return True
return False
### import ###
# import #
def __import_r__(self, variable, reason):
global NullType
if NullType is None:
......@@ -281,7 +281,6 @@ class FunctionGraph(utils.object2):
if (r.owner is None and
not isinstance(r, graph.Constant) and
r not in self.inputs):
# Verbose error message
# Show a complete chain of variables from the missing input to an output
if config.exception_verbosity == 'high':
......@@ -373,7 +372,7 @@ class FunctionGraph(utils.object2):
assert node.fgraph is self
self.execute_callbacks('on_import', node, reason)
### prune ###
# prune #
def __prune_r__(self, variable, reason=None):
"""Should be called for variable that aren't used anymore:
len(var.clients) == 0
......@@ -430,7 +429,7 @@ class FunctionGraph(utils.object2):
self.__remove_clients__(input, [(apply_node, i)], reason=reason)
# self.__prune_r__(apply_node.inputs)
### change input ###
# change input #
def change_input(self, node, i, new_r, reason=None):
"""WRITEME
Changes node.inputs[i] to new_r.
......@@ -475,7 +474,7 @@ class FunctionGraph(utils.object2):
if prune:
self.__prune_r__(r, reason=reason)
### replace ###
# replace #
def replace(self, r, new_r, reason=None, verbose=None):
""" WRITEME
This is the main interface to manipulate the subgraph in FunctionGraph.
......@@ -582,7 +581,7 @@ class FunctionGraph(utils.object2):
if detach is not None:
detach(self)
### callback utils ###
# callback utils #
def execute_callbacks(self, name, *args, **kwargs):
"""WRITEME
Calls
......@@ -618,7 +617,7 @@ class FunctionGraph(utils.object2):
d[feature] = fn(*args)
return d
### misc ###
# misc #
def toposort(self):
"""WRITEME
Returns an ordering of the graph's Apply nodes such that:
......@@ -737,7 +736,7 @@ class FunctionGraph(utils.object2):
def __repr__(self):
return self.__str__()
### clone ###
# clone #
def clone(self, check_integrity=True):
"""WRITEME"""
return self.clone_get_equiv(check_integrity)[0]
......
......@@ -7,14 +7,14 @@ import traceback
import numpy
import theano
from theano.compat import PY3, izip
from theano.compat import izip
from six import reraise
from six.moves import StringIO
from theano.gof import utils
from theano.gof import graph
from theano.gof.type import Type
from .utils import MethodNotDefined, undef
from .utils import undef
__excepthook = sys.excepthook
......@@ -281,9 +281,9 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None):
else:
detailed_err_msg += "\n"
detailed_err_msg += " TotalSize: %s Byte(s) %.3f GB\n" % (
total_size, total_size/1024./1024/1024)
total_size, total_size / 1024. / 1024 / 1024)
detailed_err_msg += " TotalSize inputs: %s Byte(s) %.3f BG\n" % (
total_size_inputs, total_size_inputs/1024./1024/1024)
total_size_inputs, total_size_inputs / 1024. / 1024 / 1024)
else:
hints.append(
......@@ -326,7 +326,7 @@ class Linker(object):
raise utils.MethodNotDefined("make_thunk", type(self),
self.__class__.__name__)
## DELETEME ##
# DELETEME #
def make_function(self, unpack_single=True, **kwargs):
"""
Returns a function that takes values corresponding to the inputs of the
......@@ -350,8 +350,8 @@ class Linker(object):
def execute(*args):
def e_arity(takes, got):
return 'Function call takes exactly %i %s (%i given)' \
% (takes, ['argument', 'arguments'][takes > 1], got)
return 'Function call takes exactly %i %s (%i given)' % (
takes, ['argument', 'arguments'][takes > 1], got)
if (len(args) != len(inputs)):
raise TypeError(e_arity(len(inputs), len(args)))
for arg, variable in izip(args, inputs):
......@@ -394,7 +394,7 @@ class Container(object):
"""
if not isinstance(storage, list) or not len(storage) >= 1:
raise TypeError("storage must be a list of length at least one")
#self.r = r
# self.r = r
if isinstance(r, Type):
self.type = r
else:
......@@ -459,7 +459,6 @@ class Container(object):
# don't return an ndarray.
if (r.storage[0] is not None and
not self.type.is_valid_value(r.storage[0])):
assert not data_was_in_memo
assert self.type.is_valid_value(self.storage[0])
# This should also work for read only container.
......@@ -672,7 +671,7 @@ class PerformLinker(LocalLinker):
no_recycling = []
if self.fgraph is not None and self.fgraph is not fgraph:
return type(self)(allow_gc=self.allow_gc).accept(fgraph, no_recycling)
#raise Exception("Cannot accept from a Linker that is already tied to another FunctionGraph.")
# raise Exception("Cannot accept from a Linker that is already tied to another FunctionGraph.")
self.fgraph = fgraph
self.no_recycling = no_recycling
return self
......@@ -721,9 +720,12 @@ class PerformLinker(LocalLinker):
for node in order:
if self.allow_gc:
post_thunk_old_storage.append([storage_map[input]
post_thunk_old_storage.append(
[storage_map[input]
for input in node.inputs
if (input in computed) and (input not in fgraph.outputs) and node == last_user[input]])
if (input in computed) and (
input not in fgraph.outputs) and (
node == last_user[input])])
if no_recycling is True:
# True seems like some special code for *everything*?? -JB
......
......@@ -29,6 +29,7 @@ from . import destroyhandler as dh
_logger = logging.getLogger('theano.gof.opt')
_optimizer_idx = [0]
def _list_of_nodes(fgraph):
return list(graph.io_toposort(fgraph.inputs, fgraph.outputs))
......@@ -729,6 +730,7 @@ def pre_constant_merge(vars):
const_sig_inv = {}
if isinstance(vars, graph.Variable):
vars = [vars]
def recursive_merge(var):
if var in seen_var:
return var
......@@ -761,7 +763,7 @@ def pre_constant_merge(vars):
########################
### Local Optimizers ###
# Local Optimizers #
########################
class LocalOptimizer(object):
......@@ -820,9 +822,11 @@ class LocalOptimizer(object):
(' ' * level), self.__class__.__name__, id(self)), file=stream)
theano.configparser.AddConfigVar('metaopt.verbose',
theano.configparser.AddConfigVar(
'metaopt.verbose',
"Enable verbose output for meta optimizers",
theano.configparser.BoolParam(False), in_c_key=False)
theano.configparser.BoolParam(False),
in_c_key=False)
class LocalMetaOptimizer(LocalOptimizer):
......@@ -1003,19 +1007,6 @@ class LocalOptGroup(LocalOptimizer):
opt.add_requirements(fgraph)
class _LocalOpKeyOptGroup(LocalOptGroup):
"""WRITEME"""
def __init__(self, optimizers):
if any(not hasattr(opt, 'op_key'), optimizers):
raise TypeError(
"All LocalOptimizers passed here must have an op_key method.")
CompositeLocalOptimizer.__init__(self, optimizers)
def op_key(self):
return [opt.op_key() for opt in self.opts]
class OpSub(LocalOptimizer):
"""WRITEME
Replaces the application of a certain op by the application of
......@@ -1217,6 +1208,7 @@ class PatternSub(LocalOptimizer):
if node.op != self.op:
return False
# TODO: if we remove pdb, do this speed things up?
def match(pattern, expr, u, allow_multiple_clients=False, pdb=False):
# TODO move outside match
def retry_with_equiv():
......@@ -1233,9 +1225,8 @@ class PatternSub(LocalOptimizer):
if isinstance(pattern, (list, tuple)):
if expr.owner is None:
return False
if (not (expr.owner.op == pattern[0])
or (not allow_multiple_clients
and len(expr.clients) > 1)):
if (not (expr.owner.op == pattern[0]) or
(not allow_multiple_clients and len(expr.clients) > 1)):
return retry_with_equiv()
if len(pattern) - 1 != len(expr.owner.inputs):
return retry_with_equiv()
......@@ -1263,16 +1254,16 @@ class PatternSub(LocalOptimizer):
return retry_with_equiv()
else:
u = u.merge(expr, v)
elif (isinstance(pattern, (int, float))
and isinstance(expr, graph.Constant)):
elif (isinstance(pattern, (int, float)) and
isinstance(expr, graph.Constant)):
if numpy.all(
theano.tensor.constant(pattern).value == expr.value):
return u
else:
return retry_with_equiv()
elif (isinstance(pattern, graph.Constant)
and isinstance(expr, graph.Constant)
and pattern.equals(expr)):
elif (isinstance(pattern, graph.Constant) and
isinstance(expr, graph.Constant) and
pattern.equals(expr)):
return u
else:
return retry_with_equiv()
......@@ -1335,7 +1326,7 @@ class PatternSub(LocalOptimizer):
##################
### Navigators ###
# Navigators #
##################
# Use the following classes to apply LocalOptimizers
......@@ -1811,8 +1802,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
node_created[gopt] += change_tracker.nb_imported - nb
if global_process_count[gopt] > max_use:
max_use_abort = True
opt_name = (getattr(gopt, "name", None)
or getattr(gopt, "__name__", ""))
opt_name = (getattr(gopt, "name", None) or
getattr(gopt, "__name__", ""))
global_sub_profs.append(sub_profs)
global_opt_timing.append(float(time.time() - t0))
......@@ -1858,8 +1849,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
node_created[lopt] += change_tracker.nb_imported - nb
if global_process_count[lopt] > max_use:
max_use_abort = True
opt_name = (getattr(lopt, "name", None)
or getattr(lopt, "__name__", ""))
opt_name = (getattr(lopt, "name", None) or
getattr(lopt, "__name__", ""))
if node not in fgraph.apply_nodes:
# go to next node
break
......@@ -1884,8 +1875,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
node_created[gopt] += change_tracker.nb_imported - nb
if global_process_count[gopt] > max_use:
max_use_abort = True
opt_name = (getattr(gopt, "name", None)
or getattr(gopt, "__name__", ""))
opt_name = (getattr(gopt, "name", None) or
getattr(gopt, "__name__", ""))
final_sub_profs.append(sub_profs)
global_opt_timing[-1] += time.time() - t_before_final_opt
......@@ -1896,9 +1887,9 @@ class EquilibriumOptimizer(NavigatorOptimizer):
end_nb_nodes = len(fgraph.apply_nodes)
if max_use_abort:
_logger.error("EquilibriumOptimizer max'ed out by '%s'" % opt_name
+ ". You can safely raise the current threshold of "
+ "%f with the theano flag 'optdb.max_use_ratio'." %
_logger.error("EquilibriumOptimizer max'ed out by '%s'" % opt_name +
". You can safely raise the current threshold of " +
"%f with the theano flag 'optdb.max_use_ratio'." %
config.optdb.max_use_ratio)
fgraph.remove_feature(change_tracker)
return (self, loop_timing, loop_process_count,
......@@ -1975,8 +1966,9 @@ class EquilibriumOptimizer(NavigatorOptimizer):
not_used_time += time_opts[o]
if count_opt:
print(blanc, \
' times - times applied - nb node created - name:', file=stream)
print(blanc,
' times - times applied - nb node created - name:',
file=stream)
count_opt.sort()
for (t, count, n_created, o) in count_opt[::-1]:
print(blanc, ' %.3fs - %d - %d - %s' % (
......@@ -2010,7 +2002,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
@staticmethod
def merge_profile(prof1, prof2):
#(opt, loop_timing, loop_process_count, max_nb_nodes,
# (opt, loop_timing, loop_process_count, max_nb_nodes,
# global_opt_timing, nb_nodes, time_opts, io_toposort_timing) = prof1
local_optimizers = OrderedSet(prof1[0].get_local_optimizers()).union(
prof2[0].get_local_optimizers())
......@@ -2085,7 +2077,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
final_sub_profs)
#################
### Utilities ###
# Utilities #
#################
......@@ -2096,7 +2088,7 @@ def _check_chain(r, chain):
while chain:
elem = chain.pop()
if elem is None:
if not r.owner is None:
if r.owner is not None:
return False
elif r.owner is None:
return False
......@@ -2105,20 +2097,20 @@ def _check_chain(r, chain):
return False
else:
try:
if (issubclass(elem, op.Op)
and not isinstance(r.owner.op, elem)):
if (issubclass(elem, op.Op) and
not isinstance(r.owner.op, elem)):
return False
except TypeError:
return False
if chain:
r = r.owner.inputs[chain.pop()]
# print 'check_chain', _check_chain.n_calls
#_check_chain.n_calls += 1
# _check_chain.n_calls += 1
# The return value will be used as a Boolean, but some Variables cannot
# be used as Booleans (the results of comparisons, for instance)
return (r is not None)
#_check_chain.n_calls = 0
# _check_chain.n_calls = 0
def check_chain(r, *chain):
......
......@@ -3,9 +3,11 @@ import linecache
import traceback
import sys
import numpy
from six import iteritems
from theano import config
from theano.compat import PY3
def simple_extract_stack(f=None, limit=None):
......@@ -435,3 +437,31 @@ def remove(predicate, coll):
[1, 3]
"""
return [x for x in coll if not predicate(x)]
if PY3:
import hashlib
def hash_from_code(msg):
# hashlib.md5() requires an object that supports buffer interface,
# but Python 3 (unicode) strings don't.
if isinstance(msg, str):
msg = msg.encode()
# Python 3 does not like module names that start with
# a digit.
return 'm' + hashlib.md5(msg).hexdigest()
else:
import hashlib
def hash_from_code(msg):
try:
return hashlib.md5(msg).hexdigest()
except TypeError:
assert isinstance(msg, numpy.ndarray)
return hashlib.md5(numpy.getbuffer(msg)).hexdigest()
def hash_from_file(file_path):
"""Return the MD5 hash of a file."""
return hash_from_code(open(file_path, 'rb').read())
......@@ -10,7 +10,7 @@ import numpy
from theano.compat import decode, decode_iter
from theano.gof import local_bitwidth
from theano.gof.cc import hash_from_file
from theano.gof.utils import hash_from_file
from theano.gof.cmodule import (std_libs, std_lib_dirs,
std_include_dirs, dlimport,
Compiler,
......
from theano.gof.cc import hash_from_code
from theano.gof.utils import hash_from_code
def hash_from_sparse(data):
......
......@@ -2,7 +2,7 @@ import numpy
import theano
from theano.compat import izip
from theano.gof.cc import hash_from_code
from theano.gof.utils import hash_from_code
def hash_from_ndarray(data):
......
......@@ -233,16 +233,10 @@ whitelist_flake8 = [
"sparse/sandbox/sp2.py",
"sparse/sandbox/truedot.py",
"sparse/sandbox/sp.py",
"gof/destroyhandler.py",
"gof/unify.py",
"gof/graph.py",
"gof/__init__.py",
"gof/cc.py",
"gof/opt.py",
"gof/link.py",
"gof/fg.py",
"gof/op.py",
"gof/cmodule.py",
"gof/tests/test_cmodule.py",
"gof/tests/test_destroyhandler.py",
"gof/tests/test_opt.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论