提交 80aba223 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #6271 from nouiz/dareneiri-master

Use sha256 for config hash to be FIPS compliant and make Theano cache more forward compatible
......@@ -1216,9 +1216,7 @@ AddConfigVar('cmodule.age_thresh_use',
AddConfigVar('cmodule.debug',
"If True, define a DEBUG macro (if not exists) for any compiled C code.",
BoolParam(False),
# Do not add it in the c key when we keep use the old default.
# To do not recompile for no good reason.
in_c_key=lambda: theano.config.cmodule.debug)
in_c_key=True)
def default_blas_ldflags():
......
......@@ -183,24 +183,17 @@ def _config_print(thing, buf, print_doc=True):
print("", file=buf)
def get_config_md5():
def get_config_hash():
"""
Return a string md5 of the current config options. It should be such that
we can safely assume that two different config setups will lead to two
different strings.
Return a string sha256 of the current config options. In the past,
it was md5.
The string should be such that we can safely assume that two different
config setups will lead to two different strings.
We only take into account config options for which `in_c_key` is True.
"""
all_opts = []
for c in _config_var_list:
if callable(c.in_c_key):
i = c.in_c_key()
else:
i = c.in_c_key
if i:
all_opts.append(c)
all_opts = sorted(all_opts,
all_opts = sorted([c for c in _config_var_list if c.in_c_key],
key=lambda cv: cv.fullname)
return theano.gof.utils.hash_from_code('\n'.join(
['%s = %s' % (cv.fullname, cv.__get__(True, None)) for cv in all_opts]))
......
......@@ -1229,17 +1229,19 @@ class CLinker(link.Linker):
The signature has the following form:
{{{
'CLinker.cmodule_key', compilation args, libraries,
header_dirs, numpy ABI version, config md5,
header_dirs, numpy ABI version, config hash,
(op0, input_signature0, output_signature0),
(op1, input_signature1, output_signature1),
...
(opK, input_signatureK, output_signatureK),
}}}
Note that config hash now uses sha256, and not md5.
The signature is a tuple, some elements of which are sub-tuples.
The outer tuple has a brief header, containing the compilation options
passed to the compiler, the libraries to link against, an md5 hash
passed to the compiler, the libraries to link against, a sha256 hash
of theano.config (for all config options where "in_c_key" is True).
It is followed by elements for every node in the topological ordering
of `self.fgraph`.
......@@ -1298,7 +1300,7 @@ class CLinker(link.Linker):
def cmodule_key_variables(self, inputs, outputs, no_recycling,
compile_args=None, libraries=None,
header_dirs=None, insert_config_md5=True,
header_dirs=None, insert_config_hash=True,
c_compiler=None):
# Assemble a dummy fgraph using the provided inputs and outputs. It is
......@@ -1321,11 +1323,11 @@ class CLinker(link.Linker):
fgraph = FakeFunctionGraph(inputs, outputs)
return self.cmodule_key_(fgraph, no_recycling, compile_args,
libraries, header_dirs, insert_config_md5,
libraries, header_dirs, insert_config_hash,
c_compiler)
def cmodule_key_(self, fgraph, no_recycling, compile_args=None,
libraries=None, header_dirs=None, insert_config_md5=True,
libraries=None, header_dirs=None, insert_config_hash=True,
c_compiler=None):
"""
Do the actual computation of cmodule_key in a static method
......@@ -1347,7 +1349,7 @@ class CLinker(link.Linker):
constant_ids = dict()
op_pos = {} # Apply -> topological position
# First we put the header, compile_args, library names and config md5
# First we put the header, compile_args, library names and config hash
# into the signature.
sig = ['CLinker.cmodule_key'] # will be cast to tuple on return
if compile_args is not None:
......@@ -1380,8 +1382,11 @@ class CLinker(link.Linker):
# parameters from the rest of the key. If you want to add more key
# elements, they should be before this md5 hash if and only if they
# can lead to a different compiled file with the same source code.
if insert_config_md5:
sig.append('md5:' + theano.configparser.get_config_md5())
# NOTE: config md5 is not using md5 hash, but sha256 instead. Function
# string instances of md5 will be updated at a later release.
if insert_config_hash:
sig.append('md5:' + theano.configparser.get_config_hash())
else:
sig.append('md5: <omitted>')
......@@ -1439,6 +1444,8 @@ class CLinker(link.Linker):
for node_pos, node in enumerate(order):
if hasattr(node.op, 'c_code_cache_version_apply'):
version.append(node.op.c_code_cache_version_apply(node))
if hasattr(node.op, '__props__'):
version.append(node.op.__props__)
for i in node.inputs:
version.append(i.type.c_code_cache_version())
for o in node.outputs:
......
......@@ -377,7 +377,7 @@ def is_same_entry(entry_1, entry_2):
def get_module_hash(src_code, key):
"""
Return an MD5 hash that uniquely identifies a module.
Return a SHA256 hash that uniquely identifies a module.
This hash takes into account:
1. The C source code of the module (`src_code`).
......@@ -415,9 +415,12 @@ def get_module_hash(src_code, key):
# libraries to link against.
to_hash += list(key_element)
elif isinstance(key_element, string_types):
if key_element.startswith('md5:'):
# This is the md5 hash of the config options. We can stop
# here.
if (key_element.startswith('md5:') or
key_element.startswith('hash:')):
# This is actually a sha256 hash of the config options.
# Currently, we still keep md5 to don't break old Theano.
# We add 'hash:' so that when we change it in
# the futur, it won't break this version of Theano.
break
elif (key_element.startswith('NPY_ABI_VERSION=0x') or
key_element.startswith('c_compiler_str=')):
......@@ -435,29 +438,36 @@ def get_safe_part(key):
This tuple should only contain objects whose __eq__ and __hash__ methods
can be trusted (currently: the version part of the key, as well as the
md5 hash of the config options).
SHA256 hash of the config options).
It is used to reduce the amount of key comparisons one has to go through
in order to find broken keys (i.e. keys with bad implementations of __eq__
or __hash__).
"""
version = key[0]
# This function should only be called on versioned keys.
assert version
# Find the md5 hash part.
# Find the hash part. This is actually a sha256 hash of the config
# options. Currently, we still keep md5 to don't break old
# Theano. We add 'hash:' so that when we change it
# in the futur, it won't break this version of Theano.
c_link_key = key[1]
# In case in the future, we don't have an md5 part and we have
# such stuff in the cache. In that case, we can set None, and the
# rest of the cache mechanism will just skip that key.
md5 = None
hash = None
for key_element in c_link_key[1:]:
if (isinstance(key_element, string_types) and
key_element.startswith('md5:')):
md5 = key_element[4:]
break
if isinstance(key_element, string_types):
if key_element.startswith('md5:'):
hash = key_element[4:]
break
elif key_element.startswith('hash:'):
hash = key_element[5:]
break
return key[0] + (md5, )
return key[0] + (hash, )
class KeyData(object):
......@@ -1238,7 +1248,7 @@ class ModuleCache(object):
"Ops. The file is: %s. The key is: %s" % (msg, key_pkl, key))
# Also verify that there exists no other loaded key that would be equal
# to this key. In order to speed things up, we only compare to keys
# with the same version part and config md5, since we can assume this
# with the same version part and config hash, since we can assume this
# part of the key is not broken.
for other in self.similar_keys.get(get_safe_part(key), []):
if other is not key and other == key and hash(other) != hash(key):
......
......@@ -295,8 +295,8 @@ class ParamsType(Type):
# (see c_support_code() below).
fields_string = ','.join(self.fields).encode('utf-8')
types_string = ','.join(str(t) for t in self.types).encode('utf-8')
fields_hex = hashlib.md5(fields_string).hexdigest()
types_hex = hashlib.md5(types_string).hexdigest()
fields_hex = hashlib.sha256(fields_string).hexdigest()
types_hex = hashlib.sha256(types_string).hexdigest()
return '_Params_%s_%s' % (fields_hex, types_hex)
def has_type(self, theano_type):
......
......@@ -546,28 +546,28 @@ if PY3:
import hashlib
def hash_from_code(msg):
# hashlib.md5() requires an object that supports buffer interface,
# hashlib.sha256() requires an object that supports buffer interface,
# but Python 3 (unicode) strings don't.
if isinstance(msg, str):
msg = msg.encode()
# Python 3 does not like module names that start with
# a digit.
return 'm' + hashlib.md5(msg).hexdigest()
return 'm' + hashlib.sha256(msg).hexdigest()
else:
import hashlib
def hash_from_code(msg):
try:
return hashlib.md5(msg).hexdigest()
return hashlib.sha256(msg).hexdigest()
except TypeError:
assert isinstance(msg, np.ndarray)
return hashlib.md5(np.getbuffer(msg)).hexdigest()
return hashlib.sha256(np.getbuffer(msg)).hexdigest()
def hash_from_file(file_path):
"""
Return the MD5 hash of a file.
Return the SHA256 hash of a file.
"""
with open(file_path, 'rb') as f:
......
......@@ -1221,7 +1221,7 @@ def var_descriptor(obj, _prev_obs=None, _tag_generator=None):
name = '<ndarray:'
name += 'strides=[' + ','.join(str(stride)
for stride in obj.strides) + ']'
name += ',digest=' + hashlib.md5(obj).hexdigest() + '>'
name += ',digest=' + hashlib.sha256(obj).hexdigest() + '>'
elif hasattr(obj, 'owner') and obj.owner is not None:
name = str(obj.owner.op) + '('
name += ','.join(var_descriptor(ipt,
......@@ -1265,7 +1265,7 @@ def hex_digest(x):
Returns a short, mostly hexadecimal hash of a numpy ndarray
"""
assert isinstance(x, np.ndarray)
rval = hashlib.md5(x.tostring()).hexdigest()
rval = hashlib.sha256(x.tostring()).hexdigest()
# hex digest must be annotated with strides to avoid collisions
# because the buffer interface only exposes the raw data, not
# any info about the semantics of how that data should be arranged
......
......@@ -9,7 +9,7 @@ def hash_from_sparse(data):
# We also need to add the dtype to make the distinction between
# uint32 and int32 of zeros with the same shape.
# Python hash is not strong, so I always use md5. To avoid having a too
# Python hash is not strong, so use sha256 instead. To avoid having a too
# long hash, I call it again on the contatenation of all parts.
return hash_from_code(hash_from_code(data.data) +
hash_from_code(data.indices) +
......
......@@ -19,8 +19,9 @@ def hash_from_ndarray(data):
# We also need to add the dtype to make the distinction between
# uint32 and int32 of zeros with the same shape and strides.
# python hash are not strong, so I always use md5 in order not to have a
# too long hash, I call it again on the concatenation of all parts.
# python hash are not strong, so use sha256 (md5 is not
# FIPS compatible). To not have too long of hash, I call it again on
# the concatenation of all parts.
if not data.flags["C_CONTIGUOUS"]:
# hash_from_code needs a C-contiguous array.
data = np.ascontiguousarray(data)
......
......@@ -112,7 +112,7 @@ class Record(object):
class RecordMode(Mode):
"""
Records all computations done with a function in a file at output_path.
Writes into the file the index of each apply node and md5 digests of the
Writes into the file the index of each apply node and sha256 digests of the
numpy ndarrays it receives as inputs and produces as output.
Example:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论