提交 cba9c812 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3130 from harlouci/flake8_gof

Flake8 gof
...@@ -177,7 +177,7 @@ def get_config_md5(): ...@@ -177,7 +177,7 @@ def get_config_md5():
""" """
all_opts = sorted([c for c in _config_var_list if c.in_c_key], all_opts = sorted([c for c in _config_var_list if c.in_c_key],
key=lambda cv: cv.fullname) key=lambda cv: cv.fullname)
return theano.gof.cc.hash_from_code('\n'.join( return theano.gof.utils.hash_from_code('\n'.join(
['%s = %s' % (cv.fullname, cv.__get__()) for cv in all_opts])) ['%s = %s' % (cv.fullname, cv.__get__()) for cv in all_opts]))
......
""" """
Defines Linkers that deal with C implementations. Defines Linkers that deal with C implementations.
""" """
from __future__ import print_function from __future__ import print_function
# Python imports # Python imports
from copy import copy from copy import copy
import os import os
import sys import sys
from theano.compat import izip import logging
import numpy import numpy
import theano
from theano import config
from theano.compat import PY3 from theano.compat import PY3
from theano.compat import izip
from six import string_types, reraise from six import string_types, reraise
from six.moves import StringIO, xrange from six.moves import StringIO, xrange
from theano.gof.utils import MethodNotDefined
import theano
from theano import config
if PY3:
import hashlib
def hash_from_code(msg):
# hashlib.md5() requires an object that supports buffer interface,
# but Python 3 (unicode) strings don't.
if isinstance(msg, str):
msg = msg.encode()
# Python 3 does not like module names that start with
# a digit.
return 'm' + hashlib.md5(msg).hexdigest()
else:
import hashlib
def hash_from_code(msg):
try:
return hashlib.md5(msg).hexdigest()
except TypeError:
assert isinstance(msg, numpy.ndarray)
return hashlib.md5(numpy.getbuffer(msg)).hexdigest()
def hash_from_file(file_path):
"""Return the MD5 hash of a file."""
return hash_from_code(open(file_path, 'rb').read())
# Note that we need to do this before importing cutils, since when there is # Note that we need to do this before importing cutils, since when there is
# no theano cache dir initialized yet, importing cutils may require compilation # no theano cache dir initialized yet, importing cutils may require compilation
# of cutils_ext. # of cutils_ext.
from theano.configparser import AddConfigVar, StrParam from theano.configparser import AddConfigVar, StrParam
AddConfigVar('gcc.cxxflags',
"Extra compiler flags for gcc",
StrParam(""))
# gof imports # gof imports
from theano.gof import graph from theano.gof import graph
from theano.gof import link from theano.gof import link
from theano.gof import utils from theano.gof import utils
from theano.gof import cmodule
from theano.gof.compilelock import get_lock, release_lock from theano.gof.compilelock import get_lock, release_lock
from theano.gof.callcache import CallCache
from theano.gof import cmodule AddConfigVar('gcc.cxxflags',
"Extra compiler flags for gcc",
StrParam(""))
import logging
_logger = logging.getLogger("theano.gof.cc") _logger = logging.getLogger("theano.gof.cc")
from theano.gof.callcache import CallCache
run_cthunk = None # Will be imported only when needed. run_cthunk = None # Will be imported only when needed.
...@@ -314,9 +284,8 @@ def get_c_declare(r, name, sub): ...@@ -314,9 +284,8 @@ def get_c_declare(r, name, sub):
"""Wrapper around c_declare that declares py_name""" """Wrapper around c_declare that declares py_name"""
if any([c != "output" and getattr(c.op, 'check_input', if any([c != "output" and getattr(c.op, 'check_input',
config.check_input) for (c, _) in r.clients]) or (r.owner config.check_input) for (c, _) in r.clients]) or (
and getattr(r.owner.op, 'check_input', True)): r.owner and getattr(r.owner.op, 'check_input', True)):
c_declare = r.type.c_declare(name, sub, True) c_declare = r.type.c_declare(name, sub, True)
else: else:
c_declare = r.type.c_declare(name, sub, False) c_declare = r.type.c_declare(name, sub, False)
...@@ -532,7 +501,7 @@ class CLinker(link.Linker): ...@@ -532,7 +501,7 @@ class CLinker(link.Linker):
if isinstance(r, graph.Constant) and if isinstance(r, graph.Constant) and
r not in self.inputs) r not in self.inputs)
self.temps = list(set(self.variables).difference( self.temps = list(set(self.variables).difference(
self.inputs).difference(self.outputs).difference(self.orphans)) self.inputs).difference(self.outputs).difference(self.orphans))
self.consts = [] self.consts = []
def code_gen(self): def code_gen(self):
...@@ -821,7 +790,7 @@ class CLinker(link.Linker): ...@@ -821,7 +790,7 @@ class CLinker(link.Linker):
ret = [] ret = []
# generic support code # generic support code
for x in [y.type for y in self.variables] + [ for x in [y.type for y in self.variables] + [
y.op for y in self.node_order]: y.op for y in self.node_order]:
try: try:
ret.append(x.c_support_code()) ret.append(x.c_support_code())
except utils.MethodNotDefined: except utils.MethodNotDefined:
...@@ -840,11 +809,11 @@ class CLinker(link.Linker): ...@@ -840,11 +809,11 @@ class CLinker(link.Linker):
# FillMissing must disable some of them. Putting -ffast-math would # FillMissing must disable some of them. Putting -ffast-math would
# make it disable all other parameter at the same time. # make it disable all other parameter at the same time.
ret += ["-fno-math-errno", ret += ["-fno-math-errno",
#"-funsafe-math-optimizations", # "-funsafe-math-optimizations",
#"-fno-signaling-nans", # "-fno-signaling-nans",
#"-fcx-limited-range", # "-fcx-limited-range",
#"-fno-rounding-math", # "-fno-rounding-math",
#"-ffinite-math-only", # "-ffinite-math-only",
# the current code generate label event if they are not used. # the current code generate label event if they are not used.
# Could use gcc attribute for those label only # Could use gcc attribute for those label only
...@@ -853,7 +822,7 @@ class CLinker(link.Linker): ...@@ -853,7 +822,7 @@ class CLinker(link.Linker):
"-Wno-write-strings", # generated by our code generator... "-Wno-write-strings", # generated by our code generator...
] ]
for x in [y.type for y in self.variables] + [ for x in [y.type for y in self.variables] + [
y.op for y in self.node_order]: y.op for y in self.node_order]:
try: try:
ret += x.c_compile_args() ret += x.c_compile_args()
except utils.MethodNotDefined: except utils.MethodNotDefined:
...@@ -866,7 +835,7 @@ class CLinker(link.Linker): ...@@ -866,7 +835,7 @@ class CLinker(link.Linker):
# to reorder them # to reorder them
ret += c_compiler.compile_args() ret += c_compiler.compile_args()
for x in [y.type for y in self.variables] + [ for x in [y.type for y in self.variables] + [
y.op for y in self.node_order]: y.op for y in self.node_order]:
try: try:
for i in x.c_no_compile_args(): for i in x.c_no_compile_args():
try: try:
...@@ -886,7 +855,7 @@ class CLinker(link.Linker): ...@@ -886,7 +855,7 @@ class CLinker(link.Linker):
""" """
ret = [] ret = []
for x in [y.type for y in self.variables] + [ for x in [y.type for y in self.variables] + [
y.op for y in self.node_order]: y.op for y in self.node_order]:
try: try:
ret += x.c_headers() ret += x.c_headers()
except utils.MethodNotDefined: except utils.MethodNotDefined:
...@@ -901,7 +870,7 @@ class CLinker(link.Linker): ...@@ -901,7 +870,7 @@ class CLinker(link.Linker):
""" """
ret = [] ret = []
for x in [y.type for y in self.variables] + [ for x in [y.type for y in self.variables] + [
y.op for y in self.node_order]: y.op for y in self.node_order]:
try: try:
ret += x.c_init_code() ret += x.c_init_code()
except utils.MethodNotDefined: except utils.MethodNotDefined:
...@@ -911,7 +880,7 @@ class CLinker(link.Linker): ...@@ -911,7 +880,7 @@ class CLinker(link.Linker):
def c_compiler(self): def c_compiler(self):
c_compiler = None c_compiler = None
for x in [y.type for y in self.variables] + [ for x in [y.type for y in self.variables] + [
y.op for y in self.node_order]: y.op for y in self.node_order]:
if hasattr(x, 'c_compiler'): if hasattr(x, 'c_compiler'):
x_compiler = x.c_compiler() x_compiler = x.c_compiler()
else: else:
...@@ -938,7 +907,7 @@ class CLinker(link.Linker): ...@@ -938,7 +907,7 @@ class CLinker(link.Linker):
""" """
ret = [] ret = []
for x in [y.type for y in self.variables] + [ for x in [y.type for y in self.variables] + [
y.op for y in self.node_order]: y.op for y in self.node_order]:
try: try:
ret += x.c_header_dirs() ret += x.c_header_dirs()
except utils.MethodNotDefined: except utils.MethodNotDefined:
...@@ -954,7 +923,7 @@ class CLinker(link.Linker): ...@@ -954,7 +923,7 @@ class CLinker(link.Linker):
""" """
ret = [] ret = []
for x in [y.type for y in self.variables] + [ for x in [y.type for y in self.variables] + [
y.op for y in self.node_order]: y.op for y in self.node_order]:
try: try:
ret += x.c_libraries() ret += x.c_libraries()
except utils.MethodNotDefined: except utils.MethodNotDefined:
...@@ -970,7 +939,7 @@ class CLinker(link.Linker): ...@@ -970,7 +939,7 @@ class CLinker(link.Linker):
""" """
ret = [] ret = []
for x in [y.type for y in self.variables] + [ for x in [y.type for y in self.variables] + [
y.op for y in self.node_order]: y.op for y in self.node_order]:
try: try:
ret += x.c_lib_dirs() ret += x.c_lib_dirs()
except utils.MethodNotDefined: except utils.MethodNotDefined:
...@@ -1150,7 +1119,7 @@ class CLinker(link.Linker): ...@@ -1150,7 +1119,7 @@ class CLinker(link.Linker):
libraries=self.libraries(), libraries=self.libraries(),
header_dirs=self.header_dirs(), header_dirs=self.header_dirs(),
c_compiler=self.c_compiler(), c_compiler=self.c_compiler(),
) )
def cmodule_key_(self, fgraph, no_recycling, compile_args=None, def cmodule_key_(self, fgraph, no_recycling, compile_args=None,
libraries=None, header_dirs=None, insert_config_md5=True, libraries=None, header_dirs=None, insert_config_md5=True,
...@@ -1335,7 +1304,6 @@ class CLinker(link.Linker): ...@@ -1335,7 +1304,6 @@ class CLinker(link.Linker):
preargs.remove('-DREPLACE_WITH_AMDLIBM') preargs.remove('-DREPLACE_WITH_AMDLIBM')
if 'amdlibm' in libs: if 'amdlibm' in libs:
libs.remove('amdlibm') libs.remove('amdlibm')
src_code = mod.code()
get_lock() get_lock()
try: try:
_logger.debug("LOCATION %s", str(location)) _logger.debug("LOCATION %s", str(location))
...@@ -1371,9 +1339,9 @@ class CLinker(link.Linker): ...@@ -1371,9 +1339,9 @@ class CLinker(link.Linker):
code = self.instantiate_code(1 + len(self.args)) code = self.instantiate_code(1 + len(self.args))
instantiate = cmodule.ExtFunction('instantiate', code, instantiate = cmodule.ExtFunction('instantiate', code,
method=cmodule.METH_VARARGS) method=cmodule.METH_VARARGS)
#['error_storage'] + argnames, # ['error_storage'] + argnames,
#local_dict = d, # local_dict = d,
# global_dict = {}) # global_dict = {})
# Static methods that can run and destroy the struct built by # Static methods that can run and destroy the struct built by
# instantiate. # instantiate.
...@@ -1498,7 +1466,7 @@ class _CThunk(object): ...@@ -1498,7 +1466,7 @@ class _CThunk(object):
global run_cthunk global run_cthunk
if run_cthunk is None: if run_cthunk is None:
# Lazy import to avoid compilation when importing theano. # Lazy import to avoid compilation when importing theano.
from theano.gof.cutils import run_cthunk from theano.gof.cutils import run_cthunk # noqa
self.cthunk = cthunk self.cthunk = cthunk
self.init_tasks = init_tasks self.init_tasks = init_tasks
self.tasks = tasks self.tasks = tasks
...@@ -1534,7 +1502,8 @@ class _CThunk(object): ...@@ -1534,7 +1502,8 @@ class _CThunk(object):
exc_value.__thunk_trace__ = trace exc_value.__thunk_trace__ = trace
except Exception: except Exception:
print(('ERROR retrieving error_storage.' print(('ERROR retrieving error_storage.'
' Was the error set in the c code?'), end=' ', file=sys.stderr) 'Was the error set in the c code?'),
end=' ', file=sys.stderr)
print(self.error_storage, file=sys.stderr) print(self.error_storage, file=sys.stderr)
raise raise
reraise(exc_type, exc_value, exc_trace) reraise(exc_type, exc_value, exc_trace)
...@@ -1641,11 +1610,11 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -1641,11 +1610,11 @@ class OpWiseCLinker(link.LocalLinker):
for node in order: for node in order:
if self.allow_gc: if self.allow_gc:
post_thunk_old_storage.append([storage_map[input] post_thunk_old_storage.append(
for input in node.inputs [storage_map[input] for input in node.inputs
if ((input in computed) and if ((input in computed) and
(input not in fgraph.outputs) and (input not in fgraph.outputs) and
node == last_user[input])]) node == last_user[input])])
if no_recycling is True: if no_recycling is True:
no_recycling = list(storage_map.values()) no_recycling = list(storage_map.values())
...@@ -1741,12 +1710,12 @@ class DualLinker(link.Linker): ...@@ -1741,12 +1710,12 @@ class DualLinker(link.Linker):
no_recycling = self.no_recycling no_recycling = self.no_recycling
_f, i1, o1, thunks1, order1 = ( _f, i1, o1, thunks1, order1 = (
link.PerformLinker(schedule=self.schedule).accept(fgraph, link.PerformLinker(schedule=self.schedule).accept(
no_recycling=no_recycling).make_all(**kwargs)) fgraph, no_recycling=no_recycling).make_all(**kwargs))
kwargs.pop('input_storage', None) kwargs.pop('input_storage', None)
_f, i2, o2, thunks2, order2 = ( _f, i2, o2, thunks2, order2 = (
OpWiseCLinker(schedule=self.schedule).accept(fgraph, OpWiseCLinker(schedule=self.schedule).accept(
no_recycling=no_recycling).make_all(**kwargs)) fgraph, no_recycling=no_recycling).make_all(**kwargs))
def f(): def f():
for input1, input2 in izip(i1, i2): for input1, input2 in izip(i1, i2):
......
"""Generate and compile C modules for Python, """Generate and compile C modules for Python,
""" """
from __future__ import print_function from __future__ import print_function
import atexit import atexit
import six.moves.cPickle as pickle import six.moves.cPickle as pickle
import logging import logging
...@@ -15,12 +17,6 @@ import time ...@@ -15,12 +17,6 @@ import time
import platform import platform
import distutils.sysconfig import distutils.sysconfig
importlib = None
try:
import importlib
except ImportError:
pass
import numpy.distutils # TODO: TensorType should handle this import numpy.distutils # TODO: TensorType should handle this
import theano import theano
...@@ -28,7 +24,7 @@ from theano.compat import PY3, decode, decode_iter ...@@ -28,7 +24,7 @@ from theano.compat import PY3, decode, decode_iter
from six import b, BytesIO, StringIO, string_types, iteritems from six import b, BytesIO, StringIO, string_types, iteritems
from theano.gof.utils import flatten from theano.gof.utils import flatten
from theano.configparser import config from theano.configparser import config
from theano.gof.cc import hash_from_code from theano.gof.utils import hash_from_code
from theano.misc.windows import (subprocess_Popen, from theano.misc.windows import (subprocess_Popen,
output_subprocess_Popen) output_subprocess_Popen)
...@@ -38,10 +34,17 @@ from theano.gof.compiledir import gcc_version_str, local_bitwidth ...@@ -38,10 +34,17 @@ from theano.gof.compiledir import gcc_version_str, local_bitwidth
from theano.configparser import AddConfigVar, BoolParam from theano.configparser import AddConfigVar, BoolParam
AddConfigVar('cmodule.mac_framework_link', importlib = None
"If set to True, breaks certain MacOS installations with the infamous " try:
"Bus Error", import importlib
BoolParam(False)) except ImportError:
pass
AddConfigVar(
'cmodule.mac_framework_link',
"If set to True, breaks certain MacOS installations with the infamous "
"Bus Error",
BoolParam(False))
AddConfigVar('cmodule.warn_no_version', AddConfigVar('cmodule.warn_no_version',
"If True, will print a warning when compiling one or more Op " "If True, will print a warning when compiling one or more Op "
...@@ -131,15 +134,16 @@ class ExtFunction(object): ...@@ -131,15 +134,16 @@ class ExtFunction(object):
It goes into the DynamicModule's method table. It goes into the DynamicModule's method table.
""" """
return '\t{"%s", %s, %s, "%s"}' % ( return '\t{"%s", %s, %s, "%s"}' % (
self.name, self.name, self.method, self.doc) self.name, self.name, self.method, self.doc)
class DynamicModule(object): class DynamicModule(object):
def __init__(self, name=None): def __init__(self, name=None):
assert name is None, ("The 'name' parameter of DynamicModule" assert name is None, (
" cannot be specified anymore. Instead, 'code_hash'" "The 'name' parameter of DynamicModule"
" will be automatically computed and can be used as" " cannot be specified anymore. Instead, 'code_hash'"
" the module's name.") " will be automatically computed and can be used as"
" the module's name.")
# While the module is not finalized, we can call add_... # While the module is not finalized, we can call add_...
# when it is finalized, a hash is computed and used instead of # when it is finalized, a hash is computed and used instead of
# the placeholder, and as module name. # the placeholder, and as module name.
...@@ -171,18 +175,18 @@ static struct PyModuleDef moduledef = {{ ...@@ -171,18 +175,18 @@ static struct PyModuleDef moduledef = {{
}}; }};
""".format(name=self.hash_placeholder), file=stream) """.format(name=self.hash_placeholder), file=stream)
print(("PyMODINIT_FUNC PyInit_%s(void) {" % print(("PyMODINIT_FUNC PyInit_%s(void) {" %
self.hash_placeholder), file=stream) self.hash_placeholder), file=stream)
for block in self.init_blocks: for block in self.init_blocks:
print(' ', block, file=stream) print(' ', block, file=stream)
print(" PyObject *m = PyModule_Create(&moduledef);", file=stream) print(" PyObject *m = PyModule_Create(&moduledef);", file=stream)
print(" return m;", file=stream) print(" return m;", file=stream)
else: else:
print(("PyMODINIT_FUNC init%s(void){" % print(("PyMODINIT_FUNC init%s(void){" %
self.hash_placeholder), file=stream) self.hash_placeholder), file=stream)
for block in self.init_blocks: for block in self.init_blocks:
print(' ', block, file=stream) print(' ', block, file=stream)
print(' ', ('(void) Py_InitModule("%s", MyMethods);' print(' ', ('(void) Py_InitModule("%s", MyMethods);'
% self.hash_placeholder), file=stream) % self.hash_placeholder), file=stream)
print("}", file=stream) print("}", file=stream)
def add_include(self, str): def add_include(self, str):
...@@ -351,9 +355,9 @@ def is_same_entry(entry_1, entry_2): ...@@ -351,9 +355,9 @@ def is_same_entry(entry_1, entry_2):
if os.path.realpath(entry_1) == os.path.realpath(entry_2): if os.path.realpath(entry_1) == os.path.realpath(entry_2):
return True return True
if (os.path.basename(entry_1) == os.path.basename(entry_2) and if (os.path.basename(entry_1) == os.path.basename(entry_2) and
(os.path.basename(os.path.dirname(entry_1)) == (os.path.basename(os.path.dirname(entry_1)) ==
os.path.basename(os.path.dirname(entry_2))) and os.path.basename(os.path.dirname(entry_2))) and
os.path.basename(os.path.dirname(entry_1)).startswith('tmp')): os.path.basename(os.path.dirname(entry_1)).startswith('tmp')):
return True return True
return False return False
...@@ -429,8 +433,8 @@ def get_safe_part(key): ...@@ -429,8 +433,8 @@ def get_safe_part(key):
# Find the md5 hash part. # Find the md5 hash part.
c_link_key = key[1] c_link_key = key[1]
for key_element in c_link_key[1:]: for key_element in c_link_key[1:]:
if (isinstance(key_element, string_types) if (isinstance(key_element, string_types) and
and key_element.startswith('md5:')): key_element.startswith('md5:')):
md5 = key_element[4:] md5 = key_element[4:]
break break
...@@ -761,9 +765,9 @@ class ModuleCache(object): ...@@ -761,9 +765,9 @@ class ModuleCache(object):
# simpler to implement). # simpler to implement).
rmtree(root, ignore_nocleanup=True, rmtree(root, ignore_nocleanup=True,
msg=( msg=(
'invalid cache entry format -- this ' 'invalid cache entry format -- this '
'should not happen unless your cache ' 'should not happen unless your cache '
'was really old'), 'was really old'),
level=logging.WARN) level=logging.WARN)
continue continue
...@@ -964,7 +968,7 @@ class ModuleCache(object): ...@@ -964,7 +968,7 @@ class ModuleCache(object):
# process that could be changing the file at the same # process that could be changing the file at the same
# time. # time.
if (key[0] and not key_broken and if (key[0] and not key_broken and
self.check_for_broken_eq): self.check_for_broken_eq):
self.check_key(key, key_data.key_pkl) self.check_key(key, key_data.key_pkl)
self._update_mappings(key, key_data, module.__file__, check_in_keys=not key_broken) self._update_mappings(key, key_data, module.__file__, check_in_keys=not key_broken)
return module return module
...@@ -1149,15 +1153,14 @@ class ModuleCache(object): ...@@ -1149,15 +1153,14 @@ class ModuleCache(object):
# This is to make debugging in pdb easier, by providing # This is to make debugging in pdb easier, by providing
# the offending keys in the local context. # the offending keys in the local context.
# key_data_keys = list(key_data.keys) # key_data_keys = list(key_data.keys)
## import pdb; pdb.set_trace() # import pdb; pdb.set_trace()
pass pass
elif found > 1: elif found > 1:
msg = 'Multiple equal keys found in unpickled KeyData file' msg = 'Multiple equal keys found in unpickled KeyData file'
if msg: if msg:
raise AssertionError( raise AssertionError(
"%s. Verify the __eq__ and __hash__ functions of your " "%s. Verify the __eq__ and __hash__ functions of your "
"Ops. The file is: %s. The key is: %s" % "Ops. The file is: %s. The key is: %s" % (msg, key_pkl, key))
(msg, key_pkl, key))
# Also verify that there exists no other loaded key that would be equal # Also verify that there exists no other loaded key that would be equal
# to this key. In order to speed things up, we only compare to keys # to this key. In order to speed things up, we only compare to keys
# with the same version part and config md5, since we can assume this # with the same version part and config md5, since we can assume this
...@@ -1195,10 +1198,10 @@ class ModuleCache(object): ...@@ -1195,10 +1198,10 @@ class ModuleCache(object):
if age_thresh_del < self.age_thresh_use: if age_thresh_del < self.age_thresh_use:
if age_thresh_del > 0: if age_thresh_del > 0:
_logger.warning("Clearing modules that were not deemed " _logger.warning("Clearing modules that were not deemed "
"too old to use: age_thresh_del=%d, " "too old to use: age_thresh_del=%d, "
"self.age_thresh_use=%d", "self.age_thresh_use=%d",
age_thresh_del, age_thresh_del,
self.age_thresh_use) self.age_thresh_use)
else: else:
_logger.info("Clearing all modules.") _logger.info("Clearing all modules.")
age_thresh_use = age_thresh_del age_thresh_use = age_thresh_del
...@@ -1210,8 +1213,8 @@ class ModuleCache(object): ...@@ -1210,8 +1213,8 @@ class ModuleCache(object):
# processes and get all module that are too old to use # processes and get all module that are too old to use
# (not loaded in self.entry_from_key). # (not loaded in self.entry_from_key).
too_old_to_use = self.refresh( too_old_to_use = self.refresh(
age_thresh_use=age_thresh_use, age_thresh_use=age_thresh_use,
delete_if_problem=delete_if_problem) delete_if_problem=delete_if_problem)
for entry in too_old_to_use: for entry in too_old_to_use:
# TODO: we are assuming that modules that haven't been # TODO: we are assuming that modules that haven't been
...@@ -1242,8 +1245,8 @@ class ModuleCache(object): ...@@ -1242,8 +1245,8 @@ class ModuleCache(object):
""" """
with compilelock.lock_ctx(): with compilelock.lock_ctx():
self.clear_old( self.clear_old(
age_thresh_del=-1.0, age_thresh_del=-1.0,
delete_if_problem=delete_if_problem) delete_if_problem=delete_if_problem)
self.clear_unversioned(min_age=unversioned_min_age) self.clear_unversioned(min_age=unversioned_min_age)
if clear_base_files: if clear_base_files:
self.clear_base_files() self.clear_base_files()
...@@ -1333,7 +1336,7 @@ class ModuleCache(object): ...@@ -1333,7 +1336,7 @@ class ModuleCache(object):
if filename.startswith('tmp'): if filename.startswith('tmp'):
try: try:
open(os.path.join(self.dirname, filename, 'key.pkl') open(os.path.join(self.dirname, filename, 'key.pkl')
).close() ).close()
has_key = True has_key = True
except IOError: except IOError:
has_key = False has_key = False
...@@ -1420,8 +1423,8 @@ def get_module_cache(dirname, init_args=None): ...@@ -1420,8 +1423,8 @@ def get_module_cache(dirname, init_args=None):
'was created prior to this call') 'was created prior to this call')
if _module_cache.dirname != dirname: if _module_cache.dirname != dirname:
_logger.warning("Returning module cache instance with different " _logger.warning("Returning module cache instance with different "
"dirname (%s) than you requested (%s)", "dirname (%s) than you requested (%s)",
_module_cache.dirname, dirname) _module_cache.dirname, dirname)
return _module_cache return _module_cache
...@@ -1685,7 +1688,7 @@ class GCC_compiler(Compiler): ...@@ -1685,7 +1688,7 @@ class GCC_compiler(Compiler):
break break
if ('g++' not in theano.config.cxx and if ('g++' not in theano.config.cxx and
'clang++' not in theano.config.cxx): 'clang++' not in theano.config.cxx):
_logger.warn( _logger.warn(
"OPTIMIZATION WARNING: your Theano flag `cxx` seems not to be" "OPTIMIZATION WARNING: your Theano flag `cxx` seems not to be"
" the g++ compiler. So we disable the compiler optimization" " the g++ compiler. So we disable the compiler optimization"
...@@ -1719,9 +1722,9 @@ class GCC_compiler(Compiler): ...@@ -1719,9 +1722,9 @@ class GCC_compiler(Compiler):
selected_lines = [] selected_lines = []
for line in lines: for line in lines:
if ("COLLECT_GCC_OPTIONS=" in line or if ("COLLECT_GCC_OPTIONS=" in line or
"CFLAGS=" in line or "CFLAGS=" in line or
"CXXFLAGS=" in line or "CXXFLAGS=" in line or
"-march=native" in line): "-march=native" in line):
continue continue
elif "-march=" in line: elif "-march=" in line:
selected_lines.append(line.strip()) selected_lines.append(line.strip())
...@@ -1805,9 +1808,9 @@ class GCC_compiler(Compiler): ...@@ -1805,9 +1808,9 @@ class GCC_compiler(Compiler):
for line in default_lines: for line in default_lines:
if line.startswith(part[0]): if line.startswith(part[0]):
part2 = [p for p in join_options(line.split()) part2 = [p for p in join_options(line.split())
if (not 'march' in p and if ('march' not in p and
not 'mtune' in p and 'mtune' not in p and
not 'target-cpu' in p)] 'target-cpu' not in p)]
new_flags = [p for p in part if p not in part2] new_flags = [p for p in part if p not in part2]
# Replace '-target-cpu value', which is an option # Replace '-target-cpu value', which is an option
# of clang, with '-march=value', for g++ # of clang, with '-march=value', for g++
...@@ -2021,14 +2024,13 @@ class GCC_compiler(Compiler): ...@@ -2021,14 +2024,13 @@ class GCC_compiler(Compiler):
cmd.append(cppfilename) cmd.append(cppfilename)
cmd.extend(['-L%s' % ldir for ldir in lib_dirs]) cmd.extend(['-L%s' % ldir for ldir in lib_dirs])
cmd.extend(['-l%s' % l for l in libs]) cmd.extend(['-l%s' % l for l in libs])
#print >> sys.stderr, 'COMPILING W CMD', cmd # print >> sys.stderr, 'COMPILING W CMD', cmd
_logger.debug('Running cmd: %s', ' '.join(cmd)) _logger.debug('Running cmd: %s', ' '.join(cmd))
def print_command_line_error(): def print_command_line_error():
# Print command line when a problem occurred. # Print command line when a problem occurred.
print(( print(("Problem occurred during compilation with the "
"Problem occurred during compilation with the " "command line below:"), file=sys.stderr)
"command line below:"), file=sys.stderr)
print(' '.join(cmd), file=sys.stderr) print(' '.join(cmd), file=sys.stderr)
try: try:
......
...@@ -46,7 +46,6 @@ def _contains_cycle(fgraph, orderings): ...@@ -46,7 +46,6 @@ def _contains_cycle(fgraph, orderings):
""" """
# These are lists of Variable instances # These are lists of Variable instances
inputs = fgraph.inputs
outputs = fgraph.outputs outputs = fgraph.outputs
# this is hard-coded reimplementation of functions from graph.py # this is hard-coded reimplementation of functions from graph.py
...@@ -65,8 +64,6 @@ def _contains_cycle(fgraph, orderings): ...@@ -65,8 +64,6 @@ def _contains_cycle(fgraph, orderings):
# (defaultdict runs faster than dict in the case where the key # (defaultdict runs faster than dict in the case where the key
# is not in the dictionary, at least in CPython) # is not in the dictionary, at least in CPython)
iset = set(inputs)
# IG: I tried converting parent_counts to use an id for the key, # IG: I tried converting parent_counts to use an id for the key,
# so that the dict would do reference counting on its keys. # so that the dict would do reference counting on its keys.
# This caused a slowdown. # This caused a slowdown.
...@@ -236,9 +233,9 @@ def fast_inplace_check(inputs): ...@@ -236,9 +233,9 @@ def fast_inplace_check(inputs):
protected_inputs.extend(fgraph.outputs) protected_inputs.extend(fgraph.outputs)
inputs = [i for i in inputs if inputs = [i for i in inputs if
not isinstance(i, graph.Constant) not isinstance(i, graph.Constant) and
and not fgraph.destroyers(i) not fgraph.destroyers(i) and
and i not in protected_inputs] i not in protected_inputs]
return inputs return inputs
if 0: if 0:
...@@ -293,7 +290,7 @@ if 0: ...@@ -293,7 +290,7 @@ if 0:
TODO: WRITEME: what does this do besides the checks? TODO: WRITEME: what does this do besides the checks?
""" """
####### Do the checking ########### # Do the checking #
already_there = False already_there = False
if self.fgraph not in [None, fgraph]: if self.fgraph not in [None, fgraph]:
raise Exception("A DestroyHandler instance can only serve" raise Exception("A DestroyHandler instance can only serve"
...@@ -309,7 +306,7 @@ if 0: ...@@ -309,7 +306,7 @@ if 0:
"DestroyHandler feature is already present or in" "DestroyHandler feature is already present or in"
" conflict with another plugin.") " conflict with another plugin.")
####### end of checking ############ # end of checking #
def get_destroyers_of(r): def get_destroyers_of(r):
droot, impact, root_destroyer = self.refresh_droot_impact() droot, impact, root_destroyer = self.refresh_droot_impact()
...@@ -362,8 +359,8 @@ if 0: ...@@ -362,8 +359,8 @@ if 0:
"Multiple destroyers of %s" % input_root) "Multiple destroyers of %s" % input_root)
droot[input_root] = input_root droot[input_root] = input_root
root_destroyer[input_root] = app root_destroyer[input_root] = app
#input_impact = set([input_root]) # input_impact = set([input_root])
#add_impact(input_root, self.view_o, input_impact) # add_impact(input_root, self.view_o, input_impact)
input_impact = get_impact(input_root, self.view_o) input_impact = get_impact(input_root, self.view_o)
for v in input_impact: for v in input_impact:
assert v not in droot assert v not in droot
...@@ -390,7 +387,7 @@ if 0: ...@@ -390,7 +387,7 @@ if 0:
def on_import(self, fgraph, app, reason): def on_import(self, fgraph, app, reason):
"""Add Apply instance to set which must be computed""" """Add Apply instance to set which must be computed"""
#if app in self.debug_all_apps: raise ProtocolError("double import") # if app in self.debug_all_apps: raise ProtocolError("double import")
# self.debug_all_apps.add(app) # self.debug_all_apps.add(app)
# print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps) # print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps)
...@@ -421,7 +418,7 @@ if 0: ...@@ -421,7 +418,7 @@ if 0:
def on_prune(self, fgraph, app, reason): def on_prune(self, fgraph, app, reason):
"""Remove Apply instance from set which must be computed""" """Remove Apply instance from set which must be computed"""
#if app not in self.debug_all_apps: raise ProtocolError("prune without import") # if app not in self.debug_all_apps: raise ProtocolError("prune without import")
# self.debug_all_apps.remove(app) # self.debug_all_apps.remove(app)
# UPDATE self.clients # UPDATE self.clients
...@@ -458,7 +455,7 @@ if 0: ...@@ -458,7 +455,7 @@ if 0:
# considered 'outputs' of the graph. # considered 'outputs' of the graph.
pass pass
else: else:
#if app not in self.debug_all_apps: raise ProtocolError("change without import") # if app not in self.debug_all_apps: raise ProtocolError("change without import")
# UPDATE self.clients # UPDATE self.clients
self.clients[old_r][app] -= 1 self.clients[old_r][app] -= 1
...@@ -529,9 +526,10 @@ if 0: ...@@ -529,9 +526,10 @@ if 0:
droot, impact, __ignore = self.refresh_droot_impact() droot, impact, __ignore = self.refresh_droot_impact()
# check for destruction of constants # check for destruction of constants
illegal_destroy = [r for r in droot if illegal_destroy = [
getattr(r.tag, 'indestructible', False) or r for r in droot if
isinstance(r, graph.Constant)] getattr(r.tag, 'indestructible', False) or
isinstance(r, graph.Constant)]
if illegal_destroy: if illegal_destroy:
# print 'destroying illegally' # print 'destroying illegally'
raise InconsistencyError( raise InconsistencyError(
...@@ -603,7 +601,7 @@ if 0: ...@@ -603,7 +601,7 @@ if 0:
if input in root_impact \ if input in root_impact \
and (i not in tolerated or input is not destroyed_variable): and (i not in tolerated or input is not destroyed_variable):
raise InconsistencyError("Input aliasing: %s (%i, %i)" raise InconsistencyError("Input aliasing: %s (%i, %i)"
% (app, destroyed_idx, i)) % (app, destroyed_idx, i))
# add the rule: app must be preceded by all other Apply instances that # add the rule: app must be preceded by all other Apply instances that
# depend on destroyed_input # depend on destroyed_input
...@@ -621,7 +619,7 @@ if 0: ...@@ -621,7 +619,7 @@ if 0:
return rval return rval
class DestroyHandler(toolbox.Bookkeeper): class DestroyHandler(toolbox.Bookkeeper): # noqa
""" """
The DestroyHandler class detects when a graph is impossible to evaluate The DestroyHandler class detects when a graph is impossible to evaluate
because of aliasing and destructive operations. because of aliasing and destructive operations.
...@@ -702,7 +700,7 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -702,7 +700,7 @@ class DestroyHandler(toolbox.Bookkeeper):
TODO: WRITEME: what does this do besides the checks? TODO: WRITEME: what does this do besides the checks?
""" """
####### Do the checking ########### # Do the checking #
already_there = False already_there = False
if self.fgraph is fgraph: if self.fgraph is fgraph:
already_there = True already_there = True
...@@ -720,7 +718,7 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -720,7 +718,7 @@ class DestroyHandler(toolbox.Bookkeeper):
"DestroyHandler feature is already present" "DestroyHandler feature is already present"
" or in conflict with another plugin.") " or in conflict with another plugin.")
####### Annotate the FunctionGraph ############ # Annotate the FunctionGraph #
self.unpickle(fgraph) self.unpickle(fgraph)
fgraph.destroy_handler = self fgraph.destroy_handler = self
...@@ -945,12 +943,13 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -945,12 +943,13 @@ class DestroyHandler(toolbox.Bookkeeper):
droot, impact, __ignore = self.refresh_droot_impact() droot, impact, __ignore = self.refresh_droot_impact()
# check for destruction of constants # check for destruction of constants
illegal_destroy = [r for r in droot if \ illegal_destroy = [r for r in droot if
getattr(r.tag, 'indestructible', False) or \ getattr(r.tag, 'indestructible', False) or
isinstance(r, graph.Constant)] isinstance(r, graph.Constant)]
if illegal_destroy: if illegal_destroy:
raise InconsistencyError("Attempting to destroy indestructible variables: %s" % raise InconsistencyError(
illegal_destroy) "Attempting to destroy indestructible variables: %s" %
illegal_destroy)
# add destroyed variable clients as computational dependencies # add destroyed variable clients as computational dependencies
for app in self.destroyers: for app in self.destroyers:
...@@ -995,24 +994,27 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -995,24 +994,27 @@ class DestroyHandler(toolbox.Bookkeeper):
# CHECK FOR INPUT ALIASING # CHECK FOR INPUT ALIASING
# OPT: pre-compute this on import # OPT: pre-compute this on import
tolerate_same = getattr(app.op, 'destroyhandler_tolerate_same', []) tolerate_same = getattr(app.op,
'destroyhandler_tolerate_same', [])
assert isinstance(tolerate_same, list) assert isinstance(tolerate_same, list)
tolerated = OrderedSet(idx1 for idx0, idx1 in tolerate_same tolerated = OrderedSet(idx1 for idx0, idx1 in tolerate_same
if idx0 == destroyed_idx) if idx0 == destroyed_idx)
tolerated.add(destroyed_idx) tolerated.add(destroyed_idx)
tolerate_aliased = getattr(app.op, 'destroyhandler_tolerate_aliased', []) tolerate_aliased = getattr(
app.op, 'destroyhandler_tolerate_aliased', [])
assert isinstance(tolerate_aliased, list) assert isinstance(tolerate_aliased, list)
ignored = OrderedSet(idx1 for idx0, idx1 in tolerate_aliased ignored = OrderedSet(idx1 for idx0, idx1 in tolerate_aliased
if idx0 == destroyed_idx) if idx0 == destroyed_idx)
# print 'tolerated', tolerated # print 'tolerated', tolerated
# print 'ignored', ignored # print 'ignored', ignored
for i, input in enumerate(app.inputs): for i, input in enumerate(app.inputs):
if i in ignored: if i in ignored:
continue continue
if input in root_impact \ if input in root_impact \
and (i not in tolerated or input is not destroyed_variable): and (i not in tolerated or
input is not destroyed_variable):
raise InconsistencyError("Input aliasing: %s (%i, %i)" raise InconsistencyError("Input aliasing: %s (%i, %i)"
% (app, destroyed_idx, i)) % (app, destroyed_idx, i))
# add the rule: app must be preceded by all other Apply instances that # add the rule: app must be preceded by all other Apply instances that
# depend on destroyed_input # depend on destroyed_input
......
...@@ -13,7 +13,6 @@ from theano.gof import graph ...@@ -13,7 +13,6 @@ from theano.gof import graph
from theano.gof import utils from theano.gof import utils
from theano.gof import toolbox from theano.gof import toolbox
from theano import config from theano import config
import warnings
from theano.compat import OrderedDict from theano.compat import OrderedDict
from six import iteritems, itervalues from six import iteritems, itervalues
...@@ -22,6 +21,7 @@ from theano.misc.ordered_set import OrderedSet ...@@ -22,6 +21,7 @@ from theano.misc.ordered_set import OrderedSet
NullType = None NullType = None
class CachedConstantError(Exception): class CachedConstantError(Exception):
"""An exception thrown when we put in a FunctionGraph a Constant """An exception thrown when we put in a FunctionGraph a Constant
that is cached. This should not happen as the user can reuse this that is cached. This should not happen as the user can reuse this
...@@ -143,7 +143,7 @@ class FunctionGraph(utils.object2): ...@@ -143,7 +143,7 @@ class FunctionGraph(utils.object2):
self.variable_locks = {} self.variable_locks = {}
self.profile = None self.profile = None
### Setup a Variable ### # Setup a Variable #
def __setup_r__(self, r): def __setup_r__(self, r):
# sets up r so it belongs to this fgraph # sets up r so it belongs to this fgraph
if getattr(r, 'cached', False): if getattr(r, 'cached', False):
...@@ -152,12 +152,12 @@ class FunctionGraph(utils.object2): ...@@ -152,12 +152,12 @@ class FunctionGraph(utils.object2):
" graph that has a cached constant. This should not happen." " graph that has a cached constant. This should not happen."
" Clone the graph before building the FunctionGraph.") " Clone the graph before building the FunctionGraph.")
if (hasattr(r, 'fgraph') and if (hasattr(r, 'fgraph') and
r.fgraph is not None and r.fgraph is not None and
r.fgraph is not self): r.fgraph is not self):
raise Exception("%s is already owned by another fgraph" % r) raise Exception("%s is already owned by another fgraph" % r)
r.fgraph = self r.fgraph = self
r.clients = [] r.clients = []
#self.execute_callbacks('on_setup_variable', r) # self.execute_callbacks('on_setup_variable', r)
def __setup_node__(self, node): def __setup_node__(self, node):
# sets up node so it belongs to this fgraph # sets up node so it belongs to this fgraph
...@@ -177,7 +177,7 @@ class FunctionGraph(utils.object2): ...@@ -177,7 +177,7 @@ class FunctionGraph(utils.object2):
str(node.op), str(node.op.destroy_map))) str(node.op), str(node.op.destroy_map)))
node.fgraph = self node.fgraph = self
node.deps = {} node.deps = {}
#self.execute_callbacks('on_setup_node', node) # self.execute_callbacks('on_setup_node', node)
def disown(self): def disown(self):
""" WRITEME """ WRITEME
...@@ -201,7 +201,7 @@ class FunctionGraph(utils.object2): ...@@ -201,7 +201,7 @@ class FunctionGraph(utils.object2):
self.inputs = None self.inputs = None
self.outputs = None self.outputs = None
### clients ### # clients #
def clients(self, r): def clients(self, r):
""" """
Set of all the (node, i) pairs such that node.inputs[i] is r. Set of all the (node, i) pairs such that node.inputs[i] is r.
...@@ -221,9 +221,9 @@ class FunctionGraph(utils.object2): ...@@ -221,9 +221,9 @@ class FunctionGraph(utils.object2):
if set(r.clients).intersection(set(new_clients)): if set(r.clients).intersection(set(new_clients)):
print('ERROR: clients intersect!', file=sys.stderr) print('ERROR: clients intersect!', file=sys.stderr)
print(' RCLIENTS of', r, [(n, i, type(n), id(n)) print(' RCLIENTS of', r, [(n, i, type(n), id(n))
for n, i in r.clients], file=sys.stderr) for n, i in r.clients], file=sys.stderr)
print(' NCLIENTS of', r, [(n, i, type(n), id(n)) print(' NCLIENTS of', r, [(n, i, type(n), id(n))
for n, i in new_clients], file=sys.stderr) for n, i in new_clients], file=sys.stderr)
assert not set(r.clients).intersection(set(new_clients)) assert not set(r.clients).intersection(set(new_clients))
r.clients += new_clients r.clients += new_clients
...@@ -245,7 +245,7 @@ class FunctionGraph(utils.object2): ...@@ -245,7 +245,7 @@ class FunctionGraph(utils.object2):
return True return True
return False return False
### import ### # import #
def __import_r__(self, variable, reason): def __import_r__(self, variable, reason):
global NullType global NullType
if NullType is None: if NullType is None:
...@@ -279,9 +279,8 @@ class FunctionGraph(utils.object2): ...@@ -279,9 +279,8 @@ class FunctionGraph(utils.object2):
if hasattr(r, 'fgraph') and r.fgraph is not self: if hasattr(r, 'fgraph') and r.fgraph is not self:
raise Exception("%s is already owned by another fgraph" % r) raise Exception("%s is already owned by another fgraph" % r)
if (r.owner is None and if (r.owner is None and
not isinstance(r, graph.Constant) and not isinstance(r, graph.Constant) and
r not in self.inputs): r not in self.inputs):
# Verbose error message # Verbose error message
# Show a complete chain of variables from the missing input to an output # Show a complete chain of variables from the missing input to an output
if config.exception_verbosity == 'high': if config.exception_verbosity == 'high':
...@@ -373,7 +372,7 @@ class FunctionGraph(utils.object2): ...@@ -373,7 +372,7 @@ class FunctionGraph(utils.object2):
assert node.fgraph is self assert node.fgraph is self
self.execute_callbacks('on_import', node, reason) self.execute_callbacks('on_import', node, reason)
### prune ### # prune #
def __prune_r__(self, variable, reason=None): def __prune_r__(self, variable, reason=None):
"""Should be called for variable that aren't used anymore: """Should be called for variable that aren't used anymore:
len(var.clients) == 0 len(var.clients) == 0
...@@ -430,7 +429,7 @@ class FunctionGraph(utils.object2): ...@@ -430,7 +429,7 @@ class FunctionGraph(utils.object2):
self.__remove_clients__(input, [(apply_node, i)], reason=reason) self.__remove_clients__(input, [(apply_node, i)], reason=reason)
# self.__prune_r__(apply_node.inputs) # self.__prune_r__(apply_node.inputs)
### change input ### # change input #
def change_input(self, node, i, new_r, reason=None): def change_input(self, node, i, new_r, reason=None):
"""WRITEME """WRITEME
Changes node.inputs[i] to new_r. Changes node.inputs[i] to new_r.
...@@ -475,7 +474,7 @@ class FunctionGraph(utils.object2): ...@@ -475,7 +474,7 @@ class FunctionGraph(utils.object2):
if prune: if prune:
self.__prune_r__(r, reason=reason) self.__prune_r__(r, reason=reason)
### replace ### # replace #
def replace(self, r, new_r, reason=None, verbose=None): def replace(self, r, new_r, reason=None, verbose=None):
""" WRITEME """ WRITEME
This is the main interface to manipulate the subgraph in FunctionGraph. This is the main interface to manipulate the subgraph in FunctionGraph.
...@@ -582,7 +581,7 @@ class FunctionGraph(utils.object2): ...@@ -582,7 +581,7 @@ class FunctionGraph(utils.object2):
if detach is not None: if detach is not None:
detach(self) detach(self)
### callback utils ### # callback utils #
def execute_callbacks(self, name, *args, **kwargs): def execute_callbacks(self, name, *args, **kwargs):
"""WRITEME """WRITEME
Calls Calls
...@@ -618,7 +617,7 @@ class FunctionGraph(utils.object2): ...@@ -618,7 +617,7 @@ class FunctionGraph(utils.object2):
d[feature] = fn(*args) d[feature] = fn(*args)
return d return d
### misc ### # misc #
def toposort(self): def toposort(self):
"""WRITEME """WRITEME
Returns an ordering of the graph's Apply nodes such that: Returns an ordering of the graph's Apply nodes such that:
...@@ -712,8 +711,8 @@ class FunctionGraph(utils.object2): ...@@ -712,8 +711,8 @@ class FunctionGraph(utils.object2):
missing, excess) missing, excess)
for variable in variables: for variable in variables:
if (variable.owner is None and if (variable.owner is None and
variable not in self.inputs and variable not in self.inputs and
not isinstance(variable, graph.Constant)): not isinstance(variable, graph.Constant)):
raise Exception("Undeclared input.", variable) raise Exception("Undeclared input.", variable)
if variable.fgraph is not self: if variable.fgraph is not self:
raise Exception("Variable should belong to the FunctionGraph.", raise Exception("Variable should belong to the FunctionGraph.",
...@@ -737,7 +736,7 @@ class FunctionGraph(utils.object2): ...@@ -737,7 +736,7 @@ class FunctionGraph(utils.object2):
def __repr__(self): def __repr__(self):
return self.__str__() return self.__str__()
### clone ### # clone #
def clone(self, check_integrity=True): def clone(self, check_integrity=True):
"""WRITEME""" """WRITEME"""
return self.clone_get_equiv(check_integrity)[0] return self.clone_get_equiv(check_integrity)[0]
......
...@@ -7,14 +7,14 @@ import traceback ...@@ -7,14 +7,14 @@ import traceback
import numpy import numpy
import theano import theano
from theano.compat import PY3, izip from theano.compat import izip
from six import reraise from six import reraise
from six.moves import StringIO from six.moves import StringIO
from theano.gof import utils from theano.gof import utils
from theano.gof import graph from theano.gof import graph
from theano.gof.type import Type from theano.gof.type import Type
from .utils import MethodNotDefined, undef from .utils import undef
__excepthook = sys.excepthook __excepthook = sys.excepthook
...@@ -281,9 +281,9 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None): ...@@ -281,9 +281,9 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None):
else: else:
detailed_err_msg += "\n" detailed_err_msg += "\n"
detailed_err_msg += " TotalSize: %s Byte(s) %.3f GB\n" % ( detailed_err_msg += " TotalSize: %s Byte(s) %.3f GB\n" % (
total_size, total_size/1024./1024/1024) total_size, total_size / 1024. / 1024 / 1024)
detailed_err_msg += " TotalSize inputs: %s Byte(s) %.3f BG\n" % ( detailed_err_msg += " TotalSize inputs: %s Byte(s) %.3f BG\n" % (
total_size_inputs, total_size_inputs/1024./1024/1024) total_size_inputs, total_size_inputs / 1024. / 1024 / 1024)
else: else:
hints.append( hints.append(
...@@ -326,7 +326,7 @@ class Linker(object): ...@@ -326,7 +326,7 @@ class Linker(object):
raise utils.MethodNotDefined("make_thunk", type(self), raise utils.MethodNotDefined("make_thunk", type(self),
self.__class__.__name__) self.__class__.__name__)
## DELETEME ## # DELETEME #
def make_function(self, unpack_single=True, **kwargs): def make_function(self, unpack_single=True, **kwargs):
""" """
Returns a function that takes values corresponding to the inputs of the Returns a function that takes values corresponding to the inputs of the
...@@ -350,8 +350,8 @@ class Linker(object): ...@@ -350,8 +350,8 @@ class Linker(object):
def execute(*args): def execute(*args):
def e_arity(takes, got): def e_arity(takes, got):
return 'Function call takes exactly %i %s (%i given)' \ return 'Function call takes exactly %i %s (%i given)' % (
% (takes, ['argument', 'arguments'][takes > 1], got) takes, ['argument', 'arguments'][takes > 1], got)
if (len(args) != len(inputs)): if (len(args) != len(inputs)):
raise TypeError(e_arity(len(inputs), len(args))) raise TypeError(e_arity(len(inputs), len(args)))
for arg, variable in izip(args, inputs): for arg, variable in izip(args, inputs):
...@@ -394,7 +394,7 @@ class Container(object): ...@@ -394,7 +394,7 @@ class Container(object):
""" """
if not isinstance(storage, list) or not len(storage) >= 1: if not isinstance(storage, list) or not len(storage) >= 1:
raise TypeError("storage must be a list of length at least one") raise TypeError("storage must be a list of length at least one")
#self.r = r # self.r = r
if isinstance(r, Type): if isinstance(r, Type):
self.type = r self.type = r
else: else:
...@@ -454,12 +454,11 @@ class Container(object): ...@@ -454,12 +454,11 @@ class Container(object):
deepcopy(self.strict, memo=memo), deepcopy(self.strict, memo=memo),
deepcopy(self.allow_downcast, memo=memo), deepcopy(self.allow_downcast, memo=memo),
deepcopy(self.name, memo=memo), deepcopy(self.name, memo=memo),
) )
# Work around NumPy deepcopy of ndarray with 0 dimention that # Work around NumPy deepcopy of ndarray with 0 dimention that
# don't return an ndarray. # don't return an ndarray.
if (r.storage[0] is not None and if (r.storage[0] is not None and
not self.type.is_valid_value(r.storage[0])): not self.type.is_valid_value(r.storage[0])):
assert not data_was_in_memo assert not data_was_in_memo
assert self.type.is_valid_value(self.storage[0]) assert self.type.is_valid_value(self.storage[0])
# This should also work for read only container. # This should also work for read only container.
...@@ -672,7 +671,7 @@ class PerformLinker(LocalLinker): ...@@ -672,7 +671,7 @@ class PerformLinker(LocalLinker):
no_recycling = [] no_recycling = []
if self.fgraph is not None and self.fgraph is not fgraph: if self.fgraph is not None and self.fgraph is not fgraph:
return type(self)(allow_gc=self.allow_gc).accept(fgraph, no_recycling) return type(self)(allow_gc=self.allow_gc).accept(fgraph, no_recycling)
#raise Exception("Cannot accept from a Linker that is already tied to another FunctionGraph.") # raise Exception("Cannot accept from a Linker that is already tied to another FunctionGraph.")
self.fgraph = fgraph self.fgraph = fgraph
self.no_recycling = no_recycling self.no_recycling = no_recycling
return self return self
...@@ -721,9 +720,12 @@ class PerformLinker(LocalLinker): ...@@ -721,9 +720,12 @@ class PerformLinker(LocalLinker):
for node in order: for node in order:
if self.allow_gc: if self.allow_gc:
post_thunk_old_storage.append([storage_map[input] post_thunk_old_storage.append(
for input in node.inputs [storage_map[input]
if (input in computed) and (input not in fgraph.outputs) and node == last_user[input]]) for input in node.inputs
if (input in computed) and (
input not in fgraph.outputs) and (
node == last_user[input])])
if no_recycling is True: if no_recycling is True:
# True seems like some special code for *everything*?? -JB # True seems like some special code for *everything*?? -JB
...@@ -855,7 +857,7 @@ class WrapLinker(Linker): ...@@ -855,7 +857,7 @@ class WrapLinker(Linker):
make_all += [l.make_all(**kwargs) for l in self.linkers[1:]] make_all += [l.make_all(**kwargs) for l in self.linkers[1:]]
fns, input_lists, output_lists, thunk_lists, order_lists \ fns, input_lists, output_lists, thunk_lists, order_lists \
= zip(*make_all) = zip(*make_all)
order_list0 = order_lists[0] order_list0 = order_lists[0]
for order_list in order_lists[1:]: for order_list in order_lists[1:]:
......
...@@ -29,6 +29,7 @@ from . import destroyhandler as dh ...@@ -29,6 +29,7 @@ from . import destroyhandler as dh
_logger = logging.getLogger('theano.gof.opt') _logger = logging.getLogger('theano.gof.opt')
_optimizer_idx = [0] _optimizer_idx = [0]
def _list_of_nodes(fgraph): def _list_of_nodes(fgraph):
return list(graph.io_toposort(fgraph.inputs, fgraph.outputs)) return list(graph.io_toposort(fgraph.inputs, fgraph.outputs))
...@@ -99,7 +100,7 @@ class Optimizer(object): ...@@ -99,7 +100,7 @@ class Optimizer(object):
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
name = getattr(self, 'name', None) name = getattr(self, 'name', None)
print("%s%s %s id=%i" % ( print("%s%s %s id=%i" % (
(' ' * level), self.__class__.__name__, name, id(self)), file=stream) (' ' * level), self.__class__.__name__, name, id(self)), file=stream)
@staticmethod @staticmethod
def print_profile(stream, prof, level=0): def print_profile(stream, prof, level=0):
...@@ -121,9 +122,9 @@ class FromFunctionOptimizer(Optimizer): ...@@ -121,9 +122,9 @@ class FromFunctionOptimizer(Optimizer):
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print("%s%s id=%i" % ( print("%s%s id=%i" % (
' ' * level, ' ' * level,
str(self.apply), str(self.apply),
id(self)), file=stream) id(self)), file=stream)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.fn(*args, **kwargs) return self.fn(*args, **kwargs)
...@@ -222,7 +223,7 @@ class SeqOptimizer(Optimizer, list): ...@@ -222,7 +223,7 @@ class SeqOptimizer(Optimizer, list):
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
name = getattr(self, 'name', None) name = getattr(self, 'name', None)
print("%s%s %s id=%i" % ( print("%s%s %s id=%i" % (
(' ' * level), self.__class__.__name__, name, id(self)), file=stream) (' ' * level), self.__class__.__name__, name, id(self)), file=stream)
# This way, -1 will do all depth # This way, -1 will do all depth
if depth != 0: if depth != 0:
depth -= 1 depth -= 1
...@@ -241,8 +242,8 @@ class SeqOptimizer(Optimizer, list): ...@@ -241,8 +242,8 @@ class SeqOptimizer(Optimizer, list):
elif hasattr(opts, "__name__"): elif hasattr(opts, "__name__"):
print(blanc, opts.__name__, end=' ', file=stream) print(blanc, opts.__name__, end=' ', file=stream)
print((" time %.3fs for %d/%d nodes" print((" time %.3fs for %d/%d nodes"
" before/after optimization" % ( " before/after optimization" % (
sum(prof), nb_node_before, nb_node_after)), file=stream) sum(prof), nb_node_before, nb_node_after)), file=stream)
print(blanc, " %.3fs for fgraph.validate()" % (validate_time), file=stream) print(blanc, " %.3fs for fgraph.validate()" % (validate_time), file=stream)
print(blanc, " %.3fs for callback" % (callback_time), file=stream) print(blanc, " %.3fs for callback" % (callback_time), file=stream)
if level == 0: if level == 0:
...@@ -324,7 +325,7 @@ class SeqOptimizer(Optimizer, list): ...@@ -324,7 +325,7 @@ class SeqOptimizer(Optimizer, list):
new_t[idx] += p[1][p[0].index(l)] new_t[idx] += p[1][p[0].index(l)]
if hasattr(l, 'merge_profile'): if hasattr(l, 'merge_profile'):
assert len(p[6][p[0].index(l)]) == \ assert len(p[6][p[0].index(l)]) == \
len(new_sub_profile[idx]) len(new_sub_profile[idx])
new_sub_profile[idx] = l.merge_profile( new_sub_profile[idx] = l.merge_profile(
new_sub_profile[idx], p[6][p[0].index(l)]) new_sub_profile[idx], p[6][p[0].index(l)])
else: else:
...@@ -729,6 +730,7 @@ def pre_constant_merge(vars): ...@@ -729,6 +730,7 @@ def pre_constant_merge(vars):
const_sig_inv = {} const_sig_inv = {}
if isinstance(vars, graph.Variable): if isinstance(vars, graph.Variable):
vars = [vars] vars = [vars]
def recursive_merge(var): def recursive_merge(var):
if var in seen_var: if var in seen_var:
return var return var
...@@ -761,7 +763,7 @@ def pre_constant_merge(vars): ...@@ -761,7 +763,7 @@ def pre_constant_merge(vars):
######################## ########################
### Local Optimizers ### # Local Optimizers #
######################## ########################
class LocalOptimizer(object): class LocalOptimizer(object):
...@@ -817,12 +819,14 @@ class LocalOptimizer(object): ...@@ -817,12 +819,14 @@ class LocalOptimizer(object):
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print("%s%s id=%i" % ( print("%s%s id=%i" % (
(' ' * level), self.__class__.__name__, id(self)), file=stream) (' ' * level), self.__class__.__name__, id(self)), file=stream)
theano.configparser.AddConfigVar('metaopt.verbose', theano.configparser.AddConfigVar(
"Enable verbose output for meta optimizers", 'metaopt.verbose',
theano.configparser.BoolParam(False), in_c_key=False) "Enable verbose output for meta optimizers",
theano.configparser.BoolParam(False),
in_c_key=False)
class LocalMetaOptimizer(LocalOptimizer): class LocalMetaOptimizer(LocalOptimizer):
...@@ -933,9 +937,9 @@ class FromFunctionLocalOptimizer(LocalOptimizer): ...@@ -933,9 +937,9 @@ class FromFunctionLocalOptimizer(LocalOptimizer):
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print("%s%s id=%i" % ( print("%s%s id=%i" % (
' ' * level, ' ' * level,
str(self.transform), str(self.transform),
id(self)), file=stream) id(self)), file=stream)
def local_optimizer(tracks, inplace=False): def local_optimizer(tracks, inplace=False):
...@@ -992,7 +996,7 @@ class LocalOptGroup(LocalOptimizer): ...@@ -992,7 +996,7 @@ class LocalOptGroup(LocalOptimizer):
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print("%s%s id=%i" % ( print("%s%s id=%i" % (
(' ' * level), self.__class__.__name__, id(self)), file=stream) (' ' * level), self.__class__.__name__, id(self)), file=stream)
if depth != 0: if depth != 0:
depth -= 1 depth -= 1
for lopt in self.opts: for lopt in self.opts:
...@@ -1003,19 +1007,6 @@ class LocalOptGroup(LocalOptimizer): ...@@ -1003,19 +1007,6 @@ class LocalOptGroup(LocalOptimizer):
opt.add_requirements(fgraph) opt.add_requirements(fgraph)
class _LocalOpKeyOptGroup(LocalOptGroup):
"""WRITEME"""
def __init__(self, optimizers):
if any(not hasattr(opt, 'op_key'), optimizers):
raise TypeError(
"All LocalOptimizers passed here must have an op_key method.")
CompositeLocalOptimizer.__init__(self, optimizers)
def op_key(self):
return [opt.op_key() for opt in self.opts]
class OpSub(LocalOptimizer): class OpSub(LocalOptimizer):
"""WRITEME """WRITEME
Replaces the application of a certain op by the application of Replaces the application of a certain op by the application of
...@@ -1086,10 +1077,10 @@ class OpRemove(LocalOptimizer): ...@@ -1086,10 +1077,10 @@ class OpRemove(LocalOptimizer):
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print("%s%s(%s) id=%i" % ( print("%s%s(%s) id=%i" % (
' ' * level, ' ' * level,
self.__class__.__name__, self.__class__.__name__,
str(self.op), str(self.op),
id(self)), file=stream) id(self)), file=stream)
class PatternSub(LocalOptimizer): class PatternSub(LocalOptimizer):
...@@ -1217,6 +1208,7 @@ class PatternSub(LocalOptimizer): ...@@ -1217,6 +1208,7 @@ class PatternSub(LocalOptimizer):
if node.op != self.op: if node.op != self.op:
return False return False
# TODO: if we remove pdb, do this speed things up? # TODO: if we remove pdb, do this speed things up?
def match(pattern, expr, u, allow_multiple_clients=False, pdb=False): def match(pattern, expr, u, allow_multiple_clients=False, pdb=False):
# TODO move outside match # TODO move outside match
def retry_with_equiv(): def retry_with_equiv():
...@@ -1233,9 +1225,8 @@ class PatternSub(LocalOptimizer): ...@@ -1233,9 +1225,8 @@ class PatternSub(LocalOptimizer):
if isinstance(pattern, (list, tuple)): if isinstance(pattern, (list, tuple)):
if expr.owner is None: if expr.owner is None:
return False return False
if (not (expr.owner.op == pattern[0]) if (not (expr.owner.op == pattern[0]) or
or (not allow_multiple_clients (not allow_multiple_clients and len(expr.clients) > 1)):
and len(expr.clients) > 1)):
return retry_with_equiv() return retry_with_equiv()
if len(pattern) - 1 != len(expr.owner.inputs): if len(pattern) - 1 != len(expr.owner.inputs):
return retry_with_equiv() return retry_with_equiv()
...@@ -1263,16 +1254,16 @@ class PatternSub(LocalOptimizer): ...@@ -1263,16 +1254,16 @@ class PatternSub(LocalOptimizer):
return retry_with_equiv() return retry_with_equiv()
else: else:
u = u.merge(expr, v) u = u.merge(expr, v)
elif (isinstance(pattern, (int, float)) elif (isinstance(pattern, (int, float)) and
and isinstance(expr, graph.Constant)): isinstance(expr, graph.Constant)):
if numpy.all( if numpy.all(
theano.tensor.constant(pattern).value == expr.value): theano.tensor.constant(pattern).value == expr.value):
return u return u
else: else:
return retry_with_equiv() return retry_with_equiv()
elif (isinstance(pattern, graph.Constant) elif (isinstance(pattern, graph.Constant) and
and isinstance(expr, graph.Constant) isinstance(expr, graph.Constant) and
and pattern.equals(expr)): pattern.equals(expr)):
return u return u
else: else:
return retry_with_equiv() return retry_with_equiv()
...@@ -1308,17 +1299,17 @@ class PatternSub(LocalOptimizer): ...@@ -1308,17 +1299,17 @@ class PatternSub(LocalOptimizer):
def pattern_to_str(pattern): def pattern_to_str(pattern):
if isinstance(pattern, (list, tuple)): if isinstance(pattern, (list, tuple)):
return "%s(%s)" % ( return "%s(%s)" % (
str(pattern[0]), str(pattern[0]),
", ".join([pattern_to_str(p) for p in pattern[1:]])) ", ".join([pattern_to_str(p) for p in pattern[1:]]))
elif isinstance(pattern, dict): elif isinstance(pattern, dict):
return "%s subject to %s" % ( return "%s subject to %s" % (
pattern_to_str(pattern['pattern']), pattern_to_str(pattern['pattern']),
str(pattern.get('constraint', 'no conditions'))) str(pattern.get('constraint', 'no conditions')))
else: else:
return str(pattern) return str(pattern)
return "%s -> %s" % ( return "%s -> %s" % (
pattern_to_str(self.in_pattern), pattern_to_str(self.in_pattern),
pattern_to_str(self.out_pattern)) pattern_to_str(self.out_pattern))
def __repr__(self): def __repr__(self):
return str(self) return str(self)
...@@ -1326,16 +1317,16 @@ class PatternSub(LocalOptimizer): ...@@ -1326,16 +1317,16 @@ class PatternSub(LocalOptimizer):
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
name = getattr(self, '__name__', getattr(self, 'name', None)) name = getattr(self, '__name__', getattr(self, 'name', None))
print("%s%s %s(%s, %s) id=%i" % ( print("%s%s %s(%s, %s) id=%i" % (
' ' * level, ' ' * level,
self.__class__.__name__, self.__class__.__name__,
name, name,
str(self.in_pattern), str(self.in_pattern),
str(self.out_pattern), str(self.out_pattern),
id(self)), file=stream) id(self)), file=stream)
################## ##################
### Navigators ### # Navigators #
################## ##################
# Use the following classes to apply LocalOptimizers # Use the following classes to apply LocalOptimizers
...@@ -1545,7 +1536,7 @@ class NavigatorOptimizer(Optimizer): ...@@ -1545,7 +1536,7 @@ class NavigatorOptimizer(Optimizer):
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print("%s%s (%i)" % ( print("%s%s (%i)" % (
(' ' * level), self.__class__.__name__, id(self)), file=stream) (' ' * level), self.__class__.__name__, id(self)), file=stream)
if depth != 0: if depth != 0:
self.local_opt.print_summary(stream, level=(level + 2), self.local_opt.print_summary(stream, level=(level + 2),
depth=(depth - 1)) depth=(depth - 1))
...@@ -1734,7 +1725,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1734,7 +1725,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
self.final_optimizers = final_optimizers self.final_optimizers = final_optimizers
self.max_use_ratio = max_use_ratio self.max_use_ratio = max_use_ratio
assert self.max_use_ratio is not None, ( assert self.max_use_ratio is not None, (
'max_use_ratio has to be a number') 'max_use_ratio has to be a number')
def get_local_optimizers(self): def get_local_optimizers(self):
for opt in self.local_optimizers_all: for opt in self.local_optimizers_all:
...@@ -1811,8 +1802,8 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1811,8 +1802,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
node_created[gopt] += change_tracker.nb_imported - nb node_created[gopt] += change_tracker.nb_imported - nb
if global_process_count[gopt] > max_use: if global_process_count[gopt] > max_use:
max_use_abort = True max_use_abort = True
opt_name = (getattr(gopt, "name", None) opt_name = (getattr(gopt, "name", None) or
or getattr(gopt, "__name__", "")) getattr(gopt, "__name__", ""))
global_sub_profs.append(sub_profs) global_sub_profs.append(sub_profs)
global_opt_timing.append(float(time.time() - t0)) global_opt_timing.append(float(time.time() - t0))
...@@ -1858,8 +1849,8 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1858,8 +1849,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
node_created[lopt] += change_tracker.nb_imported - nb node_created[lopt] += change_tracker.nb_imported - nb
if global_process_count[lopt] > max_use: if global_process_count[lopt] > max_use:
max_use_abort = True max_use_abort = True
opt_name = (getattr(lopt, "name", None) opt_name = (getattr(lopt, "name", None) or
or getattr(lopt, "__name__", "")) getattr(lopt, "__name__", ""))
if node not in fgraph.apply_nodes: if node not in fgraph.apply_nodes:
# go to next node # go to next node
break break
...@@ -1884,8 +1875,8 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1884,8 +1875,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
node_created[gopt] += change_tracker.nb_imported - nb node_created[gopt] += change_tracker.nb_imported - nb
if global_process_count[gopt] > max_use: if global_process_count[gopt] > max_use:
max_use_abort = True max_use_abort = True
opt_name = (getattr(gopt, "name", None) opt_name = (getattr(gopt, "name", None) or
or getattr(gopt, "__name__", "")) getattr(gopt, "__name__", ""))
final_sub_profs.append(sub_profs) final_sub_profs.append(sub_profs)
global_opt_timing[-1] += time.time() - t_before_final_opt global_opt_timing[-1] += time.time() - t_before_final_opt
...@@ -1896,9 +1887,9 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1896,9 +1887,9 @@ class EquilibriumOptimizer(NavigatorOptimizer):
end_nb_nodes = len(fgraph.apply_nodes) end_nb_nodes = len(fgraph.apply_nodes)
if max_use_abort: if max_use_abort:
_logger.error("EquilibriumOptimizer max'ed out by '%s'" % opt_name _logger.error("EquilibriumOptimizer max'ed out by '%s'" % opt_name +
+ ". You can safely raise the current threshold of " ". You can safely raise the current threshold of " +
+ "%f with the theano flag 'optdb.max_use_ratio'." % "%f with the theano flag 'optdb.max_use_ratio'." %
config.optdb.max_use_ratio) config.optdb.max_use_ratio)
fgraph.remove_feature(change_tracker) fgraph.remove_feature(change_tracker)
return (self, loop_timing, loop_process_count, return (self, loop_timing, loop_process_count,
...@@ -1909,7 +1900,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1909,7 +1900,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
name = getattr(self, 'name', None) name = getattr(self, 'name', None)
print("%s%s %s id=%i" % ( print("%s%s %s id=%i" % (
(' ' * level), self.__class__.__name__, name, id(self)), file=stream) (' ' * level), self.__class__.__name__, name, id(self)), file=stream)
if depth != 0: if depth != 0:
for lopt in self.get_local_optimizers(): for lopt in self.get_local_optimizers():
lopt.print_summary(stream, level=(level + 2), lopt.print_summary(stream, level=(level + 2),
...@@ -1925,11 +1916,11 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1925,11 +1916,11 @@ class EquilibriumOptimizer(NavigatorOptimizer):
blanc = (' ' * level) blanc = (' ' * level)
print(blanc, "EquilibriumOptimizer", end=' ', file=stream) print(blanc, "EquilibriumOptimizer", end=' ', file=stream)
print(blanc, getattr(opt, "name", print(blanc, getattr(opt, "name",
getattr(opt, "__name__", "")), file=stream) getattr(opt, "__name__", "")), file=stream)
print(blanc, " time %.3fs for %d passes" % ( print(blanc, " time %.3fs for %d passes" % (
sum(loop_timing), len(loop_timing)), file=stream) sum(loop_timing), len(loop_timing)), file=stream)
print(blanc, " nb nodes (start, end, max) %d %d %d" % ( print(blanc, " nb nodes (start, end, max) %d %d %d" % (
start_nb_nodes, end_nb_nodes, max_nb_nodes), file=stream) start_nb_nodes, end_nb_nodes, max_nb_nodes), file=stream)
print(blanc, " time io_toposort %.3fs" % sum( print(blanc, " time io_toposort %.3fs" % sum(
io_toposort_timing), file=stream) io_toposort_timing), file=stream)
s = sum([time_opts[o] for o in opt.get_local_optimizers()]) s = sum([time_opts[o] for o in opt.get_local_optimizers()])
...@@ -1948,12 +1939,12 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1948,12 +1939,12 @@ class EquilibriumOptimizer(NavigatorOptimizer):
if len(d) > 5: if len(d) > 5:
lopt += " ..." lopt += " ..."
print(blanc, (' %2d - %.3fs %d (%.3fs in global opts, ' print(blanc, (' %2d - %.3fs %d (%.3fs in global opts, '
'%.3fs io_toposort) - %d nodes - %s' % ( '%.3fs io_toposort) - %d nodes - %s' % (
i, loop_timing[i], i, loop_timing[i],
sum(loop_process_count[i].values()), sum(loop_process_count[i].values()),
global_opt_timing[i], global_opt_timing[i],
io_toposort_timing[i], nb_nodes[i], io_toposort_timing[i], nb_nodes[i],
lopt)), file=stream) lopt)), file=stream)
count_opt = [] count_opt = []
not_used = [] not_used = []
...@@ -1975,8 +1966,9 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1975,8 +1966,9 @@ class EquilibriumOptimizer(NavigatorOptimizer):
not_used_time += time_opts[o] not_used_time += time_opts[o]
if count_opt: if count_opt:
print(blanc, \ print(blanc,
' times - times applied - nb node created - name:', file=stream) ' times - times applied - nb node created - name:',
file=stream)
count_opt.sort() count_opt.sort()
for (t, count, n_created, o) in count_opt[::-1]: for (t, count, n_created, o) in count_opt[::-1]:
print(blanc, ' %.3fs - %d - %d - %s' % ( print(blanc, ' %.3fs - %d - %d - %s' % (
...@@ -2010,7 +2002,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2010,7 +2002,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
@staticmethod @staticmethod
def merge_profile(prof1, prof2): def merge_profile(prof1, prof2):
#(opt, loop_timing, loop_process_count, max_nb_nodes, # (opt, loop_timing, loop_process_count, max_nb_nodes,
# global_opt_timing, nb_nodes, time_opts, io_toposort_timing) = prof1 # global_opt_timing, nb_nodes, time_opts, io_toposort_timing) = prof1
local_optimizers = OrderedSet(prof1[0].get_local_optimizers()).union( local_optimizers = OrderedSet(prof1[0].get_local_optimizers()).union(
prof2[0].get_local_optimizers()) prof2[0].get_local_optimizers())
...@@ -2085,7 +2077,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2085,7 +2077,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
final_sub_profs) final_sub_profs)
################# #################
### Utilities ### # Utilities #
################# #################
...@@ -2096,7 +2088,7 @@ def _check_chain(r, chain): ...@@ -2096,7 +2088,7 @@ def _check_chain(r, chain):
while chain: while chain:
elem = chain.pop() elem = chain.pop()
if elem is None: if elem is None:
if not r.owner is None: if r.owner is not None:
return False return False
elif r.owner is None: elif r.owner is None:
return False return False
...@@ -2105,20 +2097,20 @@ def _check_chain(r, chain): ...@@ -2105,20 +2097,20 @@ def _check_chain(r, chain):
return False return False
else: else:
try: try:
if (issubclass(elem, op.Op) if (issubclass(elem, op.Op) and
and not isinstance(r.owner.op, elem)): not isinstance(r.owner.op, elem)):
return False return False
except TypeError: except TypeError:
return False return False
if chain: if chain:
r = r.owner.inputs[chain.pop()] r = r.owner.inputs[chain.pop()]
# print 'check_chain', _check_chain.n_calls # print 'check_chain', _check_chain.n_calls
#_check_chain.n_calls += 1 # _check_chain.n_calls += 1
# The return value will be used as a Boolean, but some Variables cannot # The return value will be used as a Boolean, but some Variables cannot
# be used as Booleans (the results of comparisons, for instance) # be used as Booleans (the results of comparisons, for instance)
return (r is not None) return (r is not None)
#_check_chain.n_calls = 0 # _check_chain.n_calls = 0
def check_chain(r, *chain): def check_chain(r, *chain):
......
...@@ -3,9 +3,11 @@ import linecache ...@@ -3,9 +3,11 @@ import linecache
import traceback import traceback
import sys import sys
import numpy
from six import iteritems from six import iteritems
from theano import config from theano import config
from theano.compat import PY3
def simple_extract_stack(f=None, limit=None): def simple_extract_stack(f=None, limit=None):
...@@ -435,3 +437,31 @@ def remove(predicate, coll): ...@@ -435,3 +437,31 @@ def remove(predicate, coll):
[1, 3] [1, 3]
""" """
return [x for x in coll if not predicate(x)] return [x for x in coll if not predicate(x)]
if PY3:
import hashlib
def hash_from_code(msg):
# hashlib.md5() requires an object that supports buffer interface,
# but Python 3 (unicode) strings don't.
if isinstance(msg, str):
msg = msg.encode()
# Python 3 does not like module names that start with
# a digit.
return 'm' + hashlib.md5(msg).hexdigest()
else:
import hashlib
def hash_from_code(msg):
try:
return hashlib.md5(msg).hexdigest()
except TypeError:
assert isinstance(msg, numpy.ndarray)
return hashlib.md5(numpy.getbuffer(msg)).hexdigest()
def hash_from_file(file_path):
"""Return the MD5 hash of a file."""
return hash_from_code(open(file_path, 'rb').read())
...@@ -10,7 +10,7 @@ import numpy ...@@ -10,7 +10,7 @@ import numpy
from theano.compat import decode, decode_iter from theano.compat import decode, decode_iter
from theano.gof import local_bitwidth from theano.gof import local_bitwidth
from theano.gof.cc import hash_from_file from theano.gof.utils import hash_from_file
from theano.gof.cmodule import (std_libs, std_lib_dirs, from theano.gof.cmodule import (std_libs, std_lib_dirs,
std_include_dirs, dlimport, std_include_dirs, dlimport,
Compiler, Compiler,
......
from theano.gof.cc import hash_from_code from theano.gof.utils import hash_from_code
def hash_from_sparse(data): def hash_from_sparse(data):
......
...@@ -2,7 +2,7 @@ import numpy ...@@ -2,7 +2,7 @@ import numpy
import theano import theano
from theano.compat import izip from theano.compat import izip
from theano.gof.cc import hash_from_code from theano.gof.utils import hash_from_code
def hash_from_ndarray(data): def hash_from_ndarray(data):
......
...@@ -233,16 +233,10 @@ whitelist_flake8 = [ ...@@ -233,16 +233,10 @@ whitelist_flake8 = [
"sparse/sandbox/sp2.py", "sparse/sandbox/sp2.py",
"sparse/sandbox/truedot.py", "sparse/sandbox/truedot.py",
"sparse/sandbox/sp.py", "sparse/sandbox/sp.py",
"gof/destroyhandler.py",
"gof/unify.py", "gof/unify.py",
"gof/graph.py", "gof/graph.py",
"gof/__init__.py", "gof/__init__.py",
"gof/cc.py",
"gof/opt.py",
"gof/link.py",
"gof/fg.py",
"gof/op.py", "gof/op.py",
"gof/cmodule.py",
"gof/tests/test_cmodule.py", "gof/tests/test_cmodule.py",
"gof/tests/test_destroyhandler.py", "gof/tests/test_destroyhandler.py",
"gof/tests/test_opt.py", "gof/tests/test_opt.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论