提交 96e48a0e authored 作者: James Bergstra's avatar James Bergstra

merge

"""Provides `DebugMode`, an evaluation mode for debugging theano internals."""
__docformat__ = "restructuredtext en"
import time, copy, sys
import time, copy, sys, copy_reg
from StringIO import StringIO
import numpy
......@@ -1123,6 +1123,9 @@ class _Maker(FunctionMaker): #inheritance buys a few helper functions
fn = self.function_builder(_fn, _i, _o, self.indices, self.outputs, defaults, self.unpack_single, self)
return fn
def _pickle_DebugMode_Maker(maker):
raise NotImplementedError('DebugMode is not picklable (yet)')
copy_reg.pickle(_Maker, _pickle_DebugMode_Maker)
########################
#
......
......@@ -95,6 +95,10 @@ def std_env(input_specs, output_specs, accept_inplace = False):
env.extend(Supervisor(input for spec, input in zip(input_specs, inputs) if not (spec.mutable or (hasattr(env, 'destroyers') and env.destroyers(input)))))
return env, map(SymbolicOutput, updates)
class AliasedMemoryError(Exception):
"""Memory is aliased that should not be"""
pass
###
### Function
......@@ -140,27 +144,75 @@ class Function(object):
"""
input_storage = None
"""list of Container instances"""
output_storage = None
"""list of Container instances"""
indices = None
"""list of (SymbolicInput|SymbolicInputKit, indices, [SymbolicInput,...]), one tuple for
each input
The first tuple element is the SymbolicInput object for the corresponding function input.
The second and third tuple elements are used only by Kits, which are deprecated.
"""
defaults = None
""" list of 3-tuples, one 3-tuple for each input.
Tuple element 0: Bool: Is this input required at each function call?
Tuple element 1: Bool: Should this inputs value be reverted after each call?
Tuple element 2: Any: The value associated with this input.
"""
unpack_single = None
"""Bool: for outputs lists of length 1, should the 0'th element be returned directly?"""
maker = None
"""FunctionMaker instance"""
fn = None
"""a function that evaluates the graph. Typically a linker's make_thunk method created this
function."""
finder = None
"""Dictionary mapping several kinds of things to containers.
We set an entry in finder for:
- the index of the input
- the variable instance the input is based on
- the name of the input
All entries map to the container or to DUPLICATE if an ambiguity is detected
"""
inv_finder = None
"""Dict. Reverse lookup of `finder`.
It maps container -> SymbolicInput
"""
def __init__(self, fn, input_storage, output_storage, indices, outputs, defaults, unpack_single, maker):
"""
fn -> a function returned by some linker's make_thunk method
input_storage -> list of Container instances used by fn to fetch the inputs
output_storage -> list of Container instances used by fn to store the outputs in
indices -> list of (SymbolicInput|SymbolicInputKit, indices, [SymbolicInput,...]), one tuple for each input
defaults -> list of (required (bool), refeed (bool), value), one tuple for each input
required -> whether this input is required or optional
refeed -> whether this input's contents must be reverted to value after each call or not
value -> the initial or default value of the input
unpack_single -> if the function has one output and unpack_single is True, return that output. Else,
return [output].
maker -> FunctionMaker instance used to make this Function (used for copy)
Initialize attributes. create finder, inv_finder.
"""
self.fn = fn
self.input_storage = input_storage
self.output_storage = output_storage
self.indices = indices
self.outputs = outputs
self.defaults = defaults
self.unpack_single = unpack_single
self.maker = maker
containers = list(self.input_storage)
# we'll be popping stuff off this `containers` object. It's a copy
containers = list(self.input_storage)
finder = {}
inv_finder = {}
......@@ -183,11 +235,6 @@ class Function(object):
c.data = value
c.required = required
c.provided = 0 # this is a count of how many times the input has been provided (reinitialized to 0 on __call__)
# We set an entry in finder for:
# - the index of the input
# - the variable instance the input is based on
# - the name of the input
# All entries map to the container or to DUPLICATE if an ambiguity is detected
finder[i] = c
finder[input.variable] = c
finder[input.name] = c if input.name not in finder else DUPLICATE
......@@ -222,10 +269,6 @@ class Function(object):
self.finder = finder
self.inv_finder = inv_finder
self.outputs = outputs
self.defaults = defaults
self.unpack_single = unpack_single
self.maker = maker
# this class is important in overriding the square-bracket notation:
# fn.value[x]
......@@ -254,6 +297,8 @@ class Function(object):
s.provided += 1
else:
s(value)
def __contains__(self, item):
return finder.__contains__(item)
# this class is important in overriding the square-bracket notation:
# fn.container[x]
......@@ -261,11 +306,16 @@ class Function(object):
class ContainerAttribute(object):
def __getitem__(self, item):
return finder[item]
def __contains__(self, item):
return finder.__contains__(item)
# You cannot set the container
self._value = ValueAttribute()
self._container = ContainerAttribute()
def __contains__(self, item):
return self.value.__contains__(item)
def __getitem__(self, item):
return self.value[item]
......@@ -336,18 +386,20 @@ class Function(object):
value = property(
lambda self: self._value,
None, #not settable
doc="""TODOC""")
None, # this property itself is not settable
doc="""dictionary-like access to the values associated with Variables""")
container = property(
lambda self: self._container,
None,
doc="""TODOC""")
None, # this property itself is not settable
doc="""dictionary-like access to the containers associated with Variables""")
# pickling/deepcopy support for Function
def _pickle_Function(f):
#copy of the input storage list
ins = list(f.input_storage)
defaults = []
for (input, indices, inputs), (required, refeed, default) in zip(f.indices, f.defaults):
if isinstance(input, SymbolicInputKit):
li = len(indices)
......@@ -362,7 +414,7 @@ def _pickle_Function(f):
inputs_data = [x.data for x in f.input_storage]
#HACK to detect aliased storage.
# HACK to detect aliased storage.
# aliased relationships will not be preserved across the pickle operation
if not (f.pickle_aliased_memory_strategy == 'ignore'):
all_data = defaults + inputs_data
......@@ -381,8 +433,6 @@ def _pickle_Function(f):
rval = (_constructor_Function, (f.maker, defaults, inputs_data))
return rval
class AliasedMemoryError(Exception):pass
def _constructor_Function(maker, defaults, data):
f = maker.create(defaults, trustme = True)
assert len(f.input_storage) == len(data)
......@@ -659,7 +709,8 @@ class FunctionMaker(object):
def _pickle_FunctionMaker(fm):
return (_constructor_FunctionMaker, (fm.inputs, fm.outputs[0] if fm.unpack_single else fm.outputs, fm.mode, fm.accept_inplace))
rval = (_constructor_FunctionMaker, (fm.inputs, fm.outputs[0] if fm.unpack_single else fm.outputs, fm.mode, fm.accept_inplace))
return rval
def _constructor_FunctionMaker(*args):
return FunctionMaker(*args)
......
......@@ -381,18 +381,27 @@ class T_picklefunction(unittest.TestCase):
x,s = T.scalars('xs')
f = function([x, In(a, value=1.0,name='a'), In(s, value=0.0, update=s+a*x, mutable=True)], s+a*x)
print f.maker.function_builder
g = copy.deepcopy(f)
try:
g = copy.deepcopy(f)
except NotImplementedError, e:
if e[0].startswith('DebugMode is not picklable'):
return
else:
raise
#if they both return, assume that they return equivalent things.
self.failIf(g.container[x].storage is f.container[x].storage)
self.failIf(g.container[a].storage is f.container[a].storage)
self.failIf(g.container[s].storage is f.container[s].storage)
self.failIf(g.value[a] is f.value[a]) # should not have been copied
self.failIf(g.value[s] is f.value[s]) # should have been copied because it is mutable.
self.failIf((g.value[s] != f.value[s]).any()) # its contents should be identical
#print [(k,id(k)) for k in f.finder.keys()]
#print [(k,id(k)) for k in g.finder.keys()]
self.failIf(g.container[0].storage is f.container[0].storage)
self.failIf(g.container[1].storage is f.container[1].storage)
self.failIf(g.container[2].storage is f.container[2].storage)
self.failIf(x in g.container)
self.failIf(x in g.value)
self.failIf(g.value[1] is f.value[1]) # should not have been copied
self.failIf(g.value[2] is f.value[2]) # should have been copied because it is mutable.
self.failIf((g.value[2] != f.value[2]).any()) # its contents should be identical
self.failUnless(f(2, 1) == g(2)) #they should be in sync, default value should be copied.
self.failUnless(f(2, 1) == g(2)) #they should be in sync, default value should be copied.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论