提交 7ba9c052 authored 作者: abergeron's avatar abergeron

Merge pull request #3055 from ChienliMa/swapSV

Swap SharedVariable
...@@ -83,7 +83,7 @@ Reference ...@@ -83,7 +83,7 @@ Reference
.. function:: function(inputs, outputs, mode=None, updates=None, givens=None, no_default_updates=False, accept_inplace=False, name=None, rebuild_strict=True, allow_input_downcast=None, profile=None, on_unused_input='raise') .. function:: function(inputs, outputs, mode=None, updates=None, givens=None, no_default_updates=False, accept_inplace=False, name=None, rebuild_strict=True, allow_input_downcast=None, profile=None, on_unused_input='raise')
Return a callable object that will calculate `outputs` from `inputs`. Return a :class:`callable object <theano.compile.function_module.Function>` that will calculate `outputs` from `inputs`.
:type params: list of either Variable or Param instances, but not shared :type params: list of either Variable or Param instances, but not shared
variables. variables.
...@@ -189,3 +189,6 @@ Reference ...@@ -189,3 +189,6 @@ Reference
equivalent to Var1. equivalent to Var1.
.. autofunction:: theano.compile.function.function_dump .. autofunction:: theano.compile.function.function_dump
.. autoclass:: theano.compile.function_module.Function
:members: free, copy
\ No newline at end of file
...@@ -1827,7 +1827,7 @@ class _Linker(gof.link.LocalLinker): ...@@ -1827,7 +1827,7 @@ class _Linker(gof.link.LocalLinker):
return self return self
def make_all(self, profiler=None, input_storage=None, def make_all(self, profiler=None, input_storage=None,
output_storage=None): output_storage=None, storage_map=None):
# can't import at toplevel because of circular import TODO: # can't import at toplevel because of circular import TODO:
# don't do this ugly hacky way of setting the # don't do this ugly hacky way of setting the
# filter_checks_isfinite # filter_checks_isfinite
...@@ -1857,7 +1857,7 @@ class _Linker(gof.link.LocalLinker): ...@@ -1857,7 +1857,7 @@ class _Linker(gof.link.LocalLinker):
no_recycling = [] no_recycling = []
input_storage, output_storage, storage_map = link.map_storage( input_storage, output_storage, storage_map = link.map_storage(
fgraph, order, input_storage_, output_storage_) fgraph, order, input_storage_, output_storage_, storage_map)
thunks_py = [] # python thunks thunks_py = [] # python thunks
thunks_c = [] # c thunks thunks_c = [] # c thunks
...@@ -2525,7 +2525,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions ...@@ -2525,7 +2525,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
self.mode = mode self.mode = mode
self.output_keys = output_keys self.output_keys = output_keys
def create(self, defaults=None, trustme=False): def create(self, defaults=None, trustme=False, storage_map=None):
""" """
Create a function. Create a function.
...@@ -2633,7 +2633,8 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions ...@@ -2633,7 +2633,8 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
defaults = _defaults defaults = _defaults
# Get a function instance # Get a function instance
_fn, _i, _o = self.linker.make_thunk(input_storage=input_storage) _fn, _i, _o = self.linker.make_thunk(input_storage=input_storage,
storage_map=storage_map)
fn = self.function_builder(_fn, _i, _o, self.indices, fn = self.function_builder(_fn, _i, _o, self.indices,
self.outputs, defaults, self.unpack_single, self.outputs, defaults, self.unpack_single,
self.return_none, self.output_keys, self) self.return_none, self.output_keys, self)
......
...@@ -293,7 +293,7 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None, ...@@ -293,7 +293,7 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
mode=mode, mode=mode,
accept_inplace=accept_inplace, name=name) accept_inplace=accept_inplace, name=name)
else: else:
# note: pfunc will also call orig_function-- orig_function is # note: pfunc will also call orig_function -- orig_function is
# a choke point that all compilation must pass through # a choke point that all compilation must pass through
fn = pfunc(params=inputs, fn = pfunc(params=inputs,
outputs=outputs, outputs=outputs,
......
...@@ -5,7 +5,7 @@ Driver of graph construction, optimization, and linking. ...@@ -5,7 +5,7 @@ Driver of graph construction, optimization, and linking.
from __future__ import print_function from __future__ import print_function
import copy import copy
from six import string_types, iteritems from six import string_types, iteritems, iterkeys
from six.moves import xrange from six.moves import xrange
import six.moves.copyreg as copyreg import six.moves.copyreg as copyreg
import six.moves.cPickle as pickle import six.moves.cPickle as pickle
...@@ -15,9 +15,10 @@ import warnings ...@@ -15,9 +15,10 @@ import warnings
import numpy import numpy
import theano import theano
from theano import gof from theano import config, gof
from functools import partial from functools import partial
from theano.compat import izip from theano.compat import izip
from theano.gof import graph
import theano.compile.mode import theano.compile.mode
from theano.compile.io import ( from theano.compile.io import (
In, SymbolicInput, SymbolicInputKit, SymbolicOutput) In, SymbolicInput, SymbolicInputKit, SymbolicOutput)
...@@ -136,8 +137,8 @@ class Supervisor: ...@@ -136,8 +137,8 @@ class Supervisor:
return True return True
for r in self.protected + list(fgraph.outputs): for r in self.protected + list(fgraph.outputs):
if fgraph.destroyers(r): if fgraph.destroyers(r):
raise gof.InconsistencyError( raise gof.InconsistencyError("Trying to destroy a protected"
"Trying to destroy a protected Variable.", r) "Variable.", r)
def std_fgraph(input_specs, output_specs, accept_inplace=False): def std_fgraph(input_specs, output_specs, accept_inplace=False):
...@@ -535,16 +536,201 @@ class Function(object): ...@@ -535,16 +536,201 @@ class Function(object):
self.value[item] = value self.value[item] = value
def __copy__(self): def __copy__(self):
defaults = [default for _1, _2, default in self.defaults] """
cpy = self.maker.create(defaults, trustme=True) Copy a function. Copied function have separate intermediate
for (input, _1, _2), here, there in zip(self.indices, storages and output storages with original function
"""
return self.copy()
def copy(self, share_memory=False, swap=None, delete_updates=False,
name=None, profile=None):
"""
Copy this function. Copied function will have separated maker and
fgraph with original function. User can choose whether to separate
storage by changing the share_memory arguments.
---------------------
Params:
share_memory -- { boolean } Default is False. When True, two
function share intermediate storages(storages except input and
output storages). Otherwise two functions will only share partial
storages and same maker. If two functions share memory and
allow_gc=False, this will increase executing speed and save memory.
swap -- { dict } Dictionary that map old SharedVariables to new
SharedVariables. Default is None.
NOTE: The shared variable swap in only done in the new returned
function, not in the user graph.
delete_updates -- { boolean } Default is False. If True, Copied
function will not have update.
name -- { string } If provided, will be the name of the new
Function. Otherwise, it will be old + " copy"
profile -- as theano.function profile parameter
---------------------
Returns:
func -- Copied theano.Function
"""
# helper function
def checkSV(sv_ori, sv_rpl):
"""
Assert two SharedVariable follow some restirctions:
1. same type
2. same shape or dim?
"""
SharedVariable = theano.tensor.sharedvar.SharedVariable
assert isinstance(sv_ori, SharedVariable), (
"Key of swap should be SharedVariable, given:", sv_ori,
" type", type(sv_ori))
assert isinstance(sv_rpl, SharedVariable), (
"Value of swap should be SharedVariable, given:", sv_rpl,
"type", type(sv_ori))
assert sv_ori.type == sv_rpl.type, (
"Type of given SharedVariable conflicts with original one",
"Type of given SharedVariable:", sv_rpl.type,
"Type of original SharedVariable:", sv_ori.type)
maker = self.maker
# Copy Ins and their storage.
# so that they have different storage as their value
ins = [copy.copy(input) for input in maker.inputs]
# Delete update output in fgraph and updates In instances if needed
if delete_updates:
# The first len(maker.outputs) variables are original variables.
# The rest are the updates.
out_vars = maker.fgraph.outputs[:len(maker.outputs)]
else:
out_vars = maker.fgraph.outputs
# Init new fgraph using copied variables and get memo
# memo: a dict that map old variables to new variables
memo = graph.clone_get_equiv(maker.fgraph.inputs, out_vars)
fg_cpy = gof.fg.FunctionGraph([memo[i] for i in maker.fgraph.inputs],
[memo[o] for o in out_vars],
clone=False)
# Re initialize Outs and swap update and variable in Ins
# By doing this, we can pass FunctionMaker._check_unused_inputs()
outs = list(map(SymbolicOutput, fg_cpy.outputs[:len(maker.outputs)]))
for out_ori, out_cpy in zip(maker.outputs, outs):
out_cpy.borrow = out_ori.borrow
# swap SharedVariable
if swap is not None:
exist_svs = [i.variable for i in maker.inputs]
# Check if given ShareVariables exist
for sv in iterkeys(swap):
if sv not in exist_svs:
raise ValueError("SharedVariable: %s not found" %
(sv.name))
# Swap SharedVariable in fgraph and In instances
for index, (i, in_v) in enumerate(zip(ins, fg_cpy.inputs)):
# Variables in maker.inputs are defined by user, therefore we
# use them to make comparision and do the mapping.
# Otherwise we don't touch them.
var = maker.inputs[index].variable
if var in swap:
swap_sv = swap[var]
checkSV(i.variable, swap_sv)
# swap variable and value of In instances
i.variable = swap_sv
i.value = swap_sv.container
# In the fgraph we use the cloned SharedVariable
swap_sv = swap_sv.clone()
# Swap SharedVariable in fgraph
# if inputs was replaced, change self.inputs
fg_cpy.inputs[index] = swap_sv
fg_cpy.replace(in_v, swap_sv, reason="Swap SV")
# Delete update if needed
update_i = len(outs)
for i, in_var in zip(ins, fg_cpy.inputs):
i.variable = in_var
if not delete_updates and i.update is not None:
i.update = fg_cpy.outputs[update_i]
update_i += 1
else:
i.update = None
# Construct new storage_map that map new variable to old storage,
# so that the ensuing function shares storage with the original one
storage_map = self.fn.storage_map
new_storage_map = {}
# TODO: We could share the output storage, but we must make sure
# 2 different function call won't override each other values. This
# is already done elsewhere, so to reuse it the user would need to
# use Out(var, borrow=True) and maybe the mutable=True flag too.
# But to be safe for now as it isn't documented and we aren't sure
# it is well tested, we don't share the part of the storage_map.
if share_memory:
i_o_vars = maker.fgraph.inputs + maker.fgraph.outputs
for key in storage_map.keys():
if key not in i_o_vars:
new_storage_map[memo[key]] = storage_map[key]
if not name and self.name:
name = self.name + " copy"
input_storage = [i.value for i in ins]
# reinitialize new maker and create new function
if profile is None:
profile = config.profile
# profile -> True or False
if profile is True:
if name:
message = name
else:
message = str(maker.profile.message) + " copy"
profile = theano.compile.profiling.ProfileStats(message=message)
# profile -> object
elif type(profile) == str:
profile = theano.compile.profiling.ProfileStats(message=profile)
f_cpy = maker.__class__(inputs=ins, outputs=outs, fgraph=fg_cpy,
mode=maker.mode, profile=profile,
on_unused_input=maker.on_unused_input,
function_builder=maker.function_builder,
accept_inplace=maker.accept_inplace
).create(input_storage,
storage_map=new_storage_map)
for in_ori, in_cpy, ori, cpy in zip(maker.inputs, f_cpy.maker.inputs,
self.input_storage, self.input_storage,
cpy.input_storage): f_cpy.input_storage):
if input.mutable and here is not None:
there.data = copy.copy(here.data) # Share immutable ShareVariable and constant input's storage
swapped = swap is not None and in_ori.variable in swap
# Using the original storage if SharedVariable will not be updated
# and is not swapped
if not in_ori.mutable and not swapped:
cpy.data = ori.data
in_cpy.value = in_ori.value
# Reconstruct Function.finder which map Variable defined by user
# to container, to make Function.value and Function.data work well.
# Replace variable in new maker.inputs by the original ones.
# So that user can swap SharedVariable in a swapped function
container = f_cpy.finder.pop(in_cpy.variable)
if not swapped:
f_cpy.finder[in_ori.variable] = container
in_cpy.vairable = in_ori.variable
else: else:
there.data = here.data f_cpy.finder[swap[in_ori.variable]] = container
return cpy in_cpy.variable = swap[in_ori.variable]
f_cpy.name = name
f_cpy.maker.fgraph.name = name
return f_cpy
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
profile = self.profile profile = self.profile
...@@ -1232,8 +1418,8 @@ class FunctionMaker(object): ...@@ -1232,8 +1418,8 @@ class FunctionMaker(object):
else: else:
# fgraph is already an optimized one # fgraph is already an optimized one
need_opt = False need_opt = False
_, additional_outputs = std_fgraph(inputs, outputs, accept_inplace) updates = [spec.update for spec in inputs if spec.update]
pass additional_outputs = list(map(SymbolicOutput, updates))
self.fgraph = fgraph self.fgraph = fgraph
...@@ -1355,7 +1541,7 @@ class FunctionMaker(object): ...@@ -1355,7 +1541,7 @@ class FunctionMaker(object):
"'%s'.\nValid values are 'raise', " "'%s'.\nValid values are 'raise', "
"'warn', and 'ignore'." % on_unused_input) "'warn', and 'ignore'." % on_unused_input)
def create(self, input_storage=None, trustme=False): def create(self, input_storage=None, trustme=False, storage_map=None):
""" """
Create a function. Create a function.
...@@ -1436,7 +1622,7 @@ class FunctionMaker(object): ...@@ -1436,7 +1622,7 @@ class FunctionMaker(object):
try: try:
theano.config.traceback.limit = 0 theano.config.traceback.limit = 0
_fn, _i, _o = self.linker.make_thunk( _fn, _i, _o = self.linker.make_thunk(
input_storage=input_storage_lists) input_storage=input_storage_lists, storage_map=storage_map)
finally: finally:
theano.config.traceback.limit = limit_orig theano.config.traceback.limit = limit_orig
......
...@@ -421,7 +421,7 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None, ...@@ -421,7 +421,7 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
if profile is True: if profile is True:
profile = ProfileStats(message=name) profile = ProfileStats(message=name)
# profile -> object # profile -> object
if type(profile) == str: elif type(profile) == str:
profile = ProfileStats(message=profile) profile = ProfileStats(message=profile)
# profile is typically either False or an object at this point. # profile is typically either False or an object at this point.
# No need to block other objects being passed through though. It might be # No need to block other objects being passed through though. It might be
......
...@@ -43,8 +43,9 @@ AddConfigVar('ProfileMode.profile_memory', ...@@ -43,8 +43,9 @@ AddConfigVar('ProfileMode.profile_memory',
class Profile_Maker(FunctionMaker): class Profile_Maker(FunctionMaker):
def create(self, input_storage=None, trustme=False): def create(self, input_storage=None, trustme=False, storage_map=None):
ret = super(Profile_Maker, self).create(input_storage, trustme) ret = super(Profile_Maker, self).create(input_storage, trustme,
storage_map)
if (hasattr(theano, 'sandbox') and if (hasattr(theano, 'sandbox') and
hasattr(theano.sandbox, 'cuda') and hasattr(theano.sandbox, 'cuda') and
......
...@@ -690,6 +690,7 @@ class ProfileStats(object): ...@@ -690,6 +690,7 @@ class ProfileStats(object):
print('', file=file) print('', file=file)
# The validation time is a subset of optimizer_time # The validation time is a subset of optimizer_time
if self.optimizer_time > 0:
assert self.validate_time < self.optimizer_time assert self.validate_time < self.optimizer_time
def summary_globals(self, file): def summary_globals(self, file):
......
...@@ -241,6 +241,143 @@ class T_function(unittest.TestCase): ...@@ -241,6 +241,143 @@ class T_function(unittest.TestCase):
f(1, 2) # put them out of sync f(1, 2) # put them out of sync
self.assertFalse(f(1, 2) == g(1, 2)) # they should not be equal anymore. self.assertFalse(f(1, 2) == g(1, 2)) # they should not be equal anymore.
def test_copy_share_memory(self):
x = T.fscalar('x')
# SharedVariable for tests, one of them has update
y = theano.shared(value=1)
z = theano.shared(value=2)
out = T.tanh((x+y+2)/(x+z-0.2)**2)
# Test for different linkers
for mode in ["FAST_RUN","FAST_COMPILE"]:
ori = theano.function([x], [out], mode=mode,updates={z:z+1})
cpy = ori.copy(share_memory=True)
# Test if memories shared
storage_map_ori = ori.fn.storage_map
storage_map_cpy = cpy.fn.storage_map
fgraph_ori = ori.maker.fgraph
fgraph_cpy = cpy.maker.fgraph
# Assert intermediate and Constants storages are shared.
# and output stoarges are not shared
i_o_variables = fgraph_cpy.inputs + fgraph_cpy.outputs
ori_storages = storage_map_ori.values()
for key in storage_map_cpy.keys():
storage = storage_map_cpy[key]
if key not in i_o_variables or isinstance(key, theano.tensor.Constant):
self.assertTrue(any([ storage is s for s in ori_storages]))
# Assert storages of SharedVariable without updates are shared
for (input, _1, _2), here, there in zip(ori.indices,
ori.input_storage,
cpy.input_storage):
self.assertTrue(here.data is there.data)
def test_swap_SharedVariable(self):
i = T.iscalar()
x_list = theano.shared(value=numpy.random.rand(10).astype(config.floatX))
x = T.scalar('x')
# SharedVariable for tests, one of them has update
y = theano.shared(value=1, name='y')
z = theano.shared(value=2, name='z')
m = theano.shared(value=0, name='m')
# SharedVariable to replace
y_rpl = theano.shared(value=3,name ='y_rpl')
z_rpl = theano.shared(value=4, name='z_rpl')
swap = {y:y_rpl, z:z_rpl}
map_SV = {'y_rpl':y_rpl, 'z_rpl':z_rpl}
out = x+y+z+m
# Test for different linkers
# for mode in ["FAST_RUN","FAST_COMPILE"]:
second_time = False
for mode in ["FAST_RUN","FAST_COMPILE"]:
ori = theano.function([i], [out], mode=mode,
updates=[(z,z+1),(m,m+2)],
givens={x:x_list[i]})
cpy = ori.copy(swap=swap)
# run fuction several time
ori(1), cpy(1),cpy(2)
# assert same SharedVariable are update in different function
if not second_time:
# m should be updated 3 times
assert m.get_value() == 6
# z should be updated once
assert z.get_value() == 3
# z_rpl should be updated twice
assert z_rpl.get_value() == 6
# y and y_rpl should not be updated
assert y_rpl.get_value() == 3
assert y.get_value() == 1
elif second_time:
# doule update for sharedvariable
assert m.get_value() == 12
assert z.get_value() == 4
assert z_rpl.get_value() == 8
assert y_rpl.get_value() == 3
# test cpy function:
# 2. SharedVariable is updatable -> values did update(z == 5)
# 1. sharedvariable is swap -> Rpl sharedvariables share storage
names = map_SV.keys()
for key in cpy.fn.storage_map:
if key.name in names:
assert map_SV[key.name].container.storage[0] ==\
cpy.fn.storage_map[key][0]
second_time = True
def test_swap_SharedVaraile_with_given(self):
"""
A special testcase for logistic_sgd.py in Deep Learning Tutorial
This test assert that SharedVariable in different function have same storage
"""
train_x = theano.shared(value=numpy.random.rand(10,10).astype(config.floatX))
test_x = theano.shared(value=numpy.random.rand(10,10).astype(config.floatX))
train_y = theano.shared(value=numpy.random.rand(10,1).astype(config.floatX))
test_y = theano.shared(value=numpy.random.rand(10,1).astype(config.floatX))
i = T.iscalar('index')
x = T.vector('x')
y = T.vector('y')
# this formular has no sense but for a test
out = (T.sum(x) - y) ** 2
train = theano.function([i], out,
givens={x:train_x[i], y:train_y[i]},
updates={train_x:train_x+0.1})
test_def = theano.function([i], out, givens={x:test_x[i], y:test_y[i]})
test_cpy = train.copy(swap={train_x:test_x, train_y:test_y},
delete_updates=True)
for in1, in2 in zip( test_def.maker.inputs, test_def.maker.inputs):
assert in1.value is in2.value
def test_copy_delete_updates(self):
x = T.fscalar('x')
# SharedVariable for tests, one of them has update
y = theano.shared(value=1, name='y')
z = theano.shared(value=2, name='z')
out = x+y+z
# Test for different linkers
# for mode in ["FAST_RUN","FAST_COMPILE"]:
second_time = False
for mode in ["FAST_RUN","FAST_COMPILE"]:
ori = theano.function([x], out, mode=mode,updates={z:z*2})
cpy = ori.copy(delete_updates=True)
assert cpy(1)[0] == 4
assert cpy(1)[0] == 4
assert cpy(1)[0] == 4
def test_shared_state0(self): def test_shared_state0(self):
a = T.scalar() # the a is for 'anonymous' (un-named). a = T.scalar() # the a is for 'anonymous' (un-named).
x, s = T.scalars('xs') x, s = T.scalars('xs')
......
...@@ -1069,11 +1069,9 @@ class CLinker(link.Linker): ...@@ -1069,11 +1069,9 @@ class CLinker(link.Linker):
pass pass
return utils.uniq(ret) return utils.uniq(ret)
def __compile__(self, input_storage=None, def __compile__(self, input_storage=None, output_storage=None,
output_storage=None, keep_lock=False): storage_map=None, keep_lock=False):
""" """WRITEME
WRITEME
Compiles this linker's fgraph. Compiles this linker's fgraph.
Parameters Parameters
...@@ -1111,6 +1109,7 @@ class CLinker(link.Linker): ...@@ -1111,6 +1109,7 @@ class CLinker(link.Linker):
thunk = self.cthunk_factory(error_storage, thunk = self.cthunk_factory(error_storage,
input_storage, input_storage,
output_storage, output_storage,
storage_map,
keep_lock=keep_lock) keep_lock=keep_lock)
return (thunk, return (thunk,
[link.Container(input, storage) for input, storage in [link.Container(input, storage) for input, storage in
...@@ -1143,10 +1142,8 @@ class CLinker(link.Linker): ...@@ -1143,10 +1142,8 @@ class CLinker(link.Linker):
return init_tasks, tasks return init_tasks, tasks
def make_thunk(self, input_storage=None, output_storage=None, def make_thunk(self, input_storage=None, output_storage=None,
keep_lock=False): storage_map=None, keep_lock=False):
""" """WRITEME
WRITEME
Compiles this linker's fgraph and returns a function to perform the Compiles this linker's fgraph and returns a function to perform the
computations, as well as lists of storage cells for both the inputs computations, as well as lists of storage cells for both the inputs
and outputs. and outputs.
...@@ -1157,25 +1154,24 @@ class CLinker(link.Linker): ...@@ -1157,25 +1154,24 @@ class CLinker(link.Linker):
List of lists of length 1. In order to use List of lists of length 1. In order to use
the thunk returned by __compile__, the inputs must be put in the thunk returned by __compile__, the inputs must be put in
that storage. If None, storage will be allocated. that storage. If None, storage will be allocated.
output_storage: list of lists of length 1 @param output_storage: list of lists of length 1. The thunk returned
The thunk returned by __compile__ will put the variables of the by __compile__ will put the variables of the computation in these
computation in these lists. If None, storage will be allocated. lists. If None, storage will be allocated.
@param storage_map: dict that map variables to storages. This is used
when you need to customize the storage of this thunk.
Returns: thunk, input_storage, output_storage
Returns
-------
object
Thunk, input_storage, output_storage.
The return values can be used as follows: The return values can be used as follows:
f, istor, ostor = clinker.make_thunk() f, istor, ostor = clinker.make_thunk()
istor[0].data = first_input istor[0].data = first_input
istor[1].data = second_input istor[1].data = second_input
f() f()
first_output = ostor[0].data first_output = ostor[0].data
""" """
init_tasks, tasks = self.get_init_tasks() init_tasks, tasks = self.get_init_tasks()
cthunk, in_storage, out_storage, error_storage = self.__compile__( cthunk, in_storage, out_storage, error_storage = self.__compile__(
input_storage, output_storage, input_storage, output_storage, storage_map,
keep_lock=keep_lock) keep_lock=keep_lock)
res = _CThunk(cthunk, init_tasks, tasks, error_storage) res = _CThunk(cthunk, init_tasks, tasks, error_storage)
...@@ -1529,25 +1525,17 @@ class CLinker(link.Linker): ...@@ -1529,25 +1525,17 @@ class CLinker(link.Linker):
return self._mod return self._mod
def cthunk_factory(self, error_storage, in_storage, out_storage, def cthunk_factory(self, error_storage, in_storage, out_storage,
keep_lock=False): storage_map=None, keep_lock=False):
""" """WRITEME
WRITEME error_storage -> list of length 3
in_storage -> list of lists of length 1, one per input
Parameters out_storage -> list of lists of length 1, one per output
----------
error_storage : list of length 3
in_storage : list of lists of length 1, one per input
out_storage : list of lists of length 1, one per output
Returns Returns a thunk that points to an instance of a C struct that
-------
object
A thunk that points to an instance of a C struct that
can carry on the computation of this linker's fgraph. That thunk, can carry on the computation of this linker's fgraph. That thunk,
when executed, will fetch its inputs from in_storage, put its when executed, will fetch its inputs from in_storage, put its
outputs in out_storage and if an error occurs will put the outputs in out_storage and if an error occurs will put the
type, value and traceback of the exception in error_storage. type, value and traceback of the exception in error_storage.
""" """
try: try:
key = self.cmodule_key() key = self.cmodule_key()
...@@ -1569,7 +1557,10 @@ class CLinker(link.Linker): ...@@ -1569,7 +1557,10 @@ class CLinker(link.Linker):
out_storage = [x for i, x in enumerate(out_storage) out_storage = [x for i, x in enumerate(out_storage)
if (i + len(in_storage)) not in dupidx] if (i + len(in_storage)) not in dupidx]
in_storage = [x for i, x in enumerate(in_storage) if i not in dupidx] in_storage = [x for i, x in enumerate(in_storage) if i not in dupidx]
if storage_map is None:
orphd = [[orphan.data] for orphan in self.orphans] orphd = [[orphan.data] for orphan in self.orphans]
else:
orphd = [storage_map[orphan] for orphan in self.orphans]
ret = module.instantiate(error_storage, ret = module.instantiate(error_storage,
*(in_storage + out_storage + orphd)) *(in_storage + out_storage + orphd))
...@@ -1727,7 +1718,8 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -1727,7 +1718,8 @@ class OpWiseCLinker(link.LocalLinker):
self.no_recycling = no_recycling self.no_recycling = no_recycling
return self return self
def make_all(self, profiler=None, input_storage=None, output_storage=None): def make_all(self, profiler=None, input_storage=None, output_storage=None,
storage_map=None):
# The lock will be acquired when we compile the first # The lock will be acquired when we compile the first
# C code. We will keep the lock untill all the function # C code. We will keep the lock untill all the function
...@@ -1741,7 +1733,7 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -1741,7 +1733,7 @@ class OpWiseCLinker(link.LocalLinker):
no_recycling = self.no_recycling no_recycling = self.no_recycling
input_storage, output_storage, storage_map = link.map_storage( input_storage, output_storage, storage_map = link.map_storage(
fgraph, order, input_storage, output_storage) fgraph, order, input_storage, output_storage, storage_map)
if self.allow_gc: if self.allow_gc:
computed, last_user = link.gc_helper(order) computed, last_user = link.gc_helper(order)
post_thunk_old_storage = [] post_thunk_old_storage = []
......
...@@ -278,6 +278,10 @@ class FunctionGraph(utils.object2): ...@@ -278,6 +278,10 @@ class FunctionGraph(utils.object2):
# import # # import #
def __import_r__(self, variable, reason): def __import_r__(self, variable, reason):
"""
Import variables to this FunctionGraph and also their apply_node,
if those nodes are not in this graph.
"""
global NullType global NullType
if NullType is None: if NullType is None:
from .null_type import NullType from .null_type import NullType
...@@ -296,6 +300,12 @@ class FunctionGraph(utils.object2): ...@@ -296,6 +300,12 @@ class FunctionGraph(utils.object2):
self.variables.add(variable) self.variables.add(variable)
def __import__(self, apply_node, check=True, reason=None): def __import__(self, apply_node, check=True, reason=None):
"""
Given an apply_node, recursively search from this node to know graph,
and then add all unknown variables and apply_nodes to this graph.
"""
node = apply_node
# We import the nodes in topological order. We only are interested # We import the nodes in topological order. We only are interested
# in new nodes, so we use all variables we know of as if they were the input set. # in new nodes, so we use all variables we know of as if they were the input set.
# (the functions in the graph module only use the input set to # (the functions in the graph module only use the input set to
...@@ -806,18 +816,30 @@ class FunctionGraph(utils.object2): ...@@ -806,18 +816,30 @@ class FunctionGraph(utils.object2):
""" """
return self.clone_get_equiv(check_integrity)[0] return self.clone_get_equiv(check_integrity)[0]
def clone_get_equiv(self, check_integrity=True): def clone_get_equiv(self, check_integrity=True, attach_feature=True):
""" """Clone the graph and get a memo( a dict )that map old node to new node
WRITEME ----------------------------
Parameters:
check_integrity - { bool } Whether to check integrity.
Default is True.
attach_feature - { bool } Whether to attach feature of origin graph to
cloned graph. Default is True.
----------------------------
Returns:
e - { FunctionGraph } Cloned fgraph. Every node in cloned graph is cloned.
equiv - { dict } A dict that map old node to new node.
""" """
equiv = graph.clone_get_equiv(self.inputs, self.outputs) equiv = graph.clone_get_equiv(self.inputs, self.outputs)
if check_integrity: if check_integrity:
self.check_integrity() self.check_integrity()
e = FunctionGraph([equiv[i] for i in self.inputs], e = FunctionGraph([equiv[i] for i in self.inputs],
[equiv[o] for o in self.outputs]) [equiv[o] for o in self.outputs],
clone=False)
if check_integrity: if check_integrity:
e.check_integrity() e.check_integrity()
if attach_feature:
for feature in self._features: for feature in self._features:
e.attach_feature(feature) e.attach_feature(feature)
return e, equiv return e, equiv
......
...@@ -496,10 +496,16 @@ class Container(object): ...@@ -496,10 +496,16 @@ class Container(object):
return r return r
def map_storage(fgraph, order, input_storage, output_storage): def map_storage(fgraph, order, input_storage, output_storage, storage_map=None):
""" """Ensure there is storage (a length-1 list) for inputs, outputs, and interior nodes.
Ensure there is storage (a length-1 list) for inputs, outputs, and
interior nodes. :param fgraph: The current fgraph. This function uses the inputs and outputs attributes.
:param order: an iterable over Apply instances (in program running order)
:param input_storage: None or existing input storage (see below)
:param output_storage: None or existing output storage (see below)
:rtype: 3-tuple
:returns: (list of storage for inputs, list of storage for outputs, and the `storage_map`)
Parameters Parameters
---------- ----------
...@@ -533,25 +539,45 @@ def map_storage(fgraph, order, input_storage, output_storage): ...@@ -533,25 +539,45 @@ def map_storage(fgraph, order, input_storage, output_storage):
""" """
# each Apply argument's data is stored in a list of length 1 (these lists act like pointers) # each Apply argument's data is stored in a list of length 1 (these lists act like pointers)
if storage_map is None:
storage_map = {}
# input_storage is a list of data-containers for the inputs. # input_storage is a list of data-containers for the inputs.
if input_storage is None: if input_storage is None:
input_storage = [[None] for input in fgraph.inputs] input_storage = [[None] for input in fgraph.inputs]
else: else:
assert len(fgraph.inputs) == len(input_storage) assert len(fgraph.inputs) == len(input_storage)
storage_map = {} # add input storage into storage_map
for r, storage in izip(fgraph.inputs, input_storage): for r, storage in zip(fgraph.inputs, input_storage):
if r in storage_map:
assert storage_map[r] is storage, ("Given input_storage conflicts "
"with storage in given storage_"
"map. Given input_storage: ",
storage, "Storage in storage_ma"
"p: ", storage_map[r])
else:
storage_map[r] = storage storage_map[r] = storage
# for orphan in fgraph.orphans: # for orphan in fgraph.orphans:
# if not isinstance(orphan, Constant): # if not isinstance(orphan, Constant):
# raise TypeError("Cannot link a graph with non-constant orphans.", orphan) # raise TypeError("Cannot link a graph with non-constant orphans.", orphan)
# storage_map[orphan] = [orphan.data] # storage_map[orphan] = [orphan.data]
# allocate output storage
if output_storage is not None: if output_storage is not None:
assert len(fgraph.outputs) == len(output_storage) assert len(fgraph.outputs) == len(output_storage)
for r, storage in izip(fgraph.outputs, output_storage): for r, storage in zip(fgraph.outputs, output_storage):
if r in storage_map:
assert storage_map[r] is storage, ("Given output_storage confl"
"icts with storage in given"
" storage_map. Given output"
"_storage: ", storage, "Sto"
"rage in storage_map: ",
storage_map[r])
else:
storage_map[r] = storage storage_map[r] = storage
# allocate storage for intermediate computation
for node in order: for node in order:
for r in node.inputs: for r in node.inputs:
if r not in storage_map: if r not in storage_map:
...@@ -563,6 +589,7 @@ def map_storage(fgraph, order, input_storage, output_storage): ...@@ -563,6 +589,7 @@ def map_storage(fgraph, order, input_storage, output_storage):
if isinstance(r, graph.Constant): if isinstance(r, graph.Constant):
storage_map.setdefault(r, [r.data]) storage_map.setdefault(r, [r.data])
# extract output storage
if output_storage is None: if output_storage is None:
output_storage = [storage_map[r] for r in fgraph.outputs] output_storage = [storage_map[r] for r in fgraph.outputs]
...@@ -650,9 +677,10 @@ class LocalLinker(Linker): ...@@ -650,9 +677,10 @@ class LocalLinker(Linker):
""" """
def make_thunk(self, input_storage=None, output_storage=None): def make_thunk(self, input_storage=None, output_storage=None, storage_map=None):
return self.make_all(input_storage=input_storage, return self.make_all(input_storage=input_storage,
output_storage=output_storage)[:3] output_storage=output_storage,
storage_map=storage_map)[:3]
def make_all(self, input_storage, output_storage): def make_all(self, input_storage, output_storage):
# By convention, subclasses of LocalLinker should implement this function! # By convention, subclasses of LocalLinker should implement this function!
...@@ -746,7 +774,7 @@ class PerformLinker(LocalLinker): ...@@ -746,7 +774,7 @@ class PerformLinker(LocalLinker):
self.no_recycling = no_recycling self.no_recycling = no_recycling
return self return self
def make_all(self, input_storage=None, output_storage=None): def make_all(self, input_storage=None, output_storage=None, storage_map=None):
""" """
Parameters Parameters
...@@ -768,7 +796,7 @@ class PerformLinker(LocalLinker): ...@@ -768,7 +796,7 @@ class PerformLinker(LocalLinker):
order = self.schedule(fgraph) order = self.schedule(fgraph)
no_recycling = self.no_recycling no_recycling = self.no_recycling
input_storage, output_storage, storage_map = map_storage(fgraph, order, input_storage, output_storage) input_storage, output_storage, storage_map = map_storage(fgraph, order, input_storage, output_storage, storage_map)
compute_map = {} compute_map = {}
for k in storage_map: for k in storage_map:
......
...@@ -230,8 +230,6 @@ if run_memory_usage_tests: ...@@ -230,8 +230,6 @@ if run_memory_usage_tests:
a = cuda.CudaNdarray(n) a = cuda.CudaNdarray(n)
a.sum() a.sum()
assert c == sys.getrefcount(n) assert c == sys.getrefcount(n)
# This is to confuse flake8
a = a
del a del a
if not i % 1000: if not i % 1000:
print('.', end=' ') print('.', end=' ')
......
...@@ -1000,14 +1000,14 @@ class VM_Linker(link.LocalLinker): ...@@ -1000,14 +1000,14 @@ class VM_Linker(link.LocalLinker):
return vm return vm
def make_all(self, profiler=None, input_storage=None, def make_all(self, profiler=None, input_storage=None,
output_storage=None, output_storage=None, storage_map=None,
): ):
fgraph = self.fgraph fgraph = self.fgraph
order = self.schedule(fgraph) order = self.schedule(fgraph)
no_recycling = self.no_recycling no_recycling = self.no_recycling
input_storage, output_storage, storage_map = link.map_storage( input_storage, output_storage, storage_map = link.map_storage(
fgraph, order, input_storage, output_storage) fgraph, order, input_storage, output_storage, storage_map)
compute_map = {} compute_map = {}
for k in storage_map: for k in storage_map:
compute_map[k] = [k.owner is None] compute_map[k] = [k.owner is None]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论