提交 cba9c812 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3130 from harlouci/flake8_gof

Flake8 gof
......@@ -177,7 +177,7 @@ def get_config_md5():
"""
all_opts = sorted([c for c in _config_var_list if c.in_c_key],
key=lambda cv: cv.fullname)
return theano.gof.cc.hash_from_code('\n'.join(
return theano.gof.utils.hash_from_code('\n'.join(
['%s = %s' % (cv.fullname, cv.__get__()) for cv in all_opts]))
......
"""
Defines Linkers that deal with C implementations.
"""
from __future__ import print_function
# Python imports
from copy import copy
import os
import sys
from theano.compat import izip
import logging
import numpy
import theano
from theano import config
from theano.compat import PY3
from theano.compat import izip
from six import string_types, reraise
from six.moves import StringIO, xrange
from theano.gof.utils import MethodNotDefined
import theano
from theano import config
if PY3:
import hashlib
def hash_from_code(msg):
# hashlib.md5() requires an object that supports buffer interface,
# but Python 3 (unicode) strings don't.
if isinstance(msg, str):
msg = msg.encode()
# Python 3 does not like module names that start with
# a digit.
return 'm' + hashlib.md5(msg).hexdigest()
else:
import hashlib
def hash_from_code(msg):
try:
return hashlib.md5(msg).hexdigest()
except TypeError:
assert isinstance(msg, numpy.ndarray)
return hashlib.md5(numpy.getbuffer(msg)).hexdigest()
def hash_from_file(file_path):
"""Return the MD5 hash of a file."""
return hash_from_code(open(file_path, 'rb').read())
# Note that we need to do this before importing cutils, since when there is
# no theano cache dir initialized yet, importing cutils may require compilation
# of cutils_ext.
from theano.configparser import AddConfigVar, StrParam
AddConfigVar('gcc.cxxflags',
"Extra compiler flags for gcc",
StrParam(""))
# gof imports
from theano.gof import graph
from theano.gof import link
from theano.gof import utils
from theano.gof import cmodule
from theano.gof.compilelock import get_lock, release_lock
from theano.gof.callcache import CallCache
from theano.gof import cmodule
AddConfigVar('gcc.cxxflags',
"Extra compiler flags for gcc",
StrParam(""))
import logging
_logger = logging.getLogger("theano.gof.cc")
from theano.gof.callcache import CallCache
run_cthunk = None # Will be imported only when needed.
......@@ -314,9 +284,8 @@ def get_c_declare(r, name, sub):
"""Wrapper around c_declare that declares py_name"""
if any([c != "output" and getattr(c.op, 'check_input',
config.check_input) for (c, _) in r.clients]) or (r.owner
and getattr(r.owner.op, 'check_input', True)):
config.check_input) for (c, _) in r.clients]) or (
r.owner and getattr(r.owner.op, 'check_input', True)):
c_declare = r.type.c_declare(name, sub, True)
else:
c_declare = r.type.c_declare(name, sub, False)
......@@ -532,7 +501,7 @@ class CLinker(link.Linker):
if isinstance(r, graph.Constant) and
r not in self.inputs)
self.temps = list(set(self.variables).difference(
self.inputs).difference(self.outputs).difference(self.orphans))
self.inputs).difference(self.outputs).difference(self.orphans))
self.consts = []
def code_gen(self):
......@@ -821,7 +790,7 @@ class CLinker(link.Linker):
ret = []
# generic support code
for x in [y.type for y in self.variables] + [
y.op for y in self.node_order]:
y.op for y in self.node_order]:
try:
ret.append(x.c_support_code())
except utils.MethodNotDefined:
......@@ -840,11 +809,11 @@ class CLinker(link.Linker):
# FillMissing must disable some of them. Putting -ffast-math would
# make it disable all other parameter at the same time.
ret += ["-fno-math-errno",
#"-funsafe-math-optimizations",
#"-fno-signaling-nans",
#"-fcx-limited-range",
#"-fno-rounding-math",
#"-ffinite-math-only",
# "-funsafe-math-optimizations",
# "-fno-signaling-nans",
# "-fcx-limited-range",
# "-fno-rounding-math",
# "-ffinite-math-only",
# the current code generate label event if they are not used.
# Could use gcc attribute for those label only
......@@ -853,7 +822,7 @@ class CLinker(link.Linker):
"-Wno-write-strings", # generated by our code generator...
]
for x in [y.type for y in self.variables] + [
y.op for y in self.node_order]:
y.op for y in self.node_order]:
try:
ret += x.c_compile_args()
except utils.MethodNotDefined:
......@@ -866,7 +835,7 @@ class CLinker(link.Linker):
# to reorder them
ret += c_compiler.compile_args()
for x in [y.type for y in self.variables] + [
y.op for y in self.node_order]:
y.op for y in self.node_order]:
try:
for i in x.c_no_compile_args():
try:
......@@ -886,7 +855,7 @@ class CLinker(link.Linker):
"""
ret = []
for x in [y.type for y in self.variables] + [
y.op for y in self.node_order]:
y.op for y in self.node_order]:
try:
ret += x.c_headers()
except utils.MethodNotDefined:
......@@ -901,7 +870,7 @@ class CLinker(link.Linker):
"""
ret = []
for x in [y.type for y in self.variables] + [
y.op for y in self.node_order]:
y.op for y in self.node_order]:
try:
ret += x.c_init_code()
except utils.MethodNotDefined:
......@@ -911,7 +880,7 @@ class CLinker(link.Linker):
def c_compiler(self):
c_compiler = None
for x in [y.type for y in self.variables] + [
y.op for y in self.node_order]:
y.op for y in self.node_order]:
if hasattr(x, 'c_compiler'):
x_compiler = x.c_compiler()
else:
......@@ -938,7 +907,7 @@ class CLinker(link.Linker):
"""
ret = []
for x in [y.type for y in self.variables] + [
y.op for y in self.node_order]:
y.op for y in self.node_order]:
try:
ret += x.c_header_dirs()
except utils.MethodNotDefined:
......@@ -954,7 +923,7 @@ class CLinker(link.Linker):
"""
ret = []
for x in [y.type for y in self.variables] + [
y.op for y in self.node_order]:
y.op for y in self.node_order]:
try:
ret += x.c_libraries()
except utils.MethodNotDefined:
......@@ -970,7 +939,7 @@ class CLinker(link.Linker):
"""
ret = []
for x in [y.type for y in self.variables] + [
y.op for y in self.node_order]:
y.op for y in self.node_order]:
try:
ret += x.c_lib_dirs()
except utils.MethodNotDefined:
......@@ -1150,7 +1119,7 @@ class CLinker(link.Linker):
libraries=self.libraries(),
header_dirs=self.header_dirs(),
c_compiler=self.c_compiler(),
)
)
def cmodule_key_(self, fgraph, no_recycling, compile_args=None,
libraries=None, header_dirs=None, insert_config_md5=True,
......@@ -1335,7 +1304,6 @@ class CLinker(link.Linker):
preargs.remove('-DREPLACE_WITH_AMDLIBM')
if 'amdlibm' in libs:
libs.remove('amdlibm')
src_code = mod.code()
get_lock()
try:
_logger.debug("LOCATION %s", str(location))
......@@ -1371,9 +1339,9 @@ class CLinker(link.Linker):
code = self.instantiate_code(1 + len(self.args))
instantiate = cmodule.ExtFunction('instantiate', code,
method=cmodule.METH_VARARGS)
#['error_storage'] + argnames,
#local_dict = d,
# global_dict = {})
# ['error_storage'] + argnames,
# local_dict = d,
# global_dict = {})
# Static methods that can run and destroy the struct built by
# instantiate.
......@@ -1498,7 +1466,7 @@ class _CThunk(object):
global run_cthunk
if run_cthunk is None:
# Lazy import to avoid compilation when importing theano.
from theano.gof.cutils import run_cthunk
from theano.gof.cutils import run_cthunk # noqa
self.cthunk = cthunk
self.init_tasks = init_tasks
self.tasks = tasks
......@@ -1534,7 +1502,8 @@ class _CThunk(object):
exc_value.__thunk_trace__ = trace
except Exception:
print(('ERROR retrieving error_storage.'
' Was the error set in the c code?'), end=' ', file=sys.stderr)
'Was the error set in the c code?'),
end=' ', file=sys.stderr)
print(self.error_storage, file=sys.stderr)
raise
reraise(exc_type, exc_value, exc_trace)
......@@ -1641,11 +1610,11 @@ class OpWiseCLinker(link.LocalLinker):
for node in order:
if self.allow_gc:
post_thunk_old_storage.append([storage_map[input]
for input in node.inputs
if ((input in computed) and
(input not in fgraph.outputs) and
node == last_user[input])])
post_thunk_old_storage.append(
[storage_map[input] for input in node.inputs
if ((input in computed) and
(input not in fgraph.outputs) and
node == last_user[input])])
if no_recycling is True:
no_recycling = list(storage_map.values())
......@@ -1741,12 +1710,12 @@ class DualLinker(link.Linker):
no_recycling = self.no_recycling
_f, i1, o1, thunks1, order1 = (
link.PerformLinker(schedule=self.schedule).accept(fgraph,
no_recycling=no_recycling).make_all(**kwargs))
link.PerformLinker(schedule=self.schedule).accept(
fgraph, no_recycling=no_recycling).make_all(**kwargs))
kwargs.pop('input_storage', None)
_f, i2, o2, thunks2, order2 = (
OpWiseCLinker(schedule=self.schedule).accept(fgraph,
no_recycling=no_recycling).make_all(**kwargs))
OpWiseCLinker(schedule=self.schedule).accept(
fgraph, no_recycling=no_recycling).make_all(**kwargs))
def f():
for input1, input2 in izip(i1, i2):
......
"""Generate and compile C modules for Python,
"""
from __future__ import print_function
import atexit
import six.moves.cPickle as pickle
import logging
......@@ -15,12 +17,6 @@ import time
import platform
import distutils.sysconfig
importlib = None
try:
import importlib
except ImportError:
pass
import numpy.distutils # TODO: TensorType should handle this
import theano
......@@ -28,7 +24,7 @@ from theano.compat import PY3, decode, decode_iter
from six import b, BytesIO, StringIO, string_types, iteritems
from theano.gof.utils import flatten
from theano.configparser import config
from theano.gof.cc import hash_from_code
from theano.gof.utils import hash_from_code
from theano.misc.windows import (subprocess_Popen,
output_subprocess_Popen)
......@@ -38,10 +34,17 @@ from theano.gof.compiledir import gcc_version_str, local_bitwidth
from theano.configparser import AddConfigVar, BoolParam
AddConfigVar('cmodule.mac_framework_link',
"If set to True, breaks certain MacOS installations with the infamous "
"Bus Error",
BoolParam(False))
importlib = None
try:
import importlib
except ImportError:
pass
AddConfigVar(
'cmodule.mac_framework_link',
"If set to True, breaks certain MacOS installations with the infamous "
"Bus Error",
BoolParam(False))
AddConfigVar('cmodule.warn_no_version',
"If True, will print a warning when compiling one or more Op "
......@@ -131,15 +134,16 @@ class ExtFunction(object):
It goes into the DynamicModule's method table.
"""
return '\t{"%s", %s, %s, "%s"}' % (
self.name, self.name, self.method, self.doc)
self.name, self.name, self.method, self.doc)
class DynamicModule(object):
def __init__(self, name=None):
assert name is None, ("The 'name' parameter of DynamicModule"
" cannot be specified anymore. Instead, 'code_hash'"
" will be automatically computed and can be used as"
" the module's name.")
assert name is None, (
"The 'name' parameter of DynamicModule"
" cannot be specified anymore. Instead, 'code_hash'"
" will be automatically computed and can be used as"
" the module's name.")
# While the module is not finalized, we can call add_...
# when it is finalized, a hash is computed and used instead of
# the placeholder, and as module name.
......@@ -171,18 +175,18 @@ static struct PyModuleDef moduledef = {{
}};
""".format(name=self.hash_placeholder), file=stream)
print(("PyMODINIT_FUNC PyInit_%s(void) {" %
self.hash_placeholder), file=stream)
self.hash_placeholder), file=stream)
for block in self.init_blocks:
print(' ', block, file=stream)
print(" PyObject *m = PyModule_Create(&moduledef);", file=stream)
print(" return m;", file=stream)
else:
print(("PyMODINIT_FUNC init%s(void){" %
self.hash_placeholder), file=stream)
self.hash_placeholder), file=stream)
for block in self.init_blocks:
print(' ', block, file=stream)
print(' ', ('(void) Py_InitModule("%s", MyMethods);'
% self.hash_placeholder), file=stream)
% self.hash_placeholder), file=stream)
print("}", file=stream)
def add_include(self, str):
......@@ -351,9 +355,9 @@ def is_same_entry(entry_1, entry_2):
if os.path.realpath(entry_1) == os.path.realpath(entry_2):
return True
if (os.path.basename(entry_1) == os.path.basename(entry_2) and
(os.path.basename(os.path.dirname(entry_1)) ==
os.path.basename(os.path.dirname(entry_2))) and
os.path.basename(os.path.dirname(entry_1)).startswith('tmp')):
(os.path.basename(os.path.dirname(entry_1)) ==
os.path.basename(os.path.dirname(entry_2))) and
os.path.basename(os.path.dirname(entry_1)).startswith('tmp')):
return True
return False
......@@ -429,8 +433,8 @@ def get_safe_part(key):
# Find the md5 hash part.
c_link_key = key[1]
for key_element in c_link_key[1:]:
if (isinstance(key_element, string_types)
and key_element.startswith('md5:')):
if (isinstance(key_element, string_types) and
key_element.startswith('md5:')):
md5 = key_element[4:]
break
......@@ -761,9 +765,9 @@ class ModuleCache(object):
# simpler to implement).
rmtree(root, ignore_nocleanup=True,
msg=(
'invalid cache entry format -- this '
'should not happen unless your cache '
'was really old'),
'invalid cache entry format -- this '
'should not happen unless your cache '
'was really old'),
level=logging.WARN)
continue
......@@ -964,7 +968,7 @@ class ModuleCache(object):
# process that could be changing the file at the same
# time.
if (key[0] and not key_broken and
self.check_for_broken_eq):
self.check_for_broken_eq):
self.check_key(key, key_data.key_pkl)
self._update_mappings(key, key_data, module.__file__, check_in_keys=not key_broken)
return module
......@@ -1149,15 +1153,14 @@ class ModuleCache(object):
# This is to make debugging in pdb easier, by providing
# the offending keys in the local context.
# key_data_keys = list(key_data.keys)
## import pdb; pdb.set_trace()
# import pdb; pdb.set_trace()
pass
elif found > 1:
msg = 'Multiple equal keys found in unpickled KeyData file'
if msg:
raise AssertionError(
"%s. Verify the __eq__ and __hash__ functions of your "
"Ops. The file is: %s. The key is: %s" %
(msg, key_pkl, key))
"%s. Verify the __eq__ and __hash__ functions of your "
"Ops. The file is: %s. The key is: %s" % (msg, key_pkl, key))
# Also verify that there exists no other loaded key that would be equal
# to this key. In order to speed things up, we only compare to keys
# with the same version part and config md5, since we can assume this
......@@ -1195,10 +1198,10 @@ class ModuleCache(object):
if age_thresh_del < self.age_thresh_use:
if age_thresh_del > 0:
_logger.warning("Clearing modules that were not deemed "
"too old to use: age_thresh_del=%d, "
"self.age_thresh_use=%d",
age_thresh_del,
self.age_thresh_use)
"too old to use: age_thresh_del=%d, "
"self.age_thresh_use=%d",
age_thresh_del,
self.age_thresh_use)
else:
_logger.info("Clearing all modules.")
age_thresh_use = age_thresh_del
......@@ -1210,8 +1213,8 @@ class ModuleCache(object):
# processes and get all module that are too old to use
# (not loaded in self.entry_from_key).
too_old_to_use = self.refresh(
age_thresh_use=age_thresh_use,
delete_if_problem=delete_if_problem)
age_thresh_use=age_thresh_use,
delete_if_problem=delete_if_problem)
for entry in too_old_to_use:
# TODO: we are assuming that modules that haven't been
......@@ -1242,8 +1245,8 @@ class ModuleCache(object):
"""
with compilelock.lock_ctx():
self.clear_old(
age_thresh_del=-1.0,
delete_if_problem=delete_if_problem)
age_thresh_del=-1.0,
delete_if_problem=delete_if_problem)
self.clear_unversioned(min_age=unversioned_min_age)
if clear_base_files:
self.clear_base_files()
......@@ -1333,7 +1336,7 @@ class ModuleCache(object):
if filename.startswith('tmp'):
try:
open(os.path.join(self.dirname, filename, 'key.pkl')
).close()
).close()
has_key = True
except IOError:
has_key = False
......@@ -1420,8 +1423,8 @@ def get_module_cache(dirname, init_args=None):
'was created prior to this call')
if _module_cache.dirname != dirname:
_logger.warning("Returning module cache instance with different "
"dirname (%s) than you requested (%s)",
_module_cache.dirname, dirname)
"dirname (%s) than you requested (%s)",
_module_cache.dirname, dirname)
return _module_cache
......@@ -1685,7 +1688,7 @@ class GCC_compiler(Compiler):
break
if ('g++' not in theano.config.cxx and
'clang++' not in theano.config.cxx):
'clang++' not in theano.config.cxx):
_logger.warn(
"OPTIMIZATION WARNING: your Theano flag `cxx` seems not to be"
" the g++ compiler. So we disable the compiler optimization"
......@@ -1719,9 +1722,9 @@ class GCC_compiler(Compiler):
selected_lines = []
for line in lines:
if ("COLLECT_GCC_OPTIONS=" in line or
"CFLAGS=" in line or
"CXXFLAGS=" in line or
"-march=native" in line):
"CFLAGS=" in line or
"CXXFLAGS=" in line or
"-march=native" in line):
continue
elif "-march=" in line:
selected_lines.append(line.strip())
......@@ -1805,9 +1808,9 @@ class GCC_compiler(Compiler):
for line in default_lines:
if line.startswith(part[0]):
part2 = [p for p in join_options(line.split())
if (not 'march' in p and
not 'mtune' in p and
not 'target-cpu' in p)]
if ('march' not in p and
'mtune' not in p and
'target-cpu' not in p)]
new_flags = [p for p in part if p not in part2]
# Replace '-target-cpu value', which is an option
# of clang, with '-march=value', for g++
......@@ -2021,14 +2024,13 @@ class GCC_compiler(Compiler):
cmd.append(cppfilename)
cmd.extend(['-L%s' % ldir for ldir in lib_dirs])
cmd.extend(['-l%s' % l for l in libs])
#print >> sys.stderr, 'COMPILING W CMD', cmd
# print >> sys.stderr, 'COMPILING W CMD', cmd
_logger.debug('Running cmd: %s', ' '.join(cmd))
def print_command_line_error():
# Print command line when a problem occurred.
print((
"Problem occurred during compilation with the "
"command line below:"), file=sys.stderr)
print(("Problem occurred during compilation with the "
"command line below:"), file=sys.stderr)
print(' '.join(cmd), file=sys.stderr)
try:
......
......@@ -46,7 +46,6 @@ def _contains_cycle(fgraph, orderings):
"""
# These are lists of Variable instances
inputs = fgraph.inputs
outputs = fgraph.outputs
# this is hard-coded reimplementation of functions from graph.py
......@@ -65,8 +64,6 @@ def _contains_cycle(fgraph, orderings):
# (defaultdict runs faster than dict in the case where the key
# is not in the dictionary, at least in CPython)
iset = set(inputs)
# IG: I tried converting parent_counts to use an id for the key,
# so that the dict would do reference counting on its keys.
# This caused a slowdown.
......@@ -236,9 +233,9 @@ def fast_inplace_check(inputs):
protected_inputs.extend(fgraph.outputs)
inputs = [i for i in inputs if
not isinstance(i, graph.Constant)
and not fgraph.destroyers(i)
and i not in protected_inputs]
not isinstance(i, graph.Constant) and
not fgraph.destroyers(i) and
i not in protected_inputs]
return inputs
if 0:
......@@ -293,7 +290,7 @@ if 0:
TODO: WRITEME: what does this do besides the checks?
"""
####### Do the checking ###########
# Do the checking #
already_there = False
if self.fgraph not in [None, fgraph]:
raise Exception("A DestroyHandler instance can only serve"
......@@ -309,7 +306,7 @@ if 0:
"DestroyHandler feature is already present or in"
" conflict with another plugin.")
####### end of checking ############
# end of checking #
def get_destroyers_of(r):
droot, impact, root_destroyer = self.refresh_droot_impact()
......@@ -362,8 +359,8 @@ if 0:
"Multiple destroyers of %s" % input_root)
droot[input_root] = input_root
root_destroyer[input_root] = app
#input_impact = set([input_root])
#add_impact(input_root, self.view_o, input_impact)
# input_impact = set([input_root])
# add_impact(input_root, self.view_o, input_impact)
input_impact = get_impact(input_root, self.view_o)
for v in input_impact:
assert v not in droot
......@@ -390,7 +387,7 @@ if 0:
def on_import(self, fgraph, app, reason):
"""Add Apply instance to set which must be computed"""
#if app in self.debug_all_apps: raise ProtocolError("double import")
# if app in self.debug_all_apps: raise ProtocolError("double import")
# self.debug_all_apps.add(app)
# print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps)
......@@ -421,7 +418,7 @@ if 0:
def on_prune(self, fgraph, app, reason):
"""Remove Apply instance from set which must be computed"""
#if app not in self.debug_all_apps: raise ProtocolError("prune without import")
# if app not in self.debug_all_apps: raise ProtocolError("prune without import")
# self.debug_all_apps.remove(app)
# UPDATE self.clients
......@@ -458,7 +455,7 @@ if 0:
# considered 'outputs' of the graph.
pass
else:
#if app not in self.debug_all_apps: raise ProtocolError("change without import")
# if app not in self.debug_all_apps: raise ProtocolError("change without import")
# UPDATE self.clients
self.clients[old_r][app] -= 1
......@@ -529,9 +526,10 @@ if 0:
droot, impact, __ignore = self.refresh_droot_impact()
# check for destruction of constants
illegal_destroy = [r for r in droot if
getattr(r.tag, 'indestructible', False) or
isinstance(r, graph.Constant)]
illegal_destroy = [
r for r in droot if
getattr(r.tag, 'indestructible', False) or
isinstance(r, graph.Constant)]
if illegal_destroy:
# print 'destroying illegally'
raise InconsistencyError(
......@@ -603,7 +601,7 @@ if 0:
if input in root_impact \
and (i not in tolerated or input is not destroyed_variable):
raise InconsistencyError("Input aliasing: %s (%i, %i)"
% (app, destroyed_idx, i))
% (app, destroyed_idx, i))
# add the rule: app must be preceded by all other Apply instances that
# depend on destroyed_input
......@@ -621,7 +619,7 @@ if 0:
return rval
class DestroyHandler(toolbox.Bookkeeper):
class DestroyHandler(toolbox.Bookkeeper): # noqa
"""
The DestroyHandler class detects when a graph is impossible to evaluate
because of aliasing and destructive operations.
......@@ -702,7 +700,7 @@ class DestroyHandler(toolbox.Bookkeeper):
TODO: WRITEME: what does this do besides the checks?
"""
####### Do the checking ###########
# Do the checking #
already_there = False
if self.fgraph is fgraph:
already_there = True
......@@ -720,7 +718,7 @@ class DestroyHandler(toolbox.Bookkeeper):
"DestroyHandler feature is already present"
" or in conflict with another plugin.")
####### Annotate the FunctionGraph ############
# Annotate the FunctionGraph #
self.unpickle(fgraph)
fgraph.destroy_handler = self
......@@ -945,12 +943,13 @@ class DestroyHandler(toolbox.Bookkeeper):
droot, impact, __ignore = self.refresh_droot_impact()
# check for destruction of constants
illegal_destroy = [r for r in droot if \
getattr(r.tag, 'indestructible', False) or \
isinstance(r, graph.Constant)]
illegal_destroy = [r for r in droot if
getattr(r.tag, 'indestructible', False) or
isinstance(r, graph.Constant)]
if illegal_destroy:
raise InconsistencyError("Attempting to destroy indestructible variables: %s" %
illegal_destroy)
raise InconsistencyError(
"Attempting to destroy indestructible variables: %s" %
illegal_destroy)
# add destroyed variable clients as computational dependencies
for app in self.destroyers:
......@@ -995,24 +994,27 @@ class DestroyHandler(toolbox.Bookkeeper):
# CHECK FOR INPUT ALIASING
# OPT: pre-compute this on import
tolerate_same = getattr(app.op, 'destroyhandler_tolerate_same', [])
tolerate_same = getattr(app.op,
'destroyhandler_tolerate_same', [])
assert isinstance(tolerate_same, list)
tolerated = OrderedSet(idx1 for idx0, idx1 in tolerate_same
if idx0 == destroyed_idx)
if idx0 == destroyed_idx)
tolerated.add(destroyed_idx)
tolerate_aliased = getattr(app.op, 'destroyhandler_tolerate_aliased', [])
tolerate_aliased = getattr(
app.op, 'destroyhandler_tolerate_aliased', [])
assert isinstance(tolerate_aliased, list)
ignored = OrderedSet(idx1 for idx0, idx1 in tolerate_aliased
if idx0 == destroyed_idx)
if idx0 == destroyed_idx)
# print 'tolerated', tolerated
# print 'ignored', ignored
for i, input in enumerate(app.inputs):
if i in ignored:
continue
if input in root_impact \
and (i not in tolerated or input is not destroyed_variable):
and (i not in tolerated or
input is not destroyed_variable):
raise InconsistencyError("Input aliasing: %s (%i, %i)"
% (app, destroyed_idx, i))
% (app, destroyed_idx, i))
# add the rule: app must be preceded by all other Apply instances that
# depend on destroyed_input
......
......@@ -13,7 +13,6 @@ from theano.gof import graph
from theano.gof import utils
from theano.gof import toolbox
from theano import config
import warnings
from theano.compat import OrderedDict
from six import iteritems, itervalues
......@@ -22,6 +21,7 @@ from theano.misc.ordered_set import OrderedSet
NullType = None
class CachedConstantError(Exception):
"""An exception thrown when we put in a FunctionGraph a Constant
that is cached. This should not happen as the user can reuse this
......@@ -143,7 +143,7 @@ class FunctionGraph(utils.object2):
self.variable_locks = {}
self.profile = None
### Setup a Variable ###
# Setup a Variable #
def __setup_r__(self, r):
# sets up r so it belongs to this fgraph
if getattr(r, 'cached', False):
......@@ -152,12 +152,12 @@ class FunctionGraph(utils.object2):
" graph that has a cached constant. This should not happen."
" Clone the graph before building the FunctionGraph.")
if (hasattr(r, 'fgraph') and
r.fgraph is not None and
r.fgraph is not self):
r.fgraph is not None and
r.fgraph is not self):
raise Exception("%s is already owned by another fgraph" % r)
r.fgraph = self
r.clients = []
#self.execute_callbacks('on_setup_variable', r)
# self.execute_callbacks('on_setup_variable', r)
def __setup_node__(self, node):
# sets up node so it belongs to this fgraph
......@@ -177,7 +177,7 @@ class FunctionGraph(utils.object2):
str(node.op), str(node.op.destroy_map)))
node.fgraph = self
node.deps = {}
#self.execute_callbacks('on_setup_node', node)
# self.execute_callbacks('on_setup_node', node)
def disown(self):
""" WRITEME
......@@ -201,7 +201,7 @@ class FunctionGraph(utils.object2):
self.inputs = None
self.outputs = None
### clients ###
# clients #
def clients(self, r):
"""
Set of all the (node, i) pairs such that node.inputs[i] is r.
......@@ -221,9 +221,9 @@ class FunctionGraph(utils.object2):
if set(r.clients).intersection(set(new_clients)):
print('ERROR: clients intersect!', file=sys.stderr)
print(' RCLIENTS of', r, [(n, i, type(n), id(n))
for n, i in r.clients], file=sys.stderr)
for n, i in r.clients], file=sys.stderr)
print(' NCLIENTS of', r, [(n, i, type(n), id(n))
for n, i in new_clients], file=sys.stderr)
for n, i in new_clients], file=sys.stderr)
assert not set(r.clients).intersection(set(new_clients))
r.clients += new_clients
......@@ -245,7 +245,7 @@ class FunctionGraph(utils.object2):
return True
return False
### import ###
# import #
def __import_r__(self, variable, reason):
global NullType
if NullType is None:
......@@ -279,9 +279,8 @@ class FunctionGraph(utils.object2):
if hasattr(r, 'fgraph') and r.fgraph is not self:
raise Exception("%s is already owned by another fgraph" % r)
if (r.owner is None and
not isinstance(r, graph.Constant) and
r not in self.inputs):
not isinstance(r, graph.Constant) and
r not in self.inputs):
# Verbose error message
# Show a complete chain of variables from the missing input to an output
if config.exception_verbosity == 'high':
......@@ -373,7 +372,7 @@ class FunctionGraph(utils.object2):
assert node.fgraph is self
self.execute_callbacks('on_import', node, reason)
### prune ###
# prune #
def __prune_r__(self, variable, reason=None):
"""Should be called for variable that aren't used anymore:
len(var.clients) == 0
......@@ -430,7 +429,7 @@ class FunctionGraph(utils.object2):
self.__remove_clients__(input, [(apply_node, i)], reason=reason)
# self.__prune_r__(apply_node.inputs)
### change input ###
# change input #
def change_input(self, node, i, new_r, reason=None):
"""WRITEME
Changes node.inputs[i] to new_r.
......@@ -475,7 +474,7 @@ class FunctionGraph(utils.object2):
if prune:
self.__prune_r__(r, reason=reason)
### replace ###
# replace #
def replace(self, r, new_r, reason=None, verbose=None):
""" WRITEME
This is the main interface to manipulate the subgraph in FunctionGraph.
......@@ -582,7 +581,7 @@ class FunctionGraph(utils.object2):
if detach is not None:
detach(self)
### callback utils ###
# callback utils #
def execute_callbacks(self, name, *args, **kwargs):
"""WRITEME
Calls
......@@ -618,7 +617,7 @@ class FunctionGraph(utils.object2):
d[feature] = fn(*args)
return d
### misc ###
# misc #
def toposort(self):
"""WRITEME
Returns an ordering of the graph's Apply nodes such that:
......@@ -712,8 +711,8 @@ class FunctionGraph(utils.object2):
missing, excess)
for variable in variables:
if (variable.owner is None and
variable not in self.inputs and
not isinstance(variable, graph.Constant)):
variable not in self.inputs and
not isinstance(variable, graph.Constant)):
raise Exception("Undeclared input.", variable)
if variable.fgraph is not self:
raise Exception("Variable should belong to the FunctionGraph.",
......@@ -737,7 +736,7 @@ class FunctionGraph(utils.object2):
def __repr__(self):
return self.__str__()
### clone ###
# clone #
def clone(self, check_integrity=True):
"""WRITEME"""
return self.clone_get_equiv(check_integrity)[0]
......
......@@ -7,14 +7,14 @@ import traceback
import numpy
import theano
from theano.compat import PY3, izip
from theano.compat import izip
from six import reraise
from six.moves import StringIO
from theano.gof import utils
from theano.gof import graph
from theano.gof.type import Type
from .utils import MethodNotDefined, undef
from .utils import undef
__excepthook = sys.excepthook
......@@ -281,9 +281,9 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None):
else:
detailed_err_msg += "\n"
detailed_err_msg += " TotalSize: %s Byte(s) %.3f GB\n" % (
total_size, total_size/1024./1024/1024)
total_size, total_size / 1024. / 1024 / 1024)
detailed_err_msg += " TotalSize inputs: %s Byte(s) %.3f BG\n" % (
total_size_inputs, total_size_inputs/1024./1024/1024)
total_size_inputs, total_size_inputs / 1024. / 1024 / 1024)
else:
hints.append(
......@@ -326,7 +326,7 @@ class Linker(object):
raise utils.MethodNotDefined("make_thunk", type(self),
self.__class__.__name__)
## DELETEME ##
# DELETEME #
def make_function(self, unpack_single=True, **kwargs):
"""
Returns a function that takes values corresponding to the inputs of the
......@@ -350,8 +350,8 @@ class Linker(object):
def execute(*args):
def e_arity(takes, got):
return 'Function call takes exactly %i %s (%i given)' \
% (takes, ['argument', 'arguments'][takes > 1], got)
return 'Function call takes exactly %i %s (%i given)' % (
takes, ['argument', 'arguments'][takes > 1], got)
if (len(args) != len(inputs)):
raise TypeError(e_arity(len(inputs), len(args)))
for arg, variable in izip(args, inputs):
......@@ -394,7 +394,7 @@ class Container(object):
"""
if not isinstance(storage, list) or not len(storage) >= 1:
raise TypeError("storage must be a list of length at least one")
#self.r = r
# self.r = r
if isinstance(r, Type):
self.type = r
else:
......@@ -454,12 +454,11 @@ class Container(object):
deepcopy(self.strict, memo=memo),
deepcopy(self.allow_downcast, memo=memo),
deepcopy(self.name, memo=memo),
)
)
# Work around NumPy deepcopy of ndarray with 0 dimention that
# don't return an ndarray.
if (r.storage[0] is not None and
not self.type.is_valid_value(r.storage[0])):
not self.type.is_valid_value(r.storage[0])):
assert not data_was_in_memo
assert self.type.is_valid_value(self.storage[0])
# This should also work for read only container.
......@@ -672,7 +671,7 @@ class PerformLinker(LocalLinker):
no_recycling = []
if self.fgraph is not None and self.fgraph is not fgraph:
return type(self)(allow_gc=self.allow_gc).accept(fgraph, no_recycling)
#raise Exception("Cannot accept from a Linker that is already tied to another FunctionGraph.")
# raise Exception("Cannot accept from a Linker that is already tied to another FunctionGraph.")
self.fgraph = fgraph
self.no_recycling = no_recycling
return self
......@@ -721,9 +720,12 @@ class PerformLinker(LocalLinker):
for node in order:
if self.allow_gc:
post_thunk_old_storage.append([storage_map[input]
for input in node.inputs
if (input in computed) and (input not in fgraph.outputs) and node == last_user[input]])
post_thunk_old_storage.append(
[storage_map[input]
for input in node.inputs
if (input in computed) and (
input not in fgraph.outputs) and (
node == last_user[input])])
if no_recycling is True:
# True seems like some special code for *everything*?? -JB
......@@ -855,7 +857,7 @@ class WrapLinker(Linker):
make_all += [l.make_all(**kwargs) for l in self.linkers[1:]]
fns, input_lists, output_lists, thunk_lists, order_lists \
= zip(*make_all)
= zip(*make_all)
order_list0 = order_lists[0]
for order_list in order_lists[1:]:
......
......@@ -29,6 +29,7 @@ from . import destroyhandler as dh
_logger = logging.getLogger('theano.gof.opt')
_optimizer_idx = [0]
def _list_of_nodes(fgraph):
return list(graph.io_toposort(fgraph.inputs, fgraph.outputs))
......@@ -99,7 +100,7 @@ class Optimizer(object):
def print_summary(self, stream=sys.stdout, level=0, depth=-1):
name = getattr(self, 'name', None)
print("%s%s %s id=%i" % (
(' ' * level), self.__class__.__name__, name, id(self)), file=stream)
(' ' * level), self.__class__.__name__, name, id(self)), file=stream)
@staticmethod
def print_profile(stream, prof, level=0):
......@@ -121,9 +122,9 @@ class FromFunctionOptimizer(Optimizer):
def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print("%s%s id=%i" % (
' ' * level,
str(self.apply),
id(self)), file=stream)
' ' * level,
str(self.apply),
id(self)), file=stream)
def __call__(self, *args, **kwargs):
return self.fn(*args, **kwargs)
......@@ -222,7 +223,7 @@ class SeqOptimizer(Optimizer, list):
def print_summary(self, stream=sys.stdout, level=0, depth=-1):
name = getattr(self, 'name', None)
print("%s%s %s id=%i" % (
(' ' * level), self.__class__.__name__, name, id(self)), file=stream)
(' ' * level), self.__class__.__name__, name, id(self)), file=stream)
# This way, -1 will do all depth
if depth != 0:
depth -= 1
......@@ -241,8 +242,8 @@ class SeqOptimizer(Optimizer, list):
elif hasattr(opts, "__name__"):
print(blanc, opts.__name__, end=' ', file=stream)
print((" time %.3fs for %d/%d nodes"
" before/after optimization" % (
sum(prof), nb_node_before, nb_node_after)), file=stream)
" before/after optimization" % (
sum(prof), nb_node_before, nb_node_after)), file=stream)
print(blanc, " %.3fs for fgraph.validate()" % (validate_time), file=stream)
print(blanc, " %.3fs for callback" % (callback_time), file=stream)
if level == 0:
......@@ -324,7 +325,7 @@ class SeqOptimizer(Optimizer, list):
new_t[idx] += p[1][p[0].index(l)]
if hasattr(l, 'merge_profile'):
assert len(p[6][p[0].index(l)]) == \
len(new_sub_profile[idx])
len(new_sub_profile[idx])
new_sub_profile[idx] = l.merge_profile(
new_sub_profile[idx], p[6][p[0].index(l)])
else:
......@@ -729,6 +730,7 @@ def pre_constant_merge(vars):
const_sig_inv = {}
if isinstance(vars, graph.Variable):
vars = [vars]
def recursive_merge(var):
if var in seen_var:
return var
......@@ -761,7 +763,7 @@ def pre_constant_merge(vars):
########################
### Local Optimizers ###
# Local Optimizers #
########################
class LocalOptimizer(object):
......@@ -817,12 +819,14 @@ class LocalOptimizer(object):
def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print("%s%s id=%i" % (
(' ' * level), self.__class__.__name__, id(self)), file=stream)
(' ' * level), self.__class__.__name__, id(self)), file=stream)
theano.configparser.AddConfigVar('metaopt.verbose',
"Enable verbose output for meta optimizers",
theano.configparser.BoolParam(False), in_c_key=False)
theano.configparser.AddConfigVar(
'metaopt.verbose',
"Enable verbose output for meta optimizers",
theano.configparser.BoolParam(False),
in_c_key=False)
class LocalMetaOptimizer(LocalOptimizer):
......@@ -933,9 +937,9 @@ class FromFunctionLocalOptimizer(LocalOptimizer):
def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print("%s%s id=%i" % (
' ' * level,
str(self.transform),
id(self)), file=stream)
' ' * level,
str(self.transform),
id(self)), file=stream)
def local_optimizer(tracks, inplace=False):
......@@ -992,7 +996,7 @@ class LocalOptGroup(LocalOptimizer):
def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print("%s%s id=%i" % (
(' ' * level), self.__class__.__name__, id(self)), file=stream)
(' ' * level), self.__class__.__name__, id(self)), file=stream)
if depth != 0:
depth -= 1
for lopt in self.opts:
......@@ -1003,19 +1007,6 @@ class LocalOptGroup(LocalOptimizer):
opt.add_requirements(fgraph)
class _LocalOpKeyOptGroup(LocalOptGroup):
"""WRITEME"""
def __init__(self, optimizers):
if any(not hasattr(opt, 'op_key'), optimizers):
raise TypeError(
"All LocalOptimizers passed here must have an op_key method.")
CompositeLocalOptimizer.__init__(self, optimizers)
def op_key(self):
return [opt.op_key() for opt in self.opts]
class OpSub(LocalOptimizer):
"""WRITEME
Replaces the application of a certain op by the application of
......@@ -1086,10 +1077,10 @@ class OpRemove(LocalOptimizer):
def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print("%s%s(%s) id=%i" % (
' ' * level,
self.__class__.__name__,
str(self.op),
id(self)), file=stream)
' ' * level,
self.__class__.__name__,
str(self.op),
id(self)), file=stream)
class PatternSub(LocalOptimizer):
......@@ -1217,6 +1208,7 @@ class PatternSub(LocalOptimizer):
if node.op != self.op:
return False
# TODO: if we remove pdb, do this speed things up?
def match(pattern, expr, u, allow_multiple_clients=False, pdb=False):
# TODO move outside match
def retry_with_equiv():
......@@ -1233,9 +1225,8 @@ class PatternSub(LocalOptimizer):
if isinstance(pattern, (list, tuple)):
if expr.owner is None:
return False
if (not (expr.owner.op == pattern[0])
or (not allow_multiple_clients
and len(expr.clients) > 1)):
if (not (expr.owner.op == pattern[0]) or
(not allow_multiple_clients and len(expr.clients) > 1)):
return retry_with_equiv()
if len(pattern) - 1 != len(expr.owner.inputs):
return retry_with_equiv()
......@@ -1263,16 +1254,16 @@ class PatternSub(LocalOptimizer):
return retry_with_equiv()
else:
u = u.merge(expr, v)
elif (isinstance(pattern, (int, float))
and isinstance(expr, graph.Constant)):
elif (isinstance(pattern, (int, float)) and
isinstance(expr, graph.Constant)):
if numpy.all(
theano.tensor.constant(pattern).value == expr.value):
return u
else:
return retry_with_equiv()
elif (isinstance(pattern, graph.Constant)
and isinstance(expr, graph.Constant)
and pattern.equals(expr)):
elif (isinstance(pattern, graph.Constant) and
isinstance(expr, graph.Constant) and
pattern.equals(expr)):
return u
else:
return retry_with_equiv()
......@@ -1308,17 +1299,17 @@ class PatternSub(LocalOptimizer):
def pattern_to_str(pattern):
if isinstance(pattern, (list, tuple)):
return "%s(%s)" % (
str(pattern[0]),
", ".join([pattern_to_str(p) for p in pattern[1:]]))
str(pattern[0]),
", ".join([pattern_to_str(p) for p in pattern[1:]]))
elif isinstance(pattern, dict):
return "%s subject to %s" % (
pattern_to_str(pattern['pattern']),
str(pattern.get('constraint', 'no conditions')))
pattern_to_str(pattern['pattern']),
str(pattern.get('constraint', 'no conditions')))
else:
return str(pattern)
return "%s -> %s" % (
pattern_to_str(self.in_pattern),
pattern_to_str(self.out_pattern))
pattern_to_str(self.in_pattern),
pattern_to_str(self.out_pattern))
def __repr__(self):
return str(self)
......@@ -1326,16 +1317,16 @@ class PatternSub(LocalOptimizer):
def print_summary(self, stream=sys.stdout, level=0, depth=-1):
name = getattr(self, '__name__', getattr(self, 'name', None))
print("%s%s %s(%s, %s) id=%i" % (
' ' * level,
self.__class__.__name__,
name,
str(self.in_pattern),
str(self.out_pattern),
id(self)), file=stream)
' ' * level,
self.__class__.__name__,
name,
str(self.in_pattern),
str(self.out_pattern),
id(self)), file=stream)
##################
### Navigators ###
# Navigators #
##################
# Use the following classes to apply LocalOptimizers
......@@ -1545,7 +1536,7 @@ class NavigatorOptimizer(Optimizer):
def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print("%s%s (%i)" % (
(' ' * level), self.__class__.__name__, id(self)), file=stream)
(' ' * level), self.__class__.__name__, id(self)), file=stream)
if depth != 0:
self.local_opt.print_summary(stream, level=(level + 2),
depth=(depth - 1))
......@@ -1734,7 +1725,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
self.final_optimizers = final_optimizers
self.max_use_ratio = max_use_ratio
assert self.max_use_ratio is not None, (
'max_use_ratio has to be a number')
'max_use_ratio has to be a number')
def get_local_optimizers(self):
for opt in self.local_optimizers_all:
......@@ -1811,8 +1802,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
node_created[gopt] += change_tracker.nb_imported - nb
if global_process_count[gopt] > max_use:
max_use_abort = True
opt_name = (getattr(gopt, "name", None)
or getattr(gopt, "__name__", ""))
opt_name = (getattr(gopt, "name", None) or
getattr(gopt, "__name__", ""))
global_sub_profs.append(sub_profs)
global_opt_timing.append(float(time.time() - t0))
......@@ -1858,8 +1849,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
node_created[lopt] += change_tracker.nb_imported - nb
if global_process_count[lopt] > max_use:
max_use_abort = True
opt_name = (getattr(lopt, "name", None)
or getattr(lopt, "__name__", ""))
opt_name = (getattr(lopt, "name", None) or
getattr(lopt, "__name__", ""))
if node not in fgraph.apply_nodes:
# go to next node
break
......@@ -1884,8 +1875,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
node_created[gopt] += change_tracker.nb_imported - nb
if global_process_count[gopt] > max_use:
max_use_abort = True
opt_name = (getattr(gopt, "name", None)
or getattr(gopt, "__name__", ""))
opt_name = (getattr(gopt, "name", None) or
getattr(gopt, "__name__", ""))
final_sub_profs.append(sub_profs)
global_opt_timing[-1] += time.time() - t_before_final_opt
......@@ -1896,9 +1887,9 @@ class EquilibriumOptimizer(NavigatorOptimizer):
end_nb_nodes = len(fgraph.apply_nodes)
if max_use_abort:
_logger.error("EquilibriumOptimizer max'ed out by '%s'" % opt_name
+ ". You can safely raise the current threshold of "
+ "%f with the theano flag 'optdb.max_use_ratio'." %
_logger.error("EquilibriumOptimizer max'ed out by '%s'" % opt_name +
". You can safely raise the current threshold of " +
"%f with the theano flag 'optdb.max_use_ratio'." %
config.optdb.max_use_ratio)
fgraph.remove_feature(change_tracker)
return (self, loop_timing, loop_process_count,
......@@ -1909,7 +1900,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
def print_summary(self, stream=sys.stdout, level=0, depth=-1):
name = getattr(self, 'name', None)
print("%s%s %s id=%i" % (
(' ' * level), self.__class__.__name__, name, id(self)), file=stream)
(' ' * level), self.__class__.__name__, name, id(self)), file=stream)
if depth != 0:
for lopt in self.get_local_optimizers():
lopt.print_summary(stream, level=(level + 2),
......@@ -1925,11 +1916,11 @@ class EquilibriumOptimizer(NavigatorOptimizer):
blanc = (' ' * level)
print(blanc, "EquilibriumOptimizer", end=' ', file=stream)
print(blanc, getattr(opt, "name",
getattr(opt, "__name__", "")), file=stream)
getattr(opt, "__name__", "")), file=stream)
print(blanc, " time %.3fs for %d passes" % (
sum(loop_timing), len(loop_timing)), file=stream)
sum(loop_timing), len(loop_timing)), file=stream)
print(blanc, " nb nodes (start, end, max) %d %d %d" % (
start_nb_nodes, end_nb_nodes, max_nb_nodes), file=stream)
start_nb_nodes, end_nb_nodes, max_nb_nodes), file=stream)
print(blanc, " time io_toposort %.3fs" % sum(
io_toposort_timing), file=stream)
s = sum([time_opts[o] for o in opt.get_local_optimizers()])
......@@ -1948,12 +1939,12 @@ class EquilibriumOptimizer(NavigatorOptimizer):
if len(d) > 5:
lopt += " ..."
print(blanc, (' %2d - %.3fs %d (%.3fs in global opts, '
'%.3fs io_toposort) - %d nodes - %s' % (
i, loop_timing[i],
sum(loop_process_count[i].values()),
global_opt_timing[i],
io_toposort_timing[i], nb_nodes[i],
lopt)), file=stream)
'%.3fs io_toposort) - %d nodes - %s' % (
i, loop_timing[i],
sum(loop_process_count[i].values()),
global_opt_timing[i],
io_toposort_timing[i], nb_nodes[i],
lopt)), file=stream)
count_opt = []
not_used = []
......@@ -1975,8 +1966,9 @@ class EquilibriumOptimizer(NavigatorOptimizer):
not_used_time += time_opts[o]
if count_opt:
print(blanc, \
' times - times applied - nb node created - name:', file=stream)
print(blanc,
' times - times applied - nb node created - name:',
file=stream)
count_opt.sort()
for (t, count, n_created, o) in count_opt[::-1]:
print(blanc, ' %.3fs - %d - %d - %s' % (
......@@ -2010,7 +2002,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
@staticmethod
def merge_profile(prof1, prof2):
#(opt, loop_timing, loop_process_count, max_nb_nodes,
# (opt, loop_timing, loop_process_count, max_nb_nodes,
# global_opt_timing, nb_nodes, time_opts, io_toposort_timing) = prof1
local_optimizers = OrderedSet(prof1[0].get_local_optimizers()).union(
prof2[0].get_local_optimizers())
......@@ -2085,7 +2077,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
final_sub_profs)
#################
### Utilities ###
# Utilities #
#################
......@@ -2096,7 +2088,7 @@ def _check_chain(r, chain):
while chain:
elem = chain.pop()
if elem is None:
if not r.owner is None:
if r.owner is not None:
return False
elif r.owner is None:
return False
......@@ -2105,20 +2097,20 @@ def _check_chain(r, chain):
return False
else:
try:
if (issubclass(elem, op.Op)
and not isinstance(r.owner.op, elem)):
if (issubclass(elem, op.Op) and
not isinstance(r.owner.op, elem)):
return False
except TypeError:
return False
if chain:
r = r.owner.inputs[chain.pop()]
# print 'check_chain', _check_chain.n_calls
#_check_chain.n_calls += 1
# _check_chain.n_calls += 1
# The return value will be used as a Boolean, but some Variables cannot
# be used as Booleans (the results of comparisons, for instance)
return (r is not None)
#_check_chain.n_calls = 0
# _check_chain.n_calls = 0
def check_chain(r, *chain):
......
......@@ -3,9 +3,11 @@ import linecache
import traceback
import sys
import numpy
from six import iteritems
from theano import config
from theano.compat import PY3
def simple_extract_stack(f=None, limit=None):
......@@ -435,3 +437,31 @@ def remove(predicate, coll):
[1, 3]
"""
return [x for x in coll if not predicate(x)]
if PY3:
import hashlib
def hash_from_code(msg):
# hashlib.md5() requires an object that supports buffer interface,
# but Python 3 (unicode) strings don't.
if isinstance(msg, str):
msg = msg.encode()
# Python 3 does not like module names that start with
# a digit.
return 'm' + hashlib.md5(msg).hexdigest()
else:
import hashlib
def hash_from_code(msg):
try:
return hashlib.md5(msg).hexdigest()
except TypeError:
assert isinstance(msg, numpy.ndarray)
return hashlib.md5(numpy.getbuffer(msg)).hexdigest()
def hash_from_file(file_path):
"""Return the MD5 hash of a file."""
return hash_from_code(open(file_path, 'rb').read())
......@@ -10,7 +10,7 @@ import numpy
from theano.compat import decode, decode_iter
from theano.gof import local_bitwidth
from theano.gof.cc import hash_from_file
from theano.gof.utils import hash_from_file
from theano.gof.cmodule import (std_libs, std_lib_dirs,
std_include_dirs, dlimport,
Compiler,
......
from theano.gof.cc import hash_from_code
from theano.gof.utils import hash_from_code
def hash_from_sparse(data):
......
......@@ -2,7 +2,7 @@ import numpy
import theano
from theano.compat import izip
from theano.gof.cc import hash_from_code
from theano.gof.utils import hash_from_code
def hash_from_ndarray(data):
......
......@@ -233,16 +233,10 @@ whitelist_flake8 = [
"sparse/sandbox/sp2.py",
"sparse/sandbox/truedot.py",
"sparse/sandbox/sp.py",
"gof/destroyhandler.py",
"gof/unify.py",
"gof/graph.py",
"gof/__init__.py",
"gof/cc.py",
"gof/opt.py",
"gof/link.py",
"gof/fg.py",
"gof/op.py",
"gof/cmodule.py",
"gof/tests/test_cmodule.py",
"gof/tests/test_destroyhandler.py",
"gof/tests/test_opt.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论