提交 7ba9c052 authored 作者: abergeron's avatar abergeron

Merge pull request #3055 from ChienliMa/swapSV

Swap SharedVariable
......@@ -83,7 +83,7 @@ Reference
.. function:: function(inputs, outputs, mode=None, updates=None, givens=None, no_default_updates=False, accept_inplace=False, name=None, rebuild_strict=True, allow_input_downcast=None, profile=None, on_unused_input='raise')
Return a callable object that will calculate `outputs` from `inputs`.
Return a :class:`callable object <theano.compile.function_module.Function>` that will calculate `outputs` from `inputs`.
:type params: list of either Variable or Param instances, but not shared
variables.
......@@ -189,3 +189,6 @@ Reference
equivalent to Var1.
.. autofunction:: theano.compile.function.function_dump
.. autoclass:: theano.compile.function_module.Function
:members: free, copy
\ No newline at end of file
......@@ -1827,7 +1827,7 @@ class _Linker(gof.link.LocalLinker):
return self
def make_all(self, profiler=None, input_storage=None,
output_storage=None):
output_storage=None, storage_map=None):
# can't import at toplevel because of circular import TODO:
# don't do this ugly hacky way of setting the
# filter_checks_isfinite
......@@ -1857,7 +1857,7 @@ class _Linker(gof.link.LocalLinker):
no_recycling = []
input_storage, output_storage, storage_map = link.map_storage(
fgraph, order, input_storage_, output_storage_)
fgraph, order, input_storage_, output_storage_, storage_map)
thunks_py = [] # python thunks
thunks_c = [] # c thunks
......@@ -2525,7 +2525,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
self.mode = mode
self.output_keys = output_keys
def create(self, defaults=None, trustme=False):
def create(self, defaults=None, trustme=False, storage_map=None):
"""
Create a function.
......@@ -2633,7 +2633,8 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
defaults = _defaults
# Get a function instance
_fn, _i, _o = self.linker.make_thunk(input_storage=input_storage)
_fn, _i, _o = self.linker.make_thunk(input_storage=input_storage,
storage_map=storage_map)
fn = self.function_builder(_fn, _i, _o, self.indices,
self.outputs, defaults, self.unpack_single,
self.return_none, self.output_keys, self)
......
......@@ -293,7 +293,7 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
mode=mode,
accept_inplace=accept_inplace, name=name)
else:
# note: pfunc will also call orig_function-- orig_function is
# note: pfunc will also call orig_function -- orig_function is
# a choke point that all compilation must pass through
fn = pfunc(params=inputs,
outputs=outputs,
......
......@@ -5,7 +5,7 @@ Driver of graph construction, optimization, and linking.
from __future__ import print_function
import copy
from six import string_types, iteritems
from six import string_types, iteritems, iterkeys
from six.moves import xrange
import six.moves.copyreg as copyreg
import six.moves.cPickle as pickle
......@@ -15,9 +15,10 @@ import warnings
import numpy
import theano
from theano import gof
from theano import config, gof
from functools import partial
from theano.compat import izip
from theano.gof import graph
import theano.compile.mode
from theano.compile.io import (
In, SymbolicInput, SymbolicInputKit, SymbolicOutput)
......@@ -136,8 +137,8 @@ class Supervisor:
return True
for r in self.protected + list(fgraph.outputs):
if fgraph.destroyers(r):
raise gof.InconsistencyError(
"Trying to destroy a protected Variable.", r)
raise gof.InconsistencyError("Trying to destroy a protected"
"Variable.", r)
def std_fgraph(input_specs, output_specs, accept_inplace=False):
......@@ -535,16 +536,201 @@ class Function(object):
self.value[item] = value
def __copy__(self):
defaults = [default for _1, _2, default in self.defaults]
cpy = self.maker.create(defaults, trustme=True)
for (input, _1, _2), here, there in zip(self.indices,
self.input_storage,
cpy.input_storage):
if input.mutable and here is not None:
there.data = copy.copy(here.data)
"""
Copy a function. Copied function have separate intermediate
storages and output storages with original function
"""
return self.copy()
def copy(self, share_memory=False, swap=None, delete_updates=False,
name=None, profile=None):
"""
Copy this function. Copied function will have separated maker and
fgraph with original function. User can choose whether to separate
storage by changing the share_memory arguments.
---------------------
Params:
share_memory -- { boolean } Default is False. When True, two
function share intermediate storages(storages except input and
output storages). Otherwise two functions will only share partial
storages and same maker. If two functions share memory and
allow_gc=False, this will increase executing speed and save memory.
swap -- { dict } Dictionary that map old SharedVariables to new
SharedVariables. Default is None.
NOTE: The shared variable swap in only done in the new returned
function, not in the user graph.
delete_updates -- { boolean } Default is False. If True, Copied
function will not have update.
name -- { string } If provided, will be the name of the new
Function. Otherwise, it will be old + " copy"
profile -- as theano.function profile parameter
---------------------
Returns:
func -- Copied theano.Function
"""
# helper function
def checkSV(sv_ori, sv_rpl):
"""
Assert two SharedVariable follow some restirctions:
1. same type
2. same shape or dim?
"""
SharedVariable = theano.tensor.sharedvar.SharedVariable
assert isinstance(sv_ori, SharedVariable), (
"Key of swap should be SharedVariable, given:", sv_ori,
" type", type(sv_ori))
assert isinstance(sv_rpl, SharedVariable), (
"Value of swap should be SharedVariable, given:", sv_rpl,
"type", type(sv_ori))
assert sv_ori.type == sv_rpl.type, (
"Type of given SharedVariable conflicts with original one",
"Type of given SharedVariable:", sv_rpl.type,
"Type of original SharedVariable:", sv_ori.type)
maker = self.maker
# Copy Ins and their storage.
# so that they have different storage as their value
ins = [copy.copy(input) for input in maker.inputs]
# Delete update output in fgraph and updates In instances if needed
if delete_updates:
# The first len(maker.outputs) variables are original variables.
# The rest are the updates.
out_vars = maker.fgraph.outputs[:len(maker.outputs)]
else:
out_vars = maker.fgraph.outputs
# Init new fgraph using copied variables and get memo
# memo: a dict that map old variables to new variables
memo = graph.clone_get_equiv(maker.fgraph.inputs, out_vars)
fg_cpy = gof.fg.FunctionGraph([memo[i] for i in maker.fgraph.inputs],
[memo[o] for o in out_vars],
clone=False)
# Re initialize Outs and swap update and variable in Ins
# By doing this, we can pass FunctionMaker._check_unused_inputs()
outs = list(map(SymbolicOutput, fg_cpy.outputs[:len(maker.outputs)]))
for out_ori, out_cpy in zip(maker.outputs, outs):
out_cpy.borrow = out_ori.borrow
# swap SharedVariable
if swap is not None:
exist_svs = [i.variable for i in maker.inputs]
# Check if given ShareVariables exist
for sv in iterkeys(swap):
if sv not in exist_svs:
raise ValueError("SharedVariable: %s not found" %
(sv.name))
# Swap SharedVariable in fgraph and In instances
for index, (i, in_v) in enumerate(zip(ins, fg_cpy.inputs)):
# Variables in maker.inputs are defined by user, therefore we
# use them to make comparision and do the mapping.
# Otherwise we don't touch them.
var = maker.inputs[index].variable
if var in swap:
swap_sv = swap[var]
checkSV(i.variable, swap_sv)
# swap variable and value of In instances
i.variable = swap_sv
i.value = swap_sv.container
# In the fgraph we use the cloned SharedVariable
swap_sv = swap_sv.clone()
# Swap SharedVariable in fgraph
# if inputs was replaced, change self.inputs
fg_cpy.inputs[index] = swap_sv
fg_cpy.replace(in_v, swap_sv, reason="Swap SV")
# Delete update if needed
update_i = len(outs)
for i, in_var in zip(ins, fg_cpy.inputs):
i.variable = in_var
if not delete_updates and i.update is not None:
i.update = fg_cpy.outputs[update_i]
update_i += 1
else:
i.update = None
# Construct new storage_map that map new variable to old storage,
# so that the ensuing function shares storage with the original one
storage_map = self.fn.storage_map
new_storage_map = {}
# TODO: We could share the output storage, but we must make sure
# 2 different function call won't override each other values. This
# is already done elsewhere, so to reuse it the user would need to
# use Out(var, borrow=True) and maybe the mutable=True flag too.
# But to be safe for now as it isn't documented and we aren't sure
# it is well tested, we don't share the part of the storage_map.
if share_memory:
i_o_vars = maker.fgraph.inputs + maker.fgraph.outputs
for key in storage_map.keys():
if key not in i_o_vars:
new_storage_map[memo[key]] = storage_map[key]
if not name and self.name:
name = self.name + " copy"
input_storage = [i.value for i in ins]
# reinitialize new maker and create new function
if profile is None:
profile = config.profile
# profile -> True or False
if profile is True:
if name:
message = name
else:
there.data = here.data
return cpy
message = str(maker.profile.message) + " copy"
profile = theano.compile.profiling.ProfileStats(message=message)
# profile -> object
elif type(profile) == str:
profile = theano.compile.profiling.ProfileStats(message=profile)
f_cpy = maker.__class__(inputs=ins, outputs=outs, fgraph=fg_cpy,
mode=maker.mode, profile=profile,
on_unused_input=maker.on_unused_input,
function_builder=maker.function_builder,
accept_inplace=maker.accept_inplace
).create(input_storage,
storage_map=new_storage_map)
for in_ori, in_cpy, ori, cpy in zip(maker.inputs, f_cpy.maker.inputs,
self.input_storage,
f_cpy.input_storage):
# Share immutable ShareVariable and constant input's storage
swapped = swap is not None and in_ori.variable in swap
# Using the original storage if SharedVariable will not be updated
# and is not swapped
if not in_ori.mutable and not swapped:
cpy.data = ori.data
in_cpy.value = in_ori.value
# Reconstruct Function.finder which map Variable defined by user
# to container, to make Function.value and Function.data work well.
# Replace variable in new maker.inputs by the original ones.
# So that user can swap SharedVariable in a swapped function
container = f_cpy.finder.pop(in_cpy.variable)
if not swapped:
f_cpy.finder[in_ori.variable] = container
in_cpy.vairable = in_ori.variable
else:
f_cpy.finder[swap[in_ori.variable]] = container
in_cpy.variable = swap[in_ori.variable]
f_cpy.name = name
f_cpy.maker.fgraph.name = name
return f_cpy
def __call__(self, *args, **kwargs):
profile = self.profile
......@@ -1232,8 +1418,8 @@ class FunctionMaker(object):
else:
# fgraph is already an optimized one
need_opt = False
_, additional_outputs = std_fgraph(inputs, outputs, accept_inplace)
pass
updates = [spec.update for spec in inputs if spec.update]
additional_outputs = list(map(SymbolicOutput, updates))
self.fgraph = fgraph
......@@ -1355,7 +1541,7 @@ class FunctionMaker(object):
"'%s'.\nValid values are 'raise', "
"'warn', and 'ignore'." % on_unused_input)
def create(self, input_storage=None, trustme=False):
def create(self, input_storage=None, trustme=False, storage_map=None):
"""
Create a function.
......@@ -1436,7 +1622,7 @@ class FunctionMaker(object):
try:
theano.config.traceback.limit = 0
_fn, _i, _o = self.linker.make_thunk(
input_storage=input_storage_lists)
input_storage=input_storage_lists, storage_map=storage_map)
finally:
theano.config.traceback.limit = limit_orig
......
......@@ -421,7 +421,7 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
if profile is True:
profile = ProfileStats(message=name)
# profile -> object
if type(profile) == str:
elif type(profile) == str:
profile = ProfileStats(message=profile)
# profile is typically either False or an object at this point.
# No need to block other objects being passed through though. It might be
......
......@@ -43,8 +43,9 @@ AddConfigVar('ProfileMode.profile_memory',
class Profile_Maker(FunctionMaker):
def create(self, input_storage=None, trustme=False):
ret = super(Profile_Maker, self).create(input_storage, trustme)
def create(self, input_storage=None, trustme=False, storage_map=None):
ret = super(Profile_Maker, self).create(input_storage, trustme,
storage_map)
if (hasattr(theano, 'sandbox') and
hasattr(theano.sandbox, 'cuda') and
......
......@@ -690,7 +690,8 @@ class ProfileStats(object):
print('', file=file)
# The validation time is a subset of optimizer_time
assert self.validate_time < self.optimizer_time
if self.optimizer_time > 0:
assert self.validate_time < self.optimizer_time
def summary_globals(self, file):
print('Time in all call to theano.grad() %es' %
......
......@@ -241,6 +241,143 @@ class T_function(unittest.TestCase):
f(1, 2) # put them out of sync
self.assertFalse(f(1, 2) == g(1, 2)) # they should not be equal anymore.
def test_copy_share_memory(self):
x = T.fscalar('x')
# SharedVariable for tests, one of them has update
y = theano.shared(value=1)
z = theano.shared(value=2)
out = T.tanh((x+y+2)/(x+z-0.2)**2)
# Test for different linkers
for mode in ["FAST_RUN","FAST_COMPILE"]:
ori = theano.function([x], [out], mode=mode,updates={z:z+1})
cpy = ori.copy(share_memory=True)
# Test if memories shared
storage_map_ori = ori.fn.storage_map
storage_map_cpy = cpy.fn.storage_map
fgraph_ori = ori.maker.fgraph
fgraph_cpy = cpy.maker.fgraph
# Assert intermediate and Constants storages are shared.
# and output stoarges are not shared
i_o_variables = fgraph_cpy.inputs + fgraph_cpy.outputs
ori_storages = storage_map_ori.values()
for key in storage_map_cpy.keys():
storage = storage_map_cpy[key]
if key not in i_o_variables or isinstance(key, theano.tensor.Constant):
self.assertTrue(any([ storage is s for s in ori_storages]))
# Assert storages of SharedVariable without updates are shared
for (input, _1, _2), here, there in zip(ori.indices,
ori.input_storage,
cpy.input_storage):
self.assertTrue(here.data is there.data)
def test_swap_SharedVariable(self):
i = T.iscalar()
x_list = theano.shared(value=numpy.random.rand(10).astype(config.floatX))
x = T.scalar('x')
# SharedVariable for tests, one of them has update
y = theano.shared(value=1, name='y')
z = theano.shared(value=2, name='z')
m = theano.shared(value=0, name='m')
# SharedVariable to replace
y_rpl = theano.shared(value=3,name ='y_rpl')
z_rpl = theano.shared(value=4, name='z_rpl')
swap = {y:y_rpl, z:z_rpl}
map_SV = {'y_rpl':y_rpl, 'z_rpl':z_rpl}
out = x+y+z+m
# Test for different linkers
# for mode in ["FAST_RUN","FAST_COMPILE"]:
second_time = False
for mode in ["FAST_RUN","FAST_COMPILE"]:
ori = theano.function([i], [out], mode=mode,
updates=[(z,z+1),(m,m+2)],
givens={x:x_list[i]})
cpy = ori.copy(swap=swap)
# run fuction several time
ori(1), cpy(1),cpy(2)
# assert same SharedVariable are update in different function
if not second_time:
# m should be updated 3 times
assert m.get_value() == 6
# z should be updated once
assert z.get_value() == 3
# z_rpl should be updated twice
assert z_rpl.get_value() == 6
# y and y_rpl should not be updated
assert y_rpl.get_value() == 3
assert y.get_value() == 1
elif second_time:
# doule update for sharedvariable
assert m.get_value() == 12
assert z.get_value() == 4
assert z_rpl.get_value() == 8
assert y_rpl.get_value() == 3
# test cpy function:
# 2. SharedVariable is updatable -> values did update(z == 5)
# 1. sharedvariable is swap -> Rpl sharedvariables share storage
names = map_SV.keys()
for key in cpy.fn.storage_map:
if key.name in names:
assert map_SV[key.name].container.storage[0] ==\
cpy.fn.storage_map[key][0]
second_time = True
def test_swap_SharedVaraile_with_given(self):
"""
A special testcase for logistic_sgd.py in Deep Learning Tutorial
This test assert that SharedVariable in different function have same storage
"""
train_x = theano.shared(value=numpy.random.rand(10,10).astype(config.floatX))
test_x = theano.shared(value=numpy.random.rand(10,10).astype(config.floatX))
train_y = theano.shared(value=numpy.random.rand(10,1).astype(config.floatX))
test_y = theano.shared(value=numpy.random.rand(10,1).astype(config.floatX))
i = T.iscalar('index')
x = T.vector('x')
y = T.vector('y')
# this formular has no sense but for a test
out = (T.sum(x) - y) ** 2
train = theano.function([i], out,
givens={x:train_x[i], y:train_y[i]},
updates={train_x:train_x+0.1})
test_def = theano.function([i], out, givens={x:test_x[i], y:test_y[i]})
test_cpy = train.copy(swap={train_x:test_x, train_y:test_y},
delete_updates=True)
for in1, in2 in zip( test_def.maker.inputs, test_def.maker.inputs):
assert in1.value is in2.value
def test_copy_delete_updates(self):
x = T.fscalar('x')
# SharedVariable for tests, one of them has update
y = theano.shared(value=1, name='y')
z = theano.shared(value=2, name='z')
out = x+y+z
# Test for different linkers
# for mode in ["FAST_RUN","FAST_COMPILE"]:
second_time = False
for mode in ["FAST_RUN","FAST_COMPILE"]:
ori = theano.function([x], out, mode=mode,updates={z:z*2})
cpy = ori.copy(delete_updates=True)
assert cpy(1)[0] == 4
assert cpy(1)[0] == 4
assert cpy(1)[0] == 4
def test_shared_state0(self):
a = T.scalar() # the a is for 'anonymous' (un-named).
x, s = T.scalars('xs')
......
......@@ -1069,11 +1069,9 @@ class CLinker(link.Linker):
pass
return utils.uniq(ret)
def __compile__(self, input_storage=None,
output_storage=None, keep_lock=False):
"""
WRITEME
def __compile__(self, input_storage=None, output_storage=None,
storage_map=None, keep_lock=False):
"""WRITEME
Compiles this linker's fgraph.
Parameters
......@@ -1111,6 +1109,7 @@ class CLinker(link.Linker):
thunk = self.cthunk_factory(error_storage,
input_storage,
output_storage,
storage_map,
keep_lock=keep_lock)
return (thunk,
[link.Container(input, storage) for input, storage in
......@@ -1143,10 +1142,8 @@ class CLinker(link.Linker):
return init_tasks, tasks
def make_thunk(self, input_storage=None, output_storage=None,
keep_lock=False):
"""
WRITEME
storage_map=None, keep_lock=False):
"""WRITEME
Compiles this linker's fgraph and returns a function to perform the
computations, as well as lists of storage cells for both the inputs
and outputs.
......@@ -1157,25 +1154,24 @@ class CLinker(link.Linker):
List of lists of length 1. In order to use
the thunk returned by __compile__, the inputs must be put in
that storage. If None, storage will be allocated.
output_storage: list of lists of length 1
The thunk returned by __compile__ will put the variables of the
computation in these lists. If None, storage will be allocated.
Returns
-------
object
Thunk, input_storage, output_storage.
The return values can be used as follows:
f, istor, ostor = clinker.make_thunk()
istor[0].data = first_input
istor[1].data = second_input
f()
first_output = ostor[0].data
@param output_storage: list of lists of length 1. The thunk returned
by __compile__ will put the variables of the computation in these
lists. If None, storage will be allocated.
@param storage_map: dict that map variables to storages. This is used
when you need to customize the storage of this thunk.
Returns: thunk, input_storage, output_storage
The return values can be used as follows:
f, istor, ostor = clinker.make_thunk()
istor[0].data = first_input
istor[1].data = second_input
f()
first_output = ostor[0].data
"""
init_tasks, tasks = self.get_init_tasks()
cthunk, in_storage, out_storage, error_storage = self.__compile__(
input_storage, output_storage,
input_storage, output_storage, storage_map,
keep_lock=keep_lock)
res = _CThunk(cthunk, init_tasks, tasks, error_storage)
......@@ -1529,25 +1525,17 @@ class CLinker(link.Linker):
return self._mod
def cthunk_factory(self, error_storage, in_storage, out_storage,
keep_lock=False):
"""
WRITEME
Parameters
----------
error_storage : list of length 3
in_storage : list of lists of length 1, one per input
out_storage : list of lists of length 1, one per output
Returns
-------
object
A thunk that points to an instance of a C struct that
can carry on the computation of this linker's fgraph. That thunk,
when executed, will fetch its inputs from in_storage, put its
outputs in out_storage and if an error occurs will put the
type, value and traceback of the exception in error_storage.
storage_map=None, keep_lock=False):
"""WRITEME
error_storage -> list of length 3
in_storage -> list of lists of length 1, one per input
out_storage -> list of lists of length 1, one per output
Returns a thunk that points to an instance of a C struct that
can carry on the computation of this linker's fgraph. That thunk,
when executed, will fetch its inputs from in_storage, put its
outputs in out_storage and if an error occurs will put the
type, value and traceback of the exception in error_storage.
"""
try:
key = self.cmodule_key()
......@@ -1569,7 +1557,10 @@ class CLinker(link.Linker):
out_storage = [x for i, x in enumerate(out_storage)
if (i + len(in_storage)) not in dupidx]
in_storage = [x for i, x in enumerate(in_storage) if i not in dupidx]
orphd = [[orphan.data] for orphan in self.orphans]
if storage_map is None:
orphd = [[orphan.data] for orphan in self.orphans]
else:
orphd = [storage_map[orphan] for orphan in self.orphans]
ret = module.instantiate(error_storage,
*(in_storage + out_storage + orphd))
......@@ -1727,7 +1718,8 @@ class OpWiseCLinker(link.LocalLinker):
self.no_recycling = no_recycling
return self
def make_all(self, profiler=None, input_storage=None, output_storage=None):
def make_all(self, profiler=None, input_storage=None, output_storage=None,
storage_map=None):
# The lock will be acquired when we compile the first
# C code. We will keep the lock untill all the function
......@@ -1741,7 +1733,7 @@ class OpWiseCLinker(link.LocalLinker):
no_recycling = self.no_recycling
input_storage, output_storage, storage_map = link.map_storage(
fgraph, order, input_storage, output_storage)
fgraph, order, input_storage, output_storage, storage_map)
if self.allow_gc:
computed, last_user = link.gc_helper(order)
post_thunk_old_storage = []
......
......@@ -278,6 +278,10 @@ class FunctionGraph(utils.object2):
# import #
def __import_r__(self, variable, reason):
"""
Import variables to this FunctionGraph and also their apply_node,
if those nodes are not in this graph.
"""
global NullType
if NullType is None:
from .null_type import NullType
......@@ -296,6 +300,12 @@ class FunctionGraph(utils.object2):
self.variables.add(variable)
def __import__(self, apply_node, check=True, reason=None):
"""
Given an apply_node, recursively search from this node to know graph,
and then add all unknown variables and apply_nodes to this graph.
"""
node = apply_node
# We import the nodes in topological order. We only are interested
# in new nodes, so we use all variables we know of as if they were the input set.
# (the functions in the graph module only use the input set to
......@@ -806,20 +816,32 @@ class FunctionGraph(utils.object2):
"""
return self.clone_get_equiv(check_integrity)[0]
def clone_get_equiv(self, check_integrity=True):
"""
WRITEME
def clone_get_equiv(self, check_integrity=True, attach_feature=True):
"""Clone the graph and get a memo( a dict )that map old node to new node
----------------------------
Parameters:
check_integrity - { bool } Whether to check integrity.
Default is True.
attach_feature - { bool } Whether to attach feature of origin graph to
cloned graph. Default is True.
----------------------------
Returns:
e - { FunctionGraph } Cloned fgraph. Every node in cloned graph is cloned.
equiv - { dict } A dict that map old node to new node.
"""
equiv = graph.clone_get_equiv(self.inputs, self.outputs)
if check_integrity:
self.check_integrity()
e = FunctionGraph([equiv[i] for i in self.inputs],
[equiv[o] for o in self.outputs])
[equiv[o] for o in self.outputs],
clone=False)
if check_integrity:
e.check_integrity()
for feature in self._features:
e.attach_feature(feature)
if attach_feature:
for feature in self._features:
e.attach_feature(feature)
return e, equiv
def __getstate__(self):
......
......@@ -496,10 +496,16 @@ class Container(object):
return r
def map_storage(fgraph, order, input_storage, output_storage):
"""
Ensure there is storage (a length-1 list) for inputs, outputs, and
interior nodes.
def map_storage(fgraph, order, input_storage, output_storage, storage_map=None):
"""Ensure there is storage (a length-1 list) for inputs, outputs, and interior nodes.
:param fgraph: The current fgraph. This function uses the inputs and outputs attributes.
:param order: an iterable over Apply instances (in program running order)
:param input_storage: None or existing input storage (see below)
:param output_storage: None or existing output storage (see below)
:rtype: 3-tuple
:returns: (list of storage for inputs, list of storage for outputs, and the `storage_map`)
Parameters
----------
......@@ -533,25 +539,45 @@ def map_storage(fgraph, order, input_storage, output_storage):
"""
# each Apply argument's data is stored in a list of length 1 (these lists act like pointers)
if storage_map is None:
storage_map = {}
# input_storage is a list of data-containers for the inputs.
if input_storage is None:
input_storage = [[None] for input in fgraph.inputs]
else:
assert len(fgraph.inputs) == len(input_storage)
storage_map = {}
for r, storage in izip(fgraph.inputs, input_storage):
storage_map[r] = storage
# add input storage into storage_map
for r, storage in zip(fgraph.inputs, input_storage):
if r in storage_map:
assert storage_map[r] is storage, ("Given input_storage conflicts "
"with storage in given storage_"
"map. Given input_storage: ",
storage, "Storage in storage_ma"
"p: ", storage_map[r])
else:
storage_map[r] = storage
# for orphan in fgraph.orphans:
# if not isinstance(orphan, Constant):
# raise TypeError("Cannot link a graph with non-constant orphans.", orphan)
# storage_map[orphan] = [orphan.data]
# allocate output storage
if output_storage is not None:
assert len(fgraph.outputs) == len(output_storage)
for r, storage in izip(fgraph.outputs, output_storage):
storage_map[r] = storage
for r, storage in zip(fgraph.outputs, output_storage):
if r in storage_map:
assert storage_map[r] is storage, ("Given output_storage confl"
"icts with storage in given"
" storage_map. Given output"
"_storage: ", storage, "Sto"
"rage in storage_map: ",
storage_map[r])
else:
storage_map[r] = storage
# allocate storage for intermediate computation
for node in order:
for r in node.inputs:
if r not in storage_map:
......@@ -563,6 +589,7 @@ def map_storage(fgraph, order, input_storage, output_storage):
if isinstance(r, graph.Constant):
storage_map.setdefault(r, [r.data])
# extract output storage
if output_storage is None:
output_storage = [storage_map[r] for r in fgraph.outputs]
......@@ -650,9 +677,10 @@ class LocalLinker(Linker):
"""
def make_thunk(self, input_storage=None, output_storage=None):
def make_thunk(self, input_storage=None, output_storage=None, storage_map=None):
return self.make_all(input_storage=input_storage,
output_storage=output_storage)[:3]
output_storage=output_storage,
storage_map=storage_map)[:3]
def make_all(self, input_storage, output_storage):
# By convention, subclasses of LocalLinker should implement this function!
......@@ -746,7 +774,7 @@ class PerformLinker(LocalLinker):
self.no_recycling = no_recycling
return self
def make_all(self, input_storage=None, output_storage=None):
def make_all(self, input_storage=None, output_storage=None, storage_map=None):
"""
Parameters
......@@ -768,7 +796,7 @@ class PerformLinker(LocalLinker):
order = self.schedule(fgraph)
no_recycling = self.no_recycling
input_storage, output_storage, storage_map = map_storage(fgraph, order, input_storage, output_storage)
input_storage, output_storage, storage_map = map_storage(fgraph, order, input_storage, output_storage, storage_map)
compute_map = {}
for k in storage_map:
......
......@@ -230,8 +230,6 @@ if run_memory_usage_tests:
a = cuda.CudaNdarray(n)
a.sum()
assert c == sys.getrefcount(n)
# This is to confuse flake8
a = a
del a
if not i % 1000:
print('.', end=' ')
......
......@@ -1000,14 +1000,14 @@ class VM_Linker(link.LocalLinker):
return vm
def make_all(self, profiler=None, input_storage=None,
output_storage=None,
output_storage=None, storage_map=None,
):
fgraph = self.fgraph
order = self.schedule(fgraph)
no_recycling = self.no_recycling
input_storage, output_storage, storage_map = link.map_storage(
fgraph, order, input_storage, output_storage)
fgraph, order, input_storage, output_storage, storage_map)
compute_map = {}
for k in storage_map:
compute_map[k] = [k.owner is None]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论