提交 f6e05092 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

function works

上级 80a40cf0
...@@ -9,6 +9,10 @@ import tensor_opt ...@@ -9,6 +9,10 @@ import tensor_opt
def check_equal(x, y): def check_equal(x, y):
"""
Returns True iff x[0] and y[0] are equal (checks the dtype and
shape if x and y are numpy.ndarray instances). Used internally.
"""
x, y = x[0], y[0] x, y = x[0], y[0]
if isinstance(x, numpy.ndarray) or isinstance(y, numpy.ndarray): if isinstance(x, numpy.ndarray) or isinstance(y, numpy.ndarray):
if x.dtype != y.dtype or x.shape != y.shape or numpy.any(abs(x - y) > 1e-10): if x.dtype != y.dtype or x.shape != y.shape or numpy.any(abs(x - y) > 1e-10):
...@@ -18,6 +22,11 @@ def check_equal(x, y): ...@@ -18,6 +22,11 @@ def check_equal(x, y):
raise Exception("Output mismatch.", {'performlinker': x, 'clinker': y}) raise Exception("Output mismatch.", {'performlinker': x, 'clinker': y})
def infer_reuse_pattern(env, outputs_to_disown): def infer_reuse_pattern(env, outputs_to_disown):
"""
Given an env and a list of results, returns the list of all
results which may share the same underlying data storage as any of
the specified results. Used internally by function, FunctionMaker.
"""
do_not_reuse = list() do_not_reuse = list()
seen = set() seen = set()
def walk(r): def walk(r):
...@@ -36,6 +45,9 @@ def infer_reuse_pattern(env, outputs_to_disown): ...@@ -36,6 +45,9 @@ def infer_reuse_pattern(env, outputs_to_disown):
walk(output) walk(output)
return do_not_reuse return do_not_reuse
# If a string is passed as the linker argument in the constructor for
# Mode, it will be used as the key to retrieve the real linker in this
# dictionary
predefined_linkers = { predefined_linkers = {
'py' : gof.PerformLinker(), 'py' : gof.PerformLinker(),
'c' : gof.CLinker(), 'c' : gof.CLinker(),
...@@ -46,6 +58,9 @@ predefined_linkers = { ...@@ -46,6 +58,9 @@ predefined_linkers = {
default_linker = 'c|py' default_linker = 'c|py'
# If a string is passed as the optimizer argument in the constructor
# for Mode, it will be used as the key to retrieve the real optimizer
# in this dictionary
predefined_optimizers = { predefined_optimizers = {
None : lambda env: None, None : lambda env: None,
'merge' : gof.MergeOptimizer(), 'merge' : gof.MergeOptimizer(),
...@@ -56,6 +71,20 @@ default_optimizer = 'merge' ...@@ -56,6 +71,20 @@ default_optimizer = 'merge'
class Mode(object): class Mode(object):
"""
The Mode represents a way to optimize and then link a computation
graph.
* optimizer -> a structure of type Optimizer. An Optimizer may
simplify the math, put similar computations together, improve
numerical stability and various other improvements.
* linker -> a structure of type Linker. A Linker decides which
implementations to use (C or Python, for example) and how to
string them together to perform the computation.
See predefined_linkers, predefined_optimizers and also
predefined_modes.
"""
def __init__(self, linker = default_linker, optimizer = default_optimizer): def __init__(self, linker = default_linker, optimizer = default_optimizer):
self.provided_linker = linker self.provided_linker = linker
...@@ -70,11 +99,13 @@ class Mode(object): ...@@ -70,11 +99,13 @@ class Mode(object):
def __str__(self): def __str__(self):
return "Mode(linker = %s, optimizer = %s)" % (self.provided_linker, self.provided_optimizer) return "Mode(linker = %s, optimizer = %s)" % (self.provided_linker, self.provided_optimizer)
# If a string is passed as the mode argument in function or
# FunctionMaker, the Mode will be taken from this dictionary using the
# string as the key
predefined_modes = { predefined_modes = {
'SANITY_CHECK' : Mode('c&py', 'math'), 'SANITY_CHECK' : Mode('c&py', 'math'),
'FAST_COMPILE' : Mode('py', None), 'FAST_COMPILE' : Mode('py', None),
'FAST_RUN' : Mode('c|py', None), #'math'), 'FAST_RUN' : Mode('c|py', 'math'),
'EXPENSIVE_OPTIMIZATIONS' : Mode('c|py', 'math') 'EXPENSIVE_OPTIMIZATIONS' : Mode('c|py', 'math')
} }
...@@ -84,9 +115,9 @@ default_mode = 'FAST_RUN' ...@@ -84,9 +115,9 @@ default_mode = 'FAST_RUN'
class SymbolicInput(object): class SymbolicInput(object):
def __init__(self, result, name=None, update=None, mutable=None, autoname=True):
""" """
Represents a symbolic input for use with function or FunctionMaker.
result: a Result instance. result: a Result instance.
This will be assigned a value before running the function, This will be assigned a value before running the function,
not computed from its owner. not computed from its owner.
...@@ -95,10 +126,9 @@ class SymbolicInput(object): ...@@ -95,10 +126,9 @@ class SymbolicInput(object):
If name is a valid Python identifier, this input can be set by kwarg, and its value If name is a valid Python identifier, this input can be set by kwarg, and its value
can be accessed by self.<name>. can be accessed by self.<name>.
update: Result instance update: Result instance (default: None)
value (see previous) will be replaced with this expression result after each function call. value (see previous) will be replaced with this expression result after each function call.
If update is None, the update will be the default value of the input. If update is None, the update will be the default value of the input.
has_default must be True.
mutable: Bool (default: False if update is None, True if update is not None) mutable: Bool (default: False if update is None, True if update is not None)
True: permit the compiled function to modify the python object being passed as the input True: permit the compiled function to modify the python object being passed as the input
...@@ -107,13 +137,14 @@ class SymbolicInput(object): ...@@ -107,13 +137,14 @@ class SymbolicInput(object):
autoname: Bool (default: True) autoname: Bool (default: True)
See the name option. See the name option.
""" """
def __init__(self, result, name=None, update=None, mutable=None, autoname=True):
self.result = result self.result = result
self.name = result.name if (autoname and name is None) else name self.name = result.name if (autoname and name is None) else name
#self.has_default = has_default if (has_default is not None) else (update is not None) if self.name is not None and not isinstance(self.name, str):
self.update = update # if (update or not value) else gof.Constant(result.type, value) raise TypeError("name must be a string! (got: %s)" % self.name)
self.update = update
self.mutable = mutable if (mutable is not None) else (update is not None) self.mutable = mutable if (mutable is not None) else (update is not None)
# if self.update is not None and not self.has_default:
# raise ValueError("If update is not None, the input must accept a default value.")
def __str__(self): def __str__(self):
if self.update: if self.update:
...@@ -126,6 +157,15 @@ class SymbolicInput(object): ...@@ -126,6 +157,15 @@ class SymbolicInput(object):
class SymbolicInputKit(object): class SymbolicInputKit(object):
"""
Represents a group ("kit") of SymbolicInputs. If fed into function or
FunctionMaker, only the inputs which are needed to compile the function
properly will be taken.
A SymbolicInputKit provides the distribute function in order to set or
initialize several inputs from a single value. Specialized Kits should
override it.
"""
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
...@@ -133,13 +173,31 @@ class SymbolicInputKit(object): ...@@ -133,13 +173,31 @@ class SymbolicInputKit(object):
self.results = [] self.results = []
def add_input(self, sinput): def add_input(self, sinput):
"""
Add a SymbolicInput to this SymbolicInputKit. It will be given the
next available index.
"""
self.sinputs.append(sinput) self.sinputs.append(sinput)
self.results.append(sinput.result) self.results.append(sinput.result)
def distribute(self, value, indices, containers): def distribute(self, value, indices, containers):
"""
Given a list of indices corresponding to SymbolicInputs in this kit
as well as a corresponding list of containers, initialize all the
containers using the provided value.
"""
raise NotImplementedError raise NotImplementedError
def complete(self, inputs): def complete(self, inputs):
"""
Given inputs (a list of Result instances), checks through all
the SymbolicInputs in the kit and return a sorted list of
indices and a list of their corresponding SymbolicInputs such
that each of them represents some result in the inputs list.
Not all the provided inputs will have a corresponding
SymbolicInput in the kit.
"""
ret = [] ret = []
for input in inputs: for input in inputs:
try: try:
...@@ -151,15 +209,50 @@ class SymbolicInputKit(object): ...@@ -151,15 +209,50 @@ class SymbolicInputKit(object):
return zip(*ret) return zip(*ret)
class SymbolicOutput(object): class In(SymbolicInput):
"""
Represents a symbolic input for use with function or FunctionMaker.
def __init__(self, result, borrow=False): result: a Result instance.
This will be assigned a value before running the function,
not computed from its owner.
name: Any type. (If autoname=True, defaults to result.name).
If name is a valid Python identifier, this input can be set by kwarg, and its value
can be accessed by self.<name>.
value: Any type.
The initial/default value for this input. If update is None, this input acts just like
an argument with a default value in Python. If update is not None, changes to this
value will "stick around", whether due to an update or a user's explicit action.
update: Result instance (default: None)
value (see previous) will be replaced with this expression result after each function call.
If update is None, the update will be the default value of the input.
mutable: Bool (default: False if update is None, True if update is not None)
True: permit the compiled function to modify the python object being passed as the input
False: do not permit the compiled function to modify the python object being passed as the input.
autoname: Bool (default: True)
See the name option.
"""
def __init__(self, result, name=None, value=None, update=None, mutable=None, autoname=True):
super(In, self).__init__(result, name, update, mutable, autoname)
self.value = value
class SymbolicOutput(object):
""" """
Represents a symbolic output for use with function or FunctionMaker.
borrow: set this to True to indicate that a reference to borrow: set this to True to indicate that a reference to
function's internal storage may be returned. A value function's internal storage may be returned. A value
returned for this output might be clobbered by running returned for this output might be clobbered by running
the function again, but the function might be faster. the function again, but the function might be faster.
""" """
def __init__(self, result, borrow=False):
self.result = result self.result = result
self.borrow = borrow self.borrow = borrow
...@@ -168,6 +261,10 @@ Out = SymbolicOutput ...@@ -168,6 +261,10 @@ Out = SymbolicOutput
class Supervisor: class Supervisor:
"""
Listener for Env events which makes sure that no operation overwrites the
contents of protected Results. The outputs of the Env are protected by default.
"""
def __init__(self, protected): def __init__(self, protected):
self.protected = list(protected) self.protected = list(protected)
...@@ -181,6 +278,21 @@ class Supervisor: ...@@ -181,6 +278,21 @@ class Supervisor:
def std_env(input_specs, output_specs, accept_inplace = False): def std_env(input_specs, output_specs, accept_inplace = False):
"""
Makes an Env corresponding to the input specs and the output
specs. Any SymbolicInput in the input_specs, if its update field
is not None, will add an output to the Env corresponding to that
update. The return value is the Env as well as a list of
SymbolicOutput instances corresponding to the updates.
If accept_inplace is False, the graph will be checked for inplace
operations and an exception will be raised if it has any. If
accept_inplace is True, a DestroyHandler will be added to the Env
if there are any inplace operations.
The returned Env is a clone of the graph between the provided
inputs and outputs.
"""
orig_inputs = [spec.result for spec in input_specs] orig_inputs = [spec.result for spec in input_specs]
updates = [spec.update for spec in input_specs if spec.update] updates = [spec.update for spec in input_specs if spec.update]
orig_outputs = [spec.result for spec in output_specs] + updates orig_outputs = [spec.result for spec in output_specs] + updates
...@@ -196,6 +308,7 @@ def std_env(input_specs, output_specs, accept_inplace = False): ...@@ -196,6 +308,7 @@ def std_env(input_specs, output_specs, accept_inplace = False):
env.extend(gof.DestroyHandler()) env.extend(gof.DestroyHandler())
break break
# We need to protect all immutable inputs from inplace operations.
env.extend(Supervisor(input for spec, input in zip(input_specs, inputs) if not spec.mutable)) env.extend(Supervisor(input for spec, input in zip(input_specs, inputs) if not spec.mutable))
return env, map(SymbolicOutput, updates) return env, map(SymbolicOutput, updates)
...@@ -207,8 +320,10 @@ class FunctionMaker(object): ...@@ -207,8 +320,10 @@ class FunctionMaker(object):
if isinstance(input, (SymbolicInput, SymbolicInputKit)): if isinstance(input, (SymbolicInput, SymbolicInputKit)):
return input return input
elif isinstance(input, gof.Result): elif isinstance(input, gof.Result):
# r -> SymbolicInput(result=r)
return SymbolicInput(input) return SymbolicInput(input)
elif isinstance(input, (list, tuple)): elif isinstance(input, (list, tuple)):
# (r, u) -> SymbolicInput(result=r, update=u)
if len(input) == 2: if len(input) == 2:
return SymbolicInput(input[0], update = input[1]) return SymbolicInput(input[0], update = input[1])
else: else:
...@@ -217,11 +332,16 @@ class FunctionMaker(object): ...@@ -217,11 +332,16 @@ class FunctionMaker(object):
raise TypeError("Unknown input type:", type(input), input) raise TypeError("Unknown input type:", type(input), input)
@staticmethod @staticmethod
def expand_in(input, inputs): def expand_in(sinput, rinputs):
if isinstance(input, SymbolicInputKit): # For SymbolicInputKits, this extracts a list of SymbolicInput instances
return input.complete(inputs) # and corresponding indices such that these SymbolicInputs are representative
elif isinstance(input, SymbolicInput): # of some of the Result instances in inputs.
return [None, [input]] # For SymbolicInput, this returns None as the list of indices and a list with
# just the SymbolicInput.
if isinstance(sinput, SymbolicInputKit):
return sinput.complete(rinputs)
elif isinstance(sinput, SymbolicInput):
return [None, [sinput]]
@staticmethod @staticmethod
def wrap_out(output): def wrap_out(output):
...@@ -233,6 +353,18 @@ class FunctionMaker(object): ...@@ -233,6 +353,18 @@ class FunctionMaker(object):
raise TypeError("Unknown output type:", type(output), output) raise TypeError("Unknown output type:", type(output), output)
def __init__(self, inputs, outputs, mode = 'FAST_RUN', accept_inplace = True): def __init__(self, inputs, outputs, mode = 'FAST_RUN', accept_inplace = True):
"""
Create a FunctionMaker for the specified inputs, outputs and mode.
inputs -> a list of SymbolicInput instances
outputs -> a list of SymbolicOutput instances
outputs may also be a single Result (not a list), in which
case the functions produced by FunctionMaker will return
their output value directly
mode -> a Mode instance telling FunctionMaker how to optimize and link
accept_inplace -> True iff it is acceptable to have inplace operations
in the graph from the inputs to the outputs
"""
# Handle the case where inputs and/or outputs is a single Result (not in a list) # Handle the case where inputs and/or outputs is a single Result (not in a list)
unpack_single = False unpack_single = False
...@@ -246,12 +378,13 @@ class FunctionMaker(object): ...@@ -246,12 +378,13 @@ class FunctionMaker(object):
inputs, outputs = map(self.wrap_in, inputs), map(self.wrap_out, outputs) inputs, outputs = map(self.wrap_in, inputs), map(self.wrap_out, outputs)
_inputs = gof.graph.inputs([o.result for o in outputs]) _inputs = gof.graph.inputs([o.result for o in outputs])
indices = [[input] + self.expand_in(input, _inputs) for input in inputs] indices = [[input] + self.expand_in(input, _inputs) for input in inputs]
expanded_inputs = reduce(list.__add__, [list(z) for x, y, z in indices]) expanded_inputs = reduce(list.__add__, [list(z) for x, y, z in indices], [])
# make the env # make the env
env, additional_outputs = std_env(expanded_inputs, outputs, accept_inplace) env, additional_outputs = std_env(expanded_inputs, outputs, accept_inplace)
self.env = env self.env = env
# Fetch the mode and then the optimizer and linker
mode = predefined_modes.get(mode, mode) mode = predefined_modes.get(mode, mode)
optimizer, linker = mode.optimizer, copy(mode.linker) optimizer, linker = mode.optimizer, copy(mode.linker)
...@@ -275,123 +408,203 @@ class FunctionMaker(object): ...@@ -275,123 +408,203 @@ class FunctionMaker(object):
self.outputs = outputs self.outputs = outputs
self.unpack_single = unpack_single self.unpack_single = unpack_single
def create(self, defaults = [], profiler = None): def create(self, defaults = None, trustme = False):
input_storage = [] """
Create a function.
defaults -> a list matching the inputs list and providing default values
if the default for an input is None, then that input is a
required input. For an input with an update, the default
acts as initialization.
trustme -> disables some exceptions, used internally
"""
if defaults is None:
defaults = [None]*len(self.inputs)
input_storage = [] # list of independent one-element lists, will be passed to the linker
_defaults = [] _defaults = []
for (input, indices, results), default in zip(self.indices, defaults):
# The following loop is to fill in the input_storage and _defaults lists.
for (input, indices, subinputs), default in zip(self.indices, defaults):
__default = default
# If the default is a gof.Filter, this means we want to share
# the same storage. This is done by appending default.storage
# to input_storage
if isinstance(default, gof.Filter): if isinstance(default, gof.Filter):
if indices is not None: if indices is not None:
raise TypeError("Cannot take a Filter instance as default for a SymbolicInputKit.") raise TypeError("Cannot take a Filter instance as default for a SymbolicInputKit.")
input_storage.append(default.storage) input_storage.append(default.storage)
continue default = None
if indices is None: # If the input is a SymbolicInputKit, it represents more than
input_storage.append([None]) # one storage unit. The indices and subinputs lists represent which
else: # of the kit's inputs are active in this graph, so we make as many
# storage units as needed
elif isinstance(input, SymbolicInputKit):
input_storage += [[None] for i in indices] input_storage += [[None] for i in indices]
# Normal case: one new, independent storage unit
else:
input_storage.append([None])
# Filling _defaults. Each entry is a tuple of three elements:
# (required, refeed, value)
# - required means that the user must provide a value when calling the function
# - refeed means that we want to put the default back in the storage after each function call
# - value is the value that will be put in the storage initially
# Even though a SymbolicInputKit represents more than one input,
# we still only have one entry for the defaults list.
if isinstance(input, SymbolicInputKit): if isinstance(input, SymbolicInputKit):
if default is None: if default is None:
_defaults.append((True, True, None)) _defaults.append((True, True, None))
else: else:
_defaults.append((False, False, default)) _defaults.append((False, False, default))
elif input.update is not None: elif input.update is not None:
# If the input has an update, then (logically) it is not required since
# it is just a parameter and of course we don't want to refeed the default
# back into the storage as it would defeat the point of updating it. We
# always do this policy.
if default is None: if default is None:
raise ValueError("A default (initial) value is required for an input which can update itself.") if trustme or isinstance(__default, gof.Filter):
_defaults.append((False, False, default))
else:
# This might catch some bugs early
raise ValueError("A default (initial) value is required for an input which can update itself.", input)
else: else:
_defaults.append((False, False, default)) _defaults.append((False, False, default))
else: else:
if default is None: if default is None:
# No default, so this is a required input. Nothing to feed back, initial value is None.
_defaults.append((True, False, None)) _defaults.append((True, False, None))
else: else:
# Default value. It is not required, but we want to put it back into the storage
# everytime so it behaves like most programming languages' default values
_defaults.append((False, True, default)) _defaults.append((False, True, default))
defaults = _defaults defaults = _defaults
# Get a function instance # Get a function instance
if profiler is None:
# some linkers may not support profilers, so we avoid passing the option altogether
_fn, _i, _o = self.linker.make_thunk(input_storage = input_storage) _fn, _i, _o = self.linker.make_thunk(input_storage = input_storage)
else: fn = Function(_fn, _i, _o, self.indices, self.outputs, defaults, self.unpack_single, self)
_fn, _i, _o = self.linker.make_thunk(input_storage = input_storage,
profiler = profiler)
fn = Function(_fn, _i, _o, self.indices, self.outputs, defaults, self.unpack_single, self, profiler)
return fn return fn
from functools import partial from functools import partial
DUPLICATE = ['DUPLICATE'] # unique id object used as a placeholder for duplicate entries
class Function(object): class Function(object):
"""
Type of the functions returned by theano.function or theano.FunctionMaker.create.
"""
def __init__(self, fn, input_storage, output_storage, indices, outputs, defaults, unpack_single, maker, profiler): def __init__(self, fn, input_storage, output_storage, indices, outputs, defaults, unpack_single, maker):
"""
fn -> a function returned by some linker's make_thunk method
input_storage -> list of Filter instances used by fn to fetch the inputs
output_storage -> list of Filter instances used by fn to store the outputs in
indices -> list of (SymbolicInput|SymbolicInputKit, indices, [SymbolicInput,...]), one tuple for each input
defaults -> list of (required (bool), refeed (bool), value), one tuple for each input
required -> whether this input is required or optional
refeed -> whether this input's contents must be reverted to value after each call or not
value -> the initial or default value of the input
unpack_single -> if the function has one output and unpack_single is True, return that output. Else,
return [output].
maker -> FunctionMaker instance used to make this Function (used for copy)
"""
self.fn = fn self.fn = fn
self.input_storage = input_storage self.input_storage = input_storage
self.output_storage = output_storage self.output_storage = output_storage
self.indices = indices
containers = list(self.input_storage) containers = list(self.input_storage)
finder = {} finder = {}
inv_finder = {} inv_finder = {}
defaults_layer = []
def distribute(indices, cs, value): def distribute(indices, cs, value):
input.distribute(value, indices, cs) input.distribute(value, indices, cs)
for c in cs: for c in cs:
c.provided = True c.provided += 1
def set(c, v): def set(c, v):
c.data = v c.data = v
setters = [] setters = []
for i, ((input, indices, sinputs), (required, refeed, value)) in enumerate(zip(indices, defaults)): # Initialize the storage
if indices is None: for i, ((input, indices, sinputs), (required, refeed, value)) in enumerate(zip(self.indices, defaults)):
if indices is None: # this is true iff input is not a SymbolicInputKit
c = containers[0] c = containers[0]
if value: if value is not None:
# always initialize the storage
c.data = value c.data = value
c.required = required c.required = required
c.provided = False c.provided = 0 # this is a count of how many times the input has been provided (reinitialized to 0 on __call__)
# We set an entry in finder for:
# - the index of the input
# - the result instance the input is based on
# - the name of the input
# All entries map to the container or to DUPLICATE if an ambiguity is detected
finder[i] = c finder[i] = c
finder[input.result] = c finder[input.result] = c
finder[input.name] = c finder[input.name] = c if input.name not in finder else DUPLICATE
# inv_finder maps the container to the input (useful for one error message)
inv_finder[c] = input inv_finder[c] = input
setters.append(partial(set, c)) setters.append(partial(set, c))
containers[:1] = [] containers[:1] = []
else: else:
# The input is a SymbolicInputKit, so we take as many containers as the Kit provides inputs
cs = containers[:len(indices)] cs = containers[:len(indices)]
# distribute does the initialization of the containers
input.distribute(value, indices, cs) input.distribute(value, indices, cs)
f = partial(distribute, indices, cs) f = partial(distribute, indices, cs)
# Like before, we set a finder entry for the kit. Note that
# we are not mapping to a container but to a function which
# can reinitialize all the containers
finder[i] = f finder[i] = f
finder[input] = f finder[input] = f
finder[input.name] = f finder[input.name] = f if input.name not in finder else DUPLICATE
setters.append(f) setters.append(f)
# For each input in the kit and its corresponding container, we put an entry in finder.
# This allows the user to micro-manage elements of the kit if need be.
# All containers inherit the required field and have their own "provided" counter
for c, sin in zip(cs, sinputs): for c, sin in zip(cs, sinputs):
finder[sin.result] = c finder[sin.result] = c
finder[sin.name] = c finder[sin.name] = c
finder[sin.name] = c if sin.name not in finder else DUPLICATE
inv_finder[c] = input inv_finder[c] = input
c.required = required c.required = required
c.provided = False c.provided = 0
containers[:len(indices)] = [] containers[:len(indices)] = []
self.finder = finder self.finder = finder
self.inv_finder = inv_finder self.inv_finder = inv_finder
self.indices = indices
self.outputs = outputs self.outputs = outputs
self.defaults = defaults self.defaults = defaults
self.unpack_single = unpack_single self.unpack_single = unpack_single
self.maker = maker self.maker = maker
self.profiler = profiler
# this class is important in overriding the square-bracket notation: # this class is important in overriding the square-bracket notation:
# fn.value[x] # fn.value[x]
# self reference is available via the closure on the class # self reference is available via the closure on the class
class ValueAttribute(object): class ValueAttribute(object):
def __getitem__(self, item): def __getitem__(self, item):
try:
s = finder[item] s = finder[item]
except KeyError:
raise TypeError("Unknown input or state: %s" % item)
if s is DUPLICATE:
raise TypeError("Ambiguous name: %s - please check the names of the inputs of your function for duplicates." % item)
if isinstance(s, gof.Filter): if isinstance(s, gof.Filter):
return s.value return s.value
else: else:
raise NotImplementedError raise NotImplementedError
def __setitem__(self, item, value): def __setitem__(self, item, value):
try:
s = finder[item] s = finder[item]
except KeyError:
raise TypeError("Unknown input or state: %s" % item)
if s is DUPLICATE:
raise TypeError("Ambiguous name: %s - please check the names of the inputs of your function for duplicates." % item)
if isinstance(s, gof.Filter): if isinstance(s, gof.Filter):
s.value = value s.value = value
s.provided = True s.provided += 1
else: else:
s(value) s(value)
...@@ -401,11 +614,11 @@ class Function(object): ...@@ -401,11 +614,11 @@ class Function(object):
class ContainerAttribute(object): class ContainerAttribute(object):
def __getitem__(self, item): def __getitem__(self, item):
return finder[item] return finder[item]
# You cannot set the container
self._value = ValueAttribute() self._value = ValueAttribute()
self._container = ContainerAttribute() self._container = ContainerAttribute()
def __getitem__(self, item): def __getitem__(self, item):
return self.value[item] return self.value[item]
...@@ -413,35 +626,47 @@ class Function(object): ...@@ -413,35 +626,47 @@ class Function(object):
self.value[item] = value self.value[item] = value
def __copy__(self): def __copy__(self):
defaults = list(self.defaults) defaults = [default for _1, _2, default in self.defaults]
for i, default in enumerate(defaults): cpy = self.maker.create(defaults, trustme = True)
default[0] = self[i] for (input,_1,_2), here, there in zip(self.indices, self.input_storage, cpy.input_storage):
return self.maker.create(defaults, self.profiler) if input.mutable and here is not None:
there.data = copy(here.data)
else:
there.data = here.data
return cpy
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
# Reinitialize each container's 'provided' counter
for c in self.input_storage: for c in self.input_storage:
c.provided = False c.provided = 0
for i, (required, refeed, value) in enumerate(self.defaults): # Set positional arguments
if refeed:
self[i] = value
for i, arg in enumerate(args): for i, arg in enumerate(args):
self[i] = arg self[i] = arg
for k, arg in enumerate(kwargs): # Set keyword arguments
for k, arg in kwargs.iteritems():
self[k] = arg self[k] = arg
# Check if inputs are missing or if inputs were set more than once
for c in self.input_storage: for c in self.input_storage:
if c.required and not c.provided: if c.required and not c.provided:
raise TypeError("Missing required input: ", self.inv_finder[c].result) raise TypeError("Missing required input: %s" % self.inv_finder[c].result)
if c.provided > 1:
raise TypeError("Multiple values for input: %s" % self.inv_finder[c].result)
# Do the actual work
self.fn() self.fn()
outputs = [x.data for x in self.output_storage] outputs = [x.data for x in self.output_storage]
# Update the inputs that have an update function
for input, storage in reversed(zip(self.maker.expanded_inputs, self.input_storage)): for input, storage in reversed(zip(self.maker.expanded_inputs, self.input_storage)):
if input.update: if input.update:
storage.data = outputs.pop() storage.data = outputs.pop()
# Put default values back in the storage
for i, (required, refeed, value) in enumerate(self.defaults):
if refeed:
self[i] = value
if self.unpack_single and len(outputs) == 1: if self.unpack_single and len(outputs) == 1:
return outputs[0] return outputs[0]
else: else:
return outputs return outputs
value = property( value = property(
lambda self: self._value, lambda self: self._value,
None, #not settable None, #not settable
...@@ -453,873 +678,192 @@ class Function(object): ...@@ -453,873 +678,192 @@ class Function(object):
# self.fn = fn def function(inputs, outputs, mode='FAST_RUN', accept_inplace = False):
# self.inputs = inputs """
# self.outputs = outputs Return a function calculating the outputs from the inputs.
# self.input_storage = input_storage
# self.output_storage = output_storage inputs -> list of SymbolicInput or In instances
outputs -> a SymbolicOutput or a list of SymbolicOutput or Out instances
# self.finder = dict([(input.name, storage) for input, storage in zip(inputs, input_storage)]) The return value of the returned function will match the format of this
# self.finder.update(dict([(input, storage) for input, storage in zip(inputs, input_storage)])) argument (either the value itself or a list of one or more return values)
# self.finder.update(dict([(i, storage) for i, storage in enumerate(input_storage)])) mode -> a descriptive string or a Mode instance; descriptive strings can be one of:
# self.inv_finder = dict([(storage, (input, input.name, i)) * SANITY_CHECK
# for i, (input, storage) in enumerate(zip(inputs, input_storage))]) * FAST_COMPILE
* FAST_RUN (default)
# # this class is important in overriding the square-bracket notation: * EXPENSIVE_OPTIMIZATION
# # fn.value[x] accept_inplace -> True iff the graph can contain inplace operations
# # self reference is available via the closure on the class prior to the optimization phase (default is False)
# class ValueAttribute(object):
Every element of the input list will be upgraded to an In instance if necessary,
# def __getitem__(self, item): using the following rules:
# return self.finder[item].value
* a Result instance r will be upgraded like In(r)
# def __setitem__(self, item, value): * a tuple (name, r) will be In(r, name=name)
# self.finder[item].value = value * a tuple (r, val) will be In(r, value=value, autoname=True)
* a tuple ((r,up), val) will be In(r, value=value, update=up, autoname=True)
# # this class is important in overriding the square-bracket notation: * a tuple (name, r, val) will be In(r, name=name, value=value)
# # fn.container[x] * a tuple (name, (r,up), val) will be In(r, name=name, value=val, update=up, autoname=True)
# # self reference is available via the closure on the class
# class ContainerAttribute(object): Similarly, every element of the output list will be upgraded to an
Out instance if necessary:
# def __getitem__(self, item):
# return self.finder[item] * a Result instance r will be upgraded like Out(r)
"""
# def __setitem__(self, item, new):
# orig = self.finder[item]
# input, name, i = self.inv_finder.pop(orig)
# self.inv_finder[new] = input, name, i
# self.finder[input] = new
# self.finder[name] = new
# self.finder[i] = new
# self._value = ValueAttribute()
# self._container = ContainerAttribute()
# def __getitem__(self, item):
# return self.value[item]
# def __setitem__(self, item, value):
# self.value[item] = value
# def __copy__(self):
# raise NotImplementedError()
# def __call__(self, *args, **kwargs): def wrap_in(input):
# for i, arg in enumerate(args): if isinstance(input, (SymbolicInput, SymbolicInputKit)):
# self[i] = arg return input
elif isinstance(input, gof.Result):
return In(input)
elif isinstance(input, (list, tuple)):
orig = input
if not input:
raise TypeError("Nonsensical input specification: %s" % input)
if isinstance(input[0], str):
name = input[0]
input = input[1:]
else:
name = None
if isinstance(input[0], (list, tuple)):
if len(input[0]) != 2 or len(input) != 2:
raise TypeError("Invalid input syntax: %s (check documentation or use an In instance)" % orig)
(result, update), value = input
elif isinstance(input[0], gof.Result):
if len(input) == 1:
result, update, value = input[0], None, None
elif len(input) == 2:
(result, value), update = input, None
else:
raise TypeError("Invalid input syntax: %s (check documentation or use an In instance)" % orig)
elif isinstance(input[0], (SymbolicInput, SymbolicInputKit)):
if len(input) == 1:
return input[0]
elif len(input) == 2:
input, value = input
if name is not None: input.name = name
input.value = value
return input
return In(result, name=name, value=value, update=update)
else:
raise TypeError("Unknown input type:", type(input), input)
def wrap_out(output):
if isinstance(output, SymbolicOutput):
return output
elif isinstance(output, gof.Result):
return SymbolicOutput(output)
else:
raise TypeError("Unknown output type: %s (%s)" % (type(output), output))
def function(input, output, mode='FAST_RUN', autoname_input=True): inputs = map(wrap_in, inputs)
outputs = map(wrap_out, outputs) if isinstance(outputs, (list, tuple)) else wrap_out(outputs)
# create a subclass of Function for the given arguments. # create a subclass of Function for the given arguments.
class F(Function): class F(Function):
pass pass
if autoname_input: fn = FunctionMaker(inputs, outputs, mode, accept_inplace = accept_inplace).create([getattr(input, 'value', None) for input in inputs])
# name the inputs according to the autoname feature
raise NotImplementedError()
# add all input names as properties of F # add all input names as properties of F
# if there is a name collision: def _get(name, self):
# - raise ValueError return self[name]
# - return message suggesting to rename or disable autoname (if it was enabled) def _set(name, self, value):
raise NotImplementedError() self[name] = value
def _err(name, self):
# return an instance of this new Function subclass raise TypeError("Ambiguous name: %s - please check the names of the inputs of your function for duplicates." % name)
return F(input, output, mode) seen = set()
for input in inputs:
name = input.name
if name:
if name in seen:
f = property(partial(_err, input.name), partial(_err, input.name))
setattr(F, input.name, f)
elif not hasattr(F, name):
f = property(partial(_get, input.name), partial(_set, input.name))
setattr(F, input.name, f)
seen.add(input.name)
else:
# class State(object): pass
# def __init__(self, variable, new_state = None):
# self.variable = variable
# if new_state is None:
# self.new_state = variable
# else:
# self.new_state = new_state
# class StateContainer(object):
# def __init__(self, data):
# self.data = data
# def env_with_state(normal_inputs, normal_outputs, states, accept_inplace = False):
# state_inputs = [s.variable for s in states]
# state_outputs = [s.new_state for s in states]
# inputs = normal_inputs + state_inputs
# outputs = normal_outputs + state_outputs
# inputs, outputs = gof.graph.clone(inputs, outputs)
# env = gof.env.Env(inputs, outputs)
# for node in env.nodes:
# if getattr(node.op, 'destroy_map', None):
# if not accept_inplace:
# raise TypeError("Graph must not contain inplace operations", node)
# else:
# env.extend(gof.DestroyHandler())
# break
# env.extend(Supervisor(normal_inputs))
# return env
# def function_with_state(fn, state_containers, unpack_single = True):
# n = len(state_containers)
# nin = len(fn.inputs)
# nout = len(fn.outputs)
# if n == 0:
# if unpack_single and nin == 1:
# return lambda *inputs: fn(*inputs)[0]
# else:
# return fn
# def f(*inputs):
# results = fn(*(list(inputs) + [c.data for c in state_containers]))
# for c, d in zip(state_containers, results[-n:]):
# c.data = d
# results = results[:-n]
# if unpack_single and len(results) == 1:
# return results[0]
# else:
# return results
# class FunctionFactory:
# def __init__(self,
# inputs,
# outputs,
# states = [],
# linker = default_linker,
# optimizer = default_optimizer,
# borrow_outputs = False,
# accept_inplace = False):
# self.states = states
# inputs, outputs = list(inputs), list(outputs)
# # Error checking
# for r in inputs + outputs:
# if not isinstance(r, gof.Result):
# raise TypeError("All inputs and outputs to FunctionFactory should be Result instances. Received:", type(r), r)
# for state in states:
# if not isinstance(state, State):
# raise TypeError("All states must be State instances", type(state), state)
# if len(inputs) != len(set(inputs)):
# print >>sys.stderr, "Warning: duplicate inputs"
# # make the env
# env = env_with_state(inputs, outputs, states, accept_inplace)
# self.env = env
# # optimize the env
# optimizer = predefined_optimizers.get(optimizer, optimizer)
# optimizer(env)
# # initialize the linker
# linker = copy(predefined_linkers.get(linker, linker))
# if not hasattr(linker, 'accept'):
# raise ValueError("'linker' parameter of FunctionFactory should be a Linker with an accept method " \
# "or one of %s" % predefined_linkers.keys())
# if borrow_outputs:
# self.linker = linker.accept(env)
# else:
# self.linker = linker.accept(env, no_recycling = infer_reuse_pattern(env, env.outputs))
# def create(self,
# states = [],
# profiler = None,
# unpack_single = True,
# strict = 'if_destroyed'):
# # Error checking
# if strict not in [True, False, 'if_destroyed']:
# raise ValueError("'strict' parameter of create should be one of [True, False, 'if_destroyed']")
# if len(states) != len(self.states):
# raise ValueError("not the right number of state initializers (expected %i, got %i)" % (len(self.states), len(states)))
# # Get a function instance
# if profiler is None:
# # some linkers may not support profilers, so we avoid passing the option altogether
# _fn = self.linker.make_function(unpack_single = False)
# else:
# _fn = self.linker.make_function(unpack_single = False,
# profiler = profiler)
# fn = function_with_state(_fn, states, unpack_single)
# # Make the inputs strict accordingly to the specified policy
# for env_input, fn_input in zip(self.env.inputs, _fn.inputs):
# if strict is True or (strict == 'if_destroyed' and self.env.destroyers(env_input)):
# fn_input.strict = True
# return fn
# def function(inputs,
# outputs,
# states = [],
# linker = default_linker,
# optimizer = default_optimizer,
# borrow_outputs = False,
# accept_inplace = False,
# profiler = None,
# unpack_single = True,
# strict = 'if_destroyed'):
# ff = FunctionFactory(inputs,
# outputs,
# states = [s[0] for s in states],
# linker = linker,
# optimizer = optimizer,
# borrow_outputs = borrow_outputs)
# return ff.create(states = [s[1] for s in states],
# profiler = profiler,
# unpack_single = unpack_single,
# strict = strict)
# import numpy
# import gof
# import sys
# from copy import copy
# #TODO: put together some default optimizations (TRAC #67)
# def exec_py_opt(inputs, outputs, features=[]):
# """Return an optimized graph running purely python implementations"""
# return Function(intputs, outputs, features, exec_py_opt.optimizer, gof.link.PerformLinker(), False)
# exec_py_opt.optimizer = None
# def exec_opt(inputs, outputs, features=[]):
# """Return a fast implementation"""
# return Function(intputs, outputs, features, exec_opt.optimizer, gof.link.PerformLinker(), False)
# exec_opt.optimizer = None
# class _DefaultOptimizer(object):
# #const = gof.opt.ConstantFinder()
# merge = gof.opt.MergeOptimizer()
# def __call__(self, env):
# #self.const(env)
# self.merge(env)
# default_optimizer = _DefaultOptimizer()
# def _mark_indestructible(results):
# for r in results:
# r.tag.indestructible = True
# # def linker_cls_python_and_c(env, **kwargs):
# # """Use this as the linker_cls argument to Function.__init__ to compare
# # python and C implementations"""
# def check_equal(x, y):
# x, y = x[0], y[0]
# if isinstance(x, numpy.ndarray) or isinstance(y, numpy.ndarray):
# if x.dtype != y.dtype or x.shape != y.shape or numpy.any(abs(x - y) > 1e-10):
# raise Exception("Output mismatch.", {'performlinker': x, 'clinker': y})
# else:
# if x != y:
# raise Exception("Output mismatch.", {'performlinker': x, 'clinker': y})
# # return gof.DualLinker(checker, **kwargs).accept(env)
# def infer_reuse_pattern(env, outputs_to_disown):
# do_not_reuse = list()
# seen = set()
# def walk(r):
# if r.owner is None or r in seen:
# return
# seen.add(r)
# do_not_reuse.append(r)
# node = r.owner
# op = node.op
# dmap = op.destroy_map if hasattr(op, 'destroy_map') else {}
# vmap = op.view_map if hasattr(op, 'view_map') else {}
# for l in dmap.values() + vmap.values():
# for i in l:
# walk(node.inputs[i])
# for output in outputs_to_disown:
# walk(output)
# return do_not_reuse
# def cloned_env(inputs, outputs):
# inputs, outputs = gof.graph.clone(inputs, outputs)
# env = gof.env.Env(inputs, outputs)
# return env
# def std_env(inputs, outputs, disown_inputs = False,
# use_destroy_handler = True):
# inputs, outputs = gof.graph.clone(inputs, outputs)
# _mark_indestructible(outputs)
# env = gof.env.Env(inputs, outputs)
# if use_destroy_handler:
# env.extend(gof.DestroyHandler())
# env.extend(gof.ReplaceValidate())
# env.validate()
# for input in inputs:
# input.destroyed_by_user = use_destroy_handler and len(env.destroyers(input)) != 0
# if not input.destroyed_by_user and not disown_inputs:
# # prevent optimizations from destroying the inputs
# input.tag.indestructible = True
# return env
# def std_opt(env):
# pass
# predefined_linkers = {
# 'py' : gof.PerformLinker(),
# 'c' : gof.CLinker(),
# 'c|py' : gof.OpWiseCLinker(),
# 'c&py' : gof.DualLinker(checker = check_equal)
# }
# class FunctionFactory:
# def __init__(self, inputs, outputs, linker = 'py', optimizer = std_opt, borrow_outputs = False, disown_inputs = False,
# use_destroy_handler = True):
# if len(inputs) != len(set(inputs)):
# print >>sys.stderr, "Warning: duplicate inputs"
# for r in list(inputs) + list(outputs):
# if not isinstance(r, gof.Result):
# raise TypeError("All inputs and outputs to FunctionFactory should be Result instances. Received:", type(r), r)
# env = std_env(inputs, outputs, disown_inputs = disown_inputs,
# use_destroy_handler = use_destroy_handler)
# if None is not optimizer:
# optimizer(env)
# env.validate()
# self.env = env
# linker = copy(predefined_linkers.get(linker, linker))
# if not hasattr(linker, 'accept'):
# raise ValueError("'linker' parameter of FunctionFactory should be a Linker with an accept method " \
# "or one of ['py', 'c', 'c|py', 'c&py']")
# if borrow_outputs:
# self.linker = linker.accept(env)
# else:
# self.linker = linker.accept(env, no_recycling = infer_reuse_pattern(env, env.outputs))
# def create(self, profiler = None, unpack_single = True, strict = 'if_destroyed'):
# if strict not in [True, False, 'if_destroyed']:
# raise ValueError("'strict' parameter of create should be one of [True, False, 'if_destroyed']")
# if profiler is None:
# fn = self.linker.make_function(unpack_single=unpack_single)
# else:
# fn = self.linker.make_function(unpack_single=unpack_single,
# profiler=profiler)
# for env_input, fn_input in zip(self.env.inputs, fn.inputs):
# if strict is True or (env_input.destroyed_by_user and strict == 'if_destroyed'):
# fn_input.strict = True
# return fn
# def partial(self, *first, **kwargs):
# fn = self.create(**kwargs)
# return lambda *last: fn(*(first + last))
# def function(inputs,
# outputs,
# linker = 'py',
# optimizer = std_opt,
# borrow_outputs = False,
# disown_inputs = False,
# profiler = None,
# unpack_single = True,
# strict = 'if_destroyed',
# use_destroy_handler = True):
# ff = FunctionFactory(inputs,
# outputs,
# linker = linker,
# optimizer = optimizer,
# borrow_outputs = borrow_outputs,
# disown_inputs = disown_inputs,
# use_destroy_handler = use_destroy_handler)
# return ff.create(profiler = profiler,
# unpack_single = unpack_single,
# strict = strict)
# def eval_outputs(outputs, **kwargs):
# return function([], outputs, **kwargs)()
# _fcache = {} # it would be nice to use weakref.WeakKeyDictionary()
# def fast_compute(*outputs):
# if outputs in _fcache:
# f = _fcache[outputs]
# else:
# f = function([], outputs, linker = 'c')
# _fcache[outputs] = f
# return f()
# class OpFromGraph(gof.Op):
# """
# This create an L{Op} from a list of input results and a list of output
# results.
# The signature is the same as the signature of L{FunctionFactory}
# and/or function and the resulting L{Op}'s perform will do the same
# operation as::
# function(inputs, outputs, **kwargs)
# Take note that the following options, if provided, must take the
# value(s) listed below:
# unpack_single = False
# borrow_outputs = False
# OpFromGraph takes an additional input, grad_depth. If grad_depth
# is n, OpFromGraph will make special Ops for gradients up to the
# nth level, allowing the user to differentiate this op up to n
# times. The parameter defaults to 1. If grad_depth == 0, the op
# will not be differentiable.
# Example:
# x, y, z = tensor.scalars('xyz')
# e = x + y * z
# op = OpFromGraph([x, y, z], [e], linker='c')
# # op behaves like a normal theano op
# e2 = op(x, y, z) + op(z, y, x)
# fn = function([x, y, z], [e2])
# """
# def __init__(self, inputs, outputs, grad_depth = 1, **kwargs):
# if kwargs.get('borrow_outputs') or kwargs.get('unpack_single'):
# raise ValueError('The borrow_outputs and unpack_single options cannot be True')
# kwargs['unpack_single'] = False
# kwargs['borrow_outputs'] = False
# self.fn = function(inputs, outputs, **kwargs)
# self.inputs = inputs
# self.outputs = outputs
# self.input_types = [input.type for input in inputs]
# self.output_types = [output.type for output in outputs]
# if grad_depth > 0:
# import gradient as G
# output_grads = [t() for t in self.output_types]
# gd = G.grad_sources_inputs(zip(self.outputs, output_grads), self.inputs)
# gs = map(gd.get, self.inputs)
# self.grad_ops = []
# for g in gs:
# if g is None:
# self.grad_ops.append(lambda *args: None)
# else:
# self.grad_ops.append(OpFromGraph(inputs + output_grads,
# [g],
# grad_depth = grad_depth - 1))
# def make_node(self, *inputs):
# for input, type in zip(inputs, self.input_types):
# if not type == input.type:
# raise TypeError("Wrong type, expected %s but got %s" % type, input.type)
# return gof.Apply(self,
# inputs,
# [type() for type in self.output_types])
# def perform(self, node, inputs, outputs):
# results = self.fn(*inputs)
# for output, result in zip(outputs, results):
# output[0] = result
# def grad(self, inputs, output_grads):
# if hasattr(self, 'grad_ops'):
# return [go(*(inputs + output_grads)) for go in self.grad_ops]
# else:
# raise NotImplementedError
# # class Container(object):
# # def __init__(self, r, value, strict = False):
# # self.r = r
# # self.type = r.type
# # self._value = value
# # self.strict = strict
# # def __get(self):
# # return self._value
# # def __set(self, value):
# # try:
# # if self.strict:
# # self._value = self.type.filter(value, strict = True)
# # else:
# # self._value = self.type.filter(value)
# # except:
# # raise_with_op(self.r)
# # value = property(__get, __set)
# # def __str__(self):
# # return "<" + str(self.value) + ">"
# # def __repr__(self):
# # return "<" + repr(self.value) + ">"
# # # def put(self, filters):
# # # filter = filters.popleft()
# # # filter.data = self.value
# # # return [filter], filters
# #########################aaaaaaaaaaa
# # class State:
# # def __init__(self, init, next = None):
# # self.init = init
# # self.next = next
# # class StateFunctionFactory(Function):
# # def __init__(self, inputs, outputs, states, **kwargs):
# # states_
# # inputs = [state.init for state in states] + inputs
# # outputs = [state.next for ]
# # class Function:
# # """
# # An 'executable' compiled from a graph
# # This class is meant to be used as a function: the idea is to use
# # __call__(*args) and it will compute your graph's function on the args and
# # return the value(s) corresponding to the output(s).
# # @ivar fn: the return value of L{linker.make_function}(False)
# # Additional Attributes if keep_locals == True
# # inputs - inputs in the env
# # outputs - outputs in the env
# # features - features to add to the env
# # linker_cls - the linker class
# # linker - the linker allocated from env
# # env - The env passed to the linker
# # @note: B{Re: Memory ownership, aliasing, re-use:}
# # That the objects returned by L{Function.__call__}(self, *args) are owned
# # by self, and that in general these outputs might be overwritten (in-place)
# # by subsequent calls to L{self.__call__}(*args). Why? This behaviour is
# # necessary for inplace operations to work, and L{Function}'s linker might re-use
# # memory from one execution to the next in order to make each execution faster.
# # """
# # def __init__(self, inputs, outputs,
# # features = [],
# # optimizer = default_optimizer,
# # linker_cls = gof.link.PerformLinker,
# # profiler = None,
# # unpack_single = True,
# # except_unreachable_input = True,
# # keep_locals = True):
# # """
# # Copy the graph, optimize, and link it.
# # @param inputs: a list of results to be this function's inputs
# # @param outputs: a list of results to be this function's outputs
# # @param features: features to add to the env
# # @param optimizer: an optimizer to apply to the copied graph, before linking
# # @param linker_cls: a callable that takes an env and returns a Linker
# # @param profiler: a L{Profiler} for the produced function (only valid if the
# # linker_cls's make_function takes a profiler argument)
# # @param unpack_single: unpack return value lists of length 1. @see: L{Linker.make_function}
# # @param keep_locals: add the local variables from __init__ to the class
# # """
# # _mark_indestructible(outputs)
# # if len(inputs) != len(set(inputs)):
# # raise Exception('duplicate inputs')
# # if len(outputs) != len(set(outputs)):
# # raise Exception('duplicate outputs')
# # #evaluate the orphans, and put these values into the clone of the env
# # orphans = list(gof.graph.results_and_orphans(inputs, outputs,
# # except_unreachable_input=except_unreachable_input)[1])
# # orphan_data = eval_outputs(orphans, unpack_single=False)
# # #print 'orphans', orphans
# # #print 'ops', gof.graph.ops(inputs, outputs)
# # env = gof.env.Env(inputs, outputs)
# # #print 'orphans in env', env.orphans()
# # env, equiv = env.clone_get_equiv(clone_inputs=True)
# # for feature in features:
# # env.extend(feature(env))
# # env.extend(gof.DestroyHandler(env))
# # #print 'orphans after clone', env.orphans()
# # for d, o in zip(orphan_data, [equiv[orphan] for orphan in orphans]):
# # #print 'assigning orphan value', d
# # #o.data = d
# # new_o = gof.Constant(o.type, d)
# # env.replace(o, new_o)
# # assert new_o in env.orphans
# # # optimize and link the cloned env
# # if None is not optimizer:
# # optimizer(env)
# # linker = linker_cls(env)
# # if keep_locals:# useful flag for debugging!
# # self.__dict__.update(locals())
# # if profiler is None:
# # self.fn = linker.make_function(unpack_single=unpack_single)
# # else:
# # self.fn = linker.make_function(unpack_single=unpack_single,
# # profiler=profiler)
# # self.inputs = env.inputs
# # self.outputs = env.outputs
# # self.features = features
# # self.optimizer = optimizer
# # self.linker_cls = linker_cls
# # self.profiler = profiler
# # self.unpack_single = unpack_single
# # self.except_unreachable_input = except_unreachable_input
# # self.keep_locals = keep_locals
# # def __call__(self, *args):
# # return self.fn(*args)
# # def eval_outputs(outputs,
# # features = [],
# # optimizer = None,
# # linker_cls = gof.link.PerformLinker,
# # unpack_single = True,
# # keep_locals = True):
# # if len(outputs) == 0:
# # #print 'returning with no inputs'
# # if unpack_single:
# # return None
# # else:
# # return []
# # inputs = gof.graph.inputs(outputs)
# # if any(not isinstance(input, gof.Constant) for input in inputs):
# # raise TypeError("Cannot evaluate outputs because some of the leaves are not Constant.", outputs)
# # in_data = [i.data for i in inputs]
# # #print 'in_data = ', in_data
# # if len(inputs) != len(in_data):
# # raise Exception('some input data is unknown')
# # env = gof.env.Env(inputs, outputs)
# # env.replace_all(dict([(i, i.type()) for i in inputs]))
# # env = env.clone(clone_inputs=True)
# # _mark_indestructible(env.outputs)
# # if None is not optimizer:
# # optimizer(env)
# # linker = linker_cls(env)
# # fn = linker.make_function(unpack_single=unpack_single)
# # rval = fn(*in_data)
# # return rval
# # StateFunction([x, y], [e], (w, w + lr * bla()))
# # class _Function:
# # def __init__(self,
# # inputs,
# # outputs,
# # optimizer,
# # linker_type = 'py',
# # unpack_single = True,
# # except_unreachable_input = True,
# # disposable_inputs = [],
# # borrow_outputs = []):
# # _mark_indestructible(outputs)
# # if len(inputs) != len(set(inputs)):
# # raise Exception('duplicate inputs')
# # if len(outputs) != len(set(outputs)):
# # raise Exception('duplicate outputs')
# # orphans = list(gof.graph.results_and_orphans(inputs, outputs,
# # except_unreachable_input=except_unreachable_input)[1])
# # orphan_data = eval_outputs(orphans, unpack_single=False)
# # env = gof.env.Env(inputs, outputs, features + [gof.EquivTool], consistency_check = True)
# # env = env.clone(clone_inputs=True)
# # for d, o in zip(orphan_data, [env.equiv(orphan) for orphan in orphans]):
# # o.data = d
# # # optimize and link the cloned env
# # if None is not optimizer:
# # optimizer(env)
# # linker = linker_cls(env)
# # if keep_locals:# useful flag for debugging!
# # self.__dict__.update(locals())
# # if profiler is None:
# # self.fn = linker.make_function(inplace=True,
# # unpack_single=unpack_single)
# # else:
# # self.fn = linker.make_function(inplace=True,
# # unpack_single=unpack_single,
# # profiler=profiler)
# # self.inputs = env.inputs
# # self.outputs = env.outputs
# # self.features = features
# # self.optimizer = optimizer
# # self.linker_cls = linker_cls
# # self.profiler = profiler
# # self.unpack_single = unpack_single
# # self.except_unreachable_input = except_unreachable_input
# # self.keep_locals = keep_locals
# # def __call__(self, *args):
# # return self.fn(*args)
# # def __copy__(self):
# # return Function(self.inputs, self.outputs,
# # features = self.features,
# # optimizer = self.optimizer,
# # linker_cls = self.linker_cls,
# # profiler = self.profiler,
# # unpack_single = self.unpack_single,
# # except_unreachable_input = self.except_unreachable_input,
# # keep_locals = self.keep_locals)
fn.__class__ = F
return fn
class OpFromGraph(gof.Op):
"""
This create an L{Op} from a list of input results and a list of output
results.
# # class StateFunction: The signature is the same as the signature of L{FunctionFactory}
and/or function and the resulting L{Op}'s perform will do the same
operation as::
function(inputs, outputs, **kwargs)
# # def __init__(self, inputs, outputs, *states): Take note that the following options, if provided, must take the
# # in_states, out_states = zip(*states) value(s) listed below:
# # env = unpack_single = False
borrow_outputs = False
OpFromGraph takes an additional input, grad_depth. If grad_depth
is n, OpFromGraph will make special Ops for gradients up to the
nth level, allowing the user to differentiate this op up to n
times. The parameter defaults to 1. If grad_depth == 0, the op
will not be differentiable.
Example:
x, y, z = tensor.scalars('xyz')
e = x + y * z
op = OpFromGraph([x, y, z], [e], linker='c')
# op behaves like a normal theano op
e2 = op(x, y, z) + op(z, y, x)
fn = function([x, y, z], [e2])
"""
def __init__(self, inputs, outputs, grad_depth = 1, **kwargs):
if kwargs.get('borrow_outputs') or kwargs.get('unpack_single'):
raise ValueError('The borrow_outputs and unpack_single options cannot be True')
kwargs['unpack_single'] = False
kwargs['borrow_outputs'] = False
self.fn = function(inputs, outputs, **kwargs)
self.inputs = inputs
self.outputs = outputs
self.input_types = [input.type for input in inputs]
self.output_types = [output.type for output in outputs]
if grad_depth > 0:
import gradient as G
output_grads = [t() for t in self.output_types]
gd = G.grad_sources_inputs(zip(self.outputs, output_grads), self.inputs)
gs = map(gd.get, self.inputs)
self.grad_ops = []
for g in gs:
if g is None:
self.grad_ops.append(lambda *args: None)
else:
self.grad_ops.append(OpFromGraph(inputs + output_grads,
[g],
grad_depth = grad_depth - 1))
def make_node(self, *inputs):
for input, type in zip(inputs, self.input_types):
if not type == input.type:
raise TypeError("Wrong type, expected %s but got %s" % type, input.type)
return gof.Apply(self,
inputs,
[type() for type in self.output_types])
def perform(self, node, inputs, outputs):
results = self.fn(*inputs)
for output, result in zip(outputs, results):
output[0] = result
def grad(self, inputs, output_grads):
if hasattr(self, 'grad_ops'):
return [go(*(inputs + output_grads)) for go in self.grad_ops]
else:
raise NotImplementedError
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论