提交 7ba9c052 authored 作者: abergeron's avatar abergeron

Merge pull request #3055 from ChienliMa/swapSV

Swap SharedVariable
...@@ -83,7 +83,7 @@ Reference ...@@ -83,7 +83,7 @@ Reference
.. function:: function(inputs, outputs, mode=None, updates=None, givens=None, no_default_updates=False, accept_inplace=False, name=None, rebuild_strict=True, allow_input_downcast=None, profile=None, on_unused_input='raise') .. function:: function(inputs, outputs, mode=None, updates=None, givens=None, no_default_updates=False, accept_inplace=False, name=None, rebuild_strict=True, allow_input_downcast=None, profile=None, on_unused_input='raise')
Return a callable object that will calculate `outputs` from `inputs`. Return a :class:`callable object <theano.compile.function_module.Function>` that will calculate `outputs` from `inputs`.
:type params: list of either Variable or Param instances, but not shared :type params: list of either Variable or Param instances, but not shared
variables. variables.
...@@ -189,3 +189,6 @@ Reference ...@@ -189,3 +189,6 @@ Reference
equivalent to Var1. equivalent to Var1.
.. autofunction:: theano.compile.function.function_dump .. autofunction:: theano.compile.function.function_dump
.. autoclass:: theano.compile.function_module.Function
:members: free, copy
\ No newline at end of file
...@@ -1827,7 +1827,7 @@ class _Linker(gof.link.LocalLinker): ...@@ -1827,7 +1827,7 @@ class _Linker(gof.link.LocalLinker):
return self return self
def make_all(self, profiler=None, input_storage=None, def make_all(self, profiler=None, input_storage=None,
output_storage=None): output_storage=None, storage_map=None):
# can't import at toplevel because of circular import TODO: # can't import at toplevel because of circular import TODO:
# don't do this ugly hacky way of setting the # don't do this ugly hacky way of setting the
# filter_checks_isfinite # filter_checks_isfinite
...@@ -1857,7 +1857,7 @@ class _Linker(gof.link.LocalLinker): ...@@ -1857,7 +1857,7 @@ class _Linker(gof.link.LocalLinker):
no_recycling = [] no_recycling = []
input_storage, output_storage, storage_map = link.map_storage( input_storage, output_storage, storage_map = link.map_storage(
fgraph, order, input_storage_, output_storage_) fgraph, order, input_storage_, output_storage_, storage_map)
thunks_py = [] # python thunks thunks_py = [] # python thunks
thunks_c = [] # c thunks thunks_c = [] # c thunks
...@@ -2525,7 +2525,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions ...@@ -2525,7 +2525,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
self.mode = mode self.mode = mode
self.output_keys = output_keys self.output_keys = output_keys
def create(self, defaults=None, trustme=False): def create(self, defaults=None, trustme=False, storage_map=None):
""" """
Create a function. Create a function.
...@@ -2633,7 +2633,8 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions ...@@ -2633,7 +2633,8 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
defaults = _defaults defaults = _defaults
# Get a function instance # Get a function instance
_fn, _i, _o = self.linker.make_thunk(input_storage=input_storage) _fn, _i, _o = self.linker.make_thunk(input_storage=input_storage,
storage_map=storage_map)
fn = self.function_builder(_fn, _i, _o, self.indices, fn = self.function_builder(_fn, _i, _o, self.indices,
self.outputs, defaults, self.unpack_single, self.outputs, defaults, self.unpack_single,
self.return_none, self.output_keys, self) self.return_none, self.output_keys, self)
......
...@@ -293,7 +293,7 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None, ...@@ -293,7 +293,7 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
mode=mode, mode=mode,
accept_inplace=accept_inplace, name=name) accept_inplace=accept_inplace, name=name)
else: else:
# note: pfunc will also call orig_function-- orig_function is # note: pfunc will also call orig_function -- orig_function is
# a choke point that all compilation must pass through # a choke point that all compilation must pass through
fn = pfunc(params=inputs, fn = pfunc(params=inputs,
outputs=outputs, outputs=outputs,
......
...@@ -421,7 +421,7 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None, ...@@ -421,7 +421,7 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
if profile is True: if profile is True:
profile = ProfileStats(message=name) profile = ProfileStats(message=name)
# profile -> object # profile -> object
if type(profile) == str: elif type(profile) == str:
profile = ProfileStats(message=profile) profile = ProfileStats(message=profile)
# profile is typically either False or an object at this point. # profile is typically either False or an object at this point.
# No need to block other objects being passed through though. It might be # No need to block other objects being passed through though. It might be
......
...@@ -43,8 +43,9 @@ AddConfigVar('ProfileMode.profile_memory', ...@@ -43,8 +43,9 @@ AddConfigVar('ProfileMode.profile_memory',
class Profile_Maker(FunctionMaker): class Profile_Maker(FunctionMaker):
def create(self, input_storage=None, trustme=False): def create(self, input_storage=None, trustme=False, storage_map=None):
ret = super(Profile_Maker, self).create(input_storage, trustme) ret = super(Profile_Maker, self).create(input_storage, trustme,
storage_map)
if (hasattr(theano, 'sandbox') and if (hasattr(theano, 'sandbox') and
hasattr(theano.sandbox, 'cuda') and hasattr(theano.sandbox, 'cuda') and
......
...@@ -690,6 +690,7 @@ class ProfileStats(object): ...@@ -690,6 +690,7 @@ class ProfileStats(object):
print('', file=file) print('', file=file)
# The validation time is a subset of optimizer_time # The validation time is a subset of optimizer_time
if self.optimizer_time > 0:
assert self.validate_time < self.optimizer_time assert self.validate_time < self.optimizer_time
def summary_globals(self, file): def summary_globals(self, file):
......
...@@ -241,6 +241,143 @@ class T_function(unittest.TestCase): ...@@ -241,6 +241,143 @@ class T_function(unittest.TestCase):
f(1, 2) # put them out of sync f(1, 2) # put them out of sync
self.assertFalse(f(1, 2) == g(1, 2)) # they should not be equal anymore. self.assertFalse(f(1, 2) == g(1, 2)) # they should not be equal anymore.
def test_copy_share_memory(self):
x = T.fscalar('x')
# SharedVariable for tests, one of them has update
y = theano.shared(value=1)
z = theano.shared(value=2)
out = T.tanh((x+y+2)/(x+z-0.2)**2)
# Test for different linkers
for mode in ["FAST_RUN","FAST_COMPILE"]:
ori = theano.function([x], [out], mode=mode,updates={z:z+1})
cpy = ori.copy(share_memory=True)
# Test if memories shared
storage_map_ori = ori.fn.storage_map
storage_map_cpy = cpy.fn.storage_map
fgraph_ori = ori.maker.fgraph
fgraph_cpy = cpy.maker.fgraph
# Assert intermediate and Constants storages are shared.
# and output stoarges are not shared
i_o_variables = fgraph_cpy.inputs + fgraph_cpy.outputs
ori_storages = storage_map_ori.values()
for key in storage_map_cpy.keys():
storage = storage_map_cpy[key]
if key not in i_o_variables or isinstance(key, theano.tensor.Constant):
self.assertTrue(any([ storage is s for s in ori_storages]))
# Assert storages of SharedVariable without updates are shared
for (input, _1, _2), here, there in zip(ori.indices,
ori.input_storage,
cpy.input_storage):
self.assertTrue(here.data is there.data)
def test_swap_SharedVariable(self):
i = T.iscalar()
x_list = theano.shared(value=numpy.random.rand(10).astype(config.floatX))
x = T.scalar('x')
# SharedVariable for tests, one of them has update
y = theano.shared(value=1, name='y')
z = theano.shared(value=2, name='z')
m = theano.shared(value=0, name='m')
# SharedVariable to replace
y_rpl = theano.shared(value=3,name ='y_rpl')
z_rpl = theano.shared(value=4, name='z_rpl')
swap = {y:y_rpl, z:z_rpl}
map_SV = {'y_rpl':y_rpl, 'z_rpl':z_rpl}
out = x+y+z+m
# Test for different linkers
# for mode in ["FAST_RUN","FAST_COMPILE"]:
second_time = False
for mode in ["FAST_RUN","FAST_COMPILE"]:
ori = theano.function([i], [out], mode=mode,
updates=[(z,z+1),(m,m+2)],
givens={x:x_list[i]})
cpy = ori.copy(swap=swap)
# run fuction several time
ori(1), cpy(1),cpy(2)
# assert same SharedVariable are update in different function
if not second_time:
# m should be updated 3 times
assert m.get_value() == 6
# z should be updated once
assert z.get_value() == 3
# z_rpl should be updated twice
assert z_rpl.get_value() == 6
# y and y_rpl should not be updated
assert y_rpl.get_value() == 3
assert y.get_value() == 1
elif second_time:
# doule update for sharedvariable
assert m.get_value() == 12
assert z.get_value() == 4
assert z_rpl.get_value() == 8
assert y_rpl.get_value() == 3
# test cpy function:
# 2. SharedVariable is updatable -> values did update(z == 5)
# 1. sharedvariable is swap -> Rpl sharedvariables share storage
names = map_SV.keys()
for key in cpy.fn.storage_map:
if key.name in names:
assert map_SV[key.name].container.storage[0] ==\
cpy.fn.storage_map[key][0]
second_time = True
def test_swap_SharedVaraile_with_given(self):
"""
A special testcase for logistic_sgd.py in Deep Learning Tutorial
This test assert that SharedVariable in different function have same storage
"""
train_x = theano.shared(value=numpy.random.rand(10,10).astype(config.floatX))
test_x = theano.shared(value=numpy.random.rand(10,10).astype(config.floatX))
train_y = theano.shared(value=numpy.random.rand(10,1).astype(config.floatX))
test_y = theano.shared(value=numpy.random.rand(10,1).astype(config.floatX))
i = T.iscalar('index')
x = T.vector('x')
y = T.vector('y')
# this formular has no sense but for a test
out = (T.sum(x) - y) ** 2
train = theano.function([i], out,
givens={x:train_x[i], y:train_y[i]},
updates={train_x:train_x+0.1})
test_def = theano.function([i], out, givens={x:test_x[i], y:test_y[i]})
test_cpy = train.copy(swap={train_x:test_x, train_y:test_y},
delete_updates=True)
for in1, in2 in zip( test_def.maker.inputs, test_def.maker.inputs):
assert in1.value is in2.value
def test_copy_delete_updates(self):
x = T.fscalar('x')
# SharedVariable for tests, one of them has update
y = theano.shared(value=1, name='y')
z = theano.shared(value=2, name='z')
out = x+y+z
# Test for different linkers
# for mode in ["FAST_RUN","FAST_COMPILE"]:
second_time = False
for mode in ["FAST_RUN","FAST_COMPILE"]:
ori = theano.function([x], out, mode=mode,updates={z:z*2})
cpy = ori.copy(delete_updates=True)
assert cpy(1)[0] == 4
assert cpy(1)[0] == 4
assert cpy(1)[0] == 4
def test_shared_state0(self): def test_shared_state0(self):
a = T.scalar() # the a is for 'anonymous' (un-named). a = T.scalar() # the a is for 'anonymous' (un-named).
x, s = T.scalars('xs') x, s = T.scalars('xs')
......
...@@ -1069,11 +1069,9 @@ class CLinker(link.Linker): ...@@ -1069,11 +1069,9 @@ class CLinker(link.Linker):
pass pass
return utils.uniq(ret) return utils.uniq(ret)
def __compile__(self, input_storage=None, def __compile__(self, input_storage=None, output_storage=None,
output_storage=None, keep_lock=False): storage_map=None, keep_lock=False):
""" """WRITEME
WRITEME
Compiles this linker's fgraph. Compiles this linker's fgraph.
Parameters Parameters
...@@ -1111,6 +1109,7 @@ class CLinker(link.Linker): ...@@ -1111,6 +1109,7 @@ class CLinker(link.Linker):
thunk = self.cthunk_factory(error_storage, thunk = self.cthunk_factory(error_storage,
input_storage, input_storage,
output_storage, output_storage,
storage_map,
keep_lock=keep_lock) keep_lock=keep_lock)
return (thunk, return (thunk,
[link.Container(input, storage) for input, storage in [link.Container(input, storage) for input, storage in
...@@ -1143,10 +1142,8 @@ class CLinker(link.Linker): ...@@ -1143,10 +1142,8 @@ class CLinker(link.Linker):
return init_tasks, tasks return init_tasks, tasks
def make_thunk(self, input_storage=None, output_storage=None, def make_thunk(self, input_storage=None, output_storage=None,
keep_lock=False): storage_map=None, keep_lock=False):
""" """WRITEME
WRITEME
Compiles this linker's fgraph and returns a function to perform the Compiles this linker's fgraph and returns a function to perform the
computations, as well as lists of storage cells for both the inputs computations, as well as lists of storage cells for both the inputs
and outputs. and outputs.
...@@ -1157,25 +1154,24 @@ class CLinker(link.Linker): ...@@ -1157,25 +1154,24 @@ class CLinker(link.Linker):
List of lists of length 1. In order to use List of lists of length 1. In order to use
the thunk returned by __compile__, the inputs must be put in the thunk returned by __compile__, the inputs must be put in
that storage. If None, storage will be allocated. that storage. If None, storage will be allocated.
output_storage: list of lists of length 1 @param output_storage: list of lists of length 1. The thunk returned
The thunk returned by __compile__ will put the variables of the by __compile__ will put the variables of the computation in these
computation in these lists. If None, storage will be allocated. lists. If None, storage will be allocated.
@param storage_map: dict that map variables to storages. This is used
when you need to customize the storage of this thunk.
Returns: thunk, input_storage, output_storage
Returns
-------
object
Thunk, input_storage, output_storage.
The return values can be used as follows: The return values can be used as follows:
f, istor, ostor = clinker.make_thunk() f, istor, ostor = clinker.make_thunk()
istor[0].data = first_input istor[0].data = first_input
istor[1].data = second_input istor[1].data = second_input
f() f()
first_output = ostor[0].data first_output = ostor[0].data
""" """
init_tasks, tasks = self.get_init_tasks() init_tasks, tasks = self.get_init_tasks()
cthunk, in_storage, out_storage, error_storage = self.__compile__( cthunk, in_storage, out_storage, error_storage = self.__compile__(
input_storage, output_storage, input_storage, output_storage, storage_map,
keep_lock=keep_lock) keep_lock=keep_lock)
res = _CThunk(cthunk, init_tasks, tasks, error_storage) res = _CThunk(cthunk, init_tasks, tasks, error_storage)
...@@ -1529,25 +1525,17 @@ class CLinker(link.Linker): ...@@ -1529,25 +1525,17 @@ class CLinker(link.Linker):
return self._mod return self._mod
def cthunk_factory(self, error_storage, in_storage, out_storage, def cthunk_factory(self, error_storage, in_storage, out_storage,
keep_lock=False): storage_map=None, keep_lock=False):
""" """WRITEME
WRITEME error_storage -> list of length 3
in_storage -> list of lists of length 1, one per input
Parameters out_storage -> list of lists of length 1, one per output
----------
error_storage : list of length 3
in_storage : list of lists of length 1, one per input
out_storage : list of lists of length 1, one per output
Returns Returns a thunk that points to an instance of a C struct that
-------
object
A thunk that points to an instance of a C struct that
can carry on the computation of this linker's fgraph. That thunk, can carry on the computation of this linker's fgraph. That thunk,
when executed, will fetch its inputs from in_storage, put its when executed, will fetch its inputs from in_storage, put its
outputs in out_storage and if an error occurs will put the outputs in out_storage and if an error occurs will put the
type, value and traceback of the exception in error_storage. type, value and traceback of the exception in error_storage.
""" """
try: try:
key = self.cmodule_key() key = self.cmodule_key()
...@@ -1569,7 +1557,10 @@ class CLinker(link.Linker): ...@@ -1569,7 +1557,10 @@ class CLinker(link.Linker):
out_storage = [x for i, x in enumerate(out_storage) out_storage = [x for i, x in enumerate(out_storage)
if (i + len(in_storage)) not in dupidx] if (i + len(in_storage)) not in dupidx]
in_storage = [x for i, x in enumerate(in_storage) if i not in dupidx] in_storage = [x for i, x in enumerate(in_storage) if i not in dupidx]
if storage_map is None:
orphd = [[orphan.data] for orphan in self.orphans] orphd = [[orphan.data] for orphan in self.orphans]
else:
orphd = [storage_map[orphan] for orphan in self.orphans]
ret = module.instantiate(error_storage, ret = module.instantiate(error_storage,
*(in_storage + out_storage + orphd)) *(in_storage + out_storage + orphd))
...@@ -1727,7 +1718,8 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -1727,7 +1718,8 @@ class OpWiseCLinker(link.LocalLinker):
self.no_recycling = no_recycling self.no_recycling = no_recycling
return self return self
def make_all(self, profiler=None, input_storage=None, output_storage=None): def make_all(self, profiler=None, input_storage=None, output_storage=None,
storage_map=None):
# The lock will be acquired when we compile the first # The lock will be acquired when we compile the first
# C code. We will keep the lock untill all the function # C code. We will keep the lock untill all the function
...@@ -1741,7 +1733,7 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -1741,7 +1733,7 @@ class OpWiseCLinker(link.LocalLinker):
no_recycling = self.no_recycling no_recycling = self.no_recycling
input_storage, output_storage, storage_map = link.map_storage( input_storage, output_storage, storage_map = link.map_storage(
fgraph, order, input_storage, output_storage) fgraph, order, input_storage, output_storage, storage_map)
if self.allow_gc: if self.allow_gc:
computed, last_user = link.gc_helper(order) computed, last_user = link.gc_helper(order)
post_thunk_old_storage = [] post_thunk_old_storage = []
......
...@@ -278,6 +278,10 @@ class FunctionGraph(utils.object2): ...@@ -278,6 +278,10 @@ class FunctionGraph(utils.object2):
# import # # import #
def __import_r__(self, variable, reason): def __import_r__(self, variable, reason):
"""
Import variables to this FunctionGraph and also their apply_node,
if those nodes are not in this graph.
"""
global NullType global NullType
if NullType is None: if NullType is None:
from .null_type import NullType from .null_type import NullType
...@@ -296,6 +300,12 @@ class FunctionGraph(utils.object2): ...@@ -296,6 +300,12 @@ class FunctionGraph(utils.object2):
self.variables.add(variable) self.variables.add(variable)
def __import__(self, apply_node, check=True, reason=None): def __import__(self, apply_node, check=True, reason=None):
"""
Given an apply_node, recursively search from this node to know graph,
and then add all unknown variables and apply_nodes to this graph.
"""
node = apply_node
# We import the nodes in topological order. We only are interested # We import the nodes in topological order. We only are interested
# in new nodes, so we use all variables we know of as if they were the input set. # in new nodes, so we use all variables we know of as if they were the input set.
# (the functions in the graph module only use the input set to # (the functions in the graph module only use the input set to
...@@ -806,18 +816,30 @@ class FunctionGraph(utils.object2): ...@@ -806,18 +816,30 @@ class FunctionGraph(utils.object2):
""" """
return self.clone_get_equiv(check_integrity)[0] return self.clone_get_equiv(check_integrity)[0]
def clone_get_equiv(self, check_integrity=True): def clone_get_equiv(self, check_integrity=True, attach_feature=True):
""" """Clone the graph and get a memo( a dict )that map old node to new node
WRITEME ----------------------------
Parameters:
check_integrity - { bool } Whether to check integrity.
Default is True.
attach_feature - { bool } Whether to attach feature of origin graph to
cloned graph. Default is True.
----------------------------
Returns:
e - { FunctionGraph } Cloned fgraph. Every node in cloned graph is cloned.
equiv - { dict } A dict that map old node to new node.
""" """
equiv = graph.clone_get_equiv(self.inputs, self.outputs) equiv = graph.clone_get_equiv(self.inputs, self.outputs)
if check_integrity: if check_integrity:
self.check_integrity() self.check_integrity()
e = FunctionGraph([equiv[i] for i in self.inputs], e = FunctionGraph([equiv[i] for i in self.inputs],
[equiv[o] for o in self.outputs]) [equiv[o] for o in self.outputs],
clone=False)
if check_integrity: if check_integrity:
e.check_integrity() e.check_integrity()
if attach_feature:
for feature in self._features: for feature in self._features:
e.attach_feature(feature) e.attach_feature(feature)
return e, equiv return e, equiv
......
...@@ -496,10 +496,16 @@ class Container(object): ...@@ -496,10 +496,16 @@ class Container(object):
return r return r
def map_storage(fgraph, order, input_storage, output_storage): def map_storage(fgraph, order, input_storage, output_storage, storage_map=None):
""" """Ensure there is storage (a length-1 list) for inputs, outputs, and interior nodes.
Ensure there is storage (a length-1 list) for inputs, outputs, and
interior nodes. :param fgraph: The current fgraph. This function uses the inputs and outputs attributes.
:param order: an iterable over Apply instances (in program running order)
:param input_storage: None or existing input storage (see below)
:param output_storage: None or existing output storage (see below)
:rtype: 3-tuple
:returns: (list of storage for inputs, list of storage for outputs, and the `storage_map`)
Parameters Parameters
---------- ----------
...@@ -533,25 +539,45 @@ def map_storage(fgraph, order, input_storage, output_storage): ...@@ -533,25 +539,45 @@ def map_storage(fgraph, order, input_storage, output_storage):
""" """
# each Apply argument's data is stored in a list of length 1 (these lists act like pointers) # each Apply argument's data is stored in a list of length 1 (these lists act like pointers)
if storage_map is None:
storage_map = {}
# input_storage is a list of data-containers for the inputs. # input_storage is a list of data-containers for the inputs.
if input_storage is None: if input_storage is None:
input_storage = [[None] for input in fgraph.inputs] input_storage = [[None] for input in fgraph.inputs]
else: else:
assert len(fgraph.inputs) == len(input_storage) assert len(fgraph.inputs) == len(input_storage)
storage_map = {} # add input storage into storage_map
for r, storage in izip(fgraph.inputs, input_storage): for r, storage in zip(fgraph.inputs, input_storage):
if r in storage_map:
assert storage_map[r] is storage, ("Given input_storage conflicts "
"with storage in given storage_"
"map. Given input_storage: ",
storage, "Storage in storage_ma"
"p: ", storage_map[r])
else:
storage_map[r] = storage storage_map[r] = storage
# for orphan in fgraph.orphans: # for orphan in fgraph.orphans:
# if not isinstance(orphan, Constant): # if not isinstance(orphan, Constant):
# raise TypeError("Cannot link a graph with non-constant orphans.", orphan) # raise TypeError("Cannot link a graph with non-constant orphans.", orphan)
# storage_map[orphan] = [orphan.data] # storage_map[orphan] = [orphan.data]
# allocate output storage
if output_storage is not None: if output_storage is not None:
assert len(fgraph.outputs) == len(output_storage) assert len(fgraph.outputs) == len(output_storage)
for r, storage in izip(fgraph.outputs, output_storage): for r, storage in zip(fgraph.outputs, output_storage):
if r in storage_map:
assert storage_map[r] is storage, ("Given output_storage confl"
"icts with storage in given"
" storage_map. Given output"
"_storage: ", storage, "Sto"
"rage in storage_map: ",
storage_map[r])
else:
storage_map[r] = storage storage_map[r] = storage
# allocate storage for intermediate computation
for node in order: for node in order:
for r in node.inputs: for r in node.inputs:
if r not in storage_map: if r not in storage_map:
...@@ -563,6 +589,7 @@ def map_storage(fgraph, order, input_storage, output_storage): ...@@ -563,6 +589,7 @@ def map_storage(fgraph, order, input_storage, output_storage):
if isinstance(r, graph.Constant): if isinstance(r, graph.Constant):
storage_map.setdefault(r, [r.data]) storage_map.setdefault(r, [r.data])
# extract output storage
if output_storage is None: if output_storage is None:
output_storage = [storage_map[r] for r in fgraph.outputs] output_storage = [storage_map[r] for r in fgraph.outputs]
...@@ -650,9 +677,10 @@ class LocalLinker(Linker): ...@@ -650,9 +677,10 @@ class LocalLinker(Linker):
""" """
def make_thunk(self, input_storage=None, output_storage=None): def make_thunk(self, input_storage=None, output_storage=None, storage_map=None):
return self.make_all(input_storage=input_storage, return self.make_all(input_storage=input_storage,
output_storage=output_storage)[:3] output_storage=output_storage,
storage_map=storage_map)[:3]
def make_all(self, input_storage, output_storage): def make_all(self, input_storage, output_storage):
# By convention, subclasses of LocalLinker should implement this function! # By convention, subclasses of LocalLinker should implement this function!
...@@ -746,7 +774,7 @@ class PerformLinker(LocalLinker): ...@@ -746,7 +774,7 @@ class PerformLinker(LocalLinker):
self.no_recycling = no_recycling self.no_recycling = no_recycling
return self return self
def make_all(self, input_storage=None, output_storage=None): def make_all(self, input_storage=None, output_storage=None, storage_map=None):
""" """
Parameters Parameters
...@@ -768,7 +796,7 @@ class PerformLinker(LocalLinker): ...@@ -768,7 +796,7 @@ class PerformLinker(LocalLinker):
order = self.schedule(fgraph) order = self.schedule(fgraph)
no_recycling = self.no_recycling no_recycling = self.no_recycling
input_storage, output_storage, storage_map = map_storage(fgraph, order, input_storage, output_storage) input_storage, output_storage, storage_map = map_storage(fgraph, order, input_storage, output_storage, storage_map)
compute_map = {} compute_map = {}
for k in storage_map: for k in storage_map:
......
...@@ -230,8 +230,6 @@ if run_memory_usage_tests: ...@@ -230,8 +230,6 @@ if run_memory_usage_tests:
a = cuda.CudaNdarray(n) a = cuda.CudaNdarray(n)
a.sum() a.sum()
assert c == sys.getrefcount(n) assert c == sys.getrefcount(n)
# This is to confuse flake8
a = a
del a del a
if not i % 1000: if not i % 1000:
print('.', end=' ') print('.', end=' ')
......
...@@ -1000,14 +1000,14 @@ class VM_Linker(link.LocalLinker): ...@@ -1000,14 +1000,14 @@ class VM_Linker(link.LocalLinker):
return vm return vm
def make_all(self, profiler=None, input_storage=None, def make_all(self, profiler=None, input_storage=None,
output_storage=None, output_storage=None, storage_map=None,
): ):
fgraph = self.fgraph fgraph = self.fgraph
order = self.schedule(fgraph) order = self.schedule(fgraph)
no_recycling = self.no_recycling no_recycling = self.no_recycling
input_storage, output_storage, storage_map = link.map_storage( input_storage, output_storage, storage_map = link.map_storage(
fgraph, order, input_storage, output_storage) fgraph, order, input_storage, output_storage, storage_map)
compute_map = {} compute_map = {}
for k in storage_map: for k in storage_map:
compute_map[k] = [k.owner is None] compute_map[k] = [k.owner is None]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论